| import warnings |
| import torch.nn as nn |
| import torch |
| from romatch.models.matcher import * |
| from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention |
| from romatch.models.encoders import * |
| from romatch.models.tiny import TinyRoMa |
|
|
| def tiny_roma_v1_model(weights = None, freeze_xfeat=False, exact_softmax=False, xfeat = None): |
| model = TinyRoMa( |
| xfeat = xfeat, |
| freeze_xfeat=freeze_xfeat, |
| exact_softmax=exact_softmax) |
| if weights is not None: |
| model.load_state_dict(weights) |
| return model |
|
|
| def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_weights=None, amp_dtype: torch.dtype=torch.float16, **kwargs): |
| |
| |
| |
| warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') |
| gp_dim = 512 |
| feat_dim = 512 |
| decoder_dim = gp_dim + feat_dim |
| cls_to_coord_res = 64 |
| coordinate_decoder = TransformerDecoder( |
| nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]), |
| decoder_dim, |
| cls_to_coord_res**2 + 1, |
| is_classifier=True, |
| amp = True, |
| pos_enc = False,) |
| dw = True |
| hidden_blocks = 8 |
| kernel_size = 5 |
| displacement_emb = "linear" |
| disable_local_corr_grad = True |
| |
| conv_refiner = nn.ModuleDict( |
| { |
| "16": ConvRefiner( |
| 2 * 512+128+(2*7+1)**2, |
| 2 * 512+128+(2*7+1)**2, |
| 2 + 1, |
| kernel_size=kernel_size, |
| dw=dw, |
| hidden_blocks=hidden_blocks, |
| displacement_emb=displacement_emb, |
| displacement_emb_dim=128, |
| local_corr_radius = 7, |
| corr_in_other = True, |
| amp = True, |
| disable_local_corr_grad = disable_local_corr_grad, |
| bn_momentum = 0.01, |
| ), |
| "8": ConvRefiner( |
| 2 * 512+64+(2*3+1)**2, |
| 2 * 512+64+(2*3+1)**2, |
| 2 + 1, |
| kernel_size=kernel_size, |
| dw=dw, |
| hidden_blocks=hidden_blocks, |
| displacement_emb=displacement_emb, |
| displacement_emb_dim=64, |
| local_corr_radius = 3, |
| corr_in_other = True, |
| amp = True, |
| disable_local_corr_grad = disable_local_corr_grad, |
| bn_momentum = 0.01, |
| ), |
| "4": ConvRefiner( |
| 2 * 256+32+(2*2+1)**2, |
| 2 * 256+32+(2*2+1)**2, |
| 2 + 1, |
| kernel_size=kernel_size, |
| dw=dw, |
| hidden_blocks=hidden_blocks, |
| displacement_emb=displacement_emb, |
| displacement_emb_dim=32, |
| local_corr_radius = 2, |
| corr_in_other = True, |
| amp = True, |
| disable_local_corr_grad = disable_local_corr_grad, |
| bn_momentum = 0.01, |
| ), |
| "2": ConvRefiner( |
| 2 * 64+16, |
| 128+16, |
| 2 + 1, |
| kernel_size=kernel_size, |
| dw=dw, |
| hidden_blocks=hidden_blocks, |
| displacement_emb=displacement_emb, |
| displacement_emb_dim=16, |
| amp = True, |
| disable_local_corr_grad = disable_local_corr_grad, |
| bn_momentum = 0.01, |
| ), |
| "1": ConvRefiner( |
| 2 * 9 + 6, |
| 24, |
| 2 + 1, |
| kernel_size=kernel_size, |
| dw=dw, |
| hidden_blocks = hidden_blocks, |
| displacement_emb = displacement_emb, |
| displacement_emb_dim = 6, |
| amp = True, |
| disable_local_corr_grad = disable_local_corr_grad, |
| bn_momentum = 0.01, |
| ), |
| } |
| ) |
| kernel_temperature = 0.2 |
| learn_temperature = False |
| no_cov = True |
| kernel = CosKernel |
| only_attention = False |
| basis = "fourier" |
| gp16 = GP( |
| kernel, |
| T=kernel_temperature, |
| learn_temperature=learn_temperature, |
| only_attention=only_attention, |
| gp_dim=gp_dim, |
| basis=basis, |
| no_cov=no_cov, |
| ) |
| gps = nn.ModuleDict({"16": gp16}) |
| proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512)) |
| proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512)) |
| proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256)) |
| proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64)) |
| proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9)) |
| proj = nn.ModuleDict({ |
| "16": proj16, |
| "8": proj8, |
| "4": proj4, |
| "2": proj2, |
| "1": proj1, |
| }) |
| displacement_dropout_p = 0.0 |
| gm_warp_dropout_p = 0.0 |
| decoder = Decoder(coordinate_decoder, |
| gps, |
| proj, |
| conv_refiner, |
| detach=True, |
| scales=["16", "8", "4", "2", "1"], |
| displacement_dropout_p = displacement_dropout_p, |
| gm_warp_dropout_p = gm_warp_dropout_p) |
| |
| encoder = CNNandDinov2( |
| cnn_kwargs = dict( |
| pretrained=False, |
| amp = True), |
| amp = True, |
| use_vgg = True, |
| dinov2_weights = dinov2_weights, |
| amp_dtype=amp_dtype, |
| ) |
| h,w = resolution |
| symmetric = True |
| attenuate_cert = True |
| sample_mode = "threshold_balanced" |
| matcher = RegressionMatcher(encoder, decoder, h=h, w=w, upsample_preds=upsample_preds, |
| symmetric = symmetric, attenuate_cert = attenuate_cert, sample_mode = sample_mode, **kwargs).to(device) |
| matcher.load_state_dict(weights) |
| return matcher |
|
|