Spaces:
Runtime error
Runtime error
| """Datamodule for Llava Pretraining and Finetuning""" | |
| import os | |
| import re | |
| from PIL import Image | |
| import numpy as np | |
| import re | |
| import tempfile | |
| from typing import Dict, List, Union, Tuple | |
| import traceback | |
| import json | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import DataCollatorForSeq2Seq | |
| from tools.rw_utils import read_jsonlines | |
| from torch.utils.data import Dataset, DataLoader | |
| np_str_obj_array_pattern = re.compile(r"[SaUO]") | |
| default_collate_err_msg_format = ( | |
| "default_collate: batch must contain tensors, numpy arrays, numbers, " | |
| "dicts or lists; found {}" | |
| ) | |
| from .custom_data_parsers.standard_vision_parser import VisionParser | |
| from .custom_data_parsers.object_tracking_parser import ObjectTrackingParser | |
| from .custom_data_parsers.multi_images_parser import MultiImagesParser | |
| from .custom_data_parsers.video_permutation_parser import VideoPermutationParser | |
| from .custom_data_parsers.utils_visualize import visualize_image_bbox | |
| from .tarsier_processor import TarsierProcessor | |
| from tools.rw_utils import NumpyArrayEncoder | |
| from .utils import DictToObject | |
| import os | |
| HF_TOKEN = os.environ.get('HF_TOKEN', '') | |
| class TarsierDataProcessor: | |
| def __init__( | |
| self, | |
| processor: TarsierProcessor, | |
| n_frames: Union[int, list], | |
| max_n_frames=256, | |
| max_pixels=int(1280 * 720 // 2), | |
| min_pixels=0, | |
| max_seq_len=None, | |
| is_training=True, # 会影响:1. 训练和测试时采帧不同;2. 测试时忽略 response。 | |
| print_data_error=True, | |
| do_image_padding=False, | |
| do_image_crop=False, | |
| do_image_resize=True, | |
| video_sampling_strategy={}, | |
| prompt='', | |
| train_task='sft', | |
| **kwargs | |
| ): | |
| self.kwargs = kwargs | |
| self.processor = processor | |
| self.pad_collator = DataCollatorForSeq2Seq(processor.tokenizer, padding='longest') | |
| self.processor.max_seq_len = self.tokenizer.model_max_length if max_seq_len is None else max_seq_len | |
| self.n_frames = n_frames | |
| self.max_n_frames = max_n_frames | |
| self.max_pixels = max_pixels | |
| self.min_pixels = min_pixels | |
| self.is_training = is_training | |
| self.print_data_error = print_data_error | |
| self.do_image_padding = do_image_padding | |
| self.do_image_crop = do_image_crop | |
| self.do_image_resize = do_image_resize | |
| self.video_sampling_strategy = video_sampling_strategy | |
| self.prompt = prompt | |
| self.train_task = train_task | |
| self.object_tracking_parser = ObjectTrackingParser( | |
| n_frames=self.n_frames, | |
| max_objects=4, | |
| is_training=self.is_training, | |
| ) | |
| self.multi_images_parser = MultiImagesParser( | |
| n_frames=self.n_frames, | |
| is_training=self.is_training, | |
| ) | |
| self.video_permutation_parser = VideoPermutationParser( | |
| n_frames=self.n_frames, | |
| is_training=self.is_training, | |
| video_sampling_strategy=self.video_sampling_strategy, | |
| ) | |
| self.vision_parser = VisionParser( | |
| n_frames=self.n_frames, | |
| max_n_frames=self.max_n_frames, | |
| is_training=self.is_training, | |
| video_sampling_strategy=self.video_sampling_strategy | |
| ) | |
| def select_parser(self, data_dict): | |
| if data_dict.get('task', None) == 'video/object_tracking': | |
| return self.object_tracking_parser | |
| elif data_dict.get('task', None) == 'multi_images': | |
| return self.multi_images_parser | |
| elif data_dict.get('dataset', None) == 'video_permutation': | |
| return self.video_permutation_parser | |
| else: | |
| return self.vision_parser | |
| def parse_image_processing_config(self, data_dict): | |
| image_processing_config=data_dict.get('image_processing_config', {}) | |
| do_padding = image_processing_config.get('do_padding', self.do_image_padding) | |
| do_crop = image_processing_config.get('do_crop', self.do_image_crop) | |
| do_resize = image_processing_config.get('do_resize', self.do_image_resize) | |
| max_pixels = image_processing_config.get('max_pixels', self.max_pixels) | |
| min_pixels = image_processing_config.get('min_pixels', self.min_pixels) | |
| assert min_pixels <= max_pixels | |
| image_processing_config['do_padding'] = do_padding | |
| image_processing_config['do_crop'] = do_crop | |
| image_processing_config['do_resize'] = do_resize | |
| image_processing_config['max_pixels'] = max_pixels | |
| image_processing_config['min_pixels'] = min_pixels | |
| return image_processing_config | |
| def _transform(self, raw_data_dict: Dict) -> Dict: | |
| data_dict = json.loads(json.dumps(raw_data_dict, cls=NumpyArrayEncoder)) | |
| del raw_data_dict | |
| if self.prompt: | |
| for msg in data_dict['messages']: | |
| if msg['role'] == 'user': | |
| for content in msg['content']: | |
| if content['type'] == 'text': | |
| content['text'] = self.prompt | |
| data_dict_copy = json.loads(json.dumps(data_dict, cls=NumpyArrayEncoder)) | |
| image_processing_config = self.parse_image_processing_config(data_dict) | |
| parser = self.select_parser(data_dict) | |
| messages = parser.transform(data_dict, image_processing_config) | |
| data_dict_copy['extra_info'] = data_dict.pop('extra_info', {}) | |
| # visualize_image_bbox(data_dict, image_processing_config, self.processor) | |
| outputs = self.processor(messages, image_processing_config, is_training=self.is_training) | |
| # if not self.is_training: | |
| outputs['raw_data_dict'] = data_dict_copy | |
| return [outputs] | |
| def _split_chosen_rejected(self, data_dict: Dict): | |
| chosen_data_dict = data_dict | |
| rejected_data_dict = json.loads(json.dumps(data_dict, cls=NumpyArrayEncoder)) | |
| for msg in chosen_data_dict['messages']: | |
| if msg['role'] == 'assistant': | |
| for content in msg['content']: | |
| if content['type'] == 'text': | |
| content['text'] = content['chosen'] | |
| for msg in rejected_data_dict['messages']: | |
| if msg['role'] == 'assistant': | |
| for content in msg['content']: | |
| if content['type'] == 'text': | |
| content['text'] = content['rejected'] | |
| return chosen_data_dict, rejected_data_dict | |
| def transform(self, data_dict: Dict) -> Dict: | |
| try: | |
| if self.train_task == 'dpo': | |
| chosen_data_dict, rejected_data_dict = self._split_chosen_rejected(data_dict) | |
| return self._transform(chosen_data_dict) + self._transform(rejected_data_dict) | |
| return self._transform(data_dict) | |
| except Exception as e: | |
| if self.print_data_error: | |
| print(traceback.format_exc()) | |
| print(f'Error occurs when processing: \n{data_dict}') | |
| return [] | |
| def batch_transform(self, batch_data: List[Dict]) -> Dict: | |
| model_inputs = {} | |
| # if not self.is_training: | |
| raw_data_dict = [d.pop('raw_data_dict') for d in batch_data] | |
| model_inputs['raw_data_dict'] = raw_data_dict | |
| batch_pixel_values = [d.pop('pixel_values') for d in batch_data if 'pixel_values' in d] | |
| batch_image_grid_thw = [d.pop('image_grid_thw') for d in batch_data if 'image_grid_thw' in d] | |
| if len(batch_pixel_values) == 0: | |
| vision_placeholder = self.get_vision_placeholder() | |
| batch_pixel_values = [vision_placeholder.get('pixel_values')] | |
| batch_image_grid_thw = [vision_placeholder.get('image_grid_thw')] if 'image_grid_thw' in vision_placeholder else [] | |
| model_inputs['pixel_values'] = torch.cat(batch_pixel_values, dim=0) | |
| if len(batch_image_grid_thw) > 0: | |
| model_inputs['image_grid_thw'] = torch.cat(batch_image_grid_thw, dim=0) | |
| batch_num_images = [d.pop('num_images') for d in batch_data] | |
| model_inputs['num_images'] = torch.tensor(batch_num_images) | |
| model_inputs.update(self.pad_collator(batch_data)) | |
| return model_inputs | |
| def __call__(self, batch_data: Union[Dict, List[Dict]]) -> Dict: | |
| if isinstance(batch_data, dict): | |
| batch_data = [batch_data] | |
| batch = [self.transform(d)[0] for d in batch_data] | |
| return self.batch_transform(batch) | |
| def get_vision_placeholder(self): | |
| messages = [{"role": "user", "content": [{"type": "image", "image": Image.new(mode='RGB', size=(336, 336))}]}] | |
| image_processing_config = self.parse_image_processing_config({}) | |
| return self.processor(messages, image_processing_config) | |
| def get_text_placeholder(self): | |
| messages = [ | |
| {"role": "user", "content": [{"type": "text", "text": "Hello!"}]}, | |
| {"role": "assistant", "content": [{"type": "text", "text": "Thank you very much"}]}, | |
| ] | |
| image_processing_config = self.parse_image_processing_config({}) | |
| return self.processor(messages, image_processing_config) | |
| def init_processor(processor: Union[TarsierProcessor, str]=None, config: Dict=None): | |
| config = DictToObject(config) if isinstance(config, dict) else config | |
| if isinstance(processor, str): | |
| sub_processor = TarsierProcessor.from_pretrained( | |
| processor, | |
| padding_side='left', | |
| trust_remote_code=True, | |
| token=HF_TOKEN, | |
| ) | |
| else: | |
| sub_processor = processor | |
| processor = TarsierDataProcessor( | |
| processor=sub_processor, | |
| n_frames=config.n_frames, | |
| max_n_frames=config.max_n_frames, | |
| max_pixels=config.max_pixels, | |
| min_pixels=config.min_pixels, | |
| max_seq_len=config.max_seq_len, | |
| is_training=config.is_training, | |
| print_data_error=config.print_data_error, | |
| do_image_padding=config.do_image_padding, | |
| do_image_crop=config.do_image_crop, | |
| do_image_resize=config.do_image_resize, | |
| video_sampling_strategy=config.video_sampling_strategy, | |
| prompt=config.prompt, | |
| train_task=config.train_task | |
| ) | |
| return processor | |
| class TarsierDataset(Dataset): | |
| def __init__(self, ann_path="", anns=None, config: Dict=None, processor: Union[TarsierDataProcessor, TarsierProcessor, str]=None): | |
| self.config = DictToObject(config) if isinstance(config, dict) else config | |
| if not isinstance(processor, TarsierDataProcessor): | |
| self.processor = init_processor(processor, config) | |
| else: | |
| self.processor = processor | |
| if anns is None: | |
| self.anns = [] | |
| if isinstance(ann_path, str): | |
| ann_path = [ann_path] | |
| for path in ann_path: | |
| self.anns.extend(read_jsonlines(path)) | |
| else: | |
| self.anns = anns | |
| def __len__(self): | |
| return len(self.anns) | |
| def __getitem__(self, index): | |
| if index < 0 or index >= len(self.anns): | |
| raise IndexError("Index out of range") | |
| try: | |
| ann = self.anns[index] | |
| model_inputs = self.processor(ann) | |
| except Exception as e: | |
| print(f"Load data error: {e}") | |
| return ann, None | |
| return ann, model_inputs | |