import copy
import os
from dataclasses import dataclass, field
from typing import Dict
import torch
import transformers
import ujson as json
from torch.utils.data import Dataset
from qwen_vl_utils import process_vision_info
from PIL import Image
from transformers import AutoImageProcessor
import re
import numpy as np
import cv2
from torchvision import transforms
import random
from segment_anything import build_sam_vit_h, sam_model_registry, SamPredictor
from src.anchors.DepthAnything.depth_anything_v2.dpt import DepthAnythingV2
from diffusers import AutoencoderKL
from transformers import AutoModel, CLIPImageProcessor
from .params import DataArguments
from .constants import *
def truncate_sequence(input_ids, labels, max_length, eos_token_id):
if input_ids.size(0) > max_length:
input_ids = input_ids[:max_length-1]
labels = labels[:max_length-1]
if eos_token_id is not None:
input_ids = torch.cat([input_ids, torch.tensor([eos_token_id])])
labels = torch.cat([labels, torch.tensor([eos_token_id])])
return input_ids, labels
def pad_sequence(sequences, padding_side='right', padding_value=0):
"""
Pad a list of sequences to the same length.
sequences: list of tensors in [seq_len, *] shape
"""
assert padding_side in ['right', 'left']
max_size = sequences[0].size()
trailing_dims = max_size[1:]
max_len = max(len(seq) for seq in sequences)
batch_size = len(sequences)
output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
for i, seq in enumerate(sequences):
length = seq.size(0)
if padding_side == 'right':
output.data[i, :length] = seq
else:
output.data[i, -length:] = seq
return output
def get_image_info(image_path, min_pixel, max_pixel, width, height):
# Using this because of process_vision_info function
# Need to fix this in the future
content = {
"type": "image",
"image": image_path,
"min_pixel": min_pixel,
"max_pixel": max_pixel
}
if width is not None and height is not None:
content["resized_width"] = width
content["resized_height"] = height
messages = [
{"role": "user",
"content": [content]
}
]
image_input, _ = process_vision_info(messages)
return image_input[0]
def get_video_info(video_path, min_pixels, max_pixels, fps):
# Using this because of process_vision_info function
# Need to fix this in the future
messages = [
{"role": "user",
"content": [
{
"type": "video",
"video": video_path,
"min_pixels": min_pixels,
"max_pixels": max_pixels,
"fps": fps
}
]
}
]
_, video_input, video_kwargs = process_vision_info(messages, return_video_kwargs=True)
return video_input[0], video_kwargs
def add_anchor_pad(user_input, anchor_nums, anchor_tokens):
# add anchor pad after VISION_END_TOKEN or ANCHOR_END_TOKEN
anchor_pads = []
for anchor_num, anchor_token in zip(anchor_nums, anchor_tokens):
anchor_pad = ANCHOR_START_TOKEN + anchor_token * anchor_num + ANCHOR_END_TOKEN
anchor_pads.append(anchor_pad)
anchor_pads = "".join(anchor_pads)
if VISION_END_TOKEN in user_input:
user_input = user_input.replace(VISION_END_TOKEN, VISION_END_TOKEN + anchor_pads)
return user_input
def add_cot_anchor_pad_in_user_input(user_input, anchor_nums, anchor_tokens):
if len(anchor_nums) == 0:
return user_input
anchor_pads = []
for anchor_num, anchor_token in zip(anchor_nums, anchor_tokens, anchor_names):
anchor_pad = ANCHOR_START_TOKEN + anchor_token * anchor_num + ANCHOR_END_TOKEN
anchor_pads.append(anchor_pad)
CoT_pad = ""
if len(anchor_pads) == 1:
CoT_pad = f"The {anchor_names[0]} of the image is {anchor_pads[0]}. "
else:
for i, (anchor_name, anchor_pad) in enumerate(zip(anchor_names, anchor_pads)):
if i == 0:
CoT_pad += f"The {anchor_name} of the image is {anchor_pad}, "
elif i == len(anchor_names) - 1:
CoT_pad += f"and the {anchor_name} of the image is {anchor_pad}. "
else:
CoT_pad += f"the {anchor_name} of the image is {anchor_pad}, "
user_input = CoT_pad + user_input
return user_input
def get_cot_data_in_response(response, anchor_nums, anchor_tokens, anchor_names):
if len(anchor_nums) == 0:
return response
anchor_pads = []
for anchor_num, anchor_token in zip(anchor_nums, anchor_tokens):
anchor_pad = ANCHOR_START_TOKEN + anchor_token * anchor_num + ANCHOR_END_TOKEN
anchor_pads.append(anchor_pad)
CoT_start = "Because "
if len(anchor_names) == 1:
CoT_start += f"the {anchor_names[0]} of the image is {anchor_pads[0]}. "
else:
for anchor_name, anchor_pad in zip(anchor_names, anchor_pads):
CoT_start += f"the {anchor_name} of the image is {anchor_pad}"
if anchor_name == anchor_names[-2]:
CoT_start += ", and "
elif anchor_name == anchor_names[-1]:
CoT_start += ". "
else:
CoT_start += ", "
response = CoT_start + response
return response
COT_TEMPLATES = [
{
"name": "basic_causal",
"single": "Because the {anchor_name} of the image is {anchor_pad}. ",
"multiple": "Because the {anchor_name} of the image is {anchor_pad}{connector}",
"connectors": {
"middle": ", ",
"second_last": ", and ",
"last": ". "
}
},
{
"name": "observational",
"single": "I can observe that the {anchor_name} of the image is {anchor_pad}. ",
"multiple": "I can observe that the {anchor_name} of the image is {anchor_pad}{connector}",
"connectors": {
"middle": ", ",
"second_last": ", and ",
"last": ". "
}
},
{
"name": "analytical",
"single": "After analyzing the image, the {anchor_name} is {anchor_pad}. ",
"multiple": "After analyzing the image, the {anchor_name} is {anchor_pad}{connector}",
"connectors": {
"middle": ", ",
"second_last": ", and ",
"last": ". "
}
},
{
"name": "descriptive",
"single": "The image shows that the {anchor_name} is {anchor_pad}. ",
"multiple": "The image shows that the {anchor_name} is {anchor_pad}{connector}",
"connectors": {
"middle": ", ",
"second_last": ", and ",
"last": ". "
}
},
{
"name": "conditional",
"single": "Given that the {anchor_name} of the image is {anchor_pad}. ",
"multiple": "Given that the {anchor_name} of the image is {anchor_pad}{connector}",
"connectors": {
"middle": ", ",
"second_last": ", and ",
"last": ". "
}
},
{
"name": "evidence_based",
"single": "Based on the visual evidence, the {anchor_name} of the image is {anchor_pad}. ",
"multiple": "Based on the visual evidence, the {anchor_name} of the image is {anchor_pad}{connector}",
"connectors": {
"middle": ", ",
"second_last": ", and ",
"last": ". "
}
}
]
def get_random_cot_template():
return random.choice(COT_TEMPLATES)
def apply_cot_template(template, anchor_names, anchor_pads):
if len(anchor_names) == 1:
return template["single"].format(
anchor_name=anchor_names[0],
anchor_pad=anchor_pads[0]
)
else:
result = ""
for i, (anchor_name, anchor_pad) in enumerate(zip(anchor_names, anchor_pads)):
if i == len(anchor_names) - 1:
connector = template["connectors"]["last"]
elif i == len(anchor_names) - 2:
connector = template["connectors"]["second_last"]
else:
connector = template["connectors"]["middle"]
result += template["multiple"].format(
anchor_name=anchor_name,
anchor_pad=anchor_pad,
connector=connector
)
return result
def get_templates_comt_data_in_response(response, anchor_nums, anchor_tokens, anchor_names):
if len(anchor_nums) == 0:
return response
anchor_pads = []
for anchor_num, anchor_token in zip(anchor_nums, anchor_tokens):
anchor_pad = ANCHOR_START_TOKEN + anchor_token * anchor_num + ANCHOR_END_TOKEN
anchor_pads.append(anchor_pad)
template = get_random_cot_template()
cot_text = apply_cot_template(template, anchor_names, anchor_pads)
response = "" + cot_text + "" + "" + response + ""
return response
def get_comt_data_in_response(response, anchor_nums, anchor_tokens, anchor_names):
if len(anchor_nums) == 0:
return response
anchor_pads = []
for anchor_num, anchor_token in zip(anchor_nums, anchor_tokens):
anchor_pad = ANCHOR_START_TOKEN + anchor_token * anchor_num + ANCHOR_END_TOKEN
anchor_pads.append(anchor_pad)
CoT_start = " Because "
if len(anchor_names) == 1:
CoT_start += f"the {anchor_names[0]} of the image is {anchor_pads[0]}. "
else:
for anchor_name, anchor_pad in zip(anchor_names, anchor_pads):
CoT_start += f"the {anchor_name} of the image is {anchor_pad}"
if anchor_name == anchor_names[-2]:
CoT_start += ", and "
elif anchor_name == anchor_names[-1]:
CoT_start += ". "
else:
CoT_start += ", "
response = CoT_start + " \n" + " " + response + " "
return response
def get_feature_data(user_input, gpt_response, anchor_nums, anchor_tokens, anchor_names):
anchor_pads = []
for anchor_num, anchor_token, anchor_name in zip(anchor_nums, anchor_tokens, anchor_names):
anchor_pad = ANCHOR_START_TOKEN + anchor_token * anchor_num + ANCHOR_END_TOKEN
anchor_pads.append(anchor_pad)
anchor_name = ", ".join(anchor_names)
anchor_pads = "".join(anchor_pads)
user_input = f"{DEFAULT_IM_START_TOKEN}{user_input['role']}\n{VISION_START_TOKEN + DEFAULT_IMAGE_TOKEN + VISION_END_TOKEN}What is the {anchor_name} of the image?\n{DEFAULT_IM_END_TOKEN}\n{DEFAULT_IM_START_TOKEN}{gpt_response['role']}\n"
gpt_response = f"{anchor_pads}\n{DEFAULT_IM_END_TOKEN}\n"
return user_input, gpt_response
def replace_pad_with_anchor_tokens(gpt_response):
token_dict = {
"": SAM_PAD_TOKEN * 8,
"": DEPTH_PAD_TOKEN * 4,
"": DINO_PAD_TOKEN * 4,
"": PIDINET_PAD_TOKEN * 4,
"": SIGLIP_PAD_TOKEN * 4,
"": METACLIP_PAD_TOKEN * 4,
}
for token, anchor_token in token_dict.items():
gpt_response = gpt_response.replace(token, anchor_token)
return gpt_response
def get_token_num(anchor_model_id):
token_nums = []
for anchor_model in anchor_model_id:
if anchor_model == "sam":
token_nums.append(8)
elif anchor_model == "dino":
token_nums.append(4)
elif anchor_model == "depth":
token_nums.append(4)
elif anchor_model == "SD":
token_nums.append(4)
elif anchor_model == "InternViT":
token_nums.append(4)
elif anchor_model == "pidinet":
token_nums.append(4)
elif anchor_model == "siglip":
token_nums.append(4)
elif anchor_model == "metaclip":
token_nums.append(4)
return token_nums
def get_anchor_token(anchor_model_id):
anchor_tokens = []
for anchor_model in anchor_model_id:
if anchor_model == "sam":
anchor_tokens.append(SAM_PAD_TOKEN)
elif anchor_model == "dino":
anchor_tokens.append(DINO_PAD_TOKEN)
elif anchor_model == "depth":
anchor_tokens.append(DEPTH_PAD_TOKEN)
elif anchor_model == "SD":
anchor_tokens.append(SD_PAD_TOKEN)
elif anchor_model == "InternViT":
anchor_tokens.append(INTERN_PAD_TOKEN)
elif anchor_model == "pidinet":
anchor_tokens.append(PIDINET_PAD_TOKEN)
elif anchor_model == "siglip":
anchor_tokens.append(SIGLIP_PAD_TOKEN)
elif anchor_model == "metaclip":
anchor_tokens.append(METACLIP_PAD_TOKEN)
return anchor_tokens
def get_anchor_task_name(anchor_model_id):
anchor_task_names = []
for anchor_model in anchor_model_id:
if anchor_model == "sam":
anchor_task_names.append("segmentation")
elif anchor_model == "dino":
anchor_task_names.append("perception feature")
elif anchor_model == "depth":
anchor_task_names.append("depth map")
elif anchor_model == "SD":
anchor_task_names.append("style")
elif anchor_model == "InternViT":
anchor_task_names.append("caption")
elif anchor_model == "pidinet":
anchor_task_names.append("edge map")
elif anchor_model == "siglip":
anchor_task_names.append("clip feature")
elif anchor_model == "metaclip":
anchor_task_names.append("metaclip feature")
return anchor_task_names
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(
self,
data_path: str | list,
processor: transformers.ProcessorMixin,
data_args: DataArguments,
model_id,
padding=True,
shuffle=True,
random_seed=42,
anchor_model_id=None,
):
super(SupervisedDataset, self).__init__()
if isinstance(data_path, str):
import os as _os, glob as _glob
if _os.path.isdir(data_path):
# Parquet streaming: no HF cache, reads rows on-demand via pyarrow
import pyarrow.parquet as _pq
import bisect as _bisect
parquet_files = sorted(_glob.glob(_os.path.join(data_path, "**", "*.parquet"), recursive=True))
if not parquet_files:
raise FileNotFoundError(f"No parquet files found under {data_path}")
print(f"[DataLoader] Found {len(parquet_files)} parquet files (streaming mode)")
cumulative_rows = []
cumsum = 0
for _f in parquet_files:
_pf = _pq.ParquetFile(_f)
cumsum += _pf.metadata.num_rows
cumulative_rows.append(cumsum)
self._is_parquet = True
self._parquet_files = parquet_files
self._cumulative_rows = cumulative_rows
self._total_rows = cumsum
print(f"[DataLoader] Total parquet rows: {cumsum}")
self._pf_handles = {}
self._row_group_index = {}
list_data_dict = None
else:
self._is_parquet = False
list_data_dict = json.load(open(data_path, "r"))
else:
self._is_parquet = False
list_data_dict = data_path
self.model_id = model_id
self.processor = processor
self.list_data_dict = list_data_dict
self.data_args = data_args
self.padding = padding
self.image_min_pixel = data_args.image_min_pixels
self.image_max_pixel = data_args.image_max_pixels
self.image_resized_w = data_args.image_resized_width
self.image_resized_h = data_args.image_resized_height
self.video_min_pixel = data_args.video_min_pixels
self.video_max_pixel = data_args.video_max_pixels
self.fps = data_args.fps
self.anchor_model_id = anchor_model_id
self.anchor_token_nums = get_token_num(anchor_model_id)
self.anchor_tokens = get_anchor_token(anchor_model_id)
self.anchor_task_names = get_anchor_task_name(anchor_model_id)
self.cur_step = 0
self.stage_0_step = data_args.stage_0_step
self.stage_1_step = data_args.stage_1_step
self.stage_2_step = data_args.stage_2_step
# for shuffle
self.rng = np.random.default_rng(seed=random_seed)
if shuffle:
if self._is_parquet:
import numpy as _np
self._shuffle_perm = _np.random.RandomState(random_seed).permutation(self._total_rows)
else:
from datasets import Dataset as _HFDataset
if isinstance(self.list_data_dict, _HFDataset):
self.list_data_dict = self.list_data_dict.shuffle(seed=random_seed)
else:
self.rng.shuffle(self.list_data_dict)
else:
if self._is_parquet:
self._shuffle_perm = None
def set_cur_step(self, step: int):
self.cur_step = step
print(f"[Dataset] cur_step has been set to {step}")
def __len__(self):
if self._is_parquet:
return self._total_rows
return len(self.list_data_dict)
def _get_parquet_row(self, idx):
import pyarrow.parquet as pq, bisect, io
real_idx = self._shuffle_perm[idx] if self._shuffle_perm is not None else idx
# bisect to find file
fi = bisect.bisect_right(self._cumulative_rows, real_idx)
local_idx = real_idx - self._cumulative_rows[fi - 1] if fi > 0 else real_idx
fpath = self._parquet_files[fi]
# Open ParquetFile lazily and cache handle
if fpath not in self._pf_handles:
self._pf_handles[fpath] = pq.ParquetFile(fpath)
pf = self._pf_handles[fpath]
# Build row-group cumulative table
cum = []
s = 0
for rg_idx in range(pf.num_row_groups):
s += pf.metadata.row_group(rg_idx).num_rows
cum.append(s)
self._row_group_index[fpath] = cum
pf = self._pf_handles[fpath]
cum = self._row_group_index[fpath]
# Find which row group contains local_idx
rg_idx = bisect.bisect_right(cum, local_idx)
rg_start = 0 if rg_idx == 0 else cum[rg_idx - 1]
in_rg_idx = local_idx - rg_start
# Read just this row group (typically ~1k-10k rows, much smaller than file)
table = pf.read_row_group(rg_idx)
row = table.slice(in_rg_idx, 1).to_pylist()[0]
# Image decode
image_data = row.get('image')
if isinstance(image_data, dict) and image_data.get('bytes'):
row['image'] = Image.open(io.BytesIO(image_data['bytes'])).convert('RGB')
elif isinstance(image_data, dict) and image_data.get("path"):
row["image"] = Image.open(image_data["path"]).convert("RGB")
return row
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
# import ipdb; ipdb.set_trace()
if self._is_parquet:
sources = self._get_parquet_row(i)
else:
sources = self.list_data_dict[i]
is_video = False
processor = self.processor
if "image" in sources:
videos = None
grid_key = "image_grid_thw"
pixel_key = "pixel_values"
image_files = sources["image"]
image_folder = self.data_args.image_folder
if isinstance(image_files, Image.Image):
# Already a PIL Image (e.g. from parquet HF dataset)
image_files = [image_files.convert("RGB")]
elif isinstance(image_files, str):
image_files = Image.open(image_files).convert("RGB")
image_files = [image_files]
elif isinstance(image_files, (bytes, bytearray)):
import io as _io
image_files = [Image.open(_io.BytesIO(image_files)).convert("RGB")]
else:
image_files = [img_f.convert("RGB") if isinstance(img_f, Image.Image) else Image.open(img_f).convert("RGB") for img_f in image_files]
images = []
for image_file in image_files:
# if not os.path.exists(image_file):
# if not image_file.startswith("http"):
# image_file = os.path.join(image_folder, image_file)
# images.append(get_image_info(image_file, self.image_min_pixel, self.image_max_pixel, self.image_resized_w, self.image_resized_h))
# else:
images.append(get_image_info(image_file, self.image_min_pixel, self.image_max_pixel, self.image_resized_w, self.image_resized_h))
elif "video" in sources:
is_video = True
images=None
grid_key = "video_grid_thw"
pixel_key = "pixel_values_videos"
video_files = sources["video"]
video_folder = self.data_args.image_folder
if isinstance(video_files, str):
video_files = [video_files]
videos = []
for video_file in video_files:
if not os.path.exists(video_file):
if not video_file.startswith("http"):
video_file = os.path.join(video_folder, video_file)
video_input, video_kwargs = get_video_info(video_file, self.video_min_pixel, self.video_max_pixel, self.data_args.fps)
videos.append(video_input)
else:
grid_key = None
pixel_key = None
images=None
videos=None
if images is None:
print("No image or video found in the data.")
images = []
# Create a black image as a placeholder
black_image = Image.new("RGB", (self.image_resized_w, self.image_resized_h), (0, 0, 0))
images.append(get_image_info(black_image, self.image_min_pixel, self.image_max_pixel, self.image_resized_w, self.image_resized_h))
elif len(images) == 0:
print("No image or video found in the data.")
# Create a black image as a placeholder
black_image = Image.new("RGB", (self.image_resized_w, self.image_resized_h), (0, 0, 0))
images.append(get_image_info(black_image, self.image_min_pixel, self.image_max_pixel, self.image_resized_w, self.image_resized_h))
if videos is not None:
# import ipdb; ipdb.set_trace()
pass
sources = copy.deepcopy(llava_to_openai(sources['conversations'], is_video=is_video))
all_input_ids = []
all_labels = []
all_pixel_values = []
all_image_grid_thw = []
all_second_gird = []
# all_dino_encoded_values = []
# Qwen2-VL uses a default system message so I've added this.
if len(SYSTEM_MESSAGE) > 0:
system_message = f"{DEFAULT_IM_START_TOKEN}system\n{SYSTEM_MESSAGE}\n{DEFAULT_IM_END_TOKEN}\n"
system_message_input_ids = processor.tokenizer(system_message, add_special_tokens=False, return_tensors='pt')['input_ids']
system_labels = torch.full_like(system_message_input_ids, IGNORE_INDEX)
all_input_ids.append(system_message_input_ids.squeeze(0))
all_labels.append(system_labels.squeeze(0))
for _, j in enumerate(range(0, len(sources), 2)):
if j >= 2:
break
user_input = sources[j]
gpt_response = sources[j + 1]
if (DEFAULT_IMAGE_TOKEN not in user_input['content']) and (DEFAULT_VIDEO_TOKEN not in user_input['content']) and (LLAVA_IMAGE_TOKEN in user_input['content']):
user_input = f"{DEFAULT_IM_START_TOKEN}{VISION_START_TOKEN + DEFAULT_IMAGE_TOKEN + VISION_END_TOKEN}{user_input['role']}\n{user_input['content']}\n{DEFAULT_IM_END_TOKEN}\n{DEFAULT_IM_START_TOKEN}{gpt_response['role']}\n"
user_input = add_anchor_pad(user_input, self.anchor_token_nums, self.anchor_tokens)
gpt_response = f"{gpt_response['content']}\n{DEFAULT_IM_END_TOKEN}\n"
raise ValueError('Every man is a poet when he is in love')
else:
if self.cur_step < self.stage_0_step:
user_input = f"{DEFAULT_IM_START_TOKEN}{user_input['role']}\n{user_input['content']}\n{DEFAULT_IM_END_TOKEN}\n{DEFAULT_IM_START_TOKEN}{gpt_response['role']}\n"
user_input = add_anchor_pad(user_input, self.anchor_token_nums, self.anchor_tokens)
gpt_response = f"{gpt_response['content']}\n{DEFAULT_IM_END_TOKEN}\n"
elif self.cur_step < self.stage_1_step:
user_input, gpt_response = get_feature_data(user_input, gpt_response, self.anchor_token_nums, self.anchor_tokens, self.anchor_task_names)
elif self.cur_step < self.stage_2_step:
user_input = f"{DEFAULT_IM_START_TOKEN}{user_input['role']}\n{user_input['content']}\n{DEFAULT_IM_END_TOKEN}\n{DEFAULT_IM_START_TOKEN}{gpt_response['role']}\n"
gpt_response = f"{gpt_response['content']}"
if DEFAULT_IMAGE_TOKEN in user_input:
gpt_response = get_comt_data_in_response(gpt_response, self.anchor_token_nums, self.anchor_tokens, self.anchor_task_names)
gpt_response = f"{gpt_response}\n{DEFAULT_IM_END_TOKEN}\n"
# print(f"\033[92m gpt_response: {gpt_response}\033[0m")
else:
# user_input = f"{DEFAULT_IM_START_TOKEN}{user_input['role']}\n{user_input['content']}\n{DEFAULT_IM_END_TOKEN}\n{DEFAULT_IM_START_TOKEN}{gpt_response['role']}\n"
# gpt_response = f"{gpt_response['content']}\n{DEFAULT_IM_END_TOKEN}\n"
# gpt_response = replace_pad_with_anchor_tokens(gpt_response)
import random
xxx = random.randint(0, 5)
if xxx == 0:
user_input = f"{DEFAULT_IM_START_TOKEN}{user_input['role']}\n{user_input['content']}\n{DEFAULT_IM_END_TOKEN}\n{DEFAULT_IM_START_TOKEN}{gpt_response['role']}\n"
gpt_response = f"{gpt_response['content']}\n{DEFAULT_IM_END_TOKEN}\n"
else:
user_input = f"{DEFAULT_IM_START_TOKEN}{user_input['role']}\n{user_input['content']}\n{DEFAULT_IM_END_TOKEN}\n{DEFAULT_IM_START_TOKEN}{gpt_response['role']}\n"
gpt_response = f"{gpt_response['content']}"
if DEFAULT_IMAGE_TOKEN in user_input:
# INSERT_YOUR_CODE
total = len(self.anchor_tokens)
if total == 0:
selected_anchor_token_nums = []
selected_anchor_tokens = []
selected_anchor_task_names = []
else:
x = random.randint(1, total)
idxs = sorted(random.sample(range(total), x)) if x > 0 else []
selected_anchor_token_nums = [self.anchor_token_nums[i] for i in idxs]
selected_anchor_tokens = [self.anchor_tokens[i] for i in idxs]
selected_anchor_task_names = [self.anchor_task_names[i] for i in idxs]
gpt_response = get_comt_data_in_response(gpt_response, selected_anchor_token_nums, selected_anchor_tokens, selected_anchor_task_names)
gpt_response = f"{gpt_response}\n{DEFAULT_IM_END_TOKEN}\n"
# print(f'the user_input is {user_input}')
# print(f'the gpt_response is {gpt_response}')
# print("-----------------")
# print(user_input, gpt_response)
# print("-----------------")
# import ipdb; ipdb.set_trace()
if DEFAULT_IMAGE_TOKEN in user_input:
inputs = processor(text=[user_input], images=images, videos=videos, padding=False, return_tensors='pt')
prompt_input_ids = inputs['input_ids']
# raise ValueError('Every man is a poet when he is in love')
all_pixel_values.append(inputs[pixel_key])
all_image_grid_thw.append(inputs[grid_key])
# del dino_val
torch.cuda.empty_cache()
elif DEFAULT_VIDEO_TOKEN in user_input:
if "Qwen2.5" in self.model_id:
inputs = processor(text=[user_input], images=images, videos=videos, padding=False, return_tensors='pt', **video_kwargs)
all_second_gird.extend(inputs["second_per_grid_ts"])
else:
inputs = processor(text=[user_input], images=images, videos=videos, padding=False, return_tensors='pt')
prompt_input_ids = inputs['input_ids']
all_pixel_values.append(inputs[pixel_key])
all_image_grid_thw.append(inputs[grid_key])
else:
prompt_input_ids = processor.tokenizer(user_input, add_special_tokens=False, padding=False, return_tensors='pt')['input_ids']
response_input_ids = processor.tokenizer(gpt_response, add_special_tokens=False, padding=False, return_tensors='pt')['input_ids']
input_ids = torch.cat([prompt_input_ids, response_input_ids], dim=1).squeeze(0)
labels = torch.cat(
[
torch.tensor([IGNORE_INDEX] * len(prompt_input_ids[0])),
response_input_ids.squeeze(0),
],
dim=0,
)
all_input_ids.append(input_ids)
all_labels.append(labels)
# There is no need for eos or bos tokens in the input_ids
# Qwen2-VL does not use them
input_ids = torch.cat(all_input_ids, dim=0).to(torch.long)
labels = torch.cat(all_labels, dim=0).to(torch.long)
# eos_token_id = processor.tokenizer.convert_tokens_to_ids(DEFAULT_IM_END_TOKEN)
# input_ids, labels = truncate_sequence(input_ids, labels, self.max_length, eos_token_id)
attention_mask = (input_ids > -1000000).to(torch.long)
data_dict = dict(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
if pixel_key and grid_key:
pixel_values = torch.cat(all_pixel_values, dim=0)
image_thw = torch.cat(all_image_grid_thw, dim=0)
data_dict[pixel_key] = pixel_values
data_dict[grid_key] = image_thw
data_dict["image_files"] = image_files
if len(all_second_gird) > 0:
second_gird = all_second_gird
data_dict["second_per_grid_ts"] = second_gird
self.cur_step += 1
return data_dict
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
def __init__(self, pad_token_id: int):
self.pad_token_id = pad_token_id
def __call__(self, examples):
batch_input_ids = []
batch_label_ids = []
batch_pixel_values = []
batch_pixel_video_values = []
batch_video_thw = []
batch_image_thw = []
batch_second_per_grid_ts = []
batch_image_files = []
for example in examples:
keys = example.keys()
if "pixel_values_videos" in keys:
batch_pixel_video_values.append(example["pixel_values_videos"])
batch_video_thw.append(example["video_grid_thw"])
elif "pixel_values" in keys:
batch_pixel_values.append(example["pixel_values"])
batch_image_thw.append(example["image_grid_thw"])
if "image_files" in keys:
batch_image_files.append(example["image_files"])
batch_input_ids.append(example["input_ids"])
batch_label_ids.append(example["labels"])
if "second_per_grid_ts" in keys:
batch_second_per_grid_ts.extend(example["second_per_grid_ts"])
input_ids = pad_sequence(
batch_input_ids, padding_side='right', padding_value=self.pad_token_id
)
attention_mask = input_ids != self.pad_token_id
labels = pad_sequence(batch_label_ids, padding_side='right', padding_value=IGNORE_INDEX)
data_dict = {
'input_ids': input_ids,
'labels': labels,
'attention_mask': attention_mask,
}
if len(batch_pixel_values) > 0:
pixel_values = torch.cat(batch_pixel_values, dim=0)
image_thw = torch.cat(batch_image_thw, dim=0)
data_dict["pixel_values"] = pixel_values
data_dict["image_grid_thw"] = image_thw
if len(batch_pixel_video_values) > 0:
pixel_video_values = torch.cat(batch_pixel_video_values, dim=0)
video_thw = torch.cat(batch_video_thw, dim=0)
data_dict["pixel_values_videos"] = pixel_video_values
data_dict["video_grid_thw"] = video_thw
if len(batch_second_per_grid_ts) > 0:
data_dict["second_per_grid_ts"] = batch_second_per_grid_ts
if len(batch_image_files) > 0:
data_dict["image_files"] = batch_image_files
return data_dict
def replace_image_tokens(input_string, is_video=False):
if is_video:
pattern = r'\n?' + re.escape(LLAVA_VIDEO_TOKEN) + r'\n?'
replacement = VISION_START_TOKEN + DEFAULT_VIDEO_TOKEN + VISION_END_TOKEN
else:
pattern = r'\n?' + re.escape(LLAVA_IMAGE_TOKEN) + r'\n?'
replacement = VISION_START_TOKEN + DEFAULT_IMAGE_TOKEN + VISION_END_TOKEN
return re.sub(pattern, replacement, input_string)
def llava_to_openai(conversations, is_video=False):
role_mapping = {"human": "user", "gpt": "assistant"}
transformed_data = []
for conversation in conversations:
transformed_content = replace_image_tokens(conversation["value"], is_video=is_video)
transformed_entry = {
"role": role_mapping.get(conversation["from"], conversation["from"]),
"content": transformed_content,
}
transformed_data.append(transformed_entry)
return transformed_data
def make_supervised_data_module(model_id, processor, data_args, anchor_model_id):
"""Make dataset and collator for supervised fine-tuning."""
sft_dataset = SupervisedDataset(
data_path=data_args.data_path, processor=processor, data_args=data_args, model_id=model_id, anchor_model_id=anchor_model_id
)
data_collator = DataCollatorForSupervisedDataset(pad_token_id=processor.tokenizer.pad_token_id)
return dict(train_dataset=sft_dataset,
eval_dataset=None,
data_collator=data_collator)