from __future__ import annotations import numpy as np import torch import cv2 from PIL import Image import torchvision.transforms as T from timm.data import resolve_data_config from backbones import get_backbone from segmenters import BaseSegmenter class PCASegmenter(BaseSegmenter): def __init__( self, backbone_name: str = "dinov3_base", device: str | None = None, threshold: float = 2.5, kernel_size: int = 5, border: float = 0.2, ): super().__init__() self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) self.model = get_backbone(backbone_name).to(self.device) self.model.eval() cfg = resolve_data_config({}, model=self.model) _, img_size, _ = cfg["input_size"] arch = getattr(getattr(self.model, "pretrained_cfg", {}), "get", lambda k, d=None: {})( # type: ignore[arg-type] "architecture", "" ) if isinstance(arch, str) and "dinov3" in arch: img_size = max(img_size, 512) self.img_size = img_size interp = cfg.get("interpolation", "bicubic") self.transform = T.Compose( [ T.Resize((self.img_size, self.img_size), interpolation=getattr(T.InterpolationMode, interp.upper(), T.InterpolationMode.BICUBIC)), T.ToTensor(), T.Normalize(mean=cfg.get("mean", (0.485, 0.456, 0.406)), std=cfg.get("std", (0.229, 0.224, 0.225))), ] ) self.threshold = threshold self.border = border self.kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8) def get_object_mask(self, image: np.ndarray) -> np.ndarray: h0, w0 = image.shape[:2] pil = Image.fromarray(image.astype(np.uint8)) x = self.transform(pil).unsqueeze(0).to(self.device) with torch.inference_mode(): out = self.model.forward_features(x) tokens = out.get("x_norm_patchtokens") if isinstance(out, dict) else out if tokens is None and isinstance(out, dict): tokens = out.get("x") if tokens is not None and tokens.ndim == 4: B, C, Hf, Wf = tokens.shape tokens = tokens.permute(0, 2, 3, 1).reshape(B, Hf * Wf, C) gh_dyn = int(np.sqrt(tokens.shape[1])) gw_dyn = max(1, tokens.shape[1] // max(1, gh_dyn)) gh, gw = gh_dyn, gw_dyn if hasattr(self.model, "patch_embed") and hasattr(self.model.patch_embed, "grid_size"): gh0, gw0 = self.model.patch_embed.grid_size if gh0 * gw0 == tokens.shape[1]: gh, gw = gh0, gw0 n_patches = gh * gw tokens = tokens[:, -n_patches:, :] feats = tokens.squeeze(0).detach().cpu().numpy().astype(np.float32) feats -= feats.mean(0, keepdims=True) u, s, vh = np.linalg.svd(feats, full_matrices=False) pc1 = vh[0] scores = feats @ pc1 mask = scores > self.threshold m_grid = mask.reshape(gh, gw) bh = int(gh * self.border) bw = int(gw * self.border) inner = m_grid[bh : gh - bh, bw : gw - bw] if inner.size > 0 and inner.mean() <= 0.35: mask = scores < -self.threshold m_grid = mask.reshape(gh, gw) mask = m_grid.astype(np.uint8) mask = cv2.dilate(mask, self.kernel, iterations=1) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, self.kernel) mask = cv2.resize(mask, (w0, h0), interpolation=cv2.INTER_NEAREST) return mask.astype(bool)