File size: 1,019 Bytes
6146368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import numpy as np
import skimage
import torch
from hydra import compose
from hydra.utils import instantiate
from omegaconf import OmegaConf
from torch import nn
from torch.nn import functional as F
from torchvision.ops import roi_align
from torchvision.transforms import Resize

from .query_generator import C_base
from .sam_mask import MaskProcessor


class Box_correction(nn.Module):

    def __init__(
            self,
            reduction,
            image_size,
            emb_dim,
    ):

        super(Box_correction, self).__init__()

        self.sam_mask = MaskProcessor(emb_dim, image_size, reduction)
        self.sam_corr = True

    def forward(self, feats, outputs, x):

        # mask processing
        masks, ious, corrected_bboxes = self.sam_mask(feats, outputs)
        for i in range(len(outputs)):
            outputs[i]["scores"] = ious[i]
            outputs[i]["pred_boxes"] = corrected_bboxes[i].to(outputs[i]["pred_boxes"].device).unsqueeze(0) /x.shape[-1]
        return outputs, masks