jflo commited on
Commit
036993c
·
1 Parent(s): 2e256f2

Changed to DistilBERT model

Browse files
Files changed (1) hide show
  1. app.py +13 -24
app.py CHANGED
@@ -5,7 +5,7 @@ from pydantic import BaseModel
5
 
6
  import torch
7
  import torch.nn as nn
8
- from transformers import BertModel, BertTokenizer
9
 
10
  app = FastAPI()
11
 
@@ -36,12 +36,12 @@ soreness_label_map = {
36
  2: "Severe"
37
  }
38
 
39
- class MultiHeadBERT(nn.Module):
40
  def __init__(self, num_workout_types, num_moods, num_soreness_levels):
41
- super(MultiHeadBERT, self).__init__()
42
 
43
  # Shared BERT backbone
44
- self.bert = BertModel.from_pretrained('bert-base-uncased',token=os.getenv('HF_TOKEN'))
45
  hidden_size = self.bert.config.hidden_size # 768
46
 
47
  # Task-specific classification heads
@@ -51,22 +51,14 @@ class MultiHeadBERT(nn.Module):
51
 
52
  self.dropout = nn.Dropout(0.3)
53
 
54
- def forward(self, input_ids, attention_mask, token_type_ids=None):
55
- outputs = self.bert(
56
- input_ids=input_ids,
57
- attention_mask=attention_mask,
58
- token_type_ids=token_type_ids
59
- )
60
 
61
- # Use [CLS] token representation
62
- cls_output = self.dropout(outputs.pooler_output)
63
 
64
- # Each head produces its own logits
65
- workout_logits = self.workout_head(cls_output)
66
- mood_logits = self.mood_head(cls_output)
67
- soreness_logits = self.soreness_head(cls_output)
68
-
69
- return workout_logits, mood_logits, soreness_logits
70
 
71
  class PredictRequest(BaseModel):
72
  user_input: str
@@ -86,20 +78,17 @@ def greet_json():
86
  @app.post("/predict",response_model=PredictResponse)
87
  def predict(request: PredictRequest):
88
 
89
- model = MultiHeadBERT(
90
  num_workout_types=8,
91
  num_moods=5,
92
  num_soreness_levels=3
93
  )
94
 
95
- model.load_state_dict(
96
- torch.load('best_model.pt', map_location=torch.device('cpu'))
97
- )
98
-
99
  model.to(device)
100
  model.eval()
101
 
102
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',token=os.getenv('HF_TOKEN'))
103
 
104
  encoding = tokenizer(
105
  request.user_input, # The single string the user types
 
5
 
6
  import torch
7
  import torch.nn as nn
8
+ from transformers import DistilBertModel, DistilBertTokenizer
9
 
10
  app = FastAPI()
11
 
 
36
  2: "Severe"
37
  }
38
 
39
+ class MultiHeadDistilBERT(nn.Module):
40
  def __init__(self, num_workout_types, num_moods, num_soreness_levels):
41
+ super(MultiHeadDistilBERT, self).__init__()
42
 
43
  # Shared BERT backbone
44
+ self.bert = DistilBertModel.from_pretrained('distilbert-base-uncased',token=os.getenv('HF_TOKEN'))
45
  hidden_size = self.bert.config.hidden_size # 768
46
 
47
  # Task-specific classification heads
 
51
 
52
  self.dropout = nn.Dropout(0.3)
53
 
54
+ def forward(self, input_ids, attention_mask):
55
+ outputs = self.bert(input_ids=input_ids,attention_mask=attention_mask)
 
 
 
 
56
 
57
+ # Use [CLS] token representation. DistilBERT uses last_hidden_state instead of pooler_output like BERT
58
+ cls_output = self.dropout(outputs.last_hidden_state[:, 0, :]) # [CLS] token is first token in sequence
59
 
60
+ # Each head produces its own logits
61
+ return (self.workout_head(cls_output), self.mood_head(cls_output), self.soreness_head(cls_output))
 
 
 
 
62
 
63
  class PredictRequest(BaseModel):
64
  user_input: str
 
78
  @app.post("/predict",response_model=PredictResponse)
79
  def predict(request: PredictRequest):
80
 
81
+ model = MultiHeadDistilBERT(
82
  num_workout_types=8,
83
  num_moods=5,
84
  num_soreness_levels=3
85
  )
86
 
87
+ model.load_state_dict(torch.load('best_model.pt', map_location=torch.device('cpu')))
 
 
 
88
  model.to(device)
89
  model.eval()
90
 
91
+ tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased',token=os.getenv('HF_TOKEN'))
92
 
93
  encoding = tokenizer(
94
  request.user_input, # The single string the user types