Spaces:
Sleeping
Sleeping
| """FastAPI server exposing SchemaShiftEnvironment as OpenEnv HTTP service.""" | |
| from __future__ import annotations | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel, Field | |
| from models import Action | |
| from scenarios import SCENARIOS | |
| from server.environment import SchemaShiftEnvironment | |
| app = FastAPI( | |
| title="SchemaShift OpenEnv", | |
| description="RL environment for training adaptive tool use under schema drift.", | |
| version="0.1.0", | |
| ) | |
| env = SchemaShiftEnvironment() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Request / response models | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ResetRequest(BaseModel): | |
| task_id: str | |
| seed: int = 0 | |
| class StepRequest(BaseModel): | |
| action: Action | |
| tokens_used: int = Field( | |
| default=0, ge=0, | |
| description="Tokens consumed by the agent on this step", | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Endpoints | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def root() -> dict: | |
| return { | |
| "name": "SchemaShift", | |
| "version": "0.1.0", | |
| "description": "Adaptive tool use under schema drift", | |
| "endpoints": ["/health", "/reset", "/step", "/state", "/tasks", "/grader"], | |
| } | |
| def health() -> dict: | |
| return {"status": "ok", "version": "0.1.0"} | |
| def reset(req: ResetRequest) -> dict: | |
| """Start new episode. Returns initial Observation as dict.""" | |
| try: | |
| obs = env.reset(req.task_id, req.seed) | |
| return obs.model_dump() | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Reset failed: {e}") | |
| def step(req: StepRequest) -> dict: | |
| """Submit action, get observation + reward.""" | |
| try: | |
| obs, reward = env.step(req.action, req.tokens_used) | |
| return { | |
| "observation": obs.model_dump(), | |
| "reward": reward.model_dump(), | |
| } | |
| except RuntimeError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Step failed: {e}") | |
| def get_state() -> dict: | |
| """Return full current episode state for debugging.""" | |
| if env._state is None: | |
| raise HTTPException(status_code=400, detail="No active episode. Call /reset first.") | |
| return env._state.model_dump() | |
| def get_tasks() -> dict: | |
| """List all available scenarios with metadata.""" | |
| tasks = [] | |
| for task_id, scenario in SCENARIOS.items(): | |
| desc = scenario["task_description"] | |
| trimmed = desc[:120] + ("..." if len(desc) > 120 else "") | |
| tasks.append({ | |
| "task_id": task_id, | |
| "difficulty": scenario["difficulty"], | |
| "max_steps": scenario["max_steps"], | |
| "required_tools": scenario["required_tools"], | |
| "description": trimmed, | |
| }) | |
| return {"tasks": tasks, "count": len(tasks)} | |
| def get_grader_breakdown() -> dict: | |
| """Return grader scoring for current episode state.""" | |
| if env._state is None: | |
| raise HTTPException(status_code=400, detail="No active episode.") | |
| reward = env._grader(env._state) | |
| return { | |
| "cumulative_reward": env._state.cumulative_reward, | |
| "current_breakdown": reward.model_dump(), | |
| "step": env._state.step, | |
| "max_steps": env._state.max_steps, | |
| "done": env._state.done, | |
| } | |