Smart-Email-Sorter / backend /compare_models.py
Surya8663
Final version, database correctly ignored
4ded330
import joblib
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
import torch
# -----------------------------
# 1️⃣ Load baseline TF-IDF model
# -----------------------------
baseline_model = joblib.load("models/baseline_folder_clf.pkl")
# -----------------------------
# 2️⃣ Load transformer model
# -----------------------------
model_path = "models/transformer"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_path)
model = DistilBertForSequenceClassification.from_pretrained(model_path)
le = joblib.load(f"{model_path}/le.pkl")
model.eval()
# -----------------------------
# 3️⃣ Sample emails for testing
# -----------------------------
test_emails = [
{"subject": "Team Standup Reminder", "body": "Please join the daily standup meeting at 10 AM."},
{"subject": "50% Off on Shoes", "body": "Grab the latest offer on sneakers."},
{"subject": "Mom's Birthday", "body": "Don't forget to call mom today."},
]
# -----------------------------
# 4️⃣ Compare predictions
# -----------------------------
for email in test_emails:
text = email["subject"] + " " + email["body"]
# Baseline prediction
baseline_pred = baseline_model.predict([text])[0]
# Transformer prediction
inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
pred_id = torch.argmax(outputs.logits, dim=1).item()
transformer_pred = le.inverse_transform([pred_id])[0]
print(f"\nEmail: {text}")
print(f"Baseline prediction: {baseline_pred}")
print(f"Transformer prediction: {transformer_pred}")