File size: 3,674 Bytes
1834bc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from __future__ import annotations
import numpy as np
import torch
import cv2
from PIL import Image
import torchvision.transforms as T
from timm.data import resolve_data_config

from backbones import get_backbone
from segmenters import BaseSegmenter


class PCASegmenter(BaseSegmenter):
    def __init__(

        self,

        backbone_name: str = "dinov3_base",

        device: str | None = None,

        threshold: float = 2.5,

        kernel_size: int = 5,

        border: float = 0.2,

    ):
        super().__init__()
        self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
        self.model = get_backbone(backbone_name).to(self.device)
        self.model.eval()


        cfg = resolve_data_config({}, model=self.model)
        _, img_size, _ = cfg["input_size"]
        arch = getattr(getattr(self.model, "pretrained_cfg", {}), "get", lambda k, d=None: {})(  # type: ignore[arg-type]
            "architecture", ""
        )
        if isinstance(arch, str) and "dinov3" in arch:
            img_size = max(img_size, 512)

        self.img_size = img_size
        interp = cfg.get("interpolation", "bicubic")
        self.transform = T.Compose(
            [
                T.Resize((self.img_size, self.img_size), interpolation=getattr(T.InterpolationMode, interp.upper(), T.InterpolationMode.BICUBIC)),
                T.ToTensor(),
                T.Normalize(mean=cfg.get("mean", (0.485, 0.456, 0.406)), std=cfg.get("std", (0.229, 0.224, 0.225))),
            ]
        )
        self.threshold = threshold
        self.border = border
        self.kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8)

    def get_object_mask(self, image: np.ndarray) -> np.ndarray:
        h0, w0 = image.shape[:2]
        pil = Image.fromarray(image.astype(np.uint8))
        x = self.transform(pil).unsqueeze(0).to(self.device)

        with torch.inference_mode():
            out = self.model.forward_features(x)

        tokens = out.get("x_norm_patchtokens") if isinstance(out, dict) else out

        if tokens is None and isinstance(out, dict):
            tokens = out.get("x")
        if tokens is not None and tokens.ndim == 4:
            B, C, Hf, Wf = tokens.shape
            tokens = tokens.permute(0, 2, 3, 1).reshape(B, Hf * Wf, C)

        gh_dyn = int(np.sqrt(tokens.shape[1]))
        gw_dyn = max(1, tokens.shape[1] // max(1, gh_dyn))
        gh, gw = gh_dyn, gw_dyn

        if hasattr(self.model, "patch_embed") and hasattr(self.model.patch_embed, "grid_size"):
            gh0, gw0 = self.model.patch_embed.grid_size
            if gh0 * gw0 == tokens.shape[1]:
                gh, gw = gh0, gw0
        n_patches = gh * gw
        tokens = tokens[:, -n_patches:, :]

        feats = tokens.squeeze(0).detach().cpu().numpy().astype(np.float32)
        feats -= feats.mean(0, keepdims=True)
        u, s, vh = np.linalg.svd(feats, full_matrices=False)
        pc1 = vh[0]
        scores = feats @ pc1
        mask = scores > self.threshold
        m_grid = mask.reshape(gh, gw)
        bh = int(gh * self.border)
        bw = int(gw * self.border)
        inner = m_grid[bh : gh - bh, bw : gw - bw]
        if inner.size > 0 and inner.mean() <= 0.35:
            mask = scores < -self.threshold
            m_grid = mask.reshape(gh, gw)
        mask = m_grid.astype(np.uint8)
        mask = cv2.dilate(mask, self.kernel, iterations=1)
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, self.kernel)
        mask = cv2.resize(mask, (w0, h0), interpolation=cv2.INTER_NEAREST)
        return mask.astype(bool)