alejandrohdez commited on
Commit
2d2608e
·
verified ·
1 Parent(s): e6fd632

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +57 -4
handler.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
  from typing import Dict, List, Any, Union
4
  import logging
 
5
 
6
  # Configure logging
7
  logging.basicConfig(level=logging.INFO)
@@ -57,14 +58,17 @@ class EndpointHandler:
57
 
58
  Expected input format:
59
  {
60
- "inputs": "text string" or ["text1", "text2", ...]
 
 
61
  }
62
 
63
  Returns:
64
  {
65
  "logits": [[[vocab_logits]], [[vocab_logits]]],
66
  "input_ids": [[token_ids]],
67
- "shape": [batch_size, sequence_length, vocab_size]
 
68
  }
69
 
70
  Args:
@@ -90,7 +94,12 @@ class EndpointHandler:
90
  else:
91
  raise ValueError(f"Expected string or list of strings, got: {type(inputs)}")
92
 
93
- logger.info(f"Processing batch of {len(text_inputs)} text inputs")
 
 
 
 
 
94
 
95
  # Tokenize all inputs
96
  encoded = self.tokenizer(
@@ -101,12 +110,13 @@ class EndpointHandler:
101
  )
102
 
103
  input_ids = encoded["input_ids"].to(self.device)
 
104
 
105
  logger.info(f"Tokenized to shape: {input_ids.shape}")
106
 
107
  # Forward pass - no gradients needed
108
  with torch.no_grad():
109
- outputs = self.model(input_ids)
110
  logits = outputs.logits
111
 
112
  # Convert to CPU and then to list for JSON serialization
@@ -123,6 +133,49 @@ class EndpointHandler:
123
  "original_inputs": text_inputs
124
  }
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  logger.info(f"Successfully processed batch, output shape: {logits.shape}")
127
  return response
128
 
 
2
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
  from typing import Dict, List, Any, Union
4
  import logging
5
+ import torch.nn.functional as F
6
 
7
  # Configure logging
8
  logging.basicConfig(level=logging.INFO)
 
58
 
59
  Expected input format:
60
  {
61
+ "inputs": "text string" or ["text1", "text2", ...],
62
+ "compute_scores": false, # Optional: if true, compute scores on server side
63
+ "metric": "nll" # Optional: "nll" or "perplexity" (only used if compute_scores=true)
64
  }
65
 
66
  Returns:
67
  {
68
  "logits": [[[vocab_logits]], [[vocab_logits]]],
69
  "input_ids": [[token_ids]],
70
+ "shape": [batch_size, sequence_length, vocab_size],
71
+ "scores": [score1, score2, ...] # Only if compute_scores=true
72
  }
73
 
74
  Args:
 
94
  else:
95
  raise ValueError(f"Expected string or list of strings, got: {type(inputs)}")
96
 
97
+ # Check if we should compute scores on server side
98
+ compute_scores = data.get("compute_scores", False)
99
+ metric = data.get("metric", "nll")
100
+
101
+ logger.info(f"Processing batch of {len(text_inputs)} text inputs" +
102
+ f" (compute_scores={compute_scores}, metric={metric})" if compute_scores else "")
103
 
104
  # Tokenize all inputs
105
  encoded = self.tokenizer(
 
110
  )
111
 
112
  input_ids = encoded["input_ids"].to(self.device)
113
+ attention_mask = encoded["attention_mask"].to(self.device)
114
 
115
  logger.info(f"Tokenized to shape: {input_ids.shape}")
116
 
117
  # Forward pass - no gradients needed
118
  with torch.no_grad():
119
+ outputs = self.model(input_ids, attention_mask=attention_mask)
120
  logits = outputs.logits
121
 
122
  # Convert to CPU and then to list for JSON serialization
 
133
  "original_inputs": text_inputs
134
  }
135
 
136
+ # Optionally compute scores on server side (GPU)
137
+ if compute_scores:
138
+ scores = []
139
+
140
+ # Process each sequence in the batch
141
+ for i in range(len(text_inputs)):
142
+ try:
143
+ # Extract this sequence's data
144
+ seq_input_ids = input_ids[i:i+1] # Keep batch dimension
145
+ seq_logits = logits[i:i+1]
146
+ seq_attention_mask = attention_mask[i:i+1]
147
+
148
+ # Prepare targets and logits for loss computation
149
+ targets = seq_input_ids[:, 1:].clone()
150
+ logits_for_loss = seq_logits[:, :-1]
151
+
152
+ # Only consider non-padding tokens (use attention mask)
153
+ mask = seq_attention_mask[:, 1:] == 1 # Skip first token, match targets shape
154
+
155
+ if mask.sum() == 0: # No valid tokens
156
+ scores.append(float('inf'))
157
+ continue
158
+
159
+ # Flatten and mask
160
+ masked_logits = logits_for_loss[mask] # [num_valid_tokens, vocab_size]
161
+ masked_targets = targets[mask] # [num_valid_tokens]
162
+
163
+ if metric == "perplexity":
164
+ loss = F.cross_entropy(masked_logits, masked_targets, reduction='mean')
165
+ perplexity = torch.exp(loss).item()
166
+ scores.append(perplexity)
167
+ else: # nll
168
+ nll = F.cross_entropy(masked_logits, masked_targets, reduction='mean')
169
+ scores.append(nll.item())
170
+
171
+ except Exception as e:
172
+ logger.warning(f"Failed to compute {metric} for sequence {i}: {e}")
173
+ scores.append(float('inf'))
174
+
175
+ response["scores"] = scores
176
+ response["metric"] = metric
177
+ logger.info(f"Computed {metric} scores on server side for {len(scores)} sequences")
178
+
179
  logger.info(f"Successfully processed batch, output shape: {logits.shape}")
180
  return response
181