"""LLM client factory. Critical infrastructure: every other module in this codebase imports ``get_llm`` from here and never instantiates an LLM directly. The ``MODEL_BACKEND`` env var picks one of three backends: * ``minimax`` — ChatOpenAI-compatible client at the MiniMax API. * ``vllm`` — ChatOpenAI-compatible client at a local vLLM endpoint. * ``replay`` — Reads pre-recorded JSON from ``replays/.json``. """ from __future__ import annotations import json import logging import os from pathlib import Path from typing import Any from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import AIMessage, BaseMessage from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool from langchain_openai import ChatOpenAI from pydantic import Field from agent.constants import ( BACKEND_MINIMAX, BACKEND_REPLAY, BACKEND_VLLM, ENV_MINIMAX_API_KEY, ENV_MODEL_BACKEND, ENV_SCENARIO_ID, ENV_VLLM_ENDPOINT, MINIMAX_BASE_URL, MINIMAX_MODEL, VALID_BACKENDS, VLLM_MODEL, ) logger = logging.getLogger(__name__) REPLAYS_DIR = Path(__file__).resolve().parent.parent / "replays" def get_llm(scenario_id: str | None = None) -> BaseChatModel: """Build an LLM client based on the ``MODEL_BACKEND`` env var. Args: scenario_id: Required when ``MODEL_BACKEND=replay``. Names the JSON file under ``replays/`` to play back. May also come from the ``REPLAY_SCENARIO_ID`` env var. Returns: A ``BaseChatModel`` instance suitable for use by LangGraph. Raises: ValueError: If ``MODEL_BACKEND`` is unset, unknown, or the backend's required env vars are missing. """ backend = os.getenv(ENV_MODEL_BACKEND, "").strip().lower() if not backend: raise ValueError( f"{ENV_MODEL_BACKEND} is not set. Expected one of: {sorted(VALID_BACKENDS)}." ) if backend not in VALID_BACKENDS: raise ValueError( f"{ENV_MODEL_BACKEND}={backend!r} is not recognized. " f"Expected one of: {sorted(VALID_BACKENDS)}." ) if backend == BACKEND_MINIMAX: return _build_minimax() if backend == BACKEND_VLLM: return _build_vllm() return _build_replay(scenario_id) def _build_minimax() -> ChatOpenAI: api_key = os.getenv(ENV_MINIMAX_API_KEY) if not api_key: raise ValueError( f"{ENV_MINIMAX_API_KEY} is required when {ENV_MODEL_BACKEND}={BACKEND_MINIMAX}." ) return ChatOpenAI( model=MINIMAX_MODEL, api_key=api_key, base_url=MINIMAX_BASE_URL, ) def _build_vllm() -> ChatOpenAI: endpoint = os.getenv(ENV_VLLM_ENDPOINT) if not endpoint: raise ValueError( f"{ENV_VLLM_ENDPOINT} is required when {ENV_MODEL_BACKEND}={BACKEND_VLLM}." ) return ChatOpenAI( model=VLLM_MODEL, api_key="not-needed", base_url=endpoint, ) def _build_replay(scenario_id: str | None) -> ReplayClient: sid = scenario_id or os.getenv(ENV_SCENARIO_ID) if not sid: raise ValueError( "A scenario_id is required when " f"{ENV_MODEL_BACKEND}={BACKEND_REPLAY}. " f"Pass it via get_llm(scenario_id=...) or set {ENV_SCENARIO_ID}." ) return ReplayClient(scenario_id=sid) class ReplayClient(BaseChatModel): """Plays back a recorded LLM session — no network calls. Reads ``replays/.json`` and yields the ``AIMessage`` / tool-call sequence that the live agent produced when recorded. Used by the public demo when the GPU is off. File format (one record per turn): [{"content": "...", "tool_calls": [...]}, ...] """ scenario_id: str = Field(..., description="Replay file stem under replays/.") replays_dir: Path = Field(default_factory=lambda: REPLAYS_DIR) _turns: list[dict[str, Any]] | None = None _index: int = 0 @property def _llm_type(self) -> str: return "replay" def bind_tools( self, tools: list[BaseTool | dict[str, Any] | type] | None = None, **kwargs: Any, ) -> Runnable: """No-op: tool-call sequences are already encoded in the replay file.""" return self def _load(self) -> list[dict[str, Any]]: if self._turns is not None: return self._turns path = self.replays_dir / f"{self.scenario_id}.json" if not path.exists(): raise FileNotFoundError(f"Replay file not found: {path}") self._turns = json.loads(path.read_text()) return self._turns def _generate( self, messages: list[BaseMessage], stop: list[str] | None = None, run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: turns = self._load() if self._index >= len(turns): raise IndexError(f"Replay {self.scenario_id!r} exhausted after {len(turns)} turns.") turn = turns[self._index] self._index += 1 msg = AIMessage( content=turn.get("content", ""), tool_calls=turn.get("tool_calls", []), ) return ChatResult(generations=[ChatGeneration(message=msg)])