File size: 4,146 Bytes
cb33205
 
d4ab0f1
cb33205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""FastAPI server exposing SchemaShiftEnvironment as OpenEnv HTTP service."""
from __future__ import annotations

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field

from models import Action
from scenarios import SCENARIOS
from server.environment import SchemaShiftEnvironment


app = FastAPI(
    title="SchemaShift OpenEnv",
    description="RL environment for training adaptive tool use under schema drift.",
    version="0.1.0",
)
env = SchemaShiftEnvironment()


# ─────────────────────────────────────────────────────────────────
# Request / response models
# ─────────────────────────────────────────────────────────────────

class ResetRequest(BaseModel):
    task_id: str
    seed: int = 0


class StepRequest(BaseModel):
    action: Action
    tokens_used: int = Field(
        default=0, ge=0,
        description="Tokens consumed by the agent on this step",
    )


# ─────────────────────────────────────────────────────────────────
# Endpoints
# ─────────────────────────────────────────────────────────────────

@app.get("/")
def root() -> dict:
    return {
        "name": "SchemaShift",
        "version": "0.1.0",
        "description": "Adaptive tool use under schema drift",
        "endpoints": ["/health", "/reset", "/step", "/state", "/tasks", "/grader"],
    }


@app.get("/health")
def health() -> dict:
    return {"status": "ok", "version": "0.1.0"}


@app.post("/reset")
def reset(req: ResetRequest) -> dict:
    """Start new episode. Returns initial Observation as dict."""
    try:
        obs = env.reset(req.task_id, req.seed)
        return obs.model_dump()
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Reset failed: {e}")


@app.post("/step")
def step(req: StepRequest) -> dict:
    """Submit action, get observation + reward."""
    try:
        obs, reward = env.step(req.action, req.tokens_used)
        return {
            "observation": obs.model_dump(),
            "reward": reward.model_dump(),
        }
    except RuntimeError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Step failed: {e}")


@app.get("/state")
def get_state() -> dict:
    """Return full current episode state for debugging."""
    if env._state is None:
        raise HTTPException(status_code=400, detail="No active episode. Call /reset first.")
    return env._state.model_dump()


@app.get("/tasks")
def get_tasks() -> dict:
    """List all available scenarios with metadata."""
    tasks = []
    for task_id, scenario in SCENARIOS.items():
        desc = scenario["task_description"]
        trimmed = desc[:120] + ("..." if len(desc) > 120 else "")
        tasks.append({
            "task_id": task_id,
            "difficulty": scenario["difficulty"],
            "max_steps": scenario["max_steps"],
            "required_tools": scenario["required_tools"],
            "description": trimmed,
        })
    return {"tasks": tasks, "count": len(tasks)}


@app.get("/grader")
def get_grader_breakdown() -> dict:
    """Return grader scoring for current episode state."""
    if env._state is None:
        raise HTTPException(status_code=400, detail="No active episode.")
    reward = env._grader(env._state)
    return {
        "cumulative_reward": env._state.cumulative_reward,
        "current_breakdown": reward.model_dump(),
        "step": env._state.step,
        "max_steps": env._state.max_steps,
        "done": env._state.done,
    }