File size: 3,910 Bytes
4798a4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

import torch
import json
import os
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn

class EmailClassificationModel(nn.Module):
    def __init__(self, model_name, num_categories, num_subcategories, dropout=0.1):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(dropout)
        self.category_head = nn.Linear(self.backbone.config.hidden_size, num_categories)
        self.subcategory_head = nn.Linear(self.backbone.config.hidden_size, num_subcategories)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0]
        pooled_output = self.dropout(pooled_output)
        
        category_logits = self.category_head(pooled_output)
        subcategory_logits = self.subcategory_head(pooled_output)
        
        return {
            'category_logits': category_logits,
            'subcategory_logits': subcategory_logits
        }

class EndpointHandler:
    def __init__(self, path=""):
        # Load checkpoint
        checkpoint = torch.load(os.path.join(path, "model_checkpoint.pt"), map_location="cpu")
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(os.path.join(path, "tokenizer"))
        
        # Recreate model
        config = checkpoint['model_config']
        self.model = EmailClassificationModel(
            model_name=config['model_name'],
            num_categories=config['num_categories'],
            num_subcategories=config['num_subcategories']
        )
        
        # Load weights
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.eval()
        
        # Load metadata
        self.categories = checkpoint['categories']
        self.subcategories = checkpoint['subcategories']
        self.max_length = config['max_length']
        
    def __call__(self, data):
        try:
            inputs = data.get("inputs", "")
            if isinstance(inputs, str):
                inputs = [inputs]
            
            # Tokenize
            encoded = self.tokenizer(
                inputs,
                truncation=True,
                padding=True,
                max_length=self.max_length,
                return_tensors="pt"
            )
            
            # Predict
            with torch.no_grad():
                outputs = self.model(**encoded)
                
                category_preds = torch.argmax(outputs['category_logits'], dim=1)
                subcategory_preds = torch.argmax(outputs['subcategory_logits'], dim=1)
                
                category_probs = torch.softmax(outputs['category_logits'], dim=1)
                subcategory_probs = torch.softmax(outputs['subcategory_logits'], dim=1)
                
                category_confidence = torch.max(category_probs, dim=1)[0]
                subcategory_confidence = torch.max(subcategory_probs, dim=1)[0]
            
            # Format results
            results = []
            for i in range(len(inputs)):
                result = {
                    "text": inputs[i],
                    "category": {
                        "label": self.categories[category_preds[i].item()],
                        "confidence": round(category_confidence[i].item(), 4)
                    },
                    "subcategory": {
                        "label": self.subcategories[subcategory_preds[i].item()],
                        "confidence": round(subcategory_confidence[i].item(), 4)
                    }
                }
                results.append(result)
            
            return results[0] if len(results) == 1 else results
            
        except Exception as e:
            return {"error": f"Prediction failed: {str(e)}"}