| import torch |
| import torchvision.transforms as transforms |
| import os |
| import logging |
| import pickle |
|
|
| def read_pkl_data(pkl_path, img_path): |
| logging.info('reading pickle file: '+ pkl_path) |
| with open(pkl_path, "rb") as fp: |
| data = pickle.load(fp) |
| fp.close() |
| |
| root_dir = img_path |
| if not os.path.exists(root_dir): |
| root_dir = root_dir.replace('train', '').replace('val', '').replace('test', '') |
| imgs, phases, steps = [], [], [] |
| for vid_name in sorted(data.keys()): |
| paths = [ |
| os.path.join(root_dir, vid_name, f"{item['Frame_id']}.jpg") |
| for item in data[vid_name] |
| ] |
| imgs.append(paths) |
| phases.append([item['Phase_gt'] for item in data[vid_name]]) |
| steps.append([item['Step_gt'] for item in data[vid_name]]) |
| |
| return imgs, phases, steps |
|
|
|
|
| |
| |
| labels = os.path.join('/gpfswork/rech/okw/ukw13bv/MultiBypass140/labels', 'bern', 'labels_by70_splits/labels', 'train', f'1fps_100_0.pickle') |
| images = os.path.join('/gpfsscratch/rech/okw/ukw13bv/bypass/BernBypass70/frames') |
| videos_train, phase_labels_train, step_labels_train = read_pkl_data( |
| labels, images |
| ) |
|
|
| |
| labels = os.path.join('/gpfswork/rech/okw/ukw13bv/MultiBypass140/labels', 'bern', 'labels_by70_splits/labels', 'val', f'1fps_0.pickle') |
| images = os.path.join('/gpfsscratch/rech/okw/ukw13bv/bypass/BernBypass70/frames') |
| videos_val, phase_labels_val, step_labels_val = read_pkl_data( |
| labels, images |
| ) |
|
|
| |
| labels = os.path.join('/gpfswork/rech/okw/ukw13bv/MultiBypass140/labels', 'bern', 'labels_by70_splits/labels', 'test', f'1fps_0.pickle') |
| images = os.path.join('/gpfsscratch/rech/okw/ukw13bv/bypass/BernBypass70/frames') |
| videos_test, phase_labels_test, step_labels_test = read_pkl_data(labels, images) |
|
|
| _base_ = ['../base.py'] |
| config = dict( |
| train_config=[ |
| dict( |
| type='Recognition_frame_bypass', |
| img_list=v, |
| label_list=l, |
| transforms=transforms.Compose( |
| [ |
| transforms.Resize((360, 640)), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ] |
| ), |
| ) for v, l in zip(videos_train, phase_labels_train) |
| ], |
| val_config=[ |
| dict( |
| type='Recognition_frame_bypass', |
| img_list=v, |
| label_list=l, |
| transforms=transforms.Compose( |
| [ |
| transforms.Resize((360, 640)), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ] |
| ), |
| ) for v, l in zip(videos_val, phase_labels_val) |
| ], |
| test_config=[ |
| dict( |
| type='Recognition_frame_bypass', |
| img_list=v, |
| label_list=l, |
| transforms=transforms.Compose( |
| [ |
| transforms.Resize((360, 640)), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ] |
| ), |
| ) for v, l in zip(videos_test, phase_labels_test) |
| ], |
| model_config = dict( |
| type='MVNet_feature_extractor', |
| backbone_img = dict( |
| type='img_backbones/ImageEncoder_feature_extractor', |
| |
| num_classes=768, |
| pretrained='imagenet', |
| backbone_name='resnet_50', |
| |
| img_norm=False, |
| ), |
| backbone_text= dict( |
| type='text_backbones/BertEncoder', |
| text_bert_type='/gpfswork/rech/okw/ukw13bv/mmsl/biobert_pretrain_output_all_notes_150000', |
| text_last_n_layers=4, |
| text_aggregate_method='sum', |
| text_norm=False, |
| text_embedding_dim=768, |
| text_freeze_bert=False, |
| text_agg_tokens=True |
| ) |
| ) |
| ) |
|
|
|
|