import copy import os from dataclasses import dataclass, field from typing import Dict import torch import transformers import ujson as json from torch.utils.data import Dataset from qwen_vl_utils import process_vision_info from PIL import Image from transformers import AutoImageProcessor import re import numpy as np import cv2 from torchvision import transforms import random from segment_anything import build_sam_vit_h, sam_model_registry, SamPredictor from src.anchors.DepthAnything.depth_anything_v2.dpt import DepthAnythingV2 from diffusers import AutoencoderKL from transformers import AutoModel, CLIPImageProcessor from .params import DataArguments from .constants import * def truncate_sequence(input_ids, labels, max_length, eos_token_id): if input_ids.size(0) > max_length: input_ids = input_ids[:max_length-1] labels = labels[:max_length-1] if eos_token_id is not None: input_ids = torch.cat([input_ids, torch.tensor([eos_token_id])]) labels = torch.cat([labels, torch.tensor([eos_token_id])]) return input_ids, labels def pad_sequence(sequences, padding_side='right', padding_value=0): """ Pad a list of sequences to the same length. sequences: list of tensors in [seq_len, *] shape """ assert padding_side in ['right', 'left'] max_size = sequences[0].size() trailing_dims = max_size[1:] max_len = max(len(seq) for seq in sequences) batch_size = len(sequences) output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value) for i, seq in enumerate(sequences): length = seq.size(0) if padding_side == 'right': output.data[i, :length] = seq else: output.data[i, -length:] = seq return output def get_image_info(image_path, min_pixel, max_pixel, width, height): # Using this because of process_vision_info function # Need to fix this in the future content = { "type": "image", "image": image_path, "min_pixel": min_pixel, "max_pixel": max_pixel } if width is not None and height is not None: content["resized_width"] = width content["resized_height"] = height messages = [ {"role": "user", "content": [content] } ] image_input, _ = process_vision_info(messages) return image_input[0] def get_video_info(video_path, min_pixels, max_pixels, fps): # Using this because of process_vision_info function # Need to fix this in the future messages = [ {"role": "user", "content": [ { "type": "video", "video": video_path, "min_pixels": min_pixels, "max_pixels": max_pixels, "fps": fps } ] } ] _, video_input, video_kwargs = process_vision_info(messages, return_video_kwargs=True) return video_input[0], video_kwargs def add_anchor_pad(user_input, anchor_nums, anchor_tokens): # add anchor pad after VISION_END_TOKEN or ANCHOR_END_TOKEN anchor_pads = [] for anchor_num, anchor_token in zip(anchor_nums, anchor_tokens): anchor_pad = ANCHOR_START_TOKEN + anchor_token * anchor_num + ANCHOR_END_TOKEN anchor_pads.append(anchor_pad) anchor_pads = "".join(anchor_pads) if VISION_END_TOKEN in user_input: user_input = user_input.replace(VISION_END_TOKEN, VISION_END_TOKEN + anchor_pads) return user_input def add_cot_anchor_pad_in_user_input(user_input, anchor_nums, anchor_tokens): if len(anchor_nums) == 0: return user_input anchor_pads = [] for anchor_num, anchor_token in zip(anchor_nums, anchor_tokens, anchor_names): anchor_pad = ANCHOR_START_TOKEN + anchor_token * anchor_num + ANCHOR_END_TOKEN anchor_pads.append(anchor_pad) CoT_pad = "" if len(anchor_pads) == 1: CoT_pad = f"The {anchor_names[0]} of the image is {anchor_pads[0]}. " else: for i, (anchor_name, anchor_pad) in enumerate(zip(anchor_names, anchor_pads)): if i == 0: CoT_pad += f"The {anchor_name} of the image is {anchor_pad}, " elif i == len(anchor_names) - 1: CoT_pad += f"and the {anchor_name} of the image is {anchor_pad}. " else: CoT_pad += f"the {anchor_name} of the image is {anchor_pad}, " user_input = CoT_pad + user_input return user_input def get_cot_data_in_response(response, anchor_nums, anchor_tokens, anchor_names): if len(anchor_nums) == 0: return response anchor_pads = [] for anchor_num, anchor_token in zip(anchor_nums, anchor_tokens): anchor_pad = ANCHOR_START_TOKEN + anchor_token * anchor_num + ANCHOR_END_TOKEN anchor_pads.append(anchor_pad) CoT_start = "Because " if len(anchor_names) == 1: CoT_start += f"the {anchor_names[0]} of the image is {anchor_pads[0]}. " else: for anchor_name, anchor_pad in zip(anchor_names, anchor_pads): CoT_start += f"the {anchor_name} of the image is {anchor_pad}" if anchor_name == anchor_names[-2]: CoT_start += ", and " elif anchor_name == anchor_names[-1]: CoT_start += ". " else: CoT_start += ", " response = CoT_start + response return response COT_TEMPLATES = [ { "name": "basic_causal", "single": "Because the {anchor_name} of the image is {anchor_pad}. ", "multiple": "Because the {anchor_name} of the image is {anchor_pad}{connector}", "connectors": { "middle": ", ", "second_last": ", and ", "last": ". " } }, { "name": "observational", "single": "I can observe that the {anchor_name} of the image is {anchor_pad}. ", "multiple": "I can observe that the {anchor_name} of the image is {anchor_pad}{connector}", "connectors": { "middle": ", ", "second_last": ", and ", "last": ". " } }, { "name": "analytical", "single": "After analyzing the image, the {anchor_name} is {anchor_pad}. ", "multiple": "After analyzing the image, the {anchor_name} is {anchor_pad}{connector}", "connectors": { "middle": ", ", "second_last": ", and ", "last": ". " } }, { "name": "descriptive", "single": "The image shows that the {anchor_name} is {anchor_pad}. ", "multiple": "The image shows that the {anchor_name} is {anchor_pad}{connector}", "connectors": { "middle": ", ", "second_last": ", and ", "last": ". " } }, { "name": "conditional", "single": "Given that the {anchor_name} of the image is {anchor_pad}. ", "multiple": "Given that the {anchor_name} of the image is {anchor_pad}{connector}", "connectors": { "middle": ", ", "second_last": ", and ", "last": ". " } }, { "name": "evidence_based", "single": "Based on the visual evidence, the {anchor_name} of the image is {anchor_pad}. ", "multiple": "Based on the visual evidence, the {anchor_name} of the image is {anchor_pad}{connector}", "connectors": { "middle": ", ", "second_last": ", and ", "last": ". " } } ] def get_random_cot_template(): return random.choice(COT_TEMPLATES) def apply_cot_template(template, anchor_names, anchor_pads): if len(anchor_names) == 1: return template["single"].format( anchor_name=anchor_names[0], anchor_pad=anchor_pads[0] ) else: result = "" for i, (anchor_name, anchor_pad) in enumerate(zip(anchor_names, anchor_pads)): if i == len(anchor_names) - 1: connector = template["connectors"]["last"] elif i == len(anchor_names) - 2: connector = template["connectors"]["second_last"] else: connector = template["connectors"]["middle"] result += template["multiple"].format( anchor_name=anchor_name, anchor_pad=anchor_pad, connector=connector ) return result def get_templates_comt_data_in_response(response, anchor_nums, anchor_tokens, anchor_names): if len(anchor_nums) == 0: return response anchor_pads = [] for anchor_num, anchor_token in zip(anchor_nums, anchor_tokens): anchor_pad = ANCHOR_START_TOKEN + anchor_token * anchor_num + ANCHOR_END_TOKEN anchor_pads.append(anchor_pad) template = get_random_cot_template() cot_text = apply_cot_template(template, anchor_names, anchor_pads) response = "" + cot_text + "" + "" + response + "" return response def get_comt_data_in_response(response, anchor_nums, anchor_tokens, anchor_names): if len(anchor_nums) == 0: return response anchor_pads = [] for anchor_num, anchor_token in zip(anchor_nums, anchor_tokens): anchor_pad = ANCHOR_START_TOKEN + anchor_token * anchor_num + ANCHOR_END_TOKEN anchor_pads.append(anchor_pad) CoT_start = " Because " if len(anchor_names) == 1: CoT_start += f"the {anchor_names[0]} of the image is {anchor_pads[0]}. " else: for anchor_name, anchor_pad in zip(anchor_names, anchor_pads): CoT_start += f"the {anchor_name} of the image is {anchor_pad}" if anchor_name == anchor_names[-2]: CoT_start += ", and " elif anchor_name == anchor_names[-1]: CoT_start += ". " else: CoT_start += ", " response = CoT_start + " \n" + " " + response + " " return response def get_feature_data(user_input, gpt_response, anchor_nums, anchor_tokens, anchor_names): anchor_pads = [] for anchor_num, anchor_token, anchor_name in zip(anchor_nums, anchor_tokens, anchor_names): anchor_pad = ANCHOR_START_TOKEN + anchor_token * anchor_num + ANCHOR_END_TOKEN anchor_pads.append(anchor_pad) anchor_name = ", ".join(anchor_names) anchor_pads = "".join(anchor_pads) user_input = f"{DEFAULT_IM_START_TOKEN}{user_input['role']}\n{VISION_START_TOKEN + DEFAULT_IMAGE_TOKEN + VISION_END_TOKEN}What is the {anchor_name} of the image?\n{DEFAULT_IM_END_TOKEN}\n{DEFAULT_IM_START_TOKEN}{gpt_response['role']}\n" gpt_response = f"{anchor_pads}\n{DEFAULT_IM_END_TOKEN}\n" return user_input, gpt_response def replace_pad_with_anchor_tokens(gpt_response): token_dict = { "": SAM_PAD_TOKEN * 8, "": DEPTH_PAD_TOKEN * 4, "": DINO_PAD_TOKEN * 4, "": PIDINET_PAD_TOKEN * 4, "": SIGLIP_PAD_TOKEN * 4, "": METACLIP_PAD_TOKEN * 4, } for token, anchor_token in token_dict.items(): gpt_response = gpt_response.replace(token, anchor_token) return gpt_response def get_token_num(anchor_model_id): token_nums = [] for anchor_model in anchor_model_id: if anchor_model == "sam": token_nums.append(8) elif anchor_model == "dino": token_nums.append(4) elif anchor_model == "depth": token_nums.append(4) elif anchor_model == "SD": token_nums.append(4) elif anchor_model == "InternViT": token_nums.append(4) elif anchor_model == "pidinet": token_nums.append(4) elif anchor_model == "siglip": token_nums.append(4) elif anchor_model == "metaclip": token_nums.append(4) return token_nums def get_anchor_token(anchor_model_id): anchor_tokens = [] for anchor_model in anchor_model_id: if anchor_model == "sam": anchor_tokens.append(SAM_PAD_TOKEN) elif anchor_model == "dino": anchor_tokens.append(DINO_PAD_TOKEN) elif anchor_model == "depth": anchor_tokens.append(DEPTH_PAD_TOKEN) elif anchor_model == "SD": anchor_tokens.append(SD_PAD_TOKEN) elif anchor_model == "InternViT": anchor_tokens.append(INTERN_PAD_TOKEN) elif anchor_model == "pidinet": anchor_tokens.append(PIDINET_PAD_TOKEN) elif anchor_model == "siglip": anchor_tokens.append(SIGLIP_PAD_TOKEN) elif anchor_model == "metaclip": anchor_tokens.append(METACLIP_PAD_TOKEN) return anchor_tokens def get_anchor_task_name(anchor_model_id): anchor_task_names = [] for anchor_model in anchor_model_id: if anchor_model == "sam": anchor_task_names.append("segmentation") elif anchor_model == "dino": anchor_task_names.append("perception feature") elif anchor_model == "depth": anchor_task_names.append("depth map") elif anchor_model == "SD": anchor_task_names.append("style") elif anchor_model == "InternViT": anchor_task_names.append("caption") elif anchor_model == "pidinet": anchor_task_names.append("edge map") elif anchor_model == "siglip": anchor_task_names.append("clip feature") elif anchor_model == "metaclip": anchor_task_names.append("metaclip feature") return anchor_task_names class SupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__( self, data_path: str | list, processor: transformers.ProcessorMixin, data_args: DataArguments, model_id, padding=True, shuffle=True, random_seed=42, anchor_model_id=None, ): super(SupervisedDataset, self).__init__() if isinstance(data_path, str): import os as _os, glob as _glob if _os.path.isdir(data_path): # Parquet streaming: no HF cache, reads rows on-demand via pyarrow import pyarrow.parquet as _pq import bisect as _bisect parquet_files = sorted(_glob.glob(_os.path.join(data_path, "**", "*.parquet"), recursive=True)) if not parquet_files: raise FileNotFoundError(f"No parquet files found under {data_path}") print(f"[DataLoader] Found {len(parquet_files)} parquet files (streaming mode)") cumulative_rows = [] cumsum = 0 for _f in parquet_files: _pf = _pq.ParquetFile(_f) cumsum += _pf.metadata.num_rows cumulative_rows.append(cumsum) self._is_parquet = True self._parquet_files = parquet_files self._cumulative_rows = cumulative_rows self._total_rows = cumsum print(f"[DataLoader] Total parquet rows: {cumsum}") self._pf_handles = {} self._row_group_index = {} list_data_dict = None else: self._is_parquet = False list_data_dict = json.load(open(data_path, "r")) else: self._is_parquet = False list_data_dict = data_path self.model_id = model_id self.processor = processor self.list_data_dict = list_data_dict self.data_args = data_args self.padding = padding self.image_min_pixel = data_args.image_min_pixels self.image_max_pixel = data_args.image_max_pixels self.image_resized_w = data_args.image_resized_width self.image_resized_h = data_args.image_resized_height self.video_min_pixel = data_args.video_min_pixels self.video_max_pixel = data_args.video_max_pixels self.fps = data_args.fps self.anchor_model_id = anchor_model_id self.anchor_token_nums = get_token_num(anchor_model_id) self.anchor_tokens = get_anchor_token(anchor_model_id) self.anchor_task_names = get_anchor_task_name(anchor_model_id) self.cur_step = 0 self.stage_0_step = data_args.stage_0_step self.stage_1_step = data_args.stage_1_step self.stage_2_step = data_args.stage_2_step # for shuffle self.rng = np.random.default_rng(seed=random_seed) if shuffle: if self._is_parquet: import numpy as _np self._shuffle_perm = _np.random.RandomState(random_seed).permutation(self._total_rows) else: from datasets import Dataset as _HFDataset if isinstance(self.list_data_dict, _HFDataset): self.list_data_dict = self.list_data_dict.shuffle(seed=random_seed) else: self.rng.shuffle(self.list_data_dict) else: if self._is_parquet: self._shuffle_perm = None def set_cur_step(self, step: int): self.cur_step = step print(f"[Dataset] cur_step has been set to {step}") def __len__(self): if self._is_parquet: return self._total_rows return len(self.list_data_dict) def _get_parquet_row(self, idx): import pyarrow.parquet as pq, bisect, io real_idx = self._shuffle_perm[idx] if self._shuffle_perm is not None else idx # bisect to find file fi = bisect.bisect_right(self._cumulative_rows, real_idx) local_idx = real_idx - self._cumulative_rows[fi - 1] if fi > 0 else real_idx fpath = self._parquet_files[fi] # Open ParquetFile lazily and cache handle if fpath not in self._pf_handles: self._pf_handles[fpath] = pq.ParquetFile(fpath) pf = self._pf_handles[fpath] # Build row-group cumulative table cum = [] s = 0 for rg_idx in range(pf.num_row_groups): s += pf.metadata.row_group(rg_idx).num_rows cum.append(s) self._row_group_index[fpath] = cum pf = self._pf_handles[fpath] cum = self._row_group_index[fpath] # Find which row group contains local_idx rg_idx = bisect.bisect_right(cum, local_idx) rg_start = 0 if rg_idx == 0 else cum[rg_idx - 1] in_rg_idx = local_idx - rg_start # Read just this row group (typically ~1k-10k rows, much smaller than file) table = pf.read_row_group(rg_idx) row = table.slice(in_rg_idx, 1).to_pylist()[0] # Image decode image_data = row.get('image') if isinstance(image_data, dict) and image_data.get('bytes'): row['image'] = Image.open(io.BytesIO(image_data['bytes'])).convert('RGB') elif isinstance(image_data, dict) and image_data.get("path"): row["image"] = Image.open(image_data["path"]).convert("RGB") return row def __getitem__(self, i) -> Dict[str, torch.Tensor]: # import ipdb; ipdb.set_trace() if self._is_parquet: sources = self._get_parquet_row(i) else: sources = self.list_data_dict[i] is_video = False processor = self.processor if "image" in sources: videos = None grid_key = "image_grid_thw" pixel_key = "pixel_values" image_files = sources["image"] image_folder = self.data_args.image_folder if isinstance(image_files, Image.Image): # Already a PIL Image (e.g. from parquet HF dataset) image_files = [image_files.convert("RGB")] elif isinstance(image_files, str): image_files = Image.open(image_files).convert("RGB") image_files = [image_files] elif isinstance(image_files, (bytes, bytearray)): import io as _io image_files = [Image.open(_io.BytesIO(image_files)).convert("RGB")] else: image_files = [img_f.convert("RGB") if isinstance(img_f, Image.Image) else Image.open(img_f).convert("RGB") for img_f in image_files] images = [] for image_file in image_files: # if not os.path.exists(image_file): # if not image_file.startswith("http"): # image_file = os.path.join(image_folder, image_file) # images.append(get_image_info(image_file, self.image_min_pixel, self.image_max_pixel, self.image_resized_w, self.image_resized_h)) # else: images.append(get_image_info(image_file, self.image_min_pixel, self.image_max_pixel, self.image_resized_w, self.image_resized_h)) elif "video" in sources: is_video = True images=None grid_key = "video_grid_thw" pixel_key = "pixel_values_videos" video_files = sources["video"] video_folder = self.data_args.image_folder if isinstance(video_files, str): video_files = [video_files] videos = [] for video_file in video_files: if not os.path.exists(video_file): if not video_file.startswith("http"): video_file = os.path.join(video_folder, video_file) video_input, video_kwargs = get_video_info(video_file, self.video_min_pixel, self.video_max_pixel, self.data_args.fps) videos.append(video_input) else: grid_key = None pixel_key = None images=None videos=None if images is None: print("No image or video found in the data.") images = [] # Create a black image as a placeholder black_image = Image.new("RGB", (self.image_resized_w, self.image_resized_h), (0, 0, 0)) images.append(get_image_info(black_image, self.image_min_pixel, self.image_max_pixel, self.image_resized_w, self.image_resized_h)) elif len(images) == 0: print("No image or video found in the data.") # Create a black image as a placeholder black_image = Image.new("RGB", (self.image_resized_w, self.image_resized_h), (0, 0, 0)) images.append(get_image_info(black_image, self.image_min_pixel, self.image_max_pixel, self.image_resized_w, self.image_resized_h)) if videos is not None: # import ipdb; ipdb.set_trace() pass sources = copy.deepcopy(llava_to_openai(sources['conversations'], is_video=is_video)) all_input_ids = [] all_labels = [] all_pixel_values = [] all_image_grid_thw = [] all_second_gird = [] # all_dino_encoded_values = [] # Qwen2-VL uses a default system message so I've added this. if len(SYSTEM_MESSAGE) > 0: system_message = f"{DEFAULT_IM_START_TOKEN}system\n{SYSTEM_MESSAGE}\n{DEFAULT_IM_END_TOKEN}\n" system_message_input_ids = processor.tokenizer(system_message, add_special_tokens=False, return_tensors='pt')['input_ids'] system_labels = torch.full_like(system_message_input_ids, IGNORE_INDEX) all_input_ids.append(system_message_input_ids.squeeze(0)) all_labels.append(system_labels.squeeze(0)) for _, j in enumerate(range(0, len(sources), 2)): if j >= 2: break user_input = sources[j] gpt_response = sources[j + 1] if (DEFAULT_IMAGE_TOKEN not in user_input['content']) and (DEFAULT_VIDEO_TOKEN not in user_input['content']) and (LLAVA_IMAGE_TOKEN in user_input['content']): user_input = f"{DEFAULT_IM_START_TOKEN}{VISION_START_TOKEN + DEFAULT_IMAGE_TOKEN + VISION_END_TOKEN}{user_input['role']}\n{user_input['content']}\n{DEFAULT_IM_END_TOKEN}\n{DEFAULT_IM_START_TOKEN}{gpt_response['role']}\n" user_input = add_anchor_pad(user_input, self.anchor_token_nums, self.anchor_tokens) gpt_response = f"{gpt_response['content']}\n{DEFAULT_IM_END_TOKEN}\n" raise ValueError('Every man is a poet when he is in love') else: if self.cur_step < self.stage_0_step: user_input = f"{DEFAULT_IM_START_TOKEN}{user_input['role']}\n{user_input['content']}\n{DEFAULT_IM_END_TOKEN}\n{DEFAULT_IM_START_TOKEN}{gpt_response['role']}\n" user_input = add_anchor_pad(user_input, self.anchor_token_nums, self.anchor_tokens) gpt_response = f"{gpt_response['content']}\n{DEFAULT_IM_END_TOKEN}\n" elif self.cur_step < self.stage_1_step: user_input, gpt_response = get_feature_data(user_input, gpt_response, self.anchor_token_nums, self.anchor_tokens, self.anchor_task_names) elif self.cur_step < self.stage_2_step: user_input = f"{DEFAULT_IM_START_TOKEN}{user_input['role']}\n{user_input['content']}\n{DEFAULT_IM_END_TOKEN}\n{DEFAULT_IM_START_TOKEN}{gpt_response['role']}\n" gpt_response = f"{gpt_response['content']}" if DEFAULT_IMAGE_TOKEN in user_input: gpt_response = get_comt_data_in_response(gpt_response, self.anchor_token_nums, self.anchor_tokens, self.anchor_task_names) gpt_response = f"{gpt_response}\n{DEFAULT_IM_END_TOKEN}\n" # print(f"\033[92m gpt_response: {gpt_response}\033[0m") else: # user_input = f"{DEFAULT_IM_START_TOKEN}{user_input['role']}\n{user_input['content']}\n{DEFAULT_IM_END_TOKEN}\n{DEFAULT_IM_START_TOKEN}{gpt_response['role']}\n" # gpt_response = f"{gpt_response['content']}\n{DEFAULT_IM_END_TOKEN}\n" # gpt_response = replace_pad_with_anchor_tokens(gpt_response) import random xxx = random.randint(0, 5) if xxx == 0: user_input = f"{DEFAULT_IM_START_TOKEN}{user_input['role']}\n{user_input['content']}\n{DEFAULT_IM_END_TOKEN}\n{DEFAULT_IM_START_TOKEN}{gpt_response['role']}\n" gpt_response = f"{gpt_response['content']}\n{DEFAULT_IM_END_TOKEN}\n" else: user_input = f"{DEFAULT_IM_START_TOKEN}{user_input['role']}\n{user_input['content']}\n{DEFAULT_IM_END_TOKEN}\n{DEFAULT_IM_START_TOKEN}{gpt_response['role']}\n" gpt_response = f"{gpt_response['content']}" if DEFAULT_IMAGE_TOKEN in user_input: # INSERT_YOUR_CODE total = len(self.anchor_tokens) if total == 0: selected_anchor_token_nums = [] selected_anchor_tokens = [] selected_anchor_task_names = [] else: x = random.randint(1, total) idxs = sorted(random.sample(range(total), x)) if x > 0 else [] selected_anchor_token_nums = [self.anchor_token_nums[i] for i in idxs] selected_anchor_tokens = [self.anchor_tokens[i] for i in idxs] selected_anchor_task_names = [self.anchor_task_names[i] for i in idxs] gpt_response = get_comt_data_in_response(gpt_response, selected_anchor_token_nums, selected_anchor_tokens, selected_anchor_task_names) gpt_response = f"{gpt_response}\n{DEFAULT_IM_END_TOKEN}\n" # print(f'the user_input is {user_input}') # print(f'the gpt_response is {gpt_response}') # print("-----------------") # print(user_input, gpt_response) # print("-----------------") # import ipdb; ipdb.set_trace() if DEFAULT_IMAGE_TOKEN in user_input: inputs = processor(text=[user_input], images=images, videos=videos, padding=False, return_tensors='pt') prompt_input_ids = inputs['input_ids'] # raise ValueError('Every man is a poet when he is in love') all_pixel_values.append(inputs[pixel_key]) all_image_grid_thw.append(inputs[grid_key]) # del dino_val torch.cuda.empty_cache() elif DEFAULT_VIDEO_TOKEN in user_input: if "Qwen2.5" in self.model_id: inputs = processor(text=[user_input], images=images, videos=videos, padding=False, return_tensors='pt', **video_kwargs) all_second_gird.extend(inputs["second_per_grid_ts"]) else: inputs = processor(text=[user_input], images=images, videos=videos, padding=False, return_tensors='pt') prompt_input_ids = inputs['input_ids'] all_pixel_values.append(inputs[pixel_key]) all_image_grid_thw.append(inputs[grid_key]) else: prompt_input_ids = processor.tokenizer(user_input, add_special_tokens=False, padding=False, return_tensors='pt')['input_ids'] response_input_ids = processor.tokenizer(gpt_response, add_special_tokens=False, padding=False, return_tensors='pt')['input_ids'] input_ids = torch.cat([prompt_input_ids, response_input_ids], dim=1).squeeze(0) labels = torch.cat( [ torch.tensor([IGNORE_INDEX] * len(prompt_input_ids[0])), response_input_ids.squeeze(0), ], dim=0, ) all_input_ids.append(input_ids) all_labels.append(labels) # There is no need for eos or bos tokens in the input_ids # Qwen2-VL does not use them input_ids = torch.cat(all_input_ids, dim=0).to(torch.long) labels = torch.cat(all_labels, dim=0).to(torch.long) # eos_token_id = processor.tokenizer.convert_tokens_to_ids(DEFAULT_IM_END_TOKEN) # input_ids, labels = truncate_sequence(input_ids, labels, self.max_length, eos_token_id) attention_mask = (input_ids > -1000000).to(torch.long) data_dict = dict( input_ids=input_ids, attention_mask=attention_mask, labels=labels, ) if pixel_key and grid_key: pixel_values = torch.cat(all_pixel_values, dim=0) image_thw = torch.cat(all_image_grid_thw, dim=0) data_dict[pixel_key] = pixel_values data_dict[grid_key] = image_thw data_dict["image_files"] = image_files if len(all_second_gird) > 0: second_gird = all_second_gird data_dict["second_per_grid_ts"] = second_gird self.cur_step += 1 return data_dict class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" def __init__(self, pad_token_id: int): self.pad_token_id = pad_token_id def __call__(self, examples): batch_input_ids = [] batch_label_ids = [] batch_pixel_values = [] batch_pixel_video_values = [] batch_video_thw = [] batch_image_thw = [] batch_second_per_grid_ts = [] batch_image_files = [] for example in examples: keys = example.keys() if "pixel_values_videos" in keys: batch_pixel_video_values.append(example["pixel_values_videos"]) batch_video_thw.append(example["video_grid_thw"]) elif "pixel_values" in keys: batch_pixel_values.append(example["pixel_values"]) batch_image_thw.append(example["image_grid_thw"]) if "image_files" in keys: batch_image_files.append(example["image_files"]) batch_input_ids.append(example["input_ids"]) batch_label_ids.append(example["labels"]) if "second_per_grid_ts" in keys: batch_second_per_grid_ts.extend(example["second_per_grid_ts"]) input_ids = pad_sequence( batch_input_ids, padding_side='right', padding_value=self.pad_token_id ) attention_mask = input_ids != self.pad_token_id labels = pad_sequence(batch_label_ids, padding_side='right', padding_value=IGNORE_INDEX) data_dict = { 'input_ids': input_ids, 'labels': labels, 'attention_mask': attention_mask, } if len(batch_pixel_values) > 0: pixel_values = torch.cat(batch_pixel_values, dim=0) image_thw = torch.cat(batch_image_thw, dim=0) data_dict["pixel_values"] = pixel_values data_dict["image_grid_thw"] = image_thw if len(batch_pixel_video_values) > 0: pixel_video_values = torch.cat(batch_pixel_video_values, dim=0) video_thw = torch.cat(batch_video_thw, dim=0) data_dict["pixel_values_videos"] = pixel_video_values data_dict["video_grid_thw"] = video_thw if len(batch_second_per_grid_ts) > 0: data_dict["second_per_grid_ts"] = batch_second_per_grid_ts if len(batch_image_files) > 0: data_dict["image_files"] = batch_image_files return data_dict def replace_image_tokens(input_string, is_video=False): if is_video: pattern = r'\n?' + re.escape(LLAVA_VIDEO_TOKEN) + r'\n?' replacement = VISION_START_TOKEN + DEFAULT_VIDEO_TOKEN + VISION_END_TOKEN else: pattern = r'\n?' + re.escape(LLAVA_IMAGE_TOKEN) + r'\n?' replacement = VISION_START_TOKEN + DEFAULT_IMAGE_TOKEN + VISION_END_TOKEN return re.sub(pattern, replacement, input_string) def llava_to_openai(conversations, is_video=False): role_mapping = {"human": "user", "gpt": "assistant"} transformed_data = [] for conversation in conversations: transformed_content = replace_image_tokens(conversation["value"], is_video=is_video) transformed_entry = { "role": role_mapping.get(conversation["from"], conversation["from"]), "content": transformed_content, } transformed_data.append(transformed_entry) return transformed_data def make_supervised_data_module(model_id, processor, data_args, anchor_model_id): """Make dataset and collator for supervised fine-tuning.""" sft_dataset = SupervisedDataset( data_path=data_args.data_path, processor=processor, data_args=data_args, model_id=model_id, anchor_model_id=anchor_model_id ) data_collator = DataCollatorForSupervisedDataset(pad_token_id=processor.tokenizer.pad_token_id) return dict(train_dataset=sft_dataset, eval_dataset=None, data_collator=data_collator)