Mirrowel commited on
Commit
c4e297f
·
1 Parent(s): ed4dd55

feat(model-info): ✨ add provider priority system and enhanced model metadata

Browse files

This commit significantly improves the model lookup and metadata capabilities of the ModelInfoService:

- Introduces a provider priority system that prefers native/authoritative providers (anthropic, openai, google) over proxy/aggregator providers (openrouter, azure, requesty) when multiple sources exist for the same model
- Adds provider alias mapping to support custom provider names (e.g., nvidia_nim -> nvidia, gemini_cli -> google) for direct model ID resolution
- Extends ModelMetadata with new ModelInfo dataclass containing family, description, knowledge_cutoff, release_date, open_weights, status, tokenizer, and huggingface_id fields
- Adds extended capabilities support: structured_output, temperature, attachments, and interleaved content
- Implements version pattern normalization to handle different version formats (e.g., claude-opus-4-5 matches claude-opus-4.5)
- Refactors data merger to use best-source selection instead of averaging multiple sources, preserving queried model identity while inheriting technical specs from best matching native provider
- Restructures API response format with clear separation between core OpenAI fields, extended fields, legacy compatibility fields, and debug metadata
- Adds parent model tracking in origin field for transparency when fuzzy matching occurs
- Exports ModelMetadata in __all__ for public API access

The provider priority system ensures that when looking up custom provider models (e.g., antigravity/claude-opus-4-5), the service intelligently inherits accurate pricing and capabilities from the native provider source (anthropic/claude-opus-4.5) while preserving the queried model's provider identity.

src/rotator_library/__init__.py CHANGED
@@ -7,19 +7,33 @@ from .client import RotatingClient
7
  if TYPE_CHECKING:
8
  from .providers import PROVIDER_PLUGINS
9
  from .providers.provider_interface import ProviderInterface
10
- from .model_info_service import ModelInfoService, ModelInfo
 
 
 
 
 
 
 
 
11
 
12
- __all__ = ["RotatingClient", "PROVIDER_PLUGINS", "ModelInfoService", "ModelInfo"]
13
 
14
  def __getattr__(name):
15
  """Lazy-load PROVIDER_PLUGINS and ModelInfoService to speed up module import."""
16
  if name == "PROVIDER_PLUGINS":
17
  from .providers import PROVIDER_PLUGINS
 
18
  return PROVIDER_PLUGINS
19
  if name == "ModelInfoService":
20
  from .model_info_service import ModelInfoService
 
21
  return ModelInfoService
22
  if name == "ModelInfo":
23
  from .model_info_service import ModelInfo
 
24
  return ModelInfo
 
 
 
 
25
  raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
 
7
  if TYPE_CHECKING:
8
  from .providers import PROVIDER_PLUGINS
9
  from .providers.provider_interface import ProviderInterface
10
+ from .model_info_service import ModelInfoService, ModelInfo, ModelMetadata
11
+
12
+ __all__ = [
13
+ "RotatingClient",
14
+ "PROVIDER_PLUGINS",
15
+ "ModelInfoService",
16
+ "ModelInfo",
17
+ "ModelMetadata",
18
+ ]
19
 
 
20
 
21
  def __getattr__(name):
22
  """Lazy-load PROVIDER_PLUGINS and ModelInfoService to speed up module import."""
23
  if name == "PROVIDER_PLUGINS":
24
  from .providers import PROVIDER_PLUGINS
25
+
26
  return PROVIDER_PLUGINS
27
  if name == "ModelInfoService":
28
  from .model_info_service import ModelInfoService
29
+
30
  return ModelInfoService
31
  if name == "ModelInfo":
32
  from .model_info_service import ModelInfo
33
+
34
  return ModelInfo
35
+ if name == "ModelMetadata":
36
+ from .model_info_service import ModelMetadata
37
+
38
+ return ModelMetadata
39
  raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
src/rotator_library/model_info_service.py CHANGED
@@ -20,13 +20,114 @@ from urllib.error import URLError
20
  logger = logging.getLogger(__name__)
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # ============================================================================
24
  # Data Structures
25
  # ============================================================================
26
 
 
27
  @dataclass
28
  class ModelPricing:
29
  """Token-level pricing information."""
 
30
  prompt: Optional[float] = None
31
  completion: Optional[float] = None
32
  cached_input: Optional[float] = None
@@ -36,13 +137,15 @@ class ModelPricing:
36
  @dataclass
37
  class ModelLimits:
38
  """Context and output token limits."""
 
39
  context_window: Optional[int] = None
40
  max_output: Optional[int] = None
41
 
42
 
43
- @dataclass
44
  class ModelCapabilities:
45
  """Feature flags for model capabilities."""
 
46
  tools: bool = False
47
  functions: bool = False
48
  reasoning: bool = False
@@ -50,60 +153,91 @@ class ModelCapabilities:
50
  system_prompt: bool = True
51
  caching: bool = False
52
  prefill: bool = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
 
55
  @dataclass
56
  class ModelMetadata:
57
  """Complete model information record."""
58
-
59
  model_id: str
60
  display_name: str = ""
61
  provider: str = ""
62
  category: str = "chat" # chat, embedding, image, audio
63
-
64
  pricing: ModelPricing = field(default_factory=ModelPricing)
65
  limits: ModelLimits = field(default_factory=ModelLimits)
66
  capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
67
-
 
68
  input_types: List[str] = field(default_factory=lambda: ["text"])
69
  output_types: List[str] = field(default_factory=lambda: ["text"])
70
-
 
 
 
71
  timestamp: int = field(default_factory=lambda: int(time.time()))
72
  origin: str = ""
73
  match_quality: str = "unknown"
74
-
75
  def as_api_response(self) -> Dict[str, Any]:
76
- """Format for OpenAI-compatible /v1/models response."""
 
 
 
 
 
 
77
  response = {
78
  "id": self.model_id,
79
  "object": "model",
80
  "created": self.timestamp,
81
  "owned_by": self.provider or "proxy",
82
  }
83
-
84
- # Pricing fields
85
- if self.pricing.prompt is not None:
86
- response["input_cost_per_token"] = self.pricing.prompt
87
- if self.pricing.completion is not None:
88
- response["output_cost_per_token"] = self.pricing.completion
89
- if self.pricing.cached_input is not None:
90
- response["cache_read_input_token_cost"] = self.pricing.cached_input
91
- if self.pricing.cache_write is not None:
92
- response["cache_creation_input_token_cost"] = self.pricing.cache_write
93
-
94
- # Limits
95
  if self.limits.context_window:
96
- response["max_input_tokens"] = self.limits.context_window
97
- response["context_window"] = self.limits.context_window
98
  if self.limits.max_output:
99
- response["max_output_tokens"] = self.limits.max_output
100
-
101
- # Category and modalities
102
- response["mode"] = self.category
103
- response["supported_modalities"] = self.input_types
104
- response["supported_output_modalities"] = self.output_types
105
-
106
- # Capability flags
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  response["capabilities"] = {
108
  "tool_choice": self.capabilities.tools,
109
  "function_calling": self.capabilities.functions,
@@ -112,117 +246,168 @@ class ModelMetadata:
112
  "system_messages": self.capabilities.system_prompt,
113
  "prompt_caching": self.capabilities.caching,
114
  "assistant_prefill": self.capabilities.prefill,
 
 
 
 
115
  }
116
-
117
- # Debug metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  if self.origin:
119
- response["_sources"] = [self.origin]
 
 
 
120
  response["_match_type"] = self.match_quality
121
-
 
 
 
 
 
122
  return response
123
-
124
  def as_minimal(self) -> Dict[str, Any]:
125
  """Minimal OpenAI format."""
126
  return {
127
  "id": self.model_id,
128
- "object": "model",
129
  "created": self.timestamp,
130
  "owned_by": self.provider or "proxy",
131
  }
132
-
133
  def to_dict(self) -> Dict[str, Any]:
134
  """Alias for as_api_response() - backward compatibility."""
135
  return self.as_api_response()
136
-
137
  def to_openai_format(self) -> Dict[str, Any]:
138
  """Alias for as_minimal() - backward compatibility."""
139
  return self.as_minimal()
140
-
141
  # Backward-compatible property aliases
142
  @property
143
  def id(self) -> str:
144
  return self.model_id
145
-
146
  @property
147
  def name(self) -> str:
148
  return self.display_name
149
-
150
  @property
151
  def input_cost_per_token(self) -> Optional[float]:
152
  return self.pricing.prompt
153
-
154
  @property
155
  def output_cost_per_token(self) -> Optional[float]:
156
  return self.pricing.completion
157
-
158
  @property
159
  def cache_read_input_token_cost(self) -> Optional[float]:
160
  return self.pricing.cached_input
161
-
162
  @property
163
  def cache_creation_input_token_cost(self) -> Optional[float]:
164
  return self.pricing.cache_write
165
-
166
  @property
167
  def max_input_tokens(self) -> Optional[int]:
168
  return self.limits.context_window
169
-
170
  @property
171
  def max_output_tokens(self) -> Optional[int]:
172
  return self.limits.max_output
173
-
174
  @property
175
  def mode(self) -> str:
176
  return self.category
177
-
178
  @property
179
  def supported_modalities(self) -> List[str]:
180
  return self.input_types
181
-
182
  @property
183
  def supported_output_modalities(self) -> List[str]:
184
  return self.output_types
185
-
186
  @property
187
  def supports_tool_choice(self) -> bool:
188
  return self.capabilities.tools
189
-
190
  @property
191
  def supports_function_calling(self) -> bool:
192
  return self.capabilities.functions
193
-
194
  @property
195
  def supports_reasoning(self) -> bool:
196
  return self.capabilities.reasoning
197
-
198
  @property
199
  def supports_vision(self) -> bool:
200
  return self.capabilities.vision
201
-
202
  @property
203
  def supports_system_messages(self) -> bool:
204
  return self.capabilities.system_prompt
205
-
206
  @property
207
  def supports_prompt_caching(self) -> bool:
208
  return self.capabilities.caching
209
-
210
  @property
211
  def supports_assistant_prefill(self) -> bool:
212
  return self.capabilities.prefill
213
-
214
  @property
215
  def litellm_provider(self) -> str:
216
  return self.provider
217
-
218
  @property
219
  def created(self) -> int:
220
  return self.timestamp
221
-
222
  @property
223
  def _sources(self) -> List[str]:
224
  return [self.origin] if self.origin else []
225
-
226
  @property
227
  def _match_type(self) -> str:
228
  return self.match_quality
@@ -232,16 +417,17 @@ class ModelMetadata:
232
  # Data Source Adapters
233
  # ============================================================================
234
 
 
235
  class DataSourceAdapter:
236
  """Base interface for external data sources."""
237
-
238
  source_name: str = "unknown"
239
  endpoint: str = ""
240
-
241
  def fetch(self) -> Dict[str, Dict]:
242
  """Retrieve and normalize data. Returns {model_id: raw_data}."""
243
  raise NotImplementedError
244
-
245
  def _http_get(self, url: str, timeout: int = 30) -> Any:
246
  """Execute HTTP GET with standard headers."""
247
  req = Request(url, headers={"User-Agent": "ModelRegistry/1.0"})
@@ -251,98 +437,125 @@ class DataSourceAdapter:
251
 
252
  class OpenRouterAdapter(DataSourceAdapter):
253
  """Fetches model data from OpenRouter's public API."""
254
-
255
  source_name = "openrouter"
256
  endpoint = "https://openrouter.ai/api/v1/models"
257
-
258
  def fetch(self) -> Dict[str, Dict]:
259
  try:
260
  raw = self._http_get(self.endpoint)
261
  entries = raw.get("data", [])
262
-
263
  catalog = {}
264
  for entry in entries:
265
  mid = entry.get("id")
266
  if not mid:
267
  continue
268
-
269
  full_id = f"openrouter/{mid}"
270
  catalog[full_id] = self._normalize(entry)
271
-
272
  return catalog
273
  except (URLError, json.JSONDecodeError, TimeoutError) as err:
274
  raise ConnectionError(f"OpenRouter unavailable: {err}") from err
275
-
276
  def _normalize(self, raw: Dict) -> Dict:
277
  """Transform OpenRouter schema to internal format."""
278
  prices = raw.get("pricing", {})
279
  arch = raw.get("architecture", {})
280
  top = raw.get("top_provider", {})
281
  params = raw.get("supported_parameters", [])
282
-
283
  tokenizer = arch.get("tokenizer", "")
284
  category = "embedding" if "embedding" in tokenizer.lower() else "chat"
285
-
 
 
 
 
286
  return {
 
287
  "name": raw.get("name", ""),
 
 
 
 
 
288
  "prompt_cost": float(prices.get("prompt", 0)),
289
  "completion_cost": float(prices.get("completion", 0)),
290
- "cache_read_cost": float(prices.get("input_cache_read", 0)) or None,
 
 
291
  "context": top.get("context_length", 0),
292
  "max_out": top.get("max_completion_tokens", 0),
293
- "category": category,
294
  "inputs": arch.get("input_modalities", ["text"]),
295
  "outputs": arch.get("output_modalities", ["text"]),
 
296
  "has_tools": "tool_choice" in params or "tools" in params,
297
  "has_functions": "tools" in params or "function_calling" in params,
298
- "has_reasoning": "reasoning" in params,
299
  "has_vision": "image" in arch.get("input_modalities", []),
300
- "provider": "openrouter",
301
- "source": "openrouter",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  }
303
 
304
 
305
  class ModelsDevAdapter(DataSourceAdapter):
306
  """Fetches model data from Models.dev catalog."""
307
-
308
  source_name = "modelsdev"
309
  endpoint = "https://models.dev/api.json"
310
-
311
  def __init__(self, skip_providers: Optional[List[str]] = None):
312
  self.skip_providers = skip_providers or []
313
-
314
  def fetch(self) -> Dict[str, Dict]:
315
  try:
316
  raw = self._http_get(self.endpoint)
317
-
318
  catalog = {}
319
  for provider_key, provider_block in raw.items():
320
  if not isinstance(provider_block, dict):
321
  continue
322
  if provider_key in self.skip_providers:
323
  continue
324
-
325
  models_block = provider_block.get("models", {})
326
  if not isinstance(models_block, dict):
327
  continue
328
-
329
  for model_key, model_data in models_block.items():
330
  if not isinstance(model_data, dict):
331
  continue
332
-
333
  full_id = f"{provider_key}/{model_key}"
334
  catalog[full_id] = self._normalize(model_data, provider_key)
335
-
336
  return catalog
337
  except (URLError, json.JSONDecodeError, TimeoutError) as err:
338
  raise ConnectionError(f"Models.dev unavailable: {err}") from err
339
-
340
  def _normalize(self, raw: Dict, provider_key: str) -> Dict:
341
  """Transform Models.dev schema to internal format."""
342
  costs = raw.get("cost", {})
343
  mods = raw.get("modalities", {})
344
  lims = raw.get("limit", {})
345
-
346
  outputs = mods.get("output", ["text"])
347
  if "image" in outputs:
348
  category = "image"
@@ -350,30 +563,46 @@ class ModelsDevAdapter(DataSourceAdapter):
350
  category = "audio"
351
  else:
352
  category = "chat"
353
-
354
  # Models.dev uses per-million pricing, convert to per-token
355
  divisor = 1_000_000
356
-
357
  cache_read = costs.get("cache_read")
358
  cache_write = costs.get("cache_write")
359
-
360
  return {
 
361
  "name": raw.get("name", ""),
 
 
 
 
 
362
  "prompt_cost": float(costs.get("input", 0)) / divisor,
363
  "completion_cost": float(costs.get("output", 0)) / divisor,
364
  "cache_read_cost": float(cache_read) / divisor if cache_read else None,
365
  "cache_write_cost": float(cache_write) / divisor if cache_write else None,
 
366
  "context": lims.get("context", 0),
367
  "max_out": lims.get("output", 0),
368
- "category": category,
369
  "inputs": mods.get("input", ["text"]),
370
  "outputs": outputs,
 
371
  "has_tools": raw.get("tool_call", False),
372
  "has_functions": raw.get("tool_call", False),
373
  "has_reasoning": raw.get("reasoning", False),
374
  "has_vision": "image" in mods.get("input", []),
375
- "provider": provider_key,
376
- "source": "modelsdev",
 
 
 
 
 
 
 
 
377
  }
378
 
379
 
@@ -381,48 +610,82 @@ class ModelsDevAdapter(DataSourceAdapter):
381
  # Lookup Index
382
  # ============================================================================
383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  class ModelIndex:
385
  """Fast lookup structure for model ID resolution."""
386
-
387
  def __init__(self):
388
  self._by_full_id: Dict[str, str] = {} # normalized_id -> canonical_id
389
  self._by_suffix: Dict[str, List[str]] = {} # short_name -> [canonical_ids]
390
-
 
 
 
391
  def clear(self):
392
  """Reset the index."""
393
  self._by_full_id.clear()
394
  self._by_suffix.clear()
395
-
 
396
  def entry_count(self) -> int:
397
  """Return total number of suffix index entries."""
398
  return sum(len(v) for v in self._by_suffix.values())
399
-
400
  def add(self, canonical_id: str):
401
  """Index a canonical model ID for various lookup patterns."""
402
  self._by_full_id[canonical_id] = canonical_id
403
-
404
  segments = canonical_id.split("/")
405
  if len(segments) >= 2:
406
  # Index by everything after first segment
407
  partial = "/".join(segments[1:])
408
  self._by_suffix.setdefault(partial, []).append(canonical_id)
409
-
410
  # Index by final segment only
411
  if len(segments) >= 3:
412
  tail = segments[-1]
413
  self._by_suffix.setdefault(tail, []).append(canonical_id)
414
-
 
 
 
 
 
 
 
 
415
  def resolve(self, query: str) -> List[str]:
416
  """Find all canonical IDs matching a query."""
417
  # Direct match
418
  if query in self._by_full_id:
419
  return [self._by_full_id[query]]
420
-
421
  # Try with openrouter prefix
422
  prefixed = f"openrouter/{query}"
423
  if prefixed in self._by_full_id:
424
  return [self._by_full_id[prefixed]]
425
-
426
  # Extract search terms from query
427
  search_keys = []
428
  parts = query.split("/")
@@ -431,7 +694,8 @@ class ModelIndex:
431
  search_keys.append(parts[-1])
432
  else:
433
  search_keys.append(query)
434
- # Find matches
 
435
  matches = []
436
  seen = set()
437
  for key in search_keys:
@@ -439,7 +703,24 @@ class ModelIndex:
439
  if cid not in seen:
440
  seen.add(cid)
441
  matches.append(cid)
442
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  return matches
444
 
445
 
@@ -447,128 +728,181 @@ class ModelIndex:
447
  # Data Merger
448
  # ============================================================================
449
 
 
450
  class DataMerger:
451
- """Combines data from multiple sources into unified ModelMetadata."""
452
-
 
 
 
 
 
 
453
  @staticmethod
454
- def single(model_id: str, data: Dict, origin: str, quality: str) -> ModelMetadata:
455
- """Create ModelMetadata from a single source record."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
  return ModelMetadata(
457
- model_id=model_id,
458
- display_name=data.get("name", model_id),
459
- provider=data.get("provider", ""),
460
- category=data.get("category", "chat"),
 
461
  pricing=ModelPricing(
462
- prompt=data.get("prompt_cost"),
463
- completion=data.get("completion_cost"),
464
- cached_input=data.get("cache_read_cost"),
465
- cache_write=data.get("cache_write_cost"),
466
  ),
467
  limits=ModelLimits(
468
- context_window=data.get("context") or None,
469
- max_output=data.get("max_out") or None,
470
  ),
471
  capabilities=ModelCapabilities(
472
- tools=data.get("has_tools", False),
473
- functions=data.get("has_functions", False),
474
- reasoning=data.get("has_reasoning", False),
475
- vision=data.get("has_vision", False),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
  ),
477
- input_types=data.get("inputs", ["text"]),
478
- output_types=data.get("outputs", ["text"]),
479
- origin=origin,
 
 
 
480
  match_quality=quality,
481
  )
482
-
483
  @staticmethod
484
- def combine(model_id: str, records: List[Tuple[Dict, str]], quality: str) -> ModelMetadata:
485
- """Merge multiple source records into one ModelMetadata."""
 
 
 
 
 
 
 
 
 
 
486
  if len(records) == 1:
487
- data, origin = records[0]
488
- return DataMerger.single(model_id, data, origin, quality)
489
-
490
- # Aggregate pricing - use average
491
- prompt_costs = [r[0]["prompt_cost"] for r in records if r[0].get("prompt_cost")]
492
- comp_costs = [r[0]["completion_cost"] for r in records if r[0].get("completion_cost")]
493
- cache_costs = [r[0]["cache_read_cost"] for r in records if r[0].get("cache_read_cost")]
494
-
495
- # Aggregate limits - use most common value
496
- contexts = [r[0]["context"] for r in records if r[0].get("context")]
497
- max_outs = [r[0]["max_out"] for r in records if r[0].get("max_out")]
498
-
499
- # Capabilities - OR logic (any source supporting = supported)
500
- has_tools = any(r[0].get("has_tools") for r in records)
501
- has_funcs = any(r[0].get("has_functions") for r in records)
502
- has_reason = any(r[0].get("has_reasoning") for r in records)
503
- has_vis = any(r[0].get("has_vision") for r in records)
504
-
505
- # Modalities - union
506
- all_inputs = set()
507
- all_outputs = set()
508
- for r in records:
509
- all_inputs.update(r[0].get("inputs", ["text"]))
510
- all_outputs.update(r[0].get("outputs", ["text"]))
511
-
512
- # Category - majority vote
513
- categories = [r[0].get("category", "chat") for r in records]
514
- category = max(set(categories), key=categories.count)
515
-
516
- # Name - first non-empty
517
- name = model_id
518
- for r in records:
519
- if r[0].get("name"):
520
- name = r[0]["name"]
521
- break
522
-
523
- origins = [r[1] for r in records]
524
-
525
- return ModelMetadata(
526
- model_id=model_id,
527
- display_name=name,
528
- provider=records[0][0].get("provider", ""),
529
- category=category,
530
- pricing=ModelPricing(
531
- prompt=sum(prompt_costs) / len(prompt_costs) if prompt_costs else None,
532
- completion=sum(comp_costs) / len(comp_costs) if comp_costs else None,
533
- cached_input=sum(cache_costs) / len(cache_costs) if cache_costs else None,
534
- ),
535
- limits=ModelLimits(
536
- context_window=DataMerger._mode(contexts),
537
- max_output=DataMerger._mode(max_outs),
538
- ),
539
- capabilities=ModelCapabilities(
540
- tools=has_tools,
541
- functions=has_funcs,
542
- reasoning=has_reason,
543
- vision=has_vis,
544
- ),
545
- input_types=list(all_inputs) or ["text"],
546
- output_types=list(all_outputs) or ["text"],
547
- origin=",".join(origins),
548
- match_quality=quality,
549
- )
550
-
551
  @staticmethod
552
- def _mode(values: List[int]) -> Optional[int]:
553
- """Return most frequent value."""
554
- if not values:
 
 
 
 
 
 
555
  return None
556
- return max(set(values), key=values.count)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
 
558
 
559
  # ============================================================================
560
  # Main Registry Service
561
  # ============================================================================
562
 
 
563
  class ModelRegistry:
564
  """
565
  Central registry for model metadata from external catalogs.
566
-
567
  Manages background data refresh and provides lookup/pricing APIs.
568
  """
569
-
570
  REFRESH_INTERVAL_DEFAULT = 6 * 60 * 60 # 6 hours
571
-
572
  def __init__(
573
  self,
574
  refresh_seconds: Optional[int] = None,
@@ -578,38 +912,37 @@ class ModelRegistry:
578
  self._refresh_interval = refresh_seconds or (
579
  int(interval_env) if interval_env else self.REFRESH_INTERVAL_DEFAULT
580
  )
581
-
582
  # Configure adapters
583
  self._adapters: List[DataSourceAdapter] = [
584
  OpenRouterAdapter(),
585
  ModelsDevAdapter(skip_providers=skip_modelsdev_providers or []),
586
  ]
587
-
588
  # Raw data stores
589
  self._openrouter_store: Dict[str, Dict] = {}
590
  self._modelsdev_store: Dict[str, Dict] = {}
591
-
592
  # Lookup infrastructure
593
  self._index = ModelIndex()
594
  self._result_cache: Dict[str, ModelMetadata] = {}
595
-
596
  # Async coordination
597
  self._ready = asyncio.Event()
598
  self._mutex = asyncio.Lock()
599
  self._worker: Optional[asyncio.Task] = None
600
  self._last_refresh: float = 0
601
-
602
  # ---------- Lifecycle ----------
603
-
604
  async def start(self):
605
  """Begin background refresh worker."""
606
  if self._worker is None:
607
  self._worker = asyncio.create_task(self._refresh_worker())
608
  logger.info(
609
- "ModelRegistry started (refresh every %ds)",
610
- self._refresh_interval
611
  )
612
-
613
  async def stop(self):
614
  """Halt background worker."""
615
  if self._worker:
@@ -620,7 +953,7 @@ class ModelRegistry:
620
  pass
621
  self._worker = None
622
  logger.info("ModelRegistry stopped")
623
-
624
  async def await_ready(self, timeout_secs: float = 30.0) -> bool:
625
  """Block until initial data load completes."""
626
  try:
@@ -629,18 +962,18 @@ class ModelRegistry:
629
  except asyncio.TimeoutError:
630
  logger.warning("ModelRegistry ready timeout after %.1fs", timeout_secs)
631
  return False
632
-
633
  @property
634
  def is_ready(self) -> bool:
635
  return self._ready.is_set()
636
-
637
  # ---------- Background Worker ----------
638
-
639
  async def _refresh_worker(self):
640
  """Periodic refresh loop."""
641
  await self._load_all_sources()
642
  self._ready.set()
643
-
644
  while True:
645
  try:
646
  await asyncio.sleep(self._refresh_interval)
@@ -651,51 +984,50 @@ class ModelRegistry:
651
  break
652
  except Exception as ex:
653
  logger.error("Registry refresh error: %s", ex)
654
-
655
  async def _load_all_sources(self):
656
  """Fetch from all adapters concurrently."""
657
  loop = asyncio.get_event_loop()
658
-
659
  tasks = [
660
- loop.run_in_executor(None, adapter.fetch)
661
- for adapter in self._adapters
662
  ]
663
-
664
  results = await asyncio.gather(*tasks, return_exceptions=True)
665
-
666
  async with self._mutex:
667
  for adapter, result in zip(self._adapters, results):
668
  if isinstance(result, Exception):
669
  logger.error("%s fetch failed: %s", adapter.source_name, result)
670
  continue
671
-
672
  if adapter.source_name == "openrouter":
673
  self._openrouter_store = result
674
  logger.info("OpenRouter: %d models loaded", len(result))
675
  elif adapter.source_name == "modelsdev":
676
  self._modelsdev_store = result
677
  logger.info("Models.dev: %d models loaded", len(result))
678
-
679
  self._rebuild_index()
680
  self._last_refresh = time.time()
681
-
682
  def _rebuild_index(self):
683
  """Reconstruct lookup index from current stores."""
684
  self._index.clear()
685
  self._result_cache.clear()
686
-
687
  for model_id in self._openrouter_store:
688
  self._index.add(model_id)
689
-
690
  for model_id in self._modelsdev_store:
691
  self._index.add(model_id)
692
-
693
  # ---------- Query API ----------
694
-
695
  def lookup(self, model_id: str) -> Optional[ModelMetadata]:
696
  """
697
  Retrieve model metadata by ID.
698
-
699
  Matching strategy:
700
  1. Exact match against known IDs
701
  2. Fuzzy match by model name suffix
@@ -703,50 +1035,111 @@ class ModelRegistry:
703
  """
704
  if model_id in self._result_cache:
705
  return self._result_cache[model_id]
706
-
707
  metadata = self._resolve_model(model_id)
708
  if metadata:
709
  self._result_cache[model_id] = metadata
710
  return metadata
711
-
712
  def _resolve_model(self, model_id: str) -> Optional[ModelMetadata]:
713
  """Build ModelMetadata by matching source data."""
714
  records: List[Tuple[Dict, str]] = []
715
  quality = "none"
716
-
717
- # Check exact matches first
718
- or_key = f"openrouter/{model_id}" if not model_id.startswith("openrouter/") else model_id
 
 
 
 
719
  if or_key in self._openrouter_store:
720
- records.append((self._openrouter_store[or_key], f"openrouter:exact:{or_key}"))
 
 
721
  quality = "exact"
722
-
723
  if model_id in self._modelsdev_store:
724
- records.append((self._modelsdev_store[model_id], f"modelsdev:exact:{model_id}"))
 
 
725
  quality = "exact"
726
-
727
- # Fall back to index search
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728
  if not records:
729
  candidates = self._index.resolve(model_id)
730
  for cid in candidates:
731
  if cid in self._openrouter_store:
732
- records.append((self._openrouter_store[cid], f"openrouter:fuzzy:{cid}"))
 
 
733
  elif cid in self._modelsdev_store:
734
- records.append((self._modelsdev_store[cid], f"modelsdev:fuzzy:{cid}"))
735
-
 
 
736
  if records:
737
  quality = "fuzzy"
738
-
739
  if not records:
740
  return None
741
-
742
  return DataMerger.combine(model_id, records, quality)
743
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
744
  def get_pricing(self, model_id: str) -> Optional[Dict[str, float]]:
745
  """Extract just pricing info for cost calculations."""
746
  meta = self.lookup(model_id)
747
  if not meta:
748
  return None
749
-
750
  result = {}
751
  if meta.pricing.prompt is not None:
752
  result["input_cost_per_token"] = meta.pricing.prompt
@@ -756,9 +1149,9 @@ class ModelRegistry:
756
  result["cache_read_input_token_cost"] = meta.pricing.cached_input
757
  if meta.pricing.cache_write is not None:
758
  result["cache_creation_input_token_cost"] = meta.pricing.cache_write
759
-
760
  return result if result else None
761
-
762
  def compute_cost(
763
  self,
764
  model_id: str,
@@ -769,35 +1162,35 @@ class ModelRegistry:
769
  ) -> Optional[float]:
770
  """
771
  Calculate total request cost.
772
-
773
  Returns None if pricing unavailable.
774
  """
775
  pricing = self.get_pricing(model_id)
776
  if not pricing:
777
  return None
778
-
779
  in_rate = pricing.get("input_cost_per_token")
780
  out_rate = pricing.get("output_cost_per_token")
781
-
782
  if in_rate is None or out_rate is None:
783
  return None
784
-
785
  total = (input_tokens * in_rate) + (output_tokens * out_rate)
786
-
787
  cache_read_rate = pricing.get("cache_read_input_token_cost")
788
  if cache_read_rate and cache_hit_tokens:
789
  total += cache_hit_tokens * cache_read_rate
790
-
791
  cache_write_rate = pricing.get("cache_creation_input_token_cost")
792
  if cache_write_rate and cache_miss_tokens:
793
  total += cache_miss_tokens * cache_write_rate
794
-
795
  return total
796
-
797
  def enrich_models(self, model_ids: List[str]) -> List[Dict[str, Any]]:
798
  """
799
  Attach metadata to a list of model IDs.
800
-
801
  Used by /v1/models endpoint.
802
  """
803
  enriched = []
@@ -807,21 +1200,23 @@ class ModelRegistry:
807
  enriched.append(meta.as_api_response())
808
  else:
809
  # Fallback minimal entry
810
- enriched.append({
811
- "id": mid,
812
- "object": "model",
813
- "created": int(time.time()),
814
- "owned_by": mid.split("/")[0] if "/" in mid else "unknown",
815
- })
 
 
816
  return enriched
817
-
818
  def all_raw_models(self) -> Dict[str, Dict]:
819
  """Return all raw source data (for debugging)."""
820
  combined = {}
821
  combined.update(self._openrouter_store)
822
  combined.update(self._modelsdev_store)
823
  return combined
824
-
825
  def diagnostics(self) -> Dict[str, Any]:
826
  """Return service health/stats."""
827
  return {
@@ -833,17 +1228,17 @@ class ModelRegistry:
833
  "index_entries": self._index.entry_count(),
834
  "refresh_interval": self._refresh_interval,
835
  }
836
-
837
  # ---------- Backward Compatibility Methods ----------
838
-
839
  def get_model_info(self, model_id: str) -> Optional[ModelMetadata]:
840
  """Alias for lookup() - backward compatibility."""
841
  return self.lookup(model_id)
842
-
843
  def get_cost_info(self, model_id: str) -> Optional[Dict[str, float]]:
844
  """Alias for get_pricing() - backward compatibility."""
845
  return self.get_pricing(model_id)
846
-
847
  def calculate_cost(
848
  self,
849
  model_id: str,
@@ -854,22 +1249,25 @@ class ModelRegistry:
854
  ) -> Optional[float]:
855
  """Alias for compute_cost() - backward compatibility."""
856
  return self.compute_cost(
857
- model_id, prompt_tokens, completion_tokens,
858
- cache_read_tokens, cache_creation_tokens
 
 
 
859
  )
860
-
861
  def enrich_model_list(self, model_ids: List[str]) -> List[Dict[str, Any]]:
862
  """Alias for enrich_models() - backward compatibility."""
863
  return self.enrich_models(model_ids)
864
-
865
  def get_all_source_models(self) -> Dict[str, Dict]:
866
  """Alias for all_raw_models() - backward compatibility."""
867
  return self.all_raw_models()
868
-
869
  def get_stats(self) -> Dict[str, Any]:
870
  """Alias for diagnostics() - backward compatibility."""
871
  return self.diagnostics()
872
-
873
  def wait_for_ready(self, timeout: float = 30.0):
874
  """Sync wrapper for await_ready() - for compatibility."""
875
  return self.await_ready(timeout)
@@ -880,7 +1278,8 @@ class ModelRegistry:
880
  # ============================================================================
881
 
882
  # Alias for backward compatibility
883
- ModelInfo = ModelMetadata
 
884
  ModelInfoService = ModelRegistry
885
 
886
  # Global singleton
@@ -905,42 +1304,49 @@ async def init_model_info_service() -> ModelRegistry:
905
  # Compatibility shim - map old method names to new
906
  class _CompatibilityWrapper:
907
  """Provides old API method names for gradual migration."""
908
-
909
  def __init__(self, registry: ModelRegistry):
910
  self._reg = registry
911
-
912
  def get_model_info(self, model_id: str) -> Optional[ModelMetadata]:
913
  return self._reg.lookup(model_id)
914
-
915
  def get_cost_info(self, model_id: str) -> Optional[Dict[str, float]]:
916
  return self._reg.get_pricing(model_id)
917
-
918
  def calculate_cost(
919
- self, model_id: str, prompt_tokens: int, completion_tokens: int,
920
- cache_read_tokens: int = 0, cache_creation_tokens: int = 0
 
 
 
 
921
  ) -> Optional[float]:
922
  return self._reg.compute_cost(
923
- model_id, prompt_tokens, completion_tokens,
924
- cache_read_tokens, cache_creation_tokens
 
 
 
925
  )
926
-
927
  def enrich_model_list(self, model_ids: List[str]) -> List[Dict[str, Any]]:
928
  return self._reg.enrich_models(model_ids)
929
-
930
  def get_all_source_models(self) -> Dict[str, Dict]:
931
  return self._reg.all_raw_models()
932
-
933
  def get_stats(self) -> Dict[str, Any]:
934
  return self._reg.diagnostics()
935
-
936
  async def start(self):
937
  await self._reg.start()
938
-
939
  async def stop(self):
940
  await self._reg.stop()
941
-
942
  async def wait_for_ready(self, timeout: float = 30.0) -> bool:
943
  return await self._reg.await_ready(timeout)
944
-
945
  def is_ready(self) -> bool:
946
  return self._reg.is_ready
 
20
  logger = logging.getLogger(__name__)
21
 
22
 
23
+ # ============================================================================
24
+ # Provider Priority Configuration
25
+ # ============================================================================
26
+
27
+ # Native/authoritative providers - prefer these over proxy/aggregator providers
28
+ # Lower index = higher priority
29
+ NATIVE_PROVIDER_PRIORITY = [
30
+ "anthropic",
31
+ "openai",
32
+ "google",
33
+ "google-vertex",
34
+ "mistral",
35
+ "mistralai",
36
+ "cohere",
37
+ "deepseek",
38
+ "deepseek-ai", # Used in nvidia_nim/deepseek-ai/model format
39
+ "qwen",
40
+ "alibaba",
41
+ "alibaba-cn",
42
+ "meta-llama",
43
+ "nvidia",
44
+ "moonshotai", # Used in nvidia_nim/moonshotai/model format
45
+ "iflow",
46
+ "iflowcn",
47
+ # These are aggregators/proxies - lower priority
48
+ "openrouter",
49
+ "azure",
50
+ "azure-cognitive-services",
51
+ "aws-bedrock",
52
+ "github-copilot",
53
+ "opencode",
54
+ "requesty",
55
+ "helicone",
56
+ "vercel",
57
+ "aihubmix",
58
+ "venice",
59
+ "poe",
60
+ "cortecs",
61
+ "fastrouter",
62
+ "ollama-cloud",
63
+ "nebius",
64
+ "fireworks-ai",
65
+ "groq",
66
+ "sap-ai-core",
67
+ "zenmux",
68
+ ]
69
+
70
+ # ============================================================================
71
+ # Provider Alias Mapping (for direct lookup)
72
+ # ============================================================================
73
+ #
74
+ # Maps custom/proxy provider names to their canonical equivalents in data sources.
75
+ # When looking up "nvidia_nim/org/model", we first try "nvidia/org/model" directly.
76
+ # This allows direct matches before falling back to fuzzy suffix matching.
77
+ #
78
+ # Format: "custom_provider": ["canonical_provider1", "canonical_provider2", ...]
79
+ # Multiple aliases are tried in order until a match is found.
80
+ #
81
+ PROVIDER_ALIASES = {
82
+ "nvidia_nim": ["nvidia"],
83
+ "gemini_cli": ["google"],
84
+ "gemini": ["google"],
85
+ "iflow": ["iflow", "iflowcn"], # iflow may exist as either
86
+ }
87
+
88
+
89
+ def _get_provider_priority(provider: str) -> int:
90
+ """
91
+ Get priority score for a provider (lower = better).
92
+ Native providers get priority over proxy/aggregator providers.
93
+ """
94
+ try:
95
+ return NATIVE_PROVIDER_PRIORITY.index(provider.lower())
96
+ except ValueError:
97
+ # Unknown providers get lowest priority
98
+ return len(NATIVE_PROVIDER_PRIORITY) + 1
99
+
100
+
101
+ def _extract_provider_from_source_id(source_id: str) -> str:
102
+ """
103
+ Extract the actual data provider from a source model ID.
104
+
105
+ Examples:
106
+ "anthropic/claude-opus-4.5" -> "anthropic"
107
+ "openrouter/google/gemini-2.5-pro" -> "google" (skip openrouter prefix)
108
+ "nvidia/mistralai/mistral-large" -> "mistralai" (3-segment, use middle)
109
+ """
110
+ parts = source_id.split("/")
111
+ if len(parts) >= 2:
112
+ # Skip openrouter prefix if present
113
+ if parts[0].lower() == "openrouter" and len(parts) >= 3:
114
+ return parts[1].lower()
115
+ # For 3-segment IDs like nvidia/mistralai/model, use middle segment
116
+ if len(parts) == 3:
117
+ return parts[1].lower()
118
+ return parts[0].lower()
119
+ return source_id.lower()
120
+
121
+
122
  # ============================================================================
123
  # Data Structures
124
  # ============================================================================
125
 
126
+
127
  @dataclass
128
  class ModelPricing:
129
  """Token-level pricing information."""
130
+
131
  prompt: Optional[float] = None
132
  completion: Optional[float] = None
133
  cached_input: Optional[float] = None
 
137
  @dataclass
138
  class ModelLimits:
139
  """Context and output token limits."""
140
+
141
  context_window: Optional[int] = None
142
  max_output: Optional[int] = None
143
 
144
 
145
+ @dataclass
146
  class ModelCapabilities:
147
  """Feature flags for model capabilities."""
148
+
149
  tools: bool = False
150
  functions: bool = False
151
  reasoning: bool = False
 
153
  system_prompt: bool = True
154
  caching: bool = False
155
  prefill: bool = False
156
+ # Extended capabilities from Models.dev
157
+ structured_output: bool = False
158
+ temperature: bool = True # Most models support temperature
159
+ attachments: bool = False # File/document attachments
160
+ interleaved: bool = False # Interleaved content support
161
+
162
+
163
+ @dataclass
164
+ class ModelInfo:
165
+ """Extended model information and metadata."""
166
+
167
+ family: str = "" # Model family (e.g., "claude-opus", "gpt-4")
168
+ description: str = "" # Model description
169
+ knowledge_cutoff: str = "" # Knowledge cutoff date (e.g., "2025-03-31")
170
+ release_date: str = "" # Model release date
171
+ open_weights: bool = False # Whether model weights are open
172
+ status: str = "active" # Model status: active, deprecated, preview
173
+ tokenizer: str = "" # Tokenizer type
174
+ huggingface_id: str = "" # HuggingFace model ID
175
 
176
 
177
  @dataclass
178
  class ModelMetadata:
179
  """Complete model information record."""
180
+
181
  model_id: str
182
  display_name: str = ""
183
  provider: str = ""
184
  category: str = "chat" # chat, embedding, image, audio
185
+
186
  pricing: ModelPricing = field(default_factory=ModelPricing)
187
  limits: ModelLimits = field(default_factory=ModelLimits)
188
  capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
189
+ info: ModelInfo = field(default_factory=ModelInfo) # Extended info
190
+
191
  input_types: List[str] = field(default_factory=lambda: ["text"])
192
  output_types: List[str] = field(default_factory=lambda: ["text"])
193
+ supported_parameters: List[str] = field(
194
+ default_factory=list
195
+ ) # Supported API params
196
+
197
  timestamp: int = field(default_factory=lambda: int(time.time()))
198
  origin: str = ""
199
  match_quality: str = "unknown"
200
+
201
  def as_api_response(self) -> Dict[str, Any]:
202
+ """
203
+ Format for OpenAI-compatible /v1/models response.
204
+
205
+ Standard OpenAI fields come first, then extended fields,
206
+ then debug/meta fields prefixed with underscore.
207
+ """
208
+ # === Core OpenAI-compatible fields ===
209
  response = {
210
  "id": self.model_id,
211
  "object": "model",
212
  "created": self.timestamp,
213
  "owned_by": self.provider or "proxy",
214
  }
215
+
216
+ # === Token limits (standard) ===
 
 
 
 
 
 
 
 
 
 
217
  if self.limits.context_window:
218
+ response["context_length"] = self.limits.context_window
 
219
  if self.limits.max_output:
220
+ response["max_completion_tokens"] = self.limits.max_output
221
+
222
+ # === Pricing fields (extended but common) ===
223
+ if self.pricing.prompt is not None:
224
+ response["pricing"] = {"prompt": self.pricing.prompt}
225
+ if self.pricing.completion is not None:
226
+ response["pricing"]["completion"] = self.pricing.completion
227
+ if self.pricing.cached_input is not None:
228
+ response["pricing"]["cached_input"] = self.pricing.cached_input
229
+ if self.pricing.cache_write is not None:
230
+ response["pricing"]["cache_write"] = self.pricing.cache_write
231
+
232
+ # === Architecture/modalities (OpenRouter-style) ===
233
+ response["architecture"] = {
234
+ "input_modalities": self.input_types,
235
+ "output_modalities": self.output_types,
236
+ }
237
+ if self.info.tokenizer:
238
+ response["architecture"]["tokenizer"] = self.info.tokenizer
239
+
240
+ # === Capabilities (extended) ===
241
  response["capabilities"] = {
242
  "tool_choice": self.capabilities.tools,
243
  "function_calling": self.capabilities.functions,
 
246
  "system_messages": self.capabilities.system_prompt,
247
  "prompt_caching": self.capabilities.caching,
248
  "assistant_prefill": self.capabilities.prefill,
249
+ "structured_output": self.capabilities.structured_output,
250
+ "temperature": self.capabilities.temperature,
251
+ "attachments": self.capabilities.attachments,
252
+ "interleaved": self.capabilities.interleaved,
253
  }
254
+
255
+ # === Supported parameters (if available) ===
256
+ if self.supported_parameters:
257
+ response["supported_parameters"] = self.supported_parameters
258
+
259
+ # === Extended model info ===
260
+ if self.info.family:
261
+ response["family"] = self.info.family
262
+ if self.info.description:
263
+ response["description"] = self.info.description
264
+ if self.info.knowledge_cutoff:
265
+ response["knowledge_cutoff"] = self.info.knowledge_cutoff
266
+ if self.info.release_date:
267
+ response["release_date"] = self.info.release_date
268
+ if self.info.open_weights:
269
+ response["open_weights"] = self.info.open_weights
270
+ if self.info.status and self.info.status != "active":
271
+ response["status"] = self.info.status
272
+ if self.info.huggingface_id:
273
+ response["huggingface_id"] = self.info.huggingface_id
274
+
275
+ # === Legacy fields for backward compatibility ===
276
+ # Some tools may expect these field names
277
+ if self.limits.context_window:
278
+ response["max_input_tokens"] = self.limits.context_window
279
+ response["context_window"] = self.limits.context_window
280
+ if self.limits.max_output:
281
+ response["max_output_tokens"] = self.limits.max_output
282
+ if self.pricing.prompt is not None:
283
+ response["input_cost_per_token"] = self.pricing.prompt
284
+ if self.pricing.completion is not None:
285
+ response["output_cost_per_token"] = self.pricing.completion
286
+ if self.pricing.cached_input is not None:
287
+ response["cache_read_input_token_cost"] = self.pricing.cached_input
288
+ if self.pricing.cache_write is not None:
289
+ response["cache_creation_input_token_cost"] = self.pricing.cache_write
290
+ response["mode"] = self.category
291
+ response["supported_modalities"] = self.input_types
292
+ response["supported_output_modalities"] = self.output_types
293
+
294
+ # === Debug/meta fields (underscore prefix) ===
295
  if self.origin:
296
+ origin_parts = self.origin.split("|")
297
+ main_origin = origin_parts[0]
298
+
299
+ response["_sources"] = [main_origin]
300
  response["_match_type"] = self.match_quality
301
+
302
+ for part in origin_parts[1:]:
303
+ if part.startswith("parent:"):
304
+ response["_parent_model"] = part[len("parent:") :]
305
+ break
306
+
307
  return response
308
+
309
  def as_minimal(self) -> Dict[str, Any]:
310
  """Minimal OpenAI format."""
311
  return {
312
  "id": self.model_id,
313
+ "object": "model",
314
  "created": self.timestamp,
315
  "owned_by": self.provider or "proxy",
316
  }
317
+
318
  def to_dict(self) -> Dict[str, Any]:
319
  """Alias for as_api_response() - backward compatibility."""
320
  return self.as_api_response()
321
+
322
  def to_openai_format(self) -> Dict[str, Any]:
323
  """Alias for as_minimal() - backward compatibility."""
324
  return self.as_minimal()
325
+
326
  # Backward-compatible property aliases
327
  @property
328
  def id(self) -> str:
329
  return self.model_id
330
+
331
  @property
332
  def name(self) -> str:
333
  return self.display_name
334
+
335
  @property
336
  def input_cost_per_token(self) -> Optional[float]:
337
  return self.pricing.prompt
338
+
339
  @property
340
  def output_cost_per_token(self) -> Optional[float]:
341
  return self.pricing.completion
342
+
343
  @property
344
  def cache_read_input_token_cost(self) -> Optional[float]:
345
  return self.pricing.cached_input
346
+
347
  @property
348
  def cache_creation_input_token_cost(self) -> Optional[float]:
349
  return self.pricing.cache_write
350
+
351
  @property
352
  def max_input_tokens(self) -> Optional[int]:
353
  return self.limits.context_window
354
+
355
  @property
356
  def max_output_tokens(self) -> Optional[int]:
357
  return self.limits.max_output
358
+
359
  @property
360
  def mode(self) -> str:
361
  return self.category
362
+
363
  @property
364
  def supported_modalities(self) -> List[str]:
365
  return self.input_types
366
+
367
  @property
368
  def supported_output_modalities(self) -> List[str]:
369
  return self.output_types
370
+
371
  @property
372
  def supports_tool_choice(self) -> bool:
373
  return self.capabilities.tools
374
+
375
  @property
376
  def supports_function_calling(self) -> bool:
377
  return self.capabilities.functions
378
+
379
  @property
380
  def supports_reasoning(self) -> bool:
381
  return self.capabilities.reasoning
382
+
383
  @property
384
  def supports_vision(self) -> bool:
385
  return self.capabilities.vision
386
+
387
  @property
388
  def supports_system_messages(self) -> bool:
389
  return self.capabilities.system_prompt
390
+
391
  @property
392
  def supports_prompt_caching(self) -> bool:
393
  return self.capabilities.caching
394
+
395
  @property
396
  def supports_assistant_prefill(self) -> bool:
397
  return self.capabilities.prefill
398
+
399
  @property
400
  def litellm_provider(self) -> str:
401
  return self.provider
402
+
403
  @property
404
  def created(self) -> int:
405
  return self.timestamp
406
+
407
  @property
408
  def _sources(self) -> List[str]:
409
  return [self.origin] if self.origin else []
410
+
411
  @property
412
  def _match_type(self) -> str:
413
  return self.match_quality
 
417
  # Data Source Adapters
418
  # ============================================================================
419
 
420
+
421
  class DataSourceAdapter:
422
  """Base interface for external data sources."""
423
+
424
  source_name: str = "unknown"
425
  endpoint: str = ""
426
+
427
  def fetch(self) -> Dict[str, Dict]:
428
  """Retrieve and normalize data. Returns {model_id: raw_data}."""
429
  raise NotImplementedError
430
+
431
  def _http_get(self, url: str, timeout: int = 30) -> Any:
432
  """Execute HTTP GET with standard headers."""
433
  req = Request(url, headers={"User-Agent": "ModelRegistry/1.0"})
 
437
 
438
  class OpenRouterAdapter(DataSourceAdapter):
439
  """Fetches model data from OpenRouter's public API."""
440
+
441
  source_name = "openrouter"
442
  endpoint = "https://openrouter.ai/api/v1/models"
443
+
444
  def fetch(self) -> Dict[str, Dict]:
445
  try:
446
  raw = self._http_get(self.endpoint)
447
  entries = raw.get("data", [])
448
+
449
  catalog = {}
450
  for entry in entries:
451
  mid = entry.get("id")
452
  if not mid:
453
  continue
454
+
455
  full_id = f"openrouter/{mid}"
456
  catalog[full_id] = self._normalize(entry)
457
+
458
  return catalog
459
  except (URLError, json.JSONDecodeError, TimeoutError) as err:
460
  raise ConnectionError(f"OpenRouter unavailable: {err}") from err
461
+
462
  def _normalize(self, raw: Dict) -> Dict:
463
  """Transform OpenRouter schema to internal format."""
464
  prices = raw.get("pricing", {})
465
  arch = raw.get("architecture", {})
466
  top = raw.get("top_provider", {})
467
  params = raw.get("supported_parameters", [])
468
+
469
  tokenizer = arch.get("tokenizer", "")
470
  category = "embedding" if "embedding" in tokenizer.lower() else "chat"
471
+
472
+ # Extract cache pricing
473
+ cache_read = prices.get("input_cache_read", 0)
474
+ cache_write = prices.get("input_cache_write", 0)
475
+
476
  return {
477
+ # Basic info
478
  "name": raw.get("name", ""),
479
+ "original_id": raw.get("id", ""),
480
+ "provider": "openrouter",
481
+ "source": "openrouter",
482
+ "category": category,
483
+ # Pricing (already per-token from OpenRouter)
484
  "prompt_cost": float(prices.get("prompt", 0)),
485
  "completion_cost": float(prices.get("completion", 0)),
486
+ "cache_read_cost": float(cache_read) if cache_read else None,
487
+ "cache_write_cost": float(cache_write) if cache_write else None,
488
+ # Limits
489
  "context": top.get("context_length", 0),
490
  "max_out": top.get("max_completion_tokens", 0),
491
+ # Modalities
492
  "inputs": arch.get("input_modalities", ["text"]),
493
  "outputs": arch.get("output_modalities", ["text"]),
494
+ # Capabilities
495
  "has_tools": "tool_choice" in params or "tools" in params,
496
  "has_functions": "tools" in params or "function_calling" in params,
497
+ "has_reasoning": "reasoning" in params or "include_reasoning" in params,
498
  "has_vision": "image" in arch.get("input_modalities", []),
499
+ "has_structured_output": "structured_outputs" in params
500
+ or "response_format" in params,
501
+ "has_temperature": "temperature" in params,
502
+ "has_attachments": "file" in arch.get("input_modalities", []),
503
+ "has_interleaved": False, # Not available from OpenRouter
504
+ # Extended model info
505
+ "description": raw.get("description", ""),
506
+ "tokenizer": tokenizer,
507
+ "huggingface_id": raw.get("hugging_face_id", ""),
508
+ "supported_parameters": params,
509
+ # OpenRouter doesn't provide these, leave empty
510
+ "family": "",
511
+ "knowledge_cutoff": "",
512
+ "release_date": "",
513
+ "open_weights": False,
514
+ "status": "active",
515
  }
516
 
517
 
518
  class ModelsDevAdapter(DataSourceAdapter):
519
  """Fetches model data from Models.dev catalog."""
520
+
521
  source_name = "modelsdev"
522
  endpoint = "https://models.dev/api.json"
523
+
524
  def __init__(self, skip_providers: Optional[List[str]] = None):
525
  self.skip_providers = skip_providers or []
526
+
527
  def fetch(self) -> Dict[str, Dict]:
528
  try:
529
  raw = self._http_get(self.endpoint)
530
+
531
  catalog = {}
532
  for provider_key, provider_block in raw.items():
533
  if not isinstance(provider_block, dict):
534
  continue
535
  if provider_key in self.skip_providers:
536
  continue
537
+
538
  models_block = provider_block.get("models", {})
539
  if not isinstance(models_block, dict):
540
  continue
541
+
542
  for model_key, model_data in models_block.items():
543
  if not isinstance(model_data, dict):
544
  continue
545
+
546
  full_id = f"{provider_key}/{model_key}"
547
  catalog[full_id] = self._normalize(model_data, provider_key)
548
+
549
  return catalog
550
  except (URLError, json.JSONDecodeError, TimeoutError) as err:
551
  raise ConnectionError(f"Models.dev unavailable: {err}") from err
552
+
553
  def _normalize(self, raw: Dict, provider_key: str) -> Dict:
554
  """Transform Models.dev schema to internal format."""
555
  costs = raw.get("cost", {})
556
  mods = raw.get("modalities", {})
557
  lims = raw.get("limit", {})
558
+
559
  outputs = mods.get("output", ["text"])
560
  if "image" in outputs:
561
  category = "image"
 
563
  category = "audio"
564
  else:
565
  category = "chat"
566
+
567
  # Models.dev uses per-million pricing, convert to per-token
568
  divisor = 1_000_000
569
+
570
  cache_read = costs.get("cache_read")
571
  cache_write = costs.get("cache_write")
572
+
573
  return {
574
+ # Basic info
575
  "name": raw.get("name", ""),
576
+ "original_id": raw.get("id", ""),
577
+ "provider": provider_key,
578
+ "source": "modelsdev",
579
+ "category": category,
580
+ # Pricing (converted to per-token)
581
  "prompt_cost": float(costs.get("input", 0)) / divisor,
582
  "completion_cost": float(costs.get("output", 0)) / divisor,
583
  "cache_read_cost": float(cache_read) / divisor if cache_read else None,
584
  "cache_write_cost": float(cache_write) / divisor if cache_write else None,
585
+ # Limits
586
  "context": lims.get("context", 0),
587
  "max_out": lims.get("output", 0),
588
+ # Modalities
589
  "inputs": mods.get("input", ["text"]),
590
  "outputs": outputs,
591
+ # Capabilities
592
  "has_tools": raw.get("tool_call", False),
593
  "has_functions": raw.get("tool_call", False),
594
  "has_reasoning": raw.get("reasoning", False),
595
  "has_vision": "image" in mods.get("input", []),
596
+ "has_structured_output": raw.get("structured_output", False),
597
+ "has_temperature": raw.get("temperature", True),
598
+ "has_attachments": raw.get("attachment", False),
599
+ "has_interleaved": raw.get("interleaved", False),
600
+ # Extended model info
601
+ "family": raw.get("family", ""),
602
+ "knowledge_cutoff": raw.get("knowledge", ""),
603
+ "release_date": raw.get("release_date", ""),
604
+ "open_weights": raw.get("open_weights", False),
605
+ "status": raw.get("status", "active"),
606
  }
607
 
608
 
 
610
  # Lookup Index
611
  # ============================================================================
612
 
613
+
614
+ def _normalize_version_pattern(name: str) -> str:
615
+ """
616
+ Normalize version patterns in model names for fuzzy matching.
617
+
618
+ Converts various version formats to a canonical form:
619
+ - claude-opus-4-5 -> claude-opus-4.5
620
+ - claude-opus-4.5 -> claude-opus-4.5
621
+ - gemini-2-0-flash -> gemini-2.0-flash
622
+ - gemini-2-5-pro -> gemini-2.5-pro
623
+
624
+ Only applies to patterns that look like versions (digit-digit at end).
625
+ """
626
+ import re
627
+
628
+ # Pattern matches: -X-Y at end of string or before another dash/segment
629
+ # where X and Y are digits (like -4-5, -2-0, -2-5)
630
+ # This converts 4-5 to 4.5, 2-0 to 2.0, etc.
631
+ normalized = re.sub(r"-(\d+)-(\d+)(?=-|$)", r"-\1.\2", name)
632
+ return normalized
633
+
634
+
635
  class ModelIndex:
636
  """Fast lookup structure for model ID resolution."""
637
+
638
  def __init__(self):
639
  self._by_full_id: Dict[str, str] = {} # normalized_id -> canonical_id
640
  self._by_suffix: Dict[str, List[str]] = {} # short_name -> [canonical_ids]
641
+ self._by_normalized: Dict[
642
+ str, List[str]
643
+ ] = {} # normalized_name -> [canonical_ids]
644
+
645
  def clear(self):
646
  """Reset the index."""
647
  self._by_full_id.clear()
648
  self._by_suffix.clear()
649
+ self._by_normalized.clear()
650
+
651
  def entry_count(self) -> int:
652
  """Return total number of suffix index entries."""
653
  return sum(len(v) for v in self._by_suffix.values())
654
+
655
  def add(self, canonical_id: str):
656
  """Index a canonical model ID for various lookup patterns."""
657
  self._by_full_id[canonical_id] = canonical_id
658
+
659
  segments = canonical_id.split("/")
660
  if len(segments) >= 2:
661
  # Index by everything after first segment
662
  partial = "/".join(segments[1:])
663
  self._by_suffix.setdefault(partial, []).append(canonical_id)
664
+
665
  # Index by final segment only
666
  if len(segments) >= 3:
667
  tail = segments[-1]
668
  self._by_suffix.setdefault(tail, []).append(canonical_id)
669
+
670
+ # Index by normalized version pattern (e.g., claude-opus-4.5)
671
+ # This allows 4-5 queries to match 4.5 entries and vice versa
672
+ normalized_partial = _normalize_version_pattern(partial)
673
+ if normalized_partial != partial:
674
+ self._by_normalized.setdefault(normalized_partial, []).append(
675
+ canonical_id
676
+ )
677
+
678
  def resolve(self, query: str) -> List[str]:
679
  """Find all canonical IDs matching a query."""
680
  # Direct match
681
  if query in self._by_full_id:
682
  return [self._by_full_id[query]]
683
+
684
  # Try with openrouter prefix
685
  prefixed = f"openrouter/{query}"
686
  if prefixed in self._by_full_id:
687
  return [self._by_full_id[prefixed]]
688
+
689
  # Extract search terms from query
690
  search_keys = []
691
  parts = query.split("/")
 
694
  search_keys.append(parts[-1])
695
  else:
696
  search_keys.append(query)
697
+
698
+ # Find matches in suffix index
699
  matches = []
700
  seen = set()
701
  for key in search_keys:
 
703
  if cid not in seen:
704
  seen.add(cid)
705
  matches.append(cid)
706
+
707
+ # If no matches, try normalized version pattern matching
708
+ # This allows claude-opus-4-5 to match claude-opus-4.5
709
+ if not matches:
710
+ for key in search_keys:
711
+ normalized_key = _normalize_version_pattern(key)
712
+ # Check in normalized index
713
+ for cid in self._by_normalized.get(normalized_key, []):
714
+ if cid not in seen:
715
+ seen.add(cid)
716
+ matches.append(cid)
717
+ # Also check if normalized key matches regular suffix
718
+ # (for when source has 4-5 and query uses 4.5)
719
+ for cid in self._by_suffix.get(normalized_key, []):
720
+ if cid not in seen:
721
+ seen.add(cid)
722
+ matches.append(cid)
723
+
724
  return matches
725
 
726
 
 
728
  # Data Merger
729
  # ============================================================================
730
 
731
+
732
  class DataMerger:
733
+ """
734
+ Selects best source and creates ModelMetadata for queried model.
735
+
736
+ Key principle: For custom provider models (like antigravity/claude-opus-4-5),
737
+ we inherit technical specs from the best matching native provider source
738
+ (like anthropic/claude-opus-4.5), but keep the queried model's identity.
739
+ """
740
+
741
  @staticmethod
742
+ def create_metadata(
743
+ queried_model_id: str,
744
+ records: List[Tuple[Dict, str]],
745
+ quality: str,
746
+ ) -> ModelMetadata:
747
+ """
748
+ Create ModelMetadata for the queried model.
749
+
750
+ For fuzzy matches, picks the best source based on provider priority
751
+ rather than merging multiple sources (which would average pricing incorrectly).
752
+
753
+ The queried model's provider is preserved in owned_by, while technical
754
+ specs come from the best matching source.
755
+ """
756
+ if not records:
757
+ raise ValueError("No records to create metadata from")
758
+
759
+ # Extract the queried provider from the model ID
760
+ queried_parts = queried_model_id.split("/")
761
+ queried_provider = queried_parts[0] if queried_parts else "unknown"
762
+
763
+ # Pick the best source based on provider priority
764
+ best_record, best_origin = DataMerger._select_best_source(records)
765
+
766
+ # Extract parent model ID from origin for transparency
767
+ parent_model_id = DataMerger._extract_model_id_from_origin(best_origin)
768
+
769
  return ModelMetadata(
770
+ model_id=queried_model_id,
771
+ display_name=best_record.get("name", queried_model_id.split("/")[-1]),
772
+ # Use QUERIED provider, not source provider
773
+ provider=queried_provider,
774
+ category=best_record.get("category", "chat"),
775
  pricing=ModelPricing(
776
+ prompt=best_record.get("prompt_cost"),
777
+ completion=best_record.get("completion_cost"),
778
+ cached_input=best_record.get("cache_read_cost"),
779
+ cache_write=best_record.get("cache_write_cost"),
780
  ),
781
  limits=ModelLimits(
782
+ context_window=best_record.get("context") or None,
783
+ max_output=best_record.get("max_out") or None,
784
  ),
785
  capabilities=ModelCapabilities(
786
+ tools=best_record.get("has_tools", False),
787
+ functions=best_record.get("has_functions", False),
788
+ reasoning=best_record.get("has_reasoning", False),
789
+ vision=best_record.get("has_vision", False),
790
+ # Extended capabilities
791
+ structured_output=best_record.get("has_structured_output", False),
792
+ temperature=best_record.get("has_temperature", True),
793
+ attachments=best_record.get("has_attachments", False),
794
+ interleaved=best_record.get("has_interleaved", False),
795
+ ),
796
+ info=ModelInfo(
797
+ family=best_record.get("family", ""),
798
+ description=best_record.get("description", ""),
799
+ knowledge_cutoff=best_record.get("knowledge_cutoff", ""),
800
+ release_date=best_record.get("release_date", ""),
801
+ open_weights=best_record.get("open_weights", False),
802
+ status=best_record.get("status", "active"),
803
+ tokenizer=best_record.get("tokenizer", ""),
804
+ huggingface_id=best_record.get("huggingface_id", ""),
805
  ),
806
+ input_types=best_record.get("inputs", ["text"]),
807
+ output_types=best_record.get("outputs", ["text"]),
808
+ supported_parameters=best_record.get("supported_parameters", []),
809
+ origin=f"{best_origin}|parent:{parent_model_id}"
810
+ if parent_model_id
811
+ else best_origin,
812
  match_quality=quality,
813
  )
814
+
815
  @staticmethod
816
+ def _select_best_source(records: List[Tuple[Dict, str]]) -> Tuple[Dict, str]:
817
+ """
818
+ Select the best source from multiple candidates based on provider priority.
819
+
820
+ Prefers native providers (anthropic, openai, google) over proxy/aggregator
821
+ providers (azure, openrouter, requesty, etc.).
822
+
823
+ When multiple sources have the same extracted provider (e.g., both
824
+ requesty/anthropic/model and anthropic/model extract to anthropic),
825
+ prefer the source where the first segment is the native provider
826
+ (i.e., anthropic/model is preferred over requesty/anthropic/model).
827
+ """
828
  if len(records) == 1:
829
+ return records[0]
830
+
831
+ def get_sort_key(record_tuple: Tuple[Dict, str]) -> Tuple[int, int, int]:
832
+ data, origin = record_tuple
833
+ # Extract source_id from origin string like "modelsdev:fuzzy:anthropic/claude-opus-4.5"
834
+ source_id = origin.split(":")[-1] if ":" in origin else origin
835
+
836
+ # Primary: priority of extracted provider (handles nested paths)
837
+ provider = _extract_provider_from_source_id(source_id)
838
+ primary_priority = _get_provider_priority(provider)
839
+
840
+ # Secondary: prefer sources where first segment is a native provider
841
+ # This ensures anthropic/model wins over requesty/anthropic/model
842
+ parts = source_id.split("/")
843
+ first_segment = parts[0].lower() if parts else ""
844
+ first_segment_priority = _get_provider_priority(first_segment)
845
+
846
+ # Tertiary: prefer shorter paths (2-segment over 3-segment)
847
+ # This is a tiebreaker when both have same first segment priority
848
+ path_length = len(parts)
849
+
850
+ return (primary_priority, first_segment_priority, path_length)
851
+
852
+ # Sort by priority tuple (lower is better) and return the best
853
+ sorted_records = sorted(records, key=get_sort_key)
854
+ return sorted_records[0]
855
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
856
  @staticmethod
857
+ def _extract_model_id_from_origin(origin: str) -> Optional[str]:
858
+ """
859
+ Extract the source model ID from an origin string.
860
+
861
+ Examples:
862
+ "modelsdev:fuzzy:anthropic/claude-opus-4.5" -> "anthropic/claude-opus-4.5"
863
+ "openrouter:exact:openrouter/google/gemini-2.5-pro" -> "google/gemini-2.5-pro"
864
+ """
865
+ if ":" not in origin:
866
  return None
867
+
868
+ parts = origin.split(":")
869
+ if len(parts) >= 3:
870
+ source_id = parts[-1]
871
+ # Remove openrouter prefix if present
872
+ if source_id.startswith("openrouter/"):
873
+ source_id = source_id[len("openrouter/") :]
874
+ return source_id
875
+ return None
876
+
877
+ # Legacy method for backward compatibility
878
+ @staticmethod
879
+ def single(model_id: str, data: Dict, origin: str, quality: str) -> ModelMetadata:
880
+ """Create ModelMetadata from a single source record. Legacy method."""
881
+ return DataMerger.create_metadata(model_id, [(data, origin)], quality)
882
+
883
+ # Legacy method for backward compatibility
884
+ @staticmethod
885
+ def combine(
886
+ model_id: str, records: List[Tuple[Dict, str]], quality: str
887
+ ) -> ModelMetadata:
888
+ """Create ModelMetadata from records. Now uses best-source selection."""
889
+ return DataMerger.create_metadata(model_id, records, quality)
890
 
891
 
892
  # ============================================================================
893
  # Main Registry Service
894
  # ============================================================================
895
 
896
+
897
  class ModelRegistry:
898
  """
899
  Central registry for model metadata from external catalogs.
900
+
901
  Manages background data refresh and provides lookup/pricing APIs.
902
  """
903
+
904
  REFRESH_INTERVAL_DEFAULT = 6 * 60 * 60 # 6 hours
905
+
906
  def __init__(
907
  self,
908
  refresh_seconds: Optional[int] = None,
 
912
  self._refresh_interval = refresh_seconds or (
913
  int(interval_env) if interval_env else self.REFRESH_INTERVAL_DEFAULT
914
  )
915
+
916
  # Configure adapters
917
  self._adapters: List[DataSourceAdapter] = [
918
  OpenRouterAdapter(),
919
  ModelsDevAdapter(skip_providers=skip_modelsdev_providers or []),
920
  ]
921
+
922
  # Raw data stores
923
  self._openrouter_store: Dict[str, Dict] = {}
924
  self._modelsdev_store: Dict[str, Dict] = {}
925
+
926
  # Lookup infrastructure
927
  self._index = ModelIndex()
928
  self._result_cache: Dict[str, ModelMetadata] = {}
929
+
930
  # Async coordination
931
  self._ready = asyncio.Event()
932
  self._mutex = asyncio.Lock()
933
  self._worker: Optional[asyncio.Task] = None
934
  self._last_refresh: float = 0
935
+
936
  # ---------- Lifecycle ----------
937
+
938
  async def start(self):
939
  """Begin background refresh worker."""
940
  if self._worker is None:
941
  self._worker = asyncio.create_task(self._refresh_worker())
942
  logger.info(
943
+ "ModelRegistry started (refresh every %ds)", self._refresh_interval
 
944
  )
945
+
946
  async def stop(self):
947
  """Halt background worker."""
948
  if self._worker:
 
953
  pass
954
  self._worker = None
955
  logger.info("ModelRegistry stopped")
956
+
957
  async def await_ready(self, timeout_secs: float = 30.0) -> bool:
958
  """Block until initial data load completes."""
959
  try:
 
962
  except asyncio.TimeoutError:
963
  logger.warning("ModelRegistry ready timeout after %.1fs", timeout_secs)
964
  return False
965
+
966
  @property
967
  def is_ready(self) -> bool:
968
  return self._ready.is_set()
969
+
970
  # ---------- Background Worker ----------
971
+
972
  async def _refresh_worker(self):
973
  """Periodic refresh loop."""
974
  await self._load_all_sources()
975
  self._ready.set()
976
+
977
  while True:
978
  try:
979
  await asyncio.sleep(self._refresh_interval)
 
984
  break
985
  except Exception as ex:
986
  logger.error("Registry refresh error: %s", ex)
987
+
988
  async def _load_all_sources(self):
989
  """Fetch from all adapters concurrently."""
990
  loop = asyncio.get_event_loop()
991
+
992
  tasks = [
993
+ loop.run_in_executor(None, adapter.fetch) for adapter in self._adapters
 
994
  ]
995
+
996
  results = await asyncio.gather(*tasks, return_exceptions=True)
997
+
998
  async with self._mutex:
999
  for adapter, result in zip(self._adapters, results):
1000
  if isinstance(result, Exception):
1001
  logger.error("%s fetch failed: %s", adapter.source_name, result)
1002
  continue
1003
+
1004
  if adapter.source_name == "openrouter":
1005
  self._openrouter_store = result
1006
  logger.info("OpenRouter: %d models loaded", len(result))
1007
  elif adapter.source_name == "modelsdev":
1008
  self._modelsdev_store = result
1009
  logger.info("Models.dev: %d models loaded", len(result))
1010
+
1011
  self._rebuild_index()
1012
  self._last_refresh = time.time()
1013
+
1014
  def _rebuild_index(self):
1015
  """Reconstruct lookup index from current stores."""
1016
  self._index.clear()
1017
  self._result_cache.clear()
1018
+
1019
  for model_id in self._openrouter_store:
1020
  self._index.add(model_id)
1021
+
1022
  for model_id in self._modelsdev_store:
1023
  self._index.add(model_id)
1024
+
1025
  # ---------- Query API ----------
1026
+
1027
  def lookup(self, model_id: str) -> Optional[ModelMetadata]:
1028
  """
1029
  Retrieve model metadata by ID.
1030
+
1031
  Matching strategy:
1032
  1. Exact match against known IDs
1033
  2. Fuzzy match by model name suffix
 
1035
  """
1036
  if model_id in self._result_cache:
1037
  return self._result_cache[model_id]
1038
+
1039
  metadata = self._resolve_model(model_id)
1040
  if metadata:
1041
  self._result_cache[model_id] = metadata
1042
  return metadata
1043
+
1044
  def _resolve_model(self, model_id: str) -> Optional[ModelMetadata]:
1045
  """Build ModelMetadata by matching source data."""
1046
  records: List[Tuple[Dict, str]] = []
1047
  quality = "none"
1048
+
1049
+ # Step 1: Check exact matches first
1050
+ or_key = (
1051
+ f"openrouter/{model_id}"
1052
+ if not model_id.startswith("openrouter/")
1053
+ else model_id
1054
+ )
1055
  if or_key in self._openrouter_store:
1056
+ records.append(
1057
+ (self._openrouter_store[or_key], f"openrouter:exact:{or_key}")
1058
+ )
1059
  quality = "exact"
1060
+
1061
  if model_id in self._modelsdev_store:
1062
+ records.append(
1063
+ (self._modelsdev_store[model_id], f"modelsdev:exact:{model_id}")
1064
+ )
1065
  quality = "exact"
1066
+
1067
+ # Step 2: Try provider alias substitution for direct match
1068
+ # This handles cases like nvidia_nim/org/model -> nvidia/org/model
1069
+ if not records:
1070
+ alias_candidates = self._get_alias_candidates(model_id)
1071
+ for alias_id in alias_candidates:
1072
+ # Try Models.dev first (usually more complete)
1073
+ if alias_id in self._modelsdev_store:
1074
+ records.append(
1075
+ (self._modelsdev_store[alias_id], f"modelsdev:alias:{alias_id}")
1076
+ )
1077
+ quality = "alias"
1078
+ break # Take first match
1079
+ # Try OpenRouter with prefix
1080
+ or_alias = f"openrouter/{alias_id}"
1081
+ if or_alias in self._openrouter_store:
1082
+ records.append(
1083
+ (
1084
+ self._openrouter_store[or_alias],
1085
+ f"openrouter:alias:{or_alias}",
1086
+ )
1087
+ )
1088
+ quality = "alias"
1089
+ break
1090
+
1091
+ # Step 3: Fall back to fuzzy index search
1092
  if not records:
1093
  candidates = self._index.resolve(model_id)
1094
  for cid in candidates:
1095
  if cid in self._openrouter_store:
1096
+ records.append(
1097
+ (self._openrouter_store[cid], f"openrouter:fuzzy:{cid}")
1098
+ )
1099
  elif cid in self._modelsdev_store:
1100
+ records.append(
1101
+ (self._modelsdev_store[cid], f"modelsdev:fuzzy:{cid}")
1102
+ )
1103
+
1104
  if records:
1105
  quality = "fuzzy"
1106
+
1107
  if not records:
1108
  return None
1109
+
1110
  return DataMerger.combine(model_id, records, quality)
1111
+
1112
+ def _get_alias_candidates(self, model_id: str) -> List[str]:
1113
+ """
1114
+ Generate alternative model IDs by substituting provider aliases.
1115
+
1116
+ Examples:
1117
+ nvidia_nim/mistralai/model -> nvidia/mistralai/model
1118
+ gemini_cli/gemini-2.5-flash -> google/gemini-2.5-flash
1119
+ gemini/gemini-2.5-pro -> google/gemini-2.5-pro
1120
+ """
1121
+ parts = model_id.split("/")
1122
+ if len(parts) < 2:
1123
+ return []
1124
+
1125
+ provider = parts[0]
1126
+ rest = "/".join(parts[1:])
1127
+
1128
+ candidates = []
1129
+
1130
+ # Check if provider has aliases defined
1131
+ if provider in PROVIDER_ALIASES:
1132
+ for alias in PROVIDER_ALIASES[provider]:
1133
+ candidates.append(f"{alias}/{rest}")
1134
+
1135
+ return candidates
1136
+
1137
  def get_pricing(self, model_id: str) -> Optional[Dict[str, float]]:
1138
  """Extract just pricing info for cost calculations."""
1139
  meta = self.lookup(model_id)
1140
  if not meta:
1141
  return None
1142
+
1143
  result = {}
1144
  if meta.pricing.prompt is not None:
1145
  result["input_cost_per_token"] = meta.pricing.prompt
 
1149
  result["cache_read_input_token_cost"] = meta.pricing.cached_input
1150
  if meta.pricing.cache_write is not None:
1151
  result["cache_creation_input_token_cost"] = meta.pricing.cache_write
1152
+
1153
  return result if result else None
1154
+
1155
  def compute_cost(
1156
  self,
1157
  model_id: str,
 
1162
  ) -> Optional[float]:
1163
  """
1164
  Calculate total request cost.
1165
+
1166
  Returns None if pricing unavailable.
1167
  """
1168
  pricing = self.get_pricing(model_id)
1169
  if not pricing:
1170
  return None
1171
+
1172
  in_rate = pricing.get("input_cost_per_token")
1173
  out_rate = pricing.get("output_cost_per_token")
1174
+
1175
  if in_rate is None or out_rate is None:
1176
  return None
1177
+
1178
  total = (input_tokens * in_rate) + (output_tokens * out_rate)
1179
+
1180
  cache_read_rate = pricing.get("cache_read_input_token_cost")
1181
  if cache_read_rate and cache_hit_tokens:
1182
  total += cache_hit_tokens * cache_read_rate
1183
+
1184
  cache_write_rate = pricing.get("cache_creation_input_token_cost")
1185
  if cache_write_rate and cache_miss_tokens:
1186
  total += cache_miss_tokens * cache_write_rate
1187
+
1188
  return total
1189
+
1190
  def enrich_models(self, model_ids: List[str]) -> List[Dict[str, Any]]:
1191
  """
1192
  Attach metadata to a list of model IDs.
1193
+
1194
  Used by /v1/models endpoint.
1195
  """
1196
  enriched = []
 
1200
  enriched.append(meta.as_api_response())
1201
  else:
1202
  # Fallback minimal entry
1203
+ enriched.append(
1204
+ {
1205
+ "id": mid,
1206
+ "object": "model",
1207
+ "created": int(time.time()),
1208
+ "owned_by": mid.split("/")[0] if "/" in mid else "unknown",
1209
+ }
1210
+ )
1211
  return enriched
1212
+
1213
  def all_raw_models(self) -> Dict[str, Dict]:
1214
  """Return all raw source data (for debugging)."""
1215
  combined = {}
1216
  combined.update(self._openrouter_store)
1217
  combined.update(self._modelsdev_store)
1218
  return combined
1219
+
1220
  def diagnostics(self) -> Dict[str, Any]:
1221
  """Return service health/stats."""
1222
  return {
 
1228
  "index_entries": self._index.entry_count(),
1229
  "refresh_interval": self._refresh_interval,
1230
  }
1231
+
1232
  # ---------- Backward Compatibility Methods ----------
1233
+
1234
  def get_model_info(self, model_id: str) -> Optional[ModelMetadata]:
1235
  """Alias for lookup() - backward compatibility."""
1236
  return self.lookup(model_id)
1237
+
1238
  def get_cost_info(self, model_id: str) -> Optional[Dict[str, float]]:
1239
  """Alias for get_pricing() - backward compatibility."""
1240
  return self.get_pricing(model_id)
1241
+
1242
  def calculate_cost(
1243
  self,
1244
  model_id: str,
 
1249
  ) -> Optional[float]:
1250
  """Alias for compute_cost() - backward compatibility."""
1251
  return self.compute_cost(
1252
+ model_id,
1253
+ prompt_tokens,
1254
+ completion_tokens,
1255
+ cache_read_tokens,
1256
+ cache_creation_tokens,
1257
  )
1258
+
1259
  def enrich_model_list(self, model_ids: List[str]) -> List[Dict[str, Any]]:
1260
  """Alias for enrich_models() - backward compatibility."""
1261
  return self.enrich_models(model_ids)
1262
+
1263
  def get_all_source_models(self) -> Dict[str, Dict]:
1264
  """Alias for all_raw_models() - backward compatibility."""
1265
  return self.all_raw_models()
1266
+
1267
  def get_stats(self) -> Dict[str, Any]:
1268
  """Alias for diagnostics() - backward compatibility."""
1269
  return self.diagnostics()
1270
+
1271
  def wait_for_ready(self, timeout: float = 30.0):
1272
  """Sync wrapper for await_ready() - for compatibility."""
1273
  return self.await_ready(timeout)
 
1278
  # ============================================================================
1279
 
1280
  # Alias for backward compatibility
1281
+ # Note: ModelInfo is now a real dataclass for extended model metadata
1282
+ # The old alias (ModelInfo = ModelMetadata) has been removed
1283
  ModelInfoService = ModelRegistry
1284
 
1285
  # Global singleton
 
1304
  # Compatibility shim - map old method names to new
1305
  class _CompatibilityWrapper:
1306
  """Provides old API method names for gradual migration."""
1307
+
1308
  def __init__(self, registry: ModelRegistry):
1309
  self._reg = registry
1310
+
1311
  def get_model_info(self, model_id: str) -> Optional[ModelMetadata]:
1312
  return self._reg.lookup(model_id)
1313
+
1314
  def get_cost_info(self, model_id: str) -> Optional[Dict[str, float]]:
1315
  return self._reg.get_pricing(model_id)
1316
+
1317
  def calculate_cost(
1318
+ self,
1319
+ model_id: str,
1320
+ prompt_tokens: int,
1321
+ completion_tokens: int,
1322
+ cache_read_tokens: int = 0,
1323
+ cache_creation_tokens: int = 0,
1324
  ) -> Optional[float]:
1325
  return self._reg.compute_cost(
1326
+ model_id,
1327
+ prompt_tokens,
1328
+ completion_tokens,
1329
+ cache_read_tokens,
1330
+ cache_creation_tokens,
1331
  )
1332
+
1333
  def enrich_model_list(self, model_ids: List[str]) -> List[Dict[str, Any]]:
1334
  return self._reg.enrich_models(model_ids)
1335
+
1336
  def get_all_source_models(self) -> Dict[str, Dict]:
1337
  return self._reg.all_raw_models()
1338
+
1339
  def get_stats(self) -> Dict[str, Any]:
1340
  return self._reg.diagnostics()
1341
+
1342
  async def start(self):
1343
  await self._reg.start()
1344
+
1345
  async def stop(self):
1346
  await self._reg.stop()
1347
+
1348
  async def wait_for_ready(self, timeout: float = 30.0) -> bool:
1349
  return await self._reg.await_ready(timeout)
1350
+
1351
  def is_ready(self) -> bool:
1352
  return self._reg.is_ready