import os import torch from typing import Dict, List, Any from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from threading import Thread class EndpointHandler: def __init__(self, path=""): """ Initialize the model and tokenizer for Phi-4 inference. Args: path (str): Path to the model directory """ # Set default parameters for inference self.max_new_tokens = 4096 self.temperature = 0.7 self.top_p = 0.9 self.do_sample = True # Determine if CUDA is available self.device = "cuda" if torch.cuda.is_available() else "cpu" self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained(path) # Load model with appropriate settings self.model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=self.dtype, device_map="auto" if self.device == "cuda" else None, trust_remote_code=True ) # Move model to device if CPU if self.device == "cpu": self.model = self.model.to(self.device) # Set model to evaluation mode self.model.eval() print(f"Model loaded on {self.device} using {self.dtype}") def format_prompt(self, prompt: str) -> str: """ Format the user prompt for Phi-4 model. Args: prompt (str): User input prompt Returns: str: Formatted prompt """ # For Phi-4-mini-instruct, the prompt format is simple # You may need to adjust this based on your specific fine-tuning return prompt def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process the input data and generate a response using the Phi-4 model. Args: data (Dict[str, Any]): Input data containing the prompt and generation parameters Returns: Dict[str, Any]: Model response """ # Extract input parameters with defaults prompt = data.pop("inputs", "") parameters = data.pop("parameters", {}) # Get generation parameters with fallbacks to defaults max_new_tokens = parameters.get("max_new_tokens", self.max_new_tokens) temperature = parameters.get("temperature", self.temperature) top_p = parameters.get("top_p", self.top_p) do_sample = parameters.get("do_sample", self.do_sample) stream = parameters.get("stream", False) # Format the prompt according to model requirements formatted_prompt = self.format_prompt(prompt) # Tokenize the input inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.device) # Handle streaming if requested if stream: return self._generate_stream(inputs, max_new_tokens, temperature, top_p, do_sample) else: return self._generate(inputs, max_new_tokens, temperature, top_p, do_sample) def _generate(self, inputs, max_new_tokens, temperature, top_p, do_sample): """Generate text non-streaming mode""" with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=do_sample, pad_token_id=self.tokenizer.eos_token_id ) # Decode the generated text generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Return only the newly generated text (without the prompt) prompt_length = len(self.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)) response_text = generated_text[prompt_length:] return {"generated_text": response_text} def _generate_stream(self, inputs, max_new_tokens, temperature, top_p, do_sample): """Generate text in streaming mode""" # Create a streamer object streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True) # Set up generation in a separate thread generation_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=do_sample, pad_token_id=self.tokenizer.eos_token_id ) thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() # Determine input text length to strip it from outputs prompt_text = self.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True) prompt_length = len(prompt_text) # Stream the output def generate_stream(): # Skip the prompt part in the first chunk first_chunk = True for text in streamer: if first_chunk: # Only yield new tokens, not the original prompt if len(text) > prompt_length: yield {"generated_text": text[prompt_length:]} first_chunk = False else: yield {"generated_text": text} return generate_stream() # For local testing if __name__ == "__main__": # Example usage handler = EndpointHandler() result = handler({"inputs": "What are the major features of Phi-4?"}) print(result)