Geetesh
fix: add /metadata /schema /mcp endpoints for OpenEnv runtime validation
b060493
"""
FastAPI server for the ClinicalTrialEnv OpenEnv environment.
Endpoints (OpenEnv spec compliant):
GET /health -> {"status": "healthy"}
GET /metadata -> name, description, version, tasks
GET /schema -> action, observation, state schemas
GET /tasks -> list tasks with graders
POST /reset -> reset episode, return initial observation
POST /step -> take action, return step result
GET /state -> return current internal state
POST /mcp -> JSON-RPC 2.0 endpoint (OpenEnv runtime check)
"""
from __future__ import annotations
import os
import sys
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
sys.path.insert(0, os.path.dirname(__file__))
from environment import ClinicalTrialEnv
from models import ClinicalTrialAction
from tasks import TASKS
app = FastAPI(
title="ClinicalTrialEnv",
description="OpenEnv environment for Clinical Trial Protocol Review.",
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
_env: Optional[ClinicalTrialEnv] = None
class ResetRequest(BaseModel):
task: str = "eligibility_screening"
class StepRequest(BaseModel):
findings: List[Dict[str, Any]] = []
rationale: str = ""
class ResetResponse(BaseModel):
observation: Dict[str, Any]
reward: float
done: bool
info: Dict[str, Any]
class StepResponse(BaseModel):
observation: Dict[str, Any]
reward: float
done: bool
info: Dict[str, Any]
# ---------------------------------------------------------------------------
# OpenEnv required runtime endpoints
# ---------------------------------------------------------------------------
@app.get("/health")
def health():
"""OpenEnv spec: must return {"status": "healthy"}"""
return {"status": "healthy", "environment": "ClinicalTrialEnv", "version": "1.0.0"}
@app.get("/metadata")
def metadata():
"""OpenEnv spec: must return name and description."""
return {
"name": "clinical-trial-env",
"description": (
"OpenEnv environment for Clinical Trial Protocol Review. "
"AI agents act as medical monitors — screening patients for eligibility "
"violations, classifying adverse event severity (CTCAE), and reviewing "
"protocol amendments for GCP/ICH compliance."
),
"version": "1.0.0",
"tasks": [
{
"name": t["name"],
"difficulty": t["difficulty"],
"description": t["description"],
"max_steps": t["max_steps"],
"has_grader": True,
"score_range": [0.0, 1.0],
}
for t in TASKS.values()
],
}
@app.get("/schema")
def schema():
"""OpenEnv spec: must return action, observation, and state schemas."""
return {
"action": {
"type": "object",
"description": "Structured clinical review findings",
"properties": {
"findings": {
"type": "array",
"items": {
"type": "object",
"properties": {
"finding_type": {
"type": "string",
"enum": [
"protocol_deviation",
"adverse_event",
"eligibility_violation",
"safety_concern",
"amendment_recommendation",
],
},
"severity": {
"type": "string",
"enum": ["critical", "major", "minor", "informational"],
},
"subject_id": {"type": "string", "nullable": True},
"description": {"type": "string"},
"recommendation": {"type": "string"},
},
"required": ["finding_type", "severity", "description", "recommendation"],
},
},
"rationale": {"type": "string"},
},
"required": ["findings", "rationale"],
},
"observation": {
"type": "object",
"description": "Clinical trial data presented to the agent",
"properties": {
"task_name": {"type": "string"},
"protocol_summary": {"type": "string"},
"patient_records": {"type": "array"},
"adverse_events": {"type": "array"},
"protocol_text": {"type": "string"},
"step": {"type": "integer"},
"feedback": {"type": "string"},
"partial_score": {"type": "number"},
},
},
"state": {
"type": "object",
"description": "Internal environment state",
"properties": {
"task_name": {"type": "string"},
"step": {"type": "integer"},
"done": {"type": "boolean"},
"current_score": {"type": "number"},
"n_findings_accumulated": {"type": "integer"},
"history": {"type": "array"},
"elapsed_seconds": {"type": "number"},
},
},
}
@app.post("/mcp")
def mcp(payload: Dict[str, Any] = {}):
"""JSON-RPC 2.0 endpoint required by OpenEnv runtime validator."""
method = payload.get("method", "")
req_id = payload.get("id", 1)
if method == "initialize":
return {
"jsonrpc": "2.0",
"id": req_id,
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"serverInfo": {"name": "clinical-trial-env", "version": "1.0.0"},
},
}
if method == "tools/list":
return {
"jsonrpc": "2.0",
"id": req_id,
"result": {
"tools": [
{
"name": "reset",
"description": "Reset the environment to a new task episode",
"inputSchema": {"type": "object", "properties": {"task": {"type": "string"}}},
},
{
"name": "step",
"description": "Submit findings and advance the episode",
"inputSchema": {
"type": "object",
"properties": {
"findings": {"type": "array"},
"rationale": {"type": "string"},
},
},
},
]
},
}
return {
"jsonrpc": "2.0",
"id": req_id,
"result": {
"environment": "clinical-trial-env",
"version": "1.0.0",
"tasks": list(TASKS.keys()),
},
}
@app.get("/tasks")
def list_tasks():
return {
"tasks": [
{
"name": t["name"],
"difficulty": t["difficulty"],
"description": t["description"],
"max_steps": t["max_steps"],
"has_grader": True,
"score_range": [0.0, 1.0],
}
for t in TASKS.values()
]
}
@app.post("/reset")
def reset(req: ResetRequest = None) -> ResetResponse:
global _env
if req is None:
req = ResetRequest()
task = req.task if req and req.task else "eligibility_screening"
if task not in TASKS:
raise HTTPException(status_code=400, detail=f"Unknown task '{task}'. Available: {list(TASKS.keys())}")
_env = ClinicalTrialEnv(task_name=task)
result = _env.reset()
return ResetResponse(
observation=result.observation.model_dump(),
reward=result.reward,
done=result.done,
info=result.info,
)
@app.post("/step")
def step(req: StepRequest) -> StepResponse:
global _env
if _env is None:
raise HTTPException(status_code=400, detail="Call /reset first.")
if _env._done:
raise HTTPException(status_code=400, detail="Episode done. Call /reset.")
action = ClinicalTrialAction(findings=req.findings, rationale=req.rationale)
result = _env.step(action)
return StepResponse(
observation=result.observation.model_dump(),
reward=result.reward,
done=result.done,
info=result.info,
)
@app.get("/state")
def state() -> Dict[str, Any]:
global _env
if _env is None:
return {"status": "not_initialized"}
return _env.state()
def main():
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)
if __name__ == "__main__":
main()