""" CapsNeck: efficient capsule-style neck blocks for Ultralytics YAML models. Design intent: - Keep capsule semantics (type/channel grouping + routing-style fusion). - Stay lightweight and export-friendly for detection training/inference. - Avoid expensive iterative EM/dynamic routing inside the neck path. This neck is "capsule-style" rather than a full matrix-capsule network: 1) CapsProj : CNN feature -> packed capsules (K types * D dims) 2) CapsAlign : scale alignment between pyramid levels (no global context) 3) CapsRoute : efficient self-routing proxy across sources (softmax source gating) 4) CapsDecode: packed capsules -> standard feature map for Detect 5) CapsuleTap: optional pass-through cache hook for analysis/aux losses Note: - Routing here is source-level and single-step by default (iters=1), chosen for speed. - If stronger capsule routing is needed, it should be added in the head where cost is lower. """ from __future__ import annotations from typing import List, Optional, Tuple, Union import math import time import torch import torch.nn as nn import torch.nn.functional as F from ultralytics.nn.modules import C3k2, Conv, DWConv # ------------------------- # 1) CapsProj # ------------------------- class CapsProj(nn.Module): """ Project a standard feature map into packed capsule channels using one C3k2 block. Input: x [B, C, H, W] Output: u [B, K*(D+1), H, W] Args: K: number of capsule types D: capsule pose dimension per type mix/mix_kernel: kept for backward YAML compatibility (unused) """ def __init__(self, c1: int, K: int = 4, D: int = 16): super().__init__() self.K = int(K) self.D = int(D) self.c_out = self.K * (self.D + 1) # Use a single C3k2 block as the capsule projection operator. self.map = C3k2(c1, self.c_out, n=1, c3k=False, e=0.5, g=1, shortcut=True) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.map(x) # ------------------------- # 2) CapsAlign (no context) # ------------------------- class CapsAlign(nn.Module): """ Align packed capsules across pyramid levels with YOLO-style ops. - Upsampling uses ``nn.Upsample(scale_factor=2, mode='nearest')``. - Downsampling uses stride-2 ``Conv`` blocks. Args: c1: input/output channel count. src_level: source pyramid level in {3,4,5}. tgt_level: target pyramid level in {3,4,5}. down_groups: groups for downsample Conv. Use capsule-type count K to keep each capsule block isolated. """ def __init__(self, c1: int, src_level: int, tgt_level: int, down_groups: int = 1): super().__init__() self.c1 = int(c1) self.src_level = int(src_level) self.tgt_level = int(tgt_level) self.down_groups = int(down_groups) if self.src_level not in (3, 4, 5) or self.tgt_level not in (3, 4, 5): raise ValueError("CapsAlign levels must be in {3,4,5}.") if self.down_groups < 1 or self.c1 % self.down_groups != 0: raise ValueError(f"CapsAlign down_groups={self.down_groups} must divide c1={self.c1}.") steps = abs(self.src_level - self.tgt_level) if self.src_level == self.tgt_level: self.mode = 'identity' self.ops = nn.ModuleList() elif self.src_level > self.tgt_level: self.mode = 'up' # YOLO-style top-down path: nearest-neighbor upsample x2 per level. self.ops = nn.ModuleList(nn.Upsample(scale_factor=2, mode='nearest') for _ in range(steps)) else: self.mode = 'down' # YOLO-style bottom-up path: stride-2 grouped Conv per level. self.ops = nn.ModuleList(Conv(self.c1, self.c1, 3, 2, g=self.down_groups) for _ in range(steps)) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.mode == 'identity': return x for op in self.ops: x = op(x) return x # ------------------------- # 3) CapsRoute (light, parser-friendly) # ------------------------- class ConvSelfRouting(nn.Module): """Grouped-conv self-routing over stacked capsule sources. Args: K_in: input capsule type count. P_in: input pose dimension. K_out: output capsule type count. P_out: output pose dimension. kernel_size: grouped conv kernel for local capsule mixing. """ def __init__(self, K_in: int, P_in: int, K_out: int, P_out: int, kernel_size: int = 3): super().__init__() self.K_in = int(K_in) self.P_in = int(P_in) self.K_out = int(K_out) self.P_out = int(P_out) if min(self.K_in, self.P_in, self.K_out, self.P_out) <= 0: raise ValueError('ConvSelfRouting expects positive K/P values.') self.c_in = self.K_in * (self.P_in + 1) self.c_out = self.K_out * (self.P_out + 1) k = int(kernel_size) padding = k//2 self.mix = nn.Conv2d(self.c_in, self.c_in, kernel_size=k, stride=1, padding=padding, groups=self.K_in, bias=False) self.gate = nn.Conv2d(self.c_in, self.K_in, kernel_size=1, stride=1, padding=0, groups=self.K_in, bias=True) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [B,C,H,W], C = K_in*(P_in+1) b, c, h, w = x.shape if c != self.c_in: raise ValueError(f'ConvSelfRouting expected C={self.c_in}, got C={c}') mixed = self.mix(x) logits = self.gate(mixed).reshape(b, self.K_in, h, w) weights = logits.softmax(dim=1) caps = mixed.reshape(b, self.K_in, self.P_in + 1, h, w) routed = weights.unsqueeze(2) * caps routed = routed.reshape(b, self.c_in, h, w) return routed class SelfRouting(nn.Module): """Pose-transform self-routing on packed capsule tensor. Args: K_in: input capsule type count. P_in: input pose dimension. K_out: output capsule type count. P_out: output pose dimension. Input: x: [B, K_in*(P_in+1), H, W] Output: y: [B, K_out*(P_out+1), H, W] """ def __init__(self, K_in: int, P_in: int, K_out: int, P_out: int): super().__init__() self.K_in = int(K_in) self.P_in = int(P_in) self.K_out = int(K_out) self.P_out = int(P_out) if min(self.K_in, self.P_in, self.K_out, self.P_out) <= 0: raise ValueError('SelfRouting expects positive K/P values.') self.c_in = self.K_in * (self.P_in + 1) self.c_out = self.K_out * (self.P_out + 1) self.eps = 1e-6 self.W_pose = nn.Parameter(torch.empty(self.K_in, self.K_out, self.P_in, self.P_out)) nn.init.kaiming_uniform_(self.W_pose, a=math.sqrt(5)) self.W_gate = nn.Parameter(torch.zeros(self.K_in, self.K_out, self.P_in)) self.b_gate = nn.Parameter(torch.zeros(1, self.K_in, self.K_out, 1, 1)) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [B, C, H, W], C = K_in*(P_in+1) if x.ndim != 4: raise TypeError(f'SelfRouting expects [B,C,H,W], got {tuple(x.shape)}') b, c, h, w = x.shape if c != self.c_in: raise ValueError(f'SelfRouting expected C={self.c_in}, got C={c}') # Packed capsule layout is interleaved per type: [pose(P), act(1)]. # x_caps: [B, K_in, P_in+1, H, W] x_caps = x.reshape(b, self.K_in, self.P_in + 1, h, w) pose = x_caps[:, :, :self.P_in] # [B, K_in, P_in, H, W] act = x_caps[:, :, self.P_in : self.P_in + 1].sigmoid() # [B, K_in, 1, H, W] # votes: [B, K_in, K_out, H, W, P_out] votes = torch.einsum('bkphw,kopq->bkohwq', pose, self.W_pose) # logits/weights: [B, K_in, K_out, H, W] logits = torch.einsum('bkphw,kop->bkohw', pose, self.W_gate) + self.b_gate weights = logits.softmax(dim=2) ar = weights * act # [B, K_in, K_out, H, W] ar_sum = ar.sum(dim=1, keepdim=True) + self.eps coeff = ar / ar_sum pose_out = (coeff.unsqueeze(-1) * votes).sum(dim=1) # [B, K_out, H, W, P_out] pose_out = pose_out.permute(0, 1, 4, 2, 3) # [B, K_out, P_out, H, W] act_out = ar_sum.squeeze(1).unsqueeze(2) # [B, K_out, 1, H, W] # Keep interleaved packed output: [pose(P_out), act(1)] per capsule type. out = torch.cat([pose_out, act_out], dim=2).reshape(b, self.c_out, h, w) return out class HybridRoute1(nn.Module): """Conv-heavy replacement for SelfRouting with lightweight capsule-aware gating.""" def __init__(self, K_in: int, P_in: int, K_out: int, P_out: int): super().__init__() self.K_in = int(K_in) self.P_in = int(P_in) self.K_out = int(K_out) self.P_out = int(P_out) self.c_in = self.K_in * (self.P_in + 1) self.c_out = self.K_out * (self.P_out + 1) pose_in = self.K_in * self.P_in pose_out = self.K_out * self.P_out vote_groups = math.gcd(self.K_in, self.K_out) vote_groups = max(int(vote_groups), 1) self.vote_proj = Conv(pose_in, pose_out, 1, 1, g=vote_groups) self.gate_proj = nn.Conv2d(self.c_in, self.K_out, kernel_size=1, stride=1, padding=0, bias=True) self.act_proj = nn.Conv2d(self.K_in, self.K_out, kernel_size=1, stride=1, padding=0, bias=True) def forward(self, x: torch.Tensor) -> torch.Tensor: if x.ndim != 4: raise TypeError(f'HybridRoute1 expects [B,C,H,W], got {tuple(x.shape)}') b, c, h, w = x.shape if c != self.c_in: raise ValueError(f'HybridRoute1 expected C={self.c_in}, got C={c}') x_caps = x.reshape(b, self.K_in, self.P_in + 1, h, w) pose = x_caps[:, :, :self.P_in].reshape(b, self.K_in * self.P_in, h, w) act = x_caps[:, :, self.P_in].contiguous() pose_votes = self.vote_proj(pose).reshape(b, self.K_out, self.P_out, h, w) gate = self.gate_proj(x).sigmoid().unsqueeze(2) pose_out = pose_votes * gate act_out = self.act_proj(act).sigmoid().unsqueeze(2) out = torch.cat([pose_out, act_out], dim=2).reshape(b, self.c_out, h, w) return out class CapsRoute(nn.Module): """Capsule routing fusion by direct capsule concatenation. Args: K_in: list of input capsule type counts per source. P_in: list of input pose dimensions per source. K_out: target output capsule type count. P_out: target output pose dimension. kernel_size: grouped-conv kernel for ``ConvSelfRouting``. Notes: Inputs are concatenated directly (no pre-projection). For direct packed concat, all ``P_in`` must be identical. """ def __init__( self, K_in: Union[List[int], Tuple[int, ...]], P_in: Union[List[int], Tuple[int, ...]], K_out: int, P_out: int, kernel_size: int = 3, pre_k: int = 3, post_k: int = 3, pre_groups: Optional[int] = None, post_groups: Optional[int] = None, ): super().__init__() self.K_in_list = [int(v) for v in K_in] self.P_in_list = [int(v) for v in P_in] if len(self.K_in_list) < 2 or len(self.K_in_list) != len(self.P_in_list): raise ValueError('CapsRoute expects K_in/P_in lists with same length >= 2.') if min(*self.K_in_list, *self.P_in_list) <= 0: raise ValueError('CapsRoute expects positive K_in/P_in values.') # Direct capsule concat requires a shared pose dimension. if len(set(self.P_in_list)) != 1: raise ValueError('CapsRoute direct concat requires all P_in to be identical.') self.num_sources = len(self.K_in_list) self.P_cat = int(self.P_in_list[0]) self.K_cat = int(sum(self.K_in_list)) self.c_cat = self.K_cat * (self.P_cat + 1) self.K_out = int(K_out) self.P_out = int(P_out) if min(self.K_out, self.P_out) <= 0: raise ValueError('CapsRoute expects positive K_out/P_out values.') self.c_out = self.K_out * (self.P_out + 1) # self.conv_route = ConvSelfRouting( # K_in=self.K_cat, # P_in=self.P_cat, # K_out=self.K_cat, # P_out=self.P_cat, # kernel_size=kernel_size, # ) # Grouped Conv before routing: C = K_cat * (P_cat + 1), groups = K_cat. self.conv_route = Conv(self.c_cat, self.c_cat, 3, 1, g=self.K_cat) self.route1 = SelfRouting(K_in=self.K_cat, P_in=self.P_cat, K_out=self.K_out, P_out=self.P_out) # Grouped Conv after routing: C = K_out * (P_out + 1), groups = K_out. self.spagg = Conv(self.c_out, self.c_out, 3, 1, g=self.K_out) # self.route2 = SelfRouting(K_in=self.K_out, P_in=self.P_out, K_out=self.K_out, P_out=self.P_out) def forward(self, xs: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]]) -> torch.Tensor: if not isinstance(xs, (list, tuple)): raise TypeError(f'CapsRoute expects list/tuple inputs, got {type(xs)}') if len(xs) != self.num_sources: raise ValueError(f'CapsRoute expected {self.num_sources} sources, got {len(xs)}') h, w = int(xs[0].shape[-2]), int(xs[0].shape[-1]) cat_parts = [] for i, x in enumerate(xs): expected_c = self.K_in_list[i] * (self.P_in_list[i] + 1) if int(x.shape[1]) != expected_c: raise ValueError(f'CapsRoute source-{i} expected C={expected_c} from K_in/P_in, got C={int(x.shape[1])}') if int(x.shape[-2]) != h or int(x.shape[-1]) != w: raise ValueError('CapsRoute inputs must share H,W. Use CapsAlign before routing.') cat_parts.append(x) x_cat = torch.cat(cat_parts, dim=1) # [B, K_cat*(P+1), H, W] routed = self.route1(self.conv_route(x_cat)) routed = self.spagg(routed) return routed class CapsRoutev2(CapsRoute): """CapsRoute with per-capsule pose refinement and act residual update.""" def __init__( self, K_in: Union[List[int], Tuple[int, ...]], P_in: Union[List[int], Tuple[int, ...]], K_out: int, P_out: int, kernel_size: int = 3, pre_k: int = 3, post_k: int = 3, pre_groups: Optional[int] = None, post_groups: Optional[int] = None, ): super().__init__(K_in, P_in, K_out, P_out, kernel_size, pre_k, post_k, pre_groups, post_groups) _ = (post_k, post_groups, pre_k, pre_groups) # kept for YAML/API compatibility self.profile_route = False self._route_profile = { 'cat_ms': 0.0, 'conv_route_ms': 0.0, 'route1_ms': 0.0, 'pose_refine_ms': 0.0, 'act_from_pose_ms': 0.0, 'pack_ms': 0.0, 'calls': 0.0, } deep_stage = self.K_out >= 64 pose_ch = self.K_out * self.P_out # Match YOLO26 neck style: # - shallow/mid stages: C3k2(n=2, c3k=True, attn=False) # - deep stage: C3k2(n=1, c3k=True, attn=True) pose_e = 0.5 if (self.P_out % 2 == 0) else 1.0 self.pose_refine = C3k2( pose_ch, pose_ch, n=1 if deep_stage else 2, c3k=True, e=pose_e, attn=deep_stage, g=self.K_out, shortcut=True, ) self.act_from_pose = Conv(pose_ch, self.K_out, 1, 1, g=self.K_out) self.act_alpha = nn.Parameter(torch.tensor(0.1)) @staticmethod def _sync_profile() -> None: if torch.cuda.is_available(): torch.cuda.synchronize() def _ensure_route_profile_state(self) -> None: if not hasattr(self, "profile_route"): self.profile_route = False if not hasattr(self, "_route_profile"): self._route_profile = { 'cat_ms': 0.0, 'conv_route_ms': 0.0, 'route1_ms': 0.0, 'pose_refine_ms': 0.0, 'act_from_pose_ms': 0.0, 'pack_ms': 0.0, 'calls': 0.0, } def reset_route_profile(self) -> None: self._ensure_route_profile_state() for k in self._route_profile: self._route_profile[k] = 0.0 def get_route_profile(self) -> dict: self._ensure_route_profile_state() calls = max(float(self._route_profile.get('calls', 0.0)), 1.0) total = ( self._route_profile['cat_ms'] + self._route_profile['conv_route_ms'] + self._route_profile['route1_ms'] + self._route_profile['pose_refine_ms'] + self._route_profile['act_from_pose_ms'] + self._route_profile['pack_ms'] ) out = dict(self._route_profile) out['total_ms'] = total out['cat_avg_ms'] = self._route_profile['cat_ms'] / calls out['conv_route_avg_ms'] = self._route_profile['conv_route_ms'] / calls out['route1_avg_ms'] = self._route_profile['route1_ms'] / calls out['pose_refine_avg_ms'] = self._route_profile['pose_refine_ms'] / calls out['act_from_pose_avg_ms'] = self._route_profile['act_from_pose_ms'] / calls out['pack_avg_ms'] = self._route_profile['pack_ms'] / calls out['total_avg_ms'] = total / calls return out def forward(self, xs: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]]) -> torch.Tensor: if not isinstance(xs, (list, tuple)): raise TypeError(f'CapsRoutev2 expects list/tuple inputs, got {type(xs)}') if len(xs) != self.num_sources: raise ValueError(f'CapsRoutev2 expected {self.num_sources} sources, got {len(xs)}') h, w = int(xs[0].shape[-2]), int(xs[0].shape[-1]) cat_parts = [] for i, x in enumerate(xs): expected_c = self.K_in_list[i] * (self.P_in_list[i] + 1) if int(x.shape[1]) != expected_c: raise ValueError(f'CapsRoutev2 source-{i} expected C={expected_c}, got C={int(x.shape[1])}') if int(x.shape[-2]) != h or int(x.shape[-1]) != w: raise ValueError('CapsRoutev2 inputs must share H,W. Use CapsAlign before routing.') cat_parts.append(x) self._ensure_route_profile_state() if getattr(self, "profile_route", False): self._route_profile['calls'] += 1.0 self._sync_profile() t0 = time.perf_counter() x_cat = torch.cat(cat_parts, dim=1) # [B, K_cat*(P+1), H, W] self._sync_profile() self._route_profile['cat_ms'] += (time.perf_counter() - t0) * 1000.0 t0 = time.perf_counter() conv_out = self.conv_route(x_cat) self._sync_profile() self._route_profile['conv_route_ms'] += (time.perf_counter() - t0) * 1000.0 t0 = time.perf_counter() routed = self.route1(conv_out) # [B, K_out*(P_out+1), H, W] self._sync_profile() self._route_profile['route1_ms'] += (time.perf_counter() - t0) * 1000.0 else: x_cat = torch.cat(cat_parts, dim=1) # [B, K_cat*(P+1), H, W] routed = self.route1(self.conv_route(x_cat)) # [B, K_out*(P_out+1), H, W] b, _, _, _ = routed.shape # Packed layout by type: [pose(P), act(1)] repeated K times. caps = routed.reshape(b, self.K_out, self.P_out + 1, h, w) pose = caps[:, :, :self.P_out].contiguous() # [B, K_out, P_out, H, W] act = caps[:, :, self.P_out].contiguous() # [B, K_out, H, W] # Grouped pose refinement across type blocks (equivalent to per-type grouped processing). pose_flat = pose.reshape(b, self.K_out * self.P_out, h, w) if getattr(self, "profile_route", False): t0 = time.perf_counter() pose_flat = self.pose_refine(pose_flat) self._sync_profile() self._route_profile['pose_refine_ms'] += (time.perf_counter() - t0) * 1000.0 t0 = time.perf_counter() act_delta = self.act_from_pose(pose_flat) act_final = act + act_delta self._sync_profile() self._route_profile['act_from_pose_ms'] += (time.perf_counter() - t0) * 1000.0 else: pose_flat = self.pose_refine(pose_flat) act_delta = self.act_from_pose(pose_flat) act_final = act + act_delta if getattr(self, "profile_route", False): t0 = time.perf_counter() pose_pack = pose_flat.reshape(b, self.K_out, self.P_out, h, w) out = torch.cat([pose_pack, act_final.unsqueeze(2)], dim=2).reshape(b, self.c_out, h, w) self._sync_profile() self._route_profile['pack_ms'] += (time.perf_counter() - t0) * 1000.0 else: pose_pack = pose_flat.reshape(b, self.K_out, self.P_out, h, w) out = torch.cat([pose_pack, act_final.unsqueeze(2)], dim=2).reshape(b, self.c_out, h, w) return out # ------------------------- # 4) CapsDecode # ------------------------- class CapsRoutev3(CapsRoute): """CapsRoute with DS-style lightweight pose refinement and act residual update.""" def __init__( self, K_in: Union[List[int], Tuple[int, ...]], P_in: Union[List[int], Tuple[int, ...]], K_out: int, P_out: int, kernel_size: int = 3, pre_k: int = 3, post_k: int = 3, pre_groups: Optional[int] = None, post_groups: Optional[int] = None, ): super().__init__(K_in, P_in, K_out, P_out, kernel_size, pre_k, post_k, pre_groups, post_groups) _ = (post_k, post_groups, pre_k, pre_groups) self.profile_route = False self._route_profile = { 'cat_ms': 0.0, 'conv_route_ms': 0.0, 'route1_ms': 0.0, 'pose_refine_ms': 0.0, 'act_from_pose_ms': 0.0, 'pack_ms': 0.0, 'calls': 0.0, } pose_ch = self.K_out * self.P_out # Keep refinement fully type-grouped to preserve capsule semantics: # each capsule type only mixes its own pose channels. self.pose_refine = nn.Sequential( Conv(pose_ch, pose_ch, 1, 1, g=self.K_out), Conv(pose_ch, pose_ch, 3, 1, g=self.K_out), Conv(pose_ch, pose_ch, 1, 1, g=self.K_out), ) self.act_from_pose = Conv(pose_ch, self.K_out, 1, 1, g=self.K_out) self.act_alpha = nn.Parameter(torch.tensor(0.1)) @staticmethod def _sync_profile() -> None: if torch.cuda.is_available(): torch.cuda.synchronize() def _ensure_route_profile_state(self) -> None: if not hasattr(self, "profile_route"): self.profile_route = False if not hasattr(self, "_route_profile"): self._route_profile = { 'cat_ms': 0.0, 'conv_route_ms': 0.0, 'route1_ms': 0.0, 'pose_refine_ms': 0.0, 'act_from_pose_ms': 0.0, 'pack_ms': 0.0, 'calls': 0.0, } def reset_route_profile(self) -> None: self._ensure_route_profile_state() for k in self._route_profile: self._route_profile[k] = 0.0 def get_route_profile(self) -> dict: self._ensure_route_profile_state() calls = max(float(self._route_profile.get('calls', 0.0)), 1.0) total = ( self._route_profile['cat_ms'] + self._route_profile['conv_route_ms'] + self._route_profile['route1_ms'] + self._route_profile['pose_refine_ms'] + self._route_profile['act_from_pose_ms'] + self._route_profile['pack_ms'] ) out = dict(self._route_profile) out['total_ms'] = total out['cat_avg_ms'] = self._route_profile['cat_ms'] / calls out['conv_route_avg_ms'] = self._route_profile['conv_route_ms'] / calls out['route1_avg_ms'] = self._route_profile['route1_ms'] / calls out['pose_refine_avg_ms'] = self._route_profile['pose_refine_ms'] / calls out['act_from_pose_avg_ms'] = self._route_profile['act_from_pose_ms'] / calls out['pack_avg_ms'] = self._route_profile['pack_ms'] / calls out['total_avg_ms'] = total / calls return out def forward(self, xs: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]]) -> torch.Tensor: if not isinstance(xs, (list, tuple)): raise TypeError(f'CapsRoutev3 expects list/tuple inputs, got {type(xs)}') if len(xs) != self.num_sources: raise ValueError(f'CapsRoutev3 expected {self.num_sources} sources, got {len(xs)}') h, w = int(xs[0].shape[-2]), int(xs[0].shape[-1]) cat_parts = [] for i, x in enumerate(xs): expected_c = self.K_in_list[i] * (self.P_in_list[i] + 1) if int(x.shape[1]) != expected_c: raise ValueError(f'CapsRoutev3 source-{i} expected C={expected_c}, got C={int(x.shape[1])}') if int(x.shape[-2]) != h or int(x.shape[-1]) != w: raise ValueError('CapsRoutev3 inputs must share H,W. Use CapsAlign before routing.') cat_parts.append(x) self._ensure_route_profile_state() if getattr(self, "profile_route", False): self._route_profile['calls'] += 1.0 self._sync_profile() t0 = time.perf_counter() x_cat = torch.cat(cat_parts, dim=1) self._sync_profile() self._route_profile['cat_ms'] += (time.perf_counter() - t0) * 1000.0 t0 = time.perf_counter() conv_out = self.conv_route(x_cat) self._sync_profile() self._route_profile['conv_route_ms'] += (time.perf_counter() - t0) * 1000.0 t0 = time.perf_counter() routed = self.route1(conv_out) self._sync_profile() self._route_profile['route1_ms'] += (time.perf_counter() - t0) * 1000.0 else: x_cat = torch.cat(cat_parts, dim=1) routed = self.route1(self.conv_route(x_cat)) b, _, _, _ = routed.shape caps = routed.reshape(b, self.K_out, self.P_out + 1, h, w) pose = caps[:, :, :self.P_out].contiguous() act = caps[:, :, self.P_out].contiguous() pose_flat = pose.reshape(b, self.K_out * self.P_out, h, w) if getattr(self, "profile_route", False): t0 = time.perf_counter() pose_flat = pose_flat + self.pose_refine(pose_flat) self._sync_profile() self._route_profile['pose_refine_ms'] += (time.perf_counter() - t0) * 1000.0 t0 = time.perf_counter() act_delta = self.act_from_pose(pose_flat) act_final = act + act_delta self._sync_profile() self._route_profile['act_from_pose_ms'] += (time.perf_counter() - t0) * 1000.0 else: pose_flat = pose_flat + self.pose_refine(pose_flat) act_delta = self.act_from_pose(pose_flat) act_final = act + act_delta if getattr(self, "profile_route", False): t0 = time.perf_counter() pose_pack = pose_flat.reshape(b, self.K_out, self.P_out, h, w) out = torch.cat([pose_pack, act_final.unsqueeze(2)], dim=2).reshape(b, self.c_out, h, w) self._sync_profile() self._route_profile['pack_ms'] += (time.perf_counter() - t0) * 1000.0 else: pose_pack = pose_flat.reshape(b, self.K_out, self.P_out, h, w) out = torch.cat([pose_pack, act_final.unsqueeze(2)], dim=2).reshape(b, self.c_out, h, w) return out class CapsRoutev4(CapsRoutev2): """CapsRoutev2 with conv-heavy HybridRoute1 to reduce routing overhead.""" def __init__( self, K_in: Union[List[int], Tuple[int, ...]], P_in: Union[List[int], Tuple[int, ...]], K_out: int, P_out: int, kernel_size: int = 3, pre_k: int = 3, post_k: int = 3, pre_groups: Optional[int] = None, post_groups: Optional[int] = None, ): super().__init__(K_in, P_in, K_out, P_out, kernel_size, pre_k, post_k, pre_groups, post_groups) self.route1 = HybridRoute1(K_in=self.K_cat, P_in=self.P_cat, K_out=self.K_out, P_out=self.P_out) class CapsDecode(nn.Module): """ Decode routed capsule features to standard feature map for Detect. Input: y [B, C_in, H, W] (often concat of weighted sources, so C_in = S*(K*D)) Output: f [B, C_out, H, W] Args: c2: output channels (e.g., 256/512/1024) """ def __init__(self, c1: int, c2: int): super().__init__() self.conv = nn.Conv2d(c1, c2, kernel_size=1, stride=1, padding=0, bias=False) self.bn = nn.BatchNorm2d(c2) self.act = nn.SiLU(inplace=True) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.act(self.bn(self.conv(x))) # ------------------------- # 5) CapsuleTap # ------------------------- class CapsuleTap(nn.Module): """ Pass-through hook to cache feature maps for explainability/aux loss. MUST NOT change tensor shape. Returns x unchanged. Args: tag: string identifier ("F3"/"F4"/"F5") K,D: capsule hyperparams (metadata only) cache_enabled: if True, cache during training (disabled in tracing/scripting) """ def __init__(self, tag: str = "F", K: int = 4, D: int = 16, cache_enabled: bool = True): super().__init__() self.tag = str(tag) self.K = int(K) self.D = int(D) self.cache_enabled = bool(cache_enabled) self.last_x: Optional[torch.Tensor] = None def clear_cache(self) -> None: self.last_x = None def forward(self, x: torch.Tensor) -> torch.Tensor: if ( self.cache_enabled and self.training and (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()) ): self.last_x = x.detach() return x