Update app.py
Browse files
app.py
CHANGED
|
@@ -72,32 +72,20 @@ model = load_model()
|
|
| 72 |
|
| 73 |
# Classification function
|
| 74 |
def classify_email(email_text):
|
| 75 |
-
model.eval()
|
| 76 |
|
| 77 |
with torch.no_grad():
|
| 78 |
-
# Tokenize and convert input text to tensor
|
| 79 |
inputs = tokenizer(email_text, padding=True, truncation=True, max_length=256, return_tensors="pt")
|
| 80 |
-
|
| 81 |
-
# Move inputs to the appropriate device
|
| 82 |
inputs = {key: val.to(device) for key, val in inputs.items()}
|
| 83 |
-
|
| 84 |
-
# Get model predictions
|
| 85 |
outputs = model(**inputs)
|
| 86 |
logits = outputs.logits
|
| 87 |
-
|
| 88 |
-
# Convert logits to predicted class
|
| 89 |
predictions = torch.argmax(logits, dim=1)
|
| 90 |
-
|
| 91 |
-
# Convert logits to probabilities using softmax
|
| 92 |
probs = F.softmax(logits, dim=1)
|
| 93 |
-
confidence = torch.max(probs).item() * 100
|
| 94 |
|
| 95 |
-
# Convert numeric prediction to label
|
| 96 |
result = "Spam" if predictions.item() == 1 else "Ham"
|
|
|
|
| 97 |
|
| 98 |
-
return {
|
| 99 |
-
"result": result,
|
| 100 |
-
"confidence": f"{confidence:.2f}%",
|
| 101 |
}
|
| 102 |
|
| 103 |
# Evaluation function with detailed classification report
|
|
@@ -195,6 +183,8 @@ def create_interface():
|
|
| 195 |
fn=classify_email,
|
| 196 |
inputs=email_input,
|
| 197 |
outputs=[result_output, confidence_output]
|
|
|
|
|
|
|
| 198 |
)
|
| 199 |
|
| 200 |
gr.Markdown("## 📊 Model Performance Analytics")
|
|
|
|
| 72 |
|
| 73 |
# Classification function
|
| 74 |
def classify_email(email_text):
|
| 75 |
+
model.eval()
|
| 76 |
|
| 77 |
with torch.no_grad():
|
|
|
|
| 78 |
inputs = tokenizer(email_text, padding=True, truncation=True, max_length=256, return_tensors="pt")
|
|
|
|
|
|
|
| 79 |
inputs = {key: val.to(device) for key, val in inputs.items()}
|
|
|
|
|
|
|
| 80 |
outputs = model(**inputs)
|
| 81 |
logits = outputs.logits
|
|
|
|
|
|
|
| 82 |
predictions = torch.argmax(logits, dim=1)
|
|
|
|
|
|
|
| 83 |
probs = F.softmax(logits, dim=1)
|
| 84 |
+
confidence = torch.max(probs).item() * 100
|
| 85 |
|
|
|
|
| 86 |
result = "Spam" if predictions.item() == 1 else "Ham"
|
| 87 |
+
return result, f"{confidence:.2f}%"
|
| 88 |
|
|
|
|
|
|
|
|
|
|
| 89 |
}
|
| 90 |
|
| 91 |
# Evaluation function with detailed classification report
|
|
|
|
| 183 |
fn=classify_email,
|
| 184 |
inputs=email_input,
|
| 185 |
outputs=[result_output, confidence_output]
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
)
|
| 189 |
|
| 190 |
gr.Markdown("## 📊 Model Performance Analytics")
|