multi-agent-lab / src /models /litellm_provider.py
agharsallah
feat(observability): instrument models, conductor, memory, ledger (Units 2-6)
a196e34
Raw
History Blame Contribute Delete
15.1 kB
"""LiteLLM-backed provider β€” one gateway, every logical profile.
This is the *transport* the :class:`~src.models.router.ModelRouter` uses on the
live path: it replaces hand-rolled per-vendor SDK calls with a single idiomatic
``litellm.completion(...)`` call. Routing (profile β†’ concrete model + endpoint)
is unchanged and still lives in the router; this class only knows how to *call* a
model and report what it cost.
Two things it adds over the plain OpenAI-compatible provider:
* **Real cost.** LiteLLM prices the call from its model database, so the
Governor's ``hourly_budget_usd`` becomes real on the live path. Cost is
exposed on ``last_usage["cost_usd"]`` (and ``last_cost``); offline it is 0.
* **One model string for any endpoint.** An OpenAI-compatible custom endpoint
(the Modal/vLLM servers in ``modal/``) is reached with the LiteLLM model
string ``openai/<served_model_id>`` plus an ``api_base`` β€” no per-vendor
branching.
``litellm`` is imported lazily so ``import src.models.*`` (and ``import app``)
work with the package not installed; offline never touches this class. The call
is kept thin and standard so the structured layer can wrap it
(``instructor.from_litellm(litellm.completion)``) without fighting this code.
See ADR-0015.
On top of the plain :meth:`complete`, this gateway also offers
:meth:`complete_structured`: it wraps the same ``litellm.completion`` with
Instructor to return a *validated* Pydantic instance (kind constrained to the
agent's grant, retried on validation failure), reading usage and cost from the
raw completion exactly as :meth:`complete` does. ``instructor`` is likewise
lazy-imported, so the offline path needs neither it nor ``litellm``. See
ADR-0016.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from src import observability as obs
from src.models.openai_compat import OpenAICompatProvider
from src.models.provider import ModelProvider, model_error
if TYPE_CHECKING:
from pydantic import BaseModel
@dataclass
class LiteLLMProvider(ModelProvider):
"""Route one logical profile through the LiteLLM gateway.
``model`` is a LiteLLM model string. For an OpenAI-compatible custom
endpoint (Modal/vLLM) it is ``openai/<served_model_id>`` and ``api_base``
points at the endpoint's ``/v1`` URL. Decoding (``temperature`` /
``max_tokens``) and the binding come from the router's per-profile spec.
"""
model: str
api_base: str | None = None
api_key: str | None = None
temperature: float = 0.8
max_tokens: int = 256
max_retries: int = 2
"""Validation retries for :meth:`complete_structured` (live structured output)."""
num_retries: int = 2
"""Transport retries LiteLLM makes on a transient call failure β€” a dropped
connection, a timeout, a 5xx. Lets a flaky endpoint self-heal mid-demo before the
call gives up and returns the failure sentinel."""
structured_mode: str = "json_schema"
"""Instructor mode for :meth:`complete_structured` (an ``instructor.Mode`` member name,
case-insensitive). Defaults to ``json_schema`` β€” vLLM **guided decoding** via
``response_format``, which is parser-independent: it constrains the output to the schema
(``kind`` can't be an unauthorised value) without needing a tool-call parser. This is
deliberate: not every served model ships a tool parser (e.g. MiniCPM4.1 emits a custom
``<|tool_call_start|>`` format vLLM 0.21.0 has no parser for), so Instructor's default
``tools`` mode 400s there. ``json`` (plain ``json_object`` + schema-in-prompt) is the
fallback if a backend rejects ``json_schema``; ``tools`` restores the old behaviour."""
_last_usage: dict = field(default_factory=dict, init=False, repr=False)
_last_cost: float = field(default=0.0, init=False, repr=False)
_last_reasoning: str = field(default="", init=False, repr=False)
def complete(self, role: str, prompt: str) -> str:
litellm = self._litellm()
with obs.span("llm.call", **self._span_request_attrs(role)):
try:
response = litellm.completion(
model=self.model,
api_base=self.api_base,
api_key=self._resolved_api_key(),
messages=self._messages(role, prompt),
temperature=self.temperature,
max_tokens=self.max_tokens,
num_retries=self.num_retries,
)
text = (response.choices[0].message.content or "").strip()
self._capture_usage(litellm, response, prompt, text)
self._emit_telemetry(role, prompt, text, structured=False)
return text
except Exception as exc:
self._zero_usage()
obs.log("llm.error", level="warning", model=self.model, role=role, error=str(exc))
return model_error(exc)
def complete_structured(
self,
role: str,
prompt: str,
response_model: type["BaseModel"],
) -> "BaseModel":
"""Return a validated *response_model* instance via Instructor.
Wraps the same ``litellm.completion`` with
``instructor.from_litellm`` and asks for *response_model*, retrying a few
times on validation failure. Because the model is constrained and
re-prompted until it validates, the caller gets a typed, schema-valid
object β€” the live path never falls back to wrapping malformed prose
(see ADR-0016). Usage and cost are read from the raw completion exactly
as :meth:`complete` does, so token/cost metering is unchanged.
``instructor`` is imported lazily; offline never reaches this method.
On error the usage is zeroed and the exception propagates so the caller
can fall back to the prompt-and-parse path.
"""
litellm = self._litellm()
try:
import instructor
except ImportError as exc: # pragma: no cover - exercised only when unset
raise ImportError(
"instructor package is required for complete_structured(). Install it with: uv pip install instructor"
) from exc
# Guided-JSON by default (see ``structured_mode``): constrain the output to the
# schema via vLLM's ``response_format`` rather than tool calling, so a model with no
# tool-call parser still returns a validated payload instead of a 400.
mode = getattr(instructor.Mode, self.structured_mode.upper(), instructor.Mode.JSON_SCHEMA)
client = instructor.from_litellm(litellm.completion, mode=mode)
with obs.span("llm.structured", **{**self._span_request_attrs(role), "llm.mode": self.structured_mode}):
try:
result, response = client.create_with_completion(
model=self.model,
api_base=self.api_base,
api_key=self._resolved_api_key(),
messages=self._messages(role, prompt),
response_model=response_model,
max_retries=self.max_retries,
num_retries=self.num_retries,
temperature=self.temperature,
max_tokens=self.max_tokens,
)
text = getattr(result, "text", "") or ""
self._capture_usage(litellm, response, prompt, text)
self._emit_telemetry(role, prompt, text, structured=True)
return result
except Exception as exc:
self._zero_usage()
obs.log("llm.error", level="warning", model=self.model, role=role, structured=True, error=str(exc))
raise
@property
def last_reasoning(self) -> str:
"""The model's separated thinking from the most recent call, or "".
Reasoning models served on vLLM (e.g. the gemma4 / qwen3 reasoning parsers)
return their chain-of-thought in ``message.reasoning_content``, leaving
``content`` for the answer. We surface it so the UI can show it under the
mind-reader toggle β€” it is never fed back into any agent's prompt."""
return self._last_reasoning
@property
def last_cost(self) -> float:
"""Metered USD cost of the most recent call (0.0 offline)."""
return self._last_cost
# ── telemetry (shared by complete / complete_structured) ────────────────────
def _span_request_attrs(self, role: str) -> dict:
"""GenAI request attributes for an LLM span β€” never includes the api key."""
return {
"gen_ai.system": "litellm",
"gen_ai.request.model": self.model,
"gen_ai.request.temperature": self.temperature,
"gen_ai.request.max_tokens": self.max_tokens,
"llm.api_base": self.api_base or "",
"mal.role": role,
}
def _emit_telemetry(self, role: str, prompt: str, text: str, *, structured: bool) -> None:
"""Attach usage/cost/prompt to the active span, count the call, and log it.
The full prompt + completion + reasoning ride on the span (truncated in the
UI store) and on a DEBUG ``llm.exchange`` log, so a reviewer can read exactly
what was sent to each model. INFO ``llm.call`` carries the metered summary.
"""
usage = self._last_usage or {}
prompt_tokens = int(usage.get("prompt_tokens", 0) or 0)
completion_tokens = int(usage.get("completion_tokens", 0) or 0)
obs.add_span_attrs(
**{
"gen_ai.usage.input_tokens": prompt_tokens,
"gen_ai.usage.output_tokens": completion_tokens,
"llm.cost_usd": self._last_cost,
"llm.structured": structured,
"llm.prompt": prompt,
"llm.completion": text,
"llm.reasoning": self._last_reasoning or "",
}
)
obs.record_llm_call(
self.model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cost_usd=self._last_cost,
)
obs.log(
"llm.call",
role=role,
model=self.model,
structured=structured,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cost_usd=round(self._last_cost, 6),
)
obs.log(
"llm.exchange",
level="debug",
role=role,
model=self.model,
prompt=prompt,
completion=text,
reasoning=self._last_reasoning or "",
)
# ── call helpers (shared by complete / complete_structured) ─────────────────
@staticmethod
def _litellm():
"""Lazy-import litellm; raise a clear install hint if it is missing."""
try:
import litellm
except ImportError as exc: # pragma: no cover - exercised only when unset
raise ImportError(
"litellm package is required for LiteLLMProvider. Install it with: uv pip install litellm"
) from exc
return litellm
@staticmethod
def _messages(role: str, prompt: str) -> list[dict[str, str]]:
return [
{"role": "system", "content": OpenAICompatProvider._system_for_role(role)},
{"role": "user", "content": prompt},
]
def _resolved_api_key(self) -> str | None:
# A self-served vLLM endpoint accepts any token; default to the conventional
# placeholder so a configured custom endpoint never trips on a missing key.
return self.api_key or ("EMPTY" if self.api_base else None)
def _capture_usage(self, litellm, response, prompt: str, text: str) -> None:
"""Record tokens + cost from *response* onto ``last_usage``/``last_cost``."""
from src.models.provider import estimate_tokens
usage = getattr(response, "usage", None)
if usage is not None:
prompt_tokens = int(getattr(usage, "prompt_tokens", 0) or 0)
completion_tokens = int(getattr(usage, "completion_tokens", 0) or 0)
total_tokens = int(getattr(usage, "total_tokens", 0) or 0) or (prompt_tokens + completion_tokens)
else:
prompt_tokens, completion_tokens = estimate_tokens(prompt), estimate_tokens(text)
total_tokens = prompt_tokens + completion_tokens
cost = self._extract_cost(litellm, response)
self._last_cost = cost
self._last_reasoning = self._extract_reasoning(response)
self._last_usage = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"cost_usd": cost,
}
def _zero_usage(self) -> None:
self._last_cost = 0.0
self._last_reasoning = ""
self._last_usage = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"cost_usd": 0.0,
}
@staticmethod
def _extract_reasoning(response) -> str:
"""Pull the model's separated thinking from *response*, or "".
vLLM reasoning parsers surface it as ``message.reasoning_content`` (some
providers as ``reasoning`` or under ``provider_specific_fields``). All
access is defensive β€” a non-reasoning model simply yields ""."""
try:
message = response.choices[0].message
except (AttributeError, IndexError, TypeError):
return ""
candidates = [getattr(message, "reasoning_content", None), getattr(message, "reasoning", None)]
psf = getattr(message, "provider_specific_fields", None)
if isinstance(psf, dict):
candidates += [psf.get("reasoning_content"), psf.get("reasoning")]
for value in candidates:
if isinstance(value, str) and value.strip():
return value.strip()
return ""
@staticmethod
def _extract_cost(litellm, response) -> float:
"""Best-effort USD cost for *response*; 0.0 if the model is unpriced.
Prefers the value LiteLLM already attached during the call
(``_hidden_params["response_cost"]``); falls back to pricing the response
directly. Both paths are guarded β€” an unknown/custom model (e.g. a
self-served vLLM endpoint) simply yields 0.0 rather than raising.
"""
hidden = getattr(response, "_hidden_params", None)
if isinstance(hidden, dict):
cost = hidden.get("response_cost")
if isinstance(cost, (int, float)):
return float(cost)
try:
cost = litellm.completion_cost(completion_response=response)
return float(cost or 0.0)
except Exception:
return 0.0