personal_trainer / inference.py
jflo's picture
Update inference.py
e6d9ec1 verified
"""
inference.py β€” predict, decode, and prompt building utilities
"""
import torch
from typing import Dict, List, Optional
# ─────────────────────────────────────────────
# POST-WORKOUT FUNCTIONS
# ─────────────────────────────────────────────
def predict_post(
text: str,
model,
tokenizer,
device: torch.device,
max_len: int = 128,
) -> Dict[str, int]:
"""
Run PostWorkoutDistilBERT inference on a single text string.
Returns raw integer predictions for each head.
"""
encoding = tokenizer(
text,
max_length=max_len,
padding="max_length",
truncation=True,
return_tensors="pt",
)
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
with torch.no_grad():
mood_logits, exertion_logits, soreness_region_logits, soreness_severity_logits, completion_logits = model(
input_ids, attention_mask
)
return {
"mood": mood_logits.argmax(dim=1).item(),
"exertion": exertion_logits.argmax(dim=1).item(),
"soreness_region": soreness_region_logits.argmax(dim=1).item(),
"soreness_severity": soreness_severity_logits.argmax(dim=1).item(),
"completion": completion_logits.argmax(dim=1).item(),
}
def decode_post_predictions(
preds: Dict[str, int],
mood_map: Dict[int, str],
exertion_map: Dict[int, str],
soreness_region_map: Dict[int, str],
soreness_severity_map:Dict[int, str],
completion_map: Dict[int, str],
) -> Dict[str, str]:
"""
Decode post-workout integer predictions back to human-readable label strings.
"""
return {
"mood": mood_map[preds["mood"]],
"exertion": exertion_map[preds["exertion"]],
"soreness_region": soreness_region_map[preds["soreness_region"]],
"soreness_severity": soreness_severity_map[preds["soreness_severity"]],
"completion": completion_map[preds["completion"]],
}
def build_post_prompt(
bert_labels: Dict[str, str],
user_text: str,
duration_minutes: int,
workout_type: str,
user_goal: str,
) -> str:
"""
Build the Claude prompt for post-workout debrief generation.
Plain text output only β€” no Markdown, no asterisks, no symbols.
Section delimiters are ALL-CAPS labels so parse_debrief() can
split the response into named sections reliably.
"""
region = bert_labels["soreness_region"]
severity = bert_labels["soreness_severity"]
if region == "none" or severity == "none":
soreness_str = "no soreness"
else:
soreness_str = f"{severity} {region} soreness"
completion_str = (
"completed the full session"
if bert_labels["completion"] == "full"
else "partially completed the session"
)
prompt = f"""You are an encouraging personal fitness coach writing a post-workout debrief for a user.
Use plain text only β€” no Markdown, no asterisks, no bold, no bullet points, no special symbols.
Session summary:
- Workout type: {workout_type}
- Duration: {duration_minutes} minutes
- User goal: {user_goal}
- Completion: {completion_str}
- Exertion level: {bert_labels['exertion']}
- Post-workout mood: {bert_labels['mood']}
- Soreness: {soreness_str}
What the user wrote after their session:
"{user_text}"
Write a personalized debrief using the exact section labels below as delimiters. \
Do not add any text before ACKNOWLEDGEMENT or after NEXT SESSION. \
Write one short paragraph per section β€” warm, concise, and actionable.
ACKNOWLEDGEMENT
[Acknowledge how they felt and what they did β€” validate their effort regardless of how the session went]
HIGHLIGHTS
[Highlight what went well and give honest context for any soreness or struggles]
NEXT SESSION
[Set them up positively for their next session with one specific actionable tip]"""
return prompt
# ── Section keys returned by parse_debrief() ─────────────────
DEBRIEF_SECTIONS = ["ACKNOWLEDGEMENT", "HIGHLIGHTS", "NEXT SESSION"]
def parse_debrief(raw: str) -> Dict[str, str]:
"""
Split a plain-text post-workout debrief into named sections.
Returns a dict with keys:
"acknowledgement" β€” how they felt / what they did paragraph
"highlights" β€” what went well / soreness context paragraph
"next_session" β€” forward-looking actionable tip paragraph
"raw" β€” original unmodified response (fallback)
If a section is missing the key maps to "".
"""
result = {
"acknowledgement": "",
"highlights": "",
"next_session": "",
"raw": raw,
}
text = raw.replace("\r\n", "\n").strip()
# Build a map of {section_label: start_index} for every label found
indices: Dict[str, int] = {}
for label in DEBRIEF_SECTIONS:
idx = text.find(label)
if idx != -1:
indices[label] = idx
ordered = sorted(indices.items(), key=lambda x: x[1])
for i, (label, start) in enumerate(ordered):
content_start = start + len(label)
content_end = ordered[i + 1][1] if i + 1 < len(ordered) else len(text)
content = text[content_start:content_end].strip()
key_map = {
"ACKNOWLEDGEMENT": "acknowledgement",
"HIGHLIGHTS": "highlights",
"NEXT SESSION": "next_session",
}
result[key_map[label]] = content
return result
# ─────────────────────────────────────────────
# PRE-WORKOUT FUNCTIONS
# ─────────────────────────────────────────────
def predict_pre(
text: str,
model,
tokenizer,
device: "torch.device",
max_len: int = 128,
) -> Dict[str, int]:
"""
Run PreWorkoutDistilBERT inference on a single text string.
Returns raw integer predictions for each of the 6 heads.
"""
encoding = tokenizer(
text,
max_length=max_len,
padding="max_length",
truncation=True,
return_tensors="pt",
)
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
with torch.no_grad():
(
mood_logits,
energy_logits,
motivation_logits,
stress_logits,
soreness_region_logits,
soreness_severity_logits,
) = model(input_ids, attention_mask)
return {
"mood": mood_logits.argmax(dim=1).item(),
"energy": energy_logits.argmax(dim=1).item(),
"motivation": motivation_logits.argmax(dim=1).item(),
"stress": stress_logits.argmax(dim=1).item(),
"soreness_region": soreness_region_logits.argmax(dim=1).item(),
"soreness_severity": soreness_severity_logits.argmax(dim=1).item(),
}
def decode_pre_predictions(
preds: Dict[str, int],
mood_map: Dict[int, str],
energy_map: Dict[int, str],
motivation_map: Dict[int, str],
stress_map: Dict[int, str],
soreness_region_map: Dict[int, str],
soreness_severity_map:Dict[int, str],
) -> Dict[str, str]:
"""
Decode pre-workout integer predictions back to human-readable strings.
"""
return {
"mood": mood_map[preds["mood"]],
"energy": energy_map[preds["energy"]],
"motivation": motivation_map[preds["motivation"]],
"stress": stress_map[preds["stress"]],
"soreness_region": soreness_region_map[preds["soreness_region"]],
"soreness_severity": soreness_severity_map[preds["soreness_severity"]],
}
def build_pre_prompt(
bert_labels: Dict[str, str],
user_text: str,
workout_type: str,
duration_minutes: int,
user_goal: str,
equipment: List[str],
) -> str:
"""
Build the Claude prompt that generates a structured pre-workout plan.
Plain text output only β€” no Markdown, no asterisks, no symbols.
Section delimiters are ALL-CAPS labels so parse_workout_plan()
can split the response into named sections reliably.
"""
region = bert_labels["soreness_region"]
severity = bert_labels["soreness_severity"]
if region == "none" or severity == "none":
soreness_str = "no existing soreness"
else:
soreness_str = f"{severity} {region} soreness going in"
equipment_str = (
", ".join(equipment)
if equipment
else "bodyweight only"
)
prompt = f"""You are an expert personal trainer generating a structured pre-workout plan.
The user has described how they feel before training. Use their physical and mental state \
to prescribe the most appropriate session for them right now.
User state (classified from what they wrote):
- Mood: {bert_labels['mood']}
- Energy level: {bert_labels['energy']}
- Motivation: {bert_labels['motivation']}
- Stress level: {bert_labels['stress']}
- Existing soreness: {soreness_str}
Session parameters (selected in app):
- Workout type: {workout_type}
- Duration: {duration_minutes} minutes
- Goal: {user_goal}
- Available equipment:{equipment_str}
What the user wrote before their session:
"{user_text}"
Generate a complete structured workout plan. Use plain text only β€” no Markdown, \
no asterisks, no bold, no hyphens as bullet points, no special symbols of any kind.
Use the exact section labels below as delimiters. Do not add any text before \
WARM UP or after COACHING NOTE.
WARM UP
[list each exercise on its own line as: Exercise Name | sets/reps or duration]
MAIN WORKOUT
[list each exercise on its own line as: Exercise Name | sets x reps | rest period]
COOL DOWN
[list each exercise on its own line as: Exercise Name | duration]
COACHING NOTE
[2-3 sentences acknowledging their current state, explaining why you prescribed \
this session, and one actionable tip for today]
Important guidelines:
- If energy is low, reduce volume and intensity β€” fewer sets, lighter loads
- If stress is high, favour controlled movements over maximal effort
- If motivation is low, keep the session achievable and end on a win
- If soreness is present, programme around that muscle group entirely
- If motivation is high and energy is high, push appropriate intensity
- Match total volume to the duration specified"""
return prompt
# ── Section keys returned by parse_workout_plan() ────────────
PLAN_SECTIONS = ["WARM UP", "MAIN WORKOUT", "COOL DOWN", "COACHING NOTE"]
def parse_workout_plan(raw: str) -> Dict[str, str]:
"""
Split a plain-text workout plan response into named sections.
Returns a dict with keys:
"warm_up" β€” warm up exercises, one per line
"main_workout" β€” main workout exercises, one per line
"cool_down" β€” cool down exercises, one per line
"coaching_note" β€” the coaching note paragraph
"raw" β€” original unmodified response (fallback)
Each exercise line uses pipe-separated fields:
"Exercise Name | sets x reps | rest period"
which PreWorkoutView splits on "|" to style each field independently.
If a section is missing from the response the key maps to "".
"""
result = {
"warm_up": "",
"main_workout": "",
"cool_down": "",
"coaching_note": "",
"raw": raw,
}
# Normalise line endings
text = raw.replace("\r\n", "\n").strip()
# Build a map of {section_label: start_index} for every label found
indices: Dict[str, int] = {}
for label in PLAN_SECTIONS:
idx = text.find(label)
if idx != -1:
indices[label] = idx
# Extract the text between each found label and the next
ordered = sorted(indices.items(), key=lambda x: x[1])
for i, (label, start) in enumerate(ordered):
# Content starts after the label and its newline
content_start = start + len(label)
content_end = ordered[i + 1][1] if i + 1 < len(ordered) else len(text)
content = text[content_start:content_end].strip()
key_map = {
"WARM UP": "warm_up",
"MAIN WORKOUT": "main_workout",
"COOL DOWN": "cool_down",
"COACHING NOTE": "coaching_note",
}
result[key_map[label]] = content
return result