jflo commited on
Commit
df41ebe
Β·
verified Β·
1 Parent(s): b3e94c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -16,7 +16,7 @@ import logging
16
 
17
  from model import PostWorkoutDistilBERT, load_post_model, PreWorkoutDistilBERT, load_pre_model
18
  from inference import (
19
- predict_post, decode_post_predictions, build_post_prompt,
20
  predict_pre, decode_pre_predictions, build_pre_prompt, parse_workout_plan,
21
  )
22
 
@@ -97,7 +97,7 @@ async def lifespan(app: FastAPI):
97
  # ── Post-workout model ────────────────────────────────────
98
  logger.info("Loading post-workout model...")
99
  post_model, post_tokenizer = load_post_model(
100
- model_path=os.getenv("MODEL_PATH", "post_best_overall_model.pt"),
101
  device=device,
102
  )
103
  app_state["post_model"] = post_model
@@ -171,8 +171,11 @@ class PostWorkoutBertLabels(BaseModel):
171
 
172
 
173
  class PostWorkoutSessionResponse(BaseModel):
174
- bert_labels: PostWorkoutBertLabels
175
- debrief: Optional[str] = None
 
 
 
176
 
177
 
178
  class HealthResponse(BaseModel):
@@ -229,7 +232,7 @@ def post_classify_session(req: PostWorkoutRequest):
229
  )
230
 
231
  # ── Step 3: Optionally generate Claude debrief ────────────
232
- debrief = None
233
  if req.generate_debrief:
234
  prompt = build_post_prompt(
235
  bert_labels=bert_labels,
@@ -244,14 +247,18 @@ def post_classify_session(req: PostWorkoutRequest):
244
  max_tokens=400,
245
  messages=[{"role": "user", "content": prompt}],
246
  )
247
- debrief = message.content[0].text
 
248
  except Exception as e:
249
  logger.error(f"Claude API error (post-workout): {e}")
250
- debrief = None
251
 
252
  return PostWorkoutSessionResponse(
253
  bert_labels=PostWorkoutBertLabels(**bert_labels),
254
- debrief=debrief,
 
 
 
255
  )
256
 
257
 
 
16
 
17
  from model import PostWorkoutDistilBERT, load_post_model, PreWorkoutDistilBERT, load_pre_model
18
  from inference import (
19
+ predict_post, decode_post_predictions, build_post_prompt, parse_debrief,
20
  predict_pre, decode_pre_predictions, build_pre_prompt, parse_workout_plan,
21
  )
22
 
 
97
  # ── Post-workout model ────────────────────────────────────
98
  logger.info("Loading post-workout model...")
99
  post_model, post_tokenizer = load_post_model(
100
+ model_path=os.getenv("MODEL_PATH", "best_overall_model.pt"),
101
  device=device,
102
  )
103
  app_state["post_model"] = post_model
 
171
 
172
 
173
  class PostWorkoutSessionResponse(BaseModel):
174
+ bert_labels: PostWorkoutBertLabels
175
+ acknowledgement: Optional[str] = None
176
+ highlights: Optional[str] = None
177
+ next_session: Optional[str] = None
178
+ raw_debrief: Optional[str] = None # full unmodified response β€” fallback
179
 
180
 
181
  class HealthResponse(BaseModel):
 
232
  )
233
 
234
  # ── Step 3: Optionally generate Claude debrief ────────────
235
+ parsed = {}
236
  if req.generate_debrief:
237
  prompt = build_post_prompt(
238
  bert_labels=bert_labels,
 
247
  max_tokens=400,
248
  messages=[{"role": "user", "content": prompt}],
249
  )
250
+ raw_debrief = message.content[0].text
251
+ parsed = parse_debrief(raw_debrief)
252
  except Exception as e:
253
  logger.error(f"Claude API error (post-workout): {e}")
254
+ parsed = {}
255
 
256
  return PostWorkoutSessionResponse(
257
  bert_labels=PostWorkoutBertLabels(**bert_labels),
258
+ acknowledgement=parsed.get("acknowledgement"),
259
+ highlights=parsed.get("highlights"),
260
+ next_session=parsed.get("next_session"),
261
+ raw_debrief=parsed.get("raw"),
262
  )
263
 
264