Update handler.py
Browse files- handler.py +57 -4
handler.py
CHANGED
|
@@ -2,6 +2,7 @@ import torch
|
|
| 2 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
| 3 |
from typing import Dict, List, Any, Union
|
| 4 |
import logging
|
|
|
|
| 5 |
|
| 6 |
# Configure logging
|
| 7 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -57,14 +58,17 @@ class EndpointHandler:
|
|
| 57 |
|
| 58 |
Expected input format:
|
| 59 |
{
|
| 60 |
-
"inputs": "text string" or ["text1", "text2", ...]
|
|
|
|
|
|
|
| 61 |
}
|
| 62 |
|
| 63 |
Returns:
|
| 64 |
{
|
| 65 |
"logits": [[[vocab_logits]], [[vocab_logits]]],
|
| 66 |
"input_ids": [[token_ids]],
|
| 67 |
-
"shape": [batch_size, sequence_length, vocab_size]
|
|
|
|
| 68 |
}
|
| 69 |
|
| 70 |
Args:
|
|
@@ -90,7 +94,12 @@ class EndpointHandler:
|
|
| 90 |
else:
|
| 91 |
raise ValueError(f"Expected string or list of strings, got: {type(inputs)}")
|
| 92 |
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
# Tokenize all inputs
|
| 96 |
encoded = self.tokenizer(
|
|
@@ -101,12 +110,13 @@ class EndpointHandler:
|
|
| 101 |
)
|
| 102 |
|
| 103 |
input_ids = encoded["input_ids"].to(self.device)
|
|
|
|
| 104 |
|
| 105 |
logger.info(f"Tokenized to shape: {input_ids.shape}")
|
| 106 |
|
| 107 |
# Forward pass - no gradients needed
|
| 108 |
with torch.no_grad():
|
| 109 |
-
outputs = self.model(input_ids)
|
| 110 |
logits = outputs.logits
|
| 111 |
|
| 112 |
# Convert to CPU and then to list for JSON serialization
|
|
@@ -123,6 +133,49 @@ class EndpointHandler:
|
|
| 123 |
"original_inputs": text_inputs
|
| 124 |
}
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
logger.info(f"Successfully processed batch, output shape: {logits.shape}")
|
| 127 |
return response
|
| 128 |
|
|
|
|
| 2 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
| 3 |
from typing import Dict, List, Any, Union
|
| 4 |
import logging
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
|
| 7 |
# Configure logging
|
| 8 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 58 |
|
| 59 |
Expected input format:
|
| 60 |
{
|
| 61 |
+
"inputs": "text string" or ["text1", "text2", ...],
|
| 62 |
+
"compute_scores": false, # Optional: if true, compute scores on server side
|
| 63 |
+
"metric": "nll" # Optional: "nll" or "perplexity" (only used if compute_scores=true)
|
| 64 |
}
|
| 65 |
|
| 66 |
Returns:
|
| 67 |
{
|
| 68 |
"logits": [[[vocab_logits]], [[vocab_logits]]],
|
| 69 |
"input_ids": [[token_ids]],
|
| 70 |
+
"shape": [batch_size, sequence_length, vocab_size],
|
| 71 |
+
"scores": [score1, score2, ...] # Only if compute_scores=true
|
| 72 |
}
|
| 73 |
|
| 74 |
Args:
|
|
|
|
| 94 |
else:
|
| 95 |
raise ValueError(f"Expected string or list of strings, got: {type(inputs)}")
|
| 96 |
|
| 97 |
+
# Check if we should compute scores on server side
|
| 98 |
+
compute_scores = data.get("compute_scores", False)
|
| 99 |
+
metric = data.get("metric", "nll")
|
| 100 |
+
|
| 101 |
+
logger.info(f"Processing batch of {len(text_inputs)} text inputs" +
|
| 102 |
+
f" (compute_scores={compute_scores}, metric={metric})" if compute_scores else "")
|
| 103 |
|
| 104 |
# Tokenize all inputs
|
| 105 |
encoded = self.tokenizer(
|
|
|
|
| 110 |
)
|
| 111 |
|
| 112 |
input_ids = encoded["input_ids"].to(self.device)
|
| 113 |
+
attention_mask = encoded["attention_mask"].to(self.device)
|
| 114 |
|
| 115 |
logger.info(f"Tokenized to shape: {input_ids.shape}")
|
| 116 |
|
| 117 |
# Forward pass - no gradients needed
|
| 118 |
with torch.no_grad():
|
| 119 |
+
outputs = self.model(input_ids, attention_mask=attention_mask)
|
| 120 |
logits = outputs.logits
|
| 121 |
|
| 122 |
# Convert to CPU and then to list for JSON serialization
|
|
|
|
| 133 |
"original_inputs": text_inputs
|
| 134 |
}
|
| 135 |
|
| 136 |
+
# Optionally compute scores on server side (GPU)
|
| 137 |
+
if compute_scores:
|
| 138 |
+
scores = []
|
| 139 |
+
|
| 140 |
+
# Process each sequence in the batch
|
| 141 |
+
for i in range(len(text_inputs)):
|
| 142 |
+
try:
|
| 143 |
+
# Extract this sequence's data
|
| 144 |
+
seq_input_ids = input_ids[i:i+1] # Keep batch dimension
|
| 145 |
+
seq_logits = logits[i:i+1]
|
| 146 |
+
seq_attention_mask = attention_mask[i:i+1]
|
| 147 |
+
|
| 148 |
+
# Prepare targets and logits for loss computation
|
| 149 |
+
targets = seq_input_ids[:, 1:].clone()
|
| 150 |
+
logits_for_loss = seq_logits[:, :-1]
|
| 151 |
+
|
| 152 |
+
# Only consider non-padding tokens (use attention mask)
|
| 153 |
+
mask = seq_attention_mask[:, 1:] == 1 # Skip first token, match targets shape
|
| 154 |
+
|
| 155 |
+
if mask.sum() == 0: # No valid tokens
|
| 156 |
+
scores.append(float('inf'))
|
| 157 |
+
continue
|
| 158 |
+
|
| 159 |
+
# Flatten and mask
|
| 160 |
+
masked_logits = logits_for_loss[mask] # [num_valid_tokens, vocab_size]
|
| 161 |
+
masked_targets = targets[mask] # [num_valid_tokens]
|
| 162 |
+
|
| 163 |
+
if metric == "perplexity":
|
| 164 |
+
loss = F.cross_entropy(masked_logits, masked_targets, reduction='mean')
|
| 165 |
+
perplexity = torch.exp(loss).item()
|
| 166 |
+
scores.append(perplexity)
|
| 167 |
+
else: # nll
|
| 168 |
+
nll = F.cross_entropy(masked_logits, masked_targets, reduction='mean')
|
| 169 |
+
scores.append(nll.item())
|
| 170 |
+
|
| 171 |
+
except Exception as e:
|
| 172 |
+
logger.warning(f"Failed to compute {metric} for sequence {i}: {e}")
|
| 173 |
+
scores.append(float('inf'))
|
| 174 |
+
|
| 175 |
+
response["scores"] = scores
|
| 176 |
+
response["metric"] = metric
|
| 177 |
+
logger.info(f"Computed {metric} scores on server side for {len(scores)} sequences")
|
| 178 |
+
|
| 179 |
logger.info(f"Successfully processed batch, output shape: {logits.shape}")
|
| 180 |
return response
|
| 181 |
|