| | import os |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| | from typing import Dict, List, Any |
| |
|
| | class EndpointHandler(): |
| | def __init__(self, model_id: str): |
| | """ |
| | Initializes the handler by loading the model and tokenizer. |
| | |
| | Args: |
| | model_id (str): The Hugging Face model ID (e.g., "MoritzLaurer/DeBERTa-v3-base-mnli") |
| | This is automatically passed by the Inference Endpoint infrastructure. |
| | """ |
| | print(f"Loading model '{model_id}'...") |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(f"Using device: {self.device}") |
| |
|
| | self.tokenizer = AutoTokenizer.from_pretrained(model_id) |
| | self.model = AutoModelForSequenceClassification.from_pretrained(model_id) |
| |
|
| | |
| | self.model.to(self.device) |
| | |
| | self.model.eval() |
| | print("Model and tokenizer loaded successfully.") |
| |
|
| | |
| | |
| | try: |
| | |
| | sorted_labels = sorted(self.model.config.id2label.items()) |
| | self.label_names = [label for _, label in sorted_labels] |
| | print(f"Using label names from model config: {self.label_names}") |
| | |
| | if len(self.label_names) != 3: |
| | print(f"Warning: Expected 3 labels for NLI, but model config has {len(self.label_names)}. Proceeding with model's labels.") |
| | if not any("entail" in l.lower() for l in self.label_names) or \ |
| | not any("neutral" in l.lower() for l in self.label_names) or \ |
| | not any("contra" in l.lower() for l in self.label_names): |
| | print(f"Warning: Model labels {self.label_names} might not match standard NLI labels ('entailment', 'neutral', 'contradiction').") |
| |
|
| | except AttributeError: |
| | |
| | self.label_names = ["entailment", "neutral", "contradiction"] |
| | print(f"Warning: Could not read labels from model config. Falling back to default: {self.label_names}") |
| | print("Ensure this order matches the actual output order of the model!") |
| |
|
| | print(f"Configured label order for output: {self.label_names}") |
| |
|
| |
|
| | |
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any] | List[Dict[str, Any]]: |
| | """ |
| | Handles inference requests. |
| | |
| | Args: |
| | data (Dict[str, Any]): The input data payload from the request. |
| | Expected keys: "premise" (str) and "hypothesis" (str). |
| | Can optionally be nested under "inputs". |
| | |
| | Returns: |
| | Dict[str, Any] | List[Dict[str, Any]]: A dictionary containing error info, |
| | or a list of dictionaries, each mapping |
| | a label name to its probability score. |
| | """ |
| | |
| | inputs = data.get("inputs", data) |
| | premise = inputs.get("premise") |
| | hypothesis = inputs.get("hypothesis") |
| |
|
| | |
| | if not premise or not isinstance(premise, str): |
| | return {"error": "Missing or invalid 'premise' key in input. Expected a string."} |
| | if not hypothesis or not isinstance(hypothesis, str): |
| | return {"error": "Missing or invalid 'hypothesis' key in input. Expected a string."} |
| |
|
| | |
| | |
| | try: |
| | tokenized_inputs = self.tokenizer( |
| | premise, |
| | hypothesis, |
| | return_tensors="pt", |
| | truncation=True, |
| | padding=True, |
| | max_length=self.tokenizer.model_max_length |
| | ) |
| | except Exception as e: |
| | print(f"Error during tokenization: {e}") |
| | return {"error": f"Failed to tokenize input: {e}"} |
| |
|
| |
|
| | |
| | tokenized_inputs = {k: v.to(self.device) for k, v in tokenized_inputs.items()} |
| |
|
| | |
| | try: |
| | with torch.no_grad(): |
| | outputs = self.model(**tokenized_inputs) |
| | logits = outputs.logits |
| |
|
| | |
| | probabilities = torch.softmax(logits, dim=-1) |
| |
|
| | |
| | |
| | scores = probabilities.cpu().numpy()[0].tolist() |
| |
|
| | |
| | |
| | result = [{"label": label, "score": score} for label, score in zip(self.label_names, scores)] |
| |
|
| | return result |
| |
|
| | except Exception as e: |
| | print(f"Error during model inference: {e}") |
| | |
| | |
| | |
| | return {"error": f"Model inference failed: {e}"} |