jflo commited on
Commit
71b68a0
Β·
1 Parent(s): 47dcaa0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -79
app.py CHANGED
@@ -1,100 +1,223 @@
1
  """
2
- model.py β€” DistilBERT multi-head model definition and loader
 
 
3
  """
4
 
 
 
 
 
 
5
  import torch
6
- import torch.nn as nn
7
- from transformers import DistilBertModel, DistilBertTokenizer
8
  import logging
9
 
 
 
 
 
 
 
 
 
10
  logger = logging.getLogger(__name__)
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- class MultiHeadDistilBERT(nn.Module):
14
- """
15
- Multi-head DistilBERT classifier.
16
- 4 independent classification heads sharing one BERT backbone:
17
- - mood (8 classes)
18
- - exertion (3 classes)
19
- - soreness (17 classes β€” combined region + severity)
20
- - completion (2 classes)
21
- """
22
 
23
- def __init__(
24
- self,
25
- num_moods: int = 8,
26
- num_exertion_levels: int = 3,
27
- num_soreness_classes: int = 17,
28
- num_completion_statuses:int = 2,
29
- ):
30
- super().__init__()
31
-
32
- self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
33
- hidden_size = self.bert.config.hidden_size # 768
34
-
35
- self.dropout = nn.Dropout(0.3)
36
- self.head_dropout = nn.Dropout(0.1)
37
-
38
- # Simple heads for easy tasks
39
- self.mood_head = nn.Linear(hidden_size, num_moods)
40
- self.completion_head = nn.Linear(hidden_size, num_completion_statuses)
41
-
42
- # Deeper head for exertion
43
- self.exertion_head = nn.Sequential(
44
- nn.Linear(hidden_size, 128),
45
- nn.ReLU(),
46
- nn.Dropout(0.2),
47
- nn.Linear(128, num_exertion_levels),
48
- )
49
 
50
- # Deeper head for soreness (hardest task β€” 17 classes)
51
- self.soreness_head = nn.Sequential(
52
- nn.Linear(hidden_size, 256),
53
- nn.ReLU(),
54
- nn.Dropout(0.3),
55
- nn.Linear(256, num_soreness_classes),
56
- )
57
 
58
- def forward(self, input_ids, attention_mask):
59
- outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
60
- cls_output = self.dropout(outputs.last_hidden_state[:, 0, :])
61
- x = self.head_dropout(cls_output)
62
 
63
- return (
64
- self.mood_head(x),
65
- self.exertion_head(x),
66
- self.soreness_head(x),
67
- self.completion_head(x),
68
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
 
71
- def load_model(
72
- model_path: str,
73
- device: torch.device,
74
- num_moods: int = 8,
75
- num_exertion_levels: int = 3,
76
- num_soreness_classes: int = 17,
77
- num_completion_statuses: int = 2,
78
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  """
80
- Instantiate the model, load saved weights, set to eval mode.
81
- Returns (model, tokenizer).
 
82
  """
83
- logger.info(f"Loading model weights from: {model_path}")
84
-
85
- model = MultiHeadDistilBERT(
86
- num_moods=num_moods,
87
- num_exertion_levels=num_exertion_levels,
88
- num_soreness_classes=num_soreness_classes,
89
- num_completion_statuses=num_completion_statuses,
 
 
 
 
 
 
 
 
90
  )
91
 
92
- state_dict = torch.load(model_path, map_location=device, weights_only=True)
93
- model.load_state_dict(state_dict)
94
- model.to(device)
95
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
98
 
99
- logger.info("Model loaded and set to eval mode.")
100
- return model, tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Workout Coach β€” FastAPI Inference App
3
+ Runs DistilBERT classification + Claude debrief generation
4
+ Designed for Hugging Face Spaces with Docker
5
  """
6
 
7
+ from fastapi import FastAPI, HTTPException
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from pydantic import BaseModel, Field
10
+ from contextlib import asynccontextmanager
11
+ from typing import Optional, Dict
12
  import torch
13
+ import anthropic
14
+ import os
15
  import logging
16
 
17
+ from model import MultiHeadDistilBERT, load_model
18
+ from inference import predict, decode_predictions, build_prompt
19
+
20
+ # ─────────────────────────────────────────────
21
+ # LOGGING
22
+ # ─────────────────────────────────────────────
23
+
24
+ logging.basicConfig(level=logging.INFO)
25
  logger = logging.getLogger(__name__)
26
 
27
+ # ─────────────────────────────────────────────
28
+ # LABEL MAPS
29
+ # ─────────────────────────────────────────────
30
+
31
+ MOOD_MAP = {
32
+ 0: "accomplished", 1: "anxious", 2: "distracted",
33
+ 3: "energized", 4: "fatigued", 5: "frustrated",
34
+ 6: "neutral", 7: "positive",
35
+ }
36
+
37
+ EXERTION_MAP = {0: "low", 1: "moderate", 2: "high"}
38
+
39
+ COMPLETION_MAP = {0: "partial", 1: "full"}
40
+
41
+ SORENESS_MAP = {
42
+ 0: "none",
43
+ 1: "biceps_mild", 2: "biceps_moderate",
44
+ 3: "back_mild", 4: "back_moderate", 5: "back_severe",
45
+ 6: "chest_mild", 7: "chest_moderate", 8: "chest_severe",
46
+ 9: "legs_mild", 10: "legs_moderate", 11: "legs_severe",
47
+ 12: "shoulder_mild", 13: "shoulder_moderate", 14: "shoulder_severe",
48
+ 15: "triceps_mild", 16: "triceps_moderate",
49
+ }
50
+
51
+ # ─────────────────────────────────────────────
52
+ # APP STATE β€” model loaded once at startup
53
+ # ─────────────────────────────────────────────
54
+
55
+ app_state = {}
56
+
57
+ @asynccontextmanager
58
+ async def lifespan(app: FastAPI):
59
+ """Load model and tokenizer once at startup, clean up at shutdown."""
60
+ logger.info("Loading DistilBERT model...")
61
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
+ logger.info(f"Using device: {device}")
63
+
64
+ model, tokenizer = load_model(
65
+ model_path=os.getenv("MODEL_PATH", "best_overall_model.pt"),
66
+ device=device,
67
+ )
68
 
69
+ app_state["model"] = model
70
+ app_state["tokenizer"] = tokenizer
71
+ app_state["device"] = device
 
 
 
 
 
 
72
 
73
+ # Anthropic client β€” reads ANTHROPIC_API_KEY from environment
74
+ app_state["anthropic_client"] = anthropic.Anthropic(
75
+ api_key=os.getenv("ANTHROPIC_API_KEY")
76
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ logger.info("Model and clients loaded successfully.")
79
+ yield
 
 
 
 
 
80
 
81
+ # Cleanup
82
+ app_state.clear()
83
+ logger.info("App shutdown complete.")
 
84
 
85
+
86
+ app = FastAPI(
87
+ title="Workout Coach Inference API",
88
+ description="DistilBERT classification + Claude debrief generation",
89
+ version="1.0.0",
90
+ lifespan=lifespan,
91
+ )
92
+
93
+ app.add_middleware(
94
+ CORSMiddleware,
95
+ allow_origins=["*"],
96
+ allow_methods=["*"],
97
+ allow_headers=["*"],
98
+ )
99
+
100
+ # ─────────────────────────────────────────────
101
+ # REQUEST / RESPONSE SCHEMAS
102
+ # ─────────────────────────────────────────────
103
+
104
+ class SessionRequest(BaseModel):
105
+ # Free-text input from the user β€” fed into DistilBERT
106
+ user_text: str = Field(..., min_length=5, max_length=500,
107
+ example="That was really tough, chest is killing me but I feel accomplished.")
108
+
109
+ # UI form fields β€” collected separately in the app
110
+ duration_minutes: int = Field(..., ge=1, le=300, example=45)
111
+ workout_type: str = Field(..., example="upper_body_push")
112
+ user_goal: str = Field(..., example="muscle_gain")
113
+
114
+ # Optional β€” whether to generate the Claude debrief
115
+ generate_debrief: bool = Field(default=True)
116
 
117
 
118
+ class BertLabels(BaseModel):
119
+ mood: str
120
+ exertion: str
121
+ soreness: str
122
+ completion: str
123
+
124
+
125
+ class SessionResponse(BaseModel):
126
+ bert_labels: BertLabels
127
+ debrief: Optional[str] = None
128
+
129
+
130
+ class HealthResponse(BaseModel):
131
+ # model_config suppresses Pydantic's warning about field names
132
+ # starting with "model_" conflicting with its protected namespace
133
+ model_config = {"protected_namespaces": ()}
134
+
135
+ status: str
136
+ device: str
137
+ model_loaded: bool
138
+
139
+ # ─────────────────────────────────────────────
140
+ # ROUTES
141
+ # ─────────────────────────────────────────────
142
+
143
+ @app.get("/health", response_model=HealthResponse)
144
+ def health():
145
+ """Health check β€” confirms model is loaded and ready."""
146
+ return {
147
+ "status": "ok",
148
+ "device": str(app_state.get("device", "unknown")),
149
+ "model_loaded": "model" in app_state,
150
+ }
151
+
152
+
153
+ @app.post("/classify", response_model=SessionResponse)
154
+ def classify_session(req: SessionRequest):
155
  """
156
+ Runs DistilBERT inference on user_text and optionally
157
+ generates a Claude debrief using the classified labels
158
+ combined with the session form data.
159
  """
160
+ model = app_state["model"]
161
+ tokenizer = app_state["tokenizer"]
162
+ device = app_state["device"]
163
+ client = app_state["anthropic_client"]
164
+
165
+ # ── Step 1: DistilBERT inference ─────────────────────────
166
+ try:
167
+ raw_preds = predict(req.user_text, model, tokenizer, device)
168
+ except Exception as e:
169
+ logger.error(f"Inference error: {e}")
170
+ raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
171
+
172
+ # ── Step 2: Decode integer labels β†’ strings ───────────────
173
+ bert_labels = decode_predictions(
174
+ raw_preds, MOOD_MAP, EXERTION_MAP, SORENESS_MAP, COMPLETION_MAP
175
  )
176
 
177
+ # ── Step 3: Optionally generate Claude debrief ────────────
178
+ debrief = None
179
+ if req.generate_debrief:
180
+ prompt = build_prompt(
181
+ bert_labels=bert_labels,
182
+ user_text=req.user_text,
183
+ duration_minutes=req.duration_minutes,
184
+ workout_type=req.workout_type,
185
+ user_goal=req.user_goal,
186
+ )
187
+ try:
188
+ message = client.messages.create(
189
+ model="claude-sonnet-4-6",
190
+ max_tokens=400,
191
+ messages=[{"role": "user", "content": prompt}],
192
+ )
193
+ debrief = message.content[0].text
194
+ except Exception as e:
195
+ logger.error(f"Claude API error: {e}")
196
+ # Debrief failure is non-fatal β€” return labels without debrief
197
+ debrief = None
198
+
199
+ return SessionResponse(
200
+ bert_labels=BertLabels(**bert_labels),
201
+ debrief=debrief,
202
+ )
203
 
 
204
 
205
+ @app.post("/classify/labels-only", response_model=BertLabels)
206
+ def classify_labels_only(req: SessionRequest):
207
+ """
208
+ Runs only DistilBERT inference. Skips Claude.
209
+ Useful for storing labels to DB without generating a debrief yet.
210
+ """
211
+ model = app_state["model"]
212
+ tokenizer = app_state["tokenizer"]
213
+ device = app_state["device"]
214
+
215
+ try:
216
+ raw_preds = predict(req.user_text, model, tokenizer, device)
217
+ bert_labels = decode_predictions(
218
+ raw_preds, MOOD_MAP, EXERTION_MAP, SORENESS_MAP, COMPLETION_MAP
219
+ )
220
+ return BertLabels(**bert_labels)
221
+ except Exception as e:
222
+ logger.error(f"Inference error: {e}")
223
+ raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")