import torch import torchvision.transforms as transforms import os import logging import pickle def read_pkl_data(pkl_path, img_path): logging.info('reading pickle file: '+ pkl_path) with open(pkl_path, "rb") as fp: data = pickle.load(fp) fp.close() root_dir = img_path if not os.path.exists(root_dir): root_dir = root_dir.replace('train', '').replace('val', '').replace('test', '') imgs, phases, steps = [], [], [] for vid_name in sorted(data.keys()): paths = [ os.path.join(root_dir, vid_name, f"{item['Frame_id']}.jpg") for item in data[vid_name] ] imgs.append(paths) phases.append([item['Phase_gt'] for item in data[vid_name]]) steps.append([item['Step_gt'] for item in data[vid_name]]) return imgs, phases, steps ## Read test pickle files #### TRAIN #### labels = os.path.join('/gpfswork/rech/okw/ukw13bv/MultiBypass140/labels', 'bern', 'labels_by70_splits/labels', 'train', f'1fps_100_0.pickle') images = os.path.join('/gpfsscratch/rech/okw/ukw13bv/bypass/BernBypass70/frames') videos_train, phase_labels_train, step_labels_train = read_pkl_data( labels, images ) #### VAL #### labels = os.path.join('/gpfswork/rech/okw/ukw13bv/MultiBypass140/labels', 'bern', 'labels_by70_splits/labels', 'val', f'1fps_0.pickle') images = os.path.join('/gpfsscratch/rech/okw/ukw13bv/bypass/BernBypass70/frames') videos_val, phase_labels_val, step_labels_val = read_pkl_data( labels, images ) #### TEST #### labels = os.path.join('/gpfswork/rech/okw/ukw13bv/MultiBypass140/labels', 'bern', 'labels_by70_splits/labels', 'test', f'1fps_0.pickle') images = os.path.join('/gpfsscratch/rech/okw/ukw13bv/bypass/BernBypass70/frames') videos_test, phase_labels_test, step_labels_test = read_pkl_data(labels, images) _base_ = ['../base.py'] config = dict( train_config=[ dict( type='Recognition_frame_bypass', img_list=v, label_list=l, transforms=transforms.Compose( [ transforms.Resize((360, 640)), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ] ), ) for v, l in zip(videos_train, phase_labels_train) ], val_config=[ dict( type='Recognition_frame_bypass', img_list=v, label_list=l, transforms=transforms.Compose( [ transforms.Resize((360, 640)), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ] ), ) for v, l in zip(videos_val, phase_labels_val) ], test_config=[ dict( type='Recognition_frame_bypass', img_list=v, label_list=l, transforms=transforms.Compose( [ transforms.Resize((360, 640)), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ] ), ) for v, l in zip(videos_test, phase_labels_test) ], model_config = dict( type='MVNet_feature_extractor', backbone_img = dict( type='img_backbones/ImageEncoder_feature_extractor', # type='img_backbones/ImageEncoder_CLIPVISUAL', num_classes=768, pretrained='imagenet', # imagenet/ssl/random backbone_name='resnet_50', # backbone_name='resnet_50_clip' img_norm=False, ), backbone_text= dict( type='text_backbones/BertEncoder', text_bert_type='/gpfswork/rech/okw/ukw13bv/mmsl/biobert_pretrain_output_all_notes_150000', text_last_n_layers=4, text_aggregate_method='sum', text_norm=False, text_embedding_dim=768, text_freeze_bert=False, text_agg_tokens=True ) ) )