Spaces:
Running
Running
File size: 4,324 Bytes
4b39830 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import os
from dotenv import load_dotenv
load_dotenv()
from server.env import MindReadEnv
from server.models import (
MindReadObservation,
StepResult,
SubmitResult,
TaskMeta,
Secret,
GenerateSecretRequest,
HealthResponse,
AskQuestionAction,
)
from server.secret_generator import generate_secret
env = MindReadEnv()
@asynccontextmanager
async def lifespan(app: FastAPI):
# warm up the embedding model on startup
from server.reward import get_embedder
get_embedder()
yield
app = FastAPI(
title="MindRead: Theory of Mind RL Environment",
version="1.0.0",
description=(
"Interactive Theory of Mind training environment. "
"An LLM agent (Detective) must infer a hidden mental state "
"by asking strategic questions to an Oracle. "
"Trains functional theory of mind — the ability to adapt questioning "
"strategy based on Oracle responses."
),
lifespan=lifespan,
)
@app.get("/health", response_model=HealthResponse)
def health():
return HealthResponse(
status="ok",
version="1.0.0",
oracle_backend="groq/llama-3.1-8b-instant",
)
@app.get("/tasks", response_model=list[TaskMeta])
def get_tasks():
return env.get_tasks()
class ResetRequest(BaseModel):
task_id: str
secret_id: Optional[str] = None
@app.post("/reset", response_model=MindReadObservation)
def reset(request: ResetRequest):
try:
return env.reset(task_id=request.task_id, secret_id=request.secret_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except RuntimeError as e:
raise HTTPException(status_code=500, detail=str(e))
class StepRequest(BaseModel):
episode_id: str
action: AskQuestionAction
@app.post("/step", response_model=StepResult)
def step(request: StepRequest):
action = request.action
if action.action != "ask_question":
raise HTTPException(
status_code=400,
detail="Use /submit to submit a hypothesis. /step only accepts ask_question.",
)
if not action.question or not action.question.strip():
raise HTTPException(status_code=400, detail="Question must not be empty.")
try:
return env.step(request.episode_id, action.question.strip())
except KeyError as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
class SubmitRequest(BaseModel):
episode_id: str
hypothesis: str
category_prediction: Optional[str] = None
@app.post("/submit", response_model=SubmitResult)
def submit(request: SubmitRequest):
if not request.hypothesis or not request.hypothesis.strip():
raise HTTPException(status_code=400, detail="Hypothesis must not be empty.")
try:
return env.submit(
episode_id=request.episode_id,
hypothesis=request.hypothesis.strip(),
category_prediction=request.category_prediction,
)
except KeyError as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/state/{episode_id}", response_model=MindReadObservation)
def get_state(episode_id: str):
try:
return env.get_state(episode_id)
except KeyError as e:
raise HTTPException(status_code=404, detail=str(e))
@app.post("/generate_secret")
def generate_secret_endpoint(request: GenerateSecretRequest):
try:
secret_data = generate_secret(
category=request.category,
difficulty=request.difficulty,
domain=request.domain,
)
secret = Secret(**secret_data)
env.add_secret(secret)
obs = env.reset(task_id=secret.task_id, secret_id=secret.id)
return {"secret": secret_data, "episode_id": obs.episode_id}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
|