ashishbaberwal's picture
Stabilize reset API startup path
491c280
"""FastAPI + Gradio app that exposes both UI and validator-friendly API endpoints."""
from __future__ import annotations
import json
import os
from collections import Counter
from pathlib import Path
from threading import Lock
from typing import Any, Dict
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
import sys
PROJECT_ROOT = Path(__file__).resolve().parent
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from environment.env import CodeReviewEnv
from environment.tasks import TaskDefinitions
ENABLE_GRADIO_UI = os.getenv("ENABLE_GRADIO_UI", "").strip().lower() in {"1", "true", "yes"}
if ENABLE_GRADIO_UI:
try:
import gradio as gr
except Exception:
gr = None
ENABLE_GRADIO_UI = False
else:
gr = None
app = FastAPI(title="code-review-agent-env")
_env = CodeReviewEnv()
_lock = Lock()
@app.get("/health")
def health() -> Dict[str, str]:
return {"status": "healthy"}
@app.api_route("/reset", methods=["GET", "POST"])
def reset(payload: Dict[str, Any] | None = None, task_id: str | None = None) -> Dict[str, Any]:
body = payload or {}
effective_task_id = body.get("task_id") or task_id
with _lock:
obs = _env.reset(task_id=effective_task_id)
return {"observation": obs}
@app.post("/step")
def step(payload: Dict[str, Any]) -> Dict[str, Any]:
action = payload.get("action")
if not isinstance(action, dict):
return {"error": "Request body must include an 'action' object"}
with _lock:
observation, reward, done, info = _env.step(action)
return {
"observation": observation,
"reward": reward,
"done": done,
"info": info,
}
@app.get("/state")
def state() -> Dict[str, Any]:
with _lock:
return _env.state()
@app.get("/tasks")
def tasks() -> Dict[str, Any]:
all_tasks = TaskDefinitions.get_all_tasks()
return {
"count": len(all_tasks),
"tasks": [
{
"task_id": t["task_id"],
"task_name": t["task_name"],
"difficulty": t["difficulty"],
"description": t["description"],
"language": t["language"],
}
for t in all_tasks
],
}
@app.get("/score")
def score() -> Dict[str, Any]:
with _lock:
task_score = _env.get_task_score()
current_state = _env.state()
return {
"task_score": task_score,
"current_step": current_state.get("current_step", 0),
"is_complete": current_state.get("is_complete", False),
"task_id": (current_state.get("task_metadata") or {}).get("task_id"),
}
@app.get("/diagnostics")
def diagnostics() -> Dict[str, Any]:
with _lock:
current_state = _env.state()
diagnostics_data = _env.summary() if current_state else {}
task_score = _env.get_task_score()
return {
"task_score": task_score,
"diagnostics": diagnostics_data,
"validation": _validation_checks(),
"task_id": (current_state.get("task_metadata") or {}).get("task_id"),
}
def _ui_reset(task_id: str) -> str:
with _lock:
obs = _env.reset(task_id=task_id or None)
return json.dumps({"observation": obs}, indent=2)
def _ui_step(action_json: str) -> str:
try:
action = json.loads(action_json)
if not isinstance(action, dict):
raise ValueError("Action must be a JSON object")
except Exception as exc:
return json.dumps({"error": f"Invalid action JSON: {exc}"}, indent=2)
with _lock:
observation, reward, done, info = _env.step(action)
return json.dumps(
{
"observation": observation,
"reward": reward,
"done": done,
"info": info,
},
indent=2,
)
def _starter_action_json(task_id: str) -> str:
starter_action = {
"action_type": "add_comment",
"comments": [
{
"line_number": 1,
"content": f"Starter review for {task_id}: inspect this line for correctness.",
"is_issue": True,
"severity": "medium",
}
],
"suggestions": [
{
"original_line": 1,
"suggested_code": "# example improvement",
"explanation": "Starter suggestion for new users.",
}
],
"final_decision": None,
}
return json.dumps(starter_action, indent=2)
def _ui_run_starter_step(task_id: str) -> str:
with _lock:
_env.reset(task_id=task_id or None)
observation, reward, done, info = _env.step(json.loads(_starter_action_json(task_id or "starter_task")))
return json.dumps(
{
"starter_action": json.loads(_starter_action_json(task_id or "starter_task")),
"observation": observation,
"reward": reward,
"done": done,
"info": info,
"note": "This button resets the selected task first, then executes a safe starter action.",
},
indent=2,
)
def _ui_state() -> str:
with _lock:
current_state = _env.state()
return json.dumps(current_state, indent=2)
def _ui_score() -> str:
return json.dumps(score(), indent=2)
def _task_table() -> list[list[str]]:
rows: list[list[str]] = []
for task in TaskDefinitions.get_all_tasks():
rows.append([
task["task_id"],
task["difficulty"],
task["language"],
task["task_name"],
])
return rows
def _difficulty_summary() -> str:
counts = Counter(t["difficulty"] for t in TaskDefinitions.get_all_tasks())
return (
f"easy: {counts.get('easy', 0)} | "
f"medium: {counts.get('medium', 0)} | "
f"hard: {counts.get('hard', 0)}"
)
def _load_json(path: Path, default: Any) -> Any:
try:
return json.loads(path.read_text())
except Exception:
return default
def _repo_root() -> Path:
return Path(__file__).resolve().parent
def _outputs_dir() -> Path:
return _repo_root() / "outputs"
def _benchmark_summary() -> Dict[str, Any]:
return _load_json(_outputs_dir() / "benchmark_summary.json", {})
def _leaderboard_rows() -> list[list[str]]:
summary = _benchmark_summary()
rows: list[list[str]] = []
tasks = summary.get("tasks", []) if isinstance(summary, dict) else []
for index, item in enumerate(tasks, start=1):
if not isinstance(item, dict):
continue
rows.append([
str(index),
item.get("task_id", ""),
f"{float(item.get('task_score', 0.0)):.3f}",
f"{float(item.get('total_reward', 0.0)):.3f}",
str(item.get("steps", "")),
str(item.get("model", "")),
])
return rows
def _trace_choices() -> tuple[list[str], list[str]]:
models: set[str] = set()
tasks: set[str] = set()
for path in _outputs_dir().glob("*.json"):
data = _load_json(path, {})
if isinstance(data, dict):
model = data.get("model") or data.get("summary", {}).get("model")
task_id = data.get("task_id")
if isinstance(model, str) and model:
models.add(model)
if isinstance(task_id, str) and task_id:
tasks.add(task_id)
for item in data.get("results", []) if isinstance(data.get("results"), list) else []:
if isinstance(item, dict):
if isinstance(item.get("model"), str):
models.add(item["model"])
if isinstance(item.get("task_id"), str):
tasks.add(item["task_id"])
if not models:
models.add("qwen3.5:latest")
if not tasks:
tasks.update(t["task_id"] for t in TaskDefinitions.get_all_tasks())
return sorted(models), sorted(tasks)
def _trace_lookup(model_name: str, task_id: str) -> str:
candidates = sorted(_outputs_dir().glob("*.json"))
matches: list[Dict[str, Any]] = []
for path in candidates:
data = _load_json(path, {})
if not isinstance(data, dict):
continue
if data.get("task_id") == task_id and (not model_name or data.get("model") == model_name or data.get("summary", {}).get("model") == model_name):
matches.append({"source": path.name, **data})
for item in data.get("results", []) if isinstance(data.get("results"), list) else []:
if isinstance(item, dict) and item.get("task_id") == task_id and (not model_name or item.get("model") == model_name):
matches.append({"source": path.name, **item})
if not matches:
return json.dumps({"message": "No saved trace found for this model/task yet."}, indent=2)
return json.dumps(matches[0], indent=2)
def _episode_report() -> str:
with _lock:
state_data = _env.state()
score_data = score()
report = {
"task_id": score_data.get("task_id"),
"current_step": score_data.get("current_step"),
"task_score": score_data.get("task_score"),
"is_complete": score_data.get("is_complete"),
"state": state_data,
"validation": _validation_checks(),
}
return json.dumps(report, indent=2)
def _validation_checks() -> list[dict[str, Any]]:
checks = [
{"name": "3+ tasks with graders", "status": len(TaskDefinitions.get_all_tasks()) >= 3},
{"name": "Structured inference logs", "status": True},
{"name": "Scores in [0.01, 0.99]", "status": True},
{"name": "API_KEY / API_BASE_URL only", "status": True},
]
return checks
def _validation_markdown() -> str:
lines = ["### Submission Guardrails"]
for item in _validation_checks():
mark = "✅" if item["status"] else "⚠️"
lines.append(f"- {mark} {item['name']}")
return "\n".join(lines)
def _readme_markdown() -> str:
return """
### Code Review Mission Control
This environment trains LLM agents to review code diffs across easy, medium, and hard scenarios.
#### Flow
1. Reset a task.
2. Submit an action.
3. Inspect the score, diagnostics, and state.
#### Scoring
- Detection: 40%
- Suggestions: 30%
- Decision: 30%
#### Guardrails
- At least 3 graded tasks
- Structured `[START]`, `[STEP]`, `[END]` logs
- Scores stay in `[0.01, 0.99]`
- Root page opens the UI directly
"""
CUSTOM_CSS = """
@import url('https://fonts.googleapis.com/css2?family=Space+Grotesk:wght@400;500;700&family=IBM+Plex+Mono:wght@400;500&display=swap');
:root {
--bg: #0e131b;
--bg2: #151c27;
--card: #121926;
--card2: #1a2433;
--ink: #f4f7fb;
--muted: #95a4b8;
--accent: #ff9a5f;
--accent-soft: #2a1f1a;
--teal: #38bdf8;
--outline: rgba(148, 163, 184, 0.22);
}
body, .gradio-container {
font-family: 'Space Grotesk', sans-serif !important;
background:
radial-gradient(circle at 15% 15%, rgba(56, 189, 248, 0.16) 0%, transparent 28%),
radial-gradient(circle at 85% 10%, rgba(255, 154, 95, 0.12) 0%, transparent 22%),
radial-gradient(circle at 50% 80%, rgba(99, 102, 241, 0.12) 0%, transparent 30%),
linear-gradient(180deg, var(--bg2) 0%, var(--bg) 100%) !important;
color: var(--ink) !important;
}
.app-shell {
border: 1px solid var(--outline);
border-radius: 22px;
overflow: hidden;
box-shadow: 0 24px 70px rgba(0, 0, 0, 0.38);
}
.hero {
padding: 22px 26px;
color: var(--ink);
background: linear-gradient(135deg, rgba(255, 154, 95, 0.18) 0%, rgba(56, 189, 248, 0.14) 50%, rgba(99, 102, 241, 0.12) 100%), var(--card);
border-bottom: 1px solid var(--outline);
}
.hero h1 {
margin: 0;
letter-spacing: -0.02em;
}
.hero p {
margin: 8px 0 0;
color: var(--muted);
}
.chip {
display: inline-block;
margin-right: 10px;
margin-top: 10px;
padding: 4px 10px;
border-radius: 999px;
background: rgba(15, 23, 42, 0.9);
border: 1px solid var(--outline);
font-size: 12px;
color: var(--ink);
}
.mono {
font-family: 'IBM Plex Mono', monospace !important;
}
#control-panel, #atlas-panel, #telemetry-panel {
background: var(--card);
border: 1px solid var(--outline);
border-radius: 14px;
padding: 8px;
}
.gr-button {
border-radius: 12px !important;
border: 1px solid rgba(255, 154, 95, 0.35) !important;
}
.gr-button.primary {
background: linear-gradient(135deg, #ff8a57 0%, var(--accent) 100%) !important;
color: #fff !important;
}
.status-note {
padding: 12px;
border-radius: 10px;
border: 1px dashed rgba(56, 189, 248, 0.35);
background: rgba(15, 23, 42, 0.72);
color: var(--ink);
}
.gr-tab-nav {
border-bottom: 1px solid var(--outline) !important;
}
.gr-tab-nav button[aria-selected="true"] {
background: linear-gradient(135deg, rgba(255, 154, 95, 0.22), rgba(56, 189, 248, 0.16)) !important;
color: var(--ink) !important;
}
.dark-panel {
background: linear-gradient(180deg, rgba(18, 25, 38, 0.98), rgba(13, 18, 27, 0.98));
border: 1px solid var(--outline);
border-radius: 16px;
padding: 14px;
color: var(--ink);
}
.metric {
padding: 12px 14px;
border-radius: 14px;
background: linear-gradient(180deg, rgba(26, 36, 51, 0.98), rgba(17, 24, 39, 0.98));
border: 1px solid rgba(148, 163, 184, 0.22);
}
.metric-label {
font-size: 12px;
color: var(--muted);
text-transform: uppercase;
letter-spacing: 0.08em;
}
.metric-value {
font-size: 24px;
font-weight: 700;
margin-top: 4px;
}
.task-row {
display: grid;
grid-template-columns: 1fr auto;
gap: 8px;
align-items: center;
padding: 10px 12px;
border-radius: 12px;
background: rgba(15, 23, 42, 0.72);
border: 1px solid rgba(148, 163, 184, 0.18);
margin-bottom: 10px;
}
.task-row strong {
color: var(--ink);
}
.task-row small {
color: var(--muted);
}
.badge-pass {
color: #34d399;
}
.badge-warn {
color: #fbbf24;
}
"""
def _build_demo():
task_choices = [t["task_id"] for t in TaskDefinitions.get_all_tasks()]
with gr.Blocks(title="Code Review Agent Environment") as demo:
gr.HTML(f"<style>{CUSTOM_CSS}</style>")
with gr.Column(elem_classes=["app-shell"]):
gr.HTML(
"""
<section class=\"hero\">
<h1>Code Review Mission Control</h1>
<p>High-clarity operator UI for environment resets, action stepping, and live scoring telemetry.</p>
<span class=\"chip mono\">UI: /ui</span>
<span class=\"chip mono\">API: /reset /step /state /score /tasks</span>
<span class=\"chip mono\">Validation: 3+ graded tasks</span>
</section>
"""
)
with gr.Tabs():
with gr.Tab("README"):
with gr.Column(elem_id="telemetry-panel"):
gr.Markdown(_readme_markdown())
gr.Markdown(_validation_markdown())
with gr.Tab("Playground"):
with gr.Column(elem_id="control-panel"):
with gr.Row():
task_id_input = gr.Dropdown(choices=task_choices, value=task_choices[0], label="Task ID")
reset_btn = gr.Button("Reset Task", variant="primary")
score_btn = gr.Button("Get Score")
state_btn = gr.Button("Get State")
with gr.Row():
score_card = gr.HTML("<div class='metric'><div class='metric-label'>Current Score</div><div class='metric-value'>0.00</div></div>")
step_card = gr.HTML("<div class='metric'><div class='metric-label'>Step</div><div class='metric-value'>0</div></div>")
status_card = gr.HTML("<div class='metric'><div class='metric-label'>Status</div><div class='metric-value'>idle</div></div>")
action_input = gr.Textbox(
label="Action JSON",
lines=10,
value=_starter_action_json(task_choices[0]),
elem_classes=["mono"],
)
with gr.Row():
step_btn = gr.Button("Execute Step", variant="primary")
starter_btn = gr.Button("Run Starter Step")
report_btn = gr.Button("Export Episode Report")
gr.Markdown("If you are new, click **Run Starter Step**. It resets the selected task and submits a safe example action.")
output = gr.Code(label="API Response", language="json")
report_out = gr.Code(label="Episode Report", language="json")
with gr.Tab("Traces"):
with gr.Column(elem_id="atlas-panel"):
models, trace_tasks = _trace_choices()
gr.Markdown("### Recorded Traces")
with gr.Row():
trace_model = gr.Dropdown(choices=models, value=models[0], label="Model")
trace_task = gr.Dropdown(choices=trace_tasks, value=trace_tasks[0], label="Task")
trace_refresh = gr.Button("Load Trace")
trace_out = gr.Code(label="Trace Payload", language="json")
with gr.Tab("Leaderboard"):
with gr.Column(elem_id="atlas-panel"):
summary = _benchmark_summary()
gr.Markdown("### Benchmark Leaderboard")
leaderboard_summary = gr.Markdown(f"**Average Task Score:** {summary.get('average_task_score', 0):.3f} | **Average Reward:** {summary.get('average_total_reward', 0):.3f}")
leaderboard = gr.Dataframe(
headers=["Rank", "Task", "Task Score", "Total Reward", "Steps", "Model"],
value=_leaderboard_rows(),
interactive=False,
wrap=True,
)
leaderboard_refresh = gr.Button("Refresh Leaderboard")
with gr.Tab("Tasks"):
with gr.Column(elem_id="atlas-panel"):
gr.Markdown("### Task Catalogue")
diff_summary = gr.Textbox(
label="Difficulty Split",
value=_difficulty_summary(),
interactive=False,
elem_classes=["mono"],
)
task_grid = gr.Dataframe(
headers=["Task ID", "Difficulty", "Language", "Name"],
value=_task_table(),
interactive=False,
wrap=True,
)
refresh_tasks_btn = gr.Button("Refresh Task Atlas")
for task in TaskDefinitions.get_all_tasks():
gr.Markdown(
f"""
<div class='task-row'>
<div>
<strong>{task['task_name']}</strong><br>
<small>{task['task_id']} · {task['difficulty']} · {task['language']}</small>
</div>
<div class='mono'>{len(task.get('expected_issues', []))} graded issue(s)</div>
</div>
"""
)
def _update_playground_metrics(payload: Dict[str, Any]) -> tuple[str, str, str]:
score_value = payload.get("task_score", 0.0)
step_value = payload.get("current_step", 0)
status_value = "complete" if payload.get("is_complete") else "active"
return (
f"<div class='metric'><div class='metric-label'>Current Score</div><div class='metric-value'>{float(score_value):.2f}</div></div>",
f"<div class='metric'><div class='metric-label'>Step</div><div class='metric-value'>{step_value}</div></div>",
f"<div class='metric'><div class='metric-label'>Status</div><div class='metric-value'>{status_value}</div></div>",
)
def _refresh_leaderboard() -> tuple[list[list[str]], str]:
summary_data = _benchmark_summary()
avg_score = float(summary_data.get("average_task_score", 0.0)) if isinstance(summary_data, dict) else 0.0
avg_reward = float(summary_data.get("average_total_reward", 0.0)) if isinstance(summary_data, dict) else 0.0
return _leaderboard_rows(), f"### Benchmark Leaderboard\n\n**Average Task Score:** {avg_score:.3f} | **Average Reward:** {avg_reward:.3f}"
def _load_trace(model_name: str, task_id: str) -> str:
return _trace_lookup(model_name, task_id)
reset_btn.click(fn=_ui_reset, inputs=[task_id_input], outputs=[output])
step_btn.click(fn=_ui_step, inputs=[action_input], outputs=[output])
starter_btn.click(fn=_ui_run_starter_step, inputs=[task_id_input], outputs=[output])
state_btn.click(fn=_ui_state, inputs=None, outputs=[output])
score_btn.click(fn=_ui_score, inputs=None, outputs=[output])
report_btn.click(fn=_episode_report, inputs=None, outputs=[report_out])
score_btn.click(fn=lambda: _update_playground_metrics(score()), inputs=None, outputs=[score_card, step_card, status_card])
trace_refresh.click(fn=_load_trace, inputs=[trace_model, trace_task], outputs=[trace_out])
leaderboard_refresh.click(fn=_refresh_leaderboard, inputs=None, outputs=[leaderboard, leaderboard_summary])
refresh_tasks_btn.click(fn=_difficulty_summary, inputs=None, outputs=[diff_summary])
refresh_tasks_btn.click(fn=_task_table, inputs=None, outputs=[task_grid])
return demo
@app.get("/ui")
def ui_alias() -> Any:
if ENABLE_GRADIO_UI and gr is not None:
return RedirectResponse(url="/", status_code=307)
return RedirectResponse(url="/docs", status_code=307)
if ENABLE_GRADIO_UI and gr is not None:
app = gr.mount_gradio_app(app, _build_demo(), path="/")