Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from abc import abstractmethod | |
| import asyncio | |
| from os.path import realpath, dirname, join | |
| from loguru import logger as log | |
| import numpy as np | |
| from api_types import InferenceRequest, InferenceResult, SeedingRequest | |
| ROOT_DIR = realpath(dirname(dirname(dirname(__file__)))) | |
| DATA_DIR = join(ROOT_DIR, "data") | |
| class InferenceModel(): | |
| """ | |
| Base class for models that can be served by the inference server | |
| defined in `server.py`. | |
| """ | |
| def __init__(self, data_path: str | None = None, checkpoint_path: str | None = None, | |
| fake_delay_ms: float = 0, inference_cache_size: int = 15, | |
| compress_inference_results: bool = True) -> None: | |
| # These paths may be unused by certain inference server types. | |
| self.data_path = data_path | |
| self.checkpoint_path = checkpoint_path | |
| self.fake_delay_ms = fake_delay_ms | |
| self.inference_cache_size = inference_cache_size | |
| self.inference_tasks: dict[str, asyncio.Task] = {} | |
| self.inference_results: dict[str, InferenceResult] = {} | |
| self.request_history: set[str] = set() | |
| # If supported by the model and relevant, compress inference results, | |
| # e.g. as MP4 video, before returning from the server. | |
| self.compress_inference_results: bool = compress_inference_results | |
| # Can be acquired before starting inference | |
| # if the model can only handle one request at a time | |
| self.inference_lock = asyncio.Lock() | |
| # The generative model may need to be seeded with one or more initial frames. | |
| self.model_seeded = False | |
| # ----------- Inference model interface | |
| async def make_test_image(self): | |
| """Evaluate one default inference request, if possible. | |
| Helps ensuring that the model has been loaded correctly.""" | |
| raise NotImplementedError("make_test_image") | |
| async def seed_model(self, req: SeedingRequest) -> None: | |
| """By default, no seeding is required so the default implementation just returns.""" | |
| self.model_seeded = True | |
| async def run_inference(self, req: InferenceRequest) -> InferenceResult: | |
| """Evaluate the actual inference model to produce an inference result.""" | |
| raise NotImplementedError("run_inference") | |
| def metadata(self) -> dict: | |
| """Returns metadata about this inference server.""" | |
| raise NotImplementedError("metadata") | |
| def min_frames_per_request(self) -> int: | |
| """Minimum number of frames that can be produced in one inference batch.""" | |
| raise NotImplementedError("min_frames_per_request") | |
| def max_frames_per_request(self) -> int: | |
| """Maximum number of frames that can be produced in one inference batch.""" | |
| raise NotImplementedError("max_frames_per_request") | |
| def inference_time_per_frame(self) -> int: | |
| """Estimated average inference time per frame (not per batch!) in seconds.""" | |
| raise NotImplementedError("inference_time_per_frame") | |
| def inference_resolution(self) -> list[tuple[int, int]] | None: | |
| """ | |
| The supported inference resolutions (width, height) in pixels, | |
| or None if any resolution is supported. | |
| """ | |
| return None | |
| def default_framerate(self) -> float | None: | |
| """ | |
| The model's preferred framerate when generating video. | |
| Returns None when not applicable. | |
| """ | |
| return None | |
| def requires_seeding(self) -> int: | |
| """Whether or not this model requires to be seeded with images before inference.""" | |
| return False | |
| # ----------- Requests handling | |
| def request_inference(self, req: InferenceRequest) -> asyncio.Task: | |
| if not self.model_seeded: | |
| raise ValueError(f"Received request id '{req.request_id}', but the model was not seeded.") | |
| if (req.request_id in self.inference_tasks) or (req.request_id in self.inference_results): | |
| raise ValueError(f"Invalid request id '{req.request_id}': request already exists.") | |
| self.check_valid_request(req) | |
| task = asyncio.create_task(self.run_inference(req)) | |
| self.inference_tasks[req.request_id] = task | |
| self.request_history.add(req.request_id) | |
| return task | |
| async def request_inference_sync(self, req: InferenceRequest) -> InferenceResult: | |
| await self.request_inference(req) | |
| result = self.inference_result_or_none(req.request_id) | |
| assert isinstance(result, InferenceResult) | |
| return result | |
| def inference_result_or_none(self, request_id: str) -> InferenceResult | None: | |
| if request_id in self.inference_tasks: | |
| task = self.inference_tasks[request_id] | |
| if task.done(): | |
| try: | |
| # Inference result ready, cache it and return it | |
| result = task.result() | |
| self.inference_results[request_id] = result | |
| del self.inference_tasks[request_id] | |
| self.evict_results() | |
| return result | |
| except Exception as e: | |
| # Inference failed | |
| log.error(f"Task for request '{request_id}' failed with exception {e}") | |
| raise e | |
| else: | |
| # Inference result not ready yet | |
| return None | |
| elif request_id in self.inference_results: | |
| # Inference result was ready and cached, return it directly | |
| return self.inference_results[request_id] | |
| elif request_id in self.request_history: | |
| raise KeyError(f"Request with id '{request_id}' was known, but does not have any result. Perhaps it was evicted from the cache or failed.") | |
| else: | |
| raise KeyError(f"Invalid request id '{request_id}': request not known.") | |
| def evict_results(self, keep_max: int | None = None): | |
| """ | |
| Evict all results that were added before the last `keep_max` entries. | |
| """ | |
| keep_max = keep_max if (keep_max is not None) else self.inference_cache_size | |
| to_evict = [] | |
| for i, k in enumerate(reversed(self.inference_results)): | |
| if i < keep_max: | |
| continue | |
| to_evict.append(k) | |
| for k in to_evict: | |
| del self.inference_results[k] | |
| def get_latest_rgb(self) -> np.ndarray | None: | |
| """Returns the latest generated RGB image, if any. Useful for debugging.""" | |
| if not self.inference_results: | |
| return None | |
| last_key = next(reversed(self.inference_results.keys())) | |
| return self.inference_results[last_key].images[-1, ...] | |
| def check_valid_request(self, req: InferenceRequest): | |
| if len(req) not in range(self.min_frames_per_request(), self.max_frames_per_request() + 1): | |
| raise ValueError(f"This model can produce between {self.min_frames_per_request()} and" | |
| f" {self.max_frames_per_request()} frames per request, but the request" | |
| f" specified {len(req)} camera poses.") | |
| return True | |
| # ----------- Resource management | |
| def cleanup(self): | |
| pass | |