File size: 5,151 Bytes
b52f440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

import torch
import json
import os
from transformers import AutoModel, AutoTokenizer

class EndpointHandler:
    def __init__(self, path=""):
        """
        Initialize the handler with the model path.
        This gets called when the endpoint starts up.
        """
        print(f"Loading model from path: {path}")
        
        try:
            # Load tokenizer
            tokenizer_path = os.path.join(path, "tokenizer")
            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
            print("✅ Tokenizer loaded")
            
            # Load backbone model  
            backbone_path = os.path.join(path, "backbone")
            self.backbone = AutoModel.from_pretrained(backbone_path)
            self.backbone.eval()
            print("✅ Backbone model loaded")
            
            # Load classification heads and metadata
            heads_path = os.path.join(path, "classification_heads.pt")
            checkpoint = torch.load(heads_path, map_location="cpu")
            
            # Initialize classification heads
            hidden_size = self.backbone.config.hidden_size
            num_categories = len(checkpoint['categories'])
            num_subcategories = len(checkpoint['subcategories'])
            
            self.category_head = torch.nn.Linear(hidden_size, num_categories)
            self.subcategory_head = torch.nn.Linear(hidden_size, num_subcategories)
            self.dropout = torch.nn.Dropout(0.1)
            
            # Load weights
            self.category_head.load_state_dict(checkpoint['category_head'])
            self.subcategory_head.load_state_dict(checkpoint['subcategory_head'])
            
            # Set to eval mode
            self.category_head.eval()
            self.subcategory_head.eval()
            
            # Store metadata
            self.categories = checkpoint['categories']
            self.subcategories = checkpoint['subcategories']
            
            print(f"✅ Model fully loaded: {num_categories} categories, {num_subcategories} subcategories")
            
        except Exception as e:
            print(f"❌ Error loading model: {e}")
            raise e
    
    def __call__(self, data):
        """
        Handle inference requests.
        
        Args:
            data: Dictionary with 'inputs' key containing text or list of texts
            
        Returns:
            Dictionary with predictions
        """
        try:
            # Extract inputs
            inputs = data.get("inputs", "")
            
            # Handle both single string and list
            if isinstance(inputs, str):
                inputs = [inputs]
            elif not isinstance(inputs, list):
                return {"error": "inputs must be a string or list of strings"}
            
            if not inputs or inputs == [""]:
                return {"error": "No input text provided"}
            
            # Tokenize
            encoded = self.tokenizer(
                inputs,
                truncation=True,
                padding=True,
                max_length=256,
                return_tensors="pt"
            )
            
            # Predict
            with torch.no_grad():
                # Get backbone features
                backbone_outputs = self.backbone(**encoded)
                pooled_output = backbone_outputs.last_hidden_state[:, 0]  # [CLS] token
                pooled_output = self.dropout(pooled_output)
                
                # Get logits
                category_logits = self.category_head(pooled_output)
                subcategory_logits = self.subcategory_head(pooled_output)
                
                # Get predictions and confidence scores
                category_preds = torch.argmax(category_logits, dim=1)
                subcategory_preds = torch.argmax(subcategory_logits, dim=1)
                
                category_probs = torch.softmax(category_logits, dim=1)
                subcategory_probs = torch.softmax(subcategory_logits, dim=1)
                
                category_confidence = torch.max(category_probs, dim=1)[0]
                subcategory_confidence = torch.max(subcategory_probs, dim=1)[0]
            
            # Format results
            results = []
            for i in range(len(inputs)):
                result = {
                    "text": inputs[i],
                    "category": {
                        "label": self.categories[category_preds[i].item()],
                        "confidence": round(category_confidence[i].item(), 4)
                    },
                    "subcategory": {
                        "label": self.subcategories[subcategory_preds[i].item()],
                        "confidence": round(subcategory_confidence[i].item(), 4)
                    }
                }
                results.append(result)
            
            # Return single result if single input, otherwise return list
            return results[0] if len(results) == 1 else results
            
        except Exception as e:
            return {"error": f"Prediction failed: {str(e)}"}