| |
| |
|
|
| |
| |
| |
|
|
| import torch |
|
|
| from .croco import CroCoNet |
|
|
|
|
| def croco_args_from_ckpt(ckpt): |
| if 'croco_kwargs' in ckpt: |
| return ckpt['croco_kwargs'] |
| elif 'args' in ckpt and hasattr(ckpt['args'], 'model'): |
| s = ckpt['args'].model |
| assert s.startswith('CroCoNet(') |
| return eval('dict'+s[len('CroCoNet'):]) |
| else: |
| return dict() |
|
|
| class CroCoDownstreamMonocularEncoder(CroCoNet): |
|
|
| def __init__(self, |
| head, |
| **kwargs): |
| """ Build network for monocular downstream task, only using the encoder. |
| It takes an extra argument head, that is called with the features |
| and a dictionary img_info containing 'width' and 'height' keys |
| The head is setup with the croconet arguments in this init function |
| NOTE: It works by *calling super().__init__() but with redefined setters |
| |
| """ |
| super(CroCoDownstreamMonocularEncoder, self).__init__(**kwargs) |
| head.setup(self) |
| self.head = head |
|
|
| def _set_mask_generator(self, *args, **kwargs): |
| """ No mask generator """ |
| return |
|
|
| def _set_mask_token(self, *args, **kwargs): |
| """ No mask token """ |
| self.mask_token = None |
| return |
|
|
| def _set_decoder(self, *args, **kwargs): |
| """ No decoder """ |
| return |
|
|
| def _set_prediction_head(self, *args, **kwargs): |
| """ No 'prediction head' for downstream tasks.""" |
| return |
|
|
| def forward(self, img): |
| """ |
| img if of size batch_size x 3 x h x w |
| """ |
| B, C, H, W = img.size() |
| img_info = {'height': H, 'width': W} |
| need_all_layers = hasattr(self.head, 'return_all_blocks') and self.head.return_all_blocks |
| out, _, _ = self._encode_image(img, do_mask=False, return_all_blocks=need_all_layers) |
| return self.head(out, img_info) |
| |
| |
| class CroCoDownstreamBinocular(CroCoNet): |
|
|
| def __init__(self, |
| head, |
| **kwargs): |
| """ Build network for binocular downstream task |
| It takes an extra argument head, that is called with the features |
| and a dictionary img_info containing 'width' and 'height' keys |
| The head is setup with the croconet arguments in this init function |
| """ |
| super(CroCoDownstreamBinocular, self).__init__(**kwargs) |
| head.setup(self) |
| self.head = head |
|
|
| def _set_mask_generator(self, *args, **kwargs): |
| """ No mask generator """ |
| return |
|
|
| def _set_mask_token(self, *args, **kwargs): |
| """ No mask token """ |
| self.mask_token = None |
| return |
|
|
| def _set_prediction_head(self, *args, **kwargs): |
| """ No prediction head for downstream tasks, define your own head """ |
| return |
| |
| def encode_image_pairs(self, img1, img2, return_all_blocks=False): |
| """ run encoder for a pair of images |
| it is actually ~5% faster to concatenate the images along the batch dimension |
| than to encode them separately |
| """ |
| |
| |
| |
| |
| out, pos, _ = self._encode_image( torch.cat( (img1,img2), dim=0), do_mask=False, return_all_blocks=return_all_blocks ) |
| if return_all_blocks: |
| out,out2 = list(map(list, zip(*[o.chunk(2, dim=0) for o in out]))) |
| out2 = out2[-1] |
| else: |
| out,out2 = out.chunk(2, dim=0) |
| pos,pos2 = pos.chunk(2, dim=0) |
| return out, out2, pos, pos2 |
|
|
| def forward(self, img1, img2): |
| B, C, H, W = img1.size() |
| img_info = {'height': H, 'width': W} |
| return_all_blocks = hasattr(self.head, 'return_all_blocks') and self.head.return_all_blocks |
| out, out2, pos, pos2 = self.encode_image_pairs(img1, img2, return_all_blocks=return_all_blocks) |
| if return_all_blocks: |
| decout = self._decoder(out[-1], pos, None, out2, pos2, return_all_blocks=return_all_blocks) |
| decout = out+decout |
| else: |
| decout = self._decoder(out, pos, None, out2, pos2, return_all_blocks=return_all_blocks) |
| return self.head(decout, img_info) |