Team_Sparks / server.py
KeithXD's picture
Upload folder using huggingface_hub
4702dbb verified
"""
server.py -- AuditRepairEnv++ OpenEnv Server
=============================================
FastAPI server: /reset, /step, /state, /health
OpenEnv-compliant, HuggingFace-ready, port 7860.
"""
import os
import time
import uuid
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from tasks import TASK_CONFIGS, TASK_IDS, LedgerEnvironment, AuditObservation
# ────────────────────────────────────────
# REQUEST / RESPONSE MODELS
# ────────────────────────────────────────
class ResetRequest(BaseModel):
task_id: Optional[str] = Field(default=None, description="easy | medium | hard")
class StepAction(BaseModel):
message: str = Field(..., description="Agent action text, e.g. 'FIX_ENTRY 1'")
class StepResponse(BaseModel):
observation: AuditObservation
reward: float
done: bool
info: Dict[str, Any] = Field(default_factory=dict)
last_action_error: Optional[str] = None
class StateResponse(BaseModel):
episode_id: str
task_id: str
step: int
max_steps: int
total_reward: float
done: bool
remaining_budget: int
initial_budget: int
errors_count: int
history: List[Dict[str, Any]]
started_at: float
# ────────────────────────────────────────
# EPISODE STATE
# ────────────────────────────────────────
class EpisodeState:
def __init__(self, env: LedgerEnvironment):
self.episode_id = str(uuid.uuid4())
self.env = env
self.total_reward = 0.0
self.history: List[Dict[str, Any]] = []
self.started_at = time.time()
_current_episode: Optional[EpisodeState] = None
# ────────────────────────────────────────
# FASTAPI APP
# ────────────────────────────────────────
app = FastAPI(title="AuditRepairEnv++", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/", include_in_schema=False)
async def root():
return {"name": "AuditRepairEnv++", "status": "running", "docs": "/docs", "message": "API is live."}
# ────────────────────────────────────────
# OPENENV ENDPOINTS
# ────────────────────────────────────────
async def _do_reset(task_id: Optional[str] = None):
global _current_episode
tid = task_id or "easy"
if tid not in TASK_CONFIGS:
raise HTTPException(400, f"Unknown task '{tid}'. Available: {TASK_IDS}")
config = TASK_CONFIGS[tid]
env = config.create_env()
_current_episode = EpisodeState(env)
obs = env.get_observation(echoed_message=f"Environment reset. Task: {config.name}")
return obs.model_dump()
@app.post("/reset")
async def reset_post(request: ResetRequest = ResetRequest()):
return await _do_reset(request.task_id)
@app.get("/reset")
async def reset_get(task_id: Optional[str] = None):
return await _do_reset(task_id)
@app.post("/step")
async def step(action: StepAction):
global _current_episode
if _current_episode is None:
raise HTTPException(400, "No active episode. Call /reset first.")
if _current_episode.env.done:
raise HTTPException(400, "Episode finished. Call /reset to start a new one.")
ep = _current_episode
result = ep.env.step_with_message(action.message)
reward = float(result.get("reward", 0)) # Already normalized by normalize_reward()
done = bool(result.get("done", False))
error = result.get("error")
# Compute current score (normalized to [0.0, 1.0])
current_score = ep.env.compute_final_score()
ep.total_reward = current_score # Track the current normalized score
ep.history.append({
"step": ep.env.step,
"action": action.message[:200],
"reward": reward,
"step_score": current_score,
"done": done,
"info": result.get("result", ""),
})
final_score = current_score if done else None
return StepResponse(
observation=result["observation"],
reward=current_score, # Return normalized score instead of raw step reward
done=done,
info={
"total_reward": ep.total_reward,
"episode_id": ep.episode_id,
"result": result.get("result", ""),
"final_score": final_score,
},
last_action_error=error,
).model_dump()
@app.get("/state")
async def state():
if _current_episode is None:
raise HTTPException(400, "No active episode. Call /reset first.")
ep = _current_episode
return StateResponse(
episode_id=ep.episode_id,
task_id=ep.env.task_id,
step=ep.env.step,
max_steps=ep.env.max_steps,
total_reward=ep.total_reward,
done=ep.env.done,
remaining_budget=ep.env.remaining_budget,
initial_budget=ep.env.initial_budget,
errors_count=len(ep.env.get_errors()),
history=ep.history,
started_at=ep.started_at,
).model_dump()
@app.get("/health")
async def health():
return {
"status": "ok",
"environment": "AuditRepairEnv++",
"tasks": TASK_IDS,
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)