import torch import json import os from transformers import AutoTokenizer, AutoModel import torch.nn as nn class EmailClassificationModel(nn.Module): def __init__(self, model_name, num_categories, num_subcategories, dropout=0.1): super().__init__() self.backbone = AutoModel.from_pretrained(model_name) self.dropout = nn.Dropout(dropout) self.category_head = nn.Linear(self.backbone.config.hidden_size, num_categories) self.subcategory_head = nn.Linear(self.backbone.config.hidden_size, num_subcategories) def forward(self, input_ids, attention_mask): outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask) pooled_output = outputs.last_hidden_state[:, 0] pooled_output = self.dropout(pooled_output) category_logits = self.category_head(pooled_output) subcategory_logits = self.subcategory_head(pooled_output) return { 'category_logits': category_logits, 'subcategory_logits': subcategory_logits } class EndpointHandler: def __init__(self, path=""): # Load checkpoint checkpoint = torch.load(os.path.join(path, "model_checkpoint.pt"), map_location="cpu") # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained(os.path.join(path, "tokenizer")) # Recreate model config = checkpoint['model_config'] self.model = EmailClassificationModel( model_name=config['model_name'], num_categories=config['num_categories'], num_subcategories=config['num_subcategories'] ) # Load weights self.model.load_state_dict(checkpoint['model_state_dict']) self.model.eval() # Load metadata self.categories = checkpoint['categories'] self.subcategories = checkpoint['subcategories'] self.max_length = config['max_length'] def __call__(self, data): try: inputs = data.get("inputs", "") if isinstance(inputs, str): inputs = [inputs] # Tokenize encoded = self.tokenizer( inputs, truncation=True, padding=True, max_length=self.max_length, return_tensors="pt" ) # Predict with torch.no_grad(): outputs = self.model(**encoded) category_preds = torch.argmax(outputs['category_logits'], dim=1) subcategory_preds = torch.argmax(outputs['subcategory_logits'], dim=1) category_probs = torch.softmax(outputs['category_logits'], dim=1) subcategory_probs = torch.softmax(outputs['subcategory_logits'], dim=1) category_confidence = torch.max(category_probs, dim=1)[0] subcategory_confidence = torch.max(subcategory_probs, dim=1)[0] # Format results results = [] for i in range(len(inputs)): result = { "text": inputs[i], "category": { "label": self.categories[category_preds[i].item()], "confidence": round(category_confidence[i].item(), 4) }, "subcategory": { "label": self.subcategories[subcategory_preds[i].item()], "confidence": round(subcategory_confidence[i].item(), 4) } } results.append(result) return results[0] if len(results) == 1 else results except Exception as e: return {"error": f"Prediction failed: {str(e)}"}