|
|
from typing import List, Dict, Optional, Tuple |
|
|
from PIL import Image, ImageOps, ImageDraw, ImageFont |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import transforms |
|
|
from transformers import TextStreamer |
|
|
from transformers.tokenization_utils import PreTrainedTokenizer as T |
|
|
from abc import ABC |
|
|
import re |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
def load_image(image_path): |
|
|
try: |
|
|
image = Image.open(image_path) |
|
|
corrected_image = ImageOps.exif_transpose(image) |
|
|
|
|
|
return corrected_image |
|
|
|
|
|
except Exception as e: |
|
|
print(f"error: {e}") |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def re_match(text): |
|
|
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' |
|
|
matches = re.findall(pattern, text, re.DOTALL) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mathes_image = [] |
|
|
mathes_other = [] |
|
|
for a_match in matches: |
|
|
if '<|ref|>image<|/ref|>' in a_match[0]: |
|
|
mathes_image.append(a_match[0]) |
|
|
else: |
|
|
mathes_other.append(a_match[0]) |
|
|
return matches, mathes_image, mathes_other |
|
|
|
|
|
|
|
|
def extract_coordinates_and_label(ref_text, image_width, image_height): |
|
|
|
|
|
try: |
|
|
label_type = ref_text[1] |
|
|
cor_list = eval(ref_text[2]) |
|
|
except Exception as e: |
|
|
print(e) |
|
|
return None |
|
|
|
|
|
return (label_type, cor_list) |
|
|
|
|
|
|
|
|
def draw_bounding_boxes(image, refs, ouput_path): |
|
|
|
|
|
image_width, image_height = image.size |
|
|
|
|
|
img_draw = image.copy() |
|
|
draw = ImageDraw.Draw(img_draw) |
|
|
|
|
|
overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0)) |
|
|
draw2 = ImageDraw.Draw(overlay) |
|
|
|
|
|
font = ImageFont.load_default() |
|
|
|
|
|
img_idx = 0 |
|
|
|
|
|
for i, ref in enumerate(refs): |
|
|
try: |
|
|
result = extract_coordinates_and_label(ref, image_width, image_height) |
|
|
if result: |
|
|
label_type, points_list = result |
|
|
|
|
|
color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255)) |
|
|
|
|
|
color_a = color + (20, ) |
|
|
for points in points_list: |
|
|
x1, y1, x2, y2 = points |
|
|
|
|
|
x1 = int(x1 / 999 * image_width) |
|
|
y1 = int(y1 / 999 * image_height) |
|
|
|
|
|
x2 = int(x2 / 999 * image_width) |
|
|
y2 = int(y2 / 999 * image_height) |
|
|
|
|
|
if label_type == 'image': |
|
|
try: |
|
|
cropped = image.crop((x1, y1, x2, y2)) |
|
|
cropped.save(f"{ouput_path}/images/{img_idx}.jpg") |
|
|
except Exception as e: |
|
|
print(e) |
|
|
pass |
|
|
img_idx += 1 |
|
|
|
|
|
try: |
|
|
if label_type == 'title': |
|
|
draw.rectangle([x1, y1, x2, y2], outline=color, width=4) |
|
|
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) |
|
|
else: |
|
|
draw.rectangle([x1, y1, x2, y2], outline=color, width=2) |
|
|
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) |
|
|
text_x = x1 |
|
|
text_y = max(0, y1 - 15) |
|
|
|
|
|
|
|
|
text_bbox = draw.textbbox((0, 0), label_type, font=font) |
|
|
text_width = text_bbox[2] - text_bbox[0] |
|
|
text_height = text_bbox[3] - text_bbox[1] |
|
|
draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height], |
|
|
fill=(255, 255, 255, 30)) |
|
|
|
|
|
draw.text((text_x, text_y), label_type, font=font, fill=color) |
|
|
except: |
|
|
pass |
|
|
except: |
|
|
continue |
|
|
img_draw.paste(overlay, (0, 0), overlay) |
|
|
return img_draw |
|
|
|
|
|
|
|
|
def process_image_with_refs(image, ref_texts, output_path): |
|
|
|
|
|
result_image = draw_bounding_boxes(image, ref_texts, output_path) |
|
|
|
|
|
return result_image |
|
|
|
|
|
|
|
|
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): |
|
|
best_ratio_diff = float('inf') |
|
|
best_ratio = (1, 1) |
|
|
area = width * height |
|
|
for ratio in target_ratios: |
|
|
target_aspect_ratio = ratio[0] / ratio[1] |
|
|
ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
|
|
if ratio_diff < best_ratio_diff: |
|
|
best_ratio_diff = ratio_diff |
|
|
best_ratio = ratio |
|
|
elif ratio_diff == best_ratio_diff: |
|
|
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: |
|
|
best_ratio = ratio |
|
|
|
|
|
|
|
|
return best_ratio |
|
|
|
|
|
|
|
|
def dynamic_preprocess(image, min_num=2, max_num=9, image_size=640, use_thumbnail=False): |
|
|
orig_width, orig_height = image.size |
|
|
aspect_ratio = orig_width / orig_height |
|
|
|
|
|
|
|
|
target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) |
|
|
|
|
|
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
|
|
|
|
|
|
|
|
target_aspect_ratio = find_closest_aspect_ratio( |
|
|
aspect_ratio, |
|
|
target_ratios, |
|
|
orig_width, |
|
|
orig_height, |
|
|
image_size |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
target_width = image_size * target_aspect_ratio[0] |
|
|
target_height = image_size * target_aspect_ratio[1] |
|
|
blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
|
|
|
|
|
|
|
|
resized_img = image.resize((target_width, target_height)) |
|
|
processed_images = [] |
|
|
for i in range(blocks): |
|
|
box = ( |
|
|
(i % (target_width // image_size)) * image_size, |
|
|
(i // (target_width // image_size)) * image_size, |
|
|
((i % (target_width // image_size)) + 1) * image_size, |
|
|
((i // (target_width // image_size)) + 1) * image_size |
|
|
) |
|
|
|
|
|
split_img = resized_img.crop(box) |
|
|
processed_images.append(split_img) |
|
|
|
|
|
assert len(processed_images) == blocks |
|
|
|
|
|
|
|
|
if use_thumbnail and len(processed_images) != 1: |
|
|
thumbnail_img = image.resize((image_size, image_size)) |
|
|
processed_images.append(thumbnail_img) |
|
|
return processed_images, target_aspect_ratio |
|
|
|
|
|
|
|
|
def normalize_transform(mean, std): |
|
|
if mean is None and std is None: |
|
|
transform = None |
|
|
elif mean is None and std is not None: |
|
|
mean = [0.] * len(std) |
|
|
transform = transforms.Normalize(mean=mean, std=std) |
|
|
elif mean is not None and std is None: |
|
|
std = [1.] * len(mean) |
|
|
transform = transforms.Normalize(mean=mean, std=std) |
|
|
else: |
|
|
transform = transforms.Normalize(mean=mean, std=std) |
|
|
|
|
|
return transform |
|
|
|
|
|
def format_messages( |
|
|
tokenizer: T, |
|
|
conversations: List[Dict[str, str]], |
|
|
system_prompt: str = "", |
|
|
): |
|
|
if system_prompt is not None and system_prompt != "": |
|
|
sys_prompt = { |
|
|
"role": "system", |
|
|
"content": system_prompt, |
|
|
} |
|
|
conversations = [sys_prompt] + conversations |
|
|
|
|
|
sft_prompt = tokenizer.apply_chat_template( |
|
|
conversations, |
|
|
) |
|
|
|
|
|
return sft_prompt |
|
|
|
|
|
|
|
|
def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False): |
|
|
""" |
|
|
Encode text with optional BOS/EOS tokens. |
|
|
|
|
|
Note: Qwen2VL tokenizer has bos_token_id=None, so we skip BOS for Qwen. |
|
|
The chat template handles special tokens automatically. |
|
|
""" |
|
|
t = tokenizer.encode(text, add_special_tokens=False) |
|
|
bos_id = tokenizer.bos_token_id |
|
|
eos_id = tokenizer.eos_token_id |
|
|
|
|
|
|
|
|
if bos and bos_id is not None: |
|
|
t = [bos_id] + t |
|
|
|
|
|
|
|
|
if eos and eos_id is not None: |
|
|
t = t + [eos_id] |
|
|
|
|
|
return t |
|
|
|
|
|
def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]: |
|
|
pil_images = [] |
|
|
|
|
|
for message in conversations: |
|
|
pil_image = None |
|
|
|
|
|
if message["role"].lower() == "user": |
|
|
if isinstance(message["content"], List): |
|
|
for d in message["content"]: |
|
|
if d.get("type", "") == "image": |
|
|
|
|
|
image_path = d.get("image") or d.get("data", "") |
|
|
pil_image = load_image(image_path) |
|
|
|
|
|
elif isinstance(message["content"], Dict): |
|
|
if message["content"].get("type", "") == "image": |
|
|
|
|
|
image_path = message["content"].get("image") or message["content"].get("data", "") |
|
|
pil_image = load_image(image_path) |
|
|
|
|
|
if pil_image is not None: |
|
|
pil_images.append(pil_image) |
|
|
|
|
|
return pil_images |
|
|
|
|
|
|
|
|
class BaseTransform(ABC): |
|
|
|
|
|
def set_rng(self, *args, **kwargs): |
|
|
pass |
|
|
|
|
|
def __call__(self, *args, **kwargs) -> torch.Tensor: |
|
|
pass |
|
|
|
|
|
@property |
|
|
def default_shape(self): |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class BasicImageTransform(BaseTransform): |
|
|
def __init__( |
|
|
self, |
|
|
mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), |
|
|
std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), |
|
|
normalize: bool = True |
|
|
): |
|
|
self.mean = mean |
|
|
self.std = std |
|
|
|
|
|
transform_pipelines = [ |
|
|
transforms.ToTensor() |
|
|
] |
|
|
|
|
|
normalize = normalize_transform(mean, std) if normalize else nn.Identity() |
|
|
if normalize is not None: |
|
|
transform_pipelines.append(normalize) |
|
|
|
|
|
self.transform = transforms.Compose(transform_pipelines) |
|
|
|
|
|
def __call__(self, x): |
|
|
x = self.transform(x) |
|
|
return x |
|
|
|
|
|
class NoEOSTextStreamer(TextStreamer): |
|
|
|
|
|
def on_finalized_text(self, text: str, stream_end: bool = False): |
|
|
eos_text = self.tokenizer.decode([self.tokenizer.eos_token_id], skip_special_tokens=False) |
|
|
text = text.replace(eos_text, "\n") |
|
|
print(text, flush=True, end="") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import math |
|
|
from dataclasses import dataclass |
|
|
from typing import Dict, List, Any, Tuple |
|
|
from PIL import Image, ImageOps |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
import io |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DeepQwenDataCollator: |
|
|
""" |
|
|
Data collator for DeepQwen model using Qwen2VL tokenizer. |
|
|
|
|
|
This collator processes images using DeepSeek OCR's dynamic cropping algorithm |
|
|
while maintaining compatibility with Qwen2VL's tokenization format. |
|
|
|
|
|
Key token mappings (Qwen2VL): |
|
|
- image_token: <|image_pad|> (id=151655) |
|
|
- vision_start: <|vision_start|> (id=151652) |
|
|
- vision_end: <|vision_end|> (id=151653) |
|
|
- eos_token: <|im_end|> (id=151645) |
|
|
- NO bos_token (bos_token_id is None) |
|
|
|
|
|
Args: |
|
|
tokenizer: Qwen2VL Tokenizer |
|
|
model: Model |
|
|
image_size: Size for image patches (default: 640) |
|
|
base_size: Size for global view (default: 1024) |
|
|
crop_mode: Whether to use dynamic cropping for large images |
|
|
train_on_responses_only: If True, only train on assistant responses (mask user prompts) |
|
|
""" |
|
|
tokenizer: T |
|
|
model: Any |
|
|
image_size: int = 640 |
|
|
base_size: int = 1024 |
|
|
crop_mode: bool = True |
|
|
train_on_responses_only: bool = True |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
tokenizer, |
|
|
model, |
|
|
image_size: int = 640, |
|
|
base_size: int = 1024, |
|
|
crop_mode: bool = True, |
|
|
train_on_responses_only: bool = True, |
|
|
max_length: int = None, |
|
|
): |
|
|
self.tokenizer = tokenizer |
|
|
self.model = model |
|
|
self.image_size = image_size |
|
|
self.base_size = base_size |
|
|
self.crop_mode = crop_mode |
|
|
self.dtype = model.dtype |
|
|
self.train_on_responses_only = train_on_responses_only |
|
|
self.max_length = max_length |
|
|
|
|
|
|
|
|
|
|
|
self.image_token_id = getattr(tokenizer, 'image_token_id', None) |
|
|
if self.image_token_id is None: |
|
|
|
|
|
self.image_token_id = 151655 |
|
|
|
|
|
self.image_token = tokenizer.decode([self.image_token_id], skip_special_tokens=False) |
|
|
|
|
|
|
|
|
self.vision_start_token_id = getattr(tokenizer, 'vision_start_token_id', 151652) |
|
|
self.vision_end_token_id = getattr(tokenizer, 'vision_end_token_id', 151653) |
|
|
|
|
|
self.image_transform = BasicImageTransform( |
|
|
mean=(0.5, 0.5, 0.5), |
|
|
std=(0.5, 0.5, 0.5), |
|
|
normalize=True |
|
|
) |
|
|
self.patch_size = 16 |
|
|
self.downsample_ratio = 4 |
|
|
|
|
|
|
|
|
|
|
|
self.bos_id = tokenizer.bos_token_id |
|
|
self.eos_id = tokenizer.eos_token_id |
|
|
self.pad_token_id = tokenizer.pad_token_id |
|
|
|
|
|
def deserialize_image(self, image_data) -> Image.Image: |
|
|
"""Convert image data (bytes dict, PIL Image, or file path) to PIL Image in RGB mode""" |
|
|
if isinstance(image_data, Image.Image): |
|
|
return image_data.convert("RGB") |
|
|
elif isinstance(image_data, str): |
|
|
|
|
|
image = load_image(image_data) |
|
|
if image is None: |
|
|
raise ValueError(f"Failed to load image from path: {image_data}") |
|
|
return image.convert("RGB") |
|
|
elif isinstance(image_data, dict) and 'bytes' in image_data: |
|
|
image_bytes = image_data['bytes'] |
|
|
image = Image.open(io.BytesIO(image_bytes)) |
|
|
return image.convert("RGB") |
|
|
else: |
|
|
raise ValueError(f"Unsupported image format: {type(image_data)}") |
|
|
|
|
|
def calculate_image_token_count(self, image: Image.Image, crop_ratio: Tuple[int, int]) -> int: |
|
|
"""Calculate the number of tokens this image will generate""" |
|
|
num_queries = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio) |
|
|
num_queries_base = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio) |
|
|
|
|
|
width_crop_num, height_crop_num = crop_ratio |
|
|
|
|
|
if self.crop_mode: |
|
|
img_tokens = num_queries_base * num_queries_base + 1 |
|
|
if width_crop_num > 1 or height_crop_num > 1: |
|
|
img_tokens += (num_queries * width_crop_num + 1) * (num_queries * height_crop_num) |
|
|
else: |
|
|
img_tokens = num_queries * num_queries + 1 |
|
|
|
|
|
return img_tokens |
|
|
|
|
|
def process_image(self, image: Image.Image) -> Tuple[List, List, List, List, Tuple[int, int]]: |
|
|
""" |
|
|
Process a single image based on crop_mode and size thresholds |
|
|
|
|
|
Returns: |
|
|
Tuple of (images_list, images_crop_list, images_spatial_crop, tokenized_image, crop_ratio) |
|
|
""" |
|
|
images_list = [] |
|
|
images_crop_list = [] |
|
|
images_spatial_crop = [] |
|
|
|
|
|
if self.crop_mode: |
|
|
|
|
|
if image.size[0] <= 640 and image.size[1] <= 640: |
|
|
crop_ratio = (1, 1) |
|
|
images_crop_raw = [] |
|
|
else: |
|
|
images_crop_raw, crop_ratio = dynamic_preprocess( |
|
|
image, min_num=2, max_num=9, |
|
|
image_size=self.image_size, use_thumbnail=False |
|
|
) |
|
|
|
|
|
|
|
|
global_view = ImageOps.pad( |
|
|
image, (self.base_size, self.base_size), |
|
|
color=tuple(int(x * 255) for x in self.image_transform.mean) |
|
|
) |
|
|
images_list.append(self.image_transform(global_view).to(self.dtype)) |
|
|
|
|
|
width_crop_num, height_crop_num = crop_ratio |
|
|
images_spatial_crop.append([width_crop_num, height_crop_num]) |
|
|
|
|
|
|
|
|
if width_crop_num > 1 or height_crop_num > 1: |
|
|
for crop_img in images_crop_raw: |
|
|
images_crop_list.append( |
|
|
self.image_transform(crop_img).to(self.dtype) |
|
|
) |
|
|
|
|
|
|
|
|
num_queries = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio) |
|
|
num_queries_base = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio) |
|
|
|
|
|
tokenized_image = ([self.image_token_id] * num_queries_base + [self.image_token_id]) * num_queries_base |
|
|
tokenized_image += [self.image_token_id] |
|
|
|
|
|
if width_crop_num > 1 or height_crop_num > 1: |
|
|
tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * ( |
|
|
num_queries * height_crop_num) |
|
|
|
|
|
else: |
|
|
crop_ratio = (1, 1) |
|
|
images_spatial_crop.append([1, 1]) |
|
|
|
|
|
|
|
|
if self.base_size <= 640: |
|
|
resized_image = image.resize((self.base_size, self.base_size), Image.LANCZOS) |
|
|
images_list.append(self.image_transform(resized_image).to(self.dtype)) |
|
|
else: |
|
|
global_view = ImageOps.pad( |
|
|
image, (self.base_size, self.base_size), |
|
|
color=tuple(int(x * 255) for x in self.image_transform.mean) |
|
|
) |
|
|
images_list.append(self.image_transform(global_view).to(self.dtype)) |
|
|
|
|
|
num_queries = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio) |
|
|
tokenized_image = ([self.image_token_id] * num_queries + [self.image_token_id]) * num_queries |
|
|
tokenized_image += [self.image_token_id] |
|
|
|
|
|
return images_list, images_crop_list, images_spatial_crop, tokenized_image, crop_ratio |
|
|
|
|
|
def process_single_sample(self, messages: List[Dict]) -> Dict[str, Any]: |
|
|
""" |
|
|
Process a single conversation into model inputs. |
|
|
|
|
|
Expected message format (Qwen2.5-VL native style): |
|
|
[ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "image", "image": <PIL.Image or path or bytes>}, |
|
|
{"type": "text", "text": "Describe this image."} |
|
|
] |
|
|
}, |
|
|
{ |
|
|
"role": "assistant", |
|
|
"content": [{"type": "text", "text": "This is a description..."}] |
|
|
} |
|
|
] |
|
|
|
|
|
Also supports string content for backward compatibility. |
|
|
""" |
|
|
|
|
|
|
|
|
tokenized_str = [] |
|
|
images_seq_mask = [] |
|
|
images_list, images_crop_list, images_spatial_crop = [], [], [] |
|
|
|
|
|
prompt_token_count = -1 |
|
|
assistant_started = False |
|
|
|
|
|
|
|
|
|
|
|
for message in messages: |
|
|
role = message["role"].lower() |
|
|
content = message["content"] |
|
|
|
|
|
|
|
|
if role == "assistant": |
|
|
if not assistant_started: |
|
|
|
|
|
|
|
|
prompt_token_count = len(tokenized_str) |
|
|
assistant_started = True |
|
|
|
|
|
|
|
|
if isinstance(content, list): |
|
|
|
|
|
content_parts = [] |
|
|
|
|
|
for item in content: |
|
|
item_type = item.get("type", "") |
|
|
|
|
|
if item_type == "image": |
|
|
|
|
|
image_data = item.get("image") or item.get("data") |
|
|
if image_data is not None: |
|
|
pil_image = self.deserialize_image(image_data) |
|
|
|
|
|
|
|
|
img_list, crop_list, spatial_crop, tok_img, _ = self.process_image(pil_image) |
|
|
|
|
|
images_list.extend(img_list) |
|
|
images_crop_list.extend(crop_list) |
|
|
images_spatial_crop.extend(spatial_crop) |
|
|
|
|
|
|
|
|
tokenized_str.extend(tok_img) |
|
|
images_seq_mask.extend([True] * len(tok_img)) |
|
|
|
|
|
elif item_type == "text": |
|
|
text = item.get("text", "") |
|
|
|
|
|
|
|
|
if role == "assistant" and item == content[-1]: |
|
|
if self.tokenizer.eos_token: |
|
|
text = f"{text.strip()}{self.tokenizer.eos_token}" |
|
|
|
|
|
|
|
|
tokenized_text = text_encode(self.tokenizer, text, bos=False, eos=False) |
|
|
tokenized_str.extend(tokenized_text) |
|
|
images_seq_mask.extend([False] * len(tokenized_text)) |
|
|
|
|
|
else: |
|
|
|
|
|
text_content = content |
|
|
|
|
|
|
|
|
if role == "assistant" and self.tokenizer.eos_token: |
|
|
text_content = f"{text_content.strip()}{self.tokenizer.eos_token}" |
|
|
|
|
|
|
|
|
tokenized_text = text_encode(self.tokenizer, text_content, bos=False, eos=False) |
|
|
tokenized_str.extend(tokenized_text) |
|
|
images_seq_mask.extend([False] * len(tokenized_text)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not assistant_started: |
|
|
print("Warning: No assistant message found in sample. Masking all tokens.") |
|
|
prompt_token_count = len(tokenized_str) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
images_ori = torch.stack(images_list, dim=0) |
|
|
images_spatial_crop_tensor = torch.tensor(images_spatial_crop, dtype=torch.long) |
|
|
|
|
|
if images_crop_list: |
|
|
images_crop = torch.stack(images_crop_list, dim=0) |
|
|
else: |
|
|
images_crop = torch.zeros((1, 3, self.base_size, self.base_size), dtype=self.dtype) |
|
|
|
|
|
return { |
|
|
"input_ids": torch.tensor(tokenized_str, dtype=torch.long), |
|
|
"images_seq_mask": torch.tensor(images_seq_mask, dtype=torch.bool), |
|
|
"images_ori": images_ori, |
|
|
"images_crop": images_crop, |
|
|
"images_spatial_crop": images_spatial_crop_tensor, |
|
|
"prompt_token_count": prompt_token_count, |
|
|
} |
|
|
|
|
|
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Collate batch of samples. |
|
|
|
|
|
Expected feature format: |
|
|
{ |
|
|
"prompt": str, # The user's question/instruction |
|
|
"response": str, # The assistant's response |
|
|
"image": PIL.Image or bytes dict # The image |
|
|
} |
|
|
|
|
|
This will be converted to Qwen2.5-VL native conversation format: |
|
|
[ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "image", "image": <PIL.Image>}, |
|
|
{"type": "text", "text": "<prompt>"} |
|
|
] |
|
|
}, |
|
|
{ |
|
|
"role": "assistant", |
|
|
"content": [{"type": "text", "text": "<response>"}] |
|
|
} |
|
|
] |
|
|
""" |
|
|
batch_data = [] |
|
|
|
|
|
|
|
|
for feature in features: |
|
|
try: |
|
|
|
|
|
image_data = feature.get('image') or feature.get('image_path') |
|
|
if image_data is None: |
|
|
raise ValueError("Sample missing both 'image' and 'image_path' keys") |
|
|
|
|
|
|
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "image", "image": image_data}, |
|
|
{"type": "text", "text": feature['prompt']} |
|
|
] |
|
|
}, |
|
|
{ |
|
|
"role": "assistant", |
|
|
"content": [ |
|
|
{"type": "text", "text": feature["response"]} |
|
|
] |
|
|
} |
|
|
] |
|
|
|
|
|
processed = self.process_single_sample(messages) |
|
|
batch_data.append(processed) |
|
|
except Exception as e: |
|
|
print(f"Error processing sample: {e}") |
|
|
continue |
|
|
|
|
|
if not batch_data: |
|
|
raise ValueError("No valid samples in batch") |
|
|
|
|
|
|
|
|
input_ids_list = [item['input_ids'] for item in batch_data] |
|
|
images_seq_mask_list = [item['images_seq_mask'] for item in batch_data] |
|
|
prompt_token_counts = [item['prompt_token_count'] for item in batch_data] |
|
|
|
|
|
|
|
|
input_ids = pad_sequence(input_ids_list, batch_first=True, padding_value=self.pad_token_id) |
|
|
images_seq_mask = pad_sequence(images_seq_mask_list, batch_first=True, padding_value=False) |
|
|
|
|
|
|
|
|
if self.max_length is not None and input_ids.shape[1] > self.max_length: |
|
|
input_ids = input_ids[:, :self.max_length] |
|
|
images_seq_mask = images_seq_mask[:, :self.max_length] |
|
|
|
|
|
prompt_token_counts = [min(p, self.max_length) for p in prompt_token_counts] |
|
|
|
|
|
|
|
|
labels = input_ids.clone() |
|
|
|
|
|
|
|
|
labels[labels == self.pad_token_id] = -100 |
|
|
|
|
|
|
|
|
labels[images_seq_mask] = -100 |
|
|
|
|
|
|
|
|
if self.train_on_responses_only: |
|
|
for idx, prompt_count in enumerate(prompt_token_counts): |
|
|
if prompt_count > 0: |
|
|
labels[idx, :prompt_count] = -100 |
|
|
|
|
|
|
|
|
attention_mask = (input_ids != self.pad_token_id).long() |
|
|
|
|
|
images_batch = [] |
|
|
for item in batch_data: |
|
|
images_batch.append((item['images_crop'], item['images_ori'])) |
|
|
|
|
|
images_spatial_crop = torch.cat([item['images_spatial_crop'] for item in batch_data], dim=0) |
|
|
|
|
|
return { |
|
|
"input_ids": input_ids, |
|
|
"attention_mask": attention_mask, |
|
|
"labels": labels, |
|
|
"images": images_batch, |
|
|
"images_seq_mask": images_seq_mask, |
|
|
"images_spatial_crop": images_spatial_crop, |
|
|
} |
|
|
|
|
|
|