import torch import json from transformers import AutoModel, AutoTokenizer class EmailClassifierInference: def __init__(self, model_path): self.tokenizer = AutoTokenizer.from_pretrained(f"{model_path}/tokenizer") self.backbone = AutoModel.from_pretrained(f"{model_path}/backbone") # Load classification heads and metadata checkpoint = torch.load(f"{model_path}/classification_heads.pt", map_location="cpu") self.category_head = torch.nn.Linear(self.backbone.config.hidden_size, len(checkpoint['categories'])) self.subcategory_head = torch.nn.Linear(self.backbone.config.hidden_size, len(checkpoint['subcategories'])) self.dropout = torch.nn.Dropout(0.1) self.category_head.load_state_dict(checkpoint['category_head']) self.subcategory_head.load_state_dict(checkpoint['subcategory_head']) self.categories = checkpoint['categories'] self.subcategories = checkpoint['subcategories'] self.backbone.eval() self.category_head.eval() self.subcategory_head.eval() def predict(self, text): inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256) with torch.no_grad(): outputs = self.backbone(**inputs) pooled = outputs.last_hidden_state[:, 0] pooled = self.dropout(pooled) cat_logits = self.category_head(pooled) sub_logits = self.subcategory_head(pooled) cat_pred = torch.argmax(cat_logits, dim=1).item() sub_pred = torch.argmax(sub_logits, dim=1).item() cat_conf = torch.softmax(cat_logits, dim=1).max().item() sub_conf = torch.softmax(sub_logits, dim=1).max().item() return { "text": text, "category": {"label": self.categories[cat_pred], "confidence": cat_conf}, "subcategory": {"label": self.subcategories[sub_pred], "confidence": sub_conf} } # Usage: # classifier = EmailClassifierInference("./path/to/model") # result = classifier.predict("Your email text here") # print(result)