|
|
|
|
|
|
|
|
from typing import List, Optional, Union, Dict |
|
|
import requests |
|
|
import base64 |
|
|
import io |
|
|
import os |
|
|
from PIL import Image |
|
|
import math |
|
|
import torch |
|
|
import copy |
|
|
import numpy as np |
|
|
|
|
|
from transformers import PreTrainedTokenizer |
|
|
from transformers.processing_utils import ProcessorMixin |
|
|
from transformers.image_utils import ImageInput |
|
|
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput |
|
|
from transformers.feature_extraction_utils import BatchFeature |
|
|
|
|
|
from .configuration_logics import LogicsConfig |
|
|
from .siglip_encoder import SigLipImageProcessor |
|
|
from transformers import AutoImageProcessor |
|
|
try: |
|
|
from decord import VideoReader, cpu |
|
|
except ImportError: |
|
|
print("Please install pyav to use video processing functions.") |
|
|
|
|
|
AutoImageProcessor.register(LogicsConfig, SigLipImageProcessor) |
|
|
|
|
|
|
|
|
IGNORE_INDEX = -100 |
|
|
IMAGE_TOKEN_INDEX = -200 |
|
|
DEFAULT_IMAGE_TOKEN = "<image>" |
|
|
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>" |
|
|
DEFAULT_IM_START_TOKEN = "<im_start>" |
|
|
DEFAULT_IM_END_TOKEN = "<im_end>" |
|
|
|
|
|
|
|
|
|
|
|
def process_video_with_decord(video_file, data_args): |
|
|
vr = VideoReader(video_file, ctx=cpu(0), num_threads=1) |
|
|
total_frame_num = len(vr) |
|
|
video_time = total_frame_num / vr.get_avg_fps() |
|
|
avg_fps = round(vr.get_avg_fps() / data_args.video_fps) |
|
|
frame_idx = [i for i in range(0, total_frame_num, avg_fps)] |
|
|
frame_time = [i / avg_fps for i in frame_idx] |
|
|
|
|
|
if data_args.frames_upbound > 0: |
|
|
if len(frame_idx) > data_args.frames_upbound or data_args.force_sample: |
|
|
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, data_args.frames_upbound, dtype=int) |
|
|
frame_idx = uniform_sampled_frames.tolist() |
|
|
frame_time = [i / vr.get_avg_fps() for i in frame_idx] |
|
|
|
|
|
try: |
|
|
video = vr.get_batch(frame_idx).asnumpy() |
|
|
except: |
|
|
video = torch.zeros((10, 720, 720, 3)).numpy() |
|
|
print(f"load {video_file} error, use empty tensor instead.") |
|
|
frame_time = ",".join([f"{i:.2f}s" for i in frame_time]) |
|
|
num_frames_to_sample = num_frames = len(frame_idx) |
|
|
|
|
|
vr.seek(0) |
|
|
return video |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resize_and_pad_image(image, target_resolution): |
|
|
""" |
|
|
Resize and pad an image to a target resolution while maintaining aspect ratio. |
|
|
|
|
|
Args: |
|
|
image (PIL.Image.Image): The input image. |
|
|
target_resolution (tuple): The target resolution (width, height) of the image. |
|
|
|
|
|
Returns: |
|
|
PIL.Image.Image: The resized and padded image. |
|
|
""" |
|
|
original_width, original_height = image.size |
|
|
target_width, target_height = target_resolution |
|
|
|
|
|
|
|
|
scale_w = target_width / original_width |
|
|
scale_h = target_height / original_height |
|
|
|
|
|
if scale_w < scale_h: |
|
|
|
|
|
new_width = target_width |
|
|
new_height = min(math.ceil(original_height * scale_w), target_height) |
|
|
else: |
|
|
|
|
|
new_height = target_height |
|
|
new_width = min(math.ceil(original_width * scale_h), target_width) |
|
|
|
|
|
|
|
|
resized_image = image.resize((new_width, new_height)) |
|
|
|
|
|
|
|
|
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0)) |
|
|
paste_x = (target_width - new_width) // 2 |
|
|
paste_y = (target_height - new_height) // 2 |
|
|
new_image.paste(resized_image, (paste_x, paste_y)) |
|
|
|
|
|
return new_image |
|
|
|
|
|
|
|
|
|
|
|
def select_best_resolution(original_size, possible_resolutions): |
|
|
""" |
|
|
Selects the best resolution from a list of possible resolutions based on the original size. |
|
|
|
|
|
Args: |
|
|
original_size (tuple): The original size of the image in the format (width, height). |
|
|
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. |
|
|
|
|
|
Returns: |
|
|
tuple: The best fit resolution in the format (width, height). |
|
|
""" |
|
|
original_width, original_height = original_size |
|
|
best_fit = None |
|
|
max_effective_resolution = 0 |
|
|
min_wasted_resolution = float("inf") |
|
|
|
|
|
for width, height in possible_resolutions: |
|
|
|
|
|
scale = min(width / original_width, height / original_height) |
|
|
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) |
|
|
|
|
|
|
|
|
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) |
|
|
wasted_resolution = (width * height) - effective_resolution |
|
|
|
|
|
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): |
|
|
max_effective_resolution = effective_resolution |
|
|
min_wasted_resolution = wasted_resolution |
|
|
best_fit = (width, height) |
|
|
|
|
|
return best_fit |
|
|
|
|
|
|
|
|
def divide_to_patches(image, patch_size): |
|
|
""" |
|
|
Divides an image into patches of a specified size. |
|
|
|
|
|
Args: |
|
|
image (PIL.Image.Image): The input image. |
|
|
patch_size (int): The size of each patch. |
|
|
|
|
|
Returns: |
|
|
list: A list of PIL.Image.Image objects representing the patches. |
|
|
""" |
|
|
patches = [] |
|
|
width, height = image.size |
|
|
for i in range(0, height, patch_size): |
|
|
for j in range(0, width, patch_size): |
|
|
box = (j, i, j + patch_size, i + patch_size) |
|
|
patch = image.crop(box) |
|
|
patches.append(patch) |
|
|
|
|
|
return patches |
|
|
|
|
|
|
|
|
|
|
|
def process_anyres_image(image, processor, grid_pinpoints): |
|
|
""" |
|
|
Process an image with variable resolutions. |
|
|
|
|
|
Args: |
|
|
image (PIL.Image.Image): The input image to be processed. |
|
|
processor: The image processor object. |
|
|
grid_pinpoints (str): A string representation of a list of possible resolutions. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: A tensor containing the processed image patches. |
|
|
""" |
|
|
possible_resolutions = grid_pinpoints |
|
|
|
|
|
best_resolution = select_best_resolution(image.size, possible_resolutions) |
|
|
image_padded = resize_and_pad_image(image, best_resolution) |
|
|
|
|
|
patches = divide_to_patches(image_padded, processor.crop_size["height"]) |
|
|
|
|
|
if isinstance(processor.size, dict): |
|
|
shortest_edge = processor.size["height"] |
|
|
else: |
|
|
shortest_edge = min(processor.size) |
|
|
image_original_resize = image.resize((shortest_edge, shortest_edge)) |
|
|
|
|
|
image_patches = [image_original_resize] + patches |
|
|
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches] |
|
|
return torch.stack(image_patches, dim=0) |
|
|
|
|
|
|
|
|
|
|
|
def preprocess_qwen(sources, tokenizer: PreTrainedTokenizer,enable_thinking: bool = True, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict: |
|
|
|
|
|
roles = {"human": "user", "gpt": "assistant", "system": "system"} |
|
|
|
|
|
tokenizer = copy.deepcopy(tokenizer) |
|
|
if has_image: |
|
|
tokenizer.add_tokens(["<image>"], special_tokens=True) |
|
|
|
|
|
image_token_index = tokenizer.convert_tokens_to_ids("<image>") |
|
|
|
|
|
|
|
|
chat_template = "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in message.content %}\n {%- set content = message.content.split('</think>')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}" |
|
|
tokenizer.chat_template = chat_template |
|
|
|
|
|
input_ids, targets = [], [] |
|
|
messages = [] |
|
|
|
|
|
messages.append({"role" : "system", "content" : system_message}) |
|
|
|
|
|
for conv in sources: |
|
|
try: |
|
|
role = conv["role"] |
|
|
content = conv["content"] |
|
|
except: |
|
|
role = conv["from"] |
|
|
content = conv["value"] |
|
|
role = roles.get(role, role) |
|
|
|
|
|
if role == 'assistant': |
|
|
if content is None: |
|
|
continue |
|
|
else: |
|
|
pass |
|
|
elif role == 'user': |
|
|
pass |
|
|
conv = {"role" : role, "content" : content} |
|
|
messages.append(conv) |
|
|
|
|
|
input_id = tokenizer.apply_chat_template(messages, add_generation_prompt=True, enable_thinking=enable_thinking) |
|
|
|
|
|
|
|
|
for idx, encode_id in enumerate(input_id): |
|
|
if encode_id == image_token_index: |
|
|
input_id[idx] = IMAGE_TOKEN_INDEX |
|
|
input_ids = torch.tensor([input_id], dtype=torch.long) |
|
|
|
|
|
return input_ids |
|
|
|
|
|
|
|
|
|
|
|
def prepare_prompt_ids(question, img_path, tokenizer, enable_thinking): |
|
|
if img_path is not None: |
|
|
question = DEFAULT_IMAGE_TOKEN * len(img_path) + f"\n{question}" |
|
|
input_ids = preprocess_qwen([{'from': 'human', 'value': question},{'from': 'gpt','value': None}], tokenizer, enable_thinking, has_image=True) |
|
|
else: |
|
|
input_ids = preprocess_qwen([{'from': 'human', 'value': question},{'from': 'gpt','value': None}], tokenizer, enable_thinking, has_image=False) |
|
|
|
|
|
return input_ids |
|
|
|
|
|
|
|
|
class LogicsProcessor(ProcessorMixin): |
|
|
|
|
|
config_class = LogicsConfig |
|
|
attributes = ["image_processor", "tokenizer"] |
|
|
image_processor_class = "SigLipImageProcessor" |
|
|
tokenizer_class = "Qwen2Tokenizer" |
|
|
|
|
|
|
|
|
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): |
|
|
self.config= LogicsConfig() |
|
|
self.image_token = "<image>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token |
|
|
self.image_token_id = ( |
|
|
tokenizer.image_token_id |
|
|
if getattr(tokenizer, "image_token_id", None) |
|
|
else tokenizer.convert_tokens_to_ids(self.image_token) |
|
|
) |
|
|
|
|
|
super().__init__(image_processor, tokenizer, chat_template=chat_template) |
|
|
self.text_tokenizer = prepare_prompt_ids |
|
|
|
|
|
|
|
|
def _load_image(self, image_source: Union[str, Image.Image]) -> Image.Image: |
|
|
if isinstance(image_source, Image.Image): |
|
|
return image_source |
|
|
|
|
|
if not isinstance(image_source, str): |
|
|
raise TypeError(f"Unsupported image source type: {type(image_source)}") |
|
|
|
|
|
if image_source.startswith("http://") or image_source.startswith("https://"): |
|
|
try: |
|
|
response = requests.get(image_source) |
|
|
response.raise_for_status() |
|
|
image_bytes = response.content |
|
|
image = Image.open(io.BytesIO(image_bytes)) |
|
|
except Exception as e: |
|
|
raise IOError(f"Failed to load image from URL: {image_source}") from e |
|
|
|
|
|
elif image_source.startswith("data:image"): |
|
|
try: |
|
|
header, encoded_data = image_source.split(',', 1) |
|
|
image_bytes = base64.b64decode(encoded_data) |
|
|
image = Image.open(io.BytesIO(image_bytes)) |
|
|
except Exception as e: |
|
|
raise IOError(f"Failed to decode Base64 image string.") from e |
|
|
|
|
|
|
|
|
elif os.path.exists(image_source): |
|
|
try: |
|
|
image = Image.open(image_source) |
|
|
except Exception as e: |
|
|
raise IOError(f"Failed to load image from file path: {image_source}") from e |
|
|
else: |
|
|
raise ValueError(f"Input string is not a valid file path, URL, or Base64-encoded image: {image_source[:100]}...") |
|
|
|
|
|
return image.convert('RGB') |
|
|
|
|
|
|
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
images: ImageInput = None, |
|
|
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, |
|
|
**kwargs |
|
|
) -> BatchFeature: |
|
|
device='cuda' |
|
|
if not isinstance(images, list): |
|
|
images = [images] |
|
|
|
|
|
|
|
|
|
|
|
if images is not None and all([x.split('.')[-1] not in {"mp4", "avi", "mov", "mkv"} for x in images]): |
|
|
|
|
|
image_inputs = {} |
|
|
pixel_values = [] |
|
|
image_grid_thw = [] |
|
|
modalities = [] |
|
|
for image_source in images: |
|
|
image = self._load_image(image_source) |
|
|
width, height = image.size |
|
|
image = process_anyres_image(image, self.image_processor, self.config.image_grid_pinpoints).to(dtype=torch.bfloat16, device=device) |
|
|
pixel_values.append(image) |
|
|
image_grid_thw.append((width, height)) |
|
|
modalities.append("image") |
|
|
if all(x.shape == pixel_values[0].shape for x in pixel_values): |
|
|
pixel_values = torch.stack(pixel_values, dim=0) |
|
|
image_inputs["images_inputs"] = pixel_values |
|
|
|
|
|
image_inputs["image_sizes"] = image_grid_thw |
|
|
image_inputs["modalities"] = modalities |
|
|
|
|
|
|
|
|
elif images is not None: |
|
|
|
|
|
image_inputs={} |
|
|
video = images[0] |
|
|
video_frames = process_video_with_decord(video, self.config) |
|
|
video_tensor=self.image_processor.preprocess(video_frames, return_tensors="pt")["pixel_values"].to(dtype=torch.bfloat16, device=device) |
|
|
|
|
|
image_inputs["images_inputs"]=[video_tensor] |
|
|
image_inputs["image_sizes"] = [video_frames[0].size] |
|
|
image_inputs["modalities"] = ["video"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text_ids = self.text_tokenizer(question=text, img_path=images, tokenizer=self.tokenizer, enable_thinking=self.config.enable_thinking) |
|
|
text_inputs = {} |
|
|
text_inputs["text_inputs"] = text_ids |
|
|
|
|
|
if images is not None: |
|
|
return BatchFeature(data={**text_inputs, **image_inputs}) |
|
|
else: |
|
|
return BatchFeature(data={**text_inputs}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def batch_decode(self, *args, **kwargs): |
|
|
return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
|
|
|
|