Yong Liu commited on
Commit
0b6ae9b
·
1 Parent(s): bb64432

updated the handler.py

Browse files
Files changed (1) hide show
  1. handler.py +102 -10
handler.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  import logging
4
  import time
5
  import traceback
 
6
  from typing import Dict, List, Any, Union, Generator
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
  from threading import Thread
@@ -117,13 +118,14 @@ class EndpointHandler:
117
  Dict[str, Any]: Model response
118
  """
119
  start_time = time.time()
120
- logger.info(f"Starting request processing: {data}")
121
 
122
  try:
123
  # Extract input parameters with defaults
124
  if "inputs" not in data:
125
  logger.warning("No 'inputs' field in request data")
126
- return {"generated_text": "Error: Missing 'inputs' field in request"}
 
127
 
128
  # Track user and system messages
129
  user_message = ""
@@ -158,10 +160,12 @@ class EndpointHandler:
158
  user_message = msg.get("content", "")
159
  else:
160
  logger.warning(f"Unsupported input format: {type(data['inputs'])}")
161
- return {"generated_text": "Error: Unsupported input format. Expected string or messages object."}
 
162
 
163
- logger.info(f"Extracted user message: '{user_message}'")
164
- logger.info(f"Extracted system message: '{system_message}'")
 
165
 
166
  # Format the prompt with system and user messages
167
  prompt = self.format_prompt_with_system(user_message, system_message)
@@ -199,18 +203,25 @@ class EndpointHandler:
199
 
200
  logger.info(f"Response generation completed, text length: {len(response_text) if isinstance(response_text, str) else 'N/A'}")
201
 
202
- # Return response - keeping it simple for debug purposes
203
- return {"generated_text": response_text}
 
 
 
 
 
 
 
204
 
205
  except RuntimeError as e:
206
  logger.error(f"Runtime Error during generation: {str(e)}")
207
  logger.error(traceback.format_exc())
208
- return {"generated_text": f"Error during generation: {str(e)}"}
209
 
210
  except Exception as e:
211
  logger.error(f"Unexpected error during request processing: {str(e)}")
212
  logger.error(traceback.format_exc())
213
- return {"generated_text": f"Unexpected error: {str(e)}"}
214
  finally:
215
  duration = time.time() - start_time
216
  logger.info(f"Request processing completed in {duration:.2f} seconds")
@@ -298,6 +309,87 @@ class EndpointHandler:
298
  logger.error(f"Error in _safe_generate: {str(e)}")
299
  logger.error(traceback.format_exc())
300
  return f"Generation error: {str(e)}. Please try a simpler input."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
  # For local testing
303
  if __name__ == "__main__":
@@ -316,4 +408,4 @@ if __name__ == "__main__":
316
 
317
  # Run the test
318
  result = handler(test_with_messages)
319
- print(result)
 
3
  import logging
4
  import time
5
  import traceback
6
+ import json
7
  from typing import Dict, List, Any, Union, Generator
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
  from threading import Thread
 
118
  Dict[str, Any]: Model response
119
  """
120
  start_time = time.time()
121
+ logger.info(f"Starting request processing")
122
 
123
  try:
124
  # Extract input parameters with defaults
125
  if "inputs" not in data:
126
  logger.warning("No 'inputs' field in request data")
127
+ error_msg = "Missing 'inputs' field in request"
128
+ return self._format_error_response(error_msg)
129
 
130
  # Track user and system messages
131
  user_message = ""
 
160
  user_message = msg.get("content", "")
161
  else:
162
  logger.warning(f"Unsupported input format: {type(data['inputs'])}")
163
+ error_msg = "Unsupported input format. Expected string or messages object."
164
+ return self._format_error_response(error_msg)
165
 
166
+ logger.info(f"Extracted user message length: {len(user_message)} characters")
167
+ if system_message:
168
+ logger.info(f"Extracted system message length: {len(system_message)} characters")
169
 
170
  # Format the prompt with system and user messages
171
  prompt = self.format_prompt_with_system(user_message, system_message)
 
203
 
204
  logger.info(f"Response generation completed, text length: {len(response_text) if isinstance(response_text, str) else 'N/A'}")
205
 
206
+ # Format and return response in OpenAI format
207
+ if isinstance(response_text, str):
208
+ return self._format_openai_response(
209
+ response_text,
210
+ input_ids.shape[1],
211
+ len(self.tokenizer.encode(response_text)) if response_text else 0
212
+ )
213
+ else:
214
+ return self._format_error_response(f"Error during generation: {response_text}")
215
 
216
  except RuntimeError as e:
217
  logger.error(f"Runtime Error during generation: {str(e)}")
218
  logger.error(traceback.format_exc())
219
+ return self._format_error_response(f"Error during generation: {str(e)}")
220
 
221
  except Exception as e:
222
  logger.error(f"Unexpected error during request processing: {str(e)}")
223
  logger.error(traceback.format_exc())
224
+ return self._format_error_response(f"Unexpected error: {str(e)}")
225
  finally:
226
  duration = time.time() - start_time
227
  logger.info(f"Request processing completed in {duration:.2f} seconds")
 
309
  logger.error(f"Error in _safe_generate: {str(e)}")
310
  logger.error(traceback.format_exc())
311
  return f"Generation error: {str(e)}. Please try a simpler input."
312
+
313
+ def _format_openai_response(self, response_text, prompt_tokens, completion_tokens):
314
+ """Format the response in OpenAI-style format"""
315
+ try:
316
+ # Create a response ID
317
+ response_id = f"phi4-{int(time.time())}"
318
+
319
+ # Build OpenAI-compatible response
320
+ openai_response = {
321
+ "id": response_id,
322
+ "object": "chat.completion",
323
+ "created": int(time.time()),
324
+ "model": "phi-4-mini",
325
+ "choices": [
326
+ {
327
+ "index": 0,
328
+ "message": {
329
+ "role": "assistant",
330
+ "content": response_text
331
+ },
332
+ "finish_reason": "stop"
333
+ }
334
+ ],
335
+ "usage": {
336
+ "prompt_tokens": prompt_tokens,
337
+ "completion_tokens": completion_tokens,
338
+ "total_tokens": prompt_tokens + completion_tokens
339
+ }
340
+ }
341
+
342
+ # For compatibility with Hugging Face UI, include the generated_text field
343
+ openai_response["generated_text"] = response_text
344
+
345
+ logger.info(f"Formatted OpenAI-style response: {len(json.dumps(openai_response))} bytes")
346
+ return openai_response
347
+
348
+ except Exception as e:
349
+ logger.error(f"Error formatting OpenAI response: {str(e)}")
350
+ # Fall back to simple response
351
+ return {"generated_text": response_text}
352
+
353
+ def _format_error_response(self, error_message):
354
+ """Format an error response in OpenAI-style format"""
355
+ try:
356
+ error_response = {
357
+ "id": f"phi4-error-{int(time.time())}",
358
+ "object": "chat.completion",
359
+ "created": int(time.time()),
360
+ "model": "phi-4-mini",
361
+ "choices": [
362
+ {
363
+ "index": 0,
364
+ "message": {
365
+ "role": "assistant",
366
+ "content": f"Error: {error_message}"
367
+ },
368
+ "finish_reason": "error"
369
+ }
370
+ ],
371
+ "usage": {
372
+ "prompt_tokens": 0,
373
+ "completion_tokens": 0,
374
+ "total_tokens": 0
375
+ },
376
+ "error": {
377
+ "message": error_message,
378
+ "type": "invalid_request_error",
379
+ "code": "error"
380
+ }
381
+ }
382
+
383
+ # For compatibility with Hugging Face UI, include the generated_text field
384
+ error_response["generated_text"] = f"Error: {error_message}"
385
+
386
+ logger.info(f"Formatted error response: {len(json.dumps(error_response))} bytes")
387
+ return error_response
388
+
389
+ except Exception as e:
390
+ logger.error(f"Error formatting error response: {str(e)}")
391
+ # Fall back to simple error response
392
+ return {"generated_text": f"Error: {error_message}"}
393
 
394
  # For local testing
395
  if __name__ == "__main__":
 
408
 
409
  # Run the test
410
  result = handler(test_with_messages)
411
+ print(json.dumps(result, indent=2))