File size: 2,669 Bytes
977ba08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
Inference example for the MiniLM email classifier ONNX model.

Usage:
    pip install onnxruntime transformers
    python example.py
"""

import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer

CATEGORIES = ["ALERT", "NEWSLETTER", "PERSONAL", "SOCIAL", "TRANSACTION"]
REPO_ID = "Ippoboi/minilmail-classifier"


def classify_email(
    session: ort.InferenceSession,
    tokenizer: AutoTokenizer,
    subject: str,
    body: str,
    action_threshold: float = 0.5,
) -> dict:
    """Classify an email and return category + action prediction."""
    text = f"Subject: {subject}\n\nBody: {body}"
    inputs = tokenizer(text, return_tensors="np", max_length=256, truncation=True)

    cat_probs, act_prob = session.run(
        ["category_probs", "action_prob"],
        {
            "input_ids": inputs["input_ids"].astype(np.int64),
            "attention_mask": inputs["attention_mask"].astype(np.int64),
            "token_type_ids": np.zeros_like(inputs["input_ids"], dtype=np.int64),
        },
    )

    category_idx = int(np.argmax(cat_probs[0]))
    return {
        "category": CATEGORIES[category_idx],
        "confidence": float(cat_probs[0][category_idx]),
        "action_required": float(act_prob[0][0]) > action_threshold,
        "action_probability": float(act_prob[0][0]),
        "all_probabilities": {
            cat: float(prob) for cat, prob in zip(CATEGORIES, cat_probs[0])
        },
    }


def main():
    from huggingface_hub import hf_hub_download

    # Download model and tokenizer
    model_path = hf_hub_download(REPO_ID, "model.onnx")
    tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
    session = ort.InferenceSession(model_path)

    # Example emails
    emails = [
        ("Your order has shipped", "Your order #12345 is on its way and will arrive by Monday."),
        ("Meeting tomorrow", "Hey, can we reschedule our 2pm meeting to 3pm? Let me know."),
        ("Weekly Newsletter", "Check out our latest deals! 50% off everything this weekend."),
        ("Security Alert", "A new device logged into your account from San Francisco, CA."),
        ("LinkedIn: New connection", "John Doe wants to connect with you on LinkedIn."),
    ]

    print("=" * 60)
    print("MiniLM Email Classifier")
    print("=" * 60)

    for subject, body in emails:
        result = classify_email(session, tokenizer, subject, body)
        action = "ACTION" if result["action_required"] else "NO_ACTION"
        print(f"\n  Subject: {subject}")
        print(f"  → {result['category']} ({result['confidence']:.1%}) | {action} ({result['action_probability']:.1%})")


if __name__ == "__main__":
    main()