|
|
import copy |
|
|
import os |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Dict |
|
|
import torch |
|
|
import transformers |
|
|
import ujson as json |
|
|
from torch.utils.data import Dataset |
|
|
from qwen_vl_utils import process_vision_info |
|
|
from PIL import Image |
|
|
import re |
|
|
import yaml |
|
|
import random |
|
|
import math |
|
|
import pprint |
|
|
|
|
|
from .params import DataArguments |
|
|
from .constants import * |
|
|
|
|
|
|
|
|
def truncate_sequence(input_ids, labels, max_length, eos_token_id): |
|
|
if input_ids.size(0) > max_length: |
|
|
input_ids = input_ids[:max_length-1] |
|
|
labels = labels[:max_length-1] |
|
|
|
|
|
if eos_token_id is not None: |
|
|
input_ids = torch.cat([input_ids, torch.tensor([eos_token_id])]) |
|
|
labels = torch.cat([labels, torch.tensor([eos_token_id])]) |
|
|
|
|
|
return input_ids, labels |
|
|
|
|
|
def pad_sequence(sequences, padding_side='right', padding_value=0): |
|
|
""" |
|
|
Pad a list of sequences to the same length. |
|
|
sequences: list of tensors in [seq_len, *] shape |
|
|
""" |
|
|
assert padding_side in ['right', 'left'] |
|
|
max_size = sequences[0].size() |
|
|
trailing_dims = max_size[1:] |
|
|
max_len = max(len(seq) for seq in sequences) |
|
|
batch_size = len(sequences) |
|
|
output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value) |
|
|
for i, seq in enumerate(sequences): |
|
|
length = seq.size(0) |
|
|
if padding_side == 'right': |
|
|
output.data[i, :length] = seq |
|
|
else: |
|
|
output.data[i, -length:] = seq |
|
|
return output |
|
|
|
|
|
def get_image_info(image_path, min_pixel, max_pixel): |
|
|
|
|
|
|
|
|
|
|
|
messages = [ |
|
|
{"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "image", |
|
|
"image": image_path, |
|
|
"min_pixel": min_pixel, |
|
|
"max_pixel": max_pixel |
|
|
|
|
|
} |
|
|
] |
|
|
} |
|
|
] |
|
|
|
|
|
image_input, _ = process_vision_info(messages) |
|
|
|
|
|
return image_input[0] |
|
|
|
|
|
def get_video_info(video_path, min_pixels, max_pixels, fps): |
|
|
|
|
|
|
|
|
|
|
|
messages = [ |
|
|
{"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "video", |
|
|
"video": video_path, |
|
|
"min_pixels": min_pixels, |
|
|
"max_pixels": max_pixels, |
|
|
"fps": fps |
|
|
} |
|
|
] |
|
|
} |
|
|
] |
|
|
|
|
|
_, video_input, video_kwargs = process_vision_info(messages, return_video_kwargs=True) |
|
|
|
|
|
return video_input[0], video_kwargs |
|
|
|
|
|
class SupervisedDataset(Dataset): |
|
|
"""Dataset for supervised fine-tuning.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
data_path: str | list, |
|
|
processor: transformers.ProcessorMixin, |
|
|
data_args: DataArguments, |
|
|
model_id, |
|
|
padding=True, |
|
|
): |
|
|
super(SupervisedDataset, self).__init__() |
|
|
if isinstance(data_path, str): |
|
|
if data_path.endswith(".json"): |
|
|
list_data_dict = json.load(open(data_path, "r")) |
|
|
|
|
|
elif data_path.endswith(".yaml"): |
|
|
list_data_dict = [] |
|
|
with open(data_path, "r") as file: |
|
|
yaml_data = yaml.safe_load(file) |
|
|
pprint.pprint(yaml_data) |
|
|
datasets = yaml_data.get("datasets") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_args.dataset_paths = [dataset.get("json_path") for dataset in datasets] |
|
|
for dataset in datasets: |
|
|
json_path = dataset.get("json_path") |
|
|
sampling_strategy = dataset.get("sampling_strategy", "all") |
|
|
sampling_number = None |
|
|
|
|
|
print(f"Loading {json_path} with {sampling_strategy} sampling strategy") |
|
|
|
|
|
if json_path.endswith(".jsonl"): |
|
|
cur_data_dict = [] |
|
|
with open(json_path, "r") as json_file: |
|
|
for line in json_file: |
|
|
cur_data_dict.append(json.loads(line.strip())) |
|
|
elif json_path.endswith(".json"): |
|
|
with open(json_path, "r") as json_file: |
|
|
cur_data_dict = json.load(json_file) |
|
|
else: |
|
|
raise ValueError(f"Unsupported file type: {json_path}") |
|
|
|
|
|
if ":" in sampling_strategy: |
|
|
sampling_strategy, sampling_number = sampling_strategy.split(":") |
|
|
if "%" in sampling_number: |
|
|
sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100) |
|
|
else: |
|
|
sampling_number = int(sampling_number) |
|
|
|
|
|
|
|
|
if sampling_strategy == "first" and sampling_number is not None: |
|
|
cur_data_dict = cur_data_dict[:sampling_number] |
|
|
elif sampling_strategy == "end" and sampling_number is not None: |
|
|
cur_data_dict = cur_data_dict[-sampling_number:] |
|
|
elif sampling_strategy == "random" and sampling_number is not None: |
|
|
random.shuffle(cur_data_dict) |
|
|
cur_data_dict = cur_data_dict[:sampling_number] |
|
|
|
|
|
print(f"Loaded {len(cur_data_dict)} samples from {json_path}") |
|
|
list_data_dict.extend(cur_data_dict) |
|
|
print(f"Loaded {len(list_data_dict)} samples from {data_path} in total") |
|
|
else: |
|
|
list_data_dict = data_path |
|
|
|
|
|
self.model_id = model_id |
|
|
self.processor = processor |
|
|
self.list_data_dict = list_data_dict |
|
|
self.data_args = data_args |
|
|
self.padding = padding |
|
|
self.image_min_pixel = data_args.image_min_pixels |
|
|
self.image_max_pixel = data_args.image_max_pixels |
|
|
self.video_min_pixel = data_args.video_min_pixels |
|
|
self.video_max_pixel = data_args.video_max_pixels |
|
|
self.fps = data_args.fps |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.list_data_dict) |
|
|
|
|
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
|
|
sources = self.list_data_dict[i] |
|
|
|
|
|
is_video = False |
|
|
|
|
|
processor = self.processor |
|
|
if "image" in sources: |
|
|
videos = None |
|
|
grid_key = "image_grid_thw" |
|
|
pixel_key = "pixel_values" |
|
|
|
|
|
image_files = sources["image"] |
|
|
image_folder = self.data_args.image_folder |
|
|
|
|
|
if isinstance(image_files, str): |
|
|
image_files = [image_files] |
|
|
|
|
|
images = [] |
|
|
|
|
|
for image_file in image_files: |
|
|
if not os.path.exists(image_file): |
|
|
if not image_file.startswith("http"): |
|
|
if 'share' in image_file: |
|
|
image_file = image_file.split('share/world_model/')[1] |
|
|
image_file = os.path.join(image_folder, image_file) |
|
|
images.append(get_image_info(image_file, self.image_min_pixel, self.image_max_pixel)) |
|
|
|
|
|
elif "video" in sources: |
|
|
is_video = True |
|
|
images=None |
|
|
grid_key = "video_grid_thw" |
|
|
pixel_key = "pixel_values_videos" |
|
|
|
|
|
video_files = sources["video"] |
|
|
video_folder = self.data_args.image_folder |
|
|
|
|
|
if isinstance(video_files, str): |
|
|
video_files = [video_files] |
|
|
|
|
|
videos = [] |
|
|
for video_file in video_files: |
|
|
if not os.path.exists(video_file): |
|
|
if not video_file.startswith("http"): |
|
|
if 'share' in video_file: |
|
|
video_file = video_file.split('share/world_model/')[1] |
|
|
video_file = os.path.join(video_folder, video_file) |
|
|
video_input, video_kwargs = get_video_info(video_file, self.video_min_pixel, self.video_max_pixel, self.data_args.fps) |
|
|
videos.append(video_input) |
|
|
else: |
|
|
grid_key = None |
|
|
pixel_key = None |
|
|
images=None |
|
|
videos=None |
|
|
|
|
|
sources = copy.deepcopy(llava_to_openai(sources['conversations'], is_video=is_video)) |
|
|
|
|
|
all_input_ids = [] |
|
|
all_labels = [] |
|
|
all_pixel_values = [] |
|
|
all_image_grid_thw = [] |
|
|
all_second_gird = [] |
|
|
|
|
|
|
|
|
if len(SYSTEM_MESSAGE) > 0: |
|
|
system_message = f"{DEFAULT_IM_START_TOKEN}system\n{SYSTEM_MESSAGE}\n{DEFAULT_IM_END_TOKEN}\n" |
|
|
system_message_input_ids = processor.tokenizer(system_message, add_special_tokens=False, return_tensors='pt')['input_ids'] |
|
|
system_labels = torch.full_like(system_message_input_ids, IGNORE_INDEX) |
|
|
|
|
|
all_input_ids.append(system_message_input_ids.squeeze(0)) |
|
|
all_labels.append(system_labels.squeeze(0)) |
|
|
|
|
|
for _, j in enumerate(range(0, len(sources), 2)): |
|
|
user_input = sources[j] |
|
|
gpt_response = sources[j + 1] |
|
|
|
|
|
user_input = f"{DEFAULT_IM_START_TOKEN}{user_input['role']}\n{user_input['content']}\n{DEFAULT_IM_END_TOKEN}\n{DEFAULT_IM_START_TOKEN}{gpt_response['role']}\n" |
|
|
gpt_response = f"{gpt_response['content']}\n{DEFAULT_IM_END_TOKEN}\n" |
|
|
|
|
|
if DEFAULT_IMAGE_TOKEN in user_input: |
|
|
inputs = processor(text=[user_input], images=images, videos=videos, padding=False, return_tensors='pt') |
|
|
prompt_input_ids = inputs['input_ids'] |
|
|
all_pixel_values.append(inputs[pixel_key]) |
|
|
all_image_grid_thw.append(inputs[grid_key]) |
|
|
|
|
|
elif DEFAULT_VIDEO_TOKEN in user_input: |
|
|
if "Qwen2.5" in self.model_id: |
|
|
inputs = processor(text=[user_input], images=images, videos=videos, padding=False, return_tensors='pt', **video_kwargs) |
|
|
all_second_gird.extend(inputs["second_per_grid_ts"]) |
|
|
else: |
|
|
inputs = processor(text=[user_input], images=images, videos=videos, padding=False, return_tensors='pt') |
|
|
prompt_input_ids = inputs['input_ids'] |
|
|
all_pixel_values.append(inputs[pixel_key]) |
|
|
all_image_grid_thw.append(inputs[grid_key]) |
|
|
|
|
|
else: |
|
|
prompt_input_ids = processor.tokenizer(user_input, add_special_tokens=False, padding=False, return_tensors='pt')['input_ids'] |
|
|
|
|
|
response_input_ids = processor.tokenizer(gpt_response, add_special_tokens=False, padding=False, return_tensors='pt')['input_ids'] |
|
|
|
|
|
input_ids = torch.cat([prompt_input_ids, response_input_ids], dim=1).squeeze(0) |
|
|
labels = torch.cat( |
|
|
[ |
|
|
torch.tensor([IGNORE_INDEX] * len(prompt_input_ids[0])), |
|
|
response_input_ids.squeeze(0), |
|
|
], |
|
|
dim=0, |
|
|
) |
|
|
|
|
|
all_input_ids.append(input_ids) |
|
|
all_labels.append(labels) |
|
|
|
|
|
|
|
|
|
|
|
input_ids = torch.cat(all_input_ids, dim=0).to(torch.long) |
|
|
labels = torch.cat(all_labels, dim=0).to(torch.long) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attention_mask = (input_ids > -1000000).to(torch.long) |
|
|
|
|
|
data_dict = dict( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
labels=labels, |
|
|
) |
|
|
|
|
|
if pixel_key and grid_key: |
|
|
pixel_values = torch.cat(all_pixel_values, dim=0) |
|
|
image_thw = torch.cat(all_image_grid_thw, dim=0) |
|
|
data_dict[pixel_key] = pixel_values |
|
|
data_dict[grid_key] = image_thw |
|
|
|
|
|
if len(all_second_gird) > 0: |
|
|
second_gird = all_second_gird |
|
|
data_dict["second_per_grid_ts"] = second_gird |
|
|
|
|
|
return data_dict |
|
|
|
|
|
class DataCollatorForSupervisedDataset(object): |
|
|
"""Collate examples for supervised fine-tuning.""" |
|
|
|
|
|
def __init__(self, pad_token_id: int): |
|
|
self.pad_token_id = pad_token_id |
|
|
|
|
|
def __call__(self, examples): |
|
|
batch_input_ids = [] |
|
|
batch_label_ids = [] |
|
|
batch_pixel_values = [] |
|
|
batch_pixel_video_values = [] |
|
|
batch_video_thw = [] |
|
|
batch_image_thw = [] |
|
|
batch_second_per_grid_ts = [] |
|
|
|
|
|
for example in examples: |
|
|
keys = example.keys() |
|
|
if "pixel_values_videos" in keys: |
|
|
batch_pixel_video_values.append(example["pixel_values_videos"]) |
|
|
batch_video_thw.append(example["video_grid_thw"]) |
|
|
elif "pixel_values" in keys: |
|
|
batch_pixel_values.append(example["pixel_values"]) |
|
|
batch_image_thw.append(example["image_grid_thw"]) |
|
|
|
|
|
batch_input_ids.append(example["input_ids"]) |
|
|
batch_label_ids.append(example["labels"]) |
|
|
|
|
|
if "second_per_grid_ts" in keys: |
|
|
batch_second_per_grid_ts.extend(example["second_per_grid_ts"]) |
|
|
|
|
|
input_ids = pad_sequence( |
|
|
batch_input_ids, padding_side='right', padding_value=self.pad_token_id |
|
|
) |
|
|
|
|
|
attention_mask = input_ids != self.pad_token_id |
|
|
labels = pad_sequence(batch_label_ids, padding_side='right', padding_value=IGNORE_INDEX) |
|
|
|
|
|
data_dict = { |
|
|
'input_ids': input_ids, |
|
|
'labels': labels, |
|
|
'attention_mask': attention_mask, |
|
|
} |
|
|
|
|
|
if len(batch_pixel_values) > 0: |
|
|
pixel_values = torch.cat(batch_pixel_values, dim=0) |
|
|
image_thw = torch.cat(batch_image_thw, dim=0) |
|
|
data_dict["pixel_values"] = pixel_values |
|
|
data_dict["image_grid_thw"] = image_thw |
|
|
|
|
|
if len(batch_pixel_video_values) > 0: |
|
|
pixel_video_values = torch.cat(batch_pixel_video_values, dim=0) |
|
|
video_thw = torch.cat(batch_video_thw, dim=0) |
|
|
data_dict["pixel_values_videos"] = pixel_video_values |
|
|
data_dict["video_grid_thw"] = video_thw |
|
|
|
|
|
if len(batch_second_per_grid_ts) > 0: |
|
|
data_dict["second_per_grid_ts"] = batch_second_per_grid_ts |
|
|
|
|
|
return data_dict |
|
|
|
|
|
|
|
|
def replace_image_tokens(input_string, is_video=False): |
|
|
if is_video: |
|
|
pattern = r'\n?' + re.escape(LLAVA_VIDEO_TOKEN) + r'\n?' |
|
|
replacement = VISION_START_TOKEN + DEFAULT_VIDEO_TOKEN + VISION_END_TOKEN |
|
|
else: |
|
|
pattern = r'\n?' + re.escape(LLAVA_IMAGE_TOKEN) + r'\n?' |
|
|
replacement = VISION_START_TOKEN + DEFAULT_IMAGE_TOKEN + VISION_END_TOKEN |
|
|
|
|
|
return re.sub(pattern, replacement, input_string) |
|
|
|
|
|
def llava_to_openai(conversations, is_video=False): |
|
|
role_mapping = {"human": "user", "gpt": "assistant"} |
|
|
|
|
|
transformed_data = [] |
|
|
for conversation in conversations: |
|
|
transformed_content = replace_image_tokens(conversation["value"], is_video=is_video) |
|
|
transformed_entry = { |
|
|
"role": role_mapping.get(conversation["from"], conversation["from"]), |
|
|
"content": transformed_content, |
|
|
} |
|
|
transformed_data.append(transformed_entry) |
|
|
|
|
|
return transformed_data |
|
|
|
|
|
def make_supervised_data_module(model_id, processor, data_args): |
|
|
"""Make dataset and collator for supervised fine-tuning.""" |
|
|
sft_dataset = SupervisedDataset( |
|
|
data_path=data_args.data_path, processor=processor, data_args=data_args, model_id=model_id |
|
|
) |
|
|
data_collator = DataCollatorForSupervisedDataset(pad_token_id=processor.tokenizer.pad_token_id) |
|
|
|
|
|
return dict(train_dataset=sft_dataset, |
|
|
eval_dataset=None, |
|
|
data_collator=data_collator) |