metrollm / tools.py
Remco Hendriks
sync: HF Space ↔ /simulate prompt-construction parity
bb7b69a
"""In-process tool dispatch.
We launch `harness/mock_server` (FastAPI) as a uvicorn server on an internal
port from a background thread of this same Python process. Gradio's main
thread keeps the GPU-decorated model + UI; the mock_server thread services
HTTP tool dispatch over localhost. From HF Spaces' perspective this is a
single container with one externally-exposed port (Gradio 7860) — the
mock_server port is internal-only.
Earlier ASGITransport approach was abandoned because httpx 0.27+ ASGI
transport is async-only and incompatible with Gradio's sync generator
handlers + the synchronous `@spaces.GPU` decorator pattern.
"""
from __future__ import annotations
import os
import socket
import sys
import threading
import time
from pathlib import Path
# mock_server reads SYSTEM_NAME at import to set its default.
os.environ.setdefault("SYSTEM_NAME", "marta")
_HERE = Path(__file__).resolve().parent
# Monorepo: harness/ at parent. Space deploy: harness/ at HERE.
REPO_ROOT = _HERE if (_HERE / "harness").is_dir() else _HERE.parent
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
import httpx # noqa: E402
import uvicorn # noqa: E402
import harness.mock_server as _mock_module # noqa: E402
from harness.mock_server import app as _mock_app # noqa: E402
# mock_server only sets `_system_name` via its argparse __main__. When
# launched programmatically we have to assign it ourselves so `_system_for_case`
# fallback works for any case_id we haven't seen yet.
_mock_module._system_name = os.environ.get("SYSTEM_NAME", "marta")
def _free_port() -> int:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("127.0.0.1", 0))
port = s.getsockname()[1]
s.close()
return port
_PORT = int(os.environ.get("MOCK_SERVER_PORT", "0")) or _free_port()
_BASE = f"http://127.0.0.1:{_PORT}"
def _serve():
config = uvicorn.Config(
_mock_app, host="127.0.0.1", port=_PORT, log_level="warning",
access_log=False, lifespan="off",
)
server = uvicorn.Server(config)
server.run()
_thread = threading.Thread(target=_serve, daemon=True, name="mock_server")
_thread.start()
# Block briefly until the port is accepting connections so the first dispatch
# never sees a connection-refused.
_deadline = time.time() + 10.0
while time.time() < _deadline:
try:
with socket.create_connection(("127.0.0.1", _PORT), timeout=0.5):
break
except OSError:
time.sleep(0.1)
else:
print(f"[tools] WARN: mock_server didn't bind on {_PORT} within 10s", flush=True)
print(f"[tools] mock_server up on {_BASE}", flush=True)
_client = httpx.Client(base_url=_BASE, timeout=15.0)
def dispatch(
name: str,
arguments: dict,
session_id: str,
system_context: dict | None = None,
) -> dict:
"""POST tool call to in-process mock server. Surface 4xx/5xx response
bodies so the model can self-correct on validation errors.
Mirrors `harness.runner._call_mock_tool`: when `system_context` carries
a `temporal_context.current_time` (or top-level `current_time`), it's
injected into `disruption_feed` calls so the mock server's valid_from/
valid_until filtering matches what /simulate would see.
"""
body = {**arguments, "case_id": session_id}
if name == "disruption_feed" and system_context:
current_time = (
(system_context.get("temporal_context") or {}).get("current_time")
or system_context.get("current_time")
)
if current_time:
body.setdefault("current_time", current_time)
try:
r = _client.post(f"/{name}", json=body)
except Exception as e:
return {"error": str(e), "tool": name}
if r.status_code >= 400:
try:
detail = r.json()
except Exception:
detail = r.text
return {"error": detail, "tool": name, "status_code": r.status_code}
try:
return r.json()
except Exception as e:
return {"error": f"non-JSON response: {e}", "tool": name}
def set_disruptions(session_id: str, system: str, disruptions: list[dict]) -> dict:
"""Stage per-session disruptions on the mock server before generation
starts. Subsequent disruption_feed calls scoped by case_id return them."""
try:
r = _client.post(
"/set_disruptions",
json={"case_id": session_id, "system": system, "disruptions": disruptions},
)
r.raise_for_status()
return r.json()
except Exception as e:
return {"error": str(e)}
def health() -> bool:
try:
return _client.get("/health").status_code == 200
except Exception:
return False