from torch.utils.data import Dataset import os import json from PIL import Image from llava.mm_utils import process_images import torch class SurgDataset(Dataset): def __init__(self, args, image_processor, model_config, mode='train') -> None: super().__init__() self.args = args # data_json_path = os.path.join(self.args.data_path, '4D-OR-instruct', f'llava_3d_0503_{mode}.json') # self.img_root = os.path.join(self.args.data_path, '4D-OR-MV') # Hard code the path in this part # if mode == 'test': # all_json = os.listdir('/mnt1/lyc/LLaVA-NeXT/4dor_instruct_0711_split') # self.data_json = [] # for i in all_json: # if i.endswith('_test.json'): # with open(os.path.join('/mnt1/lyc/LLaVA-NeXT/4dor_instruct_0711_split', i)) as f: # self.data_json.extend(json.load(f)[:200]) # else: # data_json_path = '/mnt1/lyc/LLaVA-NeXT/instruct_sample_18430_0713.json' # with open(data_json_path) as f: # self.data_json = json.load(f) if mode == 'test': if os.path.isfile(args.data_path): with open(args.data_path) as f: self.data_json = json.load(f) if len(self.data_json) > 200: self.data_json = self.data_json[:200] else: self.data_json_path = os.path.join(args.data_path, 'test.json') if os.path.isfile(self.data_json_path): with open(self.data_json_path) as f: self.data_json = json.load(f) else: # if you do not have test.json, it will choose the last 10 percent of training data for testing self.data_json_path = os.path.join(args.data_path, 'train.json') with open(self.data_json_path) as f: self.data_json = json.load(f) data_length = len(self.data_json) self.data_json = self.data_json[int(0.9 * data_length):] else: self.data_json_path = args.data_path if os.path.isfile(args.data_path) \ else os.path.join(args.data_path, 'train.json') with open(self.data_json_path) as f: self.data_json = json.load(f) self.valid_data_json = [] # filter out invalid data for i in self.data_json: # img_path = os.path.join(self.img_root, i['image']) img_path = i['image'] if type(img_path) is list: img_path = img_path[0] is_empty = i['conversations'][0]['value'] == '' or i['conversations'][1]['value'] == '' if os.path.isfile(img_path) and not is_empty: i['image'] = img_path self.valid_data_json.append(i) self.image_processor = image_processor self.model_config = model_config def __len__(self): return len(self.valid_data_json) def __getitem__(self, idx): # image = Image.open(os.path.join(self.img_root, self.valid_data_json[idx]['image'])) image = Image.open(self.valid_data_json[idx]['image']) image_size = [image.size] image_tensor = process_images([image], self.image_processor, self.model_config) image_tensor = [_image.to(dtype=torch.bfloat16) for _image in image_tensor] # Get question and answer question = self.valid_data_json[idx]['conversations'][0] if self.valid_data_json[idx]['conversations'][0]['from'] == 'human' \ else self.valid_data_json[idx]['conversations'][1] question = question['value'] # To adopt XTuner data format if question.startswith(''): question = question[8:] answer = self.valid_data_json[idx]['conversations'][1] if self.valid_data_json[idx]['conversations'][1]['from'] == 'gpt' \ else self.valid_data_json[idx]['conversations'][0] answer = answer['value'] raw_data = self.valid_data_json[idx] return raw_data, question, answer, image_tensor[0], image_size if __name__ == '__main__': from run_finetune_llava import parse_args args = parse_args() SurgDataset(args=args)