from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, Dict, Any, List
import json
import uvicorn
import uuid
from dotenv import load_dotenv
load_dotenv()
from .tasks import TASKS
from .grader import grade_repair
app = FastAPI(
title="JSON Repair Environment",
description="OpenEnv environment for training agents to repair malformed JSON",
version="1.0.0"
)
# --- Pydantic Models ---
class Action(BaseModel):
repaired_json: str
explanation: Optional[str] = ""
class Observation(BaseModel):
broken_json: str
target_schema: Dict[str, Any]
hint: str
task_name: str
step_number: int
total_tasks: int
class StepResult(BaseModel):
observation: Observation
reward: float
done: bool
info: Dict[str, Any]
class ResetResult(BaseModel):
observation: Observation
done: bool
reward: float
# --- In-memory State ---
state: Dict[str, Any] = {
"session_id": "",
"task_index": 0,
"step": 0,
"total_reward": 0.0,
"done": False,
"history": []
}
def build_observation() -> Observation:
task = TASKS[state["task_index"]]
return Observation(
broken_json=task["broken_json"],
target_schema=task["schema"],
hint=task["hint"],
task_name=task["name"],
step_number=state["step"],
total_tasks=len(TASKS)
)
# --- Endpoints ---
from fastapi.responses import HTMLResponse
# --- Activity Logs (In-Memory) ---
logs: List[str] = [
"[SYSTEM] Environment Initialized.",
"[SYSTEM] Ready for incoming agent connections.",
"[DASHBOARD] UI v2.0.0 (High-Performance) loaded."
]
@app.get("/", response_class=HTMLResponse)
async def root():
task_cards = ""
for task in TASKS:
task_cards += f"""
{task['name']}
{task['difficulty'].upper()}
{task['description']}
"""
log_items = "".join([f'{line}
' for line in logs])
html_content = f"""
JSON-REPAIR | Neural Environment
⚡ CUSTOM REPAIR LAB
Others can use this to validate their own JSON logic instantly.
{task_cards}
"""
return HTMLResponse(content=html_content)
@app.get("/health")
async def health():
return {"status": "ok"}
@app.post("/reset", response_model=ResetResult)
async def reset():
"""Reset the environment to the first task."""
state.update({
"session_id": str(uuid.uuid4()),
"task_index": 0,
"step": 0,
"total_reward": 0.0,
"done": False,
"history": []
})
logs.append(f"[EVENT] Environment Reset. Session: {state['session_id'][:8]}")
if len(logs) > 50: logs.pop(0)
obs = build_observation()
# Using 0.1 instead of 0.0 to satisfy 'strictly between 0 and 1' requirement
return ResetResult(observation=obs, done=False, reward=0.1)
@app.post("/step", response_model=StepResult)
async def step(action: Action):
"""Submit a repaired JSON and advance to next task."""
if state["done"]:
raise HTTPException(status_code=400, detail="Episode is done. Call /reset to start again.")
task = TASKS[state["task_index"]]
reward, info = grade_repair(action.repaired_json, task)
state["step"] += 1
state["total_reward"] += reward
state["history"].append({
"task": task["name"],
"action": action.repaired_json[:100],
"reward": reward,
"info": info
})
state["task_index"] += 1
logs.append(f"[AGENT] Action submitted for {task['name']}. Reward: {reward}")
if len(logs) > 50: logs.pop(0)
done = state["task_index"] >= len(TASKS)
state["done"] = done
if done:
obs = Observation(
broken_json="",
target_schema={},
hint="All tasks completed! Call /reset to start a new episode.",
task_name="episode_complete",
step_number=state["step"],
total_tasks=len(TASKS)
)
# Using 0.9 for final reward if it was already high
reward = min(max(reward, 0.11), 0.89)
else:
obs = build_observation()
return StepResult(observation=obs, reward=reward, done=done, info=info)
@app.get("/state")
async def get_state():
"""Return current environment state."""
return state
@app.get("/tasks")
async def list_tasks():
"""List all available tasks."""
return [
{
"name": t["name"],
"difficulty": t["difficulty"],
"description": t["description"],
"hint": t["hint"]
}
for t in TASKS
]
@app.get("/logs")
async def get_logs():
"""Return latest activity logs."""
return logs
@app.post("/test_custom")
async def test_custom(action: Action):
"""Utility endpoint for users to test their own repairs against a generic validator."""
try:
data = json.loads(action.repaired_json)
logs.append(f"[USER] Manual repair test: VALID JSON")
return {"status": "valid", "data": data}
except Exception as e:
logs.append(f"[USER] Manual repair test: INVALID. Error: {str(e)}")
return {"status": "invalid", "error": str(e)}
def main():
"""Main entry point for the server."""
uvicorn.run("server.app:app", host="0.0.0.0", port=7860, reload=False)
if __name__ == "__main__":
main()