yol146 commited on
Commit
fd19926
·
1 Parent(s): 290cf25

modify the handler

Browse files
Files changed (1) hide show
  1. handler.py +83 -10
handler.py CHANGED
@@ -104,23 +104,96 @@ class EndpointHandler:
104
  do_sample = parameters.get("do_sample", self.do_sample)
105
  stream = parameters.get("stream", False)
106
 
107
- # Tokenize the input safely
108
- inputs = self.tokenizer(prompt, return_tensors="pt")
109
- logger.info(f"Input tokens shape: {inputs['input_ids'].shape}")
 
110
 
111
- # Move to device
112
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
 
113
 
114
- # Handle streaming if requested
115
- if stream:
116
- return self._generate_stream(inputs, max_new_tokens, temperature, top_p, do_sample)
117
- else:
118
- return self._generate(inputs, max_new_tokens, temperature, top_p, do_sample)
119
 
120
  except Exception as e:
121
  logger.error(f"Error during generation: {e}")
122
  return {"error": str(e)}
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  def _generate(self, inputs, max_new_tokens, temperature, top_p, do_sample):
125
  """Generate text non-streaming mode"""
126
  try:
 
104
  do_sample = parameters.get("do_sample", self.do_sample)
105
  stream = parameters.get("stream", False)
106
 
107
+ # CRITICAL FIX: Use manual generation approach for Phi models with vocabulary mismatches
108
+ # This bypasses the token indexing issues
109
+ if stream:
110
+ return {"error": "Streaming temporarily disabled while fixing token indexing issues"}
111
 
112
+ # Manually implement generation to avoid token index errors
113
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
114
+ logger.info(f"Input tokens shape: {input_ids.shape}")
115
 
116
+ # Create attention mask
117
+ attention_mask = torch.ones_like(input_ids)
118
+
119
+ # Perform safe generation with error handling for out-of-vocabulary issues
120
+ return self._safe_generate(input_ids, attention_mask, max_new_tokens, temperature, top_p, do_sample)
121
 
122
  except Exception as e:
123
  logger.error(f"Error during generation: {e}")
124
  return {"error": str(e)}
125
 
126
+ def _safe_generate(self, input_ids, attention_mask, max_new_tokens, temperature, top_p, do_sample):
127
+ """Safely generate text handling potential token index errors"""
128
+ try:
129
+ with torch.no_grad():
130
+ # Get the input text to exclude from final output
131
+ input_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
132
+ logger.info(f"Input decoded text: '{input_text}'")
133
+
134
+ # Generate one token at a time to avoid index errors
135
+ max_steps = min(max_new_tokens, 100) # Limit to 100 tokens for testing
136
+ current_ids = input_ids.clone()
137
+
138
+ for _ in range(max_steps):
139
+ # Get logits for next token
140
+ outputs = self.model(
141
+ input_ids=current_ids,
142
+ attention_mask=attention_mask,
143
+ return_dict=True
144
+ )
145
+
146
+ next_token_logits = outputs.logits[:, -1, :]
147
+
148
+ # Apply temperature and sampling
149
+ if temperature > 0:
150
+ next_token_logits = next_token_logits / temperature
151
+
152
+ if do_sample:
153
+ # Apply top_p sampling
154
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
155
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
156
+
157
+ # Remove tokens with cumulative probability above the threshold
158
+ sorted_indices_to_remove = cumulative_probs > top_p
159
+ # Shift the indices to the right to keep also the first token above the threshold
160
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
161
+ sorted_indices_to_remove[..., 0] = 0
162
+
163
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
164
+ next_token_logits[indices_to_remove] = -float('Inf')
165
+
166
+ # Sample from the filtered distribution
167
+ probs = torch.softmax(next_token_logits, dim=-1)
168
+ next_token = torch.multinomial(probs, num_samples=1)
169
+ else:
170
+ # Take the token with highest probability
171
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
172
+
173
+ # Add the predicted token to the sequence
174
+ current_ids = torch.cat([current_ids, next_token], dim=-1)
175
+ attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)
176
+
177
+ # Check if we've generated an EOS token
178
+ if next_token[0, 0].item() == self.tokenizer.eos_token_id:
179
+ break
180
+
181
+ # Decode the generated sequence
182
+ generated_text = self.tokenizer.decode(current_ids[0], skip_special_tokens=True)
183
+
184
+ # Return only the newly generated text (without the prompt)
185
+ if generated_text.startswith(input_text):
186
+ response_text = generated_text[len(input_text):]
187
+ else:
188
+ response_text = generated_text
189
+
190
+ logger.info(f"Generated {len(response_text)} characters")
191
+ return {"generated_text": response_text}
192
+
193
+ except Exception as e:
194
+ logger.error(f"Error in _safe_generate: {str(e)}")
195
+ return {"error": f"Generation error: {str(e)}. Please try a simpler input."}
196
+
197
  def _generate(self, inputs, max_new_tokens, temperature, top_p, do_sample):
198
  """Generate text non-streaming mode"""
199
  try: