jflo commited on
Commit
64bc6bc
Β·
verified Β·
1 Parent(s): d8610aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -17,7 +17,7 @@ import logging
17
  from model import PostWorkoutDistilBERT, load_post_model, PreWorkoutDistilBERT, load_pre_model
18
  from inference import (
19
  predict_post, decode_post_predictions, build_post_prompt,
20
- predict_pre, decode_pre_predictions, build_pre_prompt,
21
  )
22
 
23
  # ─────────────────────────────────────────────
@@ -97,7 +97,7 @@ async def lifespan(app: FastAPI):
97
  # ── Post-workout model ────────────────────────────────────
98
  logger.info("Loading post-workout model...")
99
  post_model, post_tokenizer = load_post_model(
100
- model_path=os.getenv("MODEL_PATH", "post_best_overall_model.pt"),
101
  device=device,
102
  )
103
  app_state["post_model"] = post_model
@@ -198,7 +198,7 @@ def health():
198
  }
199
 
200
 
201
- @app.post("/post-classify", response_model=PostWorkoutSessionResponse,
202
  dependencies=[Depends(verify_key)])
203
  def post_classify_session(req: PostWorkoutRequest):
204
  """
@@ -255,7 +255,7 @@ def post_classify_session(req: PostWorkoutRequest):
255
  )
256
 
257
 
258
- @app.post("/post-classify/labels-only", response_model=PostWorkoutBertLabels,
259
  dependencies=[Depends(verify_key)])
260
  def post_classify_labels_only(req: PostWorkoutRequest):
261
  """
@@ -308,8 +308,12 @@ class PreWorkoutBertLabels(BaseModel):
308
 
309
 
310
  class PreWorkoutResponse(BaseModel):
311
- bert_labels: PreWorkoutBertLabels
312
- workout_plan: Optional[str] = None
 
 
 
 
313
 
314
 
315
  # ─────────────────────────────────────────────
@@ -348,7 +352,7 @@ def pre_classify_session(req: PreWorkoutRequest):
348
  )
349
 
350
  # ── Step 3: Optionally generate Claude workout plan ───────
351
- workout_plan = None
352
  if req.generate_plan:
353
  prompt = build_pre_prompt(
354
  bert_labels=bert_labels,
@@ -364,14 +368,19 @@ def pre_classify_session(req: PreWorkoutRequest):
364
  max_tokens=800,
365
  messages=[{"role": "user", "content": prompt}],
366
  )
367
- workout_plan = message.content[0].text
 
368
  except Exception as e:
369
  logger.error(f"Claude API error (pre-workout): {e}")
370
- workout_plan = None
371
 
372
  return PreWorkoutResponse(
373
  bert_labels=PreWorkoutBertLabels(**bert_labels),
374
- workout_plan=workout_plan,
 
 
 
 
375
  )
376
 
377
 
 
17
  from model import PostWorkoutDistilBERT, load_post_model, PreWorkoutDistilBERT, load_pre_model
18
  from inference import (
19
  predict_post, decode_post_predictions, build_post_prompt,
20
+ predict_pre, decode_pre_predictions, build_pre_prompt, parse_workout_plan,
21
  )
22
 
23
  # ─────────────────────────────────────────────
 
97
  # ── Post-workout model ────────────────────────────────────
98
  logger.info("Loading post-workout model...")
99
  post_model, post_tokenizer = load_post_model(
100
+ model_path=os.getenv("MODEL_PATH", "best_overall_model.pt"),
101
  device=device,
102
  )
103
  app_state["post_model"] = post_model
 
198
  }
199
 
200
 
201
+ @app.post("/classify", response_model=PostWorkoutSessionResponse,
202
  dependencies=[Depends(verify_key)])
203
  def post_classify_session(req: PostWorkoutRequest):
204
  """
 
255
  )
256
 
257
 
258
+ @app.post("/classify/labels-only", response_model=PostWorkoutBertLabels,
259
  dependencies=[Depends(verify_key)])
260
  def post_classify_labels_only(req: PostWorkoutRequest):
261
  """
 
308
 
309
 
310
  class PreWorkoutResponse(BaseModel):
311
+ bert_labels: PreWorkoutBertLabels
312
+ warm_up: Optional[str] = None
313
+ main_workout: Optional[str] = None
314
+ cool_down: Optional[str] = None
315
+ coaching_note: Optional[str] = None
316
+ raw_plan: Optional[str] = None # full unmodified response β€” fallback
317
 
318
 
319
  # ─────────────────────────────────────────────
 
352
  )
353
 
354
  # ── Step 3: Optionally generate Claude workout plan ───────
355
+ parsed = {}
356
  if req.generate_plan:
357
  prompt = build_pre_prompt(
358
  bert_labels=bert_labels,
 
368
  max_tokens=800,
369
  messages=[{"role": "user", "content": prompt}],
370
  )
371
+ raw_plan = message.content[0].text
372
+ parsed = parse_workout_plan(raw_plan)
373
  except Exception as e:
374
  logger.error(f"Claude API error (pre-workout): {e}")
375
+ parsed = {}
376
 
377
  return PreWorkoutResponse(
378
  bert_labels=PreWorkoutBertLabels(**bert_labels),
379
+ warm_up=parsed.get("warm_up"),
380
+ main_workout=parsed.get("main_workout"),
381
+ cool_down=parsed.get("cool_down"),
382
+ coaching_note=parsed.get("coaching_note"),
383
+ raw_plan=parsed.get("raw"),
384
  )
385
 
386