| import os |
| import torch |
| import logging |
| import time |
| import traceback |
| import json |
| import re |
| from typing import Dict, List, Any, Union, Generator |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
| from threading import Thread |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| """ |
| Initialize the model and tokenizer for Phi-4 inference. |
| |
| Args: |
| path (str): Path to the model directory |
| """ |
| |
| self.max_new_tokens = 1024 |
| self.temperature = 0.7 |
| self.top_p = 0.9 |
| self.do_sample = True |
| |
| |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
| |
| logger.info(f"Initializing model from {path} on {self.device}") |
| |
| try: |
| |
| |
| try: |
| self.tokenizer = AutoTokenizer.from_pretrained(path) |
| logger.info(f"Loaded tokenizer from local path") |
| except Exception as e: |
| logger.warning(f"Failed to load tokenizer from local path: {e}") |
| self.tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-4-mini-instruct") |
| logger.info("Loaded tokenizer from microsoft/Phi-4-mini-instruct") |
| |
| |
| if self.tokenizer.eos_token_id is None: |
| logger.warning("EOS token not set in tokenizer, using default") |
| self.tokenizer.eos_token_id = 199999 |
| |
| |
| self.model = AutoModelForCausalLM.from_pretrained( |
| path, |
| torch_dtype=self.dtype, |
| device_map="auto" if self.device == "cuda" else None, |
| trust_remote_code=True |
| ) |
| |
| |
| if self.device == "cpu": |
| self.model = self.model.to(self.device) |
| |
| |
| self.model.eval() |
| |
| |
| logger.info(f"Model loaded on {self.device} using {self.dtype}") |
| logger.info(f"Tokenizer vocabulary size: {len(self.tokenizer)}") |
| logger.info(f"Model vocabulary size: {self.model.config.vocab_size}") |
| logger.info(f"Model embedding size: {self.model.get_input_embeddings().weight.shape}") |
| |
| if len(self.tokenizer) != self.model.config.vocab_size: |
| logger.warning(f"Tokenizer vocab size ({len(self.tokenizer)}) doesn't match model vocab size ({self.model.config.vocab_size})") |
| |
| except Exception as e: |
| logger.error(f"Error during model initialization: {str(e)}") |
| logger.error(traceback.format_exc()) |
| raise |
|
|
| def format_prompt_with_system(self, user_message, system_message=None): |
| """ |
| Format the prompt with system and user messages according to Phi-4 format. |
| |
| Args: |
| user_message (str): The user's message |
| system_message (str, optional): The system message/instruction |
| |
| Returns: |
| str: Formatted prompt ready for the model |
| """ |
| |
| |
| |
| |
| |
| |
| |
| if system_message: |
| prompt = f"<|system|>\n{system_message}\n<|user|>\n{user_message}\n<|assistant|>" |
| else: |
| |
| prompt = f"<|user|>\n{user_message}\n<|assistant|>" |
| |
| logger.info(f"Formatted prompt with {'system message and ' if system_message else ''}user message") |
| return prompt |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Process the input data and generate a response using the Phi-4 model. |
| |
| Args: |
| data (Dict[str, Any]): Input data containing the prompt and generation parameters |
| |
| Returns: |
| Dict[str, Any]: Model response |
| """ |
| start_time = time.time() |
| logger.info(f"Starting request processing") |
| |
| try: |
| |
| if "inputs" not in data: |
| logger.warning("No 'inputs' field in request data") |
| error_msg = "Missing 'inputs' field in request" |
| return self._format_error_response(error_msg) |
| |
| |
| user_message = "" |
| system_message = None |
| |
| |
| |
| if isinstance(data["inputs"], str): |
| user_message = data["inputs"] |
| system_message = data.get("parameters", {}).get("system_message", None) |
| |
| |
| elif isinstance(data["inputs"], dict) and "messages" in data["inputs"]: |
| messages = data["inputs"]["messages"] |
| |
| |
| for msg in messages: |
| if msg.get("role") == "system": |
| system_message = msg.get("content", "") |
| elif msg.get("role") == "user": |
| user_message = msg.get("content", "") |
| |
| |
| elif isinstance(data["inputs"], list): |
| messages = data["inputs"] |
| |
| |
| for msg in messages: |
| if msg.get("role") == "system": |
| system_message = msg.get("content", "") |
| elif msg.get("role") == "user": |
| user_message = msg.get("content", "") |
| else: |
| logger.warning(f"Unsupported input format: {type(data['inputs'])}") |
| error_msg = "Unsupported input format. Expected string or messages object." |
| return self._format_error_response(error_msg) |
| |
| logger.info(f"Extracted user message length: {len(user_message)} characters") |
| if system_message: |
| logger.info(f"Extracted system message length: {len(system_message)} characters") |
| |
| |
| prompt = self.format_prompt_with_system(user_message, system_message) |
| |
| parameters = data.get("parameters", {}) |
| |
| logger.info(f"Processing input with {len(prompt)} characters") |
| |
| |
| max_new_tokens = min(parameters.get("max_new_tokens", self.max_new_tokens), 1024) |
| temperature = parameters.get("temperature", self.temperature) |
| top_p = parameters.get("top_p", self.top_p) |
| do_sample = parameters.get("do_sample", self.do_sample) |
| |
| logger.info(f"Generation parameters: max_new_tokens={max_new_tokens}, temperature={temperature}, top_p={top_p}, do_sample={do_sample}") |
| |
| |
| try: |
| input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) |
| logger.info(f"Input tokens shape: {input_ids.shape}") |
| |
| |
| attention_mask = torch.ones_like(input_ids) |
| |
| |
| response_text = self._safe_generate( |
| input_ids, |
| attention_mask, |
| max_new_tokens, |
| temperature, |
| top_p, |
| do_sample, |
| prompt |
| ) |
| |
| logger.info(f"Response generation completed, text length: {len(response_text) if isinstance(response_text, str) else 'N/A'}") |
| |
| |
| if isinstance(response_text, str): |
| response_tokens = len(self.tokenizer.encode(response_text)) if response_text else 0 |
| logger.info(f"Response token count: {response_tokens}") |
| |
| return self._format_openai_response( |
| response_text, |
| input_ids.shape[1], |
| response_tokens |
| ) |
| else: |
| return self._format_error_response(f"Error during generation: {response_text}") |
| |
| except RuntimeError as e: |
| logger.error(f"Runtime Error during generation: {str(e)}") |
| logger.error(traceback.format_exc()) |
| return self._format_error_response(f"Error during generation: {str(e)}") |
| |
| except Exception as e: |
| logger.error(f"Unexpected error during request processing: {str(e)}") |
| logger.error(traceback.format_exc()) |
| return self._format_error_response(f"Unexpected error: {str(e)}") |
| finally: |
| duration = time.time() - start_time |
| logger.info(f"Request processing completed in {duration:.2f} seconds") |
| |
| def _complete_sentence(self, text): |
| """Ensure the text ends with a complete sentence""" |
| |
| if text.strip().endswith(('.', '!', '?')): |
| return text |
| |
| |
| sentences = re.split(r'([.!?])\s+', text) |
| if len(sentences) <= 1: |
| |
| return text + "..." |
| |
| |
| result = "" |
| for i in range(len(sentences) - 1): |
| if i % 2 == 0: |
| result += sentences[i] |
| else: |
| result += sentences[i] + " " |
| |
| return result.strip() |
| |
| def _safe_generate(self, input_ids, attention_mask, max_new_tokens, temperature, top_p, do_sample, prompt): |
| """Safely generate text handling potential token index errors""" |
| try: |
| with torch.no_grad(): |
| logger.info("Starting safe generation") |
| |
| |
| input_text = prompt |
| logger.info(f"Input prompt length: {len(input_text)} characters") |
| |
| |
| |
| max_steps = min(max_new_tokens, 450) |
| current_ids = input_ids.clone() |
| |
| logger.info(f"Generating up to {max_steps} tokens") |
| |
| |
| last_tokens = [] |
| repetition_detected = False |
| |
| for i in range(max_steps): |
| if i % 50 == 0: |
| logger.info(f"Generated {i} tokens so far") |
| |
| |
| if i >= max_steps - 50: |
| |
| temp_text = self.tokenizer.decode(current_ids[0], skip_special_tokens=True) |
| |
| if "<|assistant|>" in temp_text: |
| temp_response = temp_text.split("<|assistant|>")[1].strip() |
| |
| |
| if len(temp_response) > 100 and temp_response.count('.') >= 3: |
| logger.info(f"Early termination at {i} tokens with complete response detected") |
| break |
| |
| |
| outputs = self.model( |
| input_ids=current_ids, |
| attention_mask=attention_mask, |
| return_dict=True |
| ) |
| |
| next_token_logits = outputs.logits[:, -1, :] |
| |
| |
| if temperature > 0: |
| next_token_logits = next_token_logits / temperature |
| |
| if do_sample: |
| |
| sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) |
| cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) |
| |
| |
| sorted_indices_to_remove = cumulative_probs > top_p |
| |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| sorted_indices_to_remove[..., 0] = 0 |
| |
| indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
| next_token_logits[indices_to_remove] = -float('Inf') |
| |
| |
| probs = torch.softmax(next_token_logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| else: |
| |
| next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
| |
| |
| current_ids = torch.cat([current_ids, next_token], dim=-1) |
| attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1) |
| |
| |
| last_tokens.append(next_token.item()) |
| if len(last_tokens) > 5: |
| last_tokens.pop(0) |
| |
| |
| if len(last_tokens) >= 5: |
| |
| if len(set(last_tokens)) == 1: |
| logger.warning(f"Repetition detected after {i+1} tokens, stopping generation") |
| repetition_detected = True |
| break |
| |
| |
| if next_token[0, 0].item() == self.tokenizer.eos_token_id: |
| logger.info(f"EOS token generated after {i+1} tokens") |
| break |
| |
| |
| generated_text = self.tokenizer.decode(current_ids[0], skip_special_tokens=True) |
| logger.info(f"Decoded generated text: {len(generated_text)} characters") |
| |
| |
| split_text = generated_text.split("<|assistant|>") |
| if len(split_text) > 1: |
| assistant_response = split_text[1].strip() |
| logger.info(f"Raw assistant response: {len(assistant_response)} characters") |
| |
| |
| response_text = self._complete_sentence(assistant_response) |
| logger.info(f"Processed assistant response: {len(response_text)} characters") |
| else: |
| |
| logger.warning("Could not find assistant tag in generated text") |
| response_text = generated_text |
| |
| return response_text |
| |
| except Exception as e: |
| logger.error(f"Error in _safe_generate: {str(e)}") |
| logger.error(traceback.format_exc()) |
| return f"Generation error: {str(e)}. Please try a simpler input." |
| |
| def _format_openai_response(self, response_text, prompt_tokens, completion_tokens): |
| """Format the response in OpenAI-style format""" |
| try: |
| |
| response_id = f"phi4-{int(time.time())}" |
| |
| |
| openai_response = { |
| "id": response_id, |
| "object": "chat.completion", |
| "created": int(time.time()), |
| "model": "phi-4-mini", |
| "choices": [ |
| { |
| "index": 0, |
| "message": { |
| "role": "assistant", |
| "content": response_text |
| }, |
| "finish_reason": "stop" |
| } |
| ], |
| "usage": { |
| "prompt_tokens": prompt_tokens, |
| "completion_tokens": completion_tokens, |
| "total_tokens": prompt_tokens + completion_tokens |
| } |
| } |
| |
| |
| openai_response["generated_text"] = response_text |
| |
| logger.info(f"Formatted OpenAI-style response: {len(json.dumps(openai_response))} bytes") |
| return openai_response |
| |
| except Exception as e: |
| logger.error(f"Error formatting OpenAI response: {str(e)}") |
| |
| return {"generated_text": response_text} |
| |
| def _format_error_response(self, error_message): |
| """Format an error response in OpenAI-style format""" |
| try: |
| error_response = { |
| "id": f"phi4-error-{int(time.time())}", |
| "object": "chat.completion", |
| "created": int(time.time()), |
| "model": "phi-4-mini", |
| "choices": [ |
| { |
| "index": 0, |
| "message": { |
| "role": "assistant", |
| "content": f"Error: {error_message}" |
| }, |
| "finish_reason": "error" |
| } |
| ], |
| "usage": { |
| "prompt_tokens": 0, |
| "completion_tokens": 0, |
| "total_tokens": 0 |
| }, |
| "error": { |
| "message": error_message, |
| "type": "invalid_request_error", |
| "code": "error" |
| } |
| } |
| |
| |
| error_response["generated_text"] = f"Error: {error_message}" |
| |
| logger.info(f"Formatted error response: {len(json.dumps(error_response))} bytes") |
| return error_response |
| |
| except Exception as e: |
| logger.error(f"Error formatting error response: {str(e)}") |
| |
| return {"generated_text": f"Error: {error_message}"} |
|
|
| |
| if __name__ == "__main__": |
| |
| handler = EndpointHandler() |
| |
| |
| test_with_messages = { |
| "inputs": { |
| "messages": [ |
| {"role": "system", "content": "You are an AI assistant that provides helpful, accurate, and concise information about AI models."}, |
| {"role": "user", "content": "What are the major features of Phi-4?"} |
| ] |
| } |
| } |
| |
| |
| result = handler(test_with_messages) |
| print(json.dumps(result, indent=2)) |