|
|
import torch |
|
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
|
from typing import Dict, List, Any, Union |
|
|
import logging |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
""" |
|
|
Simplified Hugging Face Inference Endpoint Handler for scoring-only models. |
|
|
|
|
|
Provides clean text-to-scores interface with standardized response format. |
|
|
Compatible with both art and cog models. |
|
|
""" |
|
|
|
|
|
def __init__(self, path: str = ""): |
|
|
"""Initialize the handler by loading the fine-tuned GPT-2 model and tokenizer.""" |
|
|
logger.info(f"Loading model and tokenizer from path: {path}") |
|
|
|
|
|
try: |
|
|
self.model = GPT2LMHeadModel.from_pretrained(path) |
|
|
self.model.eval() |
|
|
|
|
|
self.tokenizer = GPT2Tokenizer.from_pretrained(path) |
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.model = self.model.to(self.device) |
|
|
|
|
|
self.vocab_size = self.model.config.vocab_size |
|
|
|
|
|
logger.info(f"Model loaded successfully on device: {self.device}") |
|
|
logger.info(f"Model vocab size: {self.vocab_size}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load model: {e}") |
|
|
raise RuntimeError(f"Model initialization failed: {e}") |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Process scoring request with text input. |
|
|
|
|
|
Standardized response format: |
|
|
{ |
|
|
"success": bool, |
|
|
"data": {...}, |
|
|
"metadata": {...}, |
|
|
"error": str (only if success=False) |
|
|
} |
|
|
""" |
|
|
try: |
|
|
inputs = data.get("inputs", data) if isinstance(data, dict) else data |
|
|
if inputs is None: |
|
|
raise ValueError("Missing 'inputs' key in request data") |
|
|
|
|
|
|
|
|
metric = data.get("metric", "nll") |
|
|
return self._score_text(inputs, metric) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Request processing failed: {e}") |
|
|
return { |
|
|
"success": False, |
|
|
"data": {}, |
|
|
"metadata": {}, |
|
|
"error": str(e) |
|
|
} |
|
|
|
|
|
def _score_text(self, text_input: Union[str, List[str]], metric: str = "nll") -> Dict[str, Any]: |
|
|
"""Score text inputs and return computed scores.""" |
|
|
try: |
|
|
|
|
|
if isinstance(text_input, str): |
|
|
text_inputs = [text_input] |
|
|
elif isinstance(text_input, list): |
|
|
text_inputs = text_input |
|
|
else: |
|
|
raise ValueError(f"Expected string or list of strings, got: {type(text_input)}") |
|
|
|
|
|
logger.info(f"Computing {metric} scores for {len(text_inputs)} texts") |
|
|
|
|
|
|
|
|
encoded = self.tokenizer( |
|
|
text_inputs, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True |
|
|
) |
|
|
|
|
|
input_ids = encoded["input_ids"].to(self.device) |
|
|
attention_mask = encoded["attention_mask"].to(self.device) |
|
|
|
|
|
scores = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
outputs = self.model(input_ids, attention_mask=attention_mask) |
|
|
logits = outputs.logits |
|
|
|
|
|
|
|
|
for i in range(len(text_inputs)): |
|
|
seq_input_ids = input_ids[i:i+1] |
|
|
seq_logits = logits[i:i+1] |
|
|
seq_attention_mask = attention_mask[i:i+1] |
|
|
|
|
|
|
|
|
targets = seq_input_ids[:, 1:].clone() |
|
|
logits_for_loss = seq_logits[:, :-1] |
|
|
mask = seq_attention_mask[:, 1:] == 1 |
|
|
|
|
|
if mask.sum() == 0: |
|
|
scores.append(float('inf')) |
|
|
continue |
|
|
|
|
|
|
|
|
masked_logits = logits_for_loss[mask] |
|
|
masked_targets = targets[mask] |
|
|
|
|
|
loss = F.cross_entropy(masked_logits, masked_targets, reduction='mean') |
|
|
|
|
|
if metric == "perplexity": |
|
|
score = torch.exp(loss).item() |
|
|
else: |
|
|
score = loss.item() |
|
|
|
|
|
scores.append(score) |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"data": {"scores": scores}, |
|
|
"metadata": { |
|
|
"metric": metric, |
|
|
"num_sequences": len(text_inputs) |
|
|
} |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Scoring failed: {e}") |
|
|
return { |
|
|
"success": False, |
|
|
"data": {}, |
|
|
"metadata": {}, |
|
|
"error": str(e) |
|
|
} |