abdur0001's picture
fix: add mcp for hf
b562fdb
"""
FastAPI server for the SRE Incident Response OpenEnv environment.
"""
import sys
import os
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Dict, List, Optional
from models import Action, Observation, State
from env.environment import IncidentResponseEnv
from tasks import SCENARIOS
app = FastAPI(
title="SRE Incident Response Environment",
description="An OpenEnv environment for training AI agents on production incident response.",
version="1.0.0",
)
env = IncidentResponseEnv()
# ── Request/Response models ────────────────────────────────────────────
class ResetRequest(BaseModel):
task_id: str = "easy"
seed: int = 0
class ResetResponse(BaseModel):
observation: Observation
session_id: str
class StepRequest(BaseModel):
session_id: str
action: Action
class StepResponse(BaseModel):
observation: Observation
reward: float
done: bool
info: Dict
class TaskInfo(BaseModel):
task_id: str
name: str
difficulty: str
max_steps: int
description: str
# ── OpenEnv spec endpoints ─────────────────────────────────────────────
@app.get("/health")
def health():
return {"status": "healthy"}
@app.get("/metadata")
def metadata():
return {
"name": "sre-incident-response",
"description": "SRE Incident Response environment β€” train AI agents to diagnose and fix production incidents",
"version": "1.0.0",
}
@app.get("/schema")
def schema():
return {
"action": Action.model_json_schema(),
"observation": Observation.model_json_schema(),
"state": State.model_json_schema(),
}
@app.get("/state")
def state_no_session():
"""Return state for the most recent session, or empty state if none."""
if env.sessions:
last_sid = list(env.sessions.keys())[-1]
return env.state(last_sid)
return State()
@app.post("/mcp")
def mcp_endpoint(body: dict = {}):
"""Minimal MCP JSON-RPC endpoint for OpenEnv spec compliance."""
method = body.get("method", "")
req_id = body.get("id", 1)
if method == "initialize":
return {
"jsonrpc": "2.0",
"id": req_id,
"result": {
"protocolVersion": "2024-11-05",
"serverInfo": {"name": "sre-incident-response", "version": "1.0.0"},
"capabilities": {},
},
}
return {
"jsonrpc": "2.0",
"id": req_id,
"result": {},
}
# ── Endpoints ──────────────────────────────────────────────────────────
@app.get("/")
def root():
return {
"name": "SRE Incident Response Environment",
"version": "1.0.0",
"endpoints": ["/reset", "/step", "/state/{session_id}", "/tasks", "/health", "/metadata", "/schema"],
}
@app.post("/reset", response_model=ResetResponse)
def reset(request: ResetRequest):
try:
obs, session_id = env.reset(task_id=request.task_id, seed=request.seed)
return ResetResponse(observation=obs, session_id=session_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/step", response_model=StepResponse)
def step(request: StepRequest):
try:
obs, reward, done, info = env.step(request.session_id, request.action)
# Ensure info is JSON-serializable
clean_info = {}
for k, v in info.items():
clean_info[k] = v
return StepResponse(observation=obs, reward=reward, done=done, info=clean_info)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/state/{session_id}", response_model=State)
def state(session_id: str):
try:
return env.state(session_id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@app.get("/tasks", response_model=List[TaskInfo])
def tasks():
result = []
for tid, scenario in SCENARIOS.items():
result.append(TaskInfo(
task_id=tid,
name=scenario.name,
difficulty=scenario.difficulty,
max_steps=scenario.max_steps,
description=scenario.incident_summary,
))
return result
# ── OpenEnv-prefixed aliases ───────────────────────────────────────────
@app.post("/openenv/reset", response_model=ResetResponse)
def openenv_reset(request: ResetRequest):
return reset(request)
@app.post("/openenv/step", response_model=StepResponse)
def openenv_step(request: StepRequest):
return step(request)
@app.get("/openenv/state/{session_id}", response_model=State)
def openenv_state(session_id: str):
return state(session_id)
@app.get("/openenv/tasks", response_model=List[TaskInfo])
def openenv_tasks():
return tasks()
# ── Main ───────────────────────────────────────────────────────────────
def main():
import uvicorn
port = int(os.environ.get("PORT", "8000"))
uvicorn.run(app, host="0.0.0.0", port=port)
if __name__ == "__main__":
main()