File size: 5,444 Bytes
7b484a4 2d2608e 7b484a4 98d1cad 7b484a4 98d1cad 7b484a4 98d1cad 7b484a4 98d1cad 7b484a4 ee438ff 7b484a4 98d1cad 7b484a4 98d1cad 2d2608e 98d1cad ee438ff 98d1cad ee438ff 98d1cad ee438ff 2d2608e 98d1cad 7b484a4 98d1cad 7b484a4 2d2608e 7b484a4 98d1cad 7b484a4 98d1cad 2d2608e 7b484a4 2d2608e 98d1cad ee438ff 98d1cad ee438ff 98d1cad 2d2608e 98d1cad 7b484a4 98d1cad 7b484a4 98d1cad 8d31a2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from typing import Dict, List, Any, Union
import logging
import torch.nn.functional as F
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
"""
Simplified Hugging Face Inference Endpoint Handler for scoring-only models.
Provides clean text-to-scores interface with standardized response format.
Compatible with both art and cog models.
"""
def __init__(self, path: str = ""):
"""Initialize the handler by loading the fine-tuned GPT-2 model and tokenizer."""
logger.info(f"Loading model and tokenizer from path: {path}")
try:
self.model = GPT2LMHeadModel.from_pretrained(path)
self.model.eval()
self.tokenizer = GPT2Tokenizer.from_pretrained(path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device)
self.vocab_size = self.model.config.vocab_size
logger.info(f"Model loaded successfully on device: {self.device}")
logger.info(f"Model vocab size: {self.vocab_size}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise RuntimeError(f"Model initialization failed: {e}")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process scoring request with text input.
Standardized response format:
{
"success": bool,
"data": {...},
"metadata": {...},
"error": str (only if success=False)
}
"""
try:
inputs = data.get("inputs", data) if isinstance(data, dict) else data
if inputs is None:
raise ValueError("Missing 'inputs' key in request data")
# Always compute scores (simplified - no compute_scores parameter)
metric = data.get("metric", "nll")
return self._score_text(inputs, metric)
except Exception as e:
logger.error(f"Request processing failed: {e}")
return {
"success": False,
"data": {},
"metadata": {},
"error": str(e)
}
def _score_text(self, text_input: Union[str, List[str]], metric: str = "nll") -> Dict[str, Any]:
"""Score text inputs and return computed scores."""
try:
# Normalize to list
if isinstance(text_input, str):
text_inputs = [text_input]
elif isinstance(text_input, list):
text_inputs = text_input
else:
raise ValueError(f"Expected string or list of strings, got: {type(text_input)}")
logger.info(f"Computing {metric} scores for {len(text_inputs)} texts")
# Tokenize inputs
encoded = self.tokenizer(
text_inputs,
return_tensors="pt",
padding=True,
truncation=True
)
input_ids = encoded["input_ids"].to(self.device)
attention_mask = encoded["attention_mask"].to(self.device)
scores = []
with torch.no_grad():
# Get logits for all inputs (no unnecessary conversions)
outputs = self.model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
# Compute scores for each sequence
for i in range(len(text_inputs)):
seq_input_ids = input_ids[i:i+1]
seq_logits = logits[i:i+1]
seq_attention_mask = attention_mask[i:i+1]
# Prepare for loss computation
targets = seq_input_ids[:, 1:].clone()
logits_for_loss = seq_logits[:, :-1]
mask = seq_attention_mask[:, 1:] == 1
if mask.sum() == 0:
scores.append(float('inf'))
continue
# Compute loss only on valid tokens
masked_logits = logits_for_loss[mask]
masked_targets = targets[mask]
loss = F.cross_entropy(masked_logits, masked_targets, reduction='mean')
if metric == "perplexity":
score = torch.exp(loss).item()
else: # nll
score = loss.item()
scores.append(score)
return {
"success": True,
"data": {"scores": scores},
"metadata": {
"metric": metric,
"num_sequences": len(text_inputs)
}
}
except Exception as e:
logger.error(f"Scoring failed: {e}")
return {
"success": False,
"data": {},
"metadata": {},
"error": str(e)
} |