| from torch.utils.data import Dataset |
| import os |
| import json |
| from PIL import Image |
| from llava.mm_utils import process_images |
| import torch |
| class SurgDataset(Dataset): |
| def __init__(self, args, image_processor, model_config, mode='train') -> None: |
| super().__init__() |
| self.args = args |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if mode == 'test': |
| if os.path.isfile(args.data_path): |
| with open(args.data_path) as f: |
| self.data_json = json.load(f) |
| if len(self.data_json) > 200: |
| self.data_json = self.data_json[:200] |
| else: |
| self.data_json_path = os.path.join(args.data_path, 'test.json') |
| if os.path.isfile(self.data_json_path): |
| with open(self.data_json_path) as f: |
| self.data_json = json.load(f) |
| else: |
| |
| self.data_json_path = os.path.join(args.data_path, 'train.json') |
| with open(self.data_json_path) as f: |
| self.data_json = json.load(f) |
| data_length = len(self.data_json) |
| self.data_json = self.data_json[int(0.9 * data_length):] |
| else: |
| self.data_json_path = args.data_path if os.path.isfile(args.data_path) \ |
| else os.path.join(args.data_path, 'train.json') |
| |
| with open(self.data_json_path) as f: |
| self.data_json = json.load(f) |
| self.valid_data_json = [] |
|
|
| |
| for i in self.data_json: |
| |
| img_path = i['image'] |
| if type(img_path) is list: |
| img_path = img_path[0] |
| is_empty = i['conversations'][0]['value'] == '' or i['conversations'][1]['value'] == '' |
| if os.path.isfile(img_path) and not is_empty: |
| i['image'] = img_path |
| self.valid_data_json.append(i) |
| self.image_processor = image_processor |
| self.model_config = model_config |
| def __len__(self): |
| return len(self.valid_data_json) |
|
|
| def __getitem__(self, idx): |
| |
| image = Image.open(self.valid_data_json[idx]['image']) |
| image_size = [image.size] |
| image_tensor = process_images([image], self.image_processor, self.model_config) |
| image_tensor = [_image.to(dtype=torch.bfloat16) for _image in image_tensor] |
| |
| question = self.valid_data_json[idx]['conversations'][0] if self.valid_data_json[idx]['conversations'][0]['from'] == 'human' \ |
| else self.valid_data_json[idx]['conversations'][1] |
| question = question['value'] |
| |
| if question.startswith('<image>'): |
| question = question[8:] |
|
|
| answer = self.valid_data_json[idx]['conversations'][1] if self.valid_data_json[idx]['conversations'][1]['from'] == 'gpt' \ |
| else self.valid_data_json[idx]['conversations'][0] |
| answer = answer['value'] |
| raw_data = self.valid_data_json[idx] |
| return raw_data, question, answer, image_tensor[0], image_size |
|
|
| if __name__ == '__main__': |
| from run_finetune_llava import parse_args |
| args = parse_args() |
| SurgDataset(args=args) |
| |