linear_code / config_recog_bern_bypass_frame_linear.py
KKYYKK's picture
Upload config_recog_bern_bypass_frame_linear.py with huggingface_hub
e744d68 verified
import torch
import torchvision.transforms as transforms
import os
import logging
import pickle
def read_pkl_data(pkl_path, img_path):
logging.info('reading pickle file: '+ pkl_path)
with open(pkl_path, "rb") as fp:
data = pickle.load(fp)
fp.close()
root_dir = img_path
if not os.path.exists(root_dir):
root_dir = root_dir.replace('train', '').replace('val', '').replace('test', '')
imgs, phases, steps = [], [], []
for vid_name in sorted(data.keys()):
paths = [
os.path.join(root_dir, vid_name, f"{item['Frame_id']}.jpg")
for item in data[vid_name]
]
imgs.append(paths)
phases.append([item['Phase_gt'] for item in data[vid_name]])
steps.append([item['Step_gt'] for item in data[vid_name]])
return imgs, phases, steps
## Read test pickle files
#### TRAIN ####
labels = os.path.join('/gpfswork/rech/okw/ukw13bv/MultiBypass140/labels', 'bern', 'labels_by70_splits/labels', 'train', f'1fps_100_0.pickle')
images = os.path.join('/gpfsscratch/rech/okw/ukw13bv/bypass/BernBypass70/frames')
videos_train, phase_labels_train, step_labels_train = read_pkl_data(
labels, images
)
#### VAL ####
labels = os.path.join('/gpfswork/rech/okw/ukw13bv/MultiBypass140/labels', 'bern', 'labels_by70_splits/labels', 'val', f'1fps_0.pickle')
images = os.path.join('/gpfsscratch/rech/okw/ukw13bv/bypass/BernBypass70/frames')
videos_val, phase_labels_val, step_labels_val = read_pkl_data(
labels, images
)
#### TEST ####
labels = os.path.join('/gpfswork/rech/okw/ukw13bv/MultiBypass140/labels', 'bern', 'labels_by70_splits/labels', 'test', f'1fps_0.pickle')
images = os.path.join('/gpfsscratch/rech/okw/ukw13bv/bypass/BernBypass70/frames')
videos_test, phase_labels_test, step_labels_test = read_pkl_data(labels, images)
_base_ = ['../base.py']
config = dict(
train_config=[
dict(
type='Recognition_frame_bypass',
img_list=v,
label_list=l,
transforms=transforms.Compose(
[
transforms.Resize((360, 640)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
),
) for v, l in zip(videos_train, phase_labels_train)
],
val_config=[
dict(
type='Recognition_frame_bypass',
img_list=v,
label_list=l,
transforms=transforms.Compose(
[
transforms.Resize((360, 640)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
),
) for v, l in zip(videos_val, phase_labels_val)
],
test_config=[
dict(
type='Recognition_frame_bypass',
img_list=v,
label_list=l,
transforms=transforms.Compose(
[
transforms.Resize((360, 640)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
),
) for v, l in zip(videos_test, phase_labels_test)
],
model_config = dict(
type='MVNet_feature_extractor',
backbone_img = dict(
type='img_backbones/ImageEncoder_feature_extractor',
# type='img_backbones/ImageEncoder_CLIPVISUAL',
num_classes=768,
pretrained='imagenet', # imagenet/ssl/random
backbone_name='resnet_50',
# backbone_name='resnet_50_clip'
img_norm=False,
),
backbone_text= dict(
type='text_backbones/BertEncoder',
text_bert_type='/gpfswork/rech/okw/ukw13bv/mmsl/biobert_pretrain_output_all_notes_150000',
text_last_n_layers=4,
text_aggregate_method='sum',
text_norm=False,
text_embedding_dim=768,
text_freeze_bert=False,
text_agg_tokens=True
)
)
)