jflo commited on
Commit
7c23ee3
·
1 Parent(s): 19fb43f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -19
app.py CHANGED
@@ -1,12 +1,68 @@
 
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
 
4
- import joblib
5
- import pandas as pd
6
- import maven_text_preprocessing
7
 
8
  app = FastAPI()
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  class ClassificationRequest(BaseModel):
11
  message: str
12
 
@@ -17,25 +73,27 @@ def greet_json():
17
  @app.post("/classify")
18
  def sentiment_analysis(payload: ClassificationRequest):
19
 
20
- model = joblib.load("naive_bayes.joblib")
21
- vectorizer = joblib.load("vectorizer.joblib")
22
-
23
- clean_text = maven_text_preprocessing.clean_and_normalize(pd.Series([payload.message]))
24
-
25
- X = vectorizer.transform(clean_text) # ⚠️ transform, NOT fit_transform
26
-
27
- category_list = ["Politics", "Sport", "Technology", "Entertainment", "Business"]
28
 
29
- predictions = model.predict(X)
 
 
30
 
31
- pred_prob = model.predict_proba(X)
32
- pred_prob = pred_prob.tolist()[0]
 
 
 
 
 
 
 
33
 
34
  return {
35
- category_list[0]: pred_prob[0],
36
- category_list[1]: pred_prob[1],
37
- category_list[2]: pred_prob[2],
38
- category_list[3]: pred_prob[3],
39
- category_list[4]: pred_prob[4]
40
  }
41
 
 
1
+ import os
2
+
3
  from fastapi import FastAPI
4
  from pydantic import BaseModel
5
 
6
+ import torch
7
+ import torch.nn as nn
8
+ from transformers import BertModel, BertTokenizer
9
 
10
  app = FastAPI()
11
 
12
+ device = torch.device('cpu') # Hugging Face Space with no GPU
13
+
14
+ workout_label_map = {
15
+ 0: "Cardio",
16
+ 1: "Strength",
17
+ 2: "Yoga",
18
+ 3: "HIIT"
19
+ }
20
+
21
+ feeling_label_map = {
22
+ 0: "Energized",
23
+ 1: "Tired",
24
+ 2: "Stressed",
25
+ 3: "Motivated"
26
+ }
27
+
28
+ soreness_label_map = {
29
+ 0: "None",
30
+ 1: "Mild",
31
+ 2: "Severe"
32
+ }
33
+
34
+ class MultiHeadBERT(nn.Module):
35
+ def __init__(self, num_workout_types, num_feelings, num_soreness_levels):
36
+ super(MultiHeadBERT, self).__init__()
37
+
38
+ # Shared BERT backbone
39
+ self.bert = BertModel.from_pretrained('bert-base-uncased',token=os.get_env('HF_TOKEN'))
40
+ hidden_size = self.bert.config.hidden_size # 768
41
+
42
+ # Task-specific classification heads
43
+ self.workout_head = nn.Linear(hidden_size, num_workout_types)
44
+ self.feeling_head = nn.Linear(hidden_size, num_feelings)
45
+ self.soreness_head = nn.Linear(hidden_size, num_soreness_levels)
46
+
47
+ self.dropout = nn.Dropout(0.3)
48
+
49
+ def forward(self, input_ids, attention_mask, token_type_ids=None):
50
+ outputs = self.bert(
51
+ input_ids=input_ids,
52
+ attention_mask=attention_mask,
53
+ token_type_ids=token_type_ids
54
+ )
55
+
56
+ # Use [CLS] token representation
57
+ cls_output = self.dropout(outputs.pooler_output)
58
+
59
+ # Each head produces its own logits
60
+ workout_logits = self.workout_head(cls_output)
61
+ feeling_logits = self.feeling_head(cls_output)
62
+ soreness_logits = self.soreness_head(cls_output)
63
+
64
+ return workout_logits, feeling_logits, soreness_logits
65
+
66
  class ClassificationRequest(BaseModel):
67
  message: str
68
 
 
73
  @app.post("/classify")
74
  def sentiment_analysis(payload: ClassificationRequest):
75
 
76
+ model = MultiHeadBERT(
77
+ num_workout_types=4,
78
+ num_feelings=4,
79
+ num_soreness_levels=3
80
+ )
 
 
 
81
 
82
+ model.load_state_dict(
83
+ torch.load('best_model.pt', map_location=torch.device('cpu'))
84
+ )
85
 
86
+ model.to(device)
87
+ model.eval()
88
+
89
+ result = predict(
90
+ text=payload.message,
91
+ model=model,
92
+ tokenizer=tokenizer,
93
+ device=device
94
+ )
95
 
96
  return {
97
+ result
 
 
 
 
98
  }
99