commure-smislam's picture
Upload handler.py with huggingface_hub
b52f440 verified
import torch
import json
import os
from transformers import AutoModel, AutoTokenizer
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the handler with the model path.
This gets called when the endpoint starts up.
"""
print(f"Loading model from path: {path}")
try:
# Load tokenizer
tokenizer_path = os.path.join(path, "tokenizer")
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
print("✅ Tokenizer loaded")
# Load backbone model
backbone_path = os.path.join(path, "backbone")
self.backbone = AutoModel.from_pretrained(backbone_path)
self.backbone.eval()
print("✅ Backbone model loaded")
# Load classification heads and metadata
heads_path = os.path.join(path, "classification_heads.pt")
checkpoint = torch.load(heads_path, map_location="cpu")
# Initialize classification heads
hidden_size = self.backbone.config.hidden_size
num_categories = len(checkpoint['categories'])
num_subcategories = len(checkpoint['subcategories'])
self.category_head = torch.nn.Linear(hidden_size, num_categories)
self.subcategory_head = torch.nn.Linear(hidden_size, num_subcategories)
self.dropout = torch.nn.Dropout(0.1)
# Load weights
self.category_head.load_state_dict(checkpoint['category_head'])
self.subcategory_head.load_state_dict(checkpoint['subcategory_head'])
# Set to eval mode
self.category_head.eval()
self.subcategory_head.eval()
# Store metadata
self.categories = checkpoint['categories']
self.subcategories = checkpoint['subcategories']
print(f"✅ Model fully loaded: {num_categories} categories, {num_subcategories} subcategories")
except Exception as e:
print(f"❌ Error loading model: {e}")
raise e
def __call__(self, data):
"""
Handle inference requests.
Args:
data: Dictionary with 'inputs' key containing text or list of texts
Returns:
Dictionary with predictions
"""
try:
# Extract inputs
inputs = data.get("inputs", "")
# Handle both single string and list
if isinstance(inputs, str):
inputs = [inputs]
elif not isinstance(inputs, list):
return {"error": "inputs must be a string or list of strings"}
if not inputs or inputs == [""]:
return {"error": "No input text provided"}
# Tokenize
encoded = self.tokenizer(
inputs,
truncation=True,
padding=True,
max_length=256,
return_tensors="pt"
)
# Predict
with torch.no_grad():
# Get backbone features
backbone_outputs = self.backbone(**encoded)
pooled_output = backbone_outputs.last_hidden_state[:, 0] # [CLS] token
pooled_output = self.dropout(pooled_output)
# Get logits
category_logits = self.category_head(pooled_output)
subcategory_logits = self.subcategory_head(pooled_output)
# Get predictions and confidence scores
category_preds = torch.argmax(category_logits, dim=1)
subcategory_preds = torch.argmax(subcategory_logits, dim=1)
category_probs = torch.softmax(category_logits, dim=1)
subcategory_probs = torch.softmax(subcategory_logits, dim=1)
category_confidence = torch.max(category_probs, dim=1)[0]
subcategory_confidence = torch.max(subcategory_probs, dim=1)[0]
# Format results
results = []
for i in range(len(inputs)):
result = {
"text": inputs[i],
"category": {
"label": self.categories[category_preds[i].item()],
"confidence": round(category_confidence[i].item(), 4)
},
"subcategory": {
"label": self.subcategories[subcategory_preds[i].item()],
"confidence": round(subcategory_confidence[i].item(), 4)
}
}
results.append(result)
# Return single result if single input, otherwise return list
return results[0] if len(results) == 1 else results
except Exception as e:
return {"error": f"Prediction failed: {str(e)}"}