Mr66 commited on
Commit
4b39830
·
verified ·
1 Parent(s): e2ca55c

Upload server/main.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. server/main.py +145 -0
server/main.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import asynccontextmanager
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel
4
+ from typing import Optional
5
+ import os
6
+
7
+ from dotenv import load_dotenv
8
+
9
+ load_dotenv()
10
+
11
+ from server.env import MindReadEnv
12
+ from server.models import (
13
+ MindReadObservation,
14
+ StepResult,
15
+ SubmitResult,
16
+ TaskMeta,
17
+ Secret,
18
+ GenerateSecretRequest,
19
+ HealthResponse,
20
+ AskQuestionAction,
21
+ )
22
+ from server.secret_generator import generate_secret
23
+
24
+ env = MindReadEnv()
25
+
26
+
27
+ @asynccontextmanager
28
+ async def lifespan(app: FastAPI):
29
+ # warm up the embedding model on startup
30
+ from server.reward import get_embedder
31
+ get_embedder()
32
+ yield
33
+
34
+
35
+ app = FastAPI(
36
+ title="MindRead: Theory of Mind RL Environment",
37
+ version="1.0.0",
38
+ description=(
39
+ "Interactive Theory of Mind training environment. "
40
+ "An LLM agent (Detective) must infer a hidden mental state "
41
+ "by asking strategic questions to an Oracle. "
42
+ "Trains functional theory of mind — the ability to adapt questioning "
43
+ "strategy based on Oracle responses."
44
+ ),
45
+ lifespan=lifespan,
46
+ )
47
+
48
+
49
+ @app.get("/health", response_model=HealthResponse)
50
+ def health():
51
+ return HealthResponse(
52
+ status="ok",
53
+ version="1.0.0",
54
+ oracle_backend="groq/llama-3.1-8b-instant",
55
+ )
56
+
57
+
58
+ @app.get("/tasks", response_model=list[TaskMeta])
59
+ def get_tasks():
60
+ return env.get_tasks()
61
+
62
+
63
+ class ResetRequest(BaseModel):
64
+ task_id: str
65
+ secret_id: Optional[str] = None
66
+
67
+
68
+ @app.post("/reset", response_model=MindReadObservation)
69
+ def reset(request: ResetRequest):
70
+ try:
71
+ return env.reset(task_id=request.task_id, secret_id=request.secret_id)
72
+ except ValueError as e:
73
+ raise HTTPException(status_code=400, detail=str(e))
74
+ except RuntimeError as e:
75
+ raise HTTPException(status_code=500, detail=str(e))
76
+
77
+
78
+ class StepRequest(BaseModel):
79
+ episode_id: str
80
+ action: AskQuestionAction
81
+
82
+
83
+ @app.post("/step", response_model=StepResult)
84
+ def step(request: StepRequest):
85
+ action = request.action
86
+ if action.action != "ask_question":
87
+ raise HTTPException(
88
+ status_code=400,
89
+ detail="Use /submit to submit a hypothesis. /step only accepts ask_question.",
90
+ )
91
+ if not action.question or not action.question.strip():
92
+ raise HTTPException(status_code=400, detail="Question must not be empty.")
93
+ try:
94
+ return env.step(request.episode_id, action.question.strip())
95
+ except KeyError as e:
96
+ raise HTTPException(status_code=404, detail=str(e))
97
+ except ValueError as e:
98
+ raise HTTPException(status_code=400, detail=str(e))
99
+
100
+
101
+ class SubmitRequest(BaseModel):
102
+ episode_id: str
103
+ hypothesis: str
104
+ category_prediction: Optional[str] = None
105
+
106
+
107
+ @app.post("/submit", response_model=SubmitResult)
108
+ def submit(request: SubmitRequest):
109
+ if not request.hypothesis or not request.hypothesis.strip():
110
+ raise HTTPException(status_code=400, detail="Hypothesis must not be empty.")
111
+ try:
112
+ return env.submit(
113
+ episode_id=request.episode_id,
114
+ hypothesis=request.hypothesis.strip(),
115
+ category_prediction=request.category_prediction,
116
+ )
117
+ except KeyError as e:
118
+ raise HTTPException(status_code=404, detail=str(e))
119
+ except ValueError as e:
120
+ raise HTTPException(status_code=400, detail=str(e))
121
+
122
+
123
+ @app.get("/state/{episode_id}", response_model=MindReadObservation)
124
+ def get_state(episode_id: str):
125
+ try:
126
+ return env.get_state(episode_id)
127
+ except KeyError as e:
128
+ raise HTTPException(status_code=404, detail=str(e))
129
+
130
+
131
+ @app.post("/generate_secret")
132
+ def generate_secret_endpoint(request: GenerateSecretRequest):
133
+ try:
134
+ secret_data = generate_secret(
135
+ category=request.category,
136
+ difficulty=request.difficulty,
137
+ domain=request.domain,
138
+ )
139
+ secret = Secret(**secret_data)
140
+ env.add_secret(secret)
141
+
142
+ obs = env.reset(task_id=secret.task_id, secret_id=secret.id)
143
+ return {"secret": secret_data, "episode_id": obs.episode_id}
144
+ except Exception as e:
145
+ raise HTTPException(status_code=500, detail=str(e))