phi4-mini-raw / handler.py
yol146
add handler
ec5bb4e
raw
history blame
5.77 kB
import os
import torch
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the model and tokenizer for Phi-4 inference.
Args:
path (str): Path to the model directory
"""
# Set default parameters for inference
self.max_new_tokens = 4096
self.temperature = 0.7
self.top_p = 0.9
self.do_sample = True
# Determine if CUDA is available
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(path)
# Load model with appropriate settings
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=self.dtype,
device_map="auto" if self.device == "cuda" else None,
trust_remote_code=True
)
# Move model to device if CPU
if self.device == "cpu":
self.model = self.model.to(self.device)
# Set model to evaluation mode
self.model.eval()
print(f"Model loaded on {self.device} using {self.dtype}")
def format_prompt(self, prompt: str) -> str:
"""
Format the user prompt for Phi-4 model.
Args:
prompt (str): User input prompt
Returns:
str: Formatted prompt
"""
# For Phi-4-mini-instruct, the prompt format is simple
# You may need to adjust this based on your specific fine-tuning
return prompt
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process the input data and generate a response using the Phi-4 model.
Args:
data (Dict[str, Any]): Input data containing the prompt and generation parameters
Returns:
Dict[str, Any]: Model response
"""
# Extract input parameters with defaults
prompt = data.pop("inputs", "")
parameters = data.pop("parameters", {})
# Get generation parameters with fallbacks to defaults
max_new_tokens = parameters.get("max_new_tokens", self.max_new_tokens)
temperature = parameters.get("temperature", self.temperature)
top_p = parameters.get("top_p", self.top_p)
do_sample = parameters.get("do_sample", self.do_sample)
stream = parameters.get("stream", False)
# Format the prompt according to model requirements
formatted_prompt = self.format_prompt(prompt)
# Tokenize the input
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.device)
# Handle streaming if requested
if stream:
return self._generate_stream(inputs, max_new_tokens, temperature, top_p, do_sample)
else:
return self._generate(inputs, max_new_tokens, temperature, top_p, do_sample)
def _generate(self, inputs, max_new_tokens, temperature, top_p, do_sample):
"""Generate text non-streaming mode"""
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode the generated text
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Return only the newly generated text (without the prompt)
prompt_length = len(self.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True))
response_text = generated_text[prompt_length:]
return {"generated_text": response_text}
def _generate_stream(self, inputs, max_new_tokens, temperature, top_p, do_sample):
"""Generate text in streaming mode"""
# Create a streamer object
streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
# Set up generation in a separate thread
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
pad_token_id=self.tokenizer.eos_token_id
)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
# Determine input text length to strip it from outputs
prompt_text = self.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)
prompt_length = len(prompt_text)
# Stream the output
def generate_stream():
# Skip the prompt part in the first chunk
first_chunk = True
for text in streamer:
if first_chunk:
# Only yield new tokens, not the original prompt
if len(text) > prompt_length:
yield {"generated_text": text[prompt_length:]}
first_chunk = False
else:
yield {"generated_text": text}
return generate_stream()
# For local testing
if __name__ == "__main__":
# Example usage
handler = EndpointHandler()
result = handler({"inputs": "What are the major features of Phi-4?"})
print(result)