|
|
| import torch |
| import json |
| import os |
| from transformers import AutoModel, AutoTokenizer |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| """ |
| Initialize the handler with the model path. |
| This gets called when the endpoint starts up. |
| """ |
| print(f"Loading model from path: {path}") |
| |
| try: |
| |
| tokenizer_path = os.path.join(path, "tokenizer") |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
| print("✅ Tokenizer loaded") |
| |
| |
| backbone_path = os.path.join(path, "backbone") |
| self.backbone = AutoModel.from_pretrained(backbone_path) |
| self.backbone.eval() |
| print("✅ Backbone model loaded") |
| |
| |
| heads_path = os.path.join(path, "classification_heads.pt") |
| checkpoint = torch.load(heads_path, map_location="cpu") |
| |
| |
| hidden_size = self.backbone.config.hidden_size |
| num_categories = len(checkpoint['categories']) |
| num_subcategories = len(checkpoint['subcategories']) |
| |
| self.category_head = torch.nn.Linear(hidden_size, num_categories) |
| self.subcategory_head = torch.nn.Linear(hidden_size, num_subcategories) |
| self.dropout = torch.nn.Dropout(0.1) |
| |
| |
| self.category_head.load_state_dict(checkpoint['category_head']) |
| self.subcategory_head.load_state_dict(checkpoint['subcategory_head']) |
| |
| |
| self.category_head.eval() |
| self.subcategory_head.eval() |
| |
| |
| self.categories = checkpoint['categories'] |
| self.subcategories = checkpoint['subcategories'] |
| |
| print(f"✅ Model fully loaded: {num_categories} categories, {num_subcategories} subcategories") |
| |
| except Exception as e: |
| print(f"❌ Error loading model: {e}") |
| raise e |
| |
| def __call__(self, data): |
| """ |
| Handle inference requests. |
| |
| Args: |
| data: Dictionary with 'inputs' key containing text or list of texts |
| |
| Returns: |
| Dictionary with predictions |
| """ |
| try: |
| |
| inputs = data.get("inputs", "") |
| |
| |
| if isinstance(inputs, str): |
| inputs = [inputs] |
| elif not isinstance(inputs, list): |
| return {"error": "inputs must be a string or list of strings"} |
| |
| if not inputs or inputs == [""]: |
| return {"error": "No input text provided"} |
| |
| |
| encoded = self.tokenizer( |
| inputs, |
| truncation=True, |
| padding=True, |
| max_length=256, |
| return_tensors="pt" |
| ) |
| |
| |
| with torch.no_grad(): |
| |
| backbone_outputs = self.backbone(**encoded) |
| pooled_output = backbone_outputs.last_hidden_state[:, 0] |
| pooled_output = self.dropout(pooled_output) |
| |
| |
| category_logits = self.category_head(pooled_output) |
| subcategory_logits = self.subcategory_head(pooled_output) |
| |
| |
| category_preds = torch.argmax(category_logits, dim=1) |
| subcategory_preds = torch.argmax(subcategory_logits, dim=1) |
| |
| category_probs = torch.softmax(category_logits, dim=1) |
| subcategory_probs = torch.softmax(subcategory_logits, dim=1) |
| |
| category_confidence = torch.max(category_probs, dim=1)[0] |
| subcategory_confidence = torch.max(subcategory_probs, dim=1)[0] |
| |
| |
| results = [] |
| for i in range(len(inputs)): |
| result = { |
| "text": inputs[i], |
| "category": { |
| "label": self.categories[category_preds[i].item()], |
| "confidence": round(category_confidence[i].item(), 4) |
| }, |
| "subcategory": { |
| "label": self.subcategories[subcategory_preds[i].item()], |
| "confidence": round(subcategory_confidence[i].item(), 4) |
| } |
| } |
| results.append(result) |
| |
| |
| return results[0] if len(results) == 1 else results |
| |
| except Exception as e: |
| return {"error": f"Prediction failed: {str(e)}"} |
|
|