| """ |
| This file contains some useful functions for train / val. |
| """ |
| import os |
| import numpy as np |
| import torch |
| import random |
| from scalelsd.ssl.models.detector import ScaleLSD |
|
|
|
|
| |
| |
| |
| def parse_h5_data(h5_data): |
| """ Parse h5 dataset. """ |
| output_data = {} |
| for key in h5_data.keys(): |
| output_data[key] = np.array(h5_data[key]) |
| |
| return output_data |
|
|
|
|
| def fix_seeds(random_seed): |
| random.seed(random_seed) |
| np.random.seed(random_seed) |
| os.environ['PYTHONHASHSEED'] = str(random_seed) |
| torch.manual_seed(random_seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(random_seed) |
| torch.cuda.manual_seed_all(random_seed) |
| |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
| |
| |
| |
| |
|
|
| |
| |
|
|
|
|
| def load_scalelsd_model(ckpt_path, device='cuda'): |
| """load model""" |
| use_layer_scale = True |
| if os.path.basename(ckpt_path) == 'scalelsd-vitbase-v1-train-sa1b.pt': |
| use_layer_scale = False |
| elif os.path.basename(ckpt_path) == 'scalelsd-vitbase-v2-train-sa1b.pt': |
| use_layer_scale = True |
| else: |
| raise ValueError(f'Unknown model: {os.path.basename(ckpt_path)}') |
|
|
| model = ScaleLSD(gray_scale=True, use_layer_scale=use_layer_scale) |
| model = model.eval().to(device) |
| state_dict = torch.load(ckpt_path, map_location='cpu', weights_only=True) |
| try: |
| model.load_state_dict(state_dict['model_state']) |
| except: |
| model.load_state_dict(state_dict) |
|
|
| return model |
|
|
|
|