File size: 2,249 Bytes
213f3e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

import torch
import json
from transformers import AutoModel, AutoTokenizer

class EmailClassifierInference:
    def __init__(self, model_path):
        self.tokenizer = AutoTokenizer.from_pretrained(f"{model_path}/tokenizer")
        self.backbone = AutoModel.from_pretrained(f"{model_path}/backbone")
        
        # Load classification heads and metadata
        checkpoint = torch.load(f"{model_path}/classification_heads.pt", map_location="cpu")
        
        self.category_head = torch.nn.Linear(self.backbone.config.hidden_size, len(checkpoint['categories']))
        self.subcategory_head = torch.nn.Linear(self.backbone.config.hidden_size, len(checkpoint['subcategories']))
        self.dropout = torch.nn.Dropout(0.1)
        
        self.category_head.load_state_dict(checkpoint['category_head'])
        self.subcategory_head.load_state_dict(checkpoint['subcategory_head'])
        
        self.categories = checkpoint['categories']
        self.subcategories = checkpoint['subcategories']
        
        self.backbone.eval()
        self.category_head.eval()
        self.subcategory_head.eval()
    
    def predict(self, text):
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256)
        
        with torch.no_grad():
            outputs = self.backbone(**inputs)
            pooled = outputs.last_hidden_state[:, 0]
            pooled = self.dropout(pooled)
            
            cat_logits = self.category_head(pooled)
            sub_logits = self.subcategory_head(pooled)
            
            cat_pred = torch.argmax(cat_logits, dim=1).item()
            sub_pred = torch.argmax(sub_logits, dim=1).item()
            
            cat_conf = torch.softmax(cat_logits, dim=1).max().item()
            sub_conf = torch.softmax(sub_logits, dim=1).max().item()
            
            return {
                "text": text,
                "category": {"label": self.categories[cat_pred], "confidence": cat_conf},
                "subcategory": {"label": self.subcategories[sub_pred], "confidence": sub_conf}
            }

# Usage:
# classifier = EmailClassifierInference("./path/to/model")
# result = classifier.predict("Your email text here")
# print(result)