litwell's picture
Upload models/src/training/data.py with huggingface_hub
7c453be verified
import copy
import os
from dataclasses import dataclass, field
from typing import Dict
import torch
import transformers
import ujson as json
from torch.utils.data import Dataset
from qwen_vl_utils import process_vision_info
from PIL import Image
import re
import yaml
import random
import math
import pprint
from .params import DataArguments
from .constants import *
def truncate_sequence(input_ids, labels, max_length, eos_token_id):
if input_ids.size(0) > max_length:
input_ids = input_ids[:max_length-1]
labels = labels[:max_length-1]
if eos_token_id is not None:
input_ids = torch.cat([input_ids, torch.tensor([eos_token_id])])
labels = torch.cat([labels, torch.tensor([eos_token_id])])
return input_ids, labels
def pad_sequence(sequences, padding_side='right', padding_value=0):
"""
Pad a list of sequences to the same length.
sequences: list of tensors in [seq_len, *] shape
"""
assert padding_side in ['right', 'left']
max_size = sequences[0].size()
trailing_dims = max_size[1:]
max_len = max(len(seq) for seq in sequences)
batch_size = len(sequences)
output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
for i, seq in enumerate(sequences):
length = seq.size(0)
if padding_side == 'right':
output.data[i, :length] = seq
else:
output.data[i, -length:] = seq
return output
def get_image_info(image_path, min_pixel, max_pixel):
# Using this because of process_vision_info function
# Need to fix this in the future
messages = [
{"role": "user",
"content": [
{
"type": "image",
"image": image_path,
"min_pixel": min_pixel,
"max_pixel": max_pixel
}
]
}
]
image_input, _ = process_vision_info(messages)
return image_input[0]
def get_video_info(video_path, min_pixels, max_pixels, fps):
# Using this because of process_vision_info function
# Need to fix this in the future
messages = [
{"role": "user",
"content": [
{
"type": "video",
"video": video_path,
"min_pixels": min_pixels,
"max_pixels": max_pixels,
"fps": fps
}
]
}
]
_, video_input, video_kwargs = process_vision_info(messages, return_video_kwargs=True)
return video_input[0], video_kwargs
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(
self,
data_path: str | list,
processor: transformers.ProcessorMixin,
data_args: DataArguments,
model_id,
padding=True,
):
super(SupervisedDataset, self).__init__()
if isinstance(data_path, str):
if data_path.endswith(".json"):
list_data_dict = json.load(open(data_path, "r"))
# handle .yaml for multiple json and sampling strategy
elif data_path.endswith(".yaml"):
list_data_dict = []
with open(data_path, "r") as file:
yaml_data = yaml.safe_load(file)
pprint.pprint(yaml_data)
datasets = yaml_data.get("datasets")
# file should be in the format of:
# datasets:
# - json_path: xxxx1.json
# sampling_strategy: first:1000
# - json_path: xxxx2.json
# sampling_strategy: end:3000
# - json_path: xxxx3.json
# sampling_strategy: random:999
data_args.dataset_paths = [dataset.get("json_path") for dataset in datasets]
for dataset in datasets:
json_path = dataset.get("json_path")
sampling_strategy = dataset.get("sampling_strategy", "all")
sampling_number = None
print(f"Loading {json_path} with {sampling_strategy} sampling strategy")
if json_path.endswith(".jsonl"):
cur_data_dict = []
with open(json_path, "r") as json_file:
for line in json_file:
cur_data_dict.append(json.loads(line.strip()))
elif json_path.endswith(".json"):
with open(json_path, "r") as json_file:
cur_data_dict = json.load(json_file)
else:
raise ValueError(f"Unsupported file type: {json_path}")
if ":" in sampling_strategy:
sampling_strategy, sampling_number = sampling_strategy.split(":")
if "%" in sampling_number:
sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
else:
sampling_number = int(sampling_number)
# Apply the sampling strategy
if sampling_strategy == "first" and sampling_number is not None:
cur_data_dict = cur_data_dict[:sampling_number]
elif sampling_strategy == "end" and sampling_number is not None:
cur_data_dict = cur_data_dict[-sampling_number:]
elif sampling_strategy == "random" and sampling_number is not None:
random.shuffle(cur_data_dict)
cur_data_dict = cur_data_dict[:sampling_number]
print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
list_data_dict.extend(cur_data_dict)
print(f"Loaded {len(list_data_dict)} samples from {data_path} in total")
else:
list_data_dict = data_path
self.model_id = model_id
self.processor = processor
self.list_data_dict = list_data_dict
self.data_args = data_args
self.padding = padding
self.image_min_pixel = data_args.image_min_pixels
self.image_max_pixel = data_args.image_max_pixels
self.video_min_pixel = data_args.video_min_pixels
self.video_max_pixel = data_args.video_max_pixels
self.fps = data_args.fps
def __len__(self):
return len(self.list_data_dict)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
sources = self.list_data_dict[i]
is_video = False
processor = self.processor
if "image" in sources:
videos = None
grid_key = "image_grid_thw"
pixel_key = "pixel_values"
image_files = sources["image"]
image_folder = self.data_args.image_folder
if isinstance(image_files, str):
image_files = [image_files]
images = []
for image_file in image_files:
if not os.path.exists(image_file):
if not image_file.startswith("http"):
if 'share' in image_file:
image_file = image_file.split('share/world_model/')[1]
image_file = os.path.join(image_folder, image_file)
images.append(get_image_info(image_file, self.image_min_pixel, self.image_max_pixel))
elif "video" in sources:
is_video = True
images=None
grid_key = "video_grid_thw"
pixel_key = "pixel_values_videos"
video_files = sources["video"]
video_folder = self.data_args.image_folder
if isinstance(video_files, str):
video_files = [video_files]
videos = []
for video_file in video_files:
if not os.path.exists(video_file):
if not video_file.startswith("http"):
if 'share' in video_file:
video_file = video_file.split('share/world_model/')[1]
video_file = os.path.join(video_folder, video_file)
video_input, video_kwargs = get_video_info(video_file, self.video_min_pixel, self.video_max_pixel, self.data_args.fps)
videos.append(video_input)
else:
grid_key = None
pixel_key = None
images=None
videos=None
sources = copy.deepcopy(llava_to_openai(sources['conversations'], is_video=is_video))
all_input_ids = []
all_labels = []
all_pixel_values = []
all_image_grid_thw = []
all_second_gird = []
# Qwen2-VL uses a default system message so I've added this.
if len(SYSTEM_MESSAGE) > 0:
system_message = f"{DEFAULT_IM_START_TOKEN}system\n{SYSTEM_MESSAGE}\n{DEFAULT_IM_END_TOKEN}\n"
system_message_input_ids = processor.tokenizer(system_message, add_special_tokens=False, return_tensors='pt')['input_ids']
system_labels = torch.full_like(system_message_input_ids, IGNORE_INDEX)
all_input_ids.append(system_message_input_ids.squeeze(0))
all_labels.append(system_labels.squeeze(0))
for _, j in enumerate(range(0, len(sources), 2)):
user_input = sources[j]
gpt_response = sources[j + 1]
user_input = f"{DEFAULT_IM_START_TOKEN}{user_input['role']}\n{user_input['content']}\n{DEFAULT_IM_END_TOKEN}\n{DEFAULT_IM_START_TOKEN}{gpt_response['role']}\n"
gpt_response = f"{gpt_response['content']}\n{DEFAULT_IM_END_TOKEN}\n"
if DEFAULT_IMAGE_TOKEN in user_input:
inputs = processor(text=[user_input], images=images, videos=videos, padding=False, return_tensors='pt')
prompt_input_ids = inputs['input_ids']
all_pixel_values.append(inputs[pixel_key])
all_image_grid_thw.append(inputs[grid_key])
elif DEFAULT_VIDEO_TOKEN in user_input:
if "Qwen2.5" in self.model_id:
inputs = processor(text=[user_input], images=images, videos=videos, padding=False, return_tensors='pt', **video_kwargs)
all_second_gird.extend(inputs["second_per_grid_ts"])
else:
inputs = processor(text=[user_input], images=images, videos=videos, padding=False, return_tensors='pt')
prompt_input_ids = inputs['input_ids']
all_pixel_values.append(inputs[pixel_key])
all_image_grid_thw.append(inputs[grid_key])
else:
prompt_input_ids = processor.tokenizer(user_input, add_special_tokens=False, padding=False, return_tensors='pt')['input_ids']
response_input_ids = processor.tokenizer(gpt_response, add_special_tokens=False, padding=False, return_tensors='pt')['input_ids']
input_ids = torch.cat([prompt_input_ids, response_input_ids], dim=1).squeeze(0)
labels = torch.cat(
[
torch.tensor([IGNORE_INDEX] * len(prompt_input_ids[0])),
response_input_ids.squeeze(0),
],
dim=0,
)
all_input_ids.append(input_ids)
all_labels.append(labels)
# There is no need for eos or bos tokens in the input_ids
# Qwen2-VL does not use them
input_ids = torch.cat(all_input_ids, dim=0).to(torch.long)
labels = torch.cat(all_labels, dim=0).to(torch.long)
# eos_token_id = processor.tokenizer.convert_tokens_to_ids(DEFAULT_IM_END_TOKEN)
# input_ids, labels = truncate_sequence(input_ids, labels, self.max_length, eos_token_id)
attention_mask = (input_ids > -1000000).to(torch.long)
data_dict = dict(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
if pixel_key and grid_key:
pixel_values = torch.cat(all_pixel_values, dim=0)
image_thw = torch.cat(all_image_grid_thw, dim=0)
data_dict[pixel_key] = pixel_values
data_dict[grid_key] = image_thw
if len(all_second_gird) > 0:
second_gird = all_second_gird
data_dict["second_per_grid_ts"] = second_gird
return data_dict
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
def __init__(self, pad_token_id: int):
self.pad_token_id = pad_token_id
def __call__(self, examples):
batch_input_ids = []
batch_label_ids = []
batch_pixel_values = []
batch_pixel_video_values = []
batch_video_thw = []
batch_image_thw = []
batch_second_per_grid_ts = []
for example in examples:
keys = example.keys()
if "pixel_values_videos" in keys:
batch_pixel_video_values.append(example["pixel_values_videos"])
batch_video_thw.append(example["video_grid_thw"])
elif "pixel_values" in keys:
batch_pixel_values.append(example["pixel_values"])
batch_image_thw.append(example["image_grid_thw"])
batch_input_ids.append(example["input_ids"])
batch_label_ids.append(example["labels"])
if "second_per_grid_ts" in keys:
batch_second_per_grid_ts.extend(example["second_per_grid_ts"])
input_ids = pad_sequence(
batch_input_ids, padding_side='right', padding_value=self.pad_token_id
)
attention_mask = input_ids != self.pad_token_id
labels = pad_sequence(batch_label_ids, padding_side='right', padding_value=IGNORE_INDEX)
data_dict = {
'input_ids': input_ids,
'labels': labels,
'attention_mask': attention_mask,
}
if len(batch_pixel_values) > 0:
pixel_values = torch.cat(batch_pixel_values, dim=0)
image_thw = torch.cat(batch_image_thw, dim=0)
data_dict["pixel_values"] = pixel_values
data_dict["image_grid_thw"] = image_thw
if len(batch_pixel_video_values) > 0:
pixel_video_values = torch.cat(batch_pixel_video_values, dim=0)
video_thw = torch.cat(batch_video_thw, dim=0)
data_dict["pixel_values_videos"] = pixel_video_values
data_dict["video_grid_thw"] = video_thw
if len(batch_second_per_grid_ts) > 0:
data_dict["second_per_grid_ts"] = batch_second_per_grid_ts
return data_dict
def replace_image_tokens(input_string, is_video=False):
if is_video:
pattern = r'\n?' + re.escape(LLAVA_VIDEO_TOKEN) + r'\n?'
replacement = VISION_START_TOKEN + DEFAULT_VIDEO_TOKEN + VISION_END_TOKEN
else:
pattern = r'\n?' + re.escape(LLAVA_IMAGE_TOKEN) + r'\n?'
replacement = VISION_START_TOKEN + DEFAULT_IMAGE_TOKEN + VISION_END_TOKEN
return re.sub(pattern, replacement, input_string)
def llava_to_openai(conversations, is_video=False):
role_mapping = {"human": "user", "gpt": "assistant"}
transformed_data = []
for conversation in conversations:
transformed_content = replace_image_tokens(conversation["value"], is_video=is_video)
transformed_entry = {
"role": role_mapping.get(conversation["from"], conversation["from"]),
"content": transformed_content,
}
transformed_data.append(transformed_entry)
return transformed_data
def make_supervised_data_module(model_id, processor, data_args):
"""Make dataset and collator for supervised fine-tuning."""
sft_dataset = SupervisedDataset(
data_path=data_args.data_path, processor=processor, data_args=data_args, model_id=model_id
)
data_collator = DataCollatorForSupervisedDataset(pad_token_id=processor.tokenizer.pad_token_id)
return dict(train_dataset=sft_dataset,
eval_dataset=None,
data_collator=data_collator)