jflo commited on
Commit
e90a0ff
·
1 Parent(s): 7c23ee3

BERT Model to classify gym sentiment

Browse files
Files changed (1) hide show
  1. app.py +48 -12
app.py CHANGED
@@ -63,15 +63,23 @@ class MultiHeadBERT(nn.Module):
63
 
64
  return workout_logits, feeling_logits, soreness_logits
65
 
66
- class ClassificationRequest(BaseModel):
67
- message: str
 
 
 
 
 
 
 
 
68
 
69
  @app.get("/")
70
  def greet_json():
71
  return {"Hello": "World!"}
72
 
73
- @app.post("/classify")
74
- def sentiment_analysis(payload: ClassificationRequest):
75
 
76
  model = MultiHeadBERT(
77
  num_workout_types=4,
@@ -86,14 +94,42 @@ def sentiment_analysis(payload: ClassificationRequest):
86
  model.to(device)
87
  model.eval()
88
 
89
- result = predict(
90
- text=payload.message,
91
- model=model,
92
- tokenizer=tokenizer,
93
- device=device
 
94
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- return {
97
- result
98
- }
 
 
 
 
 
99
 
 
63
 
64
  return workout_logits, feeling_logits, soreness_logits
65
 
66
+ class PredictRequest(BaseModel):
67
+ user_input: str
68
+
69
+ class PredictResponse(BaseModel):
70
+ workout: str
71
+ workout_conf: float
72
+ feeling: str
73
+ feeling_conf: float
74
+ soreness: str
75
+ soreness_conf: float
76
 
77
  @app.get("/")
78
  def greet_json():
79
  return {"Hello": "World!"}
80
 
81
+ @app.post("/predict",response_model=PredictResponse)
82
+ def predict(request: PredictRequest):
83
 
84
  model = MultiHeadBERT(
85
  num_workout_types=4,
 
94
  model.to(device)
95
  model.eval()
96
 
97
+ encoding = tokenizer(
98
+ request.user_input, # The single string the user types
99
+ max_length=128,
100
+ padding='max_length',
101
+ truncation=True,
102
+ return_tensors='pt'
103
  )
104
+
105
+ input_ids = encoding['input_ids'].to(device)
106
+ attention_mask = encoding['attention_mask'].to(device)
107
+
108
+ with torch.no_grad():
109
+ workout_logits, feeling_logits, soreness_logits = model(input_ids, attention_mask)
110
+
111
+ # Convert logits to probabilities
112
+ workout_probs = torch.softmax(workout_logits, dim=1)
113
+ feeling_probs = torch.softmax(feeling_logits, dim=1)
114
+ soreness_probs = torch.softmax(soreness_logits, dim=1)
115
+
116
+ # Get predicted class and confidence percentage for each head
117
+ workout_conf, workout_pred = workout_probs.max(dim=1)
118
+ feeling_conf, feeling_pred = feeling_probs.max(dim=1)
119
+ soreness_conf, soreness_pred = soreness_probs.max(dim=1)
120
+
121
+ # Map predictions to labels
122
+ predicted_workout = workout_label_map[workout_logits.argmax().item()]
123
+ predicted_feeling = feeling_label_map[feeling_logits.argmax().item()]
124
+ predicted_soreness = soreness_label_map[soreness_logits.argmax().item()]
125
+
126
 
127
+ return PredictResponse(
128
+ workout = predicted_workout,
129
+ workout_conf = round(workout_conf.item() * 100, 1),
130
+ feeling = predicted_feeling,
131
+ feeling_conf = round(feeling_conf.item() * 100, 1),
132
+ soreness = predicted_soreness,
133
+ soreness_conf = round(soreness_conf.item() * 100, 1),
134
+ )
135