Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -16,7 +16,7 @@ import logging
|
|
| 16 |
|
| 17 |
from model import PostWorkoutDistilBERT, load_post_model, PreWorkoutDistilBERT, load_pre_model
|
| 18 |
from inference import (
|
| 19 |
-
predict_post, decode_post_predictions, build_post_prompt,
|
| 20 |
predict_pre, decode_pre_predictions, build_pre_prompt, parse_workout_plan,
|
| 21 |
)
|
| 22 |
|
|
@@ -97,7 +97,7 @@ async def lifespan(app: FastAPI):
|
|
| 97 |
# ββ Post-workout model ββββββββββββββββββββββββββββββββββββ
|
| 98 |
logger.info("Loading post-workout model...")
|
| 99 |
post_model, post_tokenizer = load_post_model(
|
| 100 |
-
model_path=os.getenv("MODEL_PATH", "
|
| 101 |
device=device,
|
| 102 |
)
|
| 103 |
app_state["post_model"] = post_model
|
|
@@ -171,8 +171,11 @@ class PostWorkoutBertLabels(BaseModel):
|
|
| 171 |
|
| 172 |
|
| 173 |
class PostWorkoutSessionResponse(BaseModel):
|
| 174 |
-
bert_labels:
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
|
| 178 |
class HealthResponse(BaseModel):
|
|
@@ -229,7 +232,7 @@ def post_classify_session(req: PostWorkoutRequest):
|
|
| 229 |
)
|
| 230 |
|
| 231 |
# ββ Step 3: Optionally generate Claude debrief ββββββββββββ
|
| 232 |
-
|
| 233 |
if req.generate_debrief:
|
| 234 |
prompt = build_post_prompt(
|
| 235 |
bert_labels=bert_labels,
|
|
@@ -244,14 +247,18 @@ def post_classify_session(req: PostWorkoutRequest):
|
|
| 244 |
max_tokens=400,
|
| 245 |
messages=[{"role": "user", "content": prompt}],
|
| 246 |
)
|
| 247 |
-
|
|
|
|
| 248 |
except Exception as e:
|
| 249 |
logger.error(f"Claude API error (post-workout): {e}")
|
| 250 |
-
|
| 251 |
|
| 252 |
return PostWorkoutSessionResponse(
|
| 253 |
bert_labels=PostWorkoutBertLabels(**bert_labels),
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
| 255 |
)
|
| 256 |
|
| 257 |
|
|
|
|
| 16 |
|
| 17 |
from model import PostWorkoutDistilBERT, load_post_model, PreWorkoutDistilBERT, load_pre_model
|
| 18 |
from inference import (
|
| 19 |
+
predict_post, decode_post_predictions, build_post_prompt, parse_debrief,
|
| 20 |
predict_pre, decode_pre_predictions, build_pre_prompt, parse_workout_plan,
|
| 21 |
)
|
| 22 |
|
|
|
|
| 97 |
# ββ Post-workout model ββββββββββββββββββββββββββββββββββββ
|
| 98 |
logger.info("Loading post-workout model...")
|
| 99 |
post_model, post_tokenizer = load_post_model(
|
| 100 |
+
model_path=os.getenv("MODEL_PATH", "best_overall_model.pt"),
|
| 101 |
device=device,
|
| 102 |
)
|
| 103 |
app_state["post_model"] = post_model
|
|
|
|
| 171 |
|
| 172 |
|
| 173 |
class PostWorkoutSessionResponse(BaseModel):
|
| 174 |
+
bert_labels: PostWorkoutBertLabels
|
| 175 |
+
acknowledgement: Optional[str] = None
|
| 176 |
+
highlights: Optional[str] = None
|
| 177 |
+
next_session: Optional[str] = None
|
| 178 |
+
raw_debrief: Optional[str] = None # full unmodified response β fallback
|
| 179 |
|
| 180 |
|
| 181 |
class HealthResponse(BaseModel):
|
|
|
|
| 232 |
)
|
| 233 |
|
| 234 |
# ββ Step 3: Optionally generate Claude debrief ββββββββββββ
|
| 235 |
+
parsed = {}
|
| 236 |
if req.generate_debrief:
|
| 237 |
prompt = build_post_prompt(
|
| 238 |
bert_labels=bert_labels,
|
|
|
|
| 247 |
max_tokens=400,
|
| 248 |
messages=[{"role": "user", "content": prompt}],
|
| 249 |
)
|
| 250 |
+
raw_debrief = message.content[0].text
|
| 251 |
+
parsed = parse_debrief(raw_debrief)
|
| 252 |
except Exception as e:
|
| 253 |
logger.error(f"Claude API error (post-workout): {e}")
|
| 254 |
+
parsed = {}
|
| 255 |
|
| 256 |
return PostWorkoutSessionResponse(
|
| 257 |
bert_labels=PostWorkoutBertLabels(**bert_labels),
|
| 258 |
+
acknowledgement=parsed.get("acknowledgement"),
|
| 259 |
+
highlights=parsed.get("highlights"),
|
| 260 |
+
next_session=parsed.get("next_session"),
|
| 261 |
+
raw_debrief=parsed.get("raw"),
|
| 262 |
)
|
| 263 |
|
| 264 |
|