Yong Liu commited on
Commit
bb64432
·
1 Parent(s): cec06c5

update handler.pu

Browse files
Files changed (1) hide show
  1. handler.py +69 -151
handler.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
2
  import torch
3
  import logging
 
 
4
  from typing import Dict, List, Any, Union, Generator
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  from threading import Thread
@@ -21,7 +23,7 @@ class EndpointHandler:
21
  path (str): Path to the model directory
22
  """
23
  # Set default parameters for inference
24
- self.max_new_tokens = 2048 # Increased from 1024 to handle longer outputs
25
  self.temperature = 0.7
26
  self.top_p = 0.9
27
  self.do_sample = True
@@ -73,7 +75,8 @@ class EndpointHandler:
73
  logger.warning(f"Tokenizer vocab size ({len(self.tokenizer)}) doesn't match model vocab size ({self.model.config.vocab_size})")
74
 
75
  except Exception as e:
76
- logger.error(f"Error during model initialization: {e}")
 
77
  raise
78
 
79
  def format_prompt_with_system(self, user_message, system_message=None):
@@ -103,7 +106,7 @@ class EndpointHandler:
103
  logger.info(f"Formatted prompt with {'system message and ' if system_message else ''}user message")
104
  return prompt
105
 
106
- def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], Generator]:
107
  """
108
  Process the input data and generate a response using the Phi-4 model.
109
 
@@ -111,33 +114,32 @@ class EndpointHandler:
111
  data (Dict[str, Any]): Input data containing the prompt and generation parameters
112
 
113
  Returns:
114
- Dict[str, Any] or Generator: Model response
115
  """
 
 
 
116
  try:
117
  # Extract input parameters with defaults
118
  if "inputs" not in data:
119
  logger.warning("No 'inputs' field in request data")
120
- return {"error": "Missing 'inputs' field in request"}
 
 
 
 
121
 
122
  # Handle different input formats
123
  # 1. Direct string input
124
  if isinstance(data["inputs"], str):
125
  user_message = data["inputs"]
126
  system_message = data.get("parameters", {}).get("system_message", None)
127
- messages = [
128
- {"role": "system", "content": system_message} if system_message else None,
129
- {"role": "user", "content": user_message}
130
- ]
131
- messages = [m for m in messages if m is not None] # Remove None values
132
 
133
  # 2. Dict with messages format
134
  elif isinstance(data["inputs"], dict) and "messages" in data["inputs"]:
135
  messages = data["inputs"]["messages"]
136
 
137
  # Extract system and user messages for prompt formatting
138
- system_message = None
139
- user_message = ""
140
-
141
  for msg in messages:
142
  if msg.get("role") == "system":
143
  system_message = msg.get("content", "")
@@ -149,17 +151,17 @@ class EndpointHandler:
149
  messages = data["inputs"]
150
 
151
  # Extract system and user messages for prompt formatting
152
- system_message = None
153
- user_message = ""
154
-
155
  for msg in messages:
156
  if msg.get("role") == "system":
157
  system_message = msg.get("content", "")
158
  elif msg.get("role") == "user":
159
  user_message = msg.get("content", "")
160
  else:
161
- logger.warning("Unsupported input format")
162
- return {"error": "Unsupported input format. Expected string or messages object."}
 
 
 
163
 
164
  # Format the prompt with system and user messages
165
  prompt = self.format_prompt_with_system(user_message, system_message)
@@ -169,72 +171,70 @@ class EndpointHandler:
169
  logger.info(f"Processing input with {len(prompt)} characters")
170
 
171
  # Get generation parameters with fallbacks to defaults
172
- max_new_tokens = min(parameters.get("max_new_tokens", self.max_new_tokens), 4096) # Increased to 4096
173
  temperature = parameters.get("temperature", self.temperature)
174
  top_p = parameters.get("top_p", self.top_p)
175
  do_sample = parameters.get("do_sample", self.do_sample)
176
- stream = parameters.get("stream", False)
177
 
178
- # CRITICAL FIX: Use manual generation approach for Phi models with vocabulary mismatches
179
- # This bypasses the token indexing issues
180
- if stream:
181
- return {"error": "Streaming temporarily disabled while fixing token indexing issues"}
182
 
183
  # Manually implement generation to avoid token index errors
184
- input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
185
- logger.info(f"Input tokens shape: {input_ids.shape}")
186
-
187
- # Create attention mask
188
- attention_mask = torch.ones_like(input_ids)
189
-
190
- # Perform safe generation with error handling for out-of-vocabulary issues
191
- response_text = self._safe_generate(input_ids, attention_mask, max_new_tokens, temperature, top_p, do_sample, prompt)
192
-
193
- # Format response in OpenAI-style format
194
- if isinstance(response_text, dict) and "error" in response_text:
195
- return response_text
196
- else:
197
- # OpenAI-style response format
198
- openai_response = {
199
- "id": f"phi4-{int(torch.randint(10000, 99999, (1,)).item())}",
200
- "object": "chat.completion",
201
- "created": int(torch.cuda.current_stream().cuda_stream if torch.cuda.is_available() else 0),
202
- "model": "phi-4-mini",
203
- "choices": [
204
- {
205
- "index": 0,
206
- "message": {
207
- "role": "assistant",
208
- "content": response_text
209
- },
210
- "finish_reason": "stop"
211
- }
212
- ],
213
- "usage": {
214
- "prompt_tokens": len(input_ids[0]),
215
- "completion_tokens": len(self.tokenizer.encode(response_text)) if isinstance(response_text, str) else 0,
216
- "total_tokens": len(input_ids[0]) + (len(self.tokenizer.encode(response_text)) if isinstance(response_text, str) else 0)
217
- }
218
- }
219
- return openai_response
220
 
221
  except Exception as e:
222
- logger.error(f"Error during generation: {e}")
223
- return {"error": str(e)}
 
 
 
 
224
 
225
  def _safe_generate(self, input_ids, attention_mask, max_new_tokens, temperature, top_p, do_sample, prompt):
226
  """Safely generate text handling potential token index errors"""
227
  try:
228
  with torch.no_grad():
 
 
229
  # Get the input text to exclude from final output
230
  input_text = prompt
231
  logger.info(f"Input prompt length: {len(input_text)} characters")
232
 
233
  # Generate one token at a time to avoid index errors
234
- max_steps = max_new_tokens # Allow for full generation length
235
  current_ids = input_ids.clone()
236
 
237
- for _ in range(max_steps):
 
 
 
 
 
238
  # Get logits for next token
239
  outputs = self.model(
240
  input_ids=current_ids,
@@ -275,111 +275,29 @@ class EndpointHandler:
275
 
276
  # Check if we've generated an EOS token
277
  if next_token[0, 0].item() == self.tokenizer.eos_token_id:
 
278
  break
279
 
280
  # Decode the generated sequence
281
  generated_text = self.tokenizer.decode(current_ids[0], skip_special_tokens=True)
 
282
 
283
  # Return only the newly generated text (after the assistant tag)
284
  split_text = generated_text.split("<|assistant|>")
285
  if len(split_text) > 1:
286
  response_text = split_text[1].strip()
 
287
  else:
288
  # Fallback if the expected format is not found
289
  logger.warning("Could not find assistant tag in generated text")
290
  response_text = generated_text
291
 
292
- logger.info(f"Generated {len(response_text)} characters")
293
  return response_text
294
 
295
  except Exception as e:
296
  logger.error(f"Error in _safe_generate: {str(e)}")
297
- return {"error": f"Generation error: {str(e)}. Please try a simpler input."}
298
-
299
- def _generate(self, inputs, max_new_tokens, temperature, top_p, do_sample):
300
- """Generate text non-streaming mode"""
301
- try:
302
- with torch.no_grad():
303
- generation_config = {
304
- "max_new_tokens": max_new_tokens,
305
- "temperature": temperature,
306
- "top_p": top_p,
307
- "do_sample": do_sample,
308
- "pad_token_id": self.tokenizer.eos_token_id
309
- }
310
-
311
- logger.info(f"Generating with config: {generation_config}")
312
-
313
- # Fix: inputs is a dictionary, not an object with attributes
314
- outputs = self.model.generate(
315
- inputs["input_ids"],
316
- attention_mask=inputs.get("attention_mask", None),
317
- **generation_config
318
- )
319
-
320
- # Decode the generated text
321
- generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
322
-
323
- # Return only the newly generated text (without the prompt)
324
- input_text = self.tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
325
-
326
- if generated_text.startswith(input_text):
327
- response_text = generated_text[len(input_text):]
328
- else:
329
- # Fallback if the decoded text doesn't start with the input
330
- response_text = generated_text
331
-
332
- logger.info(f"Generated {len(response_text)} characters")
333
- return response_text
334
-
335
- except Exception as e:
336
- logger.error(f"Error in _generate: {e}")
337
- return {"error": str(e)}
338
-
339
- def _generate_stream(self, inputs, max_new_tokens, temperature, top_p, do_sample):
340
- """Generate text in streaming mode"""
341
- try:
342
- # Create a streamer object
343
- streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
344
-
345
- # Set up generation in a separate thread
346
- generation_kwargs = {
347
- "input_ids": inputs["input_ids"],
348
- "attention_mask": inputs.get("attention_mask", None),
349
- "streamer": streamer,
350
- "max_new_tokens": max_new_tokens,
351
- "temperature": temperature,
352
- "top_p": top_p,
353
- "do_sample": do_sample,
354
- "pad_token_id": self.tokenizer.eos_token_id
355
- }
356
-
357
- thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
358
- thread.start()
359
-
360
- # Determine input text length to strip it from outputs
361
- input_text = self.tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
362
-
363
- # Stream the output
364
- def generate_stream():
365
- # Skip the prompt part in the first chunk
366
- full_text = ""
367
- for text in streamer:
368
- full_text += text
369
- # Only return the part after the prompt
370
- if full_text.startswith(input_text):
371
- current_response = full_text[len(input_text):]
372
- else:
373
- current_response = full_text
374
- yield {"generated_text": current_response}
375
-
376
- return generate_stream()
377
-
378
- except Exception as e:
379
- logger.error(f"Error in _generate_stream: {e}")
380
- def error_stream():
381
- yield {"error": str(e)}
382
- return error_stream()
383
 
384
  # For local testing
385
  if __name__ == "__main__":
 
1
  import os
2
  import torch
3
  import logging
4
+ import time
5
+ import traceback
6
  from typing import Dict, List, Any, Union, Generator
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
  from threading import Thread
 
23
  path (str): Path to the model directory
24
  """
25
  # Set default parameters for inference
26
+ self.max_new_tokens = 1024 # Keep at 1024 to avoid timeouts
27
  self.temperature = 0.7
28
  self.top_p = 0.9
29
  self.do_sample = True
 
75
  logger.warning(f"Tokenizer vocab size ({len(self.tokenizer)}) doesn't match model vocab size ({self.model.config.vocab_size})")
76
 
77
  except Exception as e:
78
+ logger.error(f"Error during model initialization: {str(e)}")
79
+ logger.error(traceback.format_exc())
80
  raise
81
 
82
  def format_prompt_with_system(self, user_message, system_message=None):
 
106
  logger.info(f"Formatted prompt with {'system message and ' if system_message else ''}user message")
107
  return prompt
108
 
109
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
110
  """
111
  Process the input data and generate a response using the Phi-4 model.
112
 
 
114
  data (Dict[str, Any]): Input data containing the prompt and generation parameters
115
 
116
  Returns:
117
+ Dict[str, Any]: Model response
118
  """
119
+ start_time = time.time()
120
+ logger.info(f"Starting request processing: {data}")
121
+
122
  try:
123
  # Extract input parameters with defaults
124
  if "inputs" not in data:
125
  logger.warning("No 'inputs' field in request data")
126
+ return {"generated_text": "Error: Missing 'inputs' field in request"}
127
+
128
+ # Track user and system messages
129
+ user_message = ""
130
+ system_message = None
131
 
132
  # Handle different input formats
133
  # 1. Direct string input
134
  if isinstance(data["inputs"], str):
135
  user_message = data["inputs"]
136
  system_message = data.get("parameters", {}).get("system_message", None)
 
 
 
 
 
137
 
138
  # 2. Dict with messages format
139
  elif isinstance(data["inputs"], dict) and "messages" in data["inputs"]:
140
  messages = data["inputs"]["messages"]
141
 
142
  # Extract system and user messages for prompt formatting
 
 
 
143
  for msg in messages:
144
  if msg.get("role") == "system":
145
  system_message = msg.get("content", "")
 
151
  messages = data["inputs"]
152
 
153
  # Extract system and user messages for prompt formatting
 
 
 
154
  for msg in messages:
155
  if msg.get("role") == "system":
156
  system_message = msg.get("content", "")
157
  elif msg.get("role") == "user":
158
  user_message = msg.get("content", "")
159
  else:
160
+ logger.warning(f"Unsupported input format: {type(data['inputs'])}")
161
+ return {"generated_text": "Error: Unsupported input format. Expected string or messages object."}
162
+
163
+ logger.info(f"Extracted user message: '{user_message}'")
164
+ logger.info(f"Extracted system message: '{system_message}'")
165
 
166
  # Format the prompt with system and user messages
167
  prompt = self.format_prompt_with_system(user_message, system_message)
 
171
  logger.info(f"Processing input with {len(prompt)} characters")
172
 
173
  # Get generation parameters with fallbacks to defaults
174
+ max_new_tokens = min(parameters.get("max_new_tokens", self.max_new_tokens), 2048)
175
  temperature = parameters.get("temperature", self.temperature)
176
  top_p = parameters.get("top_p", self.top_p)
177
  do_sample = parameters.get("do_sample", self.do_sample)
 
178
 
179
+ logger.info(f"Generation parameters: max_new_tokens={max_new_tokens}, temperature={temperature}, top_p={top_p}, do_sample={do_sample}")
 
 
 
180
 
181
  # Manually implement generation to avoid token index errors
182
+ try:
183
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
184
+ logger.info(f"Input tokens shape: {input_ids.shape}")
185
+
186
+ # Create attention mask
187
+ attention_mask = torch.ones_like(input_ids)
188
+
189
+ # Perform safe generation with error handling for out-of-vocabulary issues
190
+ response_text = self._safe_generate(
191
+ input_ids,
192
+ attention_mask,
193
+ max_new_tokens,
194
+ temperature,
195
+ top_p,
196
+ do_sample,
197
+ prompt
198
+ )
199
+
200
+ logger.info(f"Response generation completed, text length: {len(response_text) if isinstance(response_text, str) else 'N/A'}")
201
+
202
+ # Return response - keeping it simple for debug purposes
203
+ return {"generated_text": response_text}
204
+
205
+ except RuntimeError as e:
206
+ logger.error(f"Runtime Error during generation: {str(e)}")
207
+ logger.error(traceback.format_exc())
208
+ return {"generated_text": f"Error during generation: {str(e)}"}
 
 
 
 
 
 
 
 
 
209
 
210
  except Exception as e:
211
+ logger.error(f"Unexpected error during request processing: {str(e)}")
212
+ logger.error(traceback.format_exc())
213
+ return {"generated_text": f"Unexpected error: {str(e)}"}
214
+ finally:
215
+ duration = time.time() - start_time
216
+ logger.info(f"Request processing completed in {duration:.2f} seconds")
217
 
218
  def _safe_generate(self, input_ids, attention_mask, max_new_tokens, temperature, top_p, do_sample, prompt):
219
  """Safely generate text handling potential token index errors"""
220
  try:
221
  with torch.no_grad():
222
+ logger.info("Starting safe generation")
223
+
224
  # Get the input text to exclude from final output
225
  input_text = prompt
226
  logger.info(f"Input prompt length: {len(input_text)} characters")
227
 
228
  # Generate one token at a time to avoid index errors
229
+ max_steps = min(max_new_tokens, 250) # Limit to 250 tokens for reliability
230
  current_ids = input_ids.clone()
231
 
232
+ logger.info(f"Generating up to {max_steps} tokens")
233
+
234
+ for i in range(max_steps):
235
+ if i % 50 == 0:
236
+ logger.info(f"Generated {i} tokens so far")
237
+
238
  # Get logits for next token
239
  outputs = self.model(
240
  input_ids=current_ids,
 
275
 
276
  # Check if we've generated an EOS token
277
  if next_token[0, 0].item() == self.tokenizer.eos_token_id:
278
+ logger.info(f"EOS token generated after {i+1} tokens")
279
  break
280
 
281
  # Decode the generated sequence
282
  generated_text = self.tokenizer.decode(current_ids[0], skip_special_tokens=True)
283
+ logger.info(f"Decoded generated text: {len(generated_text)} characters")
284
 
285
  # Return only the newly generated text (after the assistant tag)
286
  split_text = generated_text.split("<|assistant|>")
287
  if len(split_text) > 1:
288
  response_text = split_text[1].strip()
289
+ logger.info(f"Extracted assistant response: {len(response_text)} characters")
290
  else:
291
  # Fallback if the expected format is not found
292
  logger.warning("Could not find assistant tag in generated text")
293
  response_text = generated_text
294
 
 
295
  return response_text
296
 
297
  except Exception as e:
298
  logger.error(f"Error in _safe_generate: {str(e)}")
299
+ logger.error(traceback.format_exc())
300
+ return f"Generation error: {str(e)}. Please try a simpler input."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
  # For local testing
303
  if __name__ == "__main__":