File size: 3,910 Bytes
4798a4f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import torch
import json
import os
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn
class EmailClassificationModel(nn.Module):
def __init__(self, model_name, num_categories, num_subcategories, dropout=0.1):
super().__init__()
self.backbone = AutoModel.from_pretrained(model_name)
self.dropout = nn.Dropout(dropout)
self.category_head = nn.Linear(self.backbone.config.hidden_size, num_categories)
self.subcategory_head = nn.Linear(self.backbone.config.hidden_size, num_subcategories)
def forward(self, input_ids, attention_mask):
outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.last_hidden_state[:, 0]
pooled_output = self.dropout(pooled_output)
category_logits = self.category_head(pooled_output)
subcategory_logits = self.subcategory_head(pooled_output)
return {
'category_logits': category_logits,
'subcategory_logits': subcategory_logits
}
class EndpointHandler:
def __init__(self, path=""):
# Load checkpoint
checkpoint = torch.load(os.path.join(path, "model_checkpoint.pt"), map_location="cpu")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(os.path.join(path, "tokenizer"))
# Recreate model
config = checkpoint['model_config']
self.model = EmailClassificationModel(
model_name=config['model_name'],
num_categories=config['num_categories'],
num_subcategories=config['num_subcategories']
)
# Load weights
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.eval()
# Load metadata
self.categories = checkpoint['categories']
self.subcategories = checkpoint['subcategories']
self.max_length = config['max_length']
def __call__(self, data):
try:
inputs = data.get("inputs", "")
if isinstance(inputs, str):
inputs = [inputs]
# Tokenize
encoded = self.tokenizer(
inputs,
truncation=True,
padding=True,
max_length=self.max_length,
return_tensors="pt"
)
# Predict
with torch.no_grad():
outputs = self.model(**encoded)
category_preds = torch.argmax(outputs['category_logits'], dim=1)
subcategory_preds = torch.argmax(outputs['subcategory_logits'], dim=1)
category_probs = torch.softmax(outputs['category_logits'], dim=1)
subcategory_probs = torch.softmax(outputs['subcategory_logits'], dim=1)
category_confidence = torch.max(category_probs, dim=1)[0]
subcategory_confidence = torch.max(subcategory_probs, dim=1)[0]
# Format results
results = []
for i in range(len(inputs)):
result = {
"text": inputs[i],
"category": {
"label": self.categories[category_preds[i].item()],
"confidence": round(category_confidence[i].item(), 4)
},
"subcategory": {
"label": self.subcategories[subcategory_preds[i].item()],
"confidence": round(subcategory_confidence[i].item(), 4)
}
}
results.append(result)
return results[0] if len(results) == 1 else results
except Exception as e:
return {"error": f"Prediction failed: {str(e)}"}
|