alejandrohdez's picture
Update handler.py
98d1cad verified
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from typing import Dict, List, Any, Union
import logging
import torch.nn.functional as F
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
"""
Simplified Hugging Face Inference Endpoint Handler for scoring-only models.
Provides clean text-to-scores interface with standardized response format.
Compatible with both art and cog models.
"""
def __init__(self, path: str = ""):
"""Initialize the handler by loading the fine-tuned GPT-2 model and tokenizer."""
logger.info(f"Loading model and tokenizer from path: {path}")
try:
self.model = GPT2LMHeadModel.from_pretrained(path)
self.model.eval()
self.tokenizer = GPT2Tokenizer.from_pretrained(path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device)
self.vocab_size = self.model.config.vocab_size
logger.info(f"Model loaded successfully on device: {self.device}")
logger.info(f"Model vocab size: {self.vocab_size}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise RuntimeError(f"Model initialization failed: {e}")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process scoring request with text input.
Standardized response format:
{
"success": bool,
"data": {...},
"metadata": {...},
"error": str (only if success=False)
}
"""
try:
inputs = data.get("inputs", data) if isinstance(data, dict) else data
if inputs is None:
raise ValueError("Missing 'inputs' key in request data")
# Always compute scores (simplified - no compute_scores parameter)
metric = data.get("metric", "nll")
return self._score_text(inputs, metric)
except Exception as e:
logger.error(f"Request processing failed: {e}")
return {
"success": False,
"data": {},
"metadata": {},
"error": str(e)
}
def _score_text(self, text_input: Union[str, List[str]], metric: str = "nll") -> Dict[str, Any]:
"""Score text inputs and return computed scores."""
try:
# Normalize to list
if isinstance(text_input, str):
text_inputs = [text_input]
elif isinstance(text_input, list):
text_inputs = text_input
else:
raise ValueError(f"Expected string or list of strings, got: {type(text_input)}")
logger.info(f"Computing {metric} scores for {len(text_inputs)} texts")
# Tokenize inputs
encoded = self.tokenizer(
text_inputs,
return_tensors="pt",
padding=True,
truncation=True
)
input_ids = encoded["input_ids"].to(self.device)
attention_mask = encoded["attention_mask"].to(self.device)
scores = []
with torch.no_grad():
# Get logits for all inputs (no unnecessary conversions)
outputs = self.model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
# Compute scores for each sequence
for i in range(len(text_inputs)):
seq_input_ids = input_ids[i:i+1]
seq_logits = logits[i:i+1]
seq_attention_mask = attention_mask[i:i+1]
# Prepare for loss computation
targets = seq_input_ids[:, 1:].clone()
logits_for_loss = seq_logits[:, :-1]
mask = seq_attention_mask[:, 1:] == 1
if mask.sum() == 0:
scores.append(float('inf'))
continue
# Compute loss only on valid tokens
masked_logits = logits_for_loss[mask]
masked_targets = targets[mask]
loss = F.cross_entropy(masked_logits, masked_targets, reduction='mean')
if metric == "perplexity":
score = torch.exp(loss).item()
else: # nll
score = loss.item()
scores.append(score)
return {
"success": True,
"data": {"scores": scores},
"metadata": {
"metric": metric,
"num_sequences": len(text_inputs)
}
}
except Exception as e:
logger.error(f"Scoring failed: {e}")
return {
"success": False,
"data": {},
"metadata": {},
"error": str(e)
}