Yong Liu commited on
Commit
1b1b06a
·
1 Parent(s): fd19926

update handler py

Browse files
Files changed (1) hide show
  1. handler.py +96 -9
handler.py CHANGED
@@ -76,6 +76,33 @@ class EndpointHandler:
76
  logger.error(f"Error during model initialization: {e}")
77
  raise
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, str], Generator]:
80
  """
81
  Process the input data and generate a response using the Phi-4 model.
@@ -91,8 +118,45 @@ class EndpointHandler:
91
  if "inputs" not in data:
92
  logger.warning("No 'inputs' field in request data")
93
  return {"error": "Missing 'inputs' field in request"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- prompt = data.get("inputs", "")
 
 
96
  parameters = data.get("parameters", {})
97
 
98
  logger.info(f"Processing input with {len(prompt)} characters")
@@ -117,19 +181,19 @@ class EndpointHandler:
117
  attention_mask = torch.ones_like(input_ids)
118
 
119
  # Perform safe generation with error handling for out-of-vocabulary issues
120
- return self._safe_generate(input_ids, attention_mask, max_new_tokens, temperature, top_p, do_sample)
121
 
122
  except Exception as e:
123
  logger.error(f"Error during generation: {e}")
124
  return {"error": str(e)}
125
 
126
- def _safe_generate(self, input_ids, attention_mask, max_new_tokens, temperature, top_p, do_sample):
127
  """Safely generate text handling potential token index errors"""
128
  try:
129
  with torch.no_grad():
130
  # Get the input text to exclude from final output
131
- input_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
132
- logger.info(f"Input decoded text: '{input_text}'")
133
 
134
  # Generate one token at a time to avoid index errors
135
  max_steps = min(max_new_tokens, 100) # Limit to 100 tokens for testing
@@ -181,10 +245,13 @@ class EndpointHandler:
181
  # Decode the generated sequence
182
  generated_text = self.tokenizer.decode(current_ids[0], skip_special_tokens=True)
183
 
184
- # Return only the newly generated text (without the prompt)
185
- if generated_text.startswith(input_text):
186
- response_text = generated_text[len(input_text):]
 
187
  else:
 
 
188
  response_text = generated_text
189
 
190
  logger.info(f"Generated {len(response_text)} characters")
@@ -283,5 +350,25 @@ class EndpointHandler:
283
  if __name__ == "__main__":
284
  # Example usage
285
  handler = EndpointHandler()
286
- result = handler({"inputs": "What are the major features of Phi-4?"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  print(result)
 
76
  logger.error(f"Error during model initialization: {e}")
77
  raise
78
 
79
+ def format_prompt_with_system(self, user_message, system_message=None):
80
+ """
81
+ Format the prompt with system and user messages according to Phi-4 format.
82
+
83
+ Args:
84
+ user_message (str): The user's message
85
+ system_message (str, optional): The system message/instruction
86
+
87
+ Returns:
88
+ str: Formatted prompt ready for the model
89
+ """
90
+ # Format using Phi-4's expected chat template:
91
+ # <|system|>
92
+ # {system_message}
93
+ # <|user|>
94
+ # {user_message}
95
+ # <|assistant|>
96
+
97
+ if system_message:
98
+ prompt = f"<|system|>\n{system_message}\n<|user|>\n{user_message}\n<|assistant|>"
99
+ else:
100
+ # If no system message, just use user message with assistant tag
101
+ prompt = f"<|user|>\n{user_message}\n<|assistant|>"
102
+
103
+ logger.info(f"Formatted prompt with {'system message and ' if system_message else ''}user message")
104
+ return prompt
105
+
106
  def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, str], Generator]:
107
  """
108
  Process the input data and generate a response using the Phi-4 model.
 
118
  if "inputs" not in data:
119
  logger.warning("No 'inputs' field in request data")
120
  return {"error": "Missing 'inputs' field in request"}
121
+
122
+ # Handle different input formats
123
+ # 1. Direct string input
124
+ if isinstance(data["inputs"], str):
125
+ user_message = data["inputs"]
126
+ system_message = data.get("parameters", {}).get("system_message", None)
127
+ # 2. Dict with messages format
128
+ elif isinstance(data["inputs"], dict) and "messages" in data["inputs"]:
129
+ messages = data["inputs"]["messages"]
130
+ # Extract system and user messages
131
+ system_message = None
132
+ user_message = ""
133
+
134
+ # Process messages in order, using the last user message
135
+ for msg in messages:
136
+ if msg.get("role") == "system":
137
+ system_message = msg.get("content", "")
138
+ elif msg.get("role") == "user":
139
+ user_message = msg.get("content", "")
140
+ # 3. Direct messages list format
141
+ elif isinstance(data["inputs"], list):
142
+ messages = data["inputs"]
143
+ # Extract system and user messages
144
+ system_message = None
145
+ user_message = ""
146
+
147
+ # Process messages in order, using the last user message
148
+ for msg in messages:
149
+ if msg.get("role") == "system":
150
+ system_message = msg.get("content", "")
151
+ elif msg.get("role") == "user":
152
+ user_message = msg.get("content", "")
153
+ else:
154
+ logger.warning("Unsupported input format")
155
+ return {"error": "Unsupported input format. Expected string or messages object."}
156
 
157
+ # Format the prompt with system and user messages
158
+ prompt = self.format_prompt_with_system(user_message, system_message)
159
+
160
  parameters = data.get("parameters", {})
161
 
162
  logger.info(f"Processing input with {len(prompt)} characters")
 
181
  attention_mask = torch.ones_like(input_ids)
182
 
183
  # Perform safe generation with error handling for out-of-vocabulary issues
184
+ return self._safe_generate(input_ids, attention_mask, max_new_tokens, temperature, top_p, do_sample, prompt)
185
 
186
  except Exception as e:
187
  logger.error(f"Error during generation: {e}")
188
  return {"error": str(e)}
189
 
190
+ def _safe_generate(self, input_ids, attention_mask, max_new_tokens, temperature, top_p, do_sample, prompt):
191
  """Safely generate text handling potential token index errors"""
192
  try:
193
  with torch.no_grad():
194
  # Get the input text to exclude from final output
195
+ input_text = prompt
196
+ logger.info(f"Input prompt length: {len(input_text)} characters")
197
 
198
  # Generate one token at a time to avoid index errors
199
  max_steps = min(max_new_tokens, 100) # Limit to 100 tokens for testing
 
245
  # Decode the generated sequence
246
  generated_text = self.tokenizer.decode(current_ids[0], skip_special_tokens=True)
247
 
248
+ # Return only the newly generated text (after the assistant tag)
249
+ split_text = generated_text.split("<|assistant|>")
250
+ if len(split_text) > 1:
251
+ response_text = split_text[1].strip()
252
  else:
253
+ # Fallback if the expected format is not found
254
+ logger.warning("Could not find assistant tag in generated text")
255
  response_text = generated_text
256
 
257
  logger.info(f"Generated {len(response_text)} characters")
 
350
  if __name__ == "__main__":
351
  # Example usage
352
  handler = EndpointHandler()
353
+
354
+ # Test with system message
355
+ test_with_system = {
356
+ "inputs": "What are the major features of Phi-4?",
357
+ "parameters": {
358
+ "system_message": "You are an AI assistant that provides helpful, accurate, and concise information about AI models."
359
+ }
360
+ }
361
+
362
+ # Test with messages format
363
+ test_with_messages = {
364
+ "inputs": {
365
+ "messages": [
366
+ {"role": "system", "content": "You are an AI assistant that provides helpful, accurate, and concise information about AI models."},
367
+ {"role": "user", "content": "What are the major features of Phi-4?"}
368
+ ]
369
+ }
370
+ }
371
+
372
+ # Choose which test to run
373
+ result = handler(test_with_system)
374
  print(result)