FlakyTestSleuthOpenEnvRL / server /inference_runner.py
vedkdev's picture
Deploy FlakyGym UI + inference updates (minimal upload)
f53d90b verified
from __future__ import annotations
import os
import subprocess
import sys
import threading
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
@dataclass
class InferenceJob:
job_id: str
status: str
started_at: float
command: list[str]
config: dict[str, Any]
logs: list[str] = field(default_factory=list)
return_code: int | None = None
finished_at: float | None = None
error: str | None = None
stop_requested: bool = False
summaries: list[dict[str, Any]] = field(default_factory=list)
class InferenceRunner:
"""Run inference.py in the background and expose live status."""
def __init__(self, repo_root: Path):
self._repo_root = repo_root.resolve()
self._lock = threading.Lock()
self._job: InferenceJob | None = None
self._proc: subprocess.Popen[str] | None = None
def start(self, payload: dict[str, Any]) -> dict[str, Any]:
with self._lock:
if self._job and self._job.status in {"starting", "running"}:
raise RuntimeError("An inference run is already in progress.")
dataset_rel = str(payload.get("dataset_path", "dataset/py_tasks.csv")).strip()
episodes = int(payload.get("episodes_per_task", 1))
max_steps = int(payload.get("max_steps", 20))
task_types = str(payload.get("task_types", "classify,root_cause,fix_proposal")).strip()
benchmark_name = str(payload.get("benchmark_name", "flakysleuth")).strip()
if not dataset_rel:
raise ValueError("dataset_path must not be empty.")
if episodes < 1 or episodes > 100:
raise ValueError("episodes_per_task must be between 1 and 100.")
if max_steps < 1 or max_steps > 100:
raise ValueError("max_steps must be between 1 and 100.")
if not task_types:
raise ValueError("task_types must not be empty.")
if not benchmark_name:
raise ValueError("benchmark_name must not be empty.")
dataset_path = self._resolve_dataset_path(dataset_rel)
command = [
sys.executable,
"inference.py",
"--dataset-path",
os.path.relpath(dataset_path, self._repo_root),
"--episodes-per-task",
str(episodes),
"--task-types",
task_types,
"--max-steps",
str(max_steps),
"--benchmark-name",
benchmark_name,
]
job = InferenceJob(
job_id=uuid.uuid4().hex[:12],
status="starting",
started_at=time.time(),
command=command,
config={
"dataset_path": os.path.relpath(dataset_path, self._repo_root),
"episodes_per_task": episodes,
"task_types": task_types,
"max_steps": max_steps,
"benchmark_name": benchmark_name,
"api_base_url": _clean_optional_text(payload.get("api_base_url")),
"model_name": _clean_optional_text(payload.get("model_name")),
"api_key_provided": bool(_clean_optional_text(payload.get("api_key"))),
},
)
self._append_log(job, f"[UI] Starting run {job.job_id}")
self._append_log(job, f"[UI] Command: {' '.join(command)}")
with self._lock:
self._job = job
worker = threading.Thread(
target=self._run_job,
args=(job, payload),
daemon=True,
)
worker.start()
return self.snapshot(tail=300)
def stop(self) -> bool:
with self._lock:
job = self._job
proc = self._proc
if not job or not proc or job.status not in {"starting", "running"}:
return False
job.stop_requested = True
if proc.poll() is None:
proc.terminate()
try:
proc.wait(timeout=8)
except subprocess.TimeoutExpired:
proc.kill()
proc.wait(timeout=8)
return True
def snapshot(self, tail: int = 300) -> dict[str, Any]:
with self._lock:
if self._job is None:
return {
"has_job": False,
"status": "idle",
"logs": [],
}
job = self._job
logs_tail = job.logs[-max(20, min(tail, 2000)) :]
return {
"has_job": True,
"job_id": job.job_id,
"status": job.status,
"started_at": job.started_at,
"finished_at": job.finished_at,
"return_code": job.return_code,
"error": job.error,
"config": job.config,
"command": job.command,
"summaries": job.summaries,
"logs": logs_tail,
}
def _run_job(self, job: InferenceJob, payload: dict[str, Any]) -> None:
env = os.environ.copy()
api_key = _clean_optional_text(payload.get("api_key"))
api_base_url = _clean_optional_text(payload.get("api_base_url"))
model_name = _clean_optional_text(payload.get("model_name"))
if api_key:
env["API_KEY"] = api_key
if api_base_url:
env["API_BASE_URL"] = api_base_url
if model_name:
env["MODEL_NAME"] = model_name
with self._lock:
job.status = "running"
process: subprocess.Popen[str] | None = None
try:
process = subprocess.Popen(
job.command,
cwd=self._repo_root,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
env=env,
)
with self._lock:
self._proc = process
assert process.stdout is not None
for raw_line in process.stdout:
line = raw_line.rstrip("\n")
if not line:
continue
self._append_log(job, line)
summary = _parse_end_line(line)
if summary:
with self._lock:
job.summaries.append(summary)
return_code = process.wait()
extra_log: str | None = None
with self._lock:
job.return_code = return_code
job.finished_at = time.time()
if job.stop_requested:
job.status = "stopped"
extra_log = "[UI] Run stopped by user request."
elif return_code == 0:
job.status = "completed"
else:
job.status = "failed"
extra_log = f"[UI] Process exited with code {return_code}."
self._proc = None
if extra_log:
self._append_log(job, extra_log)
except Exception as exc:
extra_log = f"[UI] Runner failed: {exc}"
with self._lock:
job.error = str(exc)
job.finished_at = time.time()
job.status = "failed"
self._proc = None
self._append_log(job, extra_log)
finally:
if process and process.stdout:
process.stdout.close()
def _append_log(self, job: InferenceJob, line: str) -> None:
with self._lock:
job.logs.append(line)
if len(job.logs) > 3000:
del job.logs[: len(job.logs) - 3000]
def _resolve_dataset_path(self, dataset_path: str) -> Path:
candidate = Path(dataset_path)
if not candidate.is_absolute():
candidate = (self._repo_root / candidate).resolve()
else:
candidate = candidate.resolve()
# Keep data access bounded to the repository.
if os.path.commonpath([str(self._repo_root), str(candidate)]) != str(self._repo_root):
raise ValueError("dataset_path must point to a file inside the repository.")
if not candidate.exists():
raise FileNotFoundError(f"Dataset file not found: {dataset_path}")
if not candidate.is_file():
raise ValueError(f"dataset_path is not a file: {dataset_path}")
return candidate
def _clean_optional_text(value: Any) -> str | None:
if value is None:
return None
text = str(value).strip()
return text or None
def _parse_end_line(line: str) -> dict[str, Any] | None:
# Example:
# [END] success=true steps=3 score=1.00 rewards=0.00,0.20,1.00
if not line.startswith("[END] "):
return None
payload: dict[str, str] = {}
for token in line[len("[END] ") :].split(" "):
if "=" not in token:
continue
key, value = token.split("=", 1)
payload[key.strip()] = value.strip()
if "success" not in payload or "steps" not in payload or "score" not in payload:
return None
rewards_raw = payload.get("rewards", "")
rewards: list[float] = []
for token in rewards_raw.split(","):
token = token.strip()
if not token:
continue
try:
rewards.append(float(token))
except ValueError:
continue
try:
return {
"success": payload["success"].lower() == "true",
"steps": int(payload["steps"]),
"score": float(payload["score"]),
"rewards": rewards,
}
except Exception:
return None