File size: 5,444 Bytes
7b484a4
 
 
 
2d2608e
7b484a4
 
 
 
 
 
 
 
98d1cad
7b484a4
98d1cad
 
7b484a4
 
 
98d1cad
7b484a4
 
 
 
98d1cad
7b484a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee438ff
7b484a4
98d1cad
 
 
 
 
 
 
7b484a4
 
 
 
 
 
98d1cad
2d2608e
98d1cad
ee438ff
 
 
 
98d1cad
 
 
 
ee438ff
 
98d1cad
 
ee438ff
 
 
 
 
 
 
 
2d2608e
98d1cad
7b484a4
98d1cad
7b484a4
 
 
 
 
 
 
 
2d2608e
7b484a4
98d1cad
 
7b484a4
98d1cad
2d2608e
7b484a4
2d2608e
98d1cad
 
 
 
 
ee438ff
98d1cad
 
 
 
ee438ff
98d1cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d2608e
98d1cad
 
 
 
 
 
 
 
7b484a4
 
98d1cad
7b484a4
98d1cad
 
 
 
8d31a2b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from typing import Dict, List, Any, Union
import logging
import torch.nn.functional as F

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class EndpointHandler:
    """
    Simplified Hugging Face Inference Endpoint Handler for scoring-only models.
    
    Provides clean text-to-scores interface with standardized response format.
    Compatible with both art and cog models.
    """
    
    def __init__(self, path: str = ""):
        """Initialize the handler by loading the fine-tuned GPT-2 model and tokenizer."""
        logger.info(f"Loading model and tokenizer from path: {path}")
        
        try:
            self.model = GPT2LMHeadModel.from_pretrained(path)
            self.model.eval()
            
            self.tokenizer = GPT2Tokenizer.from_pretrained(path)
            
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            self.model = self.model.to(self.device)
            
            self.vocab_size = self.model.config.vocab_size
            
            logger.info(f"Model loaded successfully on device: {self.device}")
            logger.info(f"Model vocab size: {self.vocab_size}")
            
        except Exception as e:
            logger.error(f"Failed to load model: {e}")
            raise RuntimeError(f"Model initialization failed: {e}")
    
    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Process scoring request with text input.
        
        Standardized response format:
        {
            "success": bool,
            "data": {...},
            "metadata": {...},
            "error": str (only if success=False)
        }
        """
        try:
            inputs = data.get("inputs", data) if isinstance(data, dict) else data
            if inputs is None:
                raise ValueError("Missing 'inputs' key in request data")
            
            # Always compute scores (simplified - no compute_scores parameter)
            metric = data.get("metric", "nll")
            return self._score_text(inputs, metric)
            
        except Exception as e:
            logger.error(f"Request processing failed: {e}")
            return {
                "success": False,
                "data": {},
                "metadata": {},
                "error": str(e)
            }
    
    def _score_text(self, text_input: Union[str, List[str]], metric: str = "nll") -> Dict[str, Any]:
        """Score text inputs and return computed scores."""
        try:
            # Normalize to list
            if isinstance(text_input, str):
                text_inputs = [text_input]
            elif isinstance(text_input, list):
                text_inputs = text_input
            else:
                raise ValueError(f"Expected string or list of strings, got: {type(text_input)}")
            
            logger.info(f"Computing {metric} scores for {len(text_inputs)} texts")
            
            # Tokenize inputs
            encoded = self.tokenizer(
                text_inputs,
                return_tensors="pt",
                padding=True,
                truncation=True
            )
            
            input_ids = encoded["input_ids"].to(self.device)
            attention_mask = encoded["attention_mask"].to(self.device)
            
            scores = []
            
            with torch.no_grad():
                # Get logits for all inputs (no unnecessary conversions)
                outputs = self.model(input_ids, attention_mask=attention_mask)
                logits = outputs.logits
                
                # Compute scores for each sequence
                for i in range(len(text_inputs)):
                    seq_input_ids = input_ids[i:i+1]
                    seq_logits = logits[i:i+1]
                    seq_attention_mask = attention_mask[i:i+1]
                    
                    # Prepare for loss computation
                    targets = seq_input_ids[:, 1:].clone()
                    logits_for_loss = seq_logits[:, :-1]
                    mask = seq_attention_mask[:, 1:] == 1
                    
                    if mask.sum() == 0:
                        scores.append(float('inf'))
                        continue
                    
                    # Compute loss only on valid tokens
                    masked_logits = logits_for_loss[mask]
                    masked_targets = targets[mask]
                    
                    loss = F.cross_entropy(masked_logits, masked_targets, reduction='mean')
                    
                    if metric == "perplexity":
                        score = torch.exp(loss).item()
                    else:  # nll
                        score = loss.item()
                    
                    scores.append(score)
            
            return {
                "success": True,
                "data": {"scores": scores},
                "metadata": {
                    "metric": metric,
                    "num_sequences": len(text_inputs)
                }
            }
            
        except Exception as e:
            logger.error(f"Scoring failed: {e}")
            return {
                "success": False,
                "data": {},
                "metadata": {},
                "error": str(e)
            }