Spaces:
Running
Running
File size: 4,488 Bytes
0157ac7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | """Provider model-list response parsing helpers."""
from __future__ import annotations
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from typing import Any
from providers.exceptions import ModelListResponseError
@dataclass(frozen=True, slots=True)
class ProviderModelInfo:
"""Internal provider model metadata used for gateway model-list shaping."""
model_id: str
supports_thinking: bool | None = None
def model_infos_from_ids(
model_ids: Iterable[str], *, supports_thinking: bool | None = None
) -> frozenset[ProviderModelInfo]:
"""Build unknown-capability model metadata from plain provider model ids."""
return frozenset(
ProviderModelInfo(model_id=model_id, supports_thinking=supports_thinking)
for model_id in model_ids
if model_id.strip()
)
def extract_openai_model_ids(payload: Any, *, provider_name: str) -> frozenset[str]:
"""Extract model ids from an OpenAI-compatible ``/models`` response."""
data = _field(payload, "data")
if not _is_sequence(data):
raise _malformed(provider_name, "expected top-level data array")
model_ids: set[str] = set()
for item in data:
model_id = _field(item, "id")
if not isinstance(model_id, str) or not model_id.strip():
raise _malformed(provider_name, "expected every data item to include id")
model_ids.add(model_id)
if not model_ids:
raise _malformed(provider_name, "response did not include any model ids")
return frozenset(model_ids)
def extract_openrouter_tool_model_ids(
payload: Any, *, provider_name: str
) -> frozenset[str]:
"""Extract OpenRouter model ids that advertise tool-use support."""
return frozenset(
info.model_id
for info in extract_openrouter_tool_model_infos(
payload, provider_name=provider_name
)
)
def extract_openrouter_tool_model_infos(
payload: Any, *, provider_name: str
) -> frozenset[ProviderModelInfo]:
"""Extract OpenRouter tool-capable model ids with thinking capability metadata."""
data = _field(payload, "data")
if not _is_sequence(data):
raise _malformed(provider_name, "expected top-level data array")
model_infos: set[ProviderModelInfo] = set()
for item in data:
model_id = _field(item, "id")
if not isinstance(model_id, str) or not model_id.strip():
raise _malformed(provider_name, "expected every data item to include id")
supported_parameters = _field(item, "supported_parameters")
if not _is_sequence(supported_parameters):
continue
supported_parameter_names = {
param for param in supported_parameters if isinstance(param, str)
}
if supported_parameter_names.isdisjoint({"tools", "tool_choice"}):
continue
model_infos.add(
ProviderModelInfo(
model_id=model_id,
supports_thinking="reasoning" in supported_parameter_names,
)
)
return frozenset(model_infos)
def extract_ollama_model_ids(payload: Any, *, provider_name: str) -> frozenset[str]:
"""Extract model ids from Ollama's native ``/api/tags`` response."""
models = _field(payload, "models")
if not _is_sequence(models):
raise _malformed(provider_name, "expected top-level models array")
model_ids: set[str] = set()
for item in models:
item_ids: list[str] = []
for key in ("model", "name"):
value = _field(item, key)
if isinstance(value, str) and value.strip():
item_ids.append(value)
if not item_ids:
raise _malformed(
provider_name,
"expected every models item to include model or name",
)
model_ids.update(item_ids)
if not model_ids:
raise _malformed(provider_name, "response did not include any model ids")
return frozenset(model_ids)
def _field(item: Any, name: str) -> Any:
if isinstance(item, Mapping):
return item.get(name)
return getattr(item, name, None)
def _is_sequence(value: Any) -> bool:
return isinstance(value, Sequence) and not isinstance(
value, str | bytes | bytearray
)
def _malformed(provider_name: str, reason: str) -> ModelListResponseError:
return ModelListResponseError(
f"{provider_name} model-list response is malformed: {reason}"
)
|