|
|
|
|
|
|
|
|
"""Convert Satlas-Pretrain model checkpoints to a format accepted by TorchGeo.
|
|
|
|
|
|
Reference implementation:
|
|
|
|
|
|
* https://github.com/allenai/satlaspretrain_models/blob/main/satlaspretrain_models/models/backbones.py
|
|
|
"""
|
|
|
|
|
|
import glob
|
|
|
import hashlib
|
|
|
import os
|
|
|
|
|
|
import timm
|
|
|
import torch
|
|
|
import torchvision
|
|
|
|
|
|
|
|
|
for checkpoint in glob.iglob('*.pth'):
|
|
|
|
|
|
if '-' in checkpoint:
|
|
|
continue
|
|
|
|
|
|
print(checkpoint)
|
|
|
|
|
|
|
|
|
state_dict = torch.load(checkpoint, map_location=torch.device('cpu'), weights_only=True)
|
|
|
|
|
|
|
|
|
if 'backbone.backbone.resnet.conv1.weight' in state_dict:
|
|
|
state_dict = {key.replace('backbone.backbone.resnet.', ''): value for key, value in state_dict.items() if key.startswith('backbone.backbone.resnet.')}
|
|
|
elif 'backbone.resnet.conv1.weight' in state_dict:
|
|
|
state_dict = {key.replace('backbone.resnet.', ''): value for key, value in state_dict.items() if key.startswith('backbone.resnet.')}
|
|
|
elif 'backbone.backbone.backbone.features.0.0.weight' in state_dict:
|
|
|
state_dict = {key.replace('backbone.backbone.backbone.', ''): value for key, value in state_dict.items() if key.startswith('backbone.backbone.backbone.')}
|
|
|
elif 'backbone.backbone.features.0.0.weight' in state_dict:
|
|
|
state_dict = {key.replace('backbone.backbone.', ''): value for key, value in state_dict.items() if key.startswith('backbone.backbone.')}
|
|
|
|
|
|
if 'resnet' in checkpoint:
|
|
|
|
|
|
in_chans = state_dict['conv1.weight'].shape[1]
|
|
|
|
|
|
|
|
|
model_name = checkpoint.split('_')[1]
|
|
|
model = timm.create_model(model_name, in_chans=in_chans)
|
|
|
elif 'swin' in checkpoint:
|
|
|
|
|
|
out_channels, num_channels, kernel_size_0, kernel_size_1 = state_dict['features.0.0.weight'].shape
|
|
|
|
|
|
|
|
|
if 'swint' in checkpoint:
|
|
|
model = torchvision.models.swin_v2_t()
|
|
|
elif 'swinb' in checkpoint:
|
|
|
model = torchvision.models.swin_v2_b()
|
|
|
|
|
|
model.features[0][0] = torch.nn.Conv2d(num_channels, out_channels, kernel_size=(kernel_size_0, kernel_size_1), stride=(4, 4))
|
|
|
|
|
|
|
|
|
model.load_state_dict(state_dict)
|
|
|
|
|
|
|
|
|
torch.save(model.state_dict(), f'{checkpoint}.tmp')
|
|
|
|
|
|
|
|
|
with open(f'{checkpoint}.tmp', 'rb') as f:
|
|
|
checksum = hashlib.file_digest(f, 'sha256').hexdigest()
|
|
|
|
|
|
|
|
|
os.rename(f'{checkpoint}.tmp', f'{checkpoint[:-4]}-{checksum[:8]}.pth')
|
|
|
|