jflo commited on
Commit
9fbf83a
Β·
verified Β·
1 Parent(s): c334a21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -33
app.py CHANGED
@@ -4,12 +4,12 @@ from contextlib import asynccontextmanager
4
 
5
  from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
7
- from typing import List
8
 
9
- from supabase import create_client, Client
10
  import torch
11
  import torch.nn as nn
12
  from transformers import DistilBertModel, DistilBertTokenizer
 
13
 
14
  # ── Logging setup ─────────────────────────────────────────────────────────────
15
  logging.basicConfig(level=logging.INFO)
@@ -30,29 +30,24 @@ soreness_label_map = {
30
  0: "None", 1: "Mild", 2: "Severe"
31
  }
32
 
 
33
  class MultiHeadDistilBERT(nn.Module):
34
  def __init__(self, num_workout_types, num_moods, num_soreness_levels):
35
  super(MultiHeadDistilBERT, self).__init__()
36
-
37
- # Shared BERT backbone
38
- self.bert = DistilBertModel.from_pretrained('distilbert-base-uncased',token=os.getenv('HF_TOKEN'))
39
- hidden_size = self.bert.config.hidden_size # 768
40
-
41
- self.dropout = nn.Dropout(0.3)
42
-
43
- # Task-specific classification heads
44
- self.workout_head = nn.Linear(hidden_size, num_workout_types)
45
- self.mood_head = nn.Linear(hidden_size, num_moods)
46
  self.soreness_head = nn.Linear(hidden_size, num_soreness_levels)
47
-
48
 
49
  def forward(self, input_ids, attention_mask):
50
- outputs = self.bert(input_ids=input_ids,attention_mask=attention_mask)
51
-
52
- # Use [CLS] token representation. DistilBERT uses last_hidden_state instead of pooler_output like BERT
53
- cls_output = self.dropout(outputs.last_hidden_state[:, 0, :]) # [CLS] token is first token in sequence
54
-
55
- # Each head produces its own logits
56
  return (
57
  self.workout_head(cls_output),
58
  self.mood_head(cls_output),
@@ -72,7 +67,6 @@ state = AppState()
72
  @asynccontextmanager
73
  async def lifespan(app: FastAPI):
74
  # ── Startup ───────────────────────────────────────────────────────────────
75
-
76
  logger.info("Loading model, tokenizer and Supabase client...")
77
 
78
  state.device = torch.device('cpu')
@@ -82,7 +76,6 @@ async def lifespan(app: FastAPI):
82
  'distilbert-base-uncased',
83
  token=os.getenv('HF_TOKEN')
84
  )
85
-
86
  logger.info("Tokenizer loaded")
87
 
88
  # Load model once
@@ -111,9 +104,9 @@ async def lifespan(app: FastAPI):
111
  # ── Shutdown ──────────────────────────────────────────────────────────────
112
  logger.info("Shutting down API")
113
 
114
-
115
  app = FastAPI(lifespan=lifespan)
116
 
 
117
  class PredictRequest(BaseModel):
118
  user_input: str
119
 
@@ -125,16 +118,15 @@ class ExerciseResponse(BaseModel):
125
  notes: str
126
  suitable_moods: List[int]
127
  suitable_soreness: List[int]
128
-
129
  class PredictResponse(BaseModel):
130
- workout: str
131
- workout_conf: float
132
- mood: str
133
- mood_conf: float
134
- soreness: str
135
  soreness_conf: float
136
- exercises: List[ExerciseResponse]
137
-
138
 
139
  # ── Supabase Helper ───────────────────────────────────────────────────────────
140
  def get_suitable_exercises(workout_type: int, mood: int, soreness: int) -> List[ExerciseResponse]:
@@ -152,7 +144,7 @@ def get_suitable_exercises(workout_type: int, mood: int, soreness: int) -> List[
152
  except Exception as e:
153
  logger.error(f"Supabase query failed: {e}")
154
  raise HTTPException(status_code=503, detail="Failed to fetch exercises from database")
155
-
156
  # ─�� Health Check ──────────────────────────────────────────────────────────────
157
  @app.get("/")
158
  def health_check():
@@ -160,7 +152,7 @@ def health_check():
160
  "status": "ok",
161
  "model": "MultiHeadDistilBERT",
162
  "device": str(state.device)
163
- }
164
 
165
  # ── Predict Endpoint ──────────────────────────────────────────────────────────
166
  @app.post("/predict", response_model=PredictResponse)
@@ -232,4 +224,4 @@ def predict(request: PredictRequest):
232
  except Exception as e:
233
  logger.error(f"Prediction failed: {e}")
234
  raise HTTPException(status_code=500, detail="Prediction failed. Please try again.")
235
-
 
4
 
5
  from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
7
+ from typing import List
8
 
 
9
  import torch
10
  import torch.nn as nn
11
  from transformers import DistilBertModel, DistilBertTokenizer
12
+ from supabase import create_client, Client
13
 
14
  # ── Logging setup ─────────────────────────────────────────────────────────────
15
  logging.basicConfig(level=logging.INFO)
 
30
  0: "None", 1: "Mild", 2: "Severe"
31
  }
32
 
33
+ # ── Model Definition ──────────────────────────────────────────────────────────
34
  class MultiHeadDistilBERT(nn.Module):
35
  def __init__(self, num_workout_types, num_moods, num_soreness_levels):
36
  super(MultiHeadDistilBERT, self).__init__()
37
+
38
+ self.bert = DistilBertModel.from_pretrained(
39
+ 'distilbert-base-uncased',
40
+ token=os.getenv('HF_TOKEN')
41
+ )
42
+ hidden_size = self.bert.config.hidden_size
43
+ self.dropout = nn.Dropout(0.3)
44
+ self.workout_head = nn.Linear(hidden_size, num_workout_types)
45
+ self.mood_head = nn.Linear(hidden_size, num_moods)
 
46
  self.soreness_head = nn.Linear(hidden_size, num_soreness_levels)
 
47
 
48
  def forward(self, input_ids, attention_mask):
49
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
50
+ cls_output = self.dropout(outputs.last_hidden_state[:, 0, :])
 
 
 
 
51
  return (
52
  self.workout_head(cls_output),
53
  self.mood_head(cls_output),
 
67
  @asynccontextmanager
68
  async def lifespan(app: FastAPI):
69
  # ── Startup ───────────────────────────────────────────────────────────────
 
70
  logger.info("Loading model, tokenizer and Supabase client...")
71
 
72
  state.device = torch.device('cpu')
 
76
  'distilbert-base-uncased',
77
  token=os.getenv('HF_TOKEN')
78
  )
 
79
  logger.info("Tokenizer loaded")
80
 
81
  # Load model once
 
104
  # ── Shutdown ──────────────────────────────────────────────────────────────
105
  logger.info("Shutting down API")
106
 
 
107
  app = FastAPI(lifespan=lifespan)
108
 
109
+ # ── Schemas ───────────────────────────────────────────────────────────────────
110
  class PredictRequest(BaseModel):
111
  user_input: str
112
 
 
118
  notes: str
119
  suitable_moods: List[int]
120
  suitable_soreness: List[int]
121
+
122
  class PredictResponse(BaseModel):
123
+ workout: str
124
+ workout_conf: float
125
+ mood: str
126
+ mood_conf: float
127
+ soreness: str
128
  soreness_conf: float
129
+ exercises: List[ExerciseResponse]
 
130
 
131
  # ── Supabase Helper ───────────────────────────────────────────────────────────
132
  def get_suitable_exercises(workout_type: int, mood: int, soreness: int) -> List[ExerciseResponse]:
 
144
  except Exception as e:
145
  logger.error(f"Supabase query failed: {e}")
146
  raise HTTPException(status_code=503, detail="Failed to fetch exercises from database")
147
+
148
  # ─�� Health Check ──────────────────────────────────────────────────────────────
149
  @app.get("/")
150
  def health_check():
 
152
  "status": "ok",
153
  "model": "MultiHeadDistilBERT",
154
  "device": str(state.device)
155
+ }
156
 
157
  # ── Predict Endpoint ──────────────────────────────────────────────────────────
158
  @app.post("/predict", response_model=PredictResponse)
 
224
  except Exception as e:
225
  logger.error(f"Prediction failed: {e}")
226
  raise HTTPException(status_code=500, detail="Prediction failed. Please try again.")
227
+