File size: 7,758 Bytes
d4d15db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
from typing import Dict, List, Any
from transformers import (
AutoTokenizer,
AutoModel,
AutoImageProcessor,
)
import torch
from PIL import Image
import base64
import io
# get dtype and device
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16
device = "cuda" if torch.cuda.is_available() else "cpu"
class EndpointHandler():
def __init__(self, path=""):
print(f"Initializing model on device: {device}")
print(f"Using dtype: {dtype}")
# load the model - using AutoModel like in local inference
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
self.image_processor = AutoImageProcessor.from_pretrained(path, trust_remote_code=True)
# Load model with explicit device mapping
if device == "cuda":
self.model = AutoModel.from_pretrained(
path,
torch_dtype=dtype,
trust_remote_code=True,
device_map="auto" # Automatically map to available GPUs
)
else:
self.model = AutoModel.from_pretrained(
path,
torch_dtype=dtype,
trust_remote_code=True
)
self.model = self.model.to(device)
print(f"Model loaded successfully on device: {self.model.device}")
print(f"Model dtype: {next(self.model.parameters()).dtype}")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` or `list`): messages in chat format or text input
parameters (:obj: `dict`): generation parameters
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
print("Call inside handler")
# get inputs
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
print("parameters", parameters)
# Remove parameters that might cause issues
parameters.pop("details", None)
parameters.pop("stop", None)
parameters.pop("return_full_text", None)
if "do_sample" in parameters:
parameters["do_sample"] = True
# Set default generation parameters
max_new_tokens = parameters.pop("max_new_tokens", 512)
temperature = parameters.pop("temperature", 0)
try:
# Handle different input formats
if isinstance(inputs, str):
# If it's a string, treat it as a simple text prompt
input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to(self.model.device)
generated_ids = self.model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
**parameters
)
prompt_len = input_ids.shape[1]
generated_ids = generated_ids[:, prompt_len:]
output_text = self.tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return [{"generated_text": output_text[0]}]
elif isinstance(inputs, list):
# Handle chat format with images
messages = inputs
# Apply chat template
input_ids = self.tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True
)
input_text = self.tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
print(input_text)
input_ids = torch.tensor([input_ids]).to(self.model.device)
# Process ALL images if present
pixel_values_list = []
grid_thws_list = []
# Look for images in the messages
for message in messages:
if isinstance(message.get("content"), list):
for content_item in message["content"]:
if content_item.get("type") == "image_url":
image_data = content_item.get("image_url").get("url", "")
if image_data.startswith("data:image"):
# Decode base64 image
image_data = image_data.split(",")[1]
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
# Process each image individually
info = self.image_processor.preprocess(images=[image])
pixel_values = torch.tensor(info['pixel_values']).to(dtype=dtype, device=self.model.device)
grid_thws = torch.tensor(info['image_grid_thw']).to(self.model.device)
pixel_values_list.append(pixel_values)
grid_thws_list.append(grid_thws)
# Generate response
if pixel_values_list and grid_thws_list:
# Multi-modal generation with images
# Concatenate all pixel_values and grid_thws for batch processing
all_pixel_values = torch.cat(pixel_values_list, dim=0)
all_grid_thws = torch.cat(grid_thws_list, dim=0)
print(f"Processing {len(pixel_values_list)} images")
print(f"pixel_values shape: {all_pixel_values.shape}")
print(f"grid_thws shape: {all_grid_thws.shape}")
print("grid_thws", all_grid_thws)
# Ensure all tensors are on the same device as the model
all_pixel_values = all_pixel_values.to(self.model.device)
all_grid_thws = all_grid_thws.to(self.model.device)
with torch.no_grad():
generated_ids = self.model.generate(
input_ids,
pixel_values=all_pixel_values,
grid_thws=all_grid_thws,
max_new_tokens=max_new_tokens,
temperature=temperature,
**parameters
)
else:
# Text-only generation
generated_ids = self.model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
**parameters
)
prompt_len = input_ids.shape[1]
generated_ids = generated_ids[:, prompt_len:]
output_text = self.tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print("##Model Response##", output_text)
return [{"generated_text": output_text[0]}]
else:
raise ValueError(f"Unsupported input type: {type(inputs)}")
except Exception as e:
print(f"Error during inference: {str(e)}")
return [{"error": str(e)}]
|