Developer-Amar commited on
Commit
52eb44f
Β·
1 Parent(s): 519736d

Fix: make reset body optional, fix inference.py env var format

Browse files
Files changed (2) hide show
  1. inference.py +8 -4
  2. main.py +36 -1
inference.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  """
2
  Inference Script β€” SocraticEnv
3
  ================================
@@ -19,10 +23,10 @@ from dotenv import load_dotenv
19
  load_dotenv()
20
 
21
  # ── Config ────────────────────────────────────────────────
22
- API_BASE_URL = os.getenv("API_BASE_URL", "https://api-inference.huggingface.co/v1")
23
- MODEL_NAME = os.getenv("MODEL_NAME", "mistralai/Mistral-7B-Instruct-v0.3")
24
- HF_TOKEN = os.getenv("HF_TOKEN", "")
25
- ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
26
 
27
  MAX_TURNS = 10
28
  TEMPERATURE = 0.3
 
1
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/novita/v3/openai")
2
+ MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/llama-3.1-8b-instruct")
3
+ HF_TOKEN = os.getenv("HF_TOKEN")
4
+
5
  """
6
  Inference Script β€” SocraticEnv
7
  ================================
 
23
  load_dotenv()
24
 
25
  # ── Config ────────────────────────────────────────────────
26
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/novita/v3/openai")
27
+ MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/llama-3.1-8b-instruct")
28
+ HF_TOKEN = os.getenv("HF_TOKEN")
29
+
30
 
31
  MAX_TURNS = 10
32
  TEMPERATURE = 0.3
main.py CHANGED
@@ -44,6 +44,16 @@ env = SocraticEnvironment()
44
  class ResetRequest(BaseModel):
45
  task_id: str = "factual_recall"
46
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  class StepRequest(BaseModel):
49
  response: str
@@ -141,7 +151,32 @@ def list_tasks():
141
 
142
 
143
  @app.post("/reset")
144
- def reset(req: ResetRequest):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  """
146
  Start a new episode for the given task.
147
  Returns the first observation (tutor's opening question).
 
44
  class ResetRequest(BaseModel):
45
  task_id: str = "factual_recall"
46
 
47
+ @classmethod
48
+ def __get_validators__(cls):
49
+ yield cls._validate
50
+
51
+ @classmethod
52
+ def _validate(cls, v):
53
+ if v is None:
54
+ return cls()
55
+ return cls(**v) if isinstance(v, dict) else v
56
+
57
 
58
  class StepRequest(BaseModel):
59
  response: str
 
151
 
152
 
153
  @app.post("/reset")
154
+ def reset(req: Optional[ResetRequest] = None):
155
+ """
156
+ Start a new episode for the given task.
157
+ Returns the first observation (tutor's opening question).
158
+ Accepts empty body β€” defaults to factual_recall.
159
+ """
160
+ if req is None:
161
+ req = ResetRequest()
162
+
163
+ valid_tasks = [
164
+ "factual_recall", "socratic_dialogue", "misconception_trap",
165
+ "debate_mode", "analogy_challenge"
166
+ ]
167
+ if req.task_id not in valid_tasks:
168
+ raise HTTPException(
169
+ status_code=400,
170
+ detail=f"Invalid task_id '{req.task_id}'. Choose from: {valid_tasks}",
171
+ )
172
+ try:
173
+ obs = env.reset(req.task_id)
174
+ return {
175
+ "observation": obs.model_dump(),
176
+ "message": f"Episode started for task: {req.task_id}",
177
+ }
178
+ except Exception as e:
179
+ raise HTTPException(status_code=500, detail=str(e))
180
  """
181
  Start a new episode for the given task.
182
  Returns the first observation (tutor's opening question).