Spaces:
Sleeping
Sleeping
| """ | |
| inference.py β predict, decode, and prompt building utilities | |
| """ | |
| import torch | |
| from typing import Dict, List, Optional | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # POST-WORKOUT FUNCTIONS | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict_post( | |
| text: str, | |
| model, | |
| tokenizer, | |
| device: torch.device, | |
| max_len: int = 128, | |
| ) -> Dict[str, int]: | |
| """ | |
| Run PostWorkoutDistilBERT inference on a single text string. | |
| Returns raw integer predictions for each head. | |
| """ | |
| encoding = tokenizer( | |
| text, | |
| max_length=max_len, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| input_ids = encoding["input_ids"].to(device) | |
| attention_mask = encoding["attention_mask"].to(device) | |
| with torch.no_grad(): | |
| mood_logits, exertion_logits, soreness_region_logits, soreness_severity_logits, completion_logits = model( | |
| input_ids, attention_mask | |
| ) | |
| return { | |
| "mood": mood_logits.argmax(dim=1).item(), | |
| "exertion": exertion_logits.argmax(dim=1).item(), | |
| "soreness_region": soreness_region_logits.argmax(dim=1).item(), | |
| "soreness_severity": soreness_severity_logits.argmax(dim=1).item(), | |
| "completion": completion_logits.argmax(dim=1).item(), | |
| } | |
| def decode_post_predictions( | |
| preds: Dict[str, int], | |
| mood_map: Dict[int, str], | |
| exertion_map: Dict[int, str], | |
| soreness_region_map: Dict[int, str], | |
| soreness_severity_map:Dict[int, str], | |
| completion_map: Dict[int, str], | |
| ) -> Dict[str, str]: | |
| """ | |
| Decode post-workout integer predictions back to human-readable label strings. | |
| """ | |
| return { | |
| "mood": mood_map[preds["mood"]], | |
| "exertion": exertion_map[preds["exertion"]], | |
| "soreness_region": soreness_region_map[preds["soreness_region"]], | |
| "soreness_severity": soreness_severity_map[preds["soreness_severity"]], | |
| "completion": completion_map[preds["completion"]], | |
| } | |
| def build_post_prompt( | |
| bert_labels: Dict[str, str], | |
| user_text: str, | |
| duration_minutes: int, | |
| workout_type: str, | |
| user_goal: str, | |
| ) -> str: | |
| """ | |
| Build the Claude prompt for post-workout debrief generation. | |
| Plain text output only β no Markdown, no asterisks, no symbols. | |
| Section delimiters are ALL-CAPS labels so parse_debrief() can | |
| split the response into named sections reliably. | |
| """ | |
| region = bert_labels["soreness_region"] | |
| severity = bert_labels["soreness_severity"] | |
| if region == "none" or severity == "none": | |
| soreness_str = "no soreness" | |
| else: | |
| soreness_str = f"{severity} {region} soreness" | |
| completion_str = ( | |
| "completed the full session" | |
| if bert_labels["completion"] == "full" | |
| else "partially completed the session" | |
| ) | |
| prompt = f"""You are an encouraging personal fitness coach writing a post-workout debrief for a user. | |
| Use plain text only β no Markdown, no asterisks, no bold, no bullet points, no special symbols. | |
| Session summary: | |
| - Workout type: {workout_type} | |
| - Duration: {duration_minutes} minutes | |
| - User goal: {user_goal} | |
| - Completion: {completion_str} | |
| - Exertion level: {bert_labels['exertion']} | |
| - Post-workout mood: {bert_labels['mood']} | |
| - Soreness: {soreness_str} | |
| What the user wrote after their session: | |
| "{user_text}" | |
| Write a personalized debrief using the exact section labels below as delimiters. \ | |
| Do not add any text before ACKNOWLEDGEMENT or after NEXT SESSION. \ | |
| Write one short paragraph per section β warm, concise, and actionable. | |
| ACKNOWLEDGEMENT | |
| [Acknowledge how they felt and what they did β validate their effort regardless of how the session went] | |
| HIGHLIGHTS | |
| [Highlight what went well and give honest context for any soreness or struggles] | |
| NEXT SESSION | |
| [Set them up positively for their next session with one specific actionable tip]""" | |
| return prompt | |
| # ββ Section keys returned by parse_debrief() βββββββββββββββββ | |
| DEBRIEF_SECTIONS = ["ACKNOWLEDGEMENT", "HIGHLIGHTS", "NEXT SESSION"] | |
| def parse_debrief(raw: str) -> Dict[str, str]: | |
| """ | |
| Split a plain-text post-workout debrief into named sections. | |
| Returns a dict with keys: | |
| "acknowledgement" β how they felt / what they did paragraph | |
| "highlights" β what went well / soreness context paragraph | |
| "next_session" β forward-looking actionable tip paragraph | |
| "raw" β original unmodified response (fallback) | |
| If a section is missing the key maps to "". | |
| """ | |
| result = { | |
| "acknowledgement": "", | |
| "highlights": "", | |
| "next_session": "", | |
| "raw": raw, | |
| } | |
| text = raw.replace("\r\n", "\n").strip() | |
| # Build a map of {section_label: start_index} for every label found | |
| indices: Dict[str, int] = {} | |
| for label in DEBRIEF_SECTIONS: | |
| idx = text.find(label) | |
| if idx != -1: | |
| indices[label] = idx | |
| ordered = sorted(indices.items(), key=lambda x: x[1]) | |
| for i, (label, start) in enumerate(ordered): | |
| content_start = start + len(label) | |
| content_end = ordered[i + 1][1] if i + 1 < len(ordered) else len(text) | |
| content = text[content_start:content_end].strip() | |
| key_map = { | |
| "ACKNOWLEDGEMENT": "acknowledgement", | |
| "HIGHLIGHTS": "highlights", | |
| "NEXT SESSION": "next_session", | |
| } | |
| result[key_map[label]] = content | |
| return result | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # PRE-WORKOUT FUNCTIONS | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict_pre( | |
| text: str, | |
| model, | |
| tokenizer, | |
| device: "torch.device", | |
| max_len: int = 128, | |
| ) -> Dict[str, int]: | |
| """ | |
| Run PreWorkoutDistilBERT inference on a single text string. | |
| Returns raw integer predictions for each of the 6 heads. | |
| """ | |
| encoding = tokenizer( | |
| text, | |
| max_length=max_len, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| input_ids = encoding["input_ids"].to(device) | |
| attention_mask = encoding["attention_mask"].to(device) | |
| with torch.no_grad(): | |
| ( | |
| mood_logits, | |
| energy_logits, | |
| motivation_logits, | |
| stress_logits, | |
| soreness_region_logits, | |
| soreness_severity_logits, | |
| ) = model(input_ids, attention_mask) | |
| return { | |
| "mood": mood_logits.argmax(dim=1).item(), | |
| "energy": energy_logits.argmax(dim=1).item(), | |
| "motivation": motivation_logits.argmax(dim=1).item(), | |
| "stress": stress_logits.argmax(dim=1).item(), | |
| "soreness_region": soreness_region_logits.argmax(dim=1).item(), | |
| "soreness_severity": soreness_severity_logits.argmax(dim=1).item(), | |
| } | |
| def decode_pre_predictions( | |
| preds: Dict[str, int], | |
| mood_map: Dict[int, str], | |
| energy_map: Dict[int, str], | |
| motivation_map: Dict[int, str], | |
| stress_map: Dict[int, str], | |
| soreness_region_map: Dict[int, str], | |
| soreness_severity_map:Dict[int, str], | |
| ) -> Dict[str, str]: | |
| """ | |
| Decode pre-workout integer predictions back to human-readable strings. | |
| """ | |
| return { | |
| "mood": mood_map[preds["mood"]], | |
| "energy": energy_map[preds["energy"]], | |
| "motivation": motivation_map[preds["motivation"]], | |
| "stress": stress_map[preds["stress"]], | |
| "soreness_region": soreness_region_map[preds["soreness_region"]], | |
| "soreness_severity": soreness_severity_map[preds["soreness_severity"]], | |
| } | |
| def build_pre_prompt( | |
| bert_labels: Dict[str, str], | |
| user_text: str, | |
| workout_type: str, | |
| duration_minutes: int, | |
| user_goal: str, | |
| equipment: List[str], | |
| ) -> str: | |
| """ | |
| Build the Claude prompt that generates a structured pre-workout plan. | |
| Plain text output only β no Markdown, no asterisks, no symbols. | |
| Section delimiters are ALL-CAPS labels so parse_workout_plan() | |
| can split the response into named sections reliably. | |
| """ | |
| region = bert_labels["soreness_region"] | |
| severity = bert_labels["soreness_severity"] | |
| if region == "none" or severity == "none": | |
| soreness_str = "no existing soreness" | |
| else: | |
| soreness_str = f"{severity} {region} soreness going in" | |
| equipment_str = ( | |
| ", ".join(equipment) | |
| if equipment | |
| else "bodyweight only" | |
| ) | |
| prompt = f"""You are an expert personal trainer generating a structured pre-workout plan. | |
| The user has described how they feel before training. Use their physical and mental state \ | |
| to prescribe the most appropriate session for them right now. | |
| User state (classified from what they wrote): | |
| - Mood: {bert_labels['mood']} | |
| - Energy level: {bert_labels['energy']} | |
| - Motivation: {bert_labels['motivation']} | |
| - Stress level: {bert_labels['stress']} | |
| - Existing soreness: {soreness_str} | |
| Session parameters (selected in app): | |
| - Workout type: {workout_type} | |
| - Duration: {duration_minutes} minutes | |
| - Goal: {user_goal} | |
| - Available equipment:{equipment_str} | |
| What the user wrote before their session: | |
| "{user_text}" | |
| Generate a complete structured workout plan. Use plain text only β no Markdown, \ | |
| no asterisks, no bold, no hyphens as bullet points, no special symbols of any kind. | |
| Use the exact section labels below as delimiters. Do not add any text before \ | |
| WARM UP or after COACHING NOTE. | |
| WARM UP | |
| [list each exercise on its own line as: Exercise Name | sets/reps or duration] | |
| MAIN WORKOUT | |
| [list each exercise on its own line as: Exercise Name | sets x reps | rest period] | |
| COOL DOWN | |
| [list each exercise on its own line as: Exercise Name | duration] | |
| COACHING NOTE | |
| [2-3 sentences acknowledging their current state, explaining why you prescribed \ | |
| this session, and one actionable tip for today] | |
| Important guidelines: | |
| - If energy is low, reduce volume and intensity β fewer sets, lighter loads | |
| - If stress is high, favour controlled movements over maximal effort | |
| - If motivation is low, keep the session achievable and end on a win | |
| - If soreness is present, programme around that muscle group entirely | |
| - If motivation is high and energy is high, push appropriate intensity | |
| - Match total volume to the duration specified""" | |
| return prompt | |
| # ββ Section keys returned by parse_workout_plan() ββββββββββββ | |
| PLAN_SECTIONS = ["WARM UP", "MAIN WORKOUT", "COOL DOWN", "COACHING NOTE"] | |
| def parse_workout_plan(raw: str) -> Dict[str, str]: | |
| """ | |
| Split a plain-text workout plan response into named sections. | |
| Returns a dict with keys: | |
| "warm_up" β warm up exercises, one per line | |
| "main_workout" β main workout exercises, one per line | |
| "cool_down" β cool down exercises, one per line | |
| "coaching_note" β the coaching note paragraph | |
| "raw" β original unmodified response (fallback) | |
| Each exercise line uses pipe-separated fields: | |
| "Exercise Name | sets x reps | rest period" | |
| which PreWorkoutView splits on "|" to style each field independently. | |
| If a section is missing from the response the key maps to "". | |
| """ | |
| result = { | |
| "warm_up": "", | |
| "main_workout": "", | |
| "cool_down": "", | |
| "coaching_note": "", | |
| "raw": raw, | |
| } | |
| # Normalise line endings | |
| text = raw.replace("\r\n", "\n").strip() | |
| # Build a map of {section_label: start_index} for every label found | |
| indices: Dict[str, int] = {} | |
| for label in PLAN_SECTIONS: | |
| idx = text.find(label) | |
| if idx != -1: | |
| indices[label] = idx | |
| # Extract the text between each found label and the next | |
| ordered = sorted(indices.items(), key=lambda x: x[1]) | |
| for i, (label, start) in enumerate(ordered): | |
| # Content starts after the label and its newline | |
| content_start = start + len(label) | |
| content_end = ordered[i + 1][1] if i + 1 < len(ordered) else len(text) | |
| content = text[content_start:content_end].strip() | |
| key_map = { | |
| "WARM UP": "warm_up", | |
| "MAIN WORKOUT": "main_workout", | |
| "COOL DOWN": "cool_down", | |
| "COACHING NOTE": "coaching_note", | |
| } | |
| result[key_map[label]] = content | |
| return result | |