Unify_dataset / supervised.py
asdjghh's picture
Upload supervised.py with huggingface_hub
3c7fc5a verified
# Copyright 2025 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
from typing import Any, Callable
from typing_extensions import TypedDict # Python 3.10+
import torch
import transformers
from torch.utils.data import Dataset
from torchvision import transforms
from transformers.tokenization_utils import PaddingStrategy, TruncationStrategy
from align_anything.utils.multi_process import get_current_device
from align_anything.utils.tools import right_padding, convert_to_rgb, ends_with_any
from datasets import load_dataset
import json
import os
def read_jsonl(file_path):
data = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
json_obj = json.loads(line.strip()) # 每行解析为一个字典
data.append(json_obj)
return data
def write_jsonl(file_path, data):
with open(file_path, 'w', encoding='utf-8') as f:
for item in data:
json.dump(item, f, ensure_ascii=False) # 将每个字典写入文件
f.write('\n') # 每个 JSON 对象占一行
def read_json(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
return data
def write_json(file_path, data):
with open(file_path, 'w', encoding='utf-8') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
IGNORE_INDEX = -100
__all__ = [
'SupervisedDataset',
'SupervisedTokenizedDataset',
'SupervisedCollator',
'SupervisedSample',
'SupervisedBatch',
]
class SupervisedSample(TypedDict, total=True):
input_ids: torch.LongTensor # size = (L,)
labels: torch.LongTensor # size = (L,)
pixel_values: torch.LongTensor | None # size = (B, C, H, W)
class SupervisedBatch(TypedDict, total=True):
input_ids: torch.LongTensor # size = (B, L)
labels: torch.LongTensor # size = (B, L)
attention_mask: torch.BoolTensor # size = (B, L)
pixel_values: torch.LongTensor | None # size = (B, C, H, W)
task: str
images_seq_mask: torch.BoolTensor | None # size = (B, L)
images_emb_mask: torch.BoolTensor | None # size = (B, N)
class SupervisedDataset(Dataset):
def __init__(
self,
path: str,
template: str,
tokenizer: transformers.PreTrainedTokenizer,
processor: transformers.ProcessorMixin | transforms.Compose | None = None,
padding_side: str = 'right',
name: str | None = None,
size: int | None = None,
split: str | None = None,
subset: str | None = None,
data_files: str | None = None,
optional_args: list | str = [],
):
super().__init__()
assert path, f'You must set the valid datasets path! Here is {path}'
assert template, f'You must set the valid template path! Here is {template}'
self.tokenizer = tokenizer
self.processor = processor
self.padding_side = padding_side
self.raw_data = load_dataset(
path,
split=split,
data_files=data_files,
subset=subset,
*optional_args,
trust_remote_code=True,
)
# print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++Loading dataset from', path)
self.raw_data = read_json(f'{path}/{data_files}')
# print(self.raw_data)
if size:
self.raw_data = self.raw_data.select(range(int(size)))
self.template = template
# print(template)
def preprocess(self, raw_sample: dict[str, Any]) -> SupervisedSample:
prompt, conversation, meta_info = self.template.format_supervised_sample(raw_sample)
if not ends_with_any(conversation, self.tokenizer.eos_token):
conversation += self.tokenizer.eos_token
# return return_dict
full_inputs = self.processor(
prompt=conversation, images=[meta_info['image']], return_tensors='pt'
)
prompt_inputs = self.processor(
prompt=prompt, images=[meta_info['image']], return_tensors='pt'
)
return_dict = {}
return_dict['input_ids'] = full_inputs['input_ids'][0]
return_dict['attention_mask'] = full_inputs['attention_mask'][0]
return_dict['pixel_values'] = full_inputs['pixel_values'][0]
return_dict['images_seq_mask'] = full_inputs['images_seq_mask'][0]
return_dict['images_emb_mask'] = full_inputs['images_emb_mask'][0]
return_dict['labels'] = return_dict['input_ids'].clone()
return_dict['labels'][: len(prompt_inputs['input_ids'][0])] = IGNORE_INDEX
return_dict['task'] = 'understanding'
return return_dict
def get_collator(self) -> Callable[[list[dict[str, torch.Tensor]]], dict[str, torch.Tensor]]:
return SupervisedCollator(self.tokenizer.pad_token_id, self.processor, self.padding_side)
def tokenize(
self,
conversation: str,
add_special_tokens: bool = True,
padding: bool | str | PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
truncation: bool | str | TruncationStrategy = TruncationStrategy.LONGEST_FIRST,
max_length: int | None = None,
) -> torch.LongTensor: # size = (L,)
"""Tokenize a text string into a tensor representation."""
if max_length is None:
max_length = self.tokenizer.model_max_length
return self.tokenizer(
text=conversation,
add_special_tokens=add_special_tokens,
padding=padding,
max_length=max_length,
truncation=truncation,
return_tensors='pt',
)
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
"""Get a tokenized data sample by index."""
raw_sample = self.raw_data[index]
print('RARAAAAAAAAAAAAA')
print(raw_sample)
data = self.preprocess([raw_sample['prompt']].copy())
return data
def __len__(self) -> int:
"""Get the number of samples in the dataset."""
return len(self.raw_data)
class SupervisedTokenizedDataset(Dataset):
def __init__(
self,
path: str,
template: str | None = None,
tokenizer: transformers.PreTrainedTokenizer | None = None,
processor: transformers.ProcessorMixin | transforms.Compose | None = None,
padding_side: str = 'right',
size: int | None = None,
name: str | None = None,
split: str | None = None,
subset: str | None = None,
data_files: str | None = None,
optional_args: list | str = [],
):
super().__init__()
assert path, f'You must set the valid datasets path! Here is {path}'
assert template, f'You must set the valid template path! Here is {template}'
self.tokenizer = tokenizer
self.processor = processor
self.padding_side = padding_side
self.raw_data = torch.load(f'{path}/{data_files}', map_location=torch.device('cpu'), weights_only=False)
# self.raw_data = read_json(f'{path}/{data_files}')
# print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++Loading dataset from', path)
# print(self.raw_data)
# self.raw_data[0] = self.tokenize(self.raw_data[0]['prompt'])
for x in self.raw_data:
x['source_image'] = x['source_image']
if size:
self.raw_data = self.raw_data.select(range(int(size)))
self.template = template
def get_collator(self) -> Callable[[list[dict[str, torch.Tensor]]], dict[str, torch.Tensor]]:
return SupervisedCollator(self.tokenizer.pad_token_id, self.processor, self.padding_side)
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
"""Get a tokenized data sample by index."""
raw_sample = self.raw_data[index]
return raw_sample
def __len__(self) -> int:
"""Get the number of samples in the dataset."""
return len(self.raw_data)
import PIL.Image
def process_image(image_paths, vl_chat_processor):
# images = [PIL.Image.open(image_path).convert("RGB") for image_path in image_paths]
images_outputs = vl_chat_processor.image_processor(image_paths, return_tensors="pt")
return images_outputs['pixel_values']
class SupervisedCollator:
def __init__(self, pad_token_id: int, processor: transformers.ProcessorMixin | transforms.Compose | None = None, padding_side: str = 'right') -> None:
self.pad_token_id = pad_token_id
self.processor = processor
self.padding_side = padding_side
def __call__(self, samples: list[SupervisedSample]) -> SupervisedBatch:
return_dict = {}
current_device = get_current_device()
print('SASASADDDDDDDDDDDDDDDD')
print(samples)
return_dict['input_ids'] = right_padding(
[sample['input_ids'] for sample in samples],
padding_value=self.pad_token_id,
).to(current_device)
return_dict['labels'] = right_padding(
[sample['labels'] for sample in samples],
padding_value=IGNORE_INDEX,
).to(current_device)
if 'attention_mask' in samples[0]:
return_dict['attention_mask'] = right_padding(
[sample['attention_mask'] for sample in samples],
padding_value=0,
).to(current_device)
if 'pixel_values' in samples[0]:
return_dict['pixel_values'] = right_padding(
[sample['pixel_values'] for sample in samples],
padding_value=0,
).to(current_device)
if 'source_image' in samples[0]:
return_dict['source_image'] = [sample['source_image'] for sample in samples]
if 'sft_format' in samples[0]:
return_dict['sft_format'] = [sample['sft_format'] for sample in samples]
if 'task' in samples[0]:
return_dict['task'] = samples[0]['task']
if "images_seq_mask" in samples[0]:
return_dict['images_seq_mask'] = right_padding(
[sample['images_seq_mask'] for sample in samples],
padding_value=0,
).to(current_device)
if "images_emb_mask" in samples[0]:
return_dict['images_emb_mask'] = right_padding(
[sample['images_emb_mask'] for sample in samples],
padding_value=0,
).to(current_device)
print(return_dict)
return return_dict