Spaces:
Running
Running
| from dataclasses import dataclass | |
| from typing import Dict | |
| from manager.utils.runtime_selector import detect_runtime_environment | |
| class ModelInfo: | |
| name: str | |
| size: float | |
| tokens_sec: int | |
| type: str | |
| description: str | |
| create_cost: int = 0 | |
| invoke_cost: int = 0 | |
| class ModelRegistry: | |
| def __init__(self): | |
| self.env = detect_runtime_environment() | |
| self.models = self._build_model_registry() | |
| def estimate_create_cost(self, size: float, is_api: bool) -> int: | |
| return int(size * (10 if is_api else 5)) | |
| def estimate_invoke_cost(self, tokens_sec: int, is_api: bool) -> int: | |
| base_cost = 40 if is_api else 20 | |
| return base_cost + max(0, 60 - tokens_sec) | |
| def _build_model_registry(self) -> Dict[str, ModelInfo]: | |
| raw_models = { | |
| "llama3.2": { | |
| "size": 3, | |
| "tokens_sec": 30, | |
| "type": "local", | |
| "description": "3B lightweight local model" | |
| }, | |
| "mistral": { | |
| "size": 7, | |
| "tokens_sec": 50, | |
| "type": "local", | |
| "description": "7B stronger local model" | |
| }, | |
| "gemini-2.0-flash": { | |
| "size": 6, | |
| "tokens_sec": 170, | |
| "type": "api", | |
| "description": "Fast and efficient API model" | |
| }, | |
| "gemini-2.5-pro-preview-03-25": { | |
| "size": 10, | |
| "tokens_sec": 148, | |
| "type": "api", | |
| "description": "High-reasoning API model" | |
| }, | |
| "gemini-1.5-flash": { | |
| "size": 7, | |
| "tokens_sec": 190, | |
| "type": "api", | |
| "description": "Fast general-purpose model" | |
| }, | |
| "gemini-2.0-flash-lite": { | |
| "size": 5, | |
| "tokens_sec": 208, | |
| "type": "api", | |
| "description": "Low-latency, cost-efficient API model" | |
| }, | |
| "gemini-2.0-flash-live-001": { | |
| "size": 9, | |
| "tokens_sec": 190, | |
| "type": "api", | |
| "description": "Voice/video low-latency API model" | |
| } | |
| } | |
| models = {} | |
| for name, model in raw_models.items(): | |
| is_api = model["type"] == "api" | |
| if is_api: | |
| # Flat cost for all API models | |
| create_cost, invoke_cost = 20, 50 | |
| else: | |
| create_cost = self.estimate_create_cost(model["size"], is_api=False) | |
| invoke_cost = self.estimate_invoke_cost(model["tokens_sec"], is_api=False) | |
| models[name] = ModelInfo( | |
| name=name, | |
| size=model["size"], | |
| tokens_sec=model["tokens_sec"], | |
| type=model["type"], | |
| description=model["description"], | |
| create_cost=create_cost, | |
| invoke_cost=invoke_cost | |
| ) | |
| return models | |
| def get_filtered_models(self) -> Dict[str, ModelInfo]: | |
| """Return only models that match the current runtime.""" | |
| if self.env in ["gpu", "cpu-local"]: | |
| return {k: v for k, v in self.models.items() if v.type == "local"} | |
| else: | |
| return {k: v for k, v in self.models.items() if v.type == "api"} | |
| def get_all_models(self) -> Dict[str, ModelInfo]: | |
| """Return all models regardless of runtime.""" | |
| return self.models | |
| if __name__ == "__main__": | |
| registry = ModelRegistry() | |
| print(f"[INFO] Detected runtime: {registry.env}\n") | |
| print("Filtered models based on environment:") | |
| for name, model in registry.get_filtered_models().items(): | |
| print(f"{name}: create={model.create_cost}, invoke={model.invoke_cost}, type={model.type}") | |