the-apprentice / oracles /llm_client.py
AndrewRqy
Pre-warm Modal LLM at app startup + document cold-start behavior in README
a777ed3
Raw
History Blame Contribute Delete
11.3 kB
"""LLM client for The Wizard's Oracles.
Thin wrapper over the OpenAI SDK pointed at a Modal-hosted vLLM endpoint, with
a mock fallback for offline/demo runs. Mirrors the pattern used by
``forest/llm_client.py`` and ``apprentice_app/apprentice/llm_client.py``.
Env vars:
MODAL_URL base URL of the vLLM endpoint
MODAL_KEY optional bearer / Modal-Key header value
MODAL_SECRET optional Modal-Secret header value
ORACLES_FORCE_MOCK if "1", forces mock mode even when MODAL_URL is set
(critical for the demo: lets the app run mock-only)
Callers handle the mock fallback themselves; this client raises
``RuntimeError("LLM not configured (mock mode)")`` whenever a network call is
attempted while ``using_mock`` is True.
"""
from __future__ import annotations
import json
import os
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
try:
from openai import OpenAI
_HAS_OPENAI = True
except ImportError: # pragma: no cover
_HAS_OPENAI = False
def _force_mock_env() -> bool:
return os.environ.get("ORACLES_FORCE_MOCK", "").strip() == "1"
def _trace_dir() -> Optional[Path]:
"""Resolve where LLM-call traces should be appended.
Order of precedence:
1. ``ORACLES_TRACE_DISABLE=1`` → return None (no tracing).
2. ``ORACLES_TRACE_DIR`` set to a non-empty path → use that path.
3. Otherwise → default to ``<app_root>/traces/``.
Tracing is on by default so the Sharing-is-Caring badge's trace
deliverable is always populated by the time the user finishes a run.
Opt out by setting ``ORACLES_TRACE_DISABLE=1`` if you don't want
prompts/responses landing on local disk.
"""
if os.environ.get("ORACLES_TRACE_DISABLE", "").strip() == "1":
return None
d = os.environ.get("ORACLES_TRACE_DIR", "").strip()
if not d:
# Default: <app_root>/traces — sits next to app.py so it's
# discoverable without hunting through /tmp.
d = str(Path(__file__).resolve().parent.parent / "traces")
p = Path(d).expanduser()
try:
p.mkdir(parents=True, exist_ok=True)
except OSError:
return None
return p
_TRACE_SESSION_ID = uuid.uuid4().hex[:12]
def _announce_trace_dir() -> None:
"""Print a one-line notice on first import so users know where the
LLM-call traces will land. Silent if tracing is disabled."""
import sys
d = _trace_dir()
if d is None:
print(
"[trace] tracing disabled (ORACLES_TRACE_DISABLE=1)",
file=sys.stderr,
)
return
print(
f"[trace] LLM calls will be appended to "
f"{d / f'oracles-trace-{_TRACE_SESSION_ID}.jsonl'}",
file=sys.stderr,
)
_announce_trace_dir()
def _write_trace(record: dict) -> None:
d = _trace_dir()
if d is None:
return
path = d / f"oracles-trace-{_TRACE_SESSION_ID}.jsonl"
try:
with path.open("a", encoding="utf-8") as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
except OSError:
pass
@dataclass
class LLMConfig:
base_url: Optional[str]
api_key: Optional[str]
extra_headers: dict = field(default_factory=dict)
# Which served-model name to request. The deployed vLLM container
# co-serves "llm" (bare Qwen2.5-14B) and "oracle-wizard-lora" (the
# fine-tune). Default = fine-tune so the app uses the Well-Tuned
# adapter out of the box; override with ORACLES_LLM_MODEL=llm to
# A/B against the base model.
model_alias: str = "oracle-wizard-lora"
@classmethod
def from_env(cls) -> "LLMConfig":
base_url = os.environ.get("MODAL_URL")
modal_key = os.environ.get("MODAL_KEY")
modal_secret = os.environ.get("MODAL_SECRET")
model_alias = os.environ.get("ORACLES_LLM_MODEL", "oracle-wizard-lora").strip() \
or "oracle-wizard-lora"
headers: dict = {}
if modal_key:
headers["Modal-Key"] = modal_key
if modal_secret:
headers["Modal-Secret"] = modal_secret
return cls(
base_url=base_url,
api_key=modal_key,
extra_headers=headers,
model_alias=model_alias,
)
@property
def is_configured(self) -> bool:
if _force_mock_env():
return False
return bool(self.base_url and _HAS_OPENAI)
class LLMClient:
"""Thin wrapper over the OpenAI SDK pointed at the Modal vLLM endpoint.
When the config is not fully populated (or ``ORACLES_FORCE_MOCK=1`` is
set), ``using_mock`` is True and both completion methods raise
``RuntimeError("LLM not configured (mock mode)")`` — the caller is
expected to swap in mock content instead.
"""
_MOCK_ERROR = "LLM not configured (mock mode)"
def __init__(self, config: Optional[LLMConfig] = None) -> None:
self.config: LLMConfig = config if config is not None else LLMConfig.from_env()
self._client: Optional[OpenAI] = None # type: ignore[valid-type]
# Diagnostics: record what the most-recent call requested vs. what
# the server actually echoed back. Lets callers tell base-vs-LoRA
# from an error message after the fact.
self.last_requested_model: str = ""
self.last_returned_model: str = ""
if self.config.is_configured:
assert self.config.base_url is not None
self._client = OpenAI(
base_url=self.config.base_url.rstrip("/") + "/v1",
api_key=self.config.api_key or "not-used",
default_headers=dict(self.config.extra_headers),
timeout=60, # resolution call can be longer than apprentice
)
# Fire-and-forget warmup: hit /v1/models in a daemon thread so
# Modal's scaled-to-zero container starts spinning up while the
# player is still inscribing oracles. By the time they click
# "let the journey begin" the container should be warm.
self._kick_warmup()
def _kick_warmup(self) -> None:
"""Send a non-blocking GET /v1/models to wake a cold Modal container.
Runs in a daemon thread so app startup never blocks on the warmup.
Failures are swallowed — the real call later will surface any
connectivity issues with a proper error message.
"""
import threading
import urllib.request
if not self.config.base_url:
return
url = self.config.base_url.rstrip("/") + "/v1/models"
headers = dict(self.config.extra_headers)
def _ping() -> None:
try:
req = urllib.request.Request(url, headers=headers, method="GET")
# Long timeout — vLLM cold start with cached weights is
# 30-90s. We don't care if it eventually succeeds; we just
# want to trigger the container allocation.
with urllib.request.urlopen(req, timeout=180) as resp:
resp.read()
except Exception:
pass
threading.Thread(target=_ping, daemon=True, name="llm-warmup").start()
@property
def using_mock(self) -> bool:
return self._client is None
def complete_json(
self,
system: str,
user: str,
max_tokens: int = 700,
temperature: float = 0.9,
model: str = "",
) -> dict:
if self._client is None:
raise RuntimeError(self._MOCK_ERROR)
full_user = user + "\n\nRespond with valid JSON only."
requested_model = model or self.config.model_alias
self.last_requested_model = requested_model
t0 = time.time()
r = self._client.chat.completions.create(
model=requested_model,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": full_user},
],
max_tokens=max_tokens,
temperature=temperature,
response_format={"type": "json_object"},
)
latency_ms = int((time.time() - t0) * 1000)
content = r.choices[0].message.content or ""
returned_model = getattr(r, "model", "") or ""
self.last_returned_model = returned_model
requested_alias = model or self.config.model_alias
_write_trace({
"ts": time.time(),
"session": _TRACE_SESSION_ID,
"mode": "json",
# Both sides of the model contract — the alias we asked vLLM
# for ("oracle-wizard-lora" / "llm") AND the model id vLLM
# echoed back. They should normally match; logging both lets a
# trace consumer detect server-side fallbacks (e.g. a LoRA
# request that ended up served by the base) and identify
# exactly which model produced the response. Required by the
# Sharing-is-Caring badge so judges can reproduce the call.
"model": requested_alias, # legacy field, alias = requested
"model_requested": requested_alias,
"model_returned": returned_model,
"using_lora": "lora" in (returned_model or "").lower(),
"temperature": temperature,
"max_tokens": max_tokens,
"system": system,
"user": full_user,
"response": content,
"latency_ms": latency_ms,
"usage": getattr(r, "usage", None) and r.usage.model_dump(),
})
return json.loads(content)
def complete_text(
self,
system: str,
user: str,
max_tokens: int = 700,
temperature: float = 0.9,
model: str = "",
) -> str:
if self._client is None:
raise RuntimeError(self._MOCK_ERROR)
requested_model = model or self.config.model_alias
self.last_requested_model = requested_model
t0 = time.time()
r = self._client.chat.completions.create(
model=requested_model,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user},
],
max_tokens=max_tokens,
temperature=temperature,
)
latency_ms = int((time.time() - t0) * 1000)
text = (r.choices[0].message.content or "").strip()
returned_model = getattr(r, "model", "") or ""
self.last_returned_model = returned_model
requested_alias = model or self.config.model_alias
_write_trace({
"ts": time.time(),
"session": _TRACE_SESSION_ID,
"mode": "text",
"model": requested_alias, # legacy field
"model_requested": requested_alias,
"model_returned": returned_model,
"using_lora": "lora" in (returned_model or "").lower(),
"temperature": temperature,
"max_tokens": max_tokens,
"system": system,
"user": user,
"response": text,
"latency_ms": latency_ms,
"usage": getattr(r, "usage", None) and r.usage.model_dump(),
})
return text