Yong Liu commited on
Commit
02d7d65
·
1 Parent(s): 4aa4d08

update handler

Browse files
Files changed (2) hide show
  1. README.md +0 -81
  2. handler.py +436 -229
README.md DELETED
@@ -1,81 +0,0 @@
1
- # Phi-4 Mini Inference Endpoint Handler
2
-
3
- This repository contains code for deploying the Phi-4 Mini model to a HuggingFace Inference Endpoint with an OpenAI-compatible API format.
4
-
5
- ## Setup
6
-
7
- 1. Install the required dependencies:
8
- ```
9
- pip install -r requirements.txt
10
- ```
11
-
12
- 2. Set the environment variable to your model path (optional if model is in the same directory):
13
- ```
14
- export MODEL_PATH=/path/to/your/model
15
- ```
16
-
17
- ## Usage
18
-
19
- When deploying to a HuggingFace Inference Endpoint, the `handler.py` file will be used to process requests. The endpoint accepts requests in an OpenAI-compatible format:
20
-
21
- ```json
22
- {
23
- "messages": [
24
- {"role": "system", "content": "You are a helpful assistant."},
25
- {"role": "user", "content": "Tell me about language models."}
26
- ],
27
- "max_tokens": 256,
28
- "temperature": 0.7,
29
- "top_p": 1.0,
30
- "n": 1,
31
- "stop": ["\n", "User:"]
32
- }
33
- ```
34
-
35
- The endpoint returns responses in an OpenAI-compatible format:
36
-
37
- ```json
38
- {
39
- "id": "cmpl-12345",
40
- "object": "chat.completion",
41
- "created": 0,
42
- "model": "phi4-mini-raw",
43
- "choices": [
44
- {
45
- "index": 0,
46
- "message": {
47
- "role": "assistant",
48
- "content": "Language models are computational systems designed to understand and generate human language..."
49
- },
50
- "finish_reason": "stop"
51
- }
52
- ],
53
- "usage": {
54
- "prompt_tokens": 42,
55
- "completion_tokens": 156,
56
- "total_tokens": 198
57
- }
58
- }
59
- ```
60
-
61
- ## Local Testing
62
-
63
- To test the handler locally before deployment:
64
-
65
- ```python
66
- from handler import EndpointHandler
67
-
68
- # Initialize the handler with your model path
69
- handler = EndpointHandler("./phi4-mini-raw")
70
-
71
- # Test with a sample request
72
- request = {
73
- "messages": [
74
- {"role": "system", "content": "You are a helpful assistant."},
75
- {"role": "user", "content": "Hello, how are you?"}
76
- ]
77
- }
78
-
79
- response = handler(request)
80
- print(response)
81
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
handler.py CHANGED
@@ -1,265 +1,472 @@
1
  import os
2
- import json
3
  import torch
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
- from typing import Dict, List, Any
6
-
7
- # Fix for the rope_scaling validation issue
8
- import transformers.models.phi3.configuration_phi3
9
- # Store original method
10
- original_validation = transformers.models.phi3.configuration_phi3.Phi3Config._rope_scaling_validation
11
-
12
- # Replace with a no-op function
13
- def no_validation(self):
14
- pass
15
 
16
- # Apply the patch
17
- transformers.models.phi3.configuration_phi3.Phi3Config._rope_scaling_validation = no_validation
 
 
 
 
18
 
19
  class EndpointHandler:
20
  def __init__(self, path=""):
21
- # Initialize model and tokenizer
22
- self.model_path = path if path else os.environ.get("MODEL_PATH", "")
23
- print(f"Loading model from: {self.model_path}")
24
-
25
- # Load tokenizer
26
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
27
 
28
- # Determine the device to use
29
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
- print(f"Using device: {self.device}")
 
 
 
 
 
31
 
32
- # Load model directly without pipeline
33
- self.model = AutoModelForCausalLM.from_pretrained(
34
- self.model_path,
35
- torch_dtype=torch.float16,
36
- device_map="auto"
37
- )
38
- # Ensure model is on the correct device
39
- if torch.cuda.is_available():
40
- self.model = self.model.cuda()
41
 
42
- print("Model loaded successfully")
43
 
44
- # For Phi3 models, monkey patch the RotaryEmbedding
45
  try:
46
- from transformers.models.phi3.modeling_phi3 import PhiRotaryEmbedding
47
- original_forward = PhiRotaryEmbedding.forward
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- def patched_forward(self, position_ids, query, key, value=None):
50
- # Ensure position_ids is on the same device as query
51
- position_ids = position_ids.to(query.device)
52
- return original_forward(self, position_ids, query, key, value)
53
-
54
- PhiRotaryEmbedding.forward = patched_forward
55
- print("Successfully patched PhiRotaryEmbedding.forward")
56
  except Exception as e:
57
- print(f"Could not patch PhiRotaryEmbedding: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
 
 
 
 
 
 
 
 
 
59
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
60
- """Handle inference request in OpenAI-like format or HuggingFace Inference API format"""
 
 
 
 
 
 
 
 
 
 
 
61
  try:
62
- # Debugging: Print the received data
63
- print(f"Received data: {json.dumps(data, indent=2)}")
 
 
 
64
 
65
- # Handle HuggingFace Inference API format
66
- if "inputs" in data:
67
- # Extract data from inputs key
68
- if isinstance(data["inputs"], dict):
69
- # If inputs contains a dictionary, extract it
70
- input_data = data["inputs"]
71
- elif isinstance(data["inputs"], str):
72
- # If inputs is a string, create a simple message
73
- input_data = {
74
- "messages": [
75
- {"role": "user", "content": data["inputs"]}
76
- ]
77
- }
78
- else:
79
- print(f"Unexpected inputs format: {type(data['inputs'])}")
80
- # Try to convert to string if possible
81
- try:
82
- input_data = {
83
- "messages": [
84
- {"role": "user", "content": str(data["inputs"])}
85
- ]
86
- }
87
- except:
88
- raise ValueError(f"Unsupported inputs format: {type(data['inputs'])}")
 
 
 
 
 
 
 
89
  else:
90
- # Assume direct OpenAI format
91
- input_data = data
 
92
 
93
- # Debugging: Print the parsed input data
94
- print(f"Parsed input data: {json.dumps(input_data, indent=2)}")
 
95
 
96
- # Parse input data
97
- inputs = self._parse_input(input_data)
98
 
99
- # Generate response
100
- outputs = self._generate(inputs)
101
 
102
- # Format response in OpenAI-like format
103
- return self._format_response(outputs, inputs)
104
- except Exception as e:
105
- print(f"Error during processing: {str(e)}")
106
- import traceback
107
- traceback.print_exc()
108
- return {
109
- "error": {
110
- "message": str(e),
111
- "type": "invalid_request_error",
112
- "code": 400
113
- }
114
- }
115
-
116
- def _parse_input(self, data: Dict[str, Any]) -> Dict[str, Any]:
117
- """Parse input data to extract generation parameters"""
118
- # Extract messages
119
- messages = data.get("messages", [])
120
- if not messages:
121
- print(f"No messages found in data: {json.dumps(data, indent=2)}")
122
- raise ValueError("No messages provided")
123
-
124
- # Convert messages to prompt
125
- prompt = self._convert_messages_to_prompt(messages)
126
-
127
- # Extract generation parameters with defaults
128
- generation_params = {
129
- "max_tokens": data.get("max_tokens", 256),
130
- "temperature": data.get("temperature", 0.7),
131
- "top_p": data.get("top_p", 1.0),
132
- "n": data.get("n", 1),
133
- "stream": data.get("stream", False),
134
- "stop": data.get("stop", None),
135
- "presence_penalty": data.get("presence_penalty", 0.0),
136
- "frequency_penalty": data.get("frequency_penalty", 0.0),
137
- }
138
-
139
- return {
140
- "prompt": prompt,
141
- "messages": messages,
142
- "generation_params": generation_params
143
- }
144
-
145
- def _convert_messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
146
- """Convert list of messages to a prompt string"""
147
- prompt = ""
148
- for message in messages:
149
- role = message.get("role", "")
150
- content = message.get("content", "")
151
 
152
- if role == "system":
153
- prompt += f"System: {content}\n\n"
154
- elif role == "user":
155
- prompt += f"User: {content}\n\n"
156
- elif role == "assistant":
157
- prompt += f"Assistant: {content}\n\n"
158
 
159
- # Add final assistant prompt
160
- prompt += "Assistant: "
161
- return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
- def _generate(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
164
- """Generate response using the model directly"""
165
- prompt = inputs["prompt"]
166
- params = inputs["generation_params"]
167
-
168
- # Get the model's device
169
- device = next(self.model.parameters()).device
170
- print(f"Model is on device: {device}")
171
-
172
- # Tokenize input and ensure it's on the correct device
173
- input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(device)
174
- print(f"Input tensor device: {input_ids.device}")
175
 
176
- # Count input tokens
177
- input_tokens = input_ids.shape[1]
 
 
 
178
 
179
- # Convert OpenAI-like parameters to HF parameters
180
- generation_kwargs = {
181
- "max_new_tokens": params["max_tokens"],
182
- "temperature": params["temperature"],
183
- "top_p": params["top_p"],
184
- "num_return_sequences": params["n"],
185
- "do_sample": params["temperature"] > 0,
186
- "pad_token_id": self.tokenizer.eos_token_id,
187
- }
188
 
189
- # Generate output
 
 
 
190
  try:
191
  with torch.no_grad():
192
- outputs = self.model.generate(
193
- input_ids,
194
- **generation_kwargs
195
- )
196
- print(f"Output tensor device: {outputs.device}")
197
- except RuntimeError as e:
198
- if "Expected all tensors to be on the same device" in str(e):
199
- print("Caught device mismatch error, trying to fix...")
200
- # A more drastic approach: move the model completely to CPU if there's a device issue
201
- if torch.cuda.is_available():
202
- print("Moving everything to CPU as a fallback")
203
- self.model = self.model.cpu()
204
- input_ids = input_ids.cpu()
205
- with torch.no_grad():
206
- outputs = self.model.generate(
207
- input_ids,
208
- **generation_kwargs
209
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  else:
211
- raise
212
- else:
213
- raise
214
-
215
- # Decode output
216
- generated_texts = []
217
- for i in range(params["n"]):
218
- gen_text = self.tokenizer.decode(outputs[i][input_tokens:], skip_special_tokens=True)
219
 
220
- # Apply stop sequences if provided
221
- if params["stop"]:
222
- for stop in params["stop"]:
223
- if stop in gen_text:
224
- gen_text = gen_text[:gen_text.find(stop)]
 
 
 
 
 
225
 
226
- generated_texts.append(gen_text)
227
-
228
- # Count completion tokens
229
- completion_tokens = [len(self.tokenizer.encode(text)) for text in generated_texts]
230
-
231
- return {
232
- "generated_texts": generated_texts,
233
- "prompt_tokens": input_tokens,
234
- "completion_tokens": completion_tokens,
235
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
- def _format_response(self, outputs: Dict[str, Any], inputs: Dict[str, Any]) -> Dict[str, Any]:
238
- """Format response in OpenAI-like format"""
239
- generated_texts = outputs["generated_texts"]
240
- prompt_tokens = outputs["prompt_tokens"]
241
- completion_tokens = outputs["completion_tokens"]
242
-
243
- choices = []
244
- for i, text in enumerate(generated_texts):
245
- choices.append({
246
- "index": i,
247
- "message": {
248
- "role": "assistant",
249
- "content": text
 
 
 
 
 
 
 
 
 
250
  },
251
- "finish_reason": "stop"
252
- })
253
-
254
- return {
255
- "id": f"cmpl-{hash(inputs['prompt']) % 10000}",
256
- "object": "chat.completion",
257
- "created": int(torch.cuda.current_device()) if torch.cuda.is_available() else 0,
258
- "model": os.path.basename(self.model_path),
259
- "choices": choices,
260
- "usage": {
261
- "prompt_tokens": prompt_tokens,
262
- "completion_tokens": sum(completion_tokens),
263
- "total_tokens": prompt_tokens + sum(completion_tokens)
264
  }
265
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
2
  import torch
3
+ import logging
4
+ import time
5
+ import traceback
6
+ import json
7
+ import re
8
+ from typing import Dict, List, Any, Union, Generator
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
+ from threading import Thread
 
 
 
11
 
12
+ # Set up logging
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
16
+ )
17
+ logger = logging.getLogger(__name__)
18
 
19
  class EndpointHandler:
20
  def __init__(self, path=""):
21
+ """
22
+ Initialize the model and tokenizer for Phi-4 inference.
 
 
 
 
23
 
24
+ Args:
25
+ path (str): Path to the model directory
26
+ """
27
+ # Set default parameters for inference
28
+ self.max_new_tokens = 1024 # Keep at 1024 to avoid timeouts
29
+ self.temperature = 0.7
30
+ self.top_p = 0.9
31
+ self.do_sample = True
32
 
33
+ # Determine if CUDA is available
34
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
35
+ self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
 
 
 
 
 
 
36
 
37
+ logger.info(f"Initializing model from {path} on {self.device}")
38
 
 
39
  try:
40
+ # Load tokenizer - use original model ID as fallback
41
+ # This helps with common tokenizer mismatch issues
42
+ try:
43
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
44
+ logger.info(f"Loaded tokenizer from local path")
45
+ except Exception as e:
46
+ logger.warning(f"Failed to load tokenizer from local path: {e}")
47
+ self.tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-4-mini-instruct")
48
+ logger.info("Loaded tokenizer from microsoft/Phi-4-mini-instruct")
49
+
50
+ # Ensure tokenizer has EOS token set
51
+ if self.tokenizer.eos_token_id is None:
52
+ logger.warning("EOS token not set in tokenizer, using default")
53
+ self.tokenizer.eos_token_id = 199999 # Phi-4's default EOS token
54
+
55
+ # Load model with appropriate settings
56
+ self.model = AutoModelForCausalLM.from_pretrained(
57
+ path,
58
+ torch_dtype=self.dtype,
59
+ device_map="auto" if self.device == "cuda" else None,
60
+ trust_remote_code=True
61
+ )
62
+
63
+ # Move model to device if CPU
64
+ if self.device == "cpu":
65
+ self.model = self.model.to(self.device)
66
+
67
+ # Set model to evaluation mode
68
+ self.model.eval()
69
+
70
+ # Print diagnostic information
71
+ logger.info(f"Model loaded on {self.device} using {self.dtype}")
72
+ logger.info(f"Tokenizer vocabulary size: {len(self.tokenizer)}")
73
+ logger.info(f"Model vocabulary size: {self.model.config.vocab_size}")
74
+ logger.info(f"Model embedding size: {self.model.get_input_embeddings().weight.shape}")
75
+
76
+ if len(self.tokenizer) != self.model.config.vocab_size:
77
+ logger.warning(f"Tokenizer vocab size ({len(self.tokenizer)}) doesn't match model vocab size ({self.model.config.vocab_size})")
78
 
 
 
 
 
 
 
 
79
  except Exception as e:
80
+ logger.error(f"Error during model initialization: {str(e)}")
81
+ logger.error(traceback.format_exc())
82
+ raise
83
+
84
+ def format_prompt_with_system(self, user_message, system_message=None):
85
+ """
86
+ Format the prompt with system and user messages according to Phi-4 format.
87
+
88
+ Args:
89
+ user_message (str): The user's message
90
+ system_message (str, optional): The system message/instruction
91
+
92
+ Returns:
93
+ str: Formatted prompt ready for the model
94
+ """
95
+ # Format using Phi-4's expected chat template:
96
+ # <|system|>
97
+ # {system_message}
98
+ # <|user|>
99
+ # {user_message}
100
+ # <|assistant|>
101
 
102
+ if system_message:
103
+ prompt = f"<|system|>\n{system_message}\n<|user|>\n{user_message}\n<|assistant|>"
104
+ else:
105
+ # If no system message, just use user message with assistant tag
106
+ prompt = f"<|user|>\n{user_message}\n<|assistant|>"
107
+
108
+ logger.info(f"Formatted prompt with {'system message and ' if system_message else ''}user message")
109
+ return prompt
110
+
111
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
112
+ """
113
+ Process the input data and generate a response using the Phi-4 model.
114
+
115
+ Args:
116
+ data (Dict[str, Any]): Input data containing the prompt and generation parameters
117
+
118
+ Returns:
119
+ Dict[str, Any]: Model response
120
+ """
121
+ start_time = time.time()
122
+ logger.info(f"Starting request processing")
123
+
124
  try:
125
+ # Extract input parameters with defaults
126
+ if "inputs" not in data:
127
+ logger.warning("No 'inputs' field in request data")
128
+ error_msg = "Missing 'inputs' field in request"
129
+ return self._format_error_response(error_msg)
130
 
131
+ # Track user and system messages
132
+ user_message = ""
133
+ system_message = None
134
+
135
+ # Handle different input formats
136
+ # 1. Direct string input
137
+ if isinstance(data["inputs"], str):
138
+ user_message = data["inputs"]
139
+ system_message = data.get("parameters", {}).get("system_message", None)
140
+
141
+ # 2. Dict with messages format
142
+ elif isinstance(data["inputs"], dict) and "messages" in data["inputs"]:
143
+ messages = data["inputs"]["messages"]
144
+
145
+ # Extract system and user messages for prompt formatting
146
+ for msg in messages:
147
+ if msg.get("role") == "system":
148
+ system_message = msg.get("content", "")
149
+ elif msg.get("role") == "user":
150
+ user_message = msg.get("content", "")
151
+
152
+ # 3. Direct messages list format
153
+ elif isinstance(data["inputs"], list):
154
+ messages = data["inputs"]
155
+
156
+ # Extract system and user messages for prompt formatting
157
+ for msg in messages:
158
+ if msg.get("role") == "system":
159
+ system_message = msg.get("content", "")
160
+ elif msg.get("role") == "user":
161
+ user_message = msg.get("content", "")
162
  else:
163
+ logger.warning(f"Unsupported input format: {type(data['inputs'])}")
164
+ error_msg = "Unsupported input format. Expected string or messages object."
165
+ return self._format_error_response(error_msg)
166
 
167
+ logger.info(f"Extracted user message length: {len(user_message)} characters")
168
+ if system_message:
169
+ logger.info(f"Extracted system message length: {len(system_message)} characters")
170
 
171
+ # Format the prompt with system and user messages
172
+ prompt = self.format_prompt_with_system(user_message, system_message)
173
 
174
+ parameters = data.get("parameters", {})
 
175
 
176
+ logger.info(f"Processing input with {len(prompt)} characters")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
+ # Get generation parameters with fallbacks to defaults
179
+ max_new_tokens = min(parameters.get("max_new_tokens", self.max_new_tokens), 1024)
180
+ temperature = parameters.get("temperature", self.temperature)
181
+ top_p = parameters.get("top_p", self.top_p)
182
+ do_sample = parameters.get("do_sample", self.do_sample)
 
183
 
184
+ logger.info(f"Generation parameters: max_new_tokens={max_new_tokens}, temperature={temperature}, top_p={top_p}, do_sample={do_sample}")
185
+
186
+ # Manually implement generation to avoid token index errors
187
+ try:
188
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
189
+ logger.info(f"Input tokens shape: {input_ids.shape}")
190
+
191
+ # Create attention mask
192
+ attention_mask = torch.ones_like(input_ids)
193
+
194
+ # Perform safe generation with error handling for out-of-vocabulary issues
195
+ response_text = self._safe_generate(
196
+ input_ids,
197
+ attention_mask,
198
+ max_new_tokens,
199
+ temperature,
200
+ top_p,
201
+ do_sample,
202
+ prompt
203
+ )
204
+
205
+ logger.info(f"Response generation completed, text length: {len(response_text) if isinstance(response_text, str) else 'N/A'}")
206
+
207
+ # Format and return response in OpenAI format
208
+ if isinstance(response_text, str):
209
+ response_tokens = len(self.tokenizer.encode(response_text)) if response_text else 0
210
+ logger.info(f"Response token count: {response_tokens}")
211
+
212
+ return self._format_openai_response(
213
+ response_text,
214
+ input_ids.shape[1],
215
+ response_tokens
216
+ )
217
+ else:
218
+ return self._format_error_response(f"Error during generation: {response_text}")
219
+
220
+ except RuntimeError as e:
221
+ logger.error(f"Runtime Error during generation: {str(e)}")
222
+ logger.error(traceback.format_exc())
223
+ return self._format_error_response(f"Error during generation: {str(e)}")
224
+
225
+ except Exception as e:
226
+ logger.error(f"Unexpected error during request processing: {str(e)}")
227
+ logger.error(traceback.format_exc())
228
+ return self._format_error_response(f"Unexpected error: {str(e)}")
229
+ finally:
230
+ duration = time.time() - start_time
231
+ logger.info(f"Request processing completed in {duration:.2f} seconds")
232
 
233
+ def _complete_sentence(self, text):
234
+ """Ensure the text ends with a complete sentence"""
235
+ # If text is already a complete sentence, return it
236
+ if text.strip().endswith(('.', '!', '?')):
237
+ return text
 
 
 
 
 
 
 
238
 
239
+ # Find the last complete sentence end
240
+ sentences = re.split(r'([.!?])\s+', text)
241
+ if len(sentences) <= 1:
242
+ # No complete sentences found, return as is with ellipsis
243
+ return text + "..."
244
 
245
+ # Reconstruct text up to the last complete sentence
246
+ result = ""
247
+ for i in range(len(sentences) - 1):
248
+ if i % 2 == 0: # Content before punctuation
249
+ result += sentences[i]
250
+ else: # Punctuation
251
+ result += sentences[i] + " "
 
 
252
 
253
+ return result.strip()
254
+
255
+ def _safe_generate(self, input_ids, attention_mask, max_new_tokens, temperature, top_p, do_sample, prompt):
256
+ """Safely generate text handling potential token index errors"""
257
  try:
258
  with torch.no_grad():
259
+ logger.info("Starting safe generation")
260
+
261
+ # Get the input text to exclude from final output
262
+ input_text = prompt
263
+ logger.info(f"Input prompt length: {len(input_text)} characters")
264
+
265
+ # Generate one token at a time to avoid index errors
266
+ # Use a lower absolute maximum to ensure completion
267
+ max_steps = min(max_new_tokens, 450) # Adjusted down from 500
268
+ current_ids = input_ids.clone()
269
+
270
+ logger.info(f"Generating up to {max_steps} tokens")
271
+
272
+ # Keep track of last 5 tokens to detect repetition
273
+ last_tokens = []
274
+ repetition_detected = False
275
+
276
+ for i in range(max_steps):
277
+ if i % 50 == 0:
278
+ logger.info(f"Generated {i} tokens so far")
279
+
280
+ # Early termination if we're getting close to the limit to allow for post-processing
281
+ if i >= max_steps - 50:
282
+ # Temporarily decode to check if we have a complete response already
283
+ temp_text = self.tokenizer.decode(current_ids[0], skip_special_tokens=True)
284
+
285
+ if "<|assistant|>" in temp_text:
286
+ temp_response = temp_text.split("<|assistant|>")[1].strip()
287
+
288
+ # If we have a reasonably complete response, stop early
289
+ if len(temp_response) > 100 and temp_response.count('.') >= 3:
290
+ logger.info(f"Early termination at {i} tokens with complete response detected")
291
+ break
292
+
293
+ # Get logits for next token
294
+ outputs = self.model(
295
+ input_ids=current_ids,
296
+ attention_mask=attention_mask,
297
+ return_dict=True
298
+ )
299
+
300
+ next_token_logits = outputs.logits[:, -1, :]
301
+
302
+ # Apply temperature and sampling
303
+ if temperature > 0:
304
+ next_token_logits = next_token_logits / temperature
305
+
306
+ if do_sample:
307
+ # Apply top_p sampling
308
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
309
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
310
+
311
+ # Remove tokens with cumulative probability above the threshold
312
+ sorted_indices_to_remove = cumulative_probs > top_p
313
+ # Shift the indices to the right to keep also the first token above the threshold
314
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
315
+ sorted_indices_to_remove[..., 0] = 0
316
+
317
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
318
+ next_token_logits[indices_to_remove] = -float('Inf')
319
+
320
+ # Sample from the filtered distribution
321
+ probs = torch.softmax(next_token_logits, dim=-1)
322
+ next_token = torch.multinomial(probs, num_samples=1)
323
+ else:
324
+ # Take the token with highest probability
325
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
326
+
327
+ # Add the predicted token to the sequence
328
+ current_ids = torch.cat([current_ids, next_token], dim=-1)
329
+ attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)
330
+
331
+ # Add to last tokens list for repetition detection
332
+ last_tokens.append(next_token.item())
333
+ if len(last_tokens) > 5:
334
+ last_tokens.pop(0)
335
+
336
+ # Check for repetition (if we have at least 5 tokens)
337
+ if len(last_tokens) >= 5:
338
+ # Check if all last 5 tokens are the same
339
+ if len(set(last_tokens)) == 1:
340
+ logger.warning(f"Repetition detected after {i+1} tokens, stopping generation")
341
+ repetition_detected = True
342
+ break
343
+
344
+ # Check if we've generated an EOS token
345
+ if next_token[0, 0].item() == self.tokenizer.eos_token_id:
346
+ logger.info(f"EOS token generated after {i+1} tokens")
347
+ break
348
+
349
+ # Decode the generated sequence
350
+ generated_text = self.tokenizer.decode(current_ids[0], skip_special_tokens=True)
351
+ logger.info(f"Decoded generated text: {len(generated_text)} characters")
352
+
353
+ # Return only the newly generated text (after the assistant tag)
354
+ split_text = generated_text.split("<|assistant|>")
355
+ if len(split_text) > 1:
356
+ assistant_response = split_text[1].strip()
357
+ logger.info(f"Raw assistant response: {len(assistant_response)} characters")
358
+
359
+ # Process the response to ensure complete sentences
360
+ response_text = self._complete_sentence(assistant_response)
361
+ logger.info(f"Processed assistant response: {len(response_text)} characters")
362
  else:
363
+ # Fallback if the expected format is not found
364
+ logger.warning("Could not find assistant tag in generated text")
365
+ response_text = generated_text
366
+
367
+ return response_text
 
 
 
368
 
369
+ except Exception as e:
370
+ logger.error(f"Error in _safe_generate: {str(e)}")
371
+ logger.error(traceback.format_exc())
372
+ return f"Generation error: {str(e)}. Please try a simpler input."
373
+
374
+ def _format_openai_response(self, response_text, prompt_tokens, completion_tokens):
375
+ """Format the response in OpenAI-style format"""
376
+ try:
377
+ # Create a response ID
378
+ response_id = f"phi4-{int(time.time())}"
379
 
380
+ # Build OpenAI-compatible response
381
+ openai_response = {
382
+ "id": response_id,
383
+ "object": "chat.completion",
384
+ "created": int(time.time()),
385
+ "model": "phi-4-mini",
386
+ "choices": [
387
+ {
388
+ "index": 0,
389
+ "message": {
390
+ "role": "assistant",
391
+ "content": response_text
392
+ },
393
+ "finish_reason": "stop"
394
+ }
395
+ ],
396
+ "usage": {
397
+ "prompt_tokens": prompt_tokens,
398
+ "completion_tokens": completion_tokens,
399
+ "total_tokens": prompt_tokens + completion_tokens
400
+ }
401
+ }
402
+
403
+ # For compatibility with Hugging Face UI, include the generated_text field
404
+ openai_response["generated_text"] = response_text
405
+
406
+ logger.info(f"Formatted OpenAI-style response: {len(json.dumps(openai_response))} bytes")
407
+ return openai_response
408
+
409
+ except Exception as e:
410
+ logger.error(f"Error formatting OpenAI response: {str(e)}")
411
+ # Fall back to simple response
412
+ return {"generated_text": response_text}
413
 
414
+ def _format_error_response(self, error_message):
415
+ """Format an error response in OpenAI-style format"""
416
+ try:
417
+ error_response = {
418
+ "id": f"phi4-error-{int(time.time())}",
419
+ "object": "chat.completion",
420
+ "created": int(time.time()),
421
+ "model": "phi-4-mini",
422
+ "choices": [
423
+ {
424
+ "index": 0,
425
+ "message": {
426
+ "role": "assistant",
427
+ "content": f"Error: {error_message}"
428
+ },
429
+ "finish_reason": "error"
430
+ }
431
+ ],
432
+ "usage": {
433
+ "prompt_tokens": 0,
434
+ "completion_tokens": 0,
435
+ "total_tokens": 0
436
  },
437
+ "error": {
438
+ "message": error_message,
439
+ "type": "invalid_request_error",
440
+ "code": "error"
441
+ }
 
 
 
 
 
 
 
 
442
  }
443
+
444
+ # For compatibility with Hugging Face UI, include the generated_text field
445
+ error_response["generated_text"] = f"Error: {error_message}"
446
+
447
+ logger.info(f"Formatted error response: {len(json.dumps(error_response))} bytes")
448
+ return error_response
449
+
450
+ except Exception as e:
451
+ logger.error(f"Error formatting error response: {str(e)}")
452
+ # Fall back to simple error response
453
+ return {"generated_text": f"Error: {error_message}"}
454
+
455
+ # For local testing
456
+ if __name__ == "__main__":
457
+ # Example usage
458
+ handler = EndpointHandler()
459
+
460
+ # Test with messages format
461
+ test_with_messages = {
462
+ "inputs": {
463
+ "messages": [
464
+ {"role": "system", "content": "You are an AI assistant that provides helpful, accurate, and concise information about AI models."},
465
+ {"role": "user", "content": "What are the major features of Phi-4?"}
466
+ ]
467
+ }
468
+ }
469
+
470
+ # Run the test
471
+ result = handler(test_with_messages)
472
+ print(json.dumps(result, indent=2))