leygit commited on
Commit
32062ea
·
verified ·
1 Parent(s): c9f11fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -4
app.py CHANGED
@@ -136,16 +136,34 @@ def evaluate_model_with_report(val_loader):
136
 
137
  # Performance metrics
138
  def generate_performance_metrics():
139
- y_pred = model.predict(X_test)
140
- accuracy = evaluate_model_with_report(val_loader)
141
- report = classification_report(y_true, y_pred, target_names=["Ham", "Spam"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  return {
143
  "accuracy": f"{accuracy:.2%}",
144
  "precision": f"{report['1']['precision']:.2%}",
145
  "recall": f"{report['1']['recall']:.2%}",
146
- "f1_score": f"{report['1']['f1-score']:.2%}"
147
  }
148
 
 
149
  # Gradio Interface
150
 
151
  def create_interface():
 
136
 
137
  # Performance metrics
138
  def generate_performance_metrics():
139
+ model.eval() # Set model to evaluation mode
140
+
141
+ y_true = [] # True labels
142
+ y_pred = [] # Predicted labels
143
+
144
+ with torch.no_grad():
145
+ for batch in val_loader:
146
+ inputs = tokenizer(email, padding=True, truncation=True, max_length=256, return_tensors="pt")
147
+ inputs = {key: val.to(device) for key, val in inputs.items()}
148
+
149
+ outputs = model(**inputs)
150
+ prediction = torch.argmax(outputs.logits, dim=1).item()
151
+
152
+ y_true.append(label)
153
+ y_pred.append(prediction)
154
+
155
+ # Compute accuracy and classification report
156
+ accuracy = accuracy_score(y_true, y_pred)
157
+ report = classification_report(y_true, y_pred, output_dict=True)
158
+
159
  return {
160
  "accuracy": f"{accuracy:.2%}",
161
  "precision": f"{report['1']['precision']:.2%}",
162
  "recall": f"{report['1']['recall']:.2%}",
163
+ "f1_score": f"{report['1']['f1-score']:.2%}",
164
  }
165
 
166
+
167
  # Gradio Interface
168
 
169
  def create_interface():