|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
from typing import Any, Callable |
|
|
from typing_extensions import TypedDict |
|
|
|
|
|
import torch |
|
|
import transformers |
|
|
from torch.utils.data import Dataset |
|
|
from torchvision import transforms |
|
|
from transformers.tokenization_utils import PaddingStrategy, TruncationStrategy |
|
|
|
|
|
from align_anything.utils.multi_process import get_current_device |
|
|
from align_anything.utils.tools import right_padding, convert_to_rgb, ends_with_any |
|
|
from datasets import load_dataset |
|
|
import json |
|
|
import os |
|
|
def read_jsonl(file_path): |
|
|
data = [] |
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
|
for line in f: |
|
|
json_obj = json.loads(line.strip()) |
|
|
data.append(json_obj) |
|
|
return data |
|
|
|
|
|
|
|
|
def write_jsonl(file_path, data): |
|
|
with open(file_path, 'w', encoding='utf-8') as f: |
|
|
for item in data: |
|
|
json.dump(item, f, ensure_ascii=False) |
|
|
f.write('\n') |
|
|
def read_json(file_path): |
|
|
with open(file_path, 'r', encoding='utf-8') as file: |
|
|
data = json.load(file) |
|
|
return data |
|
|
|
|
|
|
|
|
def write_json(file_path, data): |
|
|
with open(file_path, 'w', encoding='utf-8') as file: |
|
|
json.dump(data, file, ensure_ascii=False, indent=4) |
|
|
|
|
|
IGNORE_INDEX = -100 |
|
|
|
|
|
__all__ = [ |
|
|
'SupervisedDataset', |
|
|
'SupervisedTokenizedDataset', |
|
|
'SupervisedCollator', |
|
|
'SupervisedSample', |
|
|
'SupervisedBatch', |
|
|
] |
|
|
|
|
|
|
|
|
class SupervisedSample(TypedDict, total=True): |
|
|
input_ids: torch.LongTensor |
|
|
labels: torch.LongTensor |
|
|
pixel_values: torch.LongTensor | None |
|
|
|
|
|
|
|
|
class SupervisedBatch(TypedDict, total=True): |
|
|
input_ids: torch.LongTensor |
|
|
labels: torch.LongTensor |
|
|
attention_mask: torch.BoolTensor |
|
|
pixel_values: torch.LongTensor | None |
|
|
task: str |
|
|
images_seq_mask: torch.BoolTensor | None |
|
|
images_emb_mask: torch.BoolTensor | None |
|
|
|
|
|
|
|
|
class SupervisedDataset(Dataset): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
path: str, |
|
|
template: str, |
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
processor: transformers.ProcessorMixin | transforms.Compose | None = None, |
|
|
padding_side: str = 'right', |
|
|
name: str | None = None, |
|
|
size: int | None = None, |
|
|
split: str | None = None, |
|
|
subset: str | None = None, |
|
|
data_files: str | None = None, |
|
|
optional_args: list | str = [], |
|
|
): |
|
|
super().__init__() |
|
|
assert path, f'You must set the valid datasets path! Here is {path}' |
|
|
assert template, f'You must set the valid template path! Here is {template}' |
|
|
self.tokenizer = tokenizer |
|
|
self.processor = processor |
|
|
self.padding_side = padding_side |
|
|
self.raw_data = load_dataset( |
|
|
path, |
|
|
split=split, |
|
|
data_files=data_files, |
|
|
subset=subset, |
|
|
*optional_args, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
self.raw_data = read_json(f'{path}/{data_files}') |
|
|
|
|
|
|
|
|
if size: |
|
|
self.raw_data = self.raw_data.select(range(int(size))) |
|
|
self.template = template |
|
|
|
|
|
|
|
|
def preprocess(self, raw_sample: dict[str, Any]) -> SupervisedSample: |
|
|
prompt, conversation, meta_info = self.template.format_supervised_sample(raw_sample) |
|
|
if not ends_with_any(conversation, self.tokenizer.eos_token): |
|
|
conversation += self.tokenizer.eos_token |
|
|
|
|
|
|
|
|
full_inputs = self.processor( |
|
|
prompt=conversation, images=[meta_info['image']], return_tensors='pt' |
|
|
) |
|
|
prompt_inputs = self.processor( |
|
|
prompt=prompt, images=[meta_info['image']], return_tensors='pt' |
|
|
) |
|
|
|
|
|
return_dict = {} |
|
|
return_dict['input_ids'] = full_inputs['input_ids'][0] |
|
|
return_dict['attention_mask'] = full_inputs['attention_mask'][0] |
|
|
return_dict['pixel_values'] = full_inputs['pixel_values'][0] |
|
|
return_dict['images_seq_mask'] = full_inputs['images_seq_mask'][0] |
|
|
return_dict['images_emb_mask'] = full_inputs['images_emb_mask'][0] |
|
|
return_dict['labels'] = return_dict['input_ids'].clone() |
|
|
return_dict['labels'][: len(prompt_inputs['input_ids'][0])] = IGNORE_INDEX |
|
|
return_dict['task'] = 'understanding' |
|
|
|
|
|
return return_dict |
|
|
|
|
|
def get_collator(self) -> Callable[[list[dict[str, torch.Tensor]]], dict[str, torch.Tensor]]: |
|
|
return SupervisedCollator(self.tokenizer.pad_token_id, self.processor, self.padding_side) |
|
|
|
|
|
def tokenize( |
|
|
self, |
|
|
conversation: str, |
|
|
add_special_tokens: bool = True, |
|
|
padding: bool | str | PaddingStrategy = PaddingStrategy.DO_NOT_PAD, |
|
|
truncation: bool | str | TruncationStrategy = TruncationStrategy.LONGEST_FIRST, |
|
|
max_length: int | None = None, |
|
|
) -> torch.LongTensor: |
|
|
"""Tokenize a text string into a tensor representation.""" |
|
|
if max_length is None: |
|
|
max_length = self.tokenizer.model_max_length |
|
|
|
|
|
return self.tokenizer( |
|
|
text=conversation, |
|
|
add_special_tokens=add_special_tokens, |
|
|
padding=padding, |
|
|
max_length=max_length, |
|
|
truncation=truncation, |
|
|
return_tensors='pt', |
|
|
) |
|
|
|
|
|
def __getitem__(self, index: int) -> dict[str, torch.Tensor]: |
|
|
"""Get a tokenized data sample by index.""" |
|
|
raw_sample = self.raw_data[index] |
|
|
print('RARAAAAAAAAAAAAA') |
|
|
print(raw_sample) |
|
|
data = self.preprocess([raw_sample['prompt']].copy()) |
|
|
return data |
|
|
|
|
|
def __len__(self) -> int: |
|
|
"""Get the number of samples in the dataset.""" |
|
|
return len(self.raw_data) |
|
|
|
|
|
|
|
|
class SupervisedTokenizedDataset(Dataset): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
path: str, |
|
|
template: str | None = None, |
|
|
tokenizer: transformers.PreTrainedTokenizer | None = None, |
|
|
processor: transformers.ProcessorMixin | transforms.Compose | None = None, |
|
|
padding_side: str = 'right', |
|
|
size: int | None = None, |
|
|
name: str | None = None, |
|
|
split: str | None = None, |
|
|
subset: str | None = None, |
|
|
data_files: str | None = None, |
|
|
optional_args: list | str = [], |
|
|
): |
|
|
super().__init__() |
|
|
assert path, f'You must set the valid datasets path! Here is {path}' |
|
|
assert template, f'You must set the valid template path! Here is {template}' |
|
|
self.tokenizer = tokenizer |
|
|
self.processor = processor |
|
|
self.padding_side = padding_side |
|
|
|
|
|
self.raw_data = torch.load(f'{path}/{data_files}', map_location=torch.device('cpu'), weights_only=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for x in self.raw_data: |
|
|
x['source_image'] = x['source_image'] |
|
|
if size: |
|
|
self.raw_data = self.raw_data.select(range(int(size))) |
|
|
self.template = template |
|
|
|
|
|
def get_collator(self) -> Callable[[list[dict[str, torch.Tensor]]], dict[str, torch.Tensor]]: |
|
|
return SupervisedCollator(self.tokenizer.pad_token_id, self.processor, self.padding_side) |
|
|
|
|
|
def __getitem__(self, index: int) -> dict[str, torch.Tensor]: |
|
|
"""Get a tokenized data sample by index.""" |
|
|
raw_sample = self.raw_data[index] |
|
|
return raw_sample |
|
|
|
|
|
def __len__(self) -> int: |
|
|
"""Get the number of samples in the dataset.""" |
|
|
return len(self.raw_data) |
|
|
|
|
|
import PIL.Image |
|
|
def process_image(image_paths, vl_chat_processor): |
|
|
|
|
|
images_outputs = vl_chat_processor.image_processor(image_paths, return_tensors="pt") |
|
|
return images_outputs['pixel_values'] |
|
|
|
|
|
class SupervisedCollator: |
|
|
|
|
|
def __init__(self, pad_token_id: int, processor: transformers.ProcessorMixin | transforms.Compose | None = None, padding_side: str = 'right') -> None: |
|
|
self.pad_token_id = pad_token_id |
|
|
self.processor = processor |
|
|
self.padding_side = padding_side |
|
|
|
|
|
def __call__(self, samples: list[SupervisedSample]) -> SupervisedBatch: |
|
|
|
|
|
return_dict = {} |
|
|
current_device = get_current_device() |
|
|
print('SASASADDDDDDDDDDDDDDDD') |
|
|
print(samples) |
|
|
|
|
|
return_dict['input_ids'] = right_padding( |
|
|
[sample['input_ids'] for sample in samples], |
|
|
padding_value=self.pad_token_id, |
|
|
).to(current_device) |
|
|
|
|
|
return_dict['labels'] = right_padding( |
|
|
[sample['labels'] for sample in samples], |
|
|
padding_value=IGNORE_INDEX, |
|
|
).to(current_device) |
|
|
|
|
|
if 'attention_mask' in samples[0]: |
|
|
return_dict['attention_mask'] = right_padding( |
|
|
[sample['attention_mask'] for sample in samples], |
|
|
padding_value=0, |
|
|
).to(current_device) |
|
|
|
|
|
if 'pixel_values' in samples[0]: |
|
|
return_dict['pixel_values'] = right_padding( |
|
|
[sample['pixel_values'] for sample in samples], |
|
|
padding_value=0, |
|
|
).to(current_device) |
|
|
|
|
|
if 'source_image' in samples[0]: |
|
|
return_dict['source_image'] = [sample['source_image'] for sample in samples] |
|
|
if 'sft_format' in samples[0]: |
|
|
return_dict['sft_format'] = [sample['sft_format'] for sample in samples] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'task' in samples[0]: |
|
|
return_dict['task'] = samples[0]['task'] |
|
|
|
|
|
if "images_seq_mask" in samples[0]: |
|
|
return_dict['images_seq_mask'] = right_padding( |
|
|
[sample['images_seq_mask'] for sample in samples], |
|
|
padding_value=0, |
|
|
).to(current_device) |
|
|
|
|
|
if "images_emb_mask" in samples[0]: |
|
|
return_dict['images_emb_mask'] = right_padding( |
|
|
[sample['images_emb_mask'] for sample in samples], |
|
|
padding_value=0, |
|
|
).to(current_device) |
|
|
print(return_dict) |
|
|
return return_dict |
|
|
|