File size: 2,249 Bytes
213f3e0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import torch
import json
from transformers import AutoModel, AutoTokenizer
class EmailClassifierInference:
def __init__(self, model_path):
self.tokenizer = AutoTokenizer.from_pretrained(f"{model_path}/tokenizer")
self.backbone = AutoModel.from_pretrained(f"{model_path}/backbone")
# Load classification heads and metadata
checkpoint = torch.load(f"{model_path}/classification_heads.pt", map_location="cpu")
self.category_head = torch.nn.Linear(self.backbone.config.hidden_size, len(checkpoint['categories']))
self.subcategory_head = torch.nn.Linear(self.backbone.config.hidden_size, len(checkpoint['subcategories']))
self.dropout = torch.nn.Dropout(0.1)
self.category_head.load_state_dict(checkpoint['category_head'])
self.subcategory_head.load_state_dict(checkpoint['subcategory_head'])
self.categories = checkpoint['categories']
self.subcategories = checkpoint['subcategories']
self.backbone.eval()
self.category_head.eval()
self.subcategory_head.eval()
def predict(self, text):
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256)
with torch.no_grad():
outputs = self.backbone(**inputs)
pooled = outputs.last_hidden_state[:, 0]
pooled = self.dropout(pooled)
cat_logits = self.category_head(pooled)
sub_logits = self.subcategory_head(pooled)
cat_pred = torch.argmax(cat_logits, dim=1).item()
sub_pred = torch.argmax(sub_logits, dim=1).item()
cat_conf = torch.softmax(cat_logits, dim=1).max().item()
sub_conf = torch.softmax(sub_logits, dim=1).max().item()
return {
"text": text,
"category": {"label": self.categories[cat_pred], "confidence": cat_conf},
"subcategory": {"label": self.subcategories[sub_pred], "confidence": sub_conf}
}
# Usage:
# classifier = EmailClassifierInference("./path/to/model")
# result = classifier.predict("Your email text here")
# print(result)
|