Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -17,7 +17,7 @@ import logging
|
|
| 17 |
from model import PostWorkoutDistilBERT, load_post_model, PreWorkoutDistilBERT, load_pre_model
|
| 18 |
from inference import (
|
| 19 |
predict_post, decode_post_predictions, build_post_prompt,
|
| 20 |
-
predict_pre, decode_pre_predictions, build_pre_prompt,
|
| 21 |
)
|
| 22 |
|
| 23 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -97,7 +97,7 @@ async def lifespan(app: FastAPI):
|
|
| 97 |
# ββ Post-workout model ββββββββββββββββββββββββββββββββββββ
|
| 98 |
logger.info("Loading post-workout model...")
|
| 99 |
post_model, post_tokenizer = load_post_model(
|
| 100 |
-
model_path=os.getenv("MODEL_PATH", "
|
| 101 |
device=device,
|
| 102 |
)
|
| 103 |
app_state["post_model"] = post_model
|
|
@@ -198,7 +198,7 @@ def health():
|
|
| 198 |
}
|
| 199 |
|
| 200 |
|
| 201 |
-
@app.post("/
|
| 202 |
dependencies=[Depends(verify_key)])
|
| 203 |
def post_classify_session(req: PostWorkoutRequest):
|
| 204 |
"""
|
|
@@ -255,7 +255,7 @@ def post_classify_session(req: PostWorkoutRequest):
|
|
| 255 |
)
|
| 256 |
|
| 257 |
|
| 258 |
-
@app.post("/
|
| 259 |
dependencies=[Depends(verify_key)])
|
| 260 |
def post_classify_labels_only(req: PostWorkoutRequest):
|
| 261 |
"""
|
|
@@ -308,8 +308,12 @@ class PreWorkoutBertLabels(BaseModel):
|
|
| 308 |
|
| 309 |
|
| 310 |
class PreWorkoutResponse(BaseModel):
|
| 311 |
-
bert_labels:
|
| 312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
|
| 315 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -348,7 +352,7 @@ def pre_classify_session(req: PreWorkoutRequest):
|
|
| 348 |
)
|
| 349 |
|
| 350 |
# ββ Step 3: Optionally generate Claude workout plan βββββββ
|
| 351 |
-
|
| 352 |
if req.generate_plan:
|
| 353 |
prompt = build_pre_prompt(
|
| 354 |
bert_labels=bert_labels,
|
|
@@ -364,14 +368,19 @@ def pre_classify_session(req: PreWorkoutRequest):
|
|
| 364 |
max_tokens=800,
|
| 365 |
messages=[{"role": "user", "content": prompt}],
|
| 366 |
)
|
| 367 |
-
|
|
|
|
| 368 |
except Exception as e:
|
| 369 |
logger.error(f"Claude API error (pre-workout): {e}")
|
| 370 |
-
|
| 371 |
|
| 372 |
return PreWorkoutResponse(
|
| 373 |
bert_labels=PreWorkoutBertLabels(**bert_labels),
|
| 374 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
)
|
| 376 |
|
| 377 |
|
|
|
|
| 17 |
from model import PostWorkoutDistilBERT, load_post_model, PreWorkoutDistilBERT, load_pre_model
|
| 18 |
from inference import (
|
| 19 |
predict_post, decode_post_predictions, build_post_prompt,
|
| 20 |
+
predict_pre, decode_pre_predictions, build_pre_prompt, parse_workout_plan,
|
| 21 |
)
|
| 22 |
|
| 23 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 97 |
# ββ Post-workout model ββββββββββββββββββββββββββββββββββββ
|
| 98 |
logger.info("Loading post-workout model...")
|
| 99 |
post_model, post_tokenizer = load_post_model(
|
| 100 |
+
model_path=os.getenv("MODEL_PATH", "best_overall_model.pt"),
|
| 101 |
device=device,
|
| 102 |
)
|
| 103 |
app_state["post_model"] = post_model
|
|
|
|
| 198 |
}
|
| 199 |
|
| 200 |
|
| 201 |
+
@app.post("/classify", response_model=PostWorkoutSessionResponse,
|
| 202 |
dependencies=[Depends(verify_key)])
|
| 203 |
def post_classify_session(req: PostWorkoutRequest):
|
| 204 |
"""
|
|
|
|
| 255 |
)
|
| 256 |
|
| 257 |
|
| 258 |
+
@app.post("/classify/labels-only", response_model=PostWorkoutBertLabels,
|
| 259 |
dependencies=[Depends(verify_key)])
|
| 260 |
def post_classify_labels_only(req: PostWorkoutRequest):
|
| 261 |
"""
|
|
|
|
| 308 |
|
| 309 |
|
| 310 |
class PreWorkoutResponse(BaseModel):
|
| 311 |
+
bert_labels: PreWorkoutBertLabels
|
| 312 |
+
warm_up: Optional[str] = None
|
| 313 |
+
main_workout: Optional[str] = None
|
| 314 |
+
cool_down: Optional[str] = None
|
| 315 |
+
coaching_note: Optional[str] = None
|
| 316 |
+
raw_plan: Optional[str] = None # full unmodified response β fallback
|
| 317 |
|
| 318 |
|
| 319 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 352 |
)
|
| 353 |
|
| 354 |
# ββ Step 3: Optionally generate Claude workout plan βββββββ
|
| 355 |
+
parsed = {}
|
| 356 |
if req.generate_plan:
|
| 357 |
prompt = build_pre_prompt(
|
| 358 |
bert_labels=bert_labels,
|
|
|
|
| 368 |
max_tokens=800,
|
| 369 |
messages=[{"role": "user", "content": prompt}],
|
| 370 |
)
|
| 371 |
+
raw_plan = message.content[0].text
|
| 372 |
+
parsed = parse_workout_plan(raw_plan)
|
| 373 |
except Exception as e:
|
| 374 |
logger.error(f"Claude API error (pre-workout): {e}")
|
| 375 |
+
parsed = {}
|
| 376 |
|
| 377 |
return PreWorkoutResponse(
|
| 378 |
bert_labels=PreWorkoutBertLabels(**bert_labels),
|
| 379 |
+
warm_up=parsed.get("warm_up"),
|
| 380 |
+
main_workout=parsed.get("main_workout"),
|
| 381 |
+
cool_down=parsed.get("cool_down"),
|
| 382 |
+
coaching_note=parsed.get("coaching_note"),
|
| 383 |
+
raw_plan=parsed.get("raw"),
|
| 384 |
)
|
| 385 |
|
| 386 |
|