Image Segmentation
ultralytics
PyTorch
English
object-detection
instance-segmentation
yolov8
coco
real-time
capsule-network
interpretable-ai
symbolic-ai
Eval Results (legacy)
Instructions to use zpyuan/SymbolicCapsuleNetwork with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- ultralytics
How to use zpyuan/SymbolicCapsuleNetwork with ultralytics:
from ultralytics import YOLOvv8 model = YOLOvv8.from_pretrained("zpyuan/SymbolicCapsuleNetwork") source = 'http://images.cocodataset.org/val2017/000000039769.jpg' model.predict(source=source, save=True) - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| import copy | |
| import math | |
| import time | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from ultralytics.nn.modules import Conv, DWConv, Detect, Segment | |
| from ultralytics.nn.modules.block import Proto26 | |
| class PrimaryCaps(nn.Module): | |
| r"""Primary convolutional capsules. | |
| Outputs pose and activation, plus a concatenated NHWC capsule tensor. | |
| Args: | |
| A: Input feature channels. | |
| B: Number of capsule types. | |
| K: Convolution kernel size. | |
| P: Pose matrix side length (pose size is ``P*P``). | |
| stride: Convolution stride. | |
| Input shape: | |
| x: ``(N, A, H, W)`` | |
| Output shape: | |
| a: ``(N, B, H_out, W_out)`` | |
| p: ``(N, B*P*P, H_out, W_out)`` | |
| out: ``(N, H_out, W_out, B*(P*P+1))`` | |
| Parameter size: | |
| pose conv + act conv | |
| ``(K*K*A*B*P*P + B*P*P) + (K*K*A*B + B)`` | |
| """ | |
| def __init__(self, A: int = 32, B: int = 32, K: int = 1, P: int = 4, stride: int = 1): | |
| super().__init__() | |
| self.B = B | |
| self.P = P | |
| self.psize = P * P | |
| self.pose = nn.Conv2d(in_channels=A, out_channels=B * self.psize, kernel_size=K, stride=stride, bias=True) | |
| self.a = nn.Conv2d(in_channels=A, out_channels=B, kernel_size=K, stride=stride, bias=True) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| # p: (B, B*psize, H, W), a: (B, B, H, W) | |
| p = self.pose(x) | |
| a = self.sigmoid(self.a(x)) | |
| out = torch.cat([p, a], dim=1).permute(0, 2, 3, 1).contiguous() # (B, H, W, B*(psize+1)) | |
| return a, p, out | |
| class ConvCaps(nn.Module): | |
| r"""Convolutional capsules with EM routing. | |
| Args: | |
| B: Input capsule types. | |
| C: Output capsule types. | |
| K: Patch kernel size. | |
| P: Pose matrix side length (pose size is ``P*P``). | |
| stride: Spatial stride for patch extraction. | |
| iters: Number of EM routing iterations. | |
| coor_add: Add coordinate offsets (class-caps style option). | |
| w_shared: Share transform matrices across spatial positions. | |
| Input shape: | |
| x: ``(N, H, W, B*(P*P+1))`` | |
| Output shape: | |
| p_out: ``(N, H_out, W_out, C*P*P)`` | |
| a_out: ``(N, H_out, W_out, C)`` | |
| out: ``(N, H_out, W_out, C*(P*P+1))`` | |
| Parameter size: | |
| If ``w_shared=False``: | |
| ``weights: (K*K*B*C*P*P*P*P)``, ``beta_u: C``, ``beta_a: C`` | |
| If ``w_shared=True``: | |
| ``weights: (B*C*P*P*P*P)``, ``beta_u: C``, ``beta_a: C`` | |
| Total = ``weights + 2*C`` (excluding non-trainable buffers). | |
| """ | |
| def __init__( | |
| self, | |
| B: int = 32, | |
| C: int = 32, | |
| K: int = 3, | |
| P: int = 4, | |
| stride: int = 1, | |
| iters: int = 3, | |
| coor_add: bool = False, | |
| w_shared: bool = False, | |
| ): | |
| super().__init__() | |
| self.B = B | |
| self.C = C | |
| self.K = K | |
| self.P = P | |
| self.psize = P * P | |
| self.stride = stride | |
| self.iters = iters | |
| self.coor_add = coor_add | |
| self.w_shared = w_shared | |
| self.eps = 1e-6 | |
| self._lambda = 1e-3 | |
| self.register_buffer("ln_2pi", torch.tensor(math.log(2 * math.pi), dtype=torch.float32), persistent=False) | |
| # Matrix-caps paper uses per-capsule beta scalars. | |
| self.beta_u = nn.Parameter(torch.zeros(C)) | |
| self.beta_a = nn.Parameter(torch.zeros(C)) | |
| # For non-shared conv-caps, input vote count is K*K*B. For shared mode it is B then repeated by HW. | |
| weight_in = B if w_shared else (K * K * B) | |
| self.weights = nn.Parameter(torch.randn(1, weight_in, C, self.psize, self.psize) * 0.02) | |
| self.sigmoid = nn.Sigmoid() | |
| self.softmax = nn.Softmax(dim=2) | |
| def m_step( | |
| self, | |
| a_in: torch.Tensor, | |
| r: torch.Tensor, | |
| v: torch.Tensor, | |
| eps: float, | |
| b: int, | |
| B: int, | |
| C: int, | |
| psize: int, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| # a_in: (b, B, 1) or (b, B, 1, 1), r: (b, B, C, 1), v: (b, B, C, psize) | |
| if a_in.ndim == 3: | |
| a_in = a_in.unsqueeze(2) | |
| r = r * a_in | |
| r = r / (r.sum(dim=2, keepdim=True) + eps) | |
| r_sum = r.sum(dim=1, keepdim=True) | |
| coeff = r / (r_sum + eps) | |
| mu = torch.sum(coeff * v, dim=1, keepdim=True) # (b, 1, C, psize) | |
| sigma_sq = torch.sum(coeff * (v - mu).pow(2), dim=1, keepdim=True) + eps | |
| sigma_sq = sigma_sq.clamp_min(1e-4) | |
| r_sum_flat = r_sum.view(b, C, 1) | |
| sigma_sq_flat = sigma_sq.view(b, C, psize).clamp_min(1e-4) | |
| cost_h = (self.beta_u.view(1, C, 1) + torch.log(torch.sqrt(sigma_sq_flat))) * r_sum_flat | |
| a_out = self.sigmoid(self._lambda * (self.beta_a.view(1, C) - cost_h.sum(dim=2))).clamp(1e-4, 1.0 - 1e-4) | |
| mu = torch.nan_to_num(mu, nan=0.0, posinf=1e4, neginf=-1e4) | |
| sigma_sq = torch.nan_to_num(sigma_sq, nan=1e-4, posinf=1e4, neginf=1e-4) | |
| a_out = torch.nan_to_num(a_out, nan=0.5, posinf=1.0 - 1e-4, neginf=1e-4) | |
| return a_out, mu, sigma_sq | |
| def e_step( | |
| self, | |
| mu: torch.Tensor, | |
| sigma_sq: torch.Tensor, | |
| a_out: torch.Tensor, | |
| v: torch.Tensor, | |
| eps: float, | |
| b: int, | |
| C: int, | |
| ) -> torch.Tensor: | |
| # mu: (b,1,C,psize), sigma_sq: (b,1,C,psize), a_out: (b,C), v: (b,B,C,psize) | |
| sigma_sq = sigma_sq.clamp_min(1e-4) | |
| a_out = a_out.clamp(1e-4, 1.0 - 1e-4) | |
| ln_p_j_h = -1.0 * (v - mu).pow(2) / (2.0 * sigma_sq) - torch.log(torch.sqrt(sigma_sq)) - 0.5 * self.ln_2pi | |
| ln_ap = ln_p_j_h.sum(dim=3) + torch.log(a_out.view(b, 1, C) + eps) | |
| ln_ap = torch.nan_to_num(ln_ap, nan=0.0, posinf=50.0, neginf=-50.0) | |
| r = self.softmax(ln_ap).unsqueeze(-1) # (b,B,C,1) | |
| r = torch.nan_to_num(r, nan=(1.0 / max(C, 1)), posinf=1.0, neginf=0.0) | |
| return r | |
| def caps_em_routing(self, v: torch.Tensor, a_in: torch.Tensor, C: int, eps: float) -> tuple[torch.Tensor, torch.Tensor]: | |
| b, B, _, psize = v.shape | |
| r = v.new_full((b, B, C, 1), 1.0 / C) | |
| for t in range(self.iters): | |
| a_out, mu, sigma_sq = self.m_step(a_in, r, v, eps, b, B, C, psize) | |
| if t < self.iters - 1: | |
| r = self.e_step(mu, sigma_sq, a_out, v, eps, b, C) | |
| # p_out: (b, C, psize), a_out: (b, C) | |
| p_out = torch.nan_to_num(mu.squeeze(1), nan=0.0, posinf=1e4, neginf=-1e4) | |
| a_out = torch.nan_to_num(a_out, nan=0.5, posinf=1.0 - 1e-4, neginf=1e-4) | |
| return p_out, a_out | |
| def add_pathes(self, x: torch.Tensor, B: int, K: int, psize: int, stride: int) -> tuple[torch.Tensor, int, int]: | |
| # x: (b, h, w, B*(psize+1)) -> patches: (b, oh, ow, K*K, B*(psize+1)) | |
| b, h, w, c = x.shape | |
| x_chw = x.permute(0, 3, 1, 2).contiguous() | |
| pad = K // 2 | |
| patches = F.unfold(x_chw, kernel_size=K, padding=pad, stride=stride) | |
| oh = (h + 2 * pad - K) // stride + 1 | |
| ow = (w + 2 * pad - K) // stride + 1 | |
| patches = patches.transpose(1, 2).contiguous().view(b, oh, ow, K * K, c) | |
| return patches, oh, ow | |
| def transform_view(self, x: torch.Tensor, w: torch.Tensor, C: int, P: int, w_shared: bool = False) -> torch.Tensor: | |
| # x: (b, in_votes, psize), w: (1, in_votes_base, C, psize, psize) | |
| b, in_votes, psize = x.shape | |
| assert psize == P * P | |
| w0 = w[0] | |
| if w_shared: | |
| base = w0.size(0) | |
| reps = in_votes // base | |
| w0 = w0.repeat(reps, 1, 1, 1) | |
| # (b, in_votes, C, psize) | |
| v = torch.einsum("bip,icpq->bicq", x, w0) | |
| return v | |
| def add_coord(self, v: torch.Tensor, b: int, h: int, w: int, B: int, C: int, psize: int) -> torch.Tensor: | |
| # v: (b, h*w*B, C, psize) | |
| # Supports rectangular feature maps (h != w). | |
| v = v.view(b, h, w, B, C, psize) | |
| device = v.device | |
| dtype = v.dtype | |
| coor_h_vals = torch.arange(h, dtype=dtype, device=device) / float(max(h, 1)) | |
| coor_w_vals = torch.arange(w, dtype=dtype, device=device) / float(max(w, 1)) | |
| coor_h = torch.zeros(1, h, 1, 1, 1, psize, dtype=dtype, device=device) | |
| coor_w = torch.zeros(1, 1, w, 1, 1, psize, dtype=dtype, device=device) | |
| coor_h[0, :, 0, 0, 0, 0] = coor_h_vals | |
| coor_w[0, 0, :, 0, 0, 1] = coor_w_vals | |
| v = (v + coor_h + coor_w).view(b, h * w * B, C, psize) | |
| return v | |
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| # x shape: (b, h, w, B*(psize+1)) | |
| b, h, w, c = x.shape | |
| if not self.w_shared: | |
| patches, oh, ow = self.add_pathes(x, self.B, self.K, self.psize, self.stride) | |
| p_in = patches[..., : self.B * self.psize].contiguous().view(b * oh * ow, self.K * self.K * self.B, self.psize) | |
| a_in = patches[..., self.B * self.psize :].contiguous().view(b * oh * ow, self.K * self.K * self.B, 1) | |
| v = self.transform_view(p_in, self.weights, self.C, self.P, w_shared=False) | |
| p_out, a_out = self.caps_em_routing(v, a_in, self.C, self.eps) | |
| p_out = p_out.view(b, oh, ow, self.C * self.psize) | |
| a_out = a_out.view(b, oh, ow, self.C) | |
| out = torch.cat([p_out, a_out], dim=3) | |
| else: | |
| assert c == self.B * (self.psize + 1) | |
| assert self.K == 1 | |
| assert self.stride == 1 | |
| p_in = x[..., : self.B * self.psize].contiguous().view(b, h * w * self.B, self.psize) | |
| a_in = x[..., self.B * self.psize :].contiguous().view(b, h * w * self.B, 1) | |
| v = self.transform_view(p_in, self.weights, self.C, self.P, w_shared=True) | |
| if self.coor_add: | |
| v = self.add_coord(v, b, h, w, self.B, self.C, self.psize) | |
| p_cls, a_cls = self.caps_em_routing(v, a_in, self.C, self.eps) | |
| # Broadcast class capsules back to spatial map for Detect-style dense outputs. | |
| p_out = p_cls.reshape(b, 1, 1, self.C * self.psize).expand(b, h, w, self.C * self.psize) | |
| a_out = a_cls.unsqueeze(1).unsqueeze(1).expand(b, h, w, self.C) | |
| out = torch.cat([p_out, a_out], dim=3) | |
| return p_out, a_out, out | |
| class DynamicConvCaps(nn.Module): | |
| r"""Convolutional capsules with Sabour-style dynamic routing. | |
| This layer keeps the same tensor interface as ``ConvCaps``: | |
| input: (N, H, W, B*(P*P+1)) | |
| output: p_out (N, H_out, W_out, C*P*P), a_out (N, H_out, W_out, C), out concat | |
| Args: | |
| B: Input capsule types. | |
| C: Output capsule types. | |
| K: Patch kernel size. | |
| P: Pose matrix side length. | |
| stride: Patch stride. | |
| iters: Routing iterations. | |
| coor_add: Add coordinates in shared mode. | |
| w_shared: Share transforms across spatial positions (requires K=1, stride=1). | |
| """ | |
| def __init__( | |
| self, | |
| B: int = 32, | |
| C: int = 32, | |
| K: int = 3, | |
| P: int = 4, | |
| stride: int = 1, | |
| iters: int = 3, | |
| coor_add: bool = False, | |
| w_shared: bool = False, | |
| ): | |
| super().__init__() | |
| self.B = B | |
| self.C = C | |
| self.K = K | |
| self.P = P | |
| self.psize = P * P | |
| self.stride = stride | |
| self.iters = iters | |
| self.coor_add = coor_add | |
| self.w_shared = w_shared | |
| self.eps = 1e-6 | |
| weight_in = B if w_shared else (K * K * B) | |
| self.weights = nn.Parameter(torch.randn(1, weight_in, C, self.psize, self.psize) * 0.02) | |
| def squash(s: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: | |
| s2 = (s * s).sum(dim=dim, keepdim=True) | |
| scale = s2 / (1.0 + s2) | |
| return scale * s / torch.sqrt(s2 + eps) | |
| def add_pathes(self, x: torch.Tensor, K: int, stride: int) -> tuple[torch.Tensor, int, int]: | |
| b, h, w, c = x.shape | |
| x_chw = x.permute(0, 3, 1, 2).contiguous() | |
| pad = K // 2 | |
| patches = F.unfold(x_chw, kernel_size=K, padding=pad, stride=stride) | |
| oh = (h + 2 * pad - K) // stride + 1 | |
| ow = (w + 2 * pad - K) // stride + 1 | |
| patches = patches.transpose(1, 2).contiguous().view(b, oh, ow, K * K, c) | |
| return patches, oh, ow | |
| def transform_view(self, x: torch.Tensor, w_shared: bool) -> torch.Tensor: | |
| # x: (b, in_votes, psize) -> votes: (b, in_votes, C, psize) | |
| b, in_votes, psize = x.shape | |
| if psize != self.psize: | |
| raise ValueError('Invalid pose size for DynamicConvCaps') | |
| w0 = self.weights[0] | |
| if w_shared: | |
| base = w0.size(0) | |
| reps = in_votes // base | |
| w0 = w0.repeat(reps, 1, 1, 1) | |
| return torch.einsum('bip,icpq->bicq', x, w0) | |
| def add_coord(self, v: torch.Tensor, b: int, h: int, w: int, B: int, C: int, psize: int) -> torch.Tensor: | |
| # v: (b, h*w*B, C, psize) | |
| v = v.view(b, h, w, B, C, psize) | |
| device, dtype = v.device, v.dtype | |
| coor_h_vals = torch.arange(h, dtype=dtype, device=device) / float(max(h, 1)) | |
| coor_w_vals = torch.arange(w, dtype=dtype, device=device) / float(max(w, 1)) | |
| coor_h = torch.zeros(1, h, 1, 1, 1, psize, dtype=dtype, device=device) | |
| coor_w = torch.zeros(1, 1, w, 1, 1, psize, dtype=dtype, device=device) | |
| coor_h[0, :, 0, 0, 0, 0] = coor_h_vals | |
| coor_w[0, 0, :, 0, 0, 1] = coor_w_vals | |
| return (v + coor_h + coor_w).view(b, h * w * B, C, psize) | |
| def dynamic_routing(self, v: torch.Tensor, a_in: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| # v: (n, in_votes, C, psize), a_in: (n, in_votes, 1) | |
| n, in_votes, C, psize = v.shape | |
| b_ij = v.new_zeros(n, in_votes, C) | |
| a_in = a_in.clamp(1e-4, 1.0) | |
| for t in range(self.iters): | |
| c_ij = F.softmax(b_ij, dim=2) | |
| c_ij = c_ij * a_in | |
| c_ij = c_ij / (c_ij.sum(dim=2, keepdim=True) + self.eps) | |
| s_j = (c_ij.unsqueeze(-1) * v).sum(dim=1) | |
| v_j = self.squash(s_j, dim=-1, eps=self.eps) | |
| if t < self.iters - 1: | |
| agreement = (v * v_j.unsqueeze(1)).sum(dim=-1) | |
| b_ij = b_ij + agreement | |
| # activation from vector length in (0,1) | |
| a_out = torch.sqrt((v_j * v_j).sum(dim=-1) + self.eps).clamp(1e-4, 1.0 - 1e-4) | |
| return v_j, a_out | |
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| b, h, w, c = x.shape | |
| if not self.w_shared: | |
| patches, oh, ow = self.add_pathes(x, self.K, self.stride) | |
| p_in = patches[..., : self.B * self.psize].contiguous().view(b * oh * ow, self.K * self.K * self.B, self.psize) | |
| a_in = patches[..., self.B * self.psize :].contiguous().view(b * oh * ow, self.K * self.K * self.B, 1) | |
| votes = self.transform_view(p_in, w_shared=False) | |
| p_vec, a_vec = self.dynamic_routing(votes, a_in) | |
| p_out = p_vec.view(b, oh, ow, self.C * self.psize) | |
| a_out = a_vec.view(b, oh, ow, self.C) | |
| out = torch.cat([p_out, a_out], dim=3) | |
| else: | |
| if c != self.B * (self.psize + 1) or self.K != 1 or self.stride != 1: | |
| raise ValueError('DynamicConvCaps shared mode requires K=1, stride=1 and matching capsule channels') | |
| p_in = x[..., : self.B * self.psize].contiguous().view(b, h * w * self.B, self.psize) | |
| a_in = x[..., self.B * self.psize :].contiguous().view(b, h * w * self.B, 1) | |
| votes = self.transform_view(p_in, w_shared=True) | |
| if self.coor_add: | |
| votes = self.add_coord(votes, b, h, w, self.B, self.C, self.psize) | |
| p_vec, a_vec = self.dynamic_routing(votes, a_in) | |
| p_out = p_vec.reshape(b, 1, 1, self.C * self.psize).expand(b, h, w, self.C * self.psize) | |
| a_out = a_vec.unsqueeze(1).unsqueeze(1).expand(b, h, w, self.C) | |
| out = torch.cat([p_out, a_out], dim=3) | |
| p_out = torch.nan_to_num(p_out, nan=0.0, posinf=1e4, neginf=-1e4) | |
| a_out = torch.nan_to_num(a_out, nan=0.5, posinf=1.0 - 1e-4, neginf=1e-4) | |
| out = torch.nan_to_num(out, nan=0.0, posinf=1e4, neginf=-1e4) | |
| return p_out, a_out, out | |
| class SelfRoutingConvCaps(nn.Module): | |
| r"""Convolutional self-routing capsules. | |
| Keeps the same output contract as ``ConvCaps``/``DynamicConvCaps``: | |
| input: (N, H, W, B*(P*P+1)) | |
| output: p_out (N, H_out, W_out, C*P*P), a_out (N, H_out, W_out, C), out concat | |
| """ | |
| def __init__( | |
| self, | |
| B: int = 32, | |
| C: int = 32, | |
| K: int = 3, | |
| P: int = 4, | |
| stride: int = 1, | |
| iters: int = 1, | |
| coor_add: bool = False, | |
| w_shared: bool = False, | |
| ): | |
| super().__init__() | |
| _ = (iters, w_shared) # kept for API compatibility with other capsule layers. | |
| self.B = B | |
| self.C = C | |
| self.K = K | |
| self.P = P | |
| self.psize = P * P | |
| self.stride = stride | |
| self.coor_add = coor_add | |
| self.eps = 1e-6 | |
| self.kk = K * K | |
| self.kkB = self.kk * B | |
| # Pose transform for each input capsule vote -> output capsule pose. | |
| self.W1 = nn.Parameter(torch.empty(self.kkB, C, self.psize, self.psize)) | |
| nn.init.kaiming_uniform_(self.W1, a=math.sqrt(5)) | |
| # Routing logits from local pose vectors. | |
| self.W2 = nn.Parameter(torch.zeros(self.kkB, C, self.psize)) | |
| self.b2 = nn.Parameter(torch.zeros(1, 1, self.kkB, C)) | |
| def _output_hw(self, h: int, w: int) -> tuple[int, int]: | |
| pad = self.K // 2 | |
| oh = (h + 2 * pad - self.K) // self.stride + 1 | |
| ow = (w + 2 * pad - self.K) // self.stride + 1 | |
| return oh, ow | |
| def _add_coord(self, pose_unf: torch.Tensor, oh: int, ow: int) -> torch.Tensor: | |
| # pose_unf: (b, L, kkB, psize) | |
| if self.psize < 2: | |
| return pose_unf | |
| b, L, kkB, _ = pose_unf.shape | |
| device, dtype = pose_unf.device, pose_unf.dtype | |
| gy = torch.arange(oh, device=device, dtype=dtype) / float(max(oh, 1)) | |
| gx = torch.arange(ow, device=device, dtype=dtype) / float(max(ow, 1)) | |
| yy, xx = torch.meshgrid(gy, gx, indexing='ij') | |
| coords = torch.stack((yy, xx), dim=-1).view(1, L, 1, 2) | |
| pose_unf = pose_unf.clone() | |
| pose_unf[..., :2] = pose_unf[..., :2] + coords | |
| return pose_unf | |
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| # x: (b, h, w, B*(psize+1)) | |
| b, h, w, c = x.shape | |
| expected = self.B * (self.psize + 1) | |
| if c != expected: | |
| raise ValueError(f'SelfRoutingConvCaps expected {expected} channels, got {c}') | |
| pose = x[..., : self.B * self.psize] | |
| act = x[..., self.B * self.psize :] | |
| pose_chw = pose.permute(0, 3, 1, 2).contiguous() | |
| act_chw = act.permute(0, 3, 1, 2).contiguous() | |
| pad = self.K // 2 | |
| pose_unf = F.unfold(pose_chw, kernel_size=self.K, stride=self.stride, padding=pad) | |
| act_unf = F.unfold(act_chw, kernel_size=self.K, stride=self.stride, padding=pad) | |
| oh, ow = self._output_hw(h, w) | |
| l = pose_unf.shape[-1] | |
| pose_unf = pose_unf.view(b, self.B, self.psize, self.kk, l).permute(0, 4, 3, 1, 2).contiguous() | |
| pose_unf = pose_unf.view(b, l, self.kkB, self.psize) | |
| act_unf = act_unf.view(b, self.B, self.kk, l).permute(0, 3, 2, 1).contiguous() | |
| act_unf = act_unf.view(b, l, self.kkB) | |
| if self.coor_add: | |
| pose_unf = self._add_coord(pose_unf, oh, ow) | |
| # Routing logits and couplings. | |
| logit = torch.einsum('blip,icp->blic', pose_unf, self.W2) + self.b2 | |
| r = F.softmax(logit, dim=3) | |
| ar = act_unf.unsqueeze(-1) * r | |
| ar_sum = ar.sum(dim=2, keepdim=True) + self.eps | |
| coeff = ar / ar_sum | |
| a_norm = act_unf.sum(dim=2, keepdim=True) + self.eps | |
| a_out = (ar_sum.squeeze(2) / a_norm).clamp(1e-4, 1.0 - 1e-4) | |
| pose_votes = torch.einsum('blip,icpq->blicq', pose_unf, self.W1) | |
| pose_out = (coeff.unsqueeze(-1) * pose_votes).sum(dim=2) | |
| p_out = pose_out.view(b, oh, ow, self.C * self.psize) | |
| a_out = a_out.view(b, oh, ow, self.C) | |
| out = torch.cat([p_out, a_out], dim=3) | |
| p_out = torch.nan_to_num(p_out, nan=0.0, posinf=1e4, neginf=-1e4) | |
| a_out = torch.nan_to_num(a_out, nan=0.5, posinf=1.0 - 1e-4, neginf=1e-4) | |
| out = torch.nan_to_num(out, nan=0.0, posinf=1e4, neginf=-1e4) | |
| return p_out, a_out, out | |
| class CapsuleDualHead(nn.Module): | |
| """Capsule detection head for one feature level. | |
| Args: | |
| c_in: Input channels of this feature scale (from parser-provided ``ch``). | |
| nc: Number of classes (final activation capsule count in ``ConvCaps2``). | |
| reg_max: Detect DFL bins, box channels are ``4 * reg_max``. | |
| k: Number of capsule types in ``PrimaryCaps``. | |
| d: Requested pose descriptor size; internally mapped to square ``P*P``. | |
| Input shape: | |
| x: ``(N, c_in, H, W)`` | |
| Output shape: | |
| boxes: ``(N, 4*reg_max, H, W)`` | |
| scores: ``(N, nc, H, W)`` | |
| aux: dict with final capsule activations when ``return_aux=True`` else ``None`` | |
| Parameter size: | |
| ``PrimaryCaps(c_in,k) + ConvCaps(k,nc,w_shared=True) + box_bias(4*reg_max)`` | |
| Structure: | |
| PrimaryCaps -> ConvCaps(class caps only, shared) | |
| """ | |
| def __init__(self, c_in: int, nc: int, reg_max: int, k: int, d: int): | |
| super().__init__() | |
| # Matrix-caps pose is square; choose smallest square >= requested d. | |
| p = max(1, int(math.ceil(math.sqrt(d)))) | |
| self.nc = nc | |
| self.reg_max = reg_max | |
| self.P = p | |
| self.psize = self.P * self.P | |
| # A=c_in, B=k, P controls pose channels as B*(P*P). | |
| self.primary = PrimaryCaps(A=c_in, B=k, K=1, P=self.P, stride=1) | |
| # Single class-caps layer with shared transforms for parameter reduction. | |
| self.conv_caps2 = ConvCaps(B=k, C=nc, K=1, P=self.P, stride=1, iters=1, coor_add=True, w_shared=True) | |
| # Detect-style localization prior set in CapsuleDetect.bias_init(). | |
| self.box_bias = nn.Parameter(torch.zeros(4 * reg_max)) | |
| def _pose_to_box(self, p_out: torch.Tensor, a_out: torch.Tensor) -> torch.Tensor: | |
| # p_out: (b,h,w,nc*psize), a_out is intentionally unused here. | |
| # Simple rule requested: use first 4*reg_max pose values as box channels. | |
| _ = a_out | |
| box_ch = 4 * self.reg_max | |
| if p_out.shape[-1] >= box_ch: | |
| box = p_out[..., :box_ch] | |
| else: | |
| # If pose channels are fewer than required box channels, repeat and trim. | |
| reps = math.ceil(box_ch / p_out.shape[-1]) | |
| box = p_out.repeat(1, 1, 1, reps)[..., :box_ch] | |
| return box + self.box_bias.view(1, 1, 1, box_ch) | |
| def forward(self, x: torch.Tensor, return_aux: bool = False) -> tuple[torch.Tensor, torch.Tensor, dict | None]: | |
| _, _, caps0 = self.primary(x) | |
| p2, a2, _ = self.conv_caps2(caps0) | |
| boxes = self._pose_to_box(p2, a2).permute(0, 3, 1, 2).contiguous() # (b,4*reg_max,h,w) | |
| a2_logits = torch.logit(a2.clamp(1e-4, 1.0 - 1e-4)) | |
| scores = a2_logits.permute(0, 3, 1, 2).contiguous() # (b,nc,h,w) logits | |
| aux = None | |
| if return_aux: | |
| aux = { | |
| "caps2_a": a2.permute(0, 3, 1, 2).contiguous(), | |
| } | |
| return boxes, scores, aux | |
| class CapsuleClsHead(nn.Module): | |
| """Capsule classification branch used as a drop-in replacement for Detect.cv3.""" | |
| def __init__(self, c_in: int, nc: int, k: int = 4, d: int = 16, iters: int = 1): | |
| super().__init__() | |
| p = max(1, int(math.ceil(math.sqrt(d)))) | |
| self.primary = PrimaryCaps(A=c_in, B=k, K=1, P=p, stride=1) | |
| # Internal capsule refinement layer. | |
| self.mid_caps = SelfRoutingConvCaps(B=k, C=int((k+nc)/2), K=1, P=p, stride=1, iters=iters, coor_add=False, w_shared=True) | |
| self.class_caps = SelfRoutingConvCaps(B=int((k+nc)/2), C=nc, K=1, P=p, stride=1, iters=iters, coor_add=False, w_shared=True) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # Output Detect-compatible class logits in BCHW. | |
| _, _, caps = self.primary(x) | |
| _, _, caps_mid = self.mid_caps(caps) | |
| _, a_out, _ = self.class_caps(caps_mid) | |
| logits = torch.logit(a_out.clamp(1e-4, 1.0 - 1e-4)).permute(0, 3, 1, 2).contiguous() | |
| return torch.nan_to_num(logits, nan=0.0, posinf=20.0, neginf=-20.0).float() | |
| class CapsuleDetect(Detect): | |
| """Detect head with capsule vote aggregation for both box and cls branches. | |
| Input feature of level i is packed as interleaved channels: | |
| [pose(d_i), act(1)] repeated k_i times -> C_i = k_i * (d_i + 1) | |
| In forward_head: | |
| - split pose/act per capsule type | |
| - run Detect box/cls heads on each type-specific pose tensor | |
| - aggregate type predictions with act-driven vote weights | |
| Detect decode/postprocess/end2end flow is reused unchanged. | |
| """ | |
| def __init__( | |
| self, | |
| nc: int = 80, | |
| *args, | |
| reg_max: int = 16, | |
| end2end: bool = False, | |
| k: list[int] | tuple[int, ...] = (4, 8, 16), | |
| d: list[int] | tuple[int, ...] = (16, 16, 16), | |
| ch: tuple = (), | |
| ): | |
| parsed = list(args) | |
| if parsed and isinstance(parsed[-1], (list, tuple)): | |
| ch = tuple(parsed.pop(-1)) | |
| # Parser layout: [k_list, d_list, reg_max, end2end, ch] | |
| if len(parsed) not in (2, 4): | |
| raise ValueError('CapsuleDetect expects [k_list, d_list, reg_max, end2end, ch].') | |
| k, d = parsed[0], parsed[1] | |
| if len(parsed) == 4: | |
| reg_max = int(parsed[2]) | |
| end2end = bool(parsed[3]) | |
| if not isinstance(k, (list, tuple)) or not isinstance(d, (list, tuple)): | |
| raise TypeError('CapsuleDetect requires list/tuple k and d (per-level settings).') | |
| ch = tuple(int(c) for c in ch) | |
| nl = len(ch) | |
| if len(k) != nl or len(d) != nl: | |
| raise ValueError(f'CapsuleDetect k/d length must equal number of levels ({nl}).') | |
| self.k_list = tuple(int(v) for v in k) | |
| self.d_list = tuple(int(v) for v in d) | |
| for i, c in enumerate(ch): | |
| expected = self.k_list[i] * (self.d_list[i] + 1) | |
| if c != expected: | |
| raise ValueError( | |
| f'CapsuleDetect level-{i} channel mismatch: got {c}, expected {expected} from k={self.k_list[i]}, d={self.d_list[i]}.' | |
| ) | |
| # Detect heads operate on per-type pose tensors (d_i channels). | |
| super().__init__(nc=nc, reg_max=reg_max, end2end=end2end, ch=self.d_list) | |
| # Vote weights from activation channels (K_i channels), separate for cls/box. | |
| self.box_vote = nn.ModuleList( | |
| nn.Sequential(Conv(k_i, k_i, 3), nn.Conv2d(k_i, k_i, 1, bias=True)) for k_i in self.k_list | |
| ) | |
| self.cls_vote = nn.ModuleList( | |
| nn.Sequential(Conv(k_i, k_i, 3), nn.Conv2d(k_i, k_i, 1, bias=True)) for k_i in self.k_list | |
| ) | |
| def _split_caps(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor]]: | |
| """Split packed feature into pose and activation tensors per level. | |
| Returns: | |
| pose_caps: list of tensors, each (B, K, D, H, W) | |
| act_map: list of tensors, each (B, K, H, W) | |
| """ | |
| pose_caps, act_map = [], [] | |
| for i, xi in enumerate(x): | |
| k_i = self.k_list[i] | |
| d_i = self.d_list[i] | |
| c = int(xi.shape[1]) | |
| expected = k_i * (d_i + 1) | |
| if c != expected: | |
| raise ValueError(f'CapsuleDetect level-{i} channel mismatch: got {c}, expected {expected}.') | |
| b, _, h, w = xi.shape | |
| caps = xi.view(b, k_i, d_i + 1, h, w) | |
| pose_caps.append(caps[:, :, :d_i].contiguous()) | |
| act_map.append(caps[:, :, d_i].contiguous()) | |
| return pose_caps, act_map | |
| def _normalized_votes(raw: torch.Tensor, eps: float = 1e-4) -> torch.Tensor: | |
| # No softmax/sigmoid: use softplus + sum-normalization. | |
| w = F.softplus(raw) + eps | |
| return w / (w.sum(dim=1, keepdim=True) + eps) | |
| def _run_voted_head( | |
| self, | |
| pose: torch.Tensor, | |
| act: torch.Tensor, | |
| head: torch.nn.Module, | |
| vote_head: torch.nn.Module, | |
| out_ch: int, | |
| ) -> torch.Tensor: | |
| """Apply one Detect head per type and aggregate by vote weights. | |
| Args: | |
| pose: (B, K, D, H, W) | |
| act: (B, K, H, W) | |
| head: Detect box or cls head module for this level | |
| vote_head: vote logits module for this level | |
| out_ch: output channels of target prediction | |
| Returns: | |
| (B, out_ch, H, W) | |
| """ | |
| b, k, d, h, w = pose.shape | |
| # No voting needed when there is only one capsule type. | |
| if k == 1: | |
| return head(pose[:, 0]) | |
| pose_bt = pose.reshape(b * k, d, h, w) | |
| pred_bt = head(pose_bt).reshape(b, k, out_ch, h, w) | |
| vote_raw = vote_head(act) # (B, K, H, W) | |
| vote = self._normalized_votes(vote_raw).unsqueeze(2) # (B, K, 1, H, W) | |
| pred = (pred_bt * vote).sum(dim=1) | |
| return pred | |
| def forward_head( | |
| self, x: list[torch.Tensor], box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None | |
| ) -> dict[str, torch.Tensor]: | |
| if box_head is None or cls_head is None: | |
| return dict() | |
| pose_caps, act_map = self._split_caps(x) | |
| bs = x[0].shape[0] | |
| box_list = [] | |
| cls_list = [] | |
| for i in range(self.nl): | |
| box_i = self._run_voted_head( | |
| pose_caps[i], | |
| act_map[i], | |
| box_head[i], | |
| self.box_vote[i], | |
| out_ch=4 * self.reg_max, | |
| ) | |
| cls_i = self._run_voted_head( | |
| pose_caps[i], | |
| act_map[i], | |
| cls_head[i], | |
| self.cls_vote[i], | |
| out_ch=self.nc, | |
| ) | |
| box_list.append(box_i.view(bs, 4 * self.reg_max, -1)) | |
| cls_list.append(cls_i.view(bs, self.nc, -1)) | |
| boxes = torch.cat(box_list, dim=-1) | |
| scores = torch.cat(cls_list, dim=-1) | |
| return dict(boxes=boxes, scores=scores, feats=x) | |
| class CapsuleDetectv1(Detect): | |
| """Capsule Detect variant with activation-gated pose fusion. | |
| Per level: | |
| 1) Split packed capsule channels into pose/activation (interleaved by type). | |
| 2) Use a 2-layer 1x1 gate net on activation channels. | |
| 3) Gate pose channels with residual scaling. | |
| 4) Flatten to K*D channels and run original Detect cv2/cv3 heads. | |
| """ | |
| def __init__( | |
| self, | |
| nc: int = 80, | |
| *args, | |
| reg_max: int = 16, | |
| end2end: bool = False, | |
| k: list[int] | tuple[int, ...] = (4, 8, 16), | |
| d: list[int] | tuple[int, ...] = (16, 16, 16), | |
| ch: tuple = (), | |
| ): | |
| parsed = list(args) | |
| if parsed and isinstance(parsed[-1], (list, tuple)): | |
| ch = tuple(parsed.pop(-1)) | |
| # Parser layout: [k_list, d_list, reg_max, end2end, ch] | |
| if len(parsed) not in (2, 4): | |
| raise ValueError("CapsuleDetectv1 expects [k_list, d_list, reg_max, end2end, ch].") | |
| k, d = parsed[0], parsed[1] | |
| if len(parsed) == 4: | |
| reg_max = int(parsed[2]) | |
| end2end = bool(parsed[3]) | |
| if not isinstance(k, (list, tuple)) or not isinstance(d, (list, tuple)): | |
| raise TypeError("CapsuleDetectv1 requires list/tuple k and d (per-level settings).") | |
| ch = tuple(int(c) for c in ch) | |
| nl = len(ch) | |
| if len(k) != nl or len(d) != nl: | |
| raise ValueError(f"CapsuleDetectv1 k/d length must equal number of levels ({nl}).") | |
| self.k_list = tuple(int(v) for v in k) | |
| self.d_list = tuple(int(v) for v in d) | |
| # Input from neck is packed as K*(D+1): [pose(D), act(1)] repeated K types. | |
| for i, c in enumerate(ch): | |
| expected = self.k_list[i] * (self.d_list[i] + 1) | |
| if c != expected: | |
| raise ValueError( | |
| f"CapsuleDetectv1 level-{i} channel mismatch: got {c}, " | |
| f"expected {expected} from k={self.k_list[i]}, d={self.d_list[i]}." | |
| ) | |
| # Detect heads consume merged pose channels: K*D. | |
| merged_ch = tuple(k_i * d_i for k_i, d_i in zip(self.k_list, self.d_list)) | |
| super().__init__(nc=nc, reg_max=reg_max, end2end=end2end, ch=merged_ch) | |
| self.pose_gates = nn.ModuleList() | |
| self.gate_alpha = nn.ParameterList() | |
| for k_i, d_i in zip(self.k_list, self.d_list): | |
| out_ch = k_i * d_i | |
| hidden = max(8, k_i * 2) | |
| self.pose_gates.append( | |
| nn.Sequential( | |
| nn.Conv2d(k_i, hidden, 1, bias=True), | |
| nn.SiLU(inplace=True), | |
| nn.Conv2d(hidden, out_ch, 1, bias=True), | |
| ) | |
| ) | |
| self.gate_alpha.append(nn.Parameter(torch.tensor(0.5))) | |
| def _split_pose_act(self, x: torch.Tensor, i: int) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Split one level packed tensor into pose and activation maps.""" | |
| k_i = self.k_list[i] | |
| d_i = self.d_list[i] | |
| b, c, h, w = x.shape | |
| expected = k_i * (d_i + 1) | |
| if c != expected: | |
| raise ValueError(f"CapsuleDetectv1 level-{i} channel mismatch: got {c}, expected {expected}.") | |
| caps = x.view(b, k_i, d_i + 1, h, w) | |
| pose = caps[:, :, :d_i].reshape(b, k_i * d_i, h, w).contiguous() | |
| act = caps[:, :, d_i].contiguous() | |
| return pose, act | |
| def _merge_pose(self, x: list[torch.Tensor]) -> list[torch.Tensor]: | |
| merged = [] | |
| for i, xi in enumerate(x): | |
| pose, act = self._split_pose_act(xi, i) | |
| gate = torch.sigmoid(self.pose_gates[i](act)) | |
| # Residual gating keeps base pose information and improves stability. | |
| pose = pose * (1.0 + self.gate_alpha[i] * gate) | |
| merged.append(pose) | |
| return merged | |
| def forward_head( | |
| self, x: list[torch.Tensor], box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None | |
| ) -> dict[str, torch.Tensor]: | |
| if box_head is None or cls_head is None: | |
| return dict() | |
| pose_feats = self._merge_pose(x) | |
| bs = pose_feats[0].shape[0] | |
| box_list = [box_head[i](pose_feats[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)] | |
| cls_list = [cls_head[i](pose_feats[i]).view(bs, self.nc, -1) for i in range(self.nl)] | |
| boxes = torch.cat(box_list, dim=-1) | |
| scores = torch.cat(cls_list, dim=-1) | |
| return dict(boxes=boxes, scores=scores, feats=x) | |
| class CapsuleDetectv2(Detect): | |
| """Capsule Detect v2: activation-gated pose + activation bypass for classification.""" | |
| def __init__( | |
| self, | |
| nc: int = 80, | |
| *args, | |
| reg_max: int = 16, | |
| end2end: bool = False, | |
| k: list[int] | tuple[int, ...] = (4, 8, 16), | |
| d: list[int] | tuple[int, ...] = (16, 16, 16), | |
| ch: tuple = (), | |
| ): | |
| parsed = list(args) | |
| if parsed and isinstance(parsed[-1], (list, tuple)): | |
| ch = tuple(parsed.pop(-1)) | |
| # Parser layout: [k_list, d_list, reg_max, end2end, ch] | |
| if len(parsed) not in (2, 4): | |
| raise ValueError("CapsuleDetectv2 expects [k_list, d_list, reg_max, end2end, ch].") | |
| k, d = parsed[0], parsed[1] | |
| if len(parsed) == 4: | |
| reg_max = int(parsed[2]) | |
| end2end = bool(parsed[3]) | |
| if not isinstance(k, (list, tuple)) or not isinstance(d, (list, tuple)): | |
| raise TypeError("CapsuleDetectv2 requires list/tuple k and d (per-level settings).") | |
| ch = tuple(int(c) for c in ch) | |
| nl = len(ch) | |
| if len(k) != nl or len(d) != nl: | |
| raise ValueError(f"CapsuleDetectv2 k/d length must equal number of levels ({nl}).") | |
| self.k_list = tuple(int(v) for v in k) | |
| self.d_list = tuple(int(v) for v in d) | |
| # Input from neck is packed as K*(D+1): [pose(D), act(1)] repeated K types. | |
| for i, c in enumerate(ch): | |
| expected = self.k_list[i] * (self.d_list[i] + 1) | |
| if c != expected: | |
| raise ValueError( | |
| f"CapsuleDetectv2 level-{i} channel mismatch: got {c}, " | |
| f"expected {expected} from k={self.k_list[i]}, d={self.d_list[i]}." | |
| ) | |
| # Detect heads consume merged pose channels: K*D. | |
| merged_ch = tuple(k_i * d_i for k_i, d_i in zip(self.k_list, self.d_list)) | |
| super().__init__(nc=nc, reg_max=reg_max, end2end=end2end, ch=merged_ch) | |
| self.pose_gates = nn.ModuleList() | |
| self.gate_alpha = nn.ParameterList() | |
| self.cls_bypass = nn.ModuleList() | |
| self.cls_beta = nn.ParameterList() | |
| for k_i, d_i in zip(self.k_list, self.d_list): | |
| pose_ch = k_i * d_i | |
| gate_hidden = max(8, k_i * 2) | |
| self.pose_gates.append( | |
| nn.Sequential( | |
| nn.Conv2d(k_i, gate_hidden, 1, bias=True), | |
| nn.SiLU(inplace=True), | |
| nn.Conv2d(gate_hidden, pose_ch, 1, bias=True), | |
| ) | |
| ) | |
| self.gate_alpha.append(nn.Parameter(torch.tensor(0.5))) | |
| cls_hidden = max(16, k_i * 2) | |
| self.cls_bypass.append( | |
| nn.Sequential( | |
| nn.Conv2d(k_i, cls_hidden, 1, bias=True), | |
| nn.SiLU(inplace=True), | |
| nn.Conv2d(cls_hidden, pose_ch, 1, bias=True), | |
| ) | |
| ) | |
| self.cls_beta.append(nn.Parameter(torch.tensor(0.1))) | |
| def _split_pose_act(self, x: torch.Tensor, i: int) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Split one level packed tensor into pose and activation maps.""" | |
| k_i = self.k_list[i] | |
| d_i = self.d_list[i] | |
| b, c, h, w = x.shape | |
| expected = k_i * (d_i + 1) | |
| if c != expected: | |
| raise ValueError(f"CapsuleDetectv2 level-{i} channel mismatch: got {c}, expected {expected}.") | |
| caps = x.view(b, k_i, d_i + 1, h, w) | |
| pose = caps[:, :, :d_i].reshape(b, k_i * d_i, h, w).contiguous() | |
| act = caps[:, :, d_i].contiguous() | |
| return pose, act | |
| def _fuse_pose(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor]]: | |
| box_feats, cls_feats = [], [] | |
| for i, xi in enumerate(x): | |
| pose, act = self._split_pose_act(xi, i) | |
| gate = torch.sigmoid(self.pose_gates[i](act)) | |
| pose_g = pose * (1.0 + self.gate_alpha[i] * gate) | |
| # Classification bypass from activation channels. | |
| act_skip = self.cls_bypass[i](act) | |
| cls_in = pose_g + self.cls_beta[i] * act_skip | |
| box_feats.append(pose_g) | |
| cls_feats.append(cls_in) | |
| return box_feats, cls_feats | |
| def forward_head( | |
| self, x: list[torch.Tensor], box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None | |
| ) -> dict[str, torch.Tensor]: | |
| if box_head is None or cls_head is None: | |
| return dict() | |
| box_feats, cls_feats = self._fuse_pose(x) | |
| bs = x[0].shape[0] | |
| box_list = [box_head[i](box_feats[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)] | |
| cls_list = [cls_head[i](cls_feats[i]).view(bs, self.nc, -1) for i in range(self.nl)] | |
| boxes = torch.cat(box_list, dim=-1) | |
| scores = torch.cat(cls_list, dim=-1) | |
| return dict(boxes=boxes, scores=scores, feats=x) | |
| class CapsuleDetectv4(Detect): | |
| """Capsule Detect v4: box uses raw pose, cls uses act bypass + symbolic type prior.""" | |
| def __init__( | |
| self, | |
| nc: int = 80, | |
| *args, | |
| reg_max: int = 16, | |
| end2end: bool = False, | |
| k: list[int] | tuple[int, ...] = (4, 8, 16), | |
| d: list[int] | tuple[int, ...] = (16, 16, 16), | |
| ch: tuple = (), | |
| ): | |
| parsed = list(args) | |
| if parsed and isinstance(parsed[-1], (list, tuple)): | |
| ch = tuple(parsed.pop(-1)) | |
| if len(parsed) not in (2, 4): | |
| raise ValueError("CapsuleDetectv4 expects [k_list, d_list, reg_max, end2end, ch].") | |
| k, d = parsed[0], parsed[1] | |
| if len(parsed) == 4: | |
| reg_max = int(parsed[2]) | |
| end2end = bool(parsed[3]) | |
| if not isinstance(k, (list, tuple)) or not isinstance(d, (list, tuple)): | |
| raise TypeError("CapsuleDetectv4 requires list/tuple k and d (per-level settings).") | |
| ch = tuple(int(c) for c in ch) | |
| nl = len(ch) | |
| if len(k) != nl or len(d) != nl: | |
| raise ValueError(f"CapsuleDetectv4 k/d length must equal number of levels ({nl}).") | |
| self.k_list = tuple(int(v) for v in k) | |
| self.d_list = tuple(int(v) for v in d) | |
| for i, c in enumerate(ch): | |
| expected = self.k_list[i] * (self.d_list[i] + 1) | |
| if c != expected: | |
| raise ValueError( | |
| f"CapsuleDetectv4 level-{i} channel mismatch: got {c}, " | |
| f"expected {expected} from k={self.k_list[i]}, d={self.d_list[i]}." | |
| ) | |
| merged_ch = tuple(k_i * d_i for k_i, d_i in zip(self.k_list, self.d_list)) | |
| super().__init__(nc=nc, reg_max=reg_max, end2end=end2end, ch=merged_ch) | |
| self.cls_bypass = nn.ModuleList() | |
| self.cls_beta = nn.ParameterList() | |
| self.sym_prior = nn.ModuleList() | |
| self.sym_beta = nn.ParameterList() | |
| for k_i, d_i in zip(self.k_list, self.d_list): | |
| pose_ch = k_i * d_i | |
| cls_hidden = max(16, k_i * 2) | |
| self.cls_bypass.append( | |
| nn.Sequential( | |
| nn.Conv2d(k_i, cls_hidden, 1, bias=True), | |
| nn.SiLU(inplace=True), | |
| nn.Conv2d(cls_hidden, pose_ch, 1, bias=True), | |
| ) | |
| ) | |
| self.cls_beta.append(nn.Parameter(torch.tensor(0.1))) | |
| self.sym_prior.append(nn.Conv2d(k_i, self.nc, 1, bias=False)) | |
| self.sym_beta.append(nn.Parameter(torch.tensor(0.1))) | |
| def _split_pose_act(self, x: torch.Tensor, i: int) -> tuple[torch.Tensor, torch.Tensor]: | |
| k_i = self.k_list[i] | |
| d_i = self.d_list[i] | |
| b, c, h, w = x.shape | |
| expected = k_i * (d_i + 1) | |
| if c != expected: | |
| raise ValueError(f"CapsuleDetectv4 level-{i} channel mismatch: got {c}, expected {expected}.") | |
| caps = x.view(b, k_i, d_i + 1, h, w) | |
| pose = caps[:, :, :d_i].reshape(b, k_i * d_i, h, w).contiguous() | |
| act = caps[:, :, d_i].contiguous() | |
| return pose, act | |
| def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: | |
| box_feats, cls_feats, cls_priors = [], [], [] | |
| for i, xi in enumerate(x): | |
| pose, act = self._split_pose_act(xi, i) | |
| cls_in = pose + self.cls_beta[i] * self.cls_bypass[i](act) | |
| cls_prior = self.sym_beta[i] * self.sym_prior[i](act) | |
| box_feats.append(pose) | |
| cls_feats.append(cls_in) | |
| cls_priors.append(cls_prior) | |
| return box_feats, cls_feats, cls_priors | |
| def forward_head( | |
| self, x: list[torch.Tensor], box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None | |
| ) -> dict[str, torch.Tensor]: | |
| if box_head is None or cls_head is None: | |
| return dict() | |
| box_feats, cls_feats, cls_priors = self._build_feats(x) | |
| bs = x[0].shape[0] | |
| box_list = [box_head[i](box_feats[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)] | |
| cls_list = [ | |
| (cls_head[i](cls_feats[i]) + cls_priors[i]).view(bs, self.nc, -1) | |
| for i in range(self.nl) | |
| ] | |
| boxes = torch.cat(box_list, dim=-1) | |
| scores = torch.cat(cls_list, dim=-1) | |
| return dict(boxes=boxes, scores=scores, feats=x) | |
| def _setup_capsule_layout( | |
| k: list[int] | tuple[int, ...], | |
| d: list[int] | tuple[int, ...], | |
| ch: tuple, | |
| cls_name: str, | |
| ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: | |
| if not isinstance(k, (list, tuple)) or not isinstance(d, (list, tuple)): | |
| raise TypeError(f"{cls_name} requires list/tuple k and d (per-level settings).") | |
| ch = tuple(int(c) for c in ch) | |
| nl = len(ch) | |
| if len(k) != nl or len(d) != nl: | |
| raise ValueError(f"{cls_name} k/d length must equal number of levels ({nl}).") | |
| k_list = tuple(int(v) for v in k) | |
| d_list = tuple(int(v) for v in d) | |
| for i, c in enumerate(ch): | |
| expected = k_list[i] * (d_list[i] + 1) | |
| if c != expected: | |
| raise ValueError( | |
| f"{cls_name} level-{i} channel mismatch: got {c}, " | |
| f"expected {expected} from k={k_list[i]}, d={d_list[i]}." | |
| ) | |
| merged_ch = tuple(k_i * d_i for k_i, d_i in zip(k_list, d_list)) | |
| return k_list, d_list, merged_ch | |
| def _init_capsule_semantic_heads(obj: nn.Module) -> None: | |
| obj.cls_bypass = nn.ModuleList() | |
| obj.cls_beta = nn.ParameterList() | |
| obj.sym_prior = nn.ModuleList() | |
| obj.sym_norm = nn.ModuleList() | |
| obj.sym_dropout = nn.ModuleList() | |
| obj.sym_beta = nn.ParameterList() | |
| for k_i, d_i in zip(obj.k_list, obj.d_list): | |
| pose_ch = k_i * d_i | |
| cls_hidden = max(16, k_i * 2) | |
| obj.cls_bypass.append( | |
| nn.Sequential( | |
| nn.Conv2d(k_i, cls_hidden, 1, bias=True), | |
| nn.SiLU(inplace=True), | |
| nn.Conv2d(cls_hidden, pose_ch, 1, bias=True), | |
| ) | |
| ) | |
| obj.cls_beta.append(nn.Parameter(torch.tensor(0.1))) | |
| obj.sym_dropout.append(nn.Dropout2d(p=0.1)) | |
| obj.sym_prior.append(nn.Conv2d(k_i, obj.nc, 1, bias=False)) | |
| obj.sym_norm.append(nn.GroupNorm(1, obj.nc)) | |
| obj.sym_beta.append(nn.Parameter(torch.tensor(0.1))) | |
| def _capsule_split_pose_act( | |
| x: torch.Tensor, | |
| k_i: int, | |
| d_i: int, | |
| cls_name: str, | |
| level_i: int, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| b, c, h, w = x.shape | |
| expected = k_i * (d_i + 1) | |
| if c != expected: | |
| raise ValueError(f"{cls_name} level-{level_i} channel mismatch: got {c}, expected {expected}.") | |
| caps = x.view(b, k_i, d_i + 1, h, w) | |
| pose = caps[:, :, :d_i].reshape(b, k_i * d_i, h, w).contiguous() | |
| act = caps[:, :, d_i].contiguous() | |
| return pose, act | |
| def _capsule_build_feats(obj: nn.Module, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: | |
| box_feats, cls_feats, cls_priors = [], [], [] | |
| cls_name = obj.__class__.__name__ | |
| for i, xi in enumerate(x): | |
| pose, act = _capsule_split_pose_act(xi, obj.k_list[i], obj.d_list[i], cls_name, i) | |
| cls_scale = torch.tanh(obj.cls_beta[i]) | |
| cls_in = pose + cls_scale * obj.cls_bypass[i](act) | |
| act_s = obj.sym_dropout[i](act) | |
| prior = obj.sym_prior[i](act_s) | |
| prior = obj.sym_norm[i](prior) | |
| prior = prior - prior.mean(dim=1, keepdim=True) | |
| sym_scale = torch.tanh(obj.sym_beta[i]) | |
| cls_prior = sym_scale * prior | |
| box_feats.append(pose) | |
| cls_feats.append(cls_in) | |
| cls_priors.append(cls_prior) | |
| return box_feats, cls_feats, cls_priors | |
| def _capsule_build_feats_gated( | |
| obj: nn.Module, x: list[torch.Tensor] | |
| ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: | |
| box_feats, cls_feats, cls_priors = [], [], [] | |
| cls_name = obj.__class__.__name__ | |
| for i, xi in enumerate(x): | |
| pose, act = _capsule_split_pose_act(xi, obj.k_list[i], obj.d_list[i], cls_name, i) | |
| cls_scale = torch.tanh(obj.cls_beta[i]) | |
| gate = torch.sigmoid(obj.cls_bypass[i](act)) | |
| cls_in = pose * (1.0 + cls_scale * gate) | |
| act_s = obj.sym_dropout[i](act) | |
| prior = obj.sym_prior[i](act_s) | |
| prior = obj.sym_norm[i](prior) | |
| prior = prior - prior.mean(dim=1, keepdim=True) | |
| sym_scale = torch.tanh(obj.sym_beta[i]) | |
| cls_prior = sym_scale * prior | |
| box_feats.append(pose) | |
| cls_feats.append(cls_in) | |
| cls_priors.append(cls_prior) | |
| return box_feats, cls_feats, cls_priors | |
| def _capsule_build_feats_boxcls( | |
| obj: nn.Module, x: list[torch.Tensor] | |
| ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: | |
| box_feats, cls_feats, cls_priors = [], [], [] | |
| cls_name = obj.__class__.__name__ | |
| for i, xi in enumerate(x): | |
| pose, act = _capsule_split_pose_act(xi, obj.k_list[i], obj.d_list[i], cls_name, i) | |
| act_s = obj.sym_dropout[i](act) | |
| prior = obj.sym_prior[i](act_s) | |
| prior = obj.sym_norm[i](prior) | |
| prior = prior - prior.mean(dim=1, keepdim=True) | |
| sym_scale = torch.tanh(obj.sym_beta[i]) | |
| cls_prior = sym_scale * prior | |
| box_feats.append(pose) | |
| cls_feats.append(pose) | |
| cls_priors.append(cls_prior) | |
| return box_feats, cls_feats, cls_priors | |
| def _capsule_build_feats_boxcls_simpleprior( | |
| obj: nn.Module, x: list[torch.Tensor] | |
| ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: | |
| box_feats, cls_feats, cls_priors = [], [], [] | |
| cls_name = obj.__class__.__name__ | |
| for i, xi in enumerate(x): | |
| pose, act = _capsule_split_pose_act(xi, obj.k_list[i], obj.d_list[i], cls_name, i) | |
| act_s = obj.sym_dropout[i](act) | |
| prior = obj.sym_prior[i](act_s) | |
| sym_scale = torch.tanh(obj.sym_beta[i]) | |
| cls_prior = sym_scale * prior | |
| box_feats.append(pose) | |
| cls_feats.append(pose) | |
| cls_priors.append(cls_prior) | |
| return box_feats, cls_feats, cls_priors | |
| def _capsule_build_feats_open_vocab( | |
| obj: nn.Module, x: list[torch.Tensor] | |
| ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: | |
| box_feats, cls_feats, acts = [], [], [] | |
| cls_name = obj.__class__.__name__ | |
| for i, xi in enumerate(x): | |
| pose, act = _capsule_split_pose_act(xi, obj.k_list[i], obj.d_list[i], cls_name, i) | |
| cls_in = pose | |
| if getattr(obj, "with_act_gate", False): | |
| cls_scale = torch.tanh(obj.ov_beta[i]) | |
| gate = torch.sigmoid(obj.ov_gate[i](act)) | |
| cls_in = pose * (1.0 + cls_scale * gate) | |
| box_feats.append(pose) | |
| cls_feats.append(cls_in) | |
| acts.append(act) | |
| return box_feats, cls_feats, acts | |
| class CapsuleDetectv5(Detect): | |
| """Capsule Detect v5: box uses raw pose, cls uses stabilized symbolic prior.""" | |
| def __init__( | |
| self, | |
| nc: int = 80, | |
| *args, | |
| reg_max: int = 16, | |
| end2end: bool = False, | |
| k: list[int] | tuple[int, ...] = (4, 8, 16), | |
| d: list[int] | tuple[int, ...] = (16, 16, 16), | |
| ch: tuple = (), | |
| ): | |
| parsed = list(args) | |
| if parsed and isinstance(parsed[-1], (list, tuple)): | |
| ch = tuple(parsed.pop(-1)) | |
| if len(parsed) not in (2, 4): | |
| raise ValueError("CapsuleDetectv5 expects [k_list, d_list, reg_max, end2end, ch].") | |
| k, d = parsed[0], parsed[1] | |
| if len(parsed) == 4: | |
| reg_max = int(parsed[2]) | |
| end2end = bool(parsed[3]) | |
| self.k_list, self.d_list, merged_ch = _setup_capsule_layout(k, d, ch, "CapsuleDetectv5") | |
| super().__init__(nc=nc, reg_max=reg_max, end2end=end2end, ch=merged_ch) | |
| _init_capsule_semantic_heads(self) | |
| def _split_pose_act(self, x: torch.Tensor, i: int) -> tuple[torch.Tensor, torch.Tensor]: | |
| return _capsule_split_pose_act(x, self.k_list[i], self.d_list[i], "CapsuleDetectv5", i) | |
| def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: | |
| return _capsule_build_feats(self, x) | |
| def forward_head( | |
| self, x: list[torch.Tensor], box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None | |
| ) -> dict[str, torch.Tensor]: | |
| if box_head is None or cls_head is None: | |
| return dict() | |
| box_feats, cls_feats, cls_priors = self._build_feats(x) | |
| bs = x[0].shape[0] | |
| box_list = [box_head[i](box_feats[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)] | |
| cls_list = [ | |
| (cls_head[i](cls_feats[i]) + cls_priors[i]).view(bs, self.nc, -1) | |
| for i in range(self.nl) | |
| ] | |
| boxes = torch.cat(box_list, dim=-1) | |
| scores = torch.cat(cls_list, dim=-1) | |
| return dict(boxes=boxes, scores=scores, feats=x) | |
| class CapsuleDetectv6(CapsuleDetectv5): | |
| """Capsule Detect v6: replace additive cls correction with multiplicative act gate.""" | |
| def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: | |
| return _capsule_build_feats_gated(self, x) | |
| class CapsuleDetectv7(CapsuleDetectv5): | |
| """Capsule Detect v7: cls head consumes raw pose features plus symbolic priors only.""" | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.profile_head = False | |
| self._head_profile: dict[str, float] = {} | |
| self._head_profile_calls = 0 | |
| def _ensure_profile_attrs(self) -> None: | |
| if not hasattr(self, "profile_head"): | |
| self.profile_head = False | |
| if not hasattr(self, "_head_profile"): | |
| self._head_profile = {} | |
| if not hasattr(self, "_head_profile_calls"): | |
| self._head_profile_calls = 0 | |
| def reset_head_profile(self) -> None: | |
| self._ensure_profile_attrs() | |
| self._head_profile = { | |
| "split_pose_act_ms": 0.0, | |
| "cls_prior_ms": 0.0, | |
| "box_head_ms": 0.0, | |
| "cls_head_ms": 0.0, | |
| "cat_ms": 0.0, | |
| } | |
| self._head_profile_calls = 0 | |
| def get_head_profile(self) -> dict[str, float]: | |
| self._ensure_profile_attrs() | |
| if not self._head_profile: | |
| return {} | |
| out = dict(self._head_profile) | |
| calls = max(self._head_profile_calls, 1) | |
| out["calls"] = float(self._head_profile_calls) | |
| out["total_ms"] = sum(v for k, v in out.items() if k.endswith("_ms")) | |
| for key in list(self._head_profile): | |
| out[key.replace("_ms", "_avg_ms")] = self._head_profile[key] / calls | |
| return out | |
| def _sync_profile(self) -> None: | |
| self._ensure_profile_attrs() | |
| if self.profile_head and torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: | |
| self._ensure_profile_attrs() | |
| if not self.profile_head: | |
| return _capsule_build_feats_boxcls(self, x) | |
| if not self._head_profile: | |
| self.reset_head_profile() | |
| box_feats, cls_feats, cls_priors = [], [], [] | |
| cls_name = self.__class__.__name__ | |
| for i, xi in enumerate(x): | |
| self._sync_profile() | |
| t0 = time.perf_counter() | |
| pose, act = _capsule_split_pose_act(xi, self.k_list[i], self.d_list[i], cls_name, i) | |
| self._sync_profile() | |
| self._head_profile["split_pose_act_ms"] += (time.perf_counter() - t0) * 1000.0 | |
| self._sync_profile() | |
| t0 = time.perf_counter() | |
| act_s = self.sym_dropout[i](act) | |
| prior = self.sym_prior[i](act_s) | |
| prior = self.sym_norm[i](prior) | |
| prior = prior - prior.mean(dim=1, keepdim=True) | |
| sym_scale = torch.tanh(self.sym_beta[i]) | |
| cls_prior = sym_scale * prior | |
| self._sync_profile() | |
| self._head_profile["cls_prior_ms"] += (time.perf_counter() - t0) * 1000.0 | |
| box_feats.append(pose) | |
| cls_feats.append(pose) | |
| cls_priors.append(cls_prior) | |
| return box_feats, cls_feats, cls_priors | |
| def forward_head( | |
| self, x: list[torch.Tensor], box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None | |
| ) -> dict[str, torch.Tensor]: | |
| self._ensure_profile_attrs() | |
| if box_head is None or cls_head is None: | |
| return dict() | |
| box_feats, cls_feats, cls_priors = self._build_feats(x) | |
| bs = x[0].shape[0] | |
| if not self.profile_head: | |
| box_list = [box_head[i](box_feats[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)] | |
| cls_list = [ | |
| (cls_head[i](cls_feats[i]) + cls_priors[i]).view(bs, self.nc, -1) | |
| for i in range(self.nl) | |
| ] | |
| boxes = torch.cat(box_list, dim=-1) | |
| scores = torch.cat(cls_list, dim=-1) | |
| return dict(boxes=boxes, scores=scores, feats=x) | |
| if not self._head_profile: | |
| self.reset_head_profile() | |
| self._head_profile_calls += 1 | |
| box_list, cls_list = [], [] | |
| for i in range(self.nl): | |
| self._sync_profile() | |
| t0 = time.perf_counter() | |
| box_i = box_head[i](box_feats[i]).view(bs, 4 * self.reg_max, -1) | |
| self._sync_profile() | |
| self._head_profile["box_head_ms"] += (time.perf_counter() - t0) * 1000.0 | |
| box_list.append(box_i) | |
| self._sync_profile() | |
| t0 = time.perf_counter() | |
| cls_i = (cls_head[i](cls_feats[i]) + cls_priors[i]).view(bs, self.nc, -1) | |
| self._sync_profile() | |
| self._head_profile["cls_head_ms"] += (time.perf_counter() - t0) * 1000.0 | |
| cls_list.append(cls_i) | |
| self._sync_profile() | |
| t0 = time.perf_counter() | |
| boxes = torch.cat(box_list, dim=-1) | |
| scores = torch.cat(cls_list, dim=-1) | |
| self._sync_profile() | |
| self._head_profile["cat_ms"] += (time.perf_counter() - t0) * 1000.0 | |
| return dict(boxes=boxes, scores=scores, feats=x) | |
| class CapsuleDetectv8(CapsuleDetectv5): | |
| """Capsule Detect v8: raw pose cls path with simplified cls_prior (no norm/centering).""" | |
| def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: | |
| return _capsule_build_feats_boxcls_simpleprior(self, x) | |
| class CapsuleOpenVocabDetect(Detect): | |
| """Capsule detection head with open-vocabulary classification via text embedding matching.""" | |
| def __init__( | |
| self, | |
| nc: int = 80, | |
| *args, | |
| reg_max: int = 16, | |
| end2end: bool = False, | |
| embed: int = 256, | |
| with_act_gate: bool = False, | |
| with_objectness_prior: bool = True, | |
| k: list[int] | tuple[int, ...] = (4, 8, 16), | |
| d: list[int] | tuple[int, ...] = (16, 16, 16), | |
| ch: tuple = (), | |
| ): | |
| parsed = list(args) | |
| if parsed and isinstance(parsed[-1], (list, tuple)): | |
| ch = tuple(parsed.pop(-1)) | |
| if len(parsed) not in (2, 4, 7): | |
| raise ValueError( | |
| "CapsuleOpenVocabDetect expects [k_list, d_list, (reg_max, end2end, embed, with_act_gate, " | |
| "with_objectness_prior), ch]." | |
| ) | |
| k, d = parsed[0], parsed[1] | |
| if len(parsed) == 4: | |
| reg_max = int(parsed[2]) | |
| end2end = bool(parsed[3]) | |
| elif len(parsed) == 7: | |
| # Support both direct args order: | |
| # [k_list, d_list, reg_max, end2end, embed, with_act_gate, with_objectness_prior] | |
| # and parser-appended order: | |
| # [k_list, d_list, embed, with_act_gate, with_objectness_prior, reg_max, end2end] | |
| if type(parsed[3]) is bool and type(parsed[4]) is bool and type(parsed[6]) is bool: | |
| embed = int(parsed[2]) | |
| with_act_gate = bool(parsed[3]) | |
| with_objectness_prior = bool(parsed[4]) | |
| reg_max = int(parsed[5]) | |
| end2end = bool(parsed[6]) | |
| else: | |
| reg_max = int(parsed[2]) | |
| end2end = bool(parsed[3]) | |
| embed = int(parsed[4]) | |
| with_act_gate = bool(parsed[5]) | |
| with_objectness_prior = bool(parsed[6]) | |
| self.k_list, self.d_list, merged_ch = _setup_capsule_layout(k, d, ch, "CapsuleOpenVocabDetect") | |
| super().__init__(nc=nc, reg_max=reg_max, end2end=end2end, ch=merged_ch) | |
| self.embed = int(embed) | |
| self.with_act_gate = bool(with_act_gate) | |
| self.with_objectness_prior = bool(with_objectness_prior) | |
| self.emb_head = nn.ModuleList() | |
| self.ov_gate = nn.ModuleList() | |
| self.ov_beta = nn.ParameterList() | |
| self.obj_prior = nn.ModuleList() | |
| for k_i, d_i in zip(self.k_list, self.d_list): | |
| pose_ch = k_i * d_i | |
| self.emb_head.append( | |
| nn.Sequential( | |
| Conv(pose_ch, pose_ch, 3), | |
| DWConv(pose_ch, pose_ch, 3), | |
| nn.Conv2d(pose_ch, self.embed, 1, bias=True), | |
| ) | |
| ) | |
| if self.with_act_gate: | |
| hidden = max(16, k_i * 2) | |
| self.ov_gate.append( | |
| nn.Sequential( | |
| nn.Conv2d(k_i, hidden, 1, bias=True), | |
| nn.SiLU(inplace=True), | |
| nn.Conv2d(hidden, pose_ch, 1, bias=True), | |
| ) | |
| ) | |
| self.ov_beta.append(nn.Parameter(torch.tensor(0.1))) | |
| else: | |
| self.ov_gate.append(nn.Identity()) | |
| self.ov_beta.append(nn.Parameter(torch.tensor(0.0), requires_grad=False)) | |
| if self.with_objectness_prior: | |
| self.obj_prior.append(nn.Conv2d(k_i, 1, 1, bias=True)) | |
| else: | |
| self.obj_prior.append(nn.Identity()) | |
| self.logit_scale = nn.Parameter(torch.tensor(math.log(1 / 0.07), dtype=torch.float32)) | |
| self.register_buffer("cached_text_embeddings", torch.empty(0), persistent=False) | |
| def set_text_embeddings(self, text_embs: torch.Tensor | None) -> None: | |
| """Cache normalized text embeddings for inference.""" | |
| if text_embs is None: | |
| self.cached_text_embeddings = torch.empty(0, device=self.logit_scale.device) | |
| return | |
| if text_embs.ndim != 2: | |
| raise ValueError(f"text_embs must be 2D [num_classes, embed_dim], got shape {tuple(text_embs.shape)}.") | |
| self.cached_text_embeddings = F.normalize(text_embs.detach().to(self.logit_scale.device), dim=-1) | |
| def _split_pose_act(self, x: torch.Tensor, i: int) -> tuple[torch.Tensor, torch.Tensor]: | |
| return _capsule_split_pose_act(x, self.k_list[i], self.d_list[i], "CapsuleOpenVocabDetect", i) | |
| def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: | |
| return _capsule_build_feats_open_vocab(self, x) | |
| def _prepare_text_embeddings(self, text_embs: torch.Tensor | None, bs: int, device: torch.device) -> torch.Tensor | None: | |
| if text_embs is None: | |
| if self.cached_text_embeddings.numel() == 0: | |
| return None | |
| text = self.cached_text_embeddings | |
| else: | |
| text = text_embs | |
| if text.ndim == 2: | |
| text = text.unsqueeze(0).expand(bs, -1, -1) | |
| elif text.ndim != 3: | |
| raise ValueError(f"text_embs must be 2D or 3D, got shape {tuple(text.shape)}.") | |
| if text.shape[-1] != self.embed: | |
| raise ValueError(f"text_embs last dim must equal embed={self.embed}, got {text.shape[-1]}.") | |
| return F.normalize(text.to(device=device, dtype=self.logit_scale.dtype), dim=-1) | |
| def _compute_ov_scores( | |
| self, cls_feats: list[torch.Tensor], acts: list[torch.Tensor], text_embs: torch.Tensor | None | |
| ) -> tuple[torch.Tensor | None, list[torch.Tensor], torch.Tensor | None]: | |
| bs = cls_feats[0].shape[0] | |
| level_embeddings = [] | |
| for i in range(self.nl): | |
| emb = self.emb_head[i](cls_feats[i]) | |
| if self.with_objectness_prior: | |
| emb = emb * (1.0 + torch.sigmoid(self.obj_prior[i](acts[i]))) | |
| level_embeddings.append(emb) | |
| text = self._prepare_text_embeddings(text_embs, bs, cls_feats[0].device) | |
| if text is None: | |
| return None, level_embeddings, None | |
| visual_tokens = torch.cat( | |
| [F.normalize(emb.flatten(2).transpose(1, 2), dim=-1) for emb in level_embeddings], | |
| dim=1, | |
| ) | |
| scale = self.logit_scale.exp().clamp(max=100.0) | |
| scores = torch.einsum("bnd,bcd->bcn", visual_tokens, text) * scale | |
| return scores, level_embeddings, text | |
| def forward_head( | |
| self, | |
| x: list[torch.Tensor], | |
| text_embs: torch.Tensor | None = None, | |
| box_head: torch.nn.Module = None, | |
| cls_head: torch.nn.Module = None, | |
| ) -> dict[str, torch.Tensor]: | |
| del cls_head # fixed-class cls head is unused in open-vocabulary mode | |
| if box_head is None: | |
| return dict() | |
| box_feats, cls_feats, acts = self._build_feats(x) | |
| bs = x[0].shape[0] | |
| boxes = torch.cat([box_head[i](box_feats[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)], dim=-1) | |
| scores, level_embeddings, text = self._compute_ov_scores(cls_feats, acts, text_embs) | |
| preds = { | |
| "boxes": boxes, | |
| "embeddings": level_embeddings, | |
| "cls_feats": cls_feats, | |
| "acts": acts, | |
| "feats": x, | |
| } | |
| if scores is not None: | |
| preds["scores"] = scores | |
| preds["text_embeddings"] = text | |
| return preds | |
| def forward( | |
| self, x: list[torch.Tensor], text_embs: torch.Tensor | None = None | |
| ) -> dict[str, torch.Tensor] | torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]: | |
| preds = self.forward_head(x, text_embs=text_embs, **self.one2many) | |
| if self.end2end: | |
| x_detach = [xi.detach() for xi in x] | |
| one2one = self.forward_head(x_detach, text_embs=text_embs, **self.one2one) | |
| preds = {"one2many": preds, "one2one": one2one} | |
| if self.training: | |
| return preds | |
| infer_preds = preds["one2one"] if self.end2end else preds | |
| if "scores" not in infer_preds: | |
| raise ValueError("CapsuleOpenVocabDetect inference requires text_embs or cached text embeddings.") | |
| original_nc = self.nc | |
| self.nc = int(infer_preds["scores"].shape[1]) | |
| try: | |
| y = self._inference(infer_preds) | |
| if self.end2end: | |
| y = self.postprocess(y.permute(0, 2, 1)) | |
| finally: | |
| self.nc = original_nc | |
| return y if self.export else (y, preds) | |
| class CapsuleSegmentv1(Segment): | |
| """Capsule-style Segment head aligned with CapsuleDetectv6 semantics.""" | |
| def __init__( | |
| self, | |
| nc: int = 80, | |
| *args, | |
| nm: int = 32, | |
| npr: int = 256, | |
| reg_max: int = 16, | |
| end2end: bool = False, | |
| k: list[int] | tuple[int, ...] = (4, 8, 16), | |
| d: list[int] | tuple[int, ...] = (16, 16, 16), | |
| ch: tuple = (), | |
| ): | |
| parsed = list(args) | |
| if parsed and isinstance(parsed[-1], (list, tuple)): | |
| ch = tuple(parsed.pop(-1)) | |
| if len(parsed) not in (2, 4, 6): | |
| raise ValueError("CapsuleSegmentv1 expects [k_list, d_list, (nm, npr), reg_max, end2end, ch].") | |
| k, d = parsed[0], parsed[1] | |
| if len(parsed) == 4: | |
| if isinstance(parsed[3], bool): | |
| reg_max = int(parsed[2]) | |
| end2end = bool(parsed[3]) | |
| else: | |
| nm = int(parsed[2]) | |
| npr = int(parsed[3]) | |
| elif len(parsed) == 6: | |
| nm = int(parsed[2]) | |
| npr = int(parsed[3]) | |
| reg_max = int(parsed[4]) | |
| end2end = bool(parsed[5]) | |
| self.k_list, self.d_list, merged_ch = _setup_capsule_layout(k, d, ch, "CapsuleSegmentv1") | |
| super().__init__(nc=nc, nm=nm, npr=npr, reg_max=reg_max, end2end=end2end, ch=merged_ch) | |
| _init_capsule_semantic_heads(self) | |
| self.proto = Proto26(merged_ch, self.npr, self.nm, nc) | |
| def _split_pose_act(self, x: torch.Tensor, i: int) -> tuple[torch.Tensor, torch.Tensor]: | |
| return _capsule_split_pose_act(x, self.k_list[i], self.d_list[i], "CapsuleSegmentv1", i) | |
| def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: | |
| return _capsule_build_feats_gated(self, x) | |
| def forward_head( | |
| self, | |
| x: list[torch.Tensor], | |
| box_head: torch.nn.Module = None, | |
| cls_head: torch.nn.Module = None, | |
| mask_head: torch.nn.Module = None, | |
| ) -> dict[str, torch.Tensor]: | |
| if box_head is None or cls_head is None: | |
| return dict() | |
| box_feats, cls_feats, cls_priors = self._build_feats(x) | |
| bs = x[0].shape[0] | |
| boxes = torch.cat([box_head[i](box_feats[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)], dim=-1) | |
| scores = torch.cat( | |
| [(cls_head[i](cls_feats[i]) + cls_priors[i]).view(bs, self.nc, -1) for i in range(self.nl)], | |
| dim=-1, | |
| ) | |
| preds = dict(boxes=boxes, scores=scores, feats=cls_feats) | |
| if mask_head is not None: | |
| preds["mask_coefficient"] = torch.cat( | |
| [mask_head[i](cls_feats[i]).view(bs, self.nm, -1) for i in range(self.nl)], | |
| dim=-1, | |
| ) | |
| return preds | |
| def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor] | dict[str, torch.Tensor]: | |
| _, cls_feats, _ = self._build_feats(x) | |
| outputs = Detect.forward(self, x) | |
| preds = outputs[1] if isinstance(outputs, tuple) else outputs | |
| proto_in = cls_feats | |
| proto = self.proto(proto_in) # multi-level Proto26 over merged capsule features | |
| if isinstance(preds, dict): | |
| if self.end2end: | |
| preds["one2many"]["proto"] = proto | |
| preds["one2one"]["proto"] = tuple(p.detach() for p in proto) if isinstance(proto, tuple) else proto.detach() | |
| else: | |
| preds["proto"] = proto | |
| if self.training: | |
| return preds | |
| return (outputs, proto) if self.export else ((outputs[0], proto), preds) | |
| class CapsuleSegmentv2(CapsuleSegmentv1): | |
| """Capsule Segment v2: cls head consumes raw pose features and symbolic priors only.""" | |
| def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: | |
| return _capsule_build_feats_boxcls(self, x) | |
| class CapsuleSegmentv3(CapsuleSegmentv1): | |
| """Capsule Segment v3: raw pose cls path with simplified cls_prior (no norm/centering).""" | |
| def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: | |
| return _capsule_build_feats_boxcls_simpleprior(self, x) | |