File size: 10,510 Bytes
7c23ee3
9a95b8a
f2119eb
7c23ee3
f2119eb
83e8c6c
9fbf83a
eecfa33
7c23ee3
 
036993c
9fbf83a
7a1bfc8
ae68031
9a95b8a
 
7c23ee3
ae68031
7c23ee3
ae68031
 
7c23ee3
 
254343a
ae68031
 
7c23ee3
 
 
ae68031
7c23ee3
 
9fbf83a
036993c
254343a
036993c
9fbf83a
 
 
 
 
 
 
 
 
7c23ee3
 
036993c
9fbf83a
 
ae68031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a95b8a
ae68031
 
 
 
 
 
 
 
9a95b8a
ae68031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fbf83a
e90a0ff
 
 
bd8428e
 
 
 
 
 
 
 
9fbf83a
e90a0ff
9fbf83a
 
 
 
 
254343a
9fbf83a
a5c872c
ad20b81
 
 
 
 
ae68031
 
 
ad20b81
 
ff9254c
ae68031
 
 
 
ad20b81
 
ae68031
 
ad20b81
 
 
ae68031
 
 
 
 
9fbf83a
ae68031
7a1bfc8
ae68031
 
 
 
 
9fbf83a
7a1bfc8
ae68031
 
e90a0ff
ad20b81
ae68031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad20b81
ae68031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fbf83a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import os
import logging
from contextlib import asynccontextmanager

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List

import torch
import torch.nn as nn
from transformers import DistilBertModel, DistilBertTokenizer
from supabase import create_client, Client

# ── Logging setup ─────────────────────────────────────────────────────────────
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ── Label Maps ────────────────────────────────────────────────────────────────
workout_label_map = {
    0: "Chest",     1: "Back",      2: "Legs",      3: "Shoulders",
    4: "Arms",      5: "Core",      6: "Full Body",  7: "Cardio"
}

mood_label_map = {
    0: "Energized", 1: "Tired",     2: "Stressed",
    3: "Motivated", 4: "Neutral"
}

soreness_label_map = {
    0: "None",      1: "Mild",      2: "Severe"
}

# ── Model Definition ──────────────────────────────────────────────────────────
class MultiHeadDistilBERT(nn.Module):
    def __init__(self, num_workout_types, num_moods, num_soreness_levels):
        super(MultiHeadDistilBERT, self).__init__()

        self.bert        = DistilBertModel.from_pretrained(
                               'distilbert-base-uncased',
                               token=os.getenv('HF_TOKEN')
                           )
        hidden_size      = self.bert.config.hidden_size
        self.dropout     = nn.Dropout(0.3)
        self.workout_head  = nn.Linear(hidden_size, num_workout_types)
        self.mood_head     = nn.Linear(hidden_size, num_moods)
        self.soreness_head = nn.Linear(hidden_size, num_soreness_levels)

    def forward(self, input_ids, attention_mask):
        outputs    = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = self.dropout(outputs.last_hidden_state[:, 0, :])
        return (
            self.workout_head(cls_output),
            self.mood_head(cls_output),
            self.soreness_head(cls_output)
        )

# ── App State β€” loaded once at startup ───────────────────────────────────────
class AppState:
    model:     MultiHeadDistilBERT = None
    tokenizer: DistilBertTokenizer = None
    supabase:  Client              = None
    device:    torch.device        = None

state = AppState()

# ── Lifespan β€” runs once on startup and shutdown ──────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
    # ── Startup ───────────────────────────────────────────────────────────────
    logger.info("Loading model, tokenizer and Supabase client...")

    state.device = torch.device('cpu')

    # Load tokenizer once
    state.tokenizer = DistilBertTokenizer.from_pretrained(
        'distilbert-base-uncased',
        token=os.getenv('HF_TOKEN')
    )
    logger.info("Tokenizer loaded")

    # Load model once
    state.model = MultiHeadDistilBERT(
        num_workout_types=8,
        num_moods=5,
        num_soreness_levels=3
    )
    state.model.load_state_dict(
        torch.load('best_DistilBERT_model.pt', map_location=state.device)
    )
    state.model.to(state.device)
    state.model.eval()
    logger.info("Model loaded")

    # Create Supabase client once
    state.supabase = create_client(
        os.getenv('SUPA_URL'),
        os.getenv('SUPA_KEY')
    )
    logger.info("Supabase client created")
    logger.info("Startup complete β€” API is ready")

    yield  # ← API runs here

    # ── Shutdown ──────────────────────────────────────────────────────────────
    logger.info("Shutting down API")

app = FastAPI(lifespan=lifespan)

# ── Schemas ───────────────────────────────────────────────────────────────────
class PredictRequest(BaseModel):
    user_input: str

class ExerciseResponse(BaseModel):
    id:                int
    name:              str
    workout_type:      int
    difficulty:        str
    notes:             str
    suitable_moods:    List[int]
    suitable_soreness: List[int]

class PredictResponse(BaseModel):
    workout:       str
    workout_conf:  float
    mood:          str
    mood_conf:     float
    soreness:      str
    soreness_conf: float
    exercises:     List[ExerciseResponse]

def format_pg_array(values: list) -> str:
    """Convert a Python list to PostgreSQL array literal format"""
    return '{' + ','.join(str(v) for v in values) + '}'

    
# ── Supabase Helper ───────────────────────────────────────────────────────────
def get_suitable_exercises(workout_type: int, mood: int, soreness: int) -> List[ExerciseResponse]:
    try:
        logger.info(f"Querying exercises β€” workout_type: {workout_type}, mood: {mood}, soreness: {soreness}")

        print("DATA")
        response = (
            state.supabase.table('exerciseai')
            .select('*')
            .eq('workout_type', workout_type)
            .contains('suitable_moods',    format_pg_array([mood]))
            .contains('suitable_soreness', format_pg_array([soreness]))
            .execute()
        )

        logger.info(f"Supabase returned {len(response.data)} exercises")

        return [ExerciseResponse(**exercise) for exercise in response.data]

    except Exception as e:
        logger.error(f"Supabase query failed: {e}")
        raise HTTPException(status_code=503, detail="Failed to fetch exercises from database")

# ── Health Check ──────────────────────────────────────────────────────────────
@app.get("/")
def health_check():
    return {
        "status":  "ok",
        "model":   "MultiHeadDistilBERT",
        "device":  str(state.device)
    }

# ── Predict Endpoint ──────────────────────────────────────────────────────────
@app.post("/predict", response_model=PredictResponse)
def predict(request: PredictRequest):
    print("HERE")
    # ── Input validation ──────────────────────────────────────────────────────
    if not request.user_input.strip():
        raise HTTPException(status_code=400, detail="user_input cannot be empty")

    try:
        # ── Tokenize ──────────────────────────────────────────────────────────
        encoding = state.tokenizer(
            request.user_input,
            max_length=128,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        input_ids      = encoding['input_ids'].to(state.device)
        attention_mask = encoding['attention_mask'].to(state.device)

        # ── Inference ─────────────────────────────────────────────────────────
        with torch.no_grad():
            workout_logits, mood_logits, soreness_logits = state.model(
                input_ids, attention_mask
            )
        print("HERE2")
        # ── Softmax + confidence ──────────────────────────────────────────────
        workout_probs  = torch.softmax(workout_logits,  dim=1)
        mood_probs     = torch.softmax(mood_logits,     dim=1)
        soreness_probs = torch.softmax(soreness_logits, dim=1)

        workout_conf,  workout_pred  = workout_probs.max(dim=1)
        mood_conf,     mood_pred     = mood_probs.max(dim=1)
        soreness_conf, soreness_pred = soreness_probs.max(dim=1)

        # ── Map to labels β€” reuse pred variables, no redundant argmax ─────────
        predicted_workout  = workout_label_map[workout_pred.item()]
        predicted_mood     = mood_label_map[mood_pred.item()]
        predicted_soreness = soreness_label_map[soreness_pred.item()]

        logger.info(
            f"Prediction β€” Workout: {predicted_workout} ({workout_conf.item()*100:.1f}%) | "
            f"Mood: {predicted_mood} ({mood_conf.item()*100:.1f}%) | "
            f"Soreness: {predicted_soreness} ({soreness_conf.item()*100:.1f}%)"
        )

        # ── Fetch exercises ───────────────────────────────────────────────────
        suitable_exercises = get_suitable_exercises(
            workout_type = workout_pred.item(),
            mood         = mood_pred.item(),
            soreness     = soreness_pred.item()
        )

        return PredictResponse(
            workout       = predicted_workout,
            workout_conf  = round(workout_conf.item()  * 100, 1),
            mood          = predicted_mood,
            mood_conf     = round(mood_conf.item()     * 100, 1),
            soreness      = predicted_soreness,
            soreness_conf = round(soreness_conf.item() * 100, 1),
            exercises     = suitable_exercises
        )

    except HTTPException:
        raise  # ← re-raise HTTP exceptions from get_suitable_exercises

    except Exception as e:
        logger.error(f"Prediction failed: {e}")
        raise HTTPException(status_code=500, detail="Prediction failed. Please try again.")