| |
| |
| |
|
|
| from pdb import set_trace as bb |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision.transforms as tvf |
|
|
| from core.conv_mixer import ConvMixer |
|
|
| norm_RGB = tvf.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
|
|
| class PixelDesc (nn.Module): |
| def __init__(self, path='models/PUMP_st.pt'): |
| super().__init__() |
| state_dict = torch.load( path, 'cpu' ) |
| self.pixel_desc = ConvMixer(output_dim=128, hidden_dim=512, depth=7, patch_size=4, kernel_size=9).eval() |
| self.pixel_desc.load_state_dict(state_dict) |
|
|
| def configure(self, pipeline): |
| |
| pipeline.__class__ = type(type(pipeline).__name__+'_Trained', (DescPipeline, type(pipeline)), {}) |
| return self |
|
|
| def get_atomic_patch_size(self): |
| return 4 |
|
|
| def forward(self, img, stride=1, offset=0): |
| if img.ndim == 3: img = img[None] |
| trf = torch.eye(3, device=img.device) |
|
|
| desc = self.pixel_desc( img ) |
| desc = desc[..., offset::stride, offset::stride].contiguous() |
| return desc, trf |
|
|
|
|
| class DescPipeline: |
| def extract_descs(self, img1, img2, dtype=None): |
| |
| img1, sca1 = self.demultiplex_img_trf(img1) |
| img2, sca2 = self.demultiplex_img_trf(img2) |
|
|
| |
| fimg1, fimg2 = [norm_RGB(img.type(dtype)/255) for img in (img1, img2)] |
|
|
| self.pixel_desc.type(fimg1.dtype) |
| desc1, trf1 = self.pixel_desc(fimg1, stride=4, offset=2) |
| desc2, trf2 = self.pixel_desc(fimg2) |
| return (img1, img2), (desc1.type(dtype), desc2.type(dtype)), (sca1@trf1, sca2@trf2) |
|
|
| def first_level(self, desc1, desc2, **kw): |
| B, C, H, W = desc1.shape |
| weights = desc1.permute(0, 2, 3, 1).view(H*W, C, 1, 1) |
| corr = F.conv2d(desc2, weights, padding=0, bias=None)[0] |
| norms = torch.ones(desc1.shape[-2:], device=corr.device) |
| return corr.view(desc1.shape[-2:]+desc2.shape[-2:]), norms |
|
|