File size: 5,151 Bytes
b52f440 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
import json
import os
from transformers import AutoModel, AutoTokenizer
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the handler with the model path.
This gets called when the endpoint starts up.
"""
print(f"Loading model from path: {path}")
try:
# Load tokenizer
tokenizer_path = os.path.join(path, "tokenizer")
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
print("✅ Tokenizer loaded")
# Load backbone model
backbone_path = os.path.join(path, "backbone")
self.backbone = AutoModel.from_pretrained(backbone_path)
self.backbone.eval()
print("✅ Backbone model loaded")
# Load classification heads and metadata
heads_path = os.path.join(path, "classification_heads.pt")
checkpoint = torch.load(heads_path, map_location="cpu")
# Initialize classification heads
hidden_size = self.backbone.config.hidden_size
num_categories = len(checkpoint['categories'])
num_subcategories = len(checkpoint['subcategories'])
self.category_head = torch.nn.Linear(hidden_size, num_categories)
self.subcategory_head = torch.nn.Linear(hidden_size, num_subcategories)
self.dropout = torch.nn.Dropout(0.1)
# Load weights
self.category_head.load_state_dict(checkpoint['category_head'])
self.subcategory_head.load_state_dict(checkpoint['subcategory_head'])
# Set to eval mode
self.category_head.eval()
self.subcategory_head.eval()
# Store metadata
self.categories = checkpoint['categories']
self.subcategories = checkpoint['subcategories']
print(f"✅ Model fully loaded: {num_categories} categories, {num_subcategories} subcategories")
except Exception as e:
print(f"❌ Error loading model: {e}")
raise e
def __call__(self, data):
"""
Handle inference requests.
Args:
data: Dictionary with 'inputs' key containing text or list of texts
Returns:
Dictionary with predictions
"""
try:
# Extract inputs
inputs = data.get("inputs", "")
# Handle both single string and list
if isinstance(inputs, str):
inputs = [inputs]
elif not isinstance(inputs, list):
return {"error": "inputs must be a string or list of strings"}
if not inputs or inputs == [""]:
return {"error": "No input text provided"}
# Tokenize
encoded = self.tokenizer(
inputs,
truncation=True,
padding=True,
max_length=256,
return_tensors="pt"
)
# Predict
with torch.no_grad():
# Get backbone features
backbone_outputs = self.backbone(**encoded)
pooled_output = backbone_outputs.last_hidden_state[:, 0] # [CLS] token
pooled_output = self.dropout(pooled_output)
# Get logits
category_logits = self.category_head(pooled_output)
subcategory_logits = self.subcategory_head(pooled_output)
# Get predictions and confidence scores
category_preds = torch.argmax(category_logits, dim=1)
subcategory_preds = torch.argmax(subcategory_logits, dim=1)
category_probs = torch.softmax(category_logits, dim=1)
subcategory_probs = torch.softmax(subcategory_logits, dim=1)
category_confidence = torch.max(category_probs, dim=1)[0]
subcategory_confidence = torch.max(subcategory_probs, dim=1)[0]
# Format results
results = []
for i in range(len(inputs)):
result = {
"text": inputs[i],
"category": {
"label": self.categories[category_preds[i].item()],
"confidence": round(category_confidence[i].item(), 4)
},
"subcategory": {
"label": self.subcategories[subcategory_preds[i].item()],
"confidence": round(subcategory_confidence[i].item(), 4)
}
}
results.append(result)
# Return single result if single input, otherwise return list
return results[0] if len(results) == 1 else results
except Exception as e:
return {"error": f"Prediction failed: {str(e)}"}
|