| """ |
| Inference example for the MiniLM email classifier ONNX model. |
| |
| Usage: |
| pip install onnxruntime transformers |
| python example.py |
| """ |
|
|
| import numpy as np |
| import onnxruntime as ort |
| from transformers import AutoTokenizer |
|
|
| CATEGORIES = ["ALERT", "NEWSLETTER", "PERSONAL", "SOCIAL", "TRANSACTION"] |
| REPO_ID = "Ippoboi/minilmail-classifier" |
|
|
|
|
| def classify_email( |
| session: ort.InferenceSession, |
| tokenizer: AutoTokenizer, |
| subject: str, |
| body: str, |
| action_threshold: float = 0.5, |
| ) -> dict: |
| """Classify an email and return category + action prediction.""" |
| text = f"Subject: {subject}\n\nBody: {body}" |
| inputs = tokenizer(text, return_tensors="np", max_length=256, truncation=True) |
|
|
| cat_probs, act_prob = session.run( |
| ["category_probs", "action_prob"], |
| { |
| "input_ids": inputs["input_ids"].astype(np.int64), |
| "attention_mask": inputs["attention_mask"].astype(np.int64), |
| "token_type_ids": np.zeros_like(inputs["input_ids"], dtype=np.int64), |
| }, |
| ) |
|
|
| category_idx = int(np.argmax(cat_probs[0])) |
| return { |
| "category": CATEGORIES[category_idx], |
| "confidence": float(cat_probs[0][category_idx]), |
| "action_required": float(act_prob[0][0]) > action_threshold, |
| "action_probability": float(act_prob[0][0]), |
| "all_probabilities": { |
| cat: float(prob) for cat, prob in zip(CATEGORIES, cat_probs[0]) |
| }, |
| } |
|
|
|
|
| def main(): |
| from huggingface_hub import hf_hub_download |
|
|
| |
| model_path = hf_hub_download(REPO_ID, "model.onnx") |
| tokenizer = AutoTokenizer.from_pretrained(REPO_ID) |
| session = ort.InferenceSession(model_path) |
|
|
| |
| emails = [ |
| ("Your order has shipped", "Your order #12345 is on its way and will arrive by Monday."), |
| ("Meeting tomorrow", "Hey, can we reschedule our 2pm meeting to 3pm? Let me know."), |
| ("Weekly Newsletter", "Check out our latest deals! 50% off everything this weekend."), |
| ("Security Alert", "A new device logged into your account from San Francisco, CA."), |
| ("LinkedIn: New connection", "John Doe wants to connect with you on LinkedIn."), |
| ] |
|
|
| print("=" * 60) |
| print("MiniLM Email Classifier") |
| print("=" * 60) |
|
|
| for subject, body in emails: |
| result = classify_email(session, tokenizer, subject, body) |
| action = "ACTION" if result["action_required"] else "NO_ACTION" |
| print(f"\n Subject: {subject}") |
| print(f" → {result['category']} ({result['confidence']:.1%}) | {action} ({result['action_probability']:.1%})") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|