| import copy |
| import os |
| from dataclasses import dataclass, field |
| from typing import Dict |
| import torch |
| import transformers |
| import ujson as json |
| from torch.utils.data import Dataset |
| from qwen_vl_utils import process_vision_info |
| from PIL import Image |
| from transformers import AutoImageProcessor |
| import re |
| import numpy as np |
| import cv2 |
| from torchvision import transforms |
| import random |
|
|
| from segment_anything import build_sam_vit_h, sam_model_registry, SamPredictor |
| from src.anchors.DepthAnything.depth_anything_v2.dpt import DepthAnythingV2 |
| from diffusers import AutoencoderKL |
| from transformers import AutoModel, CLIPImageProcessor |
|
|
| from .params import DataArguments |
| from .constants import * |
|
|
|
|
| def truncate_sequence(input_ids, labels, max_length, eos_token_id): |
| if input_ids.size(0) > max_length: |
| input_ids = input_ids[:max_length-1] |
| labels = labels[:max_length-1] |
|
|
| if eos_token_id is not None: |
| input_ids = torch.cat([input_ids, torch.tensor([eos_token_id])]) |
| labels = torch.cat([labels, torch.tensor([eos_token_id])]) |
|
|
| return input_ids, labels |
|
|
| def pad_sequence(sequences, padding_side='right', padding_value=0): |
| """ |
| Pad a list of sequences to the same length. |
| sequences: list of tensors in [seq_len, *] shape |
| """ |
| assert padding_side in ['right', 'left'] |
| max_size = sequences[0].size() |
| trailing_dims = max_size[1:] |
| max_len = max(len(seq) for seq in sequences) |
| batch_size = len(sequences) |
| output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value) |
| for i, seq in enumerate(sequences): |
| length = seq.size(0) |
| if padding_side == 'right': |
| output.data[i, :length] = seq |
| else: |
| output.data[i, -length:] = seq |
| return output |
|
|
| def get_image_info(image_path, min_pixel, max_pixel, width, height): |
| |
| |
|
|
|
|
| content = { |
| "type": "image", |
| "image": image_path, |
| "min_pixel": min_pixel, |
| "max_pixel": max_pixel |
| } |
|
|
| if width is not None and height is not None: |
| content["resized_width"] = width |
| content["resized_height"] = height |
|
|
| messages = [ |
| {"role": "user", |
| "content": [content] |
| } |
| ] |
|
|
| image_input, _ = process_vision_info(messages) |
|
|
| return image_input[0] |
|
|
| def get_video_info(video_path, min_pixels, max_pixels, fps): |
| |
| |
|
|
| messages = [ |
| {"role": "user", |
| "content": [ |
| { |
| "type": "video", |
| "video": video_path, |
| "min_pixels": min_pixels, |
| "max_pixels": max_pixels, |
| "fps": fps |
| } |
| ] |
| } |
| ] |
|
|
| _, video_input, video_kwargs = process_vision_info(messages, return_video_kwargs=True) |
|
|
| return video_input[0], video_kwargs |
|
|
| def add_anchor_pad(user_input, anchor_nums, anchor_tokens): |
| |
| anchor_pads = [] |
| for anchor_num, anchor_token in zip(anchor_nums, anchor_tokens): |
| anchor_pad = ANCHOR_START_TOKEN + anchor_token * anchor_num + ANCHOR_END_TOKEN |
| anchor_pads.append(anchor_pad) |
| anchor_pads = "".join(anchor_pads) |
| if VISION_END_TOKEN in user_input: |
| user_input = user_input.replace(VISION_END_TOKEN, VISION_END_TOKEN + anchor_pads) |
| return user_input |
|
|
| def add_cot_anchor_pad_in_user_input(user_input, anchor_nums, anchor_tokens): |
| if len(anchor_nums) == 0: |
| return user_input |
| |
| anchor_pads = [] |
| for anchor_num, anchor_token in zip(anchor_nums, anchor_tokens, anchor_names): |
| anchor_pad = ANCHOR_START_TOKEN + anchor_token * anchor_num + ANCHOR_END_TOKEN |
| anchor_pads.append(anchor_pad) |
| CoT_pad = "" |
| if len(anchor_pads) == 1: |
| CoT_pad = f"The {anchor_names[0]} of the image is {anchor_pads[0]}. " |
| else: |
| for i, (anchor_name, anchor_pad) in enumerate(zip(anchor_names, anchor_pads)): |
| if i == 0: |
| CoT_pad += f"The {anchor_name} of the image is {anchor_pad}, " |
| elif i == len(anchor_names) - 1: |
| CoT_pad += f"and the {anchor_name} of the image is {anchor_pad}. " |
| else: |
| CoT_pad += f"the {anchor_name} of the image is {anchor_pad}, " |
| user_input = CoT_pad + user_input |
| return user_input |
|
|
| def get_cot_data_in_response(response, anchor_nums, anchor_tokens, anchor_names): |
| if len(anchor_nums) == 0: |
| return response |
| |
| anchor_pads = [] |
| for anchor_num, anchor_token in zip(anchor_nums, anchor_tokens): |
| anchor_pad = ANCHOR_START_TOKEN + anchor_token * anchor_num + ANCHOR_END_TOKEN |
| anchor_pads.append(anchor_pad) |
| CoT_start = "Because " |
| if len(anchor_names) == 1: |
| CoT_start += f"the {anchor_names[0]} of the image is {anchor_pads[0]}. " |
| else: |
| for anchor_name, anchor_pad in zip(anchor_names, anchor_pads): |
| CoT_start += f"the {anchor_name} of the image is {anchor_pad}" |
| if anchor_name == anchor_names[-2]: |
| CoT_start += ", and " |
| elif anchor_name == anchor_names[-1]: |
| CoT_start += ". " |
| else: |
| CoT_start += ", " |
| response = CoT_start + response |
| return response |
|
|
|
|
| COT_TEMPLATES = [ |
| { |
| "name": "basic_causal", |
| "single": "Because the {anchor_name} of the image is {anchor_pad}. ", |
| "multiple": "Because the {anchor_name} of the image is {anchor_pad}{connector}", |
| "connectors": { |
| "middle": ", ", |
| "second_last": ", and ", |
| "last": ". " |
| } |
| }, |
| |
| { |
| "name": "observational", |
| "single": "I can observe that the {anchor_name} of the image is {anchor_pad}. ", |
| "multiple": "I can observe that the {anchor_name} of the image is {anchor_pad}{connector}", |
| "connectors": { |
| "middle": ", ", |
| "second_last": ", and ", |
| "last": ". " |
| } |
| }, |
| |
| { |
| "name": "analytical", |
| "single": "After analyzing the image, the {anchor_name} is {anchor_pad}. ", |
| "multiple": "After analyzing the image, the {anchor_name} is {anchor_pad}{connector}", |
| "connectors": { |
| "middle": ", ", |
| "second_last": ", and ", |
| "last": ". " |
| } |
| }, |
| |
| { |
| "name": "descriptive", |
| "single": "The image shows that the {anchor_name} is {anchor_pad}. ", |
| "multiple": "The image shows that the {anchor_name} is {anchor_pad}{connector}", |
| "connectors": { |
| "middle": ", ", |
| "second_last": ", and ", |
| "last": ". " |
| } |
| }, |
| |
| { |
| "name": "conditional", |
| "single": "Given that the {anchor_name} of the image is {anchor_pad}. ", |
| "multiple": "Given that the {anchor_name} of the image is {anchor_pad}{connector}", |
| "connectors": { |
| "middle": ", ", |
| "second_last": ", and ", |
| "last": ". " |
| } |
| }, |
| |
| { |
| "name": "evidence_based", |
| "single": "Based on the visual evidence, the {anchor_name} of the image is {anchor_pad}. ", |
| "multiple": "Based on the visual evidence, the {anchor_name} of the image is {anchor_pad}{connector}", |
| "connectors": { |
| "middle": ", ", |
| "second_last": ", and ", |
| "last": ". " |
| } |
| } |
| ] |
|
|
| def get_random_cot_template(): |
| return random.choice(COT_TEMPLATES) |
|
|
| def apply_cot_template(template, anchor_names, anchor_pads): |
| if len(anchor_names) == 1: |
| return template["single"].format( |
| anchor_name=anchor_names[0], |
| anchor_pad=anchor_pads[0] |
| ) |
| else: |
| result = "" |
| for i, (anchor_name, anchor_pad) in enumerate(zip(anchor_names, anchor_pads)): |
| if i == len(anchor_names) - 1: |
| connector = template["connectors"]["last"] |
| elif i == len(anchor_names) - 2: |
| connector = template["connectors"]["second_last"] |
| else: |
| connector = template["connectors"]["middle"] |
| |
| result += template["multiple"].format( |
| anchor_name=anchor_name, |
| anchor_pad=anchor_pad, |
| connector=connector |
| ) |
| return result |
|
|
| def get_templates_comt_data_in_response(response, anchor_nums, anchor_tokens, anchor_names): |
| if len(anchor_nums) == 0: |
| return response |
| |
| anchor_pads = [] |
| for anchor_num, anchor_token in zip(anchor_nums, anchor_tokens): |
| anchor_pad = ANCHOR_START_TOKEN + anchor_token * anchor_num + ANCHOR_END_TOKEN |
| anchor_pads.append(anchor_pad) |
| |
| template = get_random_cot_template() |
| |
| cot_text = apply_cot_template(template, anchor_names, anchor_pads) |
| |
| response = "<think>" + cot_text + "</think>" + "<answer>" + response + "</answer>" |
| return response |
| |
| def get_comt_data_in_response(response, anchor_nums, anchor_tokens, anchor_names): |
| if len(anchor_nums) == 0: |
| return response |
| |
| anchor_pads = [] |
| for anchor_num, anchor_token in zip(anchor_nums, anchor_tokens): |
| anchor_pad = ANCHOR_START_TOKEN + anchor_token * anchor_num + ANCHOR_END_TOKEN |
| anchor_pads.append(anchor_pad) |
| CoT_start = "<think> Because " |
| if len(anchor_names) == 1: |
| CoT_start += f"the {anchor_names[0]} of the image is {anchor_pads[0]}. " |
| else: |
| for anchor_name, anchor_pad in zip(anchor_names, anchor_pads): |
| CoT_start += f"the {anchor_name} of the image is {anchor_pad}" |
| if anchor_name == anchor_names[-2]: |
| CoT_start += ", and " |
| elif anchor_name == anchor_names[-1]: |
| CoT_start += ". " |
| else: |
| CoT_start += ", " |
| response = CoT_start + " </think>\n" + "<answer> " + response + " </answer>" |
| return response |
| |
| def get_feature_data(user_input, gpt_response, anchor_nums, anchor_tokens, anchor_names): |
| anchor_pads = [] |
| for anchor_num, anchor_token, anchor_name in zip(anchor_nums, anchor_tokens, anchor_names): |
| anchor_pad = ANCHOR_START_TOKEN + anchor_token * anchor_num + ANCHOR_END_TOKEN |
| anchor_pads.append(anchor_pad) |
| anchor_name = ", ".join(anchor_names) |
| anchor_pads = "".join(anchor_pads) |
| user_input = f"{DEFAULT_IM_START_TOKEN}{user_input['role']}\n{VISION_START_TOKEN + DEFAULT_IMAGE_TOKEN + VISION_END_TOKEN}What is the {anchor_name} of the image?\n{DEFAULT_IM_END_TOKEN}\n{DEFAULT_IM_START_TOKEN}{gpt_response['role']}\n" |
| gpt_response = f"{anchor_pads}\n{DEFAULT_IM_END_TOKEN}\n" |
| return user_input, gpt_response |
|
|
| def replace_pad_with_anchor_tokens(gpt_response): |
| token_dict = { |
| "<segmentation>": SAM_PAD_TOKEN * 8, |
| "<depth>": DEPTH_PAD_TOKEN * 4, |
| "<dino>": DINO_PAD_TOKEN * 4, |
| "<pidinet>": PIDINET_PAD_TOKEN * 4, |
| "<siglip>": SIGLIP_PAD_TOKEN * 4, |
| "<metaclip>": METACLIP_PAD_TOKEN * 4, |
| } |
| for token, anchor_token in token_dict.items(): |
| gpt_response = gpt_response.replace(token, anchor_token) |
| return gpt_response |
| |
| def get_token_num(anchor_model_id): |
| token_nums = [] |
| for anchor_model in anchor_model_id: |
| if anchor_model == "sam": |
| token_nums.append(8) |
| elif anchor_model == "dino": |
| token_nums.append(4) |
| elif anchor_model == "depth": |
| token_nums.append(4) |
| elif anchor_model == "SD": |
| token_nums.append(4) |
| elif anchor_model == "InternViT": |
| token_nums.append(4) |
| elif anchor_model == "pidinet": |
| token_nums.append(4) |
| elif anchor_model == "siglip": |
| token_nums.append(4) |
| elif anchor_model == "metaclip": |
| token_nums.append(4) |
| return token_nums |
|
|
| def get_anchor_token(anchor_model_id): |
| anchor_tokens = [] |
| for anchor_model in anchor_model_id: |
| if anchor_model == "sam": |
| anchor_tokens.append(SAM_PAD_TOKEN) |
| elif anchor_model == "dino": |
| anchor_tokens.append(DINO_PAD_TOKEN) |
| elif anchor_model == "depth": |
| anchor_tokens.append(DEPTH_PAD_TOKEN) |
| elif anchor_model == "SD": |
| anchor_tokens.append(SD_PAD_TOKEN) |
| elif anchor_model == "InternViT": |
| anchor_tokens.append(INTERN_PAD_TOKEN) |
| elif anchor_model == "pidinet": |
| anchor_tokens.append(PIDINET_PAD_TOKEN) |
| elif anchor_model == "siglip": |
| anchor_tokens.append(SIGLIP_PAD_TOKEN) |
| elif anchor_model == "metaclip": |
| anchor_tokens.append(METACLIP_PAD_TOKEN) |
| return anchor_tokens |
|
|
| def get_anchor_task_name(anchor_model_id): |
| anchor_task_names = [] |
| for anchor_model in anchor_model_id: |
| if anchor_model == "sam": |
| anchor_task_names.append("segmentation") |
| elif anchor_model == "dino": |
| anchor_task_names.append("perception feature") |
| elif anchor_model == "depth": |
| anchor_task_names.append("depth map") |
| elif anchor_model == "SD": |
| anchor_task_names.append("style") |
| elif anchor_model == "InternViT": |
| anchor_task_names.append("caption") |
| elif anchor_model == "pidinet": |
| anchor_task_names.append("edge map") |
| elif anchor_model == "siglip": |
| anchor_task_names.append("clip feature") |
| elif anchor_model == "metaclip": |
| anchor_task_names.append("metaclip feature") |
| return anchor_task_names |
| |
|
|
| class SupervisedDataset(Dataset): |
| """Dataset for supervised fine-tuning.""" |
|
|
| def __init__( |
| self, |
| data_path: str | list, |
| processor: transformers.ProcessorMixin, |
| data_args: DataArguments, |
| model_id, |
| padding=True, |
| shuffle=True, |
| random_seed=42, |
| anchor_model_id=None, |
| ): |
| super(SupervisedDataset, self).__init__() |
| if isinstance(data_path, str): |
| import os as _os, glob as _glob |
| if _os.path.isdir(data_path): |
| |
| import pyarrow.parquet as _pq |
| import bisect as _bisect |
| parquet_files = sorted(_glob.glob(_os.path.join(data_path, "**", "*.parquet"), recursive=True)) |
| if not parquet_files: |
| raise FileNotFoundError(f"No parquet files found under {data_path}") |
| print(f"[DataLoader] Found {len(parquet_files)} parquet files (streaming mode)") |
| cumulative_rows = [] |
| cumsum = 0 |
| for _f in parquet_files: |
| _pf = _pq.ParquetFile(_f) |
| cumsum += _pf.metadata.num_rows |
| cumulative_rows.append(cumsum) |
| self._is_parquet = True |
| self._parquet_files = parquet_files |
| self._cumulative_rows = cumulative_rows |
| self._total_rows = cumsum |
| print(f"[DataLoader] Total parquet rows: {cumsum}") |
| self._pf_handles = {} |
| self._row_group_index = {} |
| list_data_dict = None |
| else: |
| self._is_parquet = False |
| list_data_dict = json.load(open(data_path, "r")) |
| else: |
| self._is_parquet = False |
| list_data_dict = data_path |
|
|
| self.model_id = model_id |
| self.processor = processor |
| self.list_data_dict = list_data_dict |
| self.data_args = data_args |
| self.padding = padding |
| self.image_min_pixel = data_args.image_min_pixels |
| self.image_max_pixel = data_args.image_max_pixels |
| self.image_resized_w = data_args.image_resized_width |
| self.image_resized_h = data_args.image_resized_height |
| self.video_min_pixel = data_args.video_min_pixels |
| self.video_max_pixel = data_args.video_max_pixels |
| self.fps = data_args.fps |
| self.anchor_model_id = anchor_model_id |
| self.anchor_token_nums = get_token_num(anchor_model_id) |
| self.anchor_tokens = get_anchor_token(anchor_model_id) |
| self.anchor_task_names = get_anchor_task_name(anchor_model_id) |
| |
| self.cur_step = 0 |
| self.stage_0_step = data_args.stage_0_step |
| self.stage_1_step = data_args.stage_1_step |
| self.stage_2_step = data_args.stage_2_step |
| |
| |
| self.rng = np.random.default_rng(seed=random_seed) |
| |
| if shuffle: |
| if self._is_parquet: |
| import numpy as _np |
| self._shuffle_perm = _np.random.RandomState(random_seed).permutation(self._total_rows) |
| else: |
| from datasets import Dataset as _HFDataset |
| if isinstance(self.list_data_dict, _HFDataset): |
| self.list_data_dict = self.list_data_dict.shuffle(seed=random_seed) |
| else: |
| self.rng.shuffle(self.list_data_dict) |
| else: |
| if self._is_parquet: |
| self._shuffle_perm = None |
| |
| def set_cur_step(self, step: int): |
| self.cur_step = step |
| print(f"[Dataset] cur_step has been set to {step}") |
| |
| def __len__(self): |
| if self._is_parquet: |
| return self._total_rows |
| return len(self.list_data_dict) |
|
|
| def _get_parquet_row(self, idx): |
| import pyarrow.parquet as pq, bisect, io |
| real_idx = self._shuffle_perm[idx] if self._shuffle_perm is not None else idx |
| |
| fi = bisect.bisect_right(self._cumulative_rows, real_idx) |
| local_idx = real_idx - self._cumulative_rows[fi - 1] if fi > 0 else real_idx |
| fpath = self._parquet_files[fi] |
|
|
| |
| if fpath not in self._pf_handles: |
| self._pf_handles[fpath] = pq.ParquetFile(fpath) |
| pf = self._pf_handles[fpath] |
| |
| cum = [] |
| s = 0 |
| for rg_idx in range(pf.num_row_groups): |
| s += pf.metadata.row_group(rg_idx).num_rows |
| cum.append(s) |
| self._row_group_index[fpath] = cum |
|
|
| pf = self._pf_handles[fpath] |
| cum = self._row_group_index[fpath] |
|
|
| |
| rg_idx = bisect.bisect_right(cum, local_idx) |
| rg_start = 0 if rg_idx == 0 else cum[rg_idx - 1] |
| in_rg_idx = local_idx - rg_start |
|
|
| |
| table = pf.read_row_group(rg_idx) |
| row = table.slice(in_rg_idx, 1).to_pylist()[0] |
|
|
| |
| image_data = row.get('image') |
| if isinstance(image_data, dict) and image_data.get('bytes'): |
| row['image'] = Image.open(io.BytesIO(image_data['bytes'])).convert('RGB') |
| elif isinstance(image_data, dict) and image_data.get("path"): |
| row["image"] = Image.open(image_data["path"]).convert("RGB") |
|
|
| return row |
| def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
| |
| |
| |
| if self._is_parquet: |
| sources = self._get_parquet_row(i) |
| else: |
| sources = self.list_data_dict[i] |
| |
| is_video = False |
|
|
| processor = self.processor |
| if "image" in sources: |
| videos = None |
| grid_key = "image_grid_thw" |
| pixel_key = "pixel_values" |
| |
| image_files = sources["image"] |
| image_folder = self.data_args.image_folder |
|
|
| if isinstance(image_files, Image.Image): |
| |
| image_files = [image_files.convert("RGB")] |
| elif isinstance(image_files, str): |
| image_files = Image.open(image_files).convert("RGB") |
| image_files = [image_files] |
| elif isinstance(image_files, (bytes, bytearray)): |
| import io as _io |
| image_files = [Image.open(_io.BytesIO(image_files)).convert("RGB")] |
| else: |
| image_files = [img_f.convert("RGB") if isinstance(img_f, Image.Image) else Image.open(img_f).convert("RGB") for img_f in image_files] |
|
|
| images = [] |
| |
| for image_file in image_files: |
| |
| |
| |
| |
| |
|
|
| |
| images.append(get_image_info(image_file, self.image_min_pixel, self.image_max_pixel, self.image_resized_w, self.image_resized_h)) |
|
|
| elif "video" in sources: |
| is_video = True |
| images=None |
| grid_key = "video_grid_thw" |
| pixel_key = "pixel_values_videos" |
|
|
| video_files = sources["video"] |
| video_folder = self.data_args.image_folder |
|
|
| if isinstance(video_files, str): |
| video_files = [video_files] |
|
|
| videos = [] |
| for video_file in video_files: |
| if not os.path.exists(video_file): |
| if not video_file.startswith("http"): |
| video_file = os.path.join(video_folder, video_file) |
| video_input, video_kwargs = get_video_info(video_file, self.video_min_pixel, self.video_max_pixel, self.data_args.fps) |
| videos.append(video_input) |
| else: |
| grid_key = None |
| pixel_key = None |
| images=None |
| videos=None |
| |
| if images is None: |
| |
| print("No image or video found in the data.") |
| images = [] |
| |
| black_image = Image.new("RGB", (self.image_resized_w, self.image_resized_h), (0, 0, 0)) |
| images.append(get_image_info(black_image, self.image_min_pixel, self.image_max_pixel, self.image_resized_w, self.image_resized_h)) |
|
|
| elif len(images) == 0: |
| print("No image or video found in the data.") |
| |
| black_image = Image.new("RGB", (self.image_resized_w, self.image_resized_h), (0, 0, 0)) |
| images.append(get_image_info(black_image, self.image_min_pixel, self.image_max_pixel, self.image_resized_w, self.image_resized_h)) |
| |
| if videos is not None: |
| |
| |
| pass |
| |
| sources = copy.deepcopy(llava_to_openai(sources['conversations'], is_video=is_video)) |
|
|
| all_input_ids = [] |
| all_labels = [] |
| all_pixel_values = [] |
| all_image_grid_thw = [] |
| all_second_gird = [] |
| |
| |
|
|
| |
| if len(SYSTEM_MESSAGE) > 0: |
| system_message = f"{DEFAULT_IM_START_TOKEN}system\n{SYSTEM_MESSAGE}\n{DEFAULT_IM_END_TOKEN}\n" |
| system_message_input_ids = processor.tokenizer(system_message, add_special_tokens=False, return_tensors='pt')['input_ids'] |
| system_labels = torch.full_like(system_message_input_ids, IGNORE_INDEX) |
| |
| all_input_ids.append(system_message_input_ids.squeeze(0)) |
| all_labels.append(system_labels.squeeze(0)) |
| |
| for _, j in enumerate(range(0, len(sources), 2)): |
| |
| if j >= 2: |
| break |
| |
| user_input = sources[j] |
| gpt_response = sources[j + 1] |
| |
| if (DEFAULT_IMAGE_TOKEN not in user_input['content']) and (DEFAULT_VIDEO_TOKEN not in user_input['content']) and (LLAVA_IMAGE_TOKEN in user_input['content']): |
| user_input = f"{DEFAULT_IM_START_TOKEN}{VISION_START_TOKEN + DEFAULT_IMAGE_TOKEN + VISION_END_TOKEN}{user_input['role']}\n{user_input['content']}\n{DEFAULT_IM_END_TOKEN}\n{DEFAULT_IM_START_TOKEN}{gpt_response['role']}\n" |
| user_input = add_anchor_pad(user_input, self.anchor_token_nums, self.anchor_tokens) |
| gpt_response = f"{gpt_response['content']}\n{DEFAULT_IM_END_TOKEN}\n" |
| raise ValueError('Every man is a poet when he is in love') |
| else: |
| if self.cur_step < self.stage_0_step: |
| user_input = f"{DEFAULT_IM_START_TOKEN}{user_input['role']}\n{user_input['content']}\n{DEFAULT_IM_END_TOKEN}\n{DEFAULT_IM_START_TOKEN}{gpt_response['role']}\n" |
| user_input = add_anchor_pad(user_input, self.anchor_token_nums, self.anchor_tokens) |
| gpt_response = f"{gpt_response['content']}\n{DEFAULT_IM_END_TOKEN}\n" |
| elif self.cur_step < self.stage_1_step: |
| user_input, gpt_response = get_feature_data(user_input, gpt_response, self.anchor_token_nums, self.anchor_tokens, self.anchor_task_names) |
| elif self.cur_step < self.stage_2_step: |
| user_input = f"{DEFAULT_IM_START_TOKEN}{user_input['role']}\n{user_input['content']}\n{DEFAULT_IM_END_TOKEN}\n{DEFAULT_IM_START_TOKEN}{gpt_response['role']}\n" |
| gpt_response = f"{gpt_response['content']}" |
| if DEFAULT_IMAGE_TOKEN in user_input: |
| gpt_response = get_comt_data_in_response(gpt_response, self.anchor_token_nums, self.anchor_tokens, self.anchor_task_names) |
| gpt_response = f"{gpt_response}\n{DEFAULT_IM_END_TOKEN}\n" |
| |
| else: |
| |
| |
| |
| import random |
| xxx = random.randint(0, 5) |
| if xxx == 0: |
| user_input = f"{DEFAULT_IM_START_TOKEN}{user_input['role']}\n{user_input['content']}\n{DEFAULT_IM_END_TOKEN}\n{DEFAULT_IM_START_TOKEN}{gpt_response['role']}\n" |
| gpt_response = f"{gpt_response['content']}\n{DEFAULT_IM_END_TOKEN}\n" |
| else: |
| user_input = f"{DEFAULT_IM_START_TOKEN}{user_input['role']}\n{user_input['content']}\n{DEFAULT_IM_END_TOKEN}\n{DEFAULT_IM_START_TOKEN}{gpt_response['role']}\n" |
| gpt_response = f"{gpt_response['content']}" |
| if DEFAULT_IMAGE_TOKEN in user_input: |
| |
| total = len(self.anchor_tokens) |
| if total == 0: |
| selected_anchor_token_nums = [] |
| selected_anchor_tokens = [] |
| selected_anchor_task_names = [] |
| else: |
| x = random.randint(1, total) |
| idxs = sorted(random.sample(range(total), x)) if x > 0 else [] |
| selected_anchor_token_nums = [self.anchor_token_nums[i] for i in idxs] |
| selected_anchor_tokens = [self.anchor_tokens[i] for i in idxs] |
| selected_anchor_task_names = [self.anchor_task_names[i] for i in idxs] |
| gpt_response = get_comt_data_in_response(gpt_response, selected_anchor_token_nums, selected_anchor_tokens, selected_anchor_task_names) |
| gpt_response = f"{gpt_response}\n{DEFAULT_IM_END_TOKEN}\n" |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| if DEFAULT_IMAGE_TOKEN in user_input: |
| inputs = processor(text=[user_input], images=images, videos=videos, padding=False, return_tensors='pt') |
| prompt_input_ids = inputs['input_ids'] |
| |
| all_pixel_values.append(inputs[pixel_key]) |
| all_image_grid_thw.append(inputs[grid_key]) |
| |
| |
| torch.cuda.empty_cache() |
| |
| |
| elif DEFAULT_VIDEO_TOKEN in user_input: |
| if "Qwen2.5" in self.model_id: |
| inputs = processor(text=[user_input], images=images, videos=videos, padding=False, return_tensors='pt', **video_kwargs) |
| all_second_gird.extend(inputs["second_per_grid_ts"]) |
| else: |
| inputs = processor(text=[user_input], images=images, videos=videos, padding=False, return_tensors='pt') |
| prompt_input_ids = inputs['input_ids'] |
| all_pixel_values.append(inputs[pixel_key]) |
| all_image_grid_thw.append(inputs[grid_key]) |
|
|
| else: |
| prompt_input_ids = processor.tokenizer(user_input, add_special_tokens=False, padding=False, return_tensors='pt')['input_ids'] |
|
|
| response_input_ids = processor.tokenizer(gpt_response, add_special_tokens=False, padding=False, return_tensors='pt')['input_ids'] |
|
|
| input_ids = torch.cat([prompt_input_ids, response_input_ids], dim=1).squeeze(0) |
| labels = torch.cat( |
| [ |
| torch.tensor([IGNORE_INDEX] * len(prompt_input_ids[0])), |
| response_input_ids.squeeze(0), |
| ], |
| dim=0, |
| ) |
|
|
| all_input_ids.append(input_ids) |
| all_labels.append(labels) |
| |
| |
| |
| input_ids = torch.cat(all_input_ids, dim=0).to(torch.long) |
| labels = torch.cat(all_labels, dim=0).to(torch.long) |
|
|
| |
| |
|
|
| attention_mask = (input_ids > -1000000).to(torch.long) |
| |
| |
| |
| data_dict = dict( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| labels=labels, |
| ) |
|
|
| if pixel_key and grid_key: |
| pixel_values = torch.cat(all_pixel_values, dim=0) |
| image_thw = torch.cat(all_image_grid_thw, dim=0) |
| |
| data_dict[pixel_key] = pixel_values |
| data_dict[grid_key] = image_thw |
| data_dict["image_files"] = image_files |
|
|
| if len(all_second_gird) > 0: |
| second_gird = all_second_gird |
| data_dict["second_per_grid_ts"] = second_gird |
| |
| self.cur_step += 1 |
| |
| return data_dict |
|
|
|
|
| class DataCollatorForSupervisedDataset(object): |
| """Collate examples for supervised fine-tuning.""" |
|
|
| def __init__(self, pad_token_id: int): |
| self.pad_token_id = pad_token_id |
|
|
| def __call__(self, examples): |
| batch_input_ids = [] |
| batch_label_ids = [] |
| batch_pixel_values = [] |
| batch_pixel_video_values = [] |
| batch_video_thw = [] |
| batch_image_thw = [] |
| batch_second_per_grid_ts = [] |
| |
| batch_image_files = [] |
| |
| for example in examples: |
| keys = example.keys() |
| if "pixel_values_videos" in keys: |
| batch_pixel_video_values.append(example["pixel_values_videos"]) |
| batch_video_thw.append(example["video_grid_thw"]) |
| elif "pixel_values" in keys: |
| batch_pixel_values.append(example["pixel_values"]) |
| batch_image_thw.append(example["image_grid_thw"]) |
| |
| if "image_files" in keys: |
| batch_image_files.append(example["image_files"]) |
| |
| batch_input_ids.append(example["input_ids"]) |
| batch_label_ids.append(example["labels"]) |
|
|
| if "second_per_grid_ts" in keys: |
| batch_second_per_grid_ts.extend(example["second_per_grid_ts"]) |
| |
| input_ids = pad_sequence( |
| batch_input_ids, padding_side='right', padding_value=self.pad_token_id |
| ) |
|
|
| attention_mask = input_ids != self.pad_token_id |
| labels = pad_sequence(batch_label_ids, padding_side='right', padding_value=IGNORE_INDEX) |
|
|
| data_dict = { |
| 'input_ids': input_ids, |
| 'labels': labels, |
| 'attention_mask': attention_mask, |
| } |
|
|
| if len(batch_pixel_values) > 0: |
| pixel_values = torch.cat(batch_pixel_values, dim=0) |
| image_thw = torch.cat(batch_image_thw, dim=0) |
| data_dict["pixel_values"] = pixel_values |
| data_dict["image_grid_thw"] = image_thw |
|
|
| if len(batch_pixel_video_values) > 0: |
| pixel_video_values = torch.cat(batch_pixel_video_values, dim=0) |
| video_thw = torch.cat(batch_video_thw, dim=0) |
| data_dict["pixel_values_videos"] = pixel_video_values |
| data_dict["video_grid_thw"] = video_thw |
|
|
| if len(batch_second_per_grid_ts) > 0: |
| data_dict["second_per_grid_ts"] = batch_second_per_grid_ts |
| |
| if len(batch_image_files) > 0: |
| data_dict["image_files"] = batch_image_files |
|
|
| return data_dict |
|
|
| def replace_image_tokens(input_string, is_video=False): |
| if is_video: |
| pattern = r'\n?' + re.escape(LLAVA_VIDEO_TOKEN) + r'\n?' |
| replacement = VISION_START_TOKEN + DEFAULT_VIDEO_TOKEN + VISION_END_TOKEN |
| else: |
| pattern = r'\n?' + re.escape(LLAVA_IMAGE_TOKEN) + r'\n?' |
| replacement = VISION_START_TOKEN + DEFAULT_IMAGE_TOKEN + VISION_END_TOKEN |
|
|
| return re.sub(pattern, replacement, input_string) |
|
|
| def llava_to_openai(conversations, is_video=False): |
| role_mapping = {"human": "user", "gpt": "assistant"} |
|
|
| transformed_data = [] |
| for conversation in conversations: |
| transformed_content = replace_image_tokens(conversation["value"], is_video=is_video) |
| transformed_entry = { |
| "role": role_mapping.get(conversation["from"], conversation["from"]), |
| "content": transformed_content, |
| } |
| transformed_data.append(transformed_entry) |
|
|
| return transformed_data |
|
|
| def make_supervised_data_module(model_id, processor, data_args, anchor_model_id): |
| """Make dataset and collator for supervised fine-tuning.""" |
| sft_dataset = SupervisedDataset( |
| data_path=data_args.data_path, processor=processor, data_args=data_args, model_id=model_id, anchor_model_id=anchor_model_id |
| ) |
| data_collator = DataCollatorForSupervisedDataset(pad_token_id=processor.tokenizer.pad_token_id) |
|
|
| return dict(train_dataset=sft_dataset, |
| eval_dataset=None, |
| data_collator=data_collator) |