commure-smislam's picture
Upload inference.py with huggingface_hub
213f3e0 verified
import torch
import json
from transformers import AutoModel, AutoTokenizer
class EmailClassifierInference:
def __init__(self, model_path):
self.tokenizer = AutoTokenizer.from_pretrained(f"{model_path}/tokenizer")
self.backbone = AutoModel.from_pretrained(f"{model_path}/backbone")
# Load classification heads and metadata
checkpoint = torch.load(f"{model_path}/classification_heads.pt", map_location="cpu")
self.category_head = torch.nn.Linear(self.backbone.config.hidden_size, len(checkpoint['categories']))
self.subcategory_head = torch.nn.Linear(self.backbone.config.hidden_size, len(checkpoint['subcategories']))
self.dropout = torch.nn.Dropout(0.1)
self.category_head.load_state_dict(checkpoint['category_head'])
self.subcategory_head.load_state_dict(checkpoint['subcategory_head'])
self.categories = checkpoint['categories']
self.subcategories = checkpoint['subcategories']
self.backbone.eval()
self.category_head.eval()
self.subcategory_head.eval()
def predict(self, text):
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256)
with torch.no_grad():
outputs = self.backbone(**inputs)
pooled = outputs.last_hidden_state[:, 0]
pooled = self.dropout(pooled)
cat_logits = self.category_head(pooled)
sub_logits = self.subcategory_head(pooled)
cat_pred = torch.argmax(cat_logits, dim=1).item()
sub_pred = torch.argmax(sub_logits, dim=1).item()
cat_conf = torch.softmax(cat_logits, dim=1).max().item()
sub_conf = torch.softmax(sub_logits, dim=1).max().item()
return {
"text": text,
"category": {"label": self.categories[cat_pred], "confidence": cat_conf},
"subcategory": {"label": self.subcategories[sub_pred], "confidence": sub_conf}
}
# Usage:
# classifier = EmailClassifierInference("./path/to/model")
# result = classifier.predict("Your email text here")
# print(result)