| import os |
| import torch |
| from typing import Dict, List, Any |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
| from threading import Thread |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| """ |
| Initialize the model and tokenizer for Phi-4 inference. |
| |
| Args: |
| path (str): Path to the model directory |
| """ |
| |
| self.max_new_tokens = 4096 |
| self.temperature = 0.7 |
| self.top_p = 0.9 |
| self.do_sample = True |
| |
| |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
| |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(path) |
| |
| |
| self.model = AutoModelForCausalLM.from_pretrained( |
| path, |
| torch_dtype=self.dtype, |
| device_map="auto" if self.device == "cuda" else None, |
| trust_remote_code=True |
| ) |
| |
| |
| if self.device == "cpu": |
| self.model = self.model.to(self.device) |
| |
| |
| self.model.eval() |
| |
| print(f"Model loaded on {self.device} using {self.dtype}") |
|
|
| def format_prompt(self, prompt: str) -> str: |
| """ |
| Format the user prompt for Phi-4 model. |
| |
| Args: |
| prompt (str): User input prompt |
| |
| Returns: |
| str: Formatted prompt |
| """ |
| |
| |
| return prompt |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Process the input data and generate a response using the Phi-4 model. |
| |
| Args: |
| data (Dict[str, Any]): Input data containing the prompt and generation parameters |
| |
| Returns: |
| Dict[str, Any]: Model response |
| """ |
| |
| prompt = data.pop("inputs", "") |
| parameters = data.pop("parameters", {}) |
| |
| |
| max_new_tokens = parameters.get("max_new_tokens", self.max_new_tokens) |
| temperature = parameters.get("temperature", self.temperature) |
| top_p = parameters.get("top_p", self.top_p) |
| do_sample = parameters.get("do_sample", self.do_sample) |
| stream = parameters.get("stream", False) |
| |
| |
| formatted_prompt = self.format_prompt(prompt) |
| |
| |
| inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.device) |
| |
| |
| if stream: |
| return self._generate_stream(inputs, max_new_tokens, temperature, top_p, do_sample) |
| else: |
| return self._generate(inputs, max_new_tokens, temperature, top_p, do_sample) |
| |
| def _generate(self, inputs, max_new_tokens, temperature, top_p, do_sample): |
| """Generate text non-streaming mode""" |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| do_sample=do_sample, |
| pad_token_id=self.tokenizer.eos_token_id |
| ) |
| |
| |
| generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| |
| prompt_length = len(self.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)) |
| response_text = generated_text[prompt_length:] |
| |
| return {"generated_text": response_text} |
| |
| def _generate_stream(self, inputs, max_new_tokens, temperature, top_p, do_sample): |
| """Generate text in streaming mode""" |
| |
| streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True) |
| |
| |
| generation_kwargs = dict( |
| **inputs, |
| streamer=streamer, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| do_sample=do_sample, |
| pad_token_id=self.tokenizer.eos_token_id |
| ) |
| |
| thread = Thread(target=self.model.generate, kwargs=generation_kwargs) |
| thread.start() |
| |
| |
| prompt_text = self.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True) |
| prompt_length = len(prompt_text) |
| |
| |
| def generate_stream(): |
| |
| first_chunk = True |
| for text in streamer: |
| if first_chunk: |
| |
| if len(text) > prompt_length: |
| yield {"generated_text": text[prompt_length:]} |
| first_chunk = False |
| else: |
| yield {"generated_text": text} |
| |
| return generate_stream() |
|
|
| |
| if __name__ == "__main__": |
| |
| handler = EndpointHandler() |
| result = handler({"inputs": "What are the major features of Phi-4?"}) |
| print(result) |