jflo commited on
Commit
ae68031
Β·
verified Β·
1 Parent(s): a6ef2d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -96
app.py CHANGED
@@ -9,33 +9,23 @@ import torch
9
  import torch.nn as nn
10
  from transformers import DistilBertModel, DistilBertTokenizer
11
 
12
- app = FastAPI()
13
-
14
- device = torch.device('cpu') # Hugging Face Space with no GPU
15
 
 
16
  workout_label_map = {
17
- 0: "Chest",
18
- 1: "Back",
19
- 2: "Legs",
20
- 3: "Shoulders",
21
- 4: "Arms",
22
- 5: "Core",
23
- 6: "Full Body",
24
- 7: "Cardio"
25
  }
26
 
27
  mood_label_map = {
28
- 0: "Energized",
29
- 1: "Tired",
30
- 2: "Stressed",
31
- 3: "Motivated",
32
- 4: "Neutral"
33
  }
34
 
35
  soreness_label_map = {
36
- 0: "None",
37
- 1: "Mild",
38
- 2: "Severe"
39
  }
40
 
41
  class MultiHeadDistilBERT(nn.Module):
@@ -46,12 +36,13 @@ class MultiHeadDistilBERT(nn.Module):
46
  self.bert = DistilBertModel.from_pretrained('distilbert-base-uncased',token=os.getenv('HF_TOKEN'))
47
  hidden_size = self.bert.config.hidden_size # 768
48
 
 
 
49
  # Task-specific classification heads
50
  self.workout_head = nn.Linear(hidden_size, num_workout_types)
51
  self.mood_head = nn.Linear(hidden_size, num_moods)
52
  self.soreness_head = nn.Linear(hidden_size, num_soreness_levels)
53
 
54
- self.dropout = nn.Dropout(0.3)
55
 
56
  def forward(self, input_ids, attention_mask):
57
  outputs = self.bert(input_ids=input_ids,attention_mask=attention_mask)
@@ -60,8 +51,65 @@ class MultiHeadDistilBERT(nn.Module):
60
  cls_output = self.dropout(outputs.last_hidden_state[:, 0, :]) # [CLS] token is first token in sequence
61
 
62
  # Each head produces its own logits
63
- return (self.workout_head(cls_output), self.mood_head(cls_output), self.soreness_head(cls_output))
64
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  class PredictRequest(BaseModel):
66
  user_input: str
67
 
@@ -84,83 +132,100 @@ class PredictResponse(BaseModel):
84
  exercises: List[ExerciseResponse]
85
 
86
 
87
- def get_suitable_exercises(workout_type: int, mood: int, soreness: int):
88
- supabase = create_client(os.getenv('SUPA_URL'), os.getenv('SUPA_KEY'))
89
-
90
- supabase_response = (
91
- supabase.table('exerciseai')
92
- .select('*')
93
- .eq('workout_type', workout_type)
94
- .contains('suitable_moods', [str(mood)])
95
- .contains('suitable_soreness', [str(soreness)])
96
- .execute()
97
- )
98
-
99
- # Parse Supabase response into ExerciseResponse objects
100
- exercises = [ExerciseResponse(**exercise) for exercise in supabase_response.data]
101
- return exercises
 
102
 
 
103
  @app.get("/")
104
- def greet_json():
105
- return {"Hello": "World!"}
 
 
 
 
106
 
107
- @app.post("/predict",response_model=PredictResponse)
 
108
  def predict(request: PredictRequest):
109
-
110
- model = MultiHeadDistilBERT(
111
- num_workout_types=8,
112
- num_moods=5,
113
- num_soreness_levels=3
114
- )
115
-
116
- model.load_state_dict(torch.load('best_DistilBERT_model.pt', map_location=torch.device('cpu')))
117
- model.to(device)
118
- model.eval()
119
-
120
- tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased',token=os.getenv('HF_TOKEN'))
121
-
122
- encoding = tokenizer(
123
- request.user_input, # The single string the user types
124
- max_length=128,
125
- padding='max_length',
126
- truncation=True,
127
- return_tensors='pt'
128
- )
129
-
130
- input_ids = encoding['input_ids'].to(device)
131
- attention_mask = encoding['attention_mask'].to(device)
132
 
133
- with torch.no_grad():
134
- workout_logits, mood_logits, soreness_logits = model(input_ids, attention_mask)
135
-
136
- # Convert logits to probabilities
137
- workout_probs = torch.softmax(workout_logits, dim=1)
138
- mood_probs = torch.softmax(mood_logits, dim=1)
139
- soreness_probs = torch.softmax(soreness_logits, dim=1)
140
-
141
- # Get predicted class and confidence percentage for each head
142
- workout_conf, workout_pred = workout_probs.max(dim=1)
143
- mood_conf, mood_pred = mood_probs.max(dim=1)
144
- soreness_conf, soreness_pred = soreness_probs.max(dim=1)
145
-
146
- # Map predictions to labels
147
- predicted_workout = workout_label_map[workout_logits.argmax().item()]
148
- predicted_mood = mood_label_map[mood_logits.argmax().item()]
149
- predicted_soreness = soreness_label_map[soreness_logits.argmax().item()]
150
-
151
- # Fetch suitable exercises from Supabase
152
- suitable_exercises = get_suitable_exercises(
153
- workout_type = workout_logits.argmax().item(),
154
- mood = mood_logits.argmax().item(),
155
- soreness = soreness_logits.argmax().item()
156
- )
157
- return PredictResponse(
158
- workout = predicted_workout,
159
- workout_conf = round(workout_conf.item() * 100, 1),
160
- mood = predicted_mood,
161
- mood_conf = round(mood_conf.item() * 100, 1),
162
- soreness = predicted_soreness,
163
- soreness_conf = round(soreness_conf.item() * 100, 1),
164
- exercises = suitable_exercises
165
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
 
9
  import torch.nn as nn
10
  from transformers import DistilBertModel, DistilBertTokenizer
11
 
12
+ # ── Logging setup ─────────────────────────────────────────────────────────────
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
 
16
+ # ── Label Maps ────────────────────────────────────────────────────────────────
17
  workout_label_map = {
18
+ 0: "Chest", 1: "Back", 2: "Legs", 3: "Shoulders",
19
+ 4: "Arms", 5: "Core", 6: "Full Body", 7: "Cardio"
 
 
 
 
 
 
20
  }
21
 
22
  mood_label_map = {
23
+ 0: "Energized", 1: "Tired", 2: "Stressed",
24
+ 3: "Motivated", 4: "Neutral"
 
 
 
25
  }
26
 
27
  soreness_label_map = {
28
+ 0: "None", 1: "Mild", 2: "Severe"
 
 
29
  }
30
 
31
  class MultiHeadDistilBERT(nn.Module):
 
36
  self.bert = DistilBertModel.from_pretrained('distilbert-base-uncased',token=os.getenv('HF_TOKEN'))
37
  hidden_size = self.bert.config.hidden_size # 768
38
 
39
+ self.dropout = nn.Dropout(0.3)
40
+
41
  # Task-specific classification heads
42
  self.workout_head = nn.Linear(hidden_size, num_workout_types)
43
  self.mood_head = nn.Linear(hidden_size, num_moods)
44
  self.soreness_head = nn.Linear(hidden_size, num_soreness_levels)
45
 
 
46
 
47
  def forward(self, input_ids, attention_mask):
48
  outputs = self.bert(input_ids=input_ids,attention_mask=attention_mask)
 
51
  cls_output = self.dropout(outputs.last_hidden_state[:, 0, :]) # [CLS] token is first token in sequence
52
 
53
  # Each head produces its own logits
54
+ return (
55
+ self.workout_head(cls_output),
56
+ self.mood_head(cls_output),
57
+ self.soreness_head(cls_output)
58
+ )
59
+
60
+ # ── App State β€” loaded once at startup ───────────────────────────────────────
61
+ class AppState:
62
+ model: MultiHeadDistilBERT = None
63
+ tokenizer: DistilBertTokenizer = None
64
+ supabase: Client = None
65
+ device: torch.device = None
66
+
67
+ state = AppState()
68
+
69
+ # ── Lifespan β€” runs once on startup and shutdown ──────────────────────────────
70
+ @asynccontextmanager
71
+ async def lifespan(app: FastAPI):
72
+ # ── Startup ───────────────────────────────────────────────────────────────
73
+ logger.info("Loading model, tokenizer and Supabase client...")
74
+
75
+ state.device = torch.device('cpu')
76
+
77
+ # Load tokenizer once
78
+ state.tokenizer = DistilBertTokenizer.from_pretrained(
79
+ 'distilbert-base-uncased',
80
+ token=os.getenv('HF_TOKEN')
81
+ )
82
+ logger.info("Tokenizer loaded")
83
+
84
+ # Load model once
85
+ state.model = MultiHeadDistilBERT(
86
+ num_workout_types=8,
87
+ num_moods=5,
88
+ num_soreness_levels=3
89
+ )
90
+ state.model.load_state_dict(
91
+ torch.load('best_DistilBERT_model.pt', map_location=state.device)
92
+ )
93
+ state.model.to(state.device)
94
+ state.model.eval()
95
+ logger.info("Model loaded")
96
+
97
+ # Create Supabase client once
98
+ state.supabase = create_client(
99
+ os.getenv('SUPA_URL'),
100
+ os.getenv('SUPA_KEY')
101
+ )
102
+ logger.info("Supabase client created")
103
+ logger.info("Startup complete β€” API is ready")
104
+
105
+ yield # ← API runs here
106
+
107
+ # ── Shutdown ──────────────────────────────────────────────────────────────
108
+ logger.info("Shutting down API")
109
+
110
+
111
+ app = FastAPI(lifespan=lifespan)
112
+
113
  class PredictRequest(BaseModel):
114
  user_input: str
115
 
 
132
  exercises: List[ExerciseResponse]
133
 
134
 
135
+ # ── Supabase Helper ───────────────────────────────────────────────────────────
136
+ def get_suitable_exercises(workout_type: int, mood: int, soreness: int) -> List[ExerciseResponse]:
137
+ try:
138
+ response = (
139
+ state.supabase.table('exerciseai')
140
+ .select('*')
141
+ .eq('workout_type', workout_type)
142
+ .contains('suitable_moods', [mood])
143
+ .contains('suitable_soreness', [soreness])
144
+ .execute()
145
+ )
146
+ return [ExerciseResponse(**exercise) for exercise in response.data]
147
+
148
+ except Exception as e:
149
+ logger.error(f"Supabase query failed: {e}")
150
+ raise HTTPException(status_code=503, detail="Failed to fetch exercises from database")
151
 
152
+ # ── Health Check ──────────────────────────────────────────────────────────────
153
  @app.get("/")
154
+ def health_check():
155
+ return {
156
+ "status": "ok",
157
+ "model": "MultiHeadDistilBERT",
158
+ "device": str(state.device)
159
+ }
160
 
161
+ # ── Predict Endpoint ──────────────────────────────────────────────────────────
162
+ @app.post("/predict", response_model=PredictResponse)
163
  def predict(request: PredictRequest):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
+ # ── Input validation ──────────────────────────────────────────────────────
166
+ if not request.user_input.strip():
167
+ raise HTTPException(status_code=400, detail="user_input cannot be empty")
168
+
169
+ try:
170
+ # ── Tokenize ──────────────────────────────────────────────────────────
171
+ encoding = state.tokenizer(
172
+ request.user_input,
173
+ max_length=128,
174
+ padding='max_length',
175
+ truncation=True,
176
+ return_tensors='pt'
177
+ )
178
+
179
+ input_ids = encoding['input_ids'].to(state.device)
180
+ attention_mask = encoding['attention_mask'].to(state.device)
181
+
182
+ # ── Inference ─────────────────────────────────────��───────────────────
183
+ with torch.no_grad():
184
+ workout_logits, mood_logits, soreness_logits = state.model(
185
+ input_ids, attention_mask
186
+ )
187
+
188
+ # ── Softmax + confidence ──────────────────────────────────────────────
189
+ workout_probs = torch.softmax(workout_logits, dim=1)
190
+ mood_probs = torch.softmax(mood_logits, dim=1)
191
+ soreness_probs = torch.softmax(soreness_logits, dim=1)
192
+
193
+ workout_conf, workout_pred = workout_probs.max(dim=1)
194
+ mood_conf, mood_pred = mood_probs.max(dim=1)
195
+ soreness_conf, soreness_pred = soreness_probs.max(dim=1)
196
+
197
+ # ── Map to labels β€” reuse pred variables, no redundant argmax ─────────
198
+ predicted_workout = workout_label_map[workout_pred.item()]
199
+ predicted_mood = mood_label_map[mood_pred.item()]
200
+ predicted_soreness = soreness_label_map[soreness_pred.item()]
201
+
202
+ logger.info(
203
+ f"Prediction β€” Workout: {predicted_workout} ({workout_conf.item()*100:.1f}%) | "
204
+ f"Mood: {predicted_mood} ({mood_conf.item()*100:.1f}%) | "
205
+ f"Soreness: {predicted_soreness} ({soreness_conf.item()*100:.1f}%)"
206
+ )
207
+
208
+ # ── Fetch exercises ───────────────────────────────────────────────────
209
+ suitable_exercises = get_suitable_exercises(
210
+ workout_type = workout_pred.item(),
211
+ mood = mood_pred.item(),
212
+ soreness = soreness_pred.item()
213
+ )
214
+
215
+ return PredictResponse(
216
+ workout = predicted_workout,
217
+ workout_conf = round(workout_conf.item() * 100, 1),
218
+ mood = predicted_mood,
219
+ mood_conf = round(mood_conf.item() * 100, 1),
220
+ soreness = predicted_soreness,
221
+ soreness_conf = round(soreness_conf.item() * 100, 1),
222
+ exercises = suitable_exercises
223
+ )
224
+
225
+ except HTTPException:
226
+ raise # ← re-raise HTTP exceptions from get_suitable_exercises
227
+
228
+ except Exception as e:
229
+ logger.error(f"Prediction failed: {e}")
230
+ raise HTTPException(status_code=500, detail="Prediction failed. Please try again.")
231