| import torch | |
| from segment_anything.modeling import ImageEncoderViT | |
| # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa | |
| class ImageEncoderViTHQ(ImageEncoderViT): | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.patch_embed(x) | |
| if self.pos_embed is not None: | |
| x = x + self.pos_embed | |
| interm_embeddings=[] | |
| for blk in self.blocks: | |
| x = blk(x) | |
| if blk.window_size == 0: | |
| interm_embeddings.append(x) | |
| x = self.neck(x.permute(0, 3, 1, 2)) | |
| return x, interm_embeddings |