DotCache-Arena / demo_controller.py
Deano Calver
Surface actionable live runner errors
58aaef9
Raw
History Blame Contribute Delete
21.9 kB
from __future__ import annotations
import json
import os
import re
import sys
from dataclasses import asdict, dataclass, field, replace
from pathlib import Path
from typing import Any
from engines.cache import ResultCache
from engines.compare import build_summary_sentence, compute_agreement, compute_memory_reduction, compute_speedup
from engines.dense_runner import run_dense_live
from engines.dotcache_runner import run_dotcache_live
from engines.fixture_builder import FIXTURE_VERSION, build_fixture_result
from engines.llama_runner import run_llama_live_pair
from engines.live_request import (
BENCHMARK_SHORTLIST_POLICY,
canonicalize_benchmark_payload,
resolve_live_runtime_settings,
)
from engines.presets import (
BIT_WIDTHS,
CONTEXT_LENGTHS,
MODEL_BY_KEY,
MODE_HELP,
MODE_LABELS,
PAGE_SIZES,
PRESET_BY_KEY,
QUICK_START_BY_KEY,
SHORTLIST_POLICIES,
)
from engines.zero_gpu import live_request_duration, spaces
@dataclass
class DemoRequest:
model: str
preset: str | None
custom_prompt: str | None
context_length: int
mode: str
page_size: int
bits_k: int
bits_v: int
recent_window: int
sink_window: int
shortlist_policy: str
live_mode: bool = False
compare_against_dense: bool = True
def normalized_prompt(self) -> str | None:
if self.custom_prompt is None:
return None
cleaned = self.custom_prompt.strip()
return cleaned or None
def to_dict(self) -> dict[str, Any]:
payload = asdict(self)
payload["custom_prompt"] = self.normalized_prompt()
return payload
@dataclass
class ControllerResponse:
result: dict[str, Any]
run_badge: str | None
logs: list[str] = field(default_factory=list)
request_key: str | None = None
source_path: str | None = None
REPO_ROOT = Path(__file__).resolve().parent
ASSETS_DIR = REPO_ROOT / "assets"
def load_asset(name: str) -> str:
return (ASSETS_DIR / name).read_text(encoding="utf-8")
def quick_start_payload(key: str) -> dict[str, Any]:
quick_start = QUICK_START_BY_KEY[key]
return {
"model": quick_start.model,
"preset": quick_start.preset,
"custom_prompt": "",
"context_length": quick_start.context_length,
"mode": quick_start.mode,
"page_size": quick_start.page_size,
"bits_k": quick_start.bits_k,
"bits_v": quick_start.bits_v,
"recent_window": quick_start.recent_window,
"sink_window": quick_start.sink_window,
"shortlist_policy": quick_start.shortlist_policy,
}
def execute_request(request: DemoRequest) -> ControllerResponse:
request = _normalize_request(request)
validate_request(request)
cache = ResultCache(REPO_ROOT)
request_dict = request.to_dict()
if not request.live_mode:
precomputed = cache.load_precomputed(request_dict)
if precomputed is not None:
result = _normalize_result(precomputed["result"], request)
return ControllerResponse(
result=result,
run_badge=str(precomputed.get("meta", {}).get("run_badge") or "precomputed"),
logs=list(precomputed.get("meta", {}).get("logs") or []),
request_key=str(precomputed.get("meta", {}).get("request_key") or ""),
source_path=str(precomputed.get("meta", {}).get("source_path") or ""),
)
runtime_cached = None if request.live_mode else cache.load_runtime_cache(request_dict)
if runtime_cached is not None and request.preset is not None and request.normalized_prompt() is None:
cached_notes = list(runtime_cached.get("result", {}).get("comparison", {}).get("notes") or [])
if f"fixture_version:{FIXTURE_VERSION}" not in cached_notes:
runtime_cached = None
if runtime_cached is not None:
result = _normalize_result(runtime_cached["result"], request)
return ControllerResponse(
result=result,
run_badge="cached",
logs=list(runtime_cached.get("meta", {}).get("logs") or []),
request_key=str(runtime_cached.get("meta", {}).get("request_key") or ""),
source_path=str(runtime_cached.get("meta", {}).get("source_path") or ""),
)
if not request.live_mode and request.preset is not None and request.normalized_prompt() is None:
wrapped = _build_cached_fixture_payload(request)
cache.write_runtime_cache(request_dict, wrapped)
result = _normalize_result(wrapped["result"], request)
return ControllerResponse(
result=result,
run_badge="cached",
logs=list(wrapped.get("meta", {}).get("logs") or []),
request_key=str(wrapped.get("meta", {}).get("request_key") or ""),
source_path=str(wrapped.get("meta", {}).get("source_path") or ""),
)
return _execute_live_request(request, cache)
def validate_request(request: DemoRequest) -> None:
if request.model not in MODEL_BY_KEY:
raise ValueError(f"Unsupported model: {request.model}")
if request.preset is not None and request.preset not in PRESET_BY_KEY:
raise ValueError(f"Unsupported preset: {request.preset}")
if request.context_length not in CONTEXT_LENGTHS:
raise ValueError(f"Unsupported context length: {request.context_length}")
if request.mode not in MODE_LABELS:
raise ValueError(f"Unsupported compression mode: {request.mode}")
if request.page_size not in PAGE_SIZES:
raise ValueError(f"Unsupported page size: {request.page_size}")
if request.bits_k not in BIT_WIDTHS or request.bits_v not in BIT_WIDTHS:
raise ValueError("bits_k and bits_v must be one of 2, 4, or 8")
if request.shortlist_policy not in SHORTLIST_POLICIES:
raise ValueError(f"Unsupported shortlist policy: {request.shortlist_policy}")
if request.recent_window < 0 or request.sink_window < 0:
raise ValueError("recent_window and sink_window must be non-negative")
if request.preset is None and not request.normalized_prompt():
raise ValueError("Select a preset or provide a custom prompt.")
def _normalize_engine_payload(payload: dict[str, Any]) -> dict[str, Any]:
return {
"text": str(payload.get("text") or ""),
"tok_per_sec": float(payload.get("tok_per_sec") or 0.0),
"latency_ms_per_token": float(payload.get("latency_ms_per_token") or 0.0),
"kv_bytes": int(payload.get("kv_bytes") or 0),
"trace": list(payload.get("trace") or []),
}
def _normalize_result(payload: dict[str, Any], request: DemoRequest) -> dict[str, Any]:
result = {
"request": request.to_dict(),
"baseline": _normalize_engine_payload(payload.get("baseline") or {}),
"candidate": _normalize_engine_payload(payload.get("candidate") or {}),
"comparison": dict(payload.get("comparison") or {}),
}
baseline = result["baseline"]
candidate = result["candidate"]
comparison = result["comparison"]
comparison["agreement"] = float(comparison.get("agreement") or compute_agreement(baseline["text"], candidate["text"]))
comparison["speedup"] = float(comparison.get("speedup") or compute_speedup(baseline, candidate))
comparison["memory_reduction"] = float(
comparison.get("memory_reduction") or compute_memory_reduction(baseline, candidate)
)
comparison["summary"] = str(
comparison.get("summary")
or build_summary_sentence(
comparison["agreement"],
comparison["speedup"],
comparison["memory_reduction"],
)
)
comparison["notes"] = list(comparison.get("notes") or [])
return result
def _execute_live_request(request: DemoRequest, cache: ResultCache) -> ControllerResponse:
request_payload = request.to_dict()
logs = [
"No precomputed artifact matched this request.",
f"Mode: {MODE_LABELS[request.mode]}",
f"Mode help: {MODE_HELP[request.mode]}",
]
try:
_validate_live_request(request_payload)
live_payload = _run_live_request_with_optional_zero_gpu(request_payload)
logs.extend(list(live_payload.get("logs") or []))
baseline = _normalize_engine_payload(live_payload.get("baseline") or {})
candidate = _normalize_engine_payload(live_payload.get("candidate") or {})
comparison = {
"agreement": compute_agreement(baseline["text"], candidate["text"]),
"speedup": compute_speedup(baseline, candidate),
"memory_reduction": compute_memory_reduction(baseline, candidate),
"summary": "",
"notes": [
"live",
"zerogpu-duration-seconds:" + str(live_request_duration(request_payload)),
],
}
comparison["summary"] = build_summary_sentence(
comparison["agreement"],
comparison["speedup"],
comparison["memory_reduction"],
)
result = _normalize_result(
{
"baseline": baseline,
"candidate": candidate,
"comparison": comparison,
},
request,
)
return ControllerResponse(result=result, run_badge="live", logs=logs)
except Exception as exc:
logs.append(str(exc))
error_summary = _extract_runtime_error_summary(logs)
if error_summary:
logs.append(f"Actionable error: {error_summary}")
print(f"[dotcache-live] {error_summary}", file=sys.stderr, flush=True)
return ControllerResponse(
result=_build_unavailable_result(request, logs),
run_badge="live blocked",
logs=logs,
)
def _validate_live_request(request_payload: dict[str, Any]) -> None:
resolve_live_runtime_settings(
request_payload,
decode_steps=0,
max_live_context=int(os.getenv("DOTCACHE_SPACE_MAX_LIVE_CONTEXT", "4096")),
)
_LOG_NOISE_PATTERNS = (
re.compile(r"^WARNING: The directory '/home/user/\.cache/pip'"),
re.compile(r"^\* Running on local URL:"),
re.compile(r"^\* To create a public link, set `share=True`"),
re.compile(r"^Exception ignored in: <function BaseEventLoop\.__del__"),
re.compile(r"^ValueError: Invalid file descriptor: -1$"),
re.compile(r"asyncio/(base_events|unix_events|selector_events)\.py"),
re.compile(r"/python3\.10/selectors\.py"),
)
def _flatten_log_lines(logs: list[str]) -> list[str]:
lines: list[str] = []
for entry in logs:
for line in str(entry).splitlines():
stripped = line.strip()
if stripped:
lines.append(stripped)
return lines
def _is_noise_log_line(line: str) -> bool:
return any(pattern.search(line) for pattern in _LOG_NOISE_PATTERNS)
def _meaningful_log_lines(logs: list[str]) -> list[str]:
return [line for line in _flatten_log_lines(logs) if not _is_noise_log_line(line)]
def _extract_runtime_error_summary(logs: list[str]) -> str | None:
lines = _meaningful_log_lines(logs)
if not lines:
return None
preferred_prefixes = (
"OSError:",
"RuntimeError:",
"MemoryError:",
"ModuleNotFoundError:",
"ImportError:",
"KeyError:",
"ValueError:",
"AssertionError:",
)
preferred_substrings = (
"CUDA out of memory",
"No space left on device",
"Permission denied",
"does not appear to have a file named",
"Model snapshot",
"runner exit code:",
"Live execution is currently capped",
)
for line in reversed(lines):
if line.startswith(preferred_prefixes) or any(token in line for token in preferred_substrings):
return line
return lines[-1]
@spaces.GPU(duration=live_request_duration)
def _run_live_request_with_optional_zero_gpu(request_payload: dict[str, Any]) -> dict[str, Any]:
model_family = str(MODEL_BY_KEY[str(request_payload.get("model") or "")].family)
if model_family == "llama":
llama_payload, llama_logs = run_llama_live_pair(request_payload)
return {
"baseline": llama_payload.get("baseline") or {},
"candidate": llama_payload.get("candidate") or {},
"logs": llama_logs,
}
if model_family == "qwen35":
exact_payload, exact_logs = run_dotcache_live({**request_payload, "mode": "dense"})
if str(request_payload.get("mode") or "dense") == "dense":
return {
"baseline": exact_payload,
"candidate": exact_payload,
"logs": exact_logs,
}
candidate_payload, candidate_logs = run_dotcache_live(request_payload)
return {
"baseline": exact_payload,
"candidate": candidate_payload,
"logs": [*exact_logs, *candidate_logs],
}
if str(request_payload.get("mode") or "dense") == "dense":
dense_payload, dense_logs = run_dense_live(request_payload)
return {
"baseline": dense_payload,
"candidate": dense_payload,
"logs": dense_logs,
}
dense_payload, dense_logs = run_dense_live({**request_payload, "mode": "dense"})
candidate_payload, candidate_logs = run_dotcache_live(request_payload)
return {
"baseline": dense_payload,
"candidate": candidate_payload,
"logs": [*dense_logs, *candidate_logs],
}
def _build_cached_fixture_payload(request: DemoRequest) -> dict[str, Any]:
result = build_fixture_result(request.to_dict())
return {
"result": result,
"meta": {
"run_badge": "cached",
"request_key": "generated-on-demand",
"source_path": str(REPO_ROOT / "data" / "runtime_cache"),
"logs": [
"No exact precomputed artifact matched this preset configuration.",
"Loaded a benchmark-backed fixture from the bundled Qwen matrix and cached it locally.",
"Preset mode now reflects the bundled compact-task, backend-truth, and LongBench benchmark rows only.",
],
},
}
def _build_unavailable_result(request: DemoRequest, logs: list[str]) -> dict[str, Any]:
prompt = request.normalized_prompt() or PRESET_BY_KEY[request.preset].dense_text
meaningful_logs = _meaningful_log_lines(logs)
recent_logs = meaningful_logs[-8:] if meaningful_logs else [str(entry) for entry in logs[-4:]]
error_summary = _extract_runtime_error_summary(logs)
max_live_context = int(os.getenv("DOTCACHE_SPACE_MAX_LIVE_CONTEXT", "4096"))
blocked_reason = "runtime_limit"
baseline_text = "The reference row was not executed because this live request was blocked before the runner started."
if any("Live execution is currently capped" in entry for entry in recent_logs):
summary = "Live execution is wired, but this request exceeds the current live context limit."
blocked_reason = "hw_limit"
candidate_text = (
f"Live run blocked: selected context {int(request.context_length)} exceeds the current live limit "
f"of {max_live_context} tokens. Lower Context length or raise DOTCACHE_SPACE_MAX_LIVE_CONTEXT."
)
elif any("Live benchmark parity is not wired in this Space yet" in entry for entry in recent_logs):
summary = "Live compare is disabled until the Space uses the same benchmark selector artifacts and profile settings as the paper runs."
blocked_reason = "benchmark_parity"
candidate_text = (
"Live run blocked: this Space is not yet executing the benchmark-faithful Qwen lane. "
"Preset-backed compare is still the valid paper/benchmark view until the learned selector artifacts "
"and benchmark profile wiring are packaged for live execution."
)
elif any("Custom prompts are disabled" in entry for entry in recent_logs):
summary = "Live Qwen compare now replays the benchmark rows directly, so custom prompts are disabled."
blocked_reason = "benchmark_only"
candidate_text = (
"Live run blocked because the Qwen live lane now replays benchmark rows only. "
"Clear the custom prompt and use a bundled preset to run the faithful benchmark path."
)
elif any("LongBench replay is not wired" in entry for entry in recent_logs):
summary = "Preset-backed LongBench compare is valid, but live LongBench replay is not wired in this Space yet."
blocked_reason = "longbench_live_unavailable"
candidate_text = (
"Live LongBench replay is currently unavailable in this Space. "
"Use the preset-backed LongBench row for the valid benchmark view."
)
elif any("requires its learned selector artifact" in entry for entry in recent_logs):
summary = "This benchmark live lane needs a learned selector artifact that is not packaged in the Space."
blocked_reason = "selector_artifact_missing"
candidate_text = (
"Live run blocked because this model's learned selector artifact is not packaged in the Space yet. "
"Preset-backed compare is still valid for the bundled benchmark row."
)
elif any("is not wired in this v1 Space yet" in entry for entry in recent_logs):
summary = "This model is listed for the demo, but its live execution lane is not wired yet."
blocked_reason = "model_family_not_wired"
candidate_text = (
"Live run blocked: this model family is available for preset-backed compare, but the v1 live runner "
"is currently wired for Qwen3.5 lanes only."
)
elif any("supports bits_" in entry for entry in recent_logs):
summary = "Live execution is wired, but this compression setting is outside the current live runner limits."
blocked_reason = "config_limit"
candidate_text = (
f"Live run blocked: this live runner currently supports bits_k/bits_v in the runtime-safe range only. "
"Try 4-bit settings first."
)
elif any("requires HF_TOKEN" in entry for entry in recent_logs):
summary = "Live execution for this gated model requires an HF token in the Space settings."
blocked_reason = "auth_required"
candidate_text = (
"Live run blocked because this gated model needs HF_TOKEN in the Space settings, and the current runtime "
"does not have one configured."
)
elif any("Permission denied" in entry for entry in recent_logs) or any("snapshot_download" in entry for entry in recent_logs):
summary = "Live execution could not fetch the selected model into the Space runtime cache."
blocked_reason = "cache_download_limit"
candidate_text = (
"Live run blocked by the current Space storage/download path while fetching model weights. "
"This is a runtime environment issue, not an output-quality failure."
)
elif error_summary and (
"does not appear to have a file named" in error_summary
or "Model snapshot" in error_summary
or "No space left on device" in error_summary
):
summary = "Live execution could not prepare a complete local model snapshot inside the Space runtime."
blocked_reason = "cache_download_limit"
candidate_text = f"Live run blocked while preparing model files: {error_summary}"
else:
summary = "Preset-backed compare is ready; this live request could not be completed with the current runtime."
candidate_text = (
f"Live run failed: {error_summary}"
if error_summary
else "Live run blocked by the current runtime configuration. Check the Logs tab for the exact runner error."
)
baseline = {
"text": baseline_text,
"tok_per_sec": 0.0,
"latency_ms_per_token": 0.0,
"kv_bytes": 0,
"trace": [],
}
candidate = {
"text": candidate_text,
"tok_per_sec": 0.0,
"latency_ms_per_token": 0.0,
"kv_bytes": 0,
"trace": [],
}
comparison = {
"status": "blocked",
"blocked_reason": blocked_reason,
"blocked_prompt": prompt,
"blocked_context_length": int(request.context_length),
"max_live_context": max_live_context,
"agreement": 0.0,
"speedup": 1.0,
"memory_reduction": 1.0,
"summary": summary,
"error_summary": error_summary or "",
"notes": [
"The default live path now uses local wrapper scripts under scripts/.",
"Use DOTCACHE_SPACE_MAX_LIVE_CONTEXT, DOTCACHE_SPACE_LIVE_DECODE_STEPS, and backend env vars to tune live runs.",
"Set DOTCACHE_SPACE_DENSE_RUNNER_CMD and DOTCACHE_SPACE_DOTCACHE_RUNNER_CMD to override the default scripts.",
*recent_logs,
],
}
return {
"request": request.to_dict(),
"baseline": baseline,
"candidate": candidate,
"comparison": comparison,
}
def result_as_pretty_json(result: dict[str, Any]) -> str:
return json.dumps(result, indent=2, sort_keys=True)
def _normalize_request(request: DemoRequest) -> DemoRequest:
normalized = canonicalize_benchmark_payload(request.to_dict())
if normalized == request.to_dict():
return request
return replace(
request,
page_size=int(normalized["page_size"]),
bits_k=int(normalized["bits_k"]),
bits_v=int(normalized["bits_v"]),
recent_window=int(normalized["recent_window"]),
sink_window=int(normalized["sink_window"]),
shortlist_policy=str(normalized.get("shortlist_policy") or BENCHMARK_SHORTLIST_POLICY),
)