schemashift / server /app.py
yashash04's picture
Phase 7: FastAPI server + Dockerfile + openenv.yaml
cb33205
"""FastAPI server exposing SchemaShiftEnvironment as OpenEnv HTTP service."""
from __future__ import annotations
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from models import Action
from scenarios import SCENARIOS
from server.environment import SchemaShiftEnvironment
app = FastAPI(
title="SchemaShift OpenEnv",
description="RL environment for training adaptive tool use under schema drift.",
version="0.1.0",
)
env = SchemaShiftEnvironment()
# ─────────────────────────────────────────────────────────────────
# Request / response models
# ─────────────────────────────────────────────────────────────────
class ResetRequest(BaseModel):
task_id: str
seed: int = 0
class StepRequest(BaseModel):
action: Action
tokens_used: int = Field(
default=0, ge=0,
description="Tokens consumed by the agent on this step",
)
# ─────────────────────────────────────────────────────────────────
# Endpoints
# ─────────────────────────────────────────────────────────────────
@app.get("/")
def root() -> dict:
return {
"name": "SchemaShift",
"version": "0.1.0",
"description": "Adaptive tool use under schema drift",
"endpoints": ["/health", "/reset", "/step", "/state", "/tasks", "/grader"],
}
@app.get("/health")
def health() -> dict:
return {"status": "ok", "version": "0.1.0"}
@app.post("/reset")
def reset(req: ResetRequest) -> dict:
"""Start new episode. Returns initial Observation as dict."""
try:
obs = env.reset(req.task_id, req.seed)
return obs.model_dump()
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Reset failed: {e}")
@app.post("/step")
def step(req: StepRequest) -> dict:
"""Submit action, get observation + reward."""
try:
obs, reward = env.step(req.action, req.tokens_used)
return {
"observation": obs.model_dump(),
"reward": reward.model_dump(),
}
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Step failed: {e}")
@app.get("/state")
def get_state() -> dict:
"""Return full current episode state for debugging."""
if env._state is None:
raise HTTPException(status_code=400, detail="No active episode. Call /reset first.")
return env._state.model_dump()
@app.get("/tasks")
def get_tasks() -> dict:
"""List all available scenarios with metadata."""
tasks = []
for task_id, scenario in SCENARIOS.items():
desc = scenario["task_description"]
trimmed = desc[:120] + ("..." if len(desc) > 120 else "")
tasks.append({
"task_id": task_id,
"difficulty": scenario["difficulty"],
"max_steps": scenario["max_steps"],
"required_tools": scenario["required_tools"],
"description": trimmed,
})
return {"tasks": tasks, "count": len(tasks)}
@app.get("/grader")
def get_grader_breakdown() -> dict:
"""Return grader scoring for current episode state."""
if env._state is None:
raise HTTPException(status_code=400, detail="No active episode.")
reward = env._grader(env._state)
return {
"cumulative_reward": env._state.cumulative_reward,
"current_breakdown": reward.model_dump(),
"step": env._state.step,
"max_steps": env._state.max_steps,
"done": env._state.done,
}