import torch import json import os from transformers import AutoModel, AutoTokenizer class EndpointHandler: def __init__(self, path=""): """ Initialize the handler with the model path. This gets called when the endpoint starts up. """ print(f"Loading model from path: {path}") try: # Load tokenizer tokenizer_path = os.path.join(path, "tokenizer") self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) print("✅ Tokenizer loaded") # Load backbone model backbone_path = os.path.join(path, "backbone") self.backbone = AutoModel.from_pretrained(backbone_path) self.backbone.eval() print("✅ Backbone model loaded") # Load classification heads and metadata heads_path = os.path.join(path, "classification_heads.pt") checkpoint = torch.load(heads_path, map_location="cpu") # Initialize classification heads hidden_size = self.backbone.config.hidden_size num_categories = len(checkpoint['categories']) num_subcategories = len(checkpoint['subcategories']) self.category_head = torch.nn.Linear(hidden_size, num_categories) self.subcategory_head = torch.nn.Linear(hidden_size, num_subcategories) self.dropout = torch.nn.Dropout(0.1) # Load weights self.category_head.load_state_dict(checkpoint['category_head']) self.subcategory_head.load_state_dict(checkpoint['subcategory_head']) # Set to eval mode self.category_head.eval() self.subcategory_head.eval() # Store metadata self.categories = checkpoint['categories'] self.subcategories = checkpoint['subcategories'] print(f"✅ Model fully loaded: {num_categories} categories, {num_subcategories} subcategories") except Exception as e: print(f"❌ Error loading model: {e}") raise e def __call__(self, data): """ Handle inference requests. Args: data: Dictionary with 'inputs' key containing text or list of texts Returns: Dictionary with predictions """ try: # Extract inputs inputs = data.get("inputs", "") # Handle both single string and list if isinstance(inputs, str): inputs = [inputs] elif not isinstance(inputs, list): return {"error": "inputs must be a string or list of strings"} if not inputs or inputs == [""]: return {"error": "No input text provided"} # Tokenize encoded = self.tokenizer( inputs, truncation=True, padding=True, max_length=256, return_tensors="pt" ) # Predict with torch.no_grad(): # Get backbone features backbone_outputs = self.backbone(**encoded) pooled_output = backbone_outputs.last_hidden_state[:, 0] # [CLS] token pooled_output = self.dropout(pooled_output) # Get logits category_logits = self.category_head(pooled_output) subcategory_logits = self.subcategory_head(pooled_output) # Get predictions and confidence scores category_preds = torch.argmax(category_logits, dim=1) subcategory_preds = torch.argmax(subcategory_logits, dim=1) category_probs = torch.softmax(category_logits, dim=1) subcategory_probs = torch.softmax(subcategory_logits, dim=1) category_confidence = torch.max(category_probs, dim=1)[0] subcategory_confidence = torch.max(subcategory_probs, dim=1)[0] # Format results results = [] for i in range(len(inputs)): result = { "text": inputs[i], "category": { "label": self.categories[category_preds[i].item()], "confidence": round(category_confidence[i].item(), 4) }, "subcategory": { "label": self.subcategories[subcategory_preds[i].item()], "confidence": round(subcategory_confidence[i].item(), 4) } } results.append(result) # Return single result if single input, otherwise return list return results[0] if len(results) == 1 else results except Exception as e: return {"error": f"Prediction failed: {str(e)}"}