import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer from typing import Dict, List, Any, Union import logging import torch.nn.functional as F # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EndpointHandler: """ Simplified Hugging Face Inference Endpoint Handler for scoring-only models. Provides clean text-to-scores interface with standardized response format. Compatible with both art and cog models. """ def __init__(self, path: str = ""): """Initialize the handler by loading the fine-tuned GPT-2 model and tokenizer.""" logger.info(f"Loading model and tokenizer from path: {path}") try: self.model = GPT2LMHeadModel.from_pretrained(path) self.model.eval() self.tokenizer = GPT2Tokenizer.from_pretrained(path) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = self.model.to(self.device) self.vocab_size = self.model.config.vocab_size logger.info(f"Model loaded successfully on device: {self.device}") logger.info(f"Model vocab size: {self.vocab_size}") except Exception as e: logger.error(f"Failed to load model: {e}") raise RuntimeError(f"Model initialization failed: {e}") def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process scoring request with text input. Standardized response format: { "success": bool, "data": {...}, "metadata": {...}, "error": str (only if success=False) } """ try: inputs = data.get("inputs", data) if isinstance(data, dict) else data if inputs is None: raise ValueError("Missing 'inputs' key in request data") # Always compute scores (simplified - no compute_scores parameter) metric = data.get("metric", "nll") return self._score_text(inputs, metric) except Exception as e: logger.error(f"Request processing failed: {e}") return { "success": False, "data": {}, "metadata": {}, "error": str(e) } def _score_text(self, text_input: Union[str, List[str]], metric: str = "nll") -> Dict[str, Any]: """Score text inputs and return computed scores.""" try: # Normalize to list if isinstance(text_input, str): text_inputs = [text_input] elif isinstance(text_input, list): text_inputs = text_input else: raise ValueError(f"Expected string or list of strings, got: {type(text_input)}") logger.info(f"Computing {metric} scores for {len(text_inputs)} texts") # Tokenize inputs encoded = self.tokenizer( text_inputs, return_tensors="pt", padding=True, truncation=True ) input_ids = encoded["input_ids"].to(self.device) attention_mask = encoded["attention_mask"].to(self.device) scores = [] with torch.no_grad(): # Get logits for all inputs (no unnecessary conversions) outputs = self.model(input_ids, attention_mask=attention_mask) logits = outputs.logits # Compute scores for each sequence for i in range(len(text_inputs)): seq_input_ids = input_ids[i:i+1] seq_logits = logits[i:i+1] seq_attention_mask = attention_mask[i:i+1] # Prepare for loss computation targets = seq_input_ids[:, 1:].clone() logits_for_loss = seq_logits[:, :-1] mask = seq_attention_mask[:, 1:] == 1 if mask.sum() == 0: scores.append(float('inf')) continue # Compute loss only on valid tokens masked_logits = logits_for_loss[mask] masked_targets = targets[mask] loss = F.cross_entropy(masked_logits, masked_targets, reduction='mean') if metric == "perplexity": score = torch.exp(loss).item() else: # nll score = loss.item() scores.append(score) return { "success": True, "data": {"scores": scores}, "metadata": { "metric": metric, "num_sequences": len(text_inputs) } } except Exception as e: logger.error(f"Scoring failed: {e}") return { "success": False, "data": {}, "metadata": {}, "error": str(e) }