commure-smislam commited on
Commit
b52f440
·
verified ·
1 Parent(s): cc00fd1

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +131 -0
handler.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import json
4
+ import os
5
+ from transformers import AutoModel, AutoTokenizer
6
+
7
+ class EndpointHandler:
8
+ def __init__(self, path=""):
9
+ """
10
+ Initialize the handler with the model path.
11
+ This gets called when the endpoint starts up.
12
+ """
13
+ print(f"Loading model from path: {path}")
14
+
15
+ try:
16
+ # Load tokenizer
17
+ tokenizer_path = os.path.join(path, "tokenizer")
18
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
19
+ print("✅ Tokenizer loaded")
20
+
21
+ # Load backbone model
22
+ backbone_path = os.path.join(path, "backbone")
23
+ self.backbone = AutoModel.from_pretrained(backbone_path)
24
+ self.backbone.eval()
25
+ print("✅ Backbone model loaded")
26
+
27
+ # Load classification heads and metadata
28
+ heads_path = os.path.join(path, "classification_heads.pt")
29
+ checkpoint = torch.load(heads_path, map_location="cpu")
30
+
31
+ # Initialize classification heads
32
+ hidden_size = self.backbone.config.hidden_size
33
+ num_categories = len(checkpoint['categories'])
34
+ num_subcategories = len(checkpoint['subcategories'])
35
+
36
+ self.category_head = torch.nn.Linear(hidden_size, num_categories)
37
+ self.subcategory_head = torch.nn.Linear(hidden_size, num_subcategories)
38
+ self.dropout = torch.nn.Dropout(0.1)
39
+
40
+ # Load weights
41
+ self.category_head.load_state_dict(checkpoint['category_head'])
42
+ self.subcategory_head.load_state_dict(checkpoint['subcategory_head'])
43
+
44
+ # Set to eval mode
45
+ self.category_head.eval()
46
+ self.subcategory_head.eval()
47
+
48
+ # Store metadata
49
+ self.categories = checkpoint['categories']
50
+ self.subcategories = checkpoint['subcategories']
51
+
52
+ print(f"✅ Model fully loaded: {num_categories} categories, {num_subcategories} subcategories")
53
+
54
+ except Exception as e:
55
+ print(f"❌ Error loading model: {e}")
56
+ raise e
57
+
58
+ def __call__(self, data):
59
+ """
60
+ Handle inference requests.
61
+
62
+ Args:
63
+ data: Dictionary with 'inputs' key containing text or list of texts
64
+
65
+ Returns:
66
+ Dictionary with predictions
67
+ """
68
+ try:
69
+ # Extract inputs
70
+ inputs = data.get("inputs", "")
71
+
72
+ # Handle both single string and list
73
+ if isinstance(inputs, str):
74
+ inputs = [inputs]
75
+ elif not isinstance(inputs, list):
76
+ return {"error": "inputs must be a string or list of strings"}
77
+
78
+ if not inputs or inputs == [""]:
79
+ return {"error": "No input text provided"}
80
+
81
+ # Tokenize
82
+ encoded = self.tokenizer(
83
+ inputs,
84
+ truncation=True,
85
+ padding=True,
86
+ max_length=256,
87
+ return_tensors="pt"
88
+ )
89
+
90
+ # Predict
91
+ with torch.no_grad():
92
+ # Get backbone features
93
+ backbone_outputs = self.backbone(**encoded)
94
+ pooled_output = backbone_outputs.last_hidden_state[:, 0] # [CLS] token
95
+ pooled_output = self.dropout(pooled_output)
96
+
97
+ # Get logits
98
+ category_logits = self.category_head(pooled_output)
99
+ subcategory_logits = self.subcategory_head(pooled_output)
100
+
101
+ # Get predictions and confidence scores
102
+ category_preds = torch.argmax(category_logits, dim=1)
103
+ subcategory_preds = torch.argmax(subcategory_logits, dim=1)
104
+
105
+ category_probs = torch.softmax(category_logits, dim=1)
106
+ subcategory_probs = torch.softmax(subcategory_logits, dim=1)
107
+
108
+ category_confidence = torch.max(category_probs, dim=1)[0]
109
+ subcategory_confidence = torch.max(subcategory_probs, dim=1)[0]
110
+
111
+ # Format results
112
+ results = []
113
+ for i in range(len(inputs)):
114
+ result = {
115
+ "text": inputs[i],
116
+ "category": {
117
+ "label": self.categories[category_preds[i].item()],
118
+ "confidence": round(category_confidence[i].item(), 4)
119
+ },
120
+ "subcategory": {
121
+ "label": self.subcategories[subcategory_preds[i].item()],
122
+ "confidence": round(subcategory_confidence[i].item(), 4)
123
+ }
124
+ }
125
+ results.append(result)
126
+
127
+ # Return single result if single input, otherwise return list
128
+ return results[0] if len(results) == 1 else results
129
+
130
+ except Exception as e:
131
+ return {"error": f"Prediction failed: {str(e)}"}