Yong Liu commited on
Commit
cec06c5
·
1 Parent(s): 1b1b06a

update the handler py

Browse files
Files changed (1) hide show
  1. handler.py +50 -23
handler.py CHANGED
@@ -21,7 +21,7 @@ class EndpointHandler:
21
  path (str): Path to the model directory
22
  """
23
  # Set default parameters for inference
24
- self.max_new_tokens = 1024 # Reduced from 4096 to avoid memory issues
25
  self.temperature = 0.7
26
  self.top_p = 0.9
27
  self.do_sample = True
@@ -103,7 +103,7 @@ class EndpointHandler:
103
  logger.info(f"Formatted prompt with {'system message and ' if system_message else ''}user message")
104
  return prompt
105
 
106
- def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, str], Generator]:
107
  """
108
  Process the input data and generate a response using the Phi-4 model.
109
 
@@ -111,7 +111,7 @@ class EndpointHandler:
111
  data (Dict[str, Any]): Input data containing the prompt and generation parameters
112
 
113
  Returns:
114
- Dict[str, str] or Generator: Model response or stream
115
  """
116
  try:
117
  # Extract input parameters with defaults
@@ -124,27 +124,34 @@ class EndpointHandler:
124
  if isinstance(data["inputs"], str):
125
  user_message = data["inputs"]
126
  system_message = data.get("parameters", {}).get("system_message", None)
 
 
 
 
 
 
127
  # 2. Dict with messages format
128
  elif isinstance(data["inputs"], dict) and "messages" in data["inputs"]:
129
  messages = data["inputs"]["messages"]
130
- # Extract system and user messages
 
131
  system_message = None
132
  user_message = ""
133
 
134
- # Process messages in order, using the last user message
135
  for msg in messages:
136
  if msg.get("role") == "system":
137
  system_message = msg.get("content", "")
138
  elif msg.get("role") == "user":
139
  user_message = msg.get("content", "")
 
140
  # 3. Direct messages list format
141
  elif isinstance(data["inputs"], list):
142
  messages = data["inputs"]
143
- # Extract system and user messages
 
144
  system_message = None
145
  user_message = ""
146
 
147
- # Process messages in order, using the last user message
148
  for msg in messages:
149
  if msg.get("role") == "system":
150
  system_message = msg.get("content", "")
@@ -153,7 +160,7 @@ class EndpointHandler:
153
  else:
154
  logger.warning("Unsupported input format")
155
  return {"error": "Unsupported input format. Expected string or messages object."}
156
-
157
  # Format the prompt with system and user messages
158
  prompt = self.format_prompt_with_system(user_message, system_message)
159
 
@@ -162,7 +169,7 @@ class EndpointHandler:
162
  logger.info(f"Processing input with {len(prompt)} characters")
163
 
164
  # Get generation parameters with fallbacks to defaults
165
- max_new_tokens = min(parameters.get("max_new_tokens", self.max_new_tokens), 2048)
166
  temperature = parameters.get("temperature", self.temperature)
167
  top_p = parameters.get("top_p", self.top_p)
168
  do_sample = parameters.get("do_sample", self.do_sample)
@@ -181,7 +188,35 @@ class EndpointHandler:
181
  attention_mask = torch.ones_like(input_ids)
182
 
183
  # Perform safe generation with error handling for out-of-vocabulary issues
184
- return self._safe_generate(input_ids, attention_mask, max_new_tokens, temperature, top_p, do_sample, prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  except Exception as e:
187
  logger.error(f"Error during generation: {e}")
@@ -196,7 +231,7 @@ class EndpointHandler:
196
  logger.info(f"Input prompt length: {len(input_text)} characters")
197
 
198
  # Generate one token at a time to avoid index errors
199
- max_steps = min(max_new_tokens, 100) # Limit to 100 tokens for testing
200
  current_ids = input_ids.clone()
201
 
202
  for _ in range(max_steps):
@@ -255,7 +290,7 @@ class EndpointHandler:
255
  response_text = generated_text
256
 
257
  logger.info(f"Generated {len(response_text)} characters")
258
- return {"generated_text": response_text}
259
 
260
  except Exception as e:
261
  logger.error(f"Error in _safe_generate: {str(e)}")
@@ -295,7 +330,7 @@ class EndpointHandler:
295
  response_text = generated_text
296
 
297
  logger.info(f"Generated {len(response_text)} characters")
298
- return {"generated_text": response_text}
299
 
300
  except Exception as e:
301
  logger.error(f"Error in _generate: {e}")
@@ -351,14 +386,6 @@ if __name__ == "__main__":
351
  # Example usage
352
  handler = EndpointHandler()
353
 
354
- # Test with system message
355
- test_with_system = {
356
- "inputs": "What are the major features of Phi-4?",
357
- "parameters": {
358
- "system_message": "You are an AI assistant that provides helpful, accurate, and concise information about AI models."
359
- }
360
- }
361
-
362
  # Test with messages format
363
  test_with_messages = {
364
  "inputs": {
@@ -369,6 +396,6 @@ if __name__ == "__main__":
369
  }
370
  }
371
 
372
- # Choose which test to run
373
- result = handler(test_with_system)
374
  print(result)
 
21
  path (str): Path to the model directory
22
  """
23
  # Set default parameters for inference
24
+ self.max_new_tokens = 2048 # Increased from 1024 to handle longer outputs
25
  self.temperature = 0.7
26
  self.top_p = 0.9
27
  self.do_sample = True
 
103
  logger.info(f"Formatted prompt with {'system message and ' if system_message else ''}user message")
104
  return prompt
105
 
106
+ def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], Generator]:
107
  """
108
  Process the input data and generate a response using the Phi-4 model.
109
 
 
111
  data (Dict[str, Any]): Input data containing the prompt and generation parameters
112
 
113
  Returns:
114
+ Dict[str, Any] or Generator: Model response
115
  """
116
  try:
117
  # Extract input parameters with defaults
 
124
  if isinstance(data["inputs"], str):
125
  user_message = data["inputs"]
126
  system_message = data.get("parameters", {}).get("system_message", None)
127
+ messages = [
128
+ {"role": "system", "content": system_message} if system_message else None,
129
+ {"role": "user", "content": user_message}
130
+ ]
131
+ messages = [m for m in messages if m is not None] # Remove None values
132
+
133
  # 2. Dict with messages format
134
  elif isinstance(data["inputs"], dict) and "messages" in data["inputs"]:
135
  messages = data["inputs"]["messages"]
136
+
137
+ # Extract system and user messages for prompt formatting
138
  system_message = None
139
  user_message = ""
140
 
 
141
  for msg in messages:
142
  if msg.get("role") == "system":
143
  system_message = msg.get("content", "")
144
  elif msg.get("role") == "user":
145
  user_message = msg.get("content", "")
146
+
147
  # 3. Direct messages list format
148
  elif isinstance(data["inputs"], list):
149
  messages = data["inputs"]
150
+
151
+ # Extract system and user messages for prompt formatting
152
  system_message = None
153
  user_message = ""
154
 
 
155
  for msg in messages:
156
  if msg.get("role") == "system":
157
  system_message = msg.get("content", "")
 
160
  else:
161
  logger.warning("Unsupported input format")
162
  return {"error": "Unsupported input format. Expected string or messages object."}
163
+
164
  # Format the prompt with system and user messages
165
  prompt = self.format_prompt_with_system(user_message, system_message)
166
 
 
169
  logger.info(f"Processing input with {len(prompt)} characters")
170
 
171
  # Get generation parameters with fallbacks to defaults
172
+ max_new_tokens = min(parameters.get("max_new_tokens", self.max_new_tokens), 4096) # Increased to 4096
173
  temperature = parameters.get("temperature", self.temperature)
174
  top_p = parameters.get("top_p", self.top_p)
175
  do_sample = parameters.get("do_sample", self.do_sample)
 
188
  attention_mask = torch.ones_like(input_ids)
189
 
190
  # Perform safe generation with error handling for out-of-vocabulary issues
191
+ response_text = self._safe_generate(input_ids, attention_mask, max_new_tokens, temperature, top_p, do_sample, prompt)
192
+
193
+ # Format response in OpenAI-style format
194
+ if isinstance(response_text, dict) and "error" in response_text:
195
+ return response_text
196
+ else:
197
+ # OpenAI-style response format
198
+ openai_response = {
199
+ "id": f"phi4-{int(torch.randint(10000, 99999, (1,)).item())}",
200
+ "object": "chat.completion",
201
+ "created": int(torch.cuda.current_stream().cuda_stream if torch.cuda.is_available() else 0),
202
+ "model": "phi-4-mini",
203
+ "choices": [
204
+ {
205
+ "index": 0,
206
+ "message": {
207
+ "role": "assistant",
208
+ "content": response_text
209
+ },
210
+ "finish_reason": "stop"
211
+ }
212
+ ],
213
+ "usage": {
214
+ "prompt_tokens": len(input_ids[0]),
215
+ "completion_tokens": len(self.tokenizer.encode(response_text)) if isinstance(response_text, str) else 0,
216
+ "total_tokens": len(input_ids[0]) + (len(self.tokenizer.encode(response_text)) if isinstance(response_text, str) else 0)
217
+ }
218
+ }
219
+ return openai_response
220
 
221
  except Exception as e:
222
  logger.error(f"Error during generation: {e}")
 
231
  logger.info(f"Input prompt length: {len(input_text)} characters")
232
 
233
  # Generate one token at a time to avoid index errors
234
+ max_steps = max_new_tokens # Allow for full generation length
235
  current_ids = input_ids.clone()
236
 
237
  for _ in range(max_steps):
 
290
  response_text = generated_text
291
 
292
  logger.info(f"Generated {len(response_text)} characters")
293
+ return response_text
294
 
295
  except Exception as e:
296
  logger.error(f"Error in _safe_generate: {str(e)}")
 
330
  response_text = generated_text
331
 
332
  logger.info(f"Generated {len(response_text)} characters")
333
+ return response_text
334
 
335
  except Exception as e:
336
  logger.error(f"Error in _generate: {e}")
 
386
  # Example usage
387
  handler = EndpointHandler()
388
 
 
 
 
 
 
 
 
 
389
  # Test with messages format
390
  test_with_messages = {
391
  "inputs": {
 
396
  }
397
  }
398
 
399
+ # Run the test
400
+ result = handler(test_with_messages)
401
  print(result)