llava_finetune / dataset /SurgDataset.py
lyclyc52's picture
Update: integrate llama3 into finetuning code
157f5b2
from torch.utils.data import Dataset
import os
import json
from PIL import Image
from llava.mm_utils import process_images
import torch
class SurgDataset(Dataset):
def __init__(self, args, image_processor, model_config, mode='train') -> None:
super().__init__()
self.args = args
# data_json_path = os.path.join(self.args.data_path, '4D-OR-instruct', f'llava_3d_0503_{mode}.json')
# self.img_root = os.path.join(self.args.data_path, '4D-OR-MV')
# Hard code the path in this part
# if mode == 'test':
# all_json = os.listdir('/mnt1/lyc/LLaVA-NeXT/4dor_instruct_0711_split')
# self.data_json = []
# for i in all_json:
# if i.endswith('_test.json'):
# with open(os.path.join('/mnt1/lyc/LLaVA-NeXT/4dor_instruct_0711_split', i)) as f:
# self.data_json.extend(json.load(f)[:200])
# else:
# data_json_path = '/mnt1/lyc/LLaVA-NeXT/instruct_sample_18430_0713.json'
# with open(data_json_path) as f:
# self.data_json = json.load(f)
if mode == 'test':
if os.path.isfile(args.data_path):
with open(args.data_path) as f:
self.data_json = json.load(f)
if len(self.data_json) > 200:
self.data_json = self.data_json[:200]
else:
self.data_json_path = os.path.join(args.data_path, 'test.json')
if os.path.isfile(self.data_json_path):
with open(self.data_json_path) as f:
self.data_json = json.load(f)
else:
# if you do not have test.json, it will choose the last 10 percent of training data for testing
self.data_json_path = os.path.join(args.data_path, 'train.json')
with open(self.data_json_path) as f:
self.data_json = json.load(f)
data_length = len(self.data_json)
self.data_json = self.data_json[int(0.9 * data_length):]
else:
self.data_json_path = args.data_path if os.path.isfile(args.data_path) \
else os.path.join(args.data_path, 'train.json')
with open(self.data_json_path) as f:
self.data_json = json.load(f)
self.valid_data_json = []
# filter out invalid data
for i in self.data_json:
# img_path = os.path.join(self.img_root, i['image'])
img_path = i['image']
if type(img_path) is list:
img_path = img_path[0]
is_empty = i['conversations'][0]['value'] == '' or i['conversations'][1]['value'] == ''
if os.path.isfile(img_path) and not is_empty:
i['image'] = img_path
self.valid_data_json.append(i)
self.image_processor = image_processor
self.model_config = model_config
def __len__(self):
return len(self.valid_data_json)
def __getitem__(self, idx):
# image = Image.open(os.path.join(self.img_root, self.valid_data_json[idx]['image']))
image = Image.open(self.valid_data_json[idx]['image'])
image_size = [image.size]
image_tensor = process_images([image], self.image_processor, self.model_config)
image_tensor = [_image.to(dtype=torch.bfloat16) for _image in image_tensor]
# Get question and answer
question = self.valid_data_json[idx]['conversations'][0] if self.valid_data_json[idx]['conversations'][0]['from'] == 'human' \
else self.valid_data_json[idx]['conversations'][1]
question = question['value']
# To adopt XTuner data format
if question.startswith('<image>'):
question = question[8:]
answer = self.valid_data_json[idx]['conversations'][1] if self.valid_data_json[idx]['conversations'][1]['from'] == 'gpt' \
else self.valid_data_json[idx]['conversations'][0]
answer = answer['value']
raw_data = self.valid_data_json[idx]
return raw_data, question, answer, image_tensor[0], image_size
if __name__ == '__main__':
from run_finetune_llava import parse_args
args = parse_args()
SurgDataset(args=args)