leygit commited on
Commit
2013b71
·
verified ·
1 Parent(s): c2df036

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -143,14 +143,14 @@ def generate_performance_metrics():
143
 
144
  with torch.no_grad():
145
  for batch in val_loader:
146
- inputs = tokenizer(email_text, padding=True, truncation=True, max_length=256, return_tensors="pt")
147
- inputs = {key: val.to(device) for key, val in inputs.items()}
148
 
149
  outputs = model(**inputs)
150
- prediction = torch.argmax(outputs.logits, dim=1).item()
151
 
152
- y_true.append(label)
153
- y_pred.append(prediction)
154
 
155
  # Compute accuracy and classification report
156
  accuracy = accuracy_score(y_true, y_pred)
@@ -164,6 +164,7 @@ def generate_performance_metrics():
164
  }
165
 
166
 
 
167
  # Gradio Interface
168
 
169
  def create_interface():
 
143
 
144
  with torch.no_grad():
145
  for batch in val_loader:
146
+ inputs = {key: val.to(device) for key, val in batch.items()}
147
+ labels = inputs.pop("labels").to(device) # Extract labels
148
 
149
  outputs = model(**inputs)
150
+ predictions = torch.argmax(outputs.logits, dim=1)
151
 
152
+ y_true.extend(labels.cpu().numpy())
153
+ y_pred.extend(predictions.cpu().numpy())
154
 
155
  # Compute accuracy and classification report
156
  accuracy = accuracy_score(y_true, y_pred)
 
164
  }
165
 
166
 
167
+
168
  # Gradio Interface
169
 
170
  def create_interface():