Yong Liu commited on
Commit
051c5a5
·
1 Parent(s): ead8711

update handler

Browse files
Files changed (5) hide show
  1. README.md +81 -0
  2. __pycache__/handler.cpython-310.pyc +0 -0
  3. handler.py +146 -416
  4. requirements.txt +4 -0
  5. test_handler.py +71 -0
README.md ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Phi-4 Mini Inference Endpoint Handler
2
+
3
+ This repository contains code for deploying the Phi-4 Mini model to a HuggingFace Inference Endpoint with an OpenAI-compatible API format.
4
+
5
+ ## Setup
6
+
7
+ 1. Install the required dependencies:
8
+ ```
9
+ pip install -r requirements.txt
10
+ ```
11
+
12
+ 2. Set the environment variable to your model path (optional if model is in the same directory):
13
+ ```
14
+ export MODEL_PATH=/path/to/your/model
15
+ ```
16
+
17
+ ## Usage
18
+
19
+ When deploying to a HuggingFace Inference Endpoint, the `handler.py` file will be used to process requests. The endpoint accepts requests in an OpenAI-compatible format:
20
+
21
+ ```json
22
+ {
23
+ "messages": [
24
+ {"role": "system", "content": "You are a helpful assistant."},
25
+ {"role": "user", "content": "Tell me about language models."}
26
+ ],
27
+ "max_tokens": 256,
28
+ "temperature": 0.7,
29
+ "top_p": 1.0,
30
+ "n": 1,
31
+ "stop": ["\n", "User:"]
32
+ }
33
+ ```
34
+
35
+ The endpoint returns responses in an OpenAI-compatible format:
36
+
37
+ ```json
38
+ {
39
+ "id": "cmpl-12345",
40
+ "object": "chat.completion",
41
+ "created": 0,
42
+ "model": "phi4-mini-raw",
43
+ "choices": [
44
+ {
45
+ "index": 0,
46
+ "message": {
47
+ "role": "assistant",
48
+ "content": "Language models are computational systems designed to understand and generate human language..."
49
+ },
50
+ "finish_reason": "stop"
51
+ }
52
+ ],
53
+ "usage": {
54
+ "prompt_tokens": 42,
55
+ "completion_tokens": 156,
56
+ "total_tokens": 198
57
+ }
58
+ }
59
+ ```
60
+
61
+ ## Local Testing
62
+
63
+ To test the handler locally before deployment:
64
+
65
+ ```python
66
+ from handler import EndpointHandler
67
+
68
+ # Initialize the handler with your model path
69
+ handler = EndpointHandler("./phi4-mini-raw")
70
+
71
+ # Test with a sample request
72
+ request = {
73
+ "messages": [
74
+ {"role": "system", "content": "You are a helpful assistant."},
75
+ {"role": "user", "content": "Hello, how are you?"}
76
+ ]
77
+ }
78
+
79
+ response = handler(request)
80
+ print(response)
81
+ ```
__pycache__/handler.cpython-310.pyc ADDED
Binary file (4.23 kB). View file
 
handler.py CHANGED
@@ -1,436 +1,166 @@
1
  import os
2
- import torch
3
- import logging
4
- import time
5
- import traceback
6
  import json
7
- from typing import Dict, List, Any, Union, Generator
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
- from threading import Thread
10
-
11
- # Set up logging
12
- logging.basicConfig(
13
- level=logging.INFO,
14
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
15
- )
16
- logger = logging.getLogger(__name__)
17
 
18
  class EndpointHandler:
19
  def __init__(self, path=""):
20
- """
21
- Initialize the model and tokenizer for Phi-4 inference.
22
-
23
- Args:
24
- path (str): Path to the model directory
25
- """
26
- # Set default parameters for inference
27
- self.max_new_tokens = 1024 # Keep at 1024 to avoid timeouts
28
- self.temperature = 0.7
29
- self.top_p = 0.9
30
- self.do_sample = True
31
-
32
- # Determine if CUDA is available
33
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
34
- self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
35
 
36
- logger.info(f"Initializing model from {path} on {self.device}")
 
 
 
 
 
 
 
 
37
 
 
 
38
  try:
39
- # Load tokenizer - use original model ID as fallback
40
- # This helps with common tokenizer mismatch issues
41
- try:
42
- self.tokenizer = AutoTokenizer.from_pretrained(path)
43
- logger.info(f"Loaded tokenizer from local path")
44
- except Exception as e:
45
- logger.warning(f"Failed to load tokenizer from local path: {e}")
46
- self.tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-4-mini-instruct")
47
- logger.info("Loaded tokenizer from microsoft/Phi-4-mini-instruct")
48
-
49
- # Ensure tokenizer has EOS token set
50
- if self.tokenizer.eos_token_id is None:
51
- logger.warning("EOS token not set in tokenizer, using default")
52
- self.tokenizer.eos_token_id = 199999 # Phi-4's default EOS token
53
-
54
- # Load model with appropriate settings
55
- self.model = AutoModelForCausalLM.from_pretrained(
56
- path,
57
- torch_dtype=self.dtype,
58
- device_map="auto" if self.device == "cuda" else None,
59
- trust_remote_code=True
60
- )
61
-
62
- # Move model to device if CPU
63
- if self.device == "cpu":
64
- self.model = self.model.to(self.device)
65
 
66
- # Set model to evaluation mode
67
- self.model.eval()
68
-
69
- # Print diagnostic information
70
- logger.info(f"Model loaded on {self.device} using {self.dtype}")
71
- logger.info(f"Tokenizer vocabulary size: {len(self.tokenizer)}")
72
- logger.info(f"Model vocabulary size: {self.model.config.vocab_size}")
73
- logger.info(f"Model embedding size: {self.model.get_input_embeddings().weight.shape}")
74
-
75
- if len(self.tokenizer) != self.model.config.vocab_size:
76
- logger.warning(f"Tokenizer vocab size ({len(self.tokenizer)}) doesn't match model vocab size ({self.model.config.vocab_size})")
77
 
 
 
78
  except Exception as e:
79
- logger.error(f"Error during model initialization: {str(e)}")
80
- logger.error(traceback.format_exc())
81
- raise
82
-
83
- def format_prompt_with_system(self, user_message, system_message=None):
84
- """
85
- Format the prompt with system and user messages according to Phi-4 format.
 
 
 
 
 
 
 
86
 
87
- Args:
88
- user_message (str): The user's message
89
- system_message (str, optional): The system message/instruction
90
-
91
- Returns:
92
- str: Formatted prompt ready for the model
93
- """
94
- # Format using Phi-4's expected chat template:
95
- # <|system|>
96
- # {system_message}
97
- # <|user|>
98
- # {user_message}
99
- # <|assistant|>
100
 
101
- if system_message:
102
- prompt = f"<|system|>\n{system_message}\n<|user|>\n{user_message}\n<|assistant|>"
103
- else:
104
- # If no system message, just use user message with assistant tag
105
- prompt = f"<|user|>\n{user_message}\n<|assistant|>"
 
 
 
 
 
 
106
 
107
- logger.info(f"Formatted prompt with {'system message and ' if system_message else ''}user message")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  return prompt
109
-
110
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
111
- """
112
- Process the input data and generate a response using the Phi-4 model.
 
113
 
114
- Args:
115
- data (Dict[str, Any]): Input data containing the prompt and generation parameters
116
-
117
- Returns:
118
- Dict[str, Any]: Model response
119
- """
120
- start_time = time.time()
121
- logger.info(f"Starting request processing")
122
 
123
- try:
124
- # Extract input parameters with defaults
125
- if "inputs" not in data:
126
- logger.warning("No 'inputs' field in request data")
127
- error_msg = "Missing 'inputs' field in request"
128
- return self._format_error_response(error_msg)
129
-
130
- # Track user and system messages
131
- user_message = ""
132
- system_message = None
133
-
134
- # Handle different input formats
135
- # 1. Direct string input
136
- if isinstance(data["inputs"], str):
137
- user_message = data["inputs"]
138
- system_message = data.get("parameters", {}).get("system_message", None)
139
-
140
- # 2. Dict with messages format
141
- elif isinstance(data["inputs"], dict) and "messages" in data["inputs"]:
142
- messages = data["inputs"]["messages"]
143
-
144
- # Extract system and user messages for prompt formatting
145
- for msg in messages:
146
- if msg.get("role") == "system":
147
- system_message = msg.get("content", "")
148
- elif msg.get("role") == "user":
149
- user_message = msg.get("content", "")
150
-
151
- # 3. Direct messages list format
152
- elif isinstance(data["inputs"], list):
153
- messages = data["inputs"]
154
-
155
- # Extract system and user messages for prompt formatting
156
- for msg in messages:
157
- if msg.get("role") == "system":
158
- system_message = msg.get("content", "")
159
- elif msg.get("role") == "user":
160
- user_message = msg.get("content", "")
161
- else:
162
- logger.warning(f"Unsupported input format: {type(data['inputs'])}")
163
- error_msg = "Unsupported input format. Expected string or messages object."
164
- return self._format_error_response(error_msg)
165
-
166
- logger.info(f"Extracted user message length: {len(user_message)} characters")
167
- if system_message:
168
- logger.info(f"Extracted system message length: {len(system_message)} characters")
169
-
170
- # Format the prompt with system and user messages
171
- prompt = self.format_prompt_with_system(user_message, system_message)
172
-
173
- parameters = data.get("parameters", {})
174
-
175
- logger.info(f"Processing input with {len(prompt)} characters")
176
-
177
- # Get generation parameters with fallbacks to defaults
178
- max_new_tokens = min(parameters.get("max_new_tokens", self.max_new_tokens), 2048)
179
- temperature = parameters.get("temperature", self.temperature)
180
- top_p = parameters.get("top_p", self.top_p)
181
- do_sample = parameters.get("do_sample", self.do_sample)
182
-
183
- logger.info(f"Generation parameters: max_new_tokens={max_new_tokens}, temperature={temperature}, top_p={top_p}, do_sample={do_sample}")
184
-
185
- # Manually implement generation to avoid token index errors
186
- try:
187
- input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
188
- logger.info(f"Input tokens shape: {input_ids.shape}")
189
-
190
- # Create attention mask
191
- attention_mask = torch.ones_like(input_ids)
192
-
193
- # Perform safe generation with error handling for out-of-vocabulary issues
194
- response_text = self._safe_generate(
195
- input_ids,
196
- attention_mask,
197
- max_new_tokens,
198
- temperature,
199
- top_p,
200
- do_sample,
201
- prompt
202
- )
203
-
204
- logger.info(f"Response generation completed, text length: {len(response_text) if isinstance(response_text, str) else 'N/A'}")
205
-
206
- # Format and return response in OpenAI format
207
- if isinstance(response_text, str):
208
- return self._format_openai_response(
209
- response_text,
210
- input_ids.shape[1],
211
- len(self.tokenizer.encode(response_text)) if response_text else 0
212
- )
213
- else:
214
- return self._format_error_response(f"Error during generation: {response_text}")
215
-
216
- except RuntimeError as e:
217
- logger.error(f"Runtime Error during generation: {str(e)}")
218
- logger.error(traceback.format_exc())
219
- return self._format_error_response(f"Error during generation: {str(e)}")
220
-
221
- except Exception as e:
222
- logger.error(f"Unexpected error during request processing: {str(e)}")
223
- logger.error(traceback.format_exc())
224
- return self._format_error_response(f"Unexpected error: {str(e)}")
225
- finally:
226
- duration = time.time() - start_time
227
- logger.info(f"Request processing completed in {duration:.2f} seconds")
228
-
229
- def _safe_generate(self, input_ids, attention_mask, max_new_tokens, temperature, top_p, do_sample, prompt):
230
- """Safely generate text handling potential token index errors"""
231
- try:
232
- with torch.no_grad():
233
- logger.info("Starting safe generation")
234
-
235
- # Get the input text to exclude from final output
236
- input_text = prompt
237
- logger.info(f"Input prompt length: {len(input_text)} characters")
238
-
239
- # Generate one token at a time to avoid index errors
240
- # Increase from 250 to 500 to allow for longer completions
241
- max_steps = min(max_new_tokens, 500)
242
- current_ids = input_ids.clone()
243
-
244
- logger.info(f"Generating up to {max_steps} tokens")
245
-
246
- # Keep track of last 5 tokens to detect repetition
247
- last_tokens = []
248
- repetition_detected = False
249
-
250
- for i in range(max_steps):
251
- if i % 50 == 0:
252
- logger.info(f"Generated {i} tokens so far")
253
-
254
- # Get logits for next token
255
- outputs = self.model(
256
- input_ids=current_ids,
257
- attention_mask=attention_mask,
258
- return_dict=True
259
- )
260
-
261
- next_token_logits = outputs.logits[:, -1, :]
262
-
263
- # Apply temperature and sampling
264
- if temperature > 0:
265
- next_token_logits = next_token_logits / temperature
266
-
267
- if do_sample:
268
- # Apply top_p sampling
269
- sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
270
- cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
271
-
272
- # Remove tokens with cumulative probability above the threshold
273
- sorted_indices_to_remove = cumulative_probs > top_p
274
- # Shift the indices to the right to keep also the first token above the threshold
275
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
276
- sorted_indices_to_remove[..., 0] = 0
277
-
278
- indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
279
- next_token_logits[indices_to_remove] = -float('Inf')
280
-
281
- # Sample from the filtered distribution
282
- probs = torch.softmax(next_token_logits, dim=-1)
283
- next_token = torch.multinomial(probs, num_samples=1)
284
- else:
285
- # Take the token with highest probability
286
- next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
287
-
288
- # Add the predicted token to the sequence
289
- current_ids = torch.cat([current_ids, next_token], dim=-1)
290
- attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)
291
-
292
- # Add to last tokens list for repetition detection
293
- last_tokens.append(next_token.item())
294
- if len(last_tokens) > 5:
295
- last_tokens.pop(0)
296
-
297
- # Check for repetition (if we have at least 5 tokens)
298
- if len(last_tokens) >= 5:
299
- # Check if all last 5 tokens are the same
300
- if len(set(last_tokens)) == 1:
301
- logger.warning(f"Repetition detected after {i+1} tokens, stopping generation")
302
- repetition_detected = True
303
- break
304
-
305
- # Check if we've generated an EOS token
306
- if next_token[0, 0].item() == self.tokenizer.eos_token_id:
307
- logger.info(f"EOS token generated after {i+1} tokens")
308
- break
309
-
310
- # Decode the generated sequence
311
- generated_text = self.tokenizer.decode(current_ids[0], skip_special_tokens=True)
312
- logger.info(f"Decoded generated text: {len(generated_text)} characters")
313
-
314
- # Return only the newly generated text (after the assistant tag)
315
- split_text = generated_text.split("<|assistant|>")
316
- if len(split_text) > 1:
317
- response_text = split_text[1].strip()
318
- logger.info(f"Extracted assistant response: {len(response_text)} characters")
319
-
320
- # Check if the response text ends with a complete sentence
321
- if not repetition_detected and not response_text.endswith(('.', '!', '?', ':', ';', '"', "'", ')', ']', '}')):
322
- # Add an ellipsis to indicate truncation
323
- response_text += "..."
324
- logger.info("Added ellipsis to incomplete sentence")
325
-
326
- else:
327
- # Fallback if the expected format is not found
328
- logger.warning("Could not find assistant tag in generated text")
329
- response_text = generated_text
330
-
331
- return response_text
332
-
333
- except Exception as e:
334
- logger.error(f"Error in _safe_generate: {str(e)}")
335
- logger.error(traceback.format_exc())
336
- return f"Generation error: {str(e)}. Please try a simpler input."
337
-
338
- def _format_openai_response(self, response_text, prompt_tokens, completion_tokens):
339
- """Format the response in OpenAI-style format"""
340
- try:
341
- # Create a response ID
342
- response_id = f"phi4-{int(time.time())}"
343
-
344
- # Build OpenAI-compatible response
345
- openai_response = {
346
- "id": response_id,
347
- "object": "chat.completion",
348
- "created": int(time.time()),
349
- "model": "phi-4-mini",
350
- "choices": [
351
- {
352
- "index": 0,
353
- "message": {
354
- "role": "assistant",
355
- "content": response_text
356
- },
357
- "finish_reason": "stop"
358
- }
359
- ],
360
- "usage": {
361
- "prompt_tokens": prompt_tokens,
362
- "completion_tokens": completion_tokens,
363
- "total_tokens": prompt_tokens + completion_tokens
364
- }
365
- }
366
-
367
- # For compatibility with Hugging Face UI, include the generated_text field
368
- openai_response["generated_text"] = response_text
369
-
370
- logger.info(f"Formatted OpenAI-style response: {len(json.dumps(openai_response))} bytes")
371
- return openai_response
372
-
373
- except Exception as e:
374
- logger.error(f"Error formatting OpenAI response: {str(e)}")
375
- # Fall back to simple response
376
- return {"generated_text": response_text}
377
 
378
- def _format_error_response(self, error_message):
379
- """Format an error response in OpenAI-style format"""
380
- try:
381
- error_response = {
382
- "id": f"phi4-error-{int(time.time())}",
383
- "object": "chat.completion",
384
- "created": int(time.time()),
385
- "model": "phi-4-mini",
386
- "choices": [
387
- {
388
- "index": 0,
389
- "message": {
390
- "role": "assistant",
391
- "content": f"Error: {error_message}"
392
- },
393
- "finish_reason": "error"
394
- }
395
- ],
396
- "usage": {
397
- "prompt_tokens": 0,
398
- "completion_tokens": 0,
399
- "total_tokens": 0
400
  },
401
- "error": {
402
- "message": error_message,
403
- "type": "invalid_request_error",
404
- "code": "error"
405
- }
 
 
 
 
 
 
 
 
406
  }
407
-
408
- # For compatibility with Hugging Face UI, include the generated_text field
409
- error_response["generated_text"] = f"Error: {error_message}"
410
-
411
- logger.info(f"Formatted error response: {len(json.dumps(error_response))} bytes")
412
- return error_response
413
-
414
- except Exception as e:
415
- logger.error(f"Error formatting error response: {str(e)}")
416
- # Fall back to simple error response
417
- return {"generated_text": f"Error: {error_message}"}
418
-
419
- # For local testing
420
- if __name__ == "__main__":
421
- # Example usage
422
- handler = EndpointHandler()
423
-
424
- # Test with messages format
425
- test_with_messages = {
426
- "inputs": {
427
- "messages": [
428
- {"role": "system", "content": "You are an AI assistant that provides helpful, accurate, and concise information about AI models."},
429
- {"role": "user", "content": "What are the major features of Phi-4?"}
430
- ]
431
- }
432
- }
433
-
434
- # Run the test
435
- result = handler(test_with_messages)
436
- print(json.dumps(result, indent=2))
 
1
  import os
 
 
 
 
2
  import json
3
+ import torch
4
+ from transformers import pipeline, AutoTokenizer
5
+ from typing import Dict, List, Any, Optional, Union
 
 
 
 
 
 
 
6
 
7
  class EndpointHandler:
8
  def __init__(self, path=""):
9
+ # Initialize model and tokenizer
10
+ self.model_path = path if path else os.environ.get("MODEL_PATH", "")
11
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # Create text generation pipeline
14
+ self.pipe = pipeline(
15
+ "text-generation",
16
+ model=self.model_path,
17
+ tokenizer=self.tokenizer,
18
+ torch_dtype=torch.float16,
19
+ device_map="auto",
20
+ return_full_text=False # Only return the generated text, not the prompt
21
+ )
22
 
23
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
24
+ """Handle inference request in OpenAI-like format"""
25
  try:
26
+ # Parse input data
27
+ inputs = self._parse_input(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ # Generate response
30
+ outputs = self._generate(inputs)
 
 
 
 
 
 
 
 
 
31
 
32
+ # Format response in OpenAI-like format
33
+ return self._format_response(outputs, inputs)
34
  except Exception as e:
35
+ return {
36
+ "error": {
37
+ "message": str(e),
38
+ "type": "invalid_request_error",
39
+ "code": 400
40
+ }
41
+ }
42
+
43
+ def _parse_input(self, data: Dict[str, Any]) -> Dict[str, Any]:
44
+ """Parse input data to extract generation parameters"""
45
+ # Extract messages
46
+ messages = data.get("messages", [])
47
+ if not messages:
48
+ raise ValueError("No messages provided")
49
 
50
+ # Convert messages to prompt
51
+ prompt = self._convert_messages_to_prompt(messages)
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ # Extract generation parameters with defaults
54
+ generation_params = {
55
+ "max_tokens": data.get("max_tokens", 256),
56
+ "temperature": data.get("temperature", 0.7),
57
+ "top_p": data.get("top_p", 1.0),
58
+ "n": data.get("n", 1),
59
+ "stream": data.get("stream", False),
60
+ "stop": data.get("stop", None),
61
+ "presence_penalty": data.get("presence_penalty", 0.0),
62
+ "frequency_penalty": data.get("frequency_penalty", 0.0),
63
+ }
64
 
65
+ return {
66
+ "prompt": prompt,
67
+ "messages": messages,
68
+ "generation_params": generation_params
69
+ }
70
+
71
+ def _convert_messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
72
+ """Convert list of messages to a prompt string"""
73
+ prompt = ""
74
+ for message in messages:
75
+ role = message.get("role", "")
76
+ content = message.get("content", "")
77
+
78
+ if role == "system":
79
+ prompt += f"System: {content}\n\n"
80
+ elif role == "user":
81
+ prompt += f"User: {content}\n\n"
82
+ elif role == "assistant":
83
+ prompt += f"Assistant: {content}\n\n"
84
+
85
+ # Add final assistant prompt
86
+ prompt += "Assistant: "
87
  return prompt
88
+
89
+ def _generate(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
90
+ """Generate response using the pipeline"""
91
+ prompt = inputs["prompt"]
92
+ params = inputs["generation_params"]
93
 
94
+ # Count input tokens
95
+ input_tokens = len(self.tokenizer.encode(prompt))
 
 
 
 
 
 
96
 
97
+ # Convert OpenAI-like parameters to pipeline parameters
98
+ generation_kwargs = {
99
+ "max_new_tokens": params["max_tokens"],
100
+ "temperature": params["temperature"],
101
+ "top_p": params["top_p"],
102
+ "num_return_sequences": params["n"],
103
+ "do_sample": params["temperature"] > 0,
104
+ }
105
+
106
+ # Add stopping criteria if provided
107
+ if params["stop"]:
108
+ generation_kwargs["stopping_criteria"] = params["stop"]
109
+
110
+ # Generate output using the pipeline
111
+ pipeline_outputs = self.pipe(
112
+ prompt,
113
+ **generation_kwargs
114
+ )
115
+
116
+ # Extract generated texts
117
+ generated_texts = []
118
+ for output in pipeline_outputs:
119
+ gen_text = output["generated_text"]
120
+
121
+ # Apply stop sequences if provided
122
+ if params["stop"]:
123
+ for stop in params["stop"]:
124
+ if stop in gen_text:
125
+ gen_text = gen_text[:gen_text.find(stop)]
126
+
127
+ generated_texts.append(gen_text)
128
+
129
+ # Count completion tokens
130
+ completion_tokens = [len(self.tokenizer.encode(text)) for text in generated_texts]
131
+
132
+ return {
133
+ "generated_texts": generated_texts,
134
+ "prompt_tokens": input_tokens,
135
+ "completion_tokens": completion_tokens,
136
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ def _format_response(self, outputs: Dict[str, Any], inputs: Dict[str, Any]) -> Dict[str, Any]:
139
+ """Format response in OpenAI-like format"""
140
+ generated_texts = outputs["generated_texts"]
141
+ prompt_tokens = outputs["prompt_tokens"]
142
+ completion_tokens = outputs["completion_tokens"]
143
+
144
+ choices = []
145
+ for i, text in enumerate(generated_texts):
146
+ choices.append({
147
+ "index": i,
148
+ "message": {
149
+ "role": "assistant",
150
+ "content": text
 
 
 
 
 
 
 
 
 
151
  },
152
+ "finish_reason": "stop"
153
+ })
154
+
155
+ return {
156
+ "id": f"cmpl-{hash(inputs['prompt']) % 10000}",
157
+ "object": "chat.completion",
158
+ "created": int(torch.cuda.current_device()) if torch.cuda.is_available() else 0,
159
+ "model": os.path.basename(self.model_path),
160
+ "choices": choices,
161
+ "usage": {
162
+ "prompt_tokens": prompt_tokens,
163
+ "completion_tokens": sum(completion_tokens),
164
+ "total_tokens": prompt_tokens + sum(completion_tokens)
165
  }
166
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.30.0
3
+ accelerate>=0.21.0
4
+ sentencepiece>=0.1.99
test_handler.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ import sys
4
+ from handler import EndpointHandler
5
+
6
+ def test_inference(model_path=".", prompt=None, max_tokens=150, temperature=0.7):
7
+ """
8
+ Test the inference endpoint handler with a sample request.
9
+
10
+ Args:
11
+ model_path: Path to the model directory
12
+ prompt: Custom prompt to use (optional)
13
+ max_tokens: Maximum number of tokens to generate
14
+ temperature: Temperature for generation
15
+ """
16
+ try:
17
+ print(f"Initializing handler with model path: {model_path}")
18
+ handler = EndpointHandler(model_path)
19
+
20
+ # Default or custom prompt
21
+ if prompt is None:
22
+ messages = [
23
+ {"role": "system", "content": "You are a helpful assistant."},
24
+ {"role": "user", "content": "Explain quantum computing in simple terms."}
25
+ ]
26
+ else:
27
+ messages = [
28
+ {"role": "system", "content": "You are a helpful assistant."},
29
+ {"role": "user", "content": prompt}
30
+ ]
31
+
32
+ # Sample request with OpenAI-like format
33
+ request = {
34
+ "messages": messages,
35
+ "max_tokens": max_tokens,
36
+ "temperature": temperature,
37
+ "top_p": 0.95
38
+ }
39
+
40
+ print("Sending request to handler...")
41
+ print(f"Request: {json.dumps(request, indent=2)}")
42
+
43
+ # Generate response
44
+ response = handler(request)
45
+
46
+ # Print response in a readable format
47
+ print("\nResponse:")
48
+ print(json.dumps(response, indent=2))
49
+
50
+ return response
51
+
52
+ except Exception as e:
53
+ print(f"Error during inference: {str(e)}", file=sys.stderr)
54
+ import traceback
55
+ traceback.print_exc()
56
+ return {"error": str(e)}
57
+
58
+ if __name__ == "__main__":
59
+ parser = argparse.ArgumentParser(description="Test Phi-4 Mini inference")
60
+ parser.add_argument("--model_path", type=str, default=".", help="Path to the model directory")
61
+ parser.add_argument("--prompt", type=str, help="Custom prompt to use")
62
+ parser.add_argument("--max_tokens", type=int, default=150, help="Maximum number of tokens to generate")
63
+ parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for generation")
64
+
65
+ args = parser.parse_args()
66
+ test_inference(
67
+ model_path=args.model_path,
68
+ prompt=args.prompt,
69
+ max_tokens=args.max_tokens,
70
+ temperature=args.temperature
71
+ )