minilmail-classifier / example.py
Ippoboi's picture
Upload example.py with huggingface_hub
977ba08 verified
"""
Inference example for the MiniLM email classifier ONNX model.
Usage:
pip install onnxruntime transformers
python example.py
"""
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer
CATEGORIES = ["ALERT", "NEWSLETTER", "PERSONAL", "SOCIAL", "TRANSACTION"]
REPO_ID = "Ippoboi/minilmail-classifier"
def classify_email(
session: ort.InferenceSession,
tokenizer: AutoTokenizer,
subject: str,
body: str,
action_threshold: float = 0.5,
) -> dict:
"""Classify an email and return category + action prediction."""
text = f"Subject: {subject}\n\nBody: {body}"
inputs = tokenizer(text, return_tensors="np", max_length=256, truncation=True)
cat_probs, act_prob = session.run(
["category_probs", "action_prob"],
{
"input_ids": inputs["input_ids"].astype(np.int64),
"attention_mask": inputs["attention_mask"].astype(np.int64),
"token_type_ids": np.zeros_like(inputs["input_ids"], dtype=np.int64),
},
)
category_idx = int(np.argmax(cat_probs[0]))
return {
"category": CATEGORIES[category_idx],
"confidence": float(cat_probs[0][category_idx]),
"action_required": float(act_prob[0][0]) > action_threshold,
"action_probability": float(act_prob[0][0]),
"all_probabilities": {
cat: float(prob) for cat, prob in zip(CATEGORIES, cat_probs[0])
},
}
def main():
from huggingface_hub import hf_hub_download
# Download model and tokenizer
model_path = hf_hub_download(REPO_ID, "model.onnx")
tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
session = ort.InferenceSession(model_path)
# Example emails
emails = [
("Your order has shipped", "Your order #12345 is on its way and will arrive by Monday."),
("Meeting tomorrow", "Hey, can we reschedule our 2pm meeting to 3pm? Let me know."),
("Weekly Newsletter", "Check out our latest deals! 50% off everything this weekend."),
("Security Alert", "A new device logged into your account from San Francisco, CA."),
("LinkedIn: New connection", "John Doe wants to connect with you on LinkedIn."),
]
print("=" * 60)
print("MiniLM Email Classifier")
print("=" * 60)
for subject, body in emails:
result = classify_email(session, tokenizer, subject, body)
action = "ACTION" if result["action_required"] else "NO_ACTION"
print(f"\n Subject: {subject}")
print(f" → {result['category']} ({result['confidence']:.1%}) | {action} ({result['action_probability']:.1%})")
if __name__ == "__main__":
main()