File size: 5,768 Bytes
ec5bb4e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | import os
import torch
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the model and tokenizer for Phi-4 inference.
Args:
path (str): Path to the model directory
"""
# Set default parameters for inference
self.max_new_tokens = 4096
self.temperature = 0.7
self.top_p = 0.9
self.do_sample = True
# Determine if CUDA is available
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(path)
# Load model with appropriate settings
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=self.dtype,
device_map="auto" if self.device == "cuda" else None,
trust_remote_code=True
)
# Move model to device if CPU
if self.device == "cpu":
self.model = self.model.to(self.device)
# Set model to evaluation mode
self.model.eval()
print(f"Model loaded on {self.device} using {self.dtype}")
def format_prompt(self, prompt: str) -> str:
"""
Format the user prompt for Phi-4 model.
Args:
prompt (str): User input prompt
Returns:
str: Formatted prompt
"""
# For Phi-4-mini-instruct, the prompt format is simple
# You may need to adjust this based on your specific fine-tuning
return prompt
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process the input data and generate a response using the Phi-4 model.
Args:
data (Dict[str, Any]): Input data containing the prompt and generation parameters
Returns:
Dict[str, Any]: Model response
"""
# Extract input parameters with defaults
prompt = data.pop("inputs", "")
parameters = data.pop("parameters", {})
# Get generation parameters with fallbacks to defaults
max_new_tokens = parameters.get("max_new_tokens", self.max_new_tokens)
temperature = parameters.get("temperature", self.temperature)
top_p = parameters.get("top_p", self.top_p)
do_sample = parameters.get("do_sample", self.do_sample)
stream = parameters.get("stream", False)
# Format the prompt according to model requirements
formatted_prompt = self.format_prompt(prompt)
# Tokenize the input
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.device)
# Handle streaming if requested
if stream:
return self._generate_stream(inputs, max_new_tokens, temperature, top_p, do_sample)
else:
return self._generate(inputs, max_new_tokens, temperature, top_p, do_sample)
def _generate(self, inputs, max_new_tokens, temperature, top_p, do_sample):
"""Generate text non-streaming mode"""
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode the generated text
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Return only the newly generated text (without the prompt)
prompt_length = len(self.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True))
response_text = generated_text[prompt_length:]
return {"generated_text": response_text}
def _generate_stream(self, inputs, max_new_tokens, temperature, top_p, do_sample):
"""Generate text in streaming mode"""
# Create a streamer object
streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
# Set up generation in a separate thread
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
pad_token_id=self.tokenizer.eos_token_id
)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
# Determine input text length to strip it from outputs
prompt_text = self.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)
prompt_length = len(prompt_text)
# Stream the output
def generate_stream():
# Skip the prompt part in the first chunk
first_chunk = True
for text in streamer:
if first_chunk:
# Only yield new tokens, not the original prompt
if len(text) > prompt_length:
yield {"generated_text": text[prompt_length:]}
first_chunk = False
else:
yield {"generated_text": text}
return generate_stream()
# For local testing
if __name__ == "__main__":
# Example usage
handler = EndpointHandler()
result = handler({"inputs": "What are the major features of Phi-4?"})
print(result) |