Mirrowel commited on
Commit
b6a47c9
·
1 Parent(s): bd8f638

feat(api): ✨ add model pricing and capabilities enrichment service

Browse files

Introduces a new model information service that fetches pricing and capability data from external catalogs (OpenRouter and Models.dev) to enrich the /v1/models endpoint and enable cost estimation.

- Implements ModelRegistry class with async background data fetching to avoid blocking proxy startup
- Adds fuzzy model ID matching with multi-source data aggregation
- Expands /v1/models endpoint with optional enriched response containing pricing, token limits, and capability flags
- Adds new endpoints: GET /v1/models/{model_id}, GET /v1/model-info/stats, POST /v1/cost-estimate
- Supports per-token pricing for input, output, cache read, and cache write operations
- Integrates with lifespan management for proper service initialization and cleanup
- Includes comprehensive backward compatibility layer for gradual migration

The service refreshes data every 6 hours (configurable via MODEL_INFO_REFRESH_INTERVAL) and runs asynchronously to maintain fast proxy initialization times.

src/proxy_app/main.py CHANGED
@@ -100,6 +100,7 @@ with _console.status("[dim]Initializing proxy core...", spinner="dots"):
100
  from rotator_library import RotatingClient
101
  from rotator_library.credential_manager import CredentialManager
102
  from rotator_library.background_refresher import BackgroundRefresher
 
103
  from proxy_app.request_logger import log_request_to_console
104
  from proxy_app.batch_manager import EmbeddingBatcher
105
  from proxy_app.detailed_logger import DetailedLogger
@@ -123,15 +124,59 @@ class EmbeddingRequest(BaseModel):
123
  user: Optional[str] = None
124
 
125
  class ModelCard(BaseModel):
 
126
  id: str
127
  object: str = "model"
128
  created: int = Field(default_factory=lambda: int(time.time()))
129
  owned_by: str = "Mirro-Proxy"
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  class ModelList(BaseModel):
 
132
  object: str = "list"
133
  data: List[ModelCard]
134
 
 
 
 
 
 
135
  # Calculate total loading time
136
  _elapsed = time.time() - _start_time
137
  print(f"✓ Server ready in {_elapsed:.2f}s ({_plugin_count} providers discovered in {_provider_time:.2f}s)")
@@ -470,6 +515,12 @@ async def lifespan(app: FastAPI):
470
  else:
471
  app.state.embedding_batcher = None
472
  logging.info("RotatingClient initialized (EmbeddingBatcher disabled).")
 
 
 
 
 
 
473
 
474
  yield
475
 
@@ -478,6 +529,10 @@ async def lifespan(app: FastAPI):
478
  await app.state.embedding_batcher.stop()
479
  await client.close()
480
 
 
 
 
 
481
  if app.state.embedding_batcher:
482
  logging.info("RotatingClient and EmbeddingBatcher closed.")
483
  else:
@@ -847,17 +902,73 @@ async def embeddings(
847
  def read_root():
848
  return {"Status": "API Key Proxy is running"}
849
 
850
- @app.get("/v1/models", response_model=ModelList)
851
  async def list_models(
 
852
  client: RotatingClient = Depends(get_rotating_client),
853
- _=Depends(verify_api_key)
 
854
  ):
855
  """
856
  Returns a list of available models in the OpenAI-compatible format.
 
 
 
 
857
  """
858
  model_ids = await client.get_all_available_models(grouped=False)
859
- model_cards = [ModelCard(id=model_id) for model_id in model_ids]
860
- return ModelList(data=model_cards)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
861
 
862
 
863
  @app.get("/v1/providers")
@@ -891,6 +1002,101 @@ async def token_count(
891
  logging.error(f"Token count failed: {e}")
892
  raise HTTPException(status_code=500, detail=str(e))
893
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
894
  if __name__ == "__main__":
895
  # Define ENV_FILE for onboarding checks
896
  ENV_FILE = Path.cwd() / ".env"
 
100
  from rotator_library import RotatingClient
101
  from rotator_library.credential_manager import CredentialManager
102
  from rotator_library.background_refresher import BackgroundRefresher
103
+ from rotator_library.model_info_service import init_model_info_service
104
  from proxy_app.request_logger import log_request_to_console
105
  from proxy_app.batch_manager import EmbeddingBatcher
106
  from proxy_app.detailed_logger import DetailedLogger
 
124
  user: Optional[str] = None
125
 
126
  class ModelCard(BaseModel):
127
+ """Basic model card for minimal response."""
128
  id: str
129
  object: str = "model"
130
  created: int = Field(default_factory=lambda: int(time.time()))
131
  owned_by: str = "Mirro-Proxy"
132
 
133
+ class ModelCapabilities(BaseModel):
134
+ """Model capability flags."""
135
+ tool_choice: bool = False
136
+ function_calling: bool = False
137
+ reasoning: bool = False
138
+ vision: bool = False
139
+ system_messages: bool = True
140
+ prompt_caching: bool = False
141
+ assistant_prefill: bool = False
142
+
143
+ class EnrichedModelCard(BaseModel):
144
+ """Extended model card with pricing and capabilities."""
145
+ id: str
146
+ object: str = "model"
147
+ created: int = Field(default_factory=lambda: int(time.time()))
148
+ owned_by: str = "unknown"
149
+ # Pricing (optional - may not be available for all models)
150
+ input_cost_per_token: Optional[float] = None
151
+ output_cost_per_token: Optional[float] = None
152
+ cache_read_input_token_cost: Optional[float] = None
153
+ cache_creation_input_token_cost: Optional[float] = None
154
+ # Limits (optional)
155
+ max_input_tokens: Optional[int] = None
156
+ max_output_tokens: Optional[int] = None
157
+ context_window: Optional[int] = None
158
+ # Capabilities
159
+ mode: str = "chat"
160
+ supported_modalities: List[str] = Field(default_factory=lambda: ["text"])
161
+ supported_output_modalities: List[str] = Field(default_factory=lambda: ["text"])
162
+ capabilities: Optional[ModelCapabilities] = None
163
+ # Debug info (optional)
164
+ _sources: Optional[List[str]] = None
165
+ _match_type: Optional[str] = None
166
+
167
+ class Config:
168
+ extra = "allow" # Allow extra fields from the service
169
+
170
  class ModelList(BaseModel):
171
+ """List of models response."""
172
  object: str = "list"
173
  data: List[ModelCard]
174
 
175
+ class EnrichedModelList(BaseModel):
176
+ """List of enriched models with pricing and capabilities."""
177
+ object: str = "list"
178
+ data: List[EnrichedModelCard]
179
+
180
  # Calculate total loading time
181
  _elapsed = time.time() - _start_time
182
  print(f"✓ Server ready in {_elapsed:.2f}s ({_plugin_count} providers discovered in {_provider_time:.2f}s)")
 
515
  else:
516
  app.state.embedding_batcher = None
517
  logging.info("RotatingClient initialized (EmbeddingBatcher disabled).")
518
+
519
+ # Start model info service in background (fetches pricing/capabilities data)
520
+ # This runs asynchronously and doesn't block proxy startup
521
+ model_info_service = await init_model_info_service()
522
+ app.state.model_info_service = model_info_service
523
+ logging.info("Model info service started (fetching pricing data in background).")
524
 
525
  yield
526
 
 
529
  await app.state.embedding_batcher.stop()
530
  await client.close()
531
 
532
+ # Stop model info service
533
+ if hasattr(app.state, 'model_info_service') and app.state.model_info_service:
534
+ await app.state.model_info_service.stop()
535
+
536
  if app.state.embedding_batcher:
537
  logging.info("RotatingClient and EmbeddingBatcher closed.")
538
  else:
 
902
  def read_root():
903
  return {"Status": "API Key Proxy is running"}
904
 
905
+ @app.get("/v1/models")
906
  async def list_models(
907
+ request: Request,
908
  client: RotatingClient = Depends(get_rotating_client),
909
+ _=Depends(verify_api_key),
910
+ enriched: bool = True,
911
  ):
912
  """
913
  Returns a list of available models in the OpenAI-compatible format.
914
+
915
+ Query Parameters:
916
+ enriched: If True (default), returns detailed model info with pricing and capabilities.
917
+ If False, returns minimal OpenAI-compatible response.
918
  """
919
  model_ids = await client.get_all_available_models(grouped=False)
920
+
921
+ if enriched and hasattr(request.app.state, 'model_info_service'):
922
+ model_info_service = request.app.state.model_info_service
923
+ if model_info_service.is_ready():
924
+ # Return enriched model data
925
+ enriched_data = model_info_service.enrich_model_list(model_ids)
926
+ return {"object": "list", "data": enriched_data}
927
+
928
+ # Fallback to basic model cards
929
+ model_cards = [{"id": model_id, "object": "model", "created": int(time.time()), "owned_by": "Mirro-Proxy"} for model_id in model_ids]
930
+ return {"object": "list", "data": model_cards}
931
+
932
+
933
+ @app.get("/v1/models/{model_id:path}")
934
+ async def get_model(
935
+ model_id: str,
936
+ request: Request,
937
+ _=Depends(verify_api_key),
938
+ ):
939
+ """
940
+ Returns detailed information about a specific model.
941
+
942
+ Path Parameters:
943
+ model_id: The model ID (e.g., "anthropic/claude-3-opus", "openrouter/openai/gpt-4")
944
+ """
945
+ if hasattr(request.app.state, 'model_info_service'):
946
+ model_info_service = request.app.state.model_info_service
947
+ if model_info_service.is_ready():
948
+ info = model_info_service.get_model_info(model_id)
949
+ if info:
950
+ return info.to_dict()
951
+
952
+ # Return basic info if service not ready or model not found
953
+ return {
954
+ "id": model_id,
955
+ "object": "model",
956
+ "created": int(time.time()),
957
+ "owned_by": model_id.split("/")[0] if "/" in model_id else "unknown",
958
+ }
959
+
960
+
961
+ @app.get("/v1/model-info/stats")
962
+ async def model_info_stats(
963
+ request: Request,
964
+ _=Depends(verify_api_key),
965
+ ):
966
+ """
967
+ Returns statistics about the model info service (for monitoring/debugging).
968
+ """
969
+ if hasattr(request.app.state, 'model_info_service'):
970
+ return request.app.state.model_info_service.get_stats()
971
+ return {"error": "Model info service not initialized"}
972
 
973
 
974
  @app.get("/v1/providers")
 
1002
  logging.error(f"Token count failed: {e}")
1003
  raise HTTPException(status_code=500, detail=str(e))
1004
 
1005
+
1006
+ @app.post("/v1/cost-estimate")
1007
+ async def cost_estimate(
1008
+ request: Request,
1009
+ _=Depends(verify_api_key)
1010
+ ):
1011
+ """
1012
+ Estimates the cost for a request based on token counts and model pricing.
1013
+
1014
+ Request body:
1015
+ {
1016
+ "model": "anthropic/claude-3-opus",
1017
+ "prompt_tokens": 1000,
1018
+ "completion_tokens": 500,
1019
+ "cache_read_tokens": 0, # optional
1020
+ "cache_creation_tokens": 0 # optional
1021
+ }
1022
+
1023
+ Returns:
1024
+ {
1025
+ "model": "anthropic/claude-3-opus",
1026
+ "cost": 0.0375,
1027
+ "currency": "USD",
1028
+ "pricing": {
1029
+ "input_cost_per_token": 0.000015,
1030
+ "output_cost_per_token": 0.000075
1031
+ },
1032
+ "source": "model_info_service" # or "litellm_fallback"
1033
+ }
1034
+ """
1035
+ try:
1036
+ data = await request.json()
1037
+ model = data.get("model")
1038
+ prompt_tokens = data.get("prompt_tokens", 0)
1039
+ completion_tokens = data.get("completion_tokens", 0)
1040
+ cache_read_tokens = data.get("cache_read_tokens", 0)
1041
+ cache_creation_tokens = data.get("cache_creation_tokens", 0)
1042
+
1043
+ if not model:
1044
+ raise HTTPException(status_code=400, detail="'model' is required.")
1045
+
1046
+ result = {
1047
+ "model": model,
1048
+ "cost": None,
1049
+ "currency": "USD",
1050
+ "pricing": {},
1051
+ "source": None
1052
+ }
1053
+
1054
+ # Try model info service first
1055
+ if hasattr(request.app.state, 'model_info_service'):
1056
+ model_info_service = request.app.state.model_info_service
1057
+ if model_info_service.is_ready():
1058
+ cost = model_info_service.calculate_cost(
1059
+ model, prompt_tokens, completion_tokens,
1060
+ cache_read_tokens, cache_creation_tokens
1061
+ )
1062
+ if cost is not None:
1063
+ cost_info = model_info_service.get_cost_info(model)
1064
+ result["cost"] = cost
1065
+ result["pricing"] = cost_info or {}
1066
+ result["source"] = "model_info_service"
1067
+ return result
1068
+
1069
+ # Fallback to litellm
1070
+ try:
1071
+ import litellm
1072
+ # Create a mock response for cost calculation
1073
+ model_info = litellm.get_model_info(model)
1074
+ input_cost = model_info.get("input_cost_per_token", 0)
1075
+ output_cost = model_info.get("output_cost_per_token", 0)
1076
+
1077
+ if input_cost or output_cost:
1078
+ cost = (prompt_tokens * input_cost) + (completion_tokens * output_cost)
1079
+ result["cost"] = cost
1080
+ result["pricing"] = {
1081
+ "input_cost_per_token": input_cost,
1082
+ "output_cost_per_token": output_cost
1083
+ }
1084
+ result["source"] = "litellm_fallback"
1085
+ return result
1086
+ except Exception:
1087
+ pass
1088
+
1089
+ result["source"] = "unknown"
1090
+ result["error"] = "Pricing data not available for this model"
1091
+ return result
1092
+
1093
+ except HTTPException:
1094
+ raise
1095
+ except Exception as e:
1096
+ logging.error(f"Cost estimate failed: {e}")
1097
+ raise HTTPException(status_code=500, detail=str(e))
1098
+
1099
+
1100
  if __name__ == "__main__":
1101
  # Define ENV_FILE for onboarding checks
1102
  ENV_FILE = Path.cwd() / ".env"
src/rotator_library/__init__.py CHANGED
@@ -7,12 +7,19 @@ from .client import RotatingClient
7
  if TYPE_CHECKING:
8
  from .providers import PROVIDER_PLUGINS
9
  from .providers.provider_interface import ProviderInterface
 
10
 
11
- __all__ = ["RotatingClient", "PROVIDER_PLUGINS"]
12
 
13
  def __getattr__(name):
14
- """Lazy-load PROVIDER_PLUGINS to speed up module import."""
15
  if name == "PROVIDER_PLUGINS":
16
  from .providers import PROVIDER_PLUGINS
17
  return PROVIDER_PLUGINS
 
 
 
 
 
 
18
  raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
 
7
  if TYPE_CHECKING:
8
  from .providers import PROVIDER_PLUGINS
9
  from .providers.provider_interface import ProviderInterface
10
+ from .model_info_service import ModelInfoService, ModelInfo
11
 
12
+ __all__ = ["RotatingClient", "PROVIDER_PLUGINS", "ModelInfoService", "ModelInfo"]
13
 
14
  def __getattr__(name):
15
+ """Lazy-load PROVIDER_PLUGINS and ModelInfoService to speed up module import."""
16
  if name == "PROVIDER_PLUGINS":
17
  from .providers import PROVIDER_PLUGINS
18
  return PROVIDER_PLUGINS
19
+ if name == "ModelInfoService":
20
+ from .model_info_service import ModelInfoService
21
+ return ModelInfoService
22
+ if name == "ModelInfo":
23
+ from .model_info_service import ModelInfo
24
+ return ModelInfo
25
  raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
src/rotator_library/model_info_service.py ADDED
@@ -0,0 +1,946 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified Model Registry
3
+
4
+ Provides aggregated model metadata from external catalogs (OpenRouter, Models.dev)
5
+ for pricing calculations and the /v1/models endpoint.
6
+
7
+ Data retrieval happens asynchronously post-startup to keep initialization fast.
8
+ """
9
+
10
+ import asyncio
11
+ import json
12
+ import logging
13
+ import os
14
+ import time
15
+ from dataclasses import dataclass, field
16
+ from typing import Any, Dict, List, Optional, Tuple
17
+ from urllib.request import Request, urlopen
18
+ from urllib.error import URLError
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ # ============================================================================
24
+ # Data Structures
25
+ # ============================================================================
26
+
27
+ @dataclass
28
+ class ModelPricing:
29
+ """Token-level pricing information."""
30
+ prompt: Optional[float] = None
31
+ completion: Optional[float] = None
32
+ cached_input: Optional[float] = None
33
+ cache_write: Optional[float] = None
34
+
35
+
36
+ @dataclass
37
+ class ModelLimits:
38
+ """Context and output token limits."""
39
+ context_window: Optional[int] = None
40
+ max_output: Optional[int] = None
41
+
42
+
43
+ @dataclass
44
+ class ModelCapabilities:
45
+ """Feature flags for model capabilities."""
46
+ tools: bool = False
47
+ functions: bool = False
48
+ reasoning: bool = False
49
+ vision: bool = False
50
+ system_prompt: bool = True
51
+ caching: bool = False
52
+ prefill: bool = False
53
+
54
+
55
+ @dataclass
56
+ class ModelMetadata:
57
+ """Complete model information record."""
58
+
59
+ model_id: str
60
+ display_name: str = ""
61
+ provider: str = ""
62
+ category: str = "chat" # chat, embedding, image, audio
63
+
64
+ pricing: ModelPricing = field(default_factory=ModelPricing)
65
+ limits: ModelLimits = field(default_factory=ModelLimits)
66
+ capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
67
+
68
+ input_types: List[str] = field(default_factory=lambda: ["text"])
69
+ output_types: List[str] = field(default_factory=lambda: ["text"])
70
+
71
+ timestamp: int = field(default_factory=lambda: int(time.time()))
72
+ origin: str = ""
73
+ match_quality: str = "unknown"
74
+
75
+ def as_api_response(self) -> Dict[str, Any]:
76
+ """Format for OpenAI-compatible /v1/models response."""
77
+ response = {
78
+ "id": self.model_id,
79
+ "object": "model",
80
+ "created": self.timestamp,
81
+ "owned_by": self.provider or "proxy",
82
+ }
83
+
84
+ # Pricing fields
85
+ if self.pricing.prompt is not None:
86
+ response["input_cost_per_token"] = self.pricing.prompt
87
+ if self.pricing.completion is not None:
88
+ response["output_cost_per_token"] = self.pricing.completion
89
+ if self.pricing.cached_input is not None:
90
+ response["cache_read_input_token_cost"] = self.pricing.cached_input
91
+ if self.pricing.cache_write is not None:
92
+ response["cache_creation_input_token_cost"] = self.pricing.cache_write
93
+
94
+ # Limits
95
+ if self.limits.context_window:
96
+ response["max_input_tokens"] = self.limits.context_window
97
+ response["context_window"] = self.limits.context_window
98
+ if self.limits.max_output:
99
+ response["max_output_tokens"] = self.limits.max_output
100
+
101
+ # Category and modalities
102
+ response["mode"] = self.category
103
+ response["supported_modalities"] = self.input_types
104
+ response["supported_output_modalities"] = self.output_types
105
+
106
+ # Capability flags
107
+ response["capabilities"] = {
108
+ "tool_choice": self.capabilities.tools,
109
+ "function_calling": self.capabilities.functions,
110
+ "reasoning": self.capabilities.reasoning,
111
+ "vision": self.capabilities.vision,
112
+ "system_messages": self.capabilities.system_prompt,
113
+ "prompt_caching": self.capabilities.caching,
114
+ "assistant_prefill": self.capabilities.prefill,
115
+ }
116
+
117
+ # Debug metadata
118
+ if self.origin:
119
+ response["_sources"] = [self.origin]
120
+ response["_match_type"] = self.match_quality
121
+
122
+ return response
123
+
124
+ def as_minimal(self) -> Dict[str, Any]:
125
+ """Minimal OpenAI format."""
126
+ return {
127
+ "id": self.model_id,
128
+ "object": "model",
129
+ "created": self.timestamp,
130
+ "owned_by": self.provider or "proxy",
131
+ }
132
+
133
+ def to_dict(self) -> Dict[str, Any]:
134
+ """Alias for as_api_response() - backward compatibility."""
135
+ return self.as_api_response()
136
+
137
+ def to_openai_format(self) -> Dict[str, Any]:
138
+ """Alias for as_minimal() - backward compatibility."""
139
+ return self.as_minimal()
140
+
141
+ # Backward-compatible property aliases
142
+ @property
143
+ def id(self) -> str:
144
+ return self.model_id
145
+
146
+ @property
147
+ def name(self) -> str:
148
+ return self.display_name
149
+
150
+ @property
151
+ def input_cost_per_token(self) -> Optional[float]:
152
+ return self.pricing.prompt
153
+
154
+ @property
155
+ def output_cost_per_token(self) -> Optional[float]:
156
+ return self.pricing.completion
157
+
158
+ @property
159
+ def cache_read_input_token_cost(self) -> Optional[float]:
160
+ return self.pricing.cached_input
161
+
162
+ @property
163
+ def cache_creation_input_token_cost(self) -> Optional[float]:
164
+ return self.pricing.cache_write
165
+
166
+ @property
167
+ def max_input_tokens(self) -> Optional[int]:
168
+ return self.limits.context_window
169
+
170
+ @property
171
+ def max_output_tokens(self) -> Optional[int]:
172
+ return self.limits.max_output
173
+
174
+ @property
175
+ def mode(self) -> str:
176
+ return self.category
177
+
178
+ @property
179
+ def supported_modalities(self) -> List[str]:
180
+ return self.input_types
181
+
182
+ @property
183
+ def supported_output_modalities(self) -> List[str]:
184
+ return self.output_types
185
+
186
+ @property
187
+ def supports_tool_choice(self) -> bool:
188
+ return self.capabilities.tools
189
+
190
+ @property
191
+ def supports_function_calling(self) -> bool:
192
+ return self.capabilities.functions
193
+
194
+ @property
195
+ def supports_reasoning(self) -> bool:
196
+ return self.capabilities.reasoning
197
+
198
+ @property
199
+ def supports_vision(self) -> bool:
200
+ return self.capabilities.vision
201
+
202
+ @property
203
+ def supports_system_messages(self) -> bool:
204
+ return self.capabilities.system_prompt
205
+
206
+ @property
207
+ def supports_prompt_caching(self) -> bool:
208
+ return self.capabilities.caching
209
+
210
+ @property
211
+ def supports_assistant_prefill(self) -> bool:
212
+ return self.capabilities.prefill
213
+
214
+ @property
215
+ def litellm_provider(self) -> str:
216
+ return self.provider
217
+
218
+ @property
219
+ def created(self) -> int:
220
+ return self.timestamp
221
+
222
+ @property
223
+ def _sources(self) -> List[str]:
224
+ return [self.origin] if self.origin else []
225
+
226
+ @property
227
+ def _match_type(self) -> str:
228
+ return self.match_quality
229
+
230
+
231
+ # ============================================================================
232
+ # Data Source Adapters
233
+ # ============================================================================
234
+
235
+ class DataSourceAdapter:
236
+ """Base interface for external data sources."""
237
+
238
+ source_name: str = "unknown"
239
+ endpoint: str = ""
240
+
241
+ def fetch(self) -> Dict[str, Dict]:
242
+ """Retrieve and normalize data. Returns {model_id: raw_data}."""
243
+ raise NotImplementedError
244
+
245
+ def _http_get(self, url: str, timeout: int = 30) -> Any:
246
+ """Execute HTTP GET with standard headers."""
247
+ req = Request(url, headers={"User-Agent": "ModelRegistry/1.0"})
248
+ with urlopen(req, timeout=timeout) as resp:
249
+ return json.loads(resp.read().decode("utf-8"))
250
+
251
+
252
+ class OpenRouterAdapter(DataSourceAdapter):
253
+ """Fetches model data from OpenRouter's public API."""
254
+
255
+ source_name = "openrouter"
256
+ endpoint = "https://openrouter.ai/api/v1/models"
257
+
258
+ def fetch(self) -> Dict[str, Dict]:
259
+ try:
260
+ raw = self._http_get(self.endpoint)
261
+ entries = raw.get("data", [])
262
+
263
+ catalog = {}
264
+ for entry in entries:
265
+ mid = entry.get("id")
266
+ if not mid:
267
+ continue
268
+
269
+ full_id = f"openrouter/{mid}"
270
+ catalog[full_id] = self._normalize(entry)
271
+
272
+ return catalog
273
+ except (URLError, json.JSONDecodeError, TimeoutError) as err:
274
+ raise ConnectionError(f"OpenRouter unavailable: {err}") from err
275
+
276
+ def _normalize(self, raw: Dict) -> Dict:
277
+ """Transform OpenRouter schema to internal format."""
278
+ prices = raw.get("pricing", {})
279
+ arch = raw.get("architecture", {})
280
+ top = raw.get("top_provider", {})
281
+ params = raw.get("supported_parameters", [])
282
+
283
+ tokenizer = arch.get("tokenizer", "")
284
+ category = "embedding" if "embedding" in tokenizer.lower() else "chat"
285
+
286
+ return {
287
+ "name": raw.get("name", ""),
288
+ "prompt_cost": float(prices.get("prompt", 0)),
289
+ "completion_cost": float(prices.get("completion", 0)),
290
+ "cache_read_cost": float(prices.get("input_cache_read", 0)) or None,
291
+ "context": top.get("context_length", 0),
292
+ "max_out": top.get("max_completion_tokens", 0),
293
+ "category": category,
294
+ "inputs": arch.get("input_modalities", ["text"]),
295
+ "outputs": arch.get("output_modalities", ["text"]),
296
+ "has_tools": "tool_choice" in params or "tools" in params,
297
+ "has_functions": "tools" in params or "function_calling" in params,
298
+ "has_reasoning": "reasoning" in params,
299
+ "has_vision": "image" in arch.get("input_modalities", []),
300
+ "provider": "openrouter",
301
+ "source": "openrouter",
302
+ }
303
+
304
+
305
+ class ModelsDevAdapter(DataSourceAdapter):
306
+ """Fetches model data from Models.dev catalog."""
307
+
308
+ source_name = "modelsdev"
309
+ endpoint = "https://models.dev/api.json"
310
+
311
+ def __init__(self, skip_providers: Optional[List[str]] = None):
312
+ self.skip_providers = skip_providers or []
313
+
314
+ def fetch(self) -> Dict[str, Dict]:
315
+ try:
316
+ raw = self._http_get(self.endpoint)
317
+
318
+ catalog = {}
319
+ for provider_key, provider_block in raw.items():
320
+ if not isinstance(provider_block, dict):
321
+ continue
322
+ if provider_key in self.skip_providers:
323
+ continue
324
+
325
+ models_block = provider_block.get("models", {})
326
+ if not isinstance(models_block, dict):
327
+ continue
328
+
329
+ for model_key, model_data in models_block.items():
330
+ if not isinstance(model_data, dict):
331
+ continue
332
+
333
+ full_id = f"{provider_key}/{model_key}"
334
+ catalog[full_id] = self._normalize(model_data, provider_key)
335
+
336
+ return catalog
337
+ except (URLError, json.JSONDecodeError, TimeoutError) as err:
338
+ raise ConnectionError(f"Models.dev unavailable: {err}") from err
339
+
340
+ def _normalize(self, raw: Dict, provider_key: str) -> Dict:
341
+ """Transform Models.dev schema to internal format."""
342
+ costs = raw.get("cost", {})
343
+ mods = raw.get("modalities", {})
344
+ lims = raw.get("limit", {})
345
+
346
+ outputs = mods.get("output", ["text"])
347
+ if "image" in outputs:
348
+ category = "image"
349
+ elif "audio" in outputs:
350
+ category = "audio"
351
+ else:
352
+ category = "chat"
353
+
354
+ # Models.dev uses per-million pricing, convert to per-token
355
+ divisor = 1_000_000
356
+
357
+ cache_read = costs.get("cache_read")
358
+ cache_write = costs.get("cache_write")
359
+
360
+ return {
361
+ "name": raw.get("name", ""),
362
+ "prompt_cost": float(costs.get("input", 0)) / divisor,
363
+ "completion_cost": float(costs.get("output", 0)) / divisor,
364
+ "cache_read_cost": float(cache_read) / divisor if cache_read else None,
365
+ "cache_write_cost": float(cache_write) / divisor if cache_write else None,
366
+ "context": lims.get("context", 0),
367
+ "max_out": lims.get("output", 0),
368
+ "category": category,
369
+ "inputs": mods.get("input", ["text"]),
370
+ "outputs": outputs,
371
+ "has_tools": raw.get("tool_call", False),
372
+ "has_functions": raw.get("tool_call", False),
373
+ "has_reasoning": raw.get("reasoning", False),
374
+ "has_vision": "image" in mods.get("input", []),
375
+ "provider": provider_key,
376
+ "source": "modelsdev",
377
+ }
378
+
379
+
380
+ # ============================================================================
381
+ # Lookup Index
382
+ # ============================================================================
383
+
384
+ class ModelIndex:
385
+ """Fast lookup structure for model ID resolution."""
386
+
387
+ def __init__(self):
388
+ self._by_full_id: Dict[str, str] = {} # normalized_id -> canonical_id
389
+ self._by_suffix: Dict[str, List[str]] = {} # short_name -> [canonical_ids]
390
+
391
+ def clear(self):
392
+ """Reset the index."""
393
+ self._by_full_id.clear()
394
+ self._by_suffix.clear()
395
+
396
+ def entry_count(self) -> int:
397
+ """Return total number of suffix index entries."""
398
+ return sum(len(v) for v in self._by_suffix.values())
399
+
400
+ def add(self, canonical_id: str):
401
+ """Index a canonical model ID for various lookup patterns."""
402
+ self._by_full_id[canonical_id] = canonical_id
403
+
404
+ segments = canonical_id.split("/")
405
+ if len(segments) >= 2:
406
+ # Index by everything after first segment
407
+ partial = "/".join(segments[1:])
408
+ self._by_suffix.setdefault(partial, []).append(canonical_id)
409
+
410
+ # Index by final segment only
411
+ if len(segments) >= 3:
412
+ tail = segments[-1]
413
+ self._by_suffix.setdefault(tail, []).append(canonical_id)
414
+
415
+ def resolve(self, query: str) -> List[str]:
416
+ """Find all canonical IDs matching a query."""
417
+ # Direct match
418
+ if query in self._by_full_id:
419
+ return [self._by_full_id[query]]
420
+
421
+ # Try with openrouter prefix
422
+ prefixed = f"openrouter/{query}"
423
+ if prefixed in self._by_full_id:
424
+ return [self._by_full_id[prefixed]]
425
+
426
+ # Extract search terms from query
427
+ search_keys = []
428
+ parts = query.split("/")
429
+ if len(parts) >= 2:
430
+ search_keys.append("/".join(parts[1:]))
431
+ search_keys.append(parts[-1])
432
+ else:
433
+ search_keys.append(query)
434
+ # Find matches
435
+ matches = []
436
+ seen = set()
437
+ for key in search_keys:
438
+ for cid in self._by_suffix.get(key, []):
439
+ if cid not in seen:
440
+ seen.add(cid)
441
+ matches.append(cid)
442
+
443
+ return matches
444
+
445
+
446
+ # ============================================================================
447
+ # Data Merger
448
+ # ============================================================================
449
+
450
+ class DataMerger:
451
+ """Combines data from multiple sources into unified ModelMetadata."""
452
+
453
+ @staticmethod
454
+ def single(model_id: str, data: Dict, origin: str, quality: str) -> ModelMetadata:
455
+ """Create ModelMetadata from a single source record."""
456
+ return ModelMetadata(
457
+ model_id=model_id,
458
+ display_name=data.get("name", model_id),
459
+ provider=data.get("provider", ""),
460
+ category=data.get("category", "chat"),
461
+ pricing=ModelPricing(
462
+ prompt=data.get("prompt_cost"),
463
+ completion=data.get("completion_cost"),
464
+ cached_input=data.get("cache_read_cost"),
465
+ cache_write=data.get("cache_write_cost"),
466
+ ),
467
+ limits=ModelLimits(
468
+ context_window=data.get("context") or None,
469
+ max_output=data.get("max_out") or None,
470
+ ),
471
+ capabilities=ModelCapabilities(
472
+ tools=data.get("has_tools", False),
473
+ functions=data.get("has_functions", False),
474
+ reasoning=data.get("has_reasoning", False),
475
+ vision=data.get("has_vision", False),
476
+ ),
477
+ input_types=data.get("inputs", ["text"]),
478
+ output_types=data.get("outputs", ["text"]),
479
+ origin=origin,
480
+ match_quality=quality,
481
+ )
482
+
483
+ @staticmethod
484
+ def combine(model_id: str, records: List[Tuple[Dict, str]], quality: str) -> ModelMetadata:
485
+ """Merge multiple source records into one ModelMetadata."""
486
+ if len(records) == 1:
487
+ data, origin = records[0]
488
+ return DataMerger.single(model_id, data, origin, quality)
489
+
490
+ # Aggregate pricing - use average
491
+ prompt_costs = [r[0]["prompt_cost"] for r in records if r[0].get("prompt_cost")]
492
+ comp_costs = [r[0]["completion_cost"] for r in records if r[0].get("completion_cost")]
493
+ cache_costs = [r[0]["cache_read_cost"] for r in records if r[0].get("cache_read_cost")]
494
+
495
+ # Aggregate limits - use most common value
496
+ contexts = [r[0]["context"] for r in records if r[0].get("context")]
497
+ max_outs = [r[0]["max_out"] for r in records if r[0].get("max_out")]
498
+
499
+ # Capabilities - OR logic (any source supporting = supported)
500
+ has_tools = any(r[0].get("has_tools") for r in records)
501
+ has_funcs = any(r[0].get("has_functions") for r in records)
502
+ has_reason = any(r[0].get("has_reasoning") for r in records)
503
+ has_vis = any(r[0].get("has_vision") for r in records)
504
+
505
+ # Modalities - union
506
+ all_inputs = set()
507
+ all_outputs = set()
508
+ for r in records:
509
+ all_inputs.update(r[0].get("inputs", ["text"]))
510
+ all_outputs.update(r[0].get("outputs", ["text"]))
511
+
512
+ # Category - majority vote
513
+ categories = [r[0].get("category", "chat") for r in records]
514
+ category = max(set(categories), key=categories.count)
515
+
516
+ # Name - first non-empty
517
+ name = model_id
518
+ for r in records:
519
+ if r[0].get("name"):
520
+ name = r[0]["name"]
521
+ break
522
+
523
+ origins = [r[1] for r in records]
524
+
525
+ return ModelMetadata(
526
+ model_id=model_id,
527
+ display_name=name,
528
+ provider=records[0][0].get("provider", ""),
529
+ category=category,
530
+ pricing=ModelPricing(
531
+ prompt=sum(prompt_costs) / len(prompt_costs) if prompt_costs else None,
532
+ completion=sum(comp_costs) / len(comp_costs) if comp_costs else None,
533
+ cached_input=sum(cache_costs) / len(cache_costs) if cache_costs else None,
534
+ ),
535
+ limits=ModelLimits(
536
+ context_window=DataMerger._mode(contexts),
537
+ max_output=DataMerger._mode(max_outs),
538
+ ),
539
+ capabilities=ModelCapabilities(
540
+ tools=has_tools,
541
+ functions=has_funcs,
542
+ reasoning=has_reason,
543
+ vision=has_vis,
544
+ ),
545
+ input_types=list(all_inputs) or ["text"],
546
+ output_types=list(all_outputs) or ["text"],
547
+ origin=",".join(origins),
548
+ match_quality=quality,
549
+ )
550
+
551
+ @staticmethod
552
+ def _mode(values: List[int]) -> Optional[int]:
553
+ """Return most frequent value."""
554
+ if not values:
555
+ return None
556
+ return max(set(values), key=values.count)
557
+
558
+
559
+ # ============================================================================
560
+ # Main Registry Service
561
+ # ============================================================================
562
+
563
+ class ModelRegistry:
564
+ """
565
+ Central registry for model metadata from external catalogs.
566
+
567
+ Manages background data refresh and provides lookup/pricing APIs.
568
+ """
569
+
570
+ REFRESH_INTERVAL_DEFAULT = 6 * 60 * 60 # 6 hours
571
+
572
+ def __init__(
573
+ self,
574
+ refresh_seconds: Optional[int] = None,
575
+ skip_modelsdev_providers: Optional[List[str]] = None,
576
+ ):
577
+ interval_env = os.getenv("MODEL_INFO_REFRESH_INTERVAL")
578
+ self._refresh_interval = refresh_seconds or (
579
+ int(interval_env) if interval_env else self.REFRESH_INTERVAL_DEFAULT
580
+ )
581
+
582
+ # Configure adapters
583
+ self._adapters: List[DataSourceAdapter] = [
584
+ OpenRouterAdapter(),
585
+ ModelsDevAdapter(skip_providers=skip_modelsdev_providers or []),
586
+ ]
587
+
588
+ # Raw data stores
589
+ self._openrouter_store: Dict[str, Dict] = {}
590
+ self._modelsdev_store: Dict[str, Dict] = {}
591
+
592
+ # Lookup infrastructure
593
+ self._index = ModelIndex()
594
+ self._result_cache: Dict[str, ModelMetadata] = {}
595
+
596
+ # Async coordination
597
+ self._ready = asyncio.Event()
598
+ self._mutex = asyncio.Lock()
599
+ self._worker: Optional[asyncio.Task] = None
600
+ self._last_refresh: float = 0
601
+
602
+ # ---------- Lifecycle ----------
603
+
604
+ async def start(self):
605
+ """Begin background refresh worker."""
606
+ if self._worker is None:
607
+ self._worker = asyncio.create_task(self._refresh_worker())
608
+ logger.info(
609
+ "ModelRegistry started (refresh every %ds)",
610
+ self._refresh_interval
611
+ )
612
+
613
+ async def stop(self):
614
+ """Halt background worker."""
615
+ if self._worker:
616
+ self._worker.cancel()
617
+ try:
618
+ await self._worker
619
+ except asyncio.CancelledError:
620
+ pass
621
+ self._worker = None
622
+ logger.info("ModelRegistry stopped")
623
+
624
+ async def await_ready(self, timeout_secs: float = 30.0) -> bool:
625
+ """Block until initial data load completes."""
626
+ try:
627
+ await asyncio.wait_for(self._ready.wait(), timeout=timeout_secs)
628
+ return True
629
+ except asyncio.TimeoutError:
630
+ logger.warning("ModelRegistry ready timeout after %.1fs", timeout_secs)
631
+ return False
632
+
633
+ @property
634
+ def is_ready(self) -> bool:
635
+ return self._ready.is_set()
636
+
637
+ # ---------- Background Worker ----------
638
+
639
+ async def _refresh_worker(self):
640
+ """Periodic refresh loop."""
641
+ await self._load_all_sources()
642
+ self._ready.set()
643
+
644
+ while True:
645
+ try:
646
+ await asyncio.sleep(self._refresh_interval)
647
+ logger.info("Scheduled registry refresh...")
648
+ await self._load_all_sources()
649
+ logger.info("Registry refresh complete")
650
+ except asyncio.CancelledError:
651
+ break
652
+ except Exception as ex:
653
+ logger.error("Registry refresh error: %s", ex)
654
+
655
+ async def _load_all_sources(self):
656
+ """Fetch from all adapters concurrently."""
657
+ loop = asyncio.get_event_loop()
658
+
659
+ tasks = [
660
+ loop.run_in_executor(None, adapter.fetch)
661
+ for adapter in self._adapters
662
+ ]
663
+
664
+ results = await asyncio.gather(*tasks, return_exceptions=True)
665
+
666
+ async with self._mutex:
667
+ for adapter, result in zip(self._adapters, results):
668
+ if isinstance(result, Exception):
669
+ logger.error("%s fetch failed: %s", adapter.source_name, result)
670
+ continue
671
+
672
+ if adapter.source_name == "openrouter":
673
+ self._openrouter_store = result
674
+ logger.info("OpenRouter: %d models loaded", len(result))
675
+ elif adapter.source_name == "modelsdev":
676
+ self._modelsdev_store = result
677
+ logger.info("Models.dev: %d models loaded", len(result))
678
+
679
+ self._rebuild_index()
680
+ self._last_refresh = time.time()
681
+
682
+ def _rebuild_index(self):
683
+ """Reconstruct lookup index from current stores."""
684
+ self._index.clear()
685
+ self._result_cache.clear()
686
+
687
+ for model_id in self._openrouter_store:
688
+ self._index.add(model_id)
689
+
690
+ for model_id in self._modelsdev_store:
691
+ self._index.add(model_id)
692
+
693
+ # ---------- Query API ----------
694
+
695
+ def lookup(self, model_id: str) -> Optional[ModelMetadata]:
696
+ """
697
+ Retrieve model metadata by ID.
698
+
699
+ Matching strategy:
700
+ 1. Exact match against known IDs
701
+ 2. Fuzzy match by model name suffix
702
+ 3. Aggregate if multiple sources match
703
+ """
704
+ if model_id in self._result_cache:
705
+ return self._result_cache[model_id]
706
+
707
+ metadata = self._resolve_model(model_id)
708
+ if metadata:
709
+ self._result_cache[model_id] = metadata
710
+ return metadata
711
+
712
+ def _resolve_model(self, model_id: str) -> Optional[ModelMetadata]:
713
+ """Build ModelMetadata by matching source data."""
714
+ records: List[Tuple[Dict, str]] = []
715
+ quality = "none"
716
+
717
+ # Check exact matches first
718
+ or_key = f"openrouter/{model_id}" if not model_id.startswith("openrouter/") else model_id
719
+ if or_key in self._openrouter_store:
720
+ records.append((self._openrouter_store[or_key], f"openrouter:exact:{or_key}"))
721
+ quality = "exact"
722
+
723
+ if model_id in self._modelsdev_store:
724
+ records.append((self._modelsdev_store[model_id], f"modelsdev:exact:{model_id}"))
725
+ quality = "exact"
726
+
727
+ # Fall back to index search
728
+ if not records:
729
+ candidates = self._index.resolve(model_id)
730
+ for cid in candidates:
731
+ if cid in self._openrouter_store:
732
+ records.append((self._openrouter_store[cid], f"openrouter:fuzzy:{cid}"))
733
+ elif cid in self._modelsdev_store:
734
+ records.append((self._modelsdev_store[cid], f"modelsdev:fuzzy:{cid}"))
735
+
736
+ if records:
737
+ quality = "fuzzy"
738
+
739
+ if not records:
740
+ return None
741
+
742
+ return DataMerger.combine(model_id, records, quality)
743
+
744
+ def get_pricing(self, model_id: str) -> Optional[Dict[str, float]]:
745
+ """Extract just pricing info for cost calculations."""
746
+ meta = self.lookup(model_id)
747
+ if not meta:
748
+ return None
749
+
750
+ result = {}
751
+ if meta.pricing.prompt is not None:
752
+ result["input_cost_per_token"] = meta.pricing.prompt
753
+ if meta.pricing.completion is not None:
754
+ result["output_cost_per_token"] = meta.pricing.completion
755
+ if meta.pricing.cached_input is not None:
756
+ result["cache_read_input_token_cost"] = meta.pricing.cached_input
757
+ if meta.pricing.cache_write is not None:
758
+ result["cache_creation_input_token_cost"] = meta.pricing.cache_write
759
+
760
+ return result if result else None
761
+
762
+ def compute_cost(
763
+ self,
764
+ model_id: str,
765
+ input_tokens: int,
766
+ output_tokens: int,
767
+ cache_hit_tokens: int = 0,
768
+ cache_miss_tokens: int = 0,
769
+ ) -> Optional[float]:
770
+ """
771
+ Calculate total request cost.
772
+
773
+ Returns None if pricing unavailable.
774
+ """
775
+ pricing = self.get_pricing(model_id)
776
+ if not pricing:
777
+ return None
778
+
779
+ in_rate = pricing.get("input_cost_per_token")
780
+ out_rate = pricing.get("output_cost_per_token")
781
+
782
+ if in_rate is None or out_rate is None:
783
+ return None
784
+
785
+ total = (input_tokens * in_rate) + (output_tokens * out_rate)
786
+
787
+ cache_read_rate = pricing.get("cache_read_input_token_cost")
788
+ if cache_read_rate and cache_hit_tokens:
789
+ total += cache_hit_tokens * cache_read_rate
790
+
791
+ cache_write_rate = pricing.get("cache_creation_input_token_cost")
792
+ if cache_write_rate and cache_miss_tokens:
793
+ total += cache_miss_tokens * cache_write_rate
794
+
795
+ return total
796
+
797
+ def enrich_models(self, model_ids: List[str]) -> List[Dict[str, Any]]:
798
+ """
799
+ Attach metadata to a list of model IDs.
800
+
801
+ Used by /v1/models endpoint.
802
+ """
803
+ enriched = []
804
+ for mid in model_ids:
805
+ meta = self.lookup(mid)
806
+ if meta:
807
+ enriched.append(meta.as_api_response())
808
+ else:
809
+ # Fallback minimal entry
810
+ enriched.append({
811
+ "id": mid,
812
+ "object": "model",
813
+ "created": int(time.time()),
814
+ "owned_by": mid.split("/")[0] if "/" in mid else "unknown",
815
+ })
816
+ return enriched
817
+
818
+ def all_raw_models(self) -> Dict[str, Dict]:
819
+ """Return all raw source data (for debugging)."""
820
+ combined = {}
821
+ combined.update(self._openrouter_store)
822
+ combined.update(self._modelsdev_store)
823
+ return combined
824
+
825
+ def diagnostics(self) -> Dict[str, Any]:
826
+ """Return service health/stats."""
827
+ return {
828
+ "ready": self._ready.is_set(),
829
+ "last_refresh": self._last_refresh,
830
+ "openrouter_count": len(self._openrouter_store),
831
+ "modelsdev_count": len(self._modelsdev_store),
832
+ "cached_lookups": len(self._result_cache),
833
+ "index_entries": self._index.entry_count(),
834
+ "refresh_interval": self._refresh_interval,
835
+ }
836
+
837
+ # ---------- Backward Compatibility Methods ----------
838
+
839
+ def get_model_info(self, model_id: str) -> Optional[ModelMetadata]:
840
+ """Alias for lookup() - backward compatibility."""
841
+ return self.lookup(model_id)
842
+
843
+ def get_cost_info(self, model_id: str) -> Optional[Dict[str, float]]:
844
+ """Alias for get_pricing() - backward compatibility."""
845
+ return self.get_pricing(model_id)
846
+
847
+ def calculate_cost(
848
+ self,
849
+ model_id: str,
850
+ prompt_tokens: int,
851
+ completion_tokens: int,
852
+ cache_read_tokens: int = 0,
853
+ cache_creation_tokens: int = 0,
854
+ ) -> Optional[float]:
855
+ """Alias for compute_cost() - backward compatibility."""
856
+ return self.compute_cost(
857
+ model_id, prompt_tokens, completion_tokens,
858
+ cache_read_tokens, cache_creation_tokens
859
+ )
860
+
861
+ def enrich_model_list(self, model_ids: List[str]) -> List[Dict[str, Any]]:
862
+ """Alias for enrich_models() - backward compatibility."""
863
+ return self.enrich_models(model_ids)
864
+
865
+ def get_all_source_models(self) -> Dict[str, Dict]:
866
+ """Alias for all_raw_models() - backward compatibility."""
867
+ return self.all_raw_models()
868
+
869
+ def get_stats(self) -> Dict[str, Any]:
870
+ """Alias for diagnostics() - backward compatibility."""
871
+ return self.diagnostics()
872
+
873
+ def wait_for_ready(self, timeout: float = 30.0):
874
+ """Sync wrapper for await_ready() - for compatibility."""
875
+ return self.await_ready(timeout)
876
+
877
+
878
+ # ============================================================================
879
+ # Backward Compatibility Layer
880
+ # ============================================================================
881
+
882
+ # Alias for backward compatibility
883
+ ModelInfo = ModelMetadata
884
+ ModelInfoService = ModelRegistry
885
+
886
+ # Global singleton
887
+ _registry_instance: Optional[ModelRegistry] = None
888
+
889
+
890
+ def get_model_info_service() -> ModelRegistry:
891
+ """Get or create the global registry instance."""
892
+ global _registry_instance
893
+ if _registry_instance is None:
894
+ _registry_instance = ModelRegistry()
895
+ return _registry_instance
896
+
897
+
898
+ async def init_model_info_service() -> ModelRegistry:
899
+ """Initialize and start the global registry."""
900
+ registry = get_model_info_service()
901
+ await registry.start()
902
+ return registry
903
+
904
+
905
+ # Compatibility shim - map old method names to new
906
+ class _CompatibilityWrapper:
907
+ """Provides old API method names for gradual migration."""
908
+
909
+ def __init__(self, registry: ModelRegistry):
910
+ self._reg = registry
911
+
912
+ def get_model_info(self, model_id: str) -> Optional[ModelMetadata]:
913
+ return self._reg.lookup(model_id)
914
+
915
+ def get_cost_info(self, model_id: str) -> Optional[Dict[str, float]]:
916
+ return self._reg.get_pricing(model_id)
917
+
918
+ def calculate_cost(
919
+ self, model_id: str, prompt_tokens: int, completion_tokens: int,
920
+ cache_read_tokens: int = 0, cache_creation_tokens: int = 0
921
+ ) -> Optional[float]:
922
+ return self._reg.compute_cost(
923
+ model_id, prompt_tokens, completion_tokens,
924
+ cache_read_tokens, cache_creation_tokens
925
+ )
926
+
927
+ def enrich_model_list(self, model_ids: List[str]) -> List[Dict[str, Any]]:
928
+ return self._reg.enrich_models(model_ids)
929
+
930
+ def get_all_source_models(self) -> Dict[str, Dict]:
931
+ return self._reg.all_raw_models()
932
+
933
+ def get_stats(self) -> Dict[str, Any]:
934
+ return self._reg.diagnostics()
935
+
936
+ async def start(self):
937
+ await self._reg.start()
938
+
939
+ async def stop(self):
940
+ await self._reg.stop()
941
+
942
+ async def wait_for_ready(self, timeout: float = 30.0) -> bool:
943
+ return await self._reg.await_ready(timeout)
944
+
945
+ def is_ready(self) -> bool:
946
+ return self._reg.is_ready