""" inference.py — predict, decode, and prompt building utilities """ import torch from typing import Dict, List, Optional # ───────────────────────────────────────────── # POST-WORKOUT FUNCTIONS # ───────────────────────────────────────────── def predict_post( text: str, model, tokenizer, device: torch.device, max_len: int = 128, ) -> Dict[str, int]: """ Run PostWorkoutDistilBERT inference on a single text string. Returns raw integer predictions for each head. """ encoding = tokenizer( text, max_length=max_len, padding="max_length", truncation=True, return_tensors="pt", ) input_ids = encoding["input_ids"].to(device) attention_mask = encoding["attention_mask"].to(device) with torch.no_grad(): mood_logits, exertion_logits, soreness_region_logits, soreness_severity_logits, completion_logits = model( input_ids, attention_mask ) return { "mood": mood_logits.argmax(dim=1).item(), "exertion": exertion_logits.argmax(dim=1).item(), "soreness_region": soreness_region_logits.argmax(dim=1).item(), "soreness_severity": soreness_severity_logits.argmax(dim=1).item(), "completion": completion_logits.argmax(dim=1).item(), } def decode_post_predictions( preds: Dict[str, int], mood_map: Dict[int, str], exertion_map: Dict[int, str], soreness_region_map: Dict[int, str], soreness_severity_map:Dict[int, str], completion_map: Dict[int, str], ) -> Dict[str, str]: """ Decode post-workout integer predictions back to human-readable label strings. """ return { "mood": mood_map[preds["mood"]], "exertion": exertion_map[preds["exertion"]], "soreness_region": soreness_region_map[preds["soreness_region"]], "soreness_severity": soreness_severity_map[preds["soreness_severity"]], "completion": completion_map[preds["completion"]], } def build_post_prompt( bert_labels: Dict[str, str], user_text: str, duration_minutes: int, workout_type: str, user_goal: str, ) -> str: """ Build the Claude prompt for post-workout debrief generation. Plain text output only — no Markdown, no asterisks, no symbols. Section delimiters are ALL-CAPS labels so parse_debrief() can split the response into named sections reliably. """ region = bert_labels["soreness_region"] severity = bert_labels["soreness_severity"] if region == "none" or severity == "none": soreness_str = "no soreness" else: soreness_str = f"{severity} {region} soreness" completion_str = ( "completed the full session" if bert_labels["completion"] == "full" else "partially completed the session" ) prompt = f"""You are an encouraging personal fitness coach writing a post-workout debrief for a user. Use plain text only — no Markdown, no asterisks, no bold, no bullet points, no special symbols. Session summary: - Workout type: {workout_type} - Duration: {duration_minutes} minutes - User goal: {user_goal} - Completion: {completion_str} - Exertion level: {bert_labels['exertion']} - Post-workout mood: {bert_labels['mood']} - Soreness: {soreness_str} What the user wrote after their session: "{user_text}" Write a personalized debrief using the exact section labels below as delimiters. \ Do not add any text before ACKNOWLEDGEMENT or after NEXT SESSION. \ Write one short paragraph per section — warm, concise, and actionable. ACKNOWLEDGEMENT [Acknowledge how they felt and what they did — validate their effort regardless of how the session went] HIGHLIGHTS [Highlight what went well and give honest context for any soreness or struggles] NEXT SESSION [Set them up positively for their next session with one specific actionable tip]""" return prompt # ── Section keys returned by parse_debrief() ───────────────── DEBRIEF_SECTIONS = ["ACKNOWLEDGEMENT", "HIGHLIGHTS", "NEXT SESSION"] def parse_debrief(raw: str) -> Dict[str, str]: """ Split a plain-text post-workout debrief into named sections. Returns a dict with keys: "acknowledgement" — how they felt / what they did paragraph "highlights" — what went well / soreness context paragraph "next_session" — forward-looking actionable tip paragraph "raw" — original unmodified response (fallback) If a section is missing the key maps to "". """ result = { "acknowledgement": "", "highlights": "", "next_session": "", "raw": raw, } text = raw.replace("\r\n", "\n").strip() # Build a map of {section_label: start_index} for every label found indices: Dict[str, int] = {} for label in DEBRIEF_SECTIONS: idx = text.find(label) if idx != -1: indices[label] = idx ordered = sorted(indices.items(), key=lambda x: x[1]) for i, (label, start) in enumerate(ordered): content_start = start + len(label) content_end = ordered[i + 1][1] if i + 1 < len(ordered) else len(text) content = text[content_start:content_end].strip() key_map = { "ACKNOWLEDGEMENT": "acknowledgement", "HIGHLIGHTS": "highlights", "NEXT SESSION": "next_session", } result[key_map[label]] = content return result # ───────────────────────────────────────────── # PRE-WORKOUT FUNCTIONS # ───────────────────────────────────────────── def predict_pre( text: str, model, tokenizer, device: "torch.device", max_len: int = 128, ) -> Dict[str, int]: """ Run PreWorkoutDistilBERT inference on a single text string. Returns raw integer predictions for each of the 6 heads. """ encoding = tokenizer( text, max_length=max_len, padding="max_length", truncation=True, return_tensors="pt", ) input_ids = encoding["input_ids"].to(device) attention_mask = encoding["attention_mask"].to(device) with torch.no_grad(): ( mood_logits, energy_logits, motivation_logits, stress_logits, soreness_region_logits, soreness_severity_logits, ) = model(input_ids, attention_mask) return { "mood": mood_logits.argmax(dim=1).item(), "energy": energy_logits.argmax(dim=1).item(), "motivation": motivation_logits.argmax(dim=1).item(), "stress": stress_logits.argmax(dim=1).item(), "soreness_region": soreness_region_logits.argmax(dim=1).item(), "soreness_severity": soreness_severity_logits.argmax(dim=1).item(), } def decode_pre_predictions( preds: Dict[str, int], mood_map: Dict[int, str], energy_map: Dict[int, str], motivation_map: Dict[int, str], stress_map: Dict[int, str], soreness_region_map: Dict[int, str], soreness_severity_map:Dict[int, str], ) -> Dict[str, str]: """ Decode pre-workout integer predictions back to human-readable strings. """ return { "mood": mood_map[preds["mood"]], "energy": energy_map[preds["energy"]], "motivation": motivation_map[preds["motivation"]], "stress": stress_map[preds["stress"]], "soreness_region": soreness_region_map[preds["soreness_region"]], "soreness_severity": soreness_severity_map[preds["soreness_severity"]], } def build_pre_prompt( bert_labels: Dict[str, str], user_text: str, workout_type: str, duration_minutes: int, user_goal: str, equipment: List[str], ) -> str: """ Build the Claude prompt that generates a structured pre-workout plan. Plain text output only — no Markdown, no asterisks, no symbols. Section delimiters are ALL-CAPS labels so parse_workout_plan() can split the response into named sections reliably. """ region = bert_labels["soreness_region"] severity = bert_labels["soreness_severity"] if region == "none" or severity == "none": soreness_str = "no existing soreness" else: soreness_str = f"{severity} {region} soreness going in" equipment_str = ( ", ".join(equipment) if equipment else "bodyweight only" ) prompt = f"""You are an expert personal trainer generating a structured pre-workout plan. The user has described how they feel before training. Use their physical and mental state \ to prescribe the most appropriate session for them right now. User state (classified from what they wrote): - Mood: {bert_labels['mood']} - Energy level: {bert_labels['energy']} - Motivation: {bert_labels['motivation']} - Stress level: {bert_labels['stress']} - Existing soreness: {soreness_str} Session parameters (selected in app): - Workout type: {workout_type} - Duration: {duration_minutes} minutes - Goal: {user_goal} - Available equipment:{equipment_str} What the user wrote before their session: "{user_text}" Generate a complete structured workout plan. Use plain text only — no Markdown, \ no asterisks, no bold, no hyphens as bullet points, no special symbols of any kind. Use the exact section labels below as delimiters. Do not add any text before \ WARM UP or after COACHING NOTE. WARM UP [list each exercise on its own line as: Exercise Name | sets/reps or duration] MAIN WORKOUT [list each exercise on its own line as: Exercise Name | sets x reps | rest period] COOL DOWN [list each exercise on its own line as: Exercise Name | duration] COACHING NOTE [2-3 sentences acknowledging their current state, explaining why you prescribed \ this session, and one actionable tip for today] Important guidelines: - If energy is low, reduce volume and intensity — fewer sets, lighter loads - If stress is high, favour controlled movements over maximal effort - If motivation is low, keep the session achievable and end on a win - If soreness is present, programme around that muscle group entirely - If motivation is high and energy is high, push appropriate intensity - Match total volume to the duration specified""" return prompt # ── Section keys returned by parse_workout_plan() ──────────── PLAN_SECTIONS = ["WARM UP", "MAIN WORKOUT", "COOL DOWN", "COACHING NOTE"] def parse_workout_plan(raw: str) -> Dict[str, str]: """ Split a plain-text workout plan response into named sections. Returns a dict with keys: "warm_up" — warm up exercises, one per line "main_workout" — main workout exercises, one per line "cool_down" — cool down exercises, one per line "coaching_note" — the coaching note paragraph "raw" — original unmodified response (fallback) Each exercise line uses pipe-separated fields: "Exercise Name | sets x reps | rest period" which PreWorkoutView splits on "|" to style each field independently. If a section is missing from the response the key maps to "". """ result = { "warm_up": "", "main_workout": "", "cool_down": "", "coaching_note": "", "raw": raw, } # Normalise line endings text = raw.replace("\r\n", "\n").strip() # Build a map of {section_label: start_index} for every label found indices: Dict[str, int] = {} for label in PLAN_SECTIONS: idx = text.find(label) if idx != -1: indices[label] = idx # Extract the text between each found label and the next ordered = sorted(indices.items(), key=lambda x: x[1]) for i, (label, start) in enumerate(ordered): # Content starts after the label and its newline content_start = start + len(label) content_end = ordered[i + 1][1] if i + 1 < len(ordered) else len(text) content = text[content_start:content_end].strip() key_map = { "WARM UP": "warm_up", "MAIN WORKOUT": "main_workout", "COOL DOWN": "cool_down", "COACHING NOTE": "coaching_note", } result[key_map[label]] = content return result