Image Segmentation
ultralytics
PyTorch
English
object-detection
instance-segmentation
yolov8
coco
real-time
capsule-network
interpretable-ai
symbolic-ai
Eval Results (legacy)
Instructions to use zpyuan/SymbolicCapsuleNetwork with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- ultralytics
How to use zpyuan/SymbolicCapsuleNetwork with ultralytics:
from ultralytics import YOLOvv8 model = YOLOvv8.from_pretrained("zpyuan/SymbolicCapsuleNetwork") source = 'http://images.cocodataset.org/val2017/000000039769.jpg' model.predict(source=source, save=True) - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| from typing import Iterable, Sequence | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| def squash(x: torch.Tensor, dim: int = -1, eps: float = 1e-7) -> torch.Tensor: | |
| """Squash nonlinearity used by capsule networks.""" | |
| squared_norm = (x * x).sum(dim=dim, keepdim=True) | |
| scale = squared_norm / (1.0 + squared_norm) | |
| return scale * x / torch.sqrt(squared_norm + eps) | |
| class ConvBNAct(nn.Module): | |
| """Convolution + BatchNorm + SiLU.""" | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int = 3, | |
| stride: int = 1, | |
| padding: int | None = None, | |
| ) -> None: | |
| super().__init__() | |
| if padding is None: | |
| padding = kernel_size // 2 | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False) | |
| self.bn = nn.BatchNorm2d(out_channels) | |
| self.act = nn.SiLU(inplace=True) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.act(self.bn(self.conv(x))) | |
| class PrimaryCaps2d(nn.Module): | |
| """Primary capsule layer for 2D feature maps.""" | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| num_caps: int, | |
| dim_caps: int, | |
| kernel_size: int = 1, | |
| stride: int = 1, | |
| padding: int | None = None, | |
| ) -> None: | |
| super().__init__() | |
| if padding is None: | |
| padding = kernel_size // 2 | |
| out_channels = num_caps * dim_caps | |
| self.num_caps = num_caps | |
| self.dim_caps = dim_caps | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False) | |
| self.bn = nn.BatchNorm2d(out_channels) | |
| self.act = nn.SiLU(inplace=True) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.act(self.bn(self.conv(x))) | |
| bsz, _, h, w = x.shape | |
| x = x.view(bsz, self.num_caps, self.dim_caps, h, w) | |
| return squash(x, dim=2) | |
| class RoutingCaps(nn.Module): | |
| """Dynamic routing between capsules.""" | |
| def __init__( | |
| self, | |
| num_in_caps: int, | |
| dim_in: int, | |
| num_out_caps: int, | |
| dim_out: int, | |
| routing_iters: int = 3, | |
| ) -> None: | |
| super().__init__() | |
| self.num_in_caps = num_in_caps | |
| self.dim_in = dim_in | |
| self.num_out_caps = num_out_caps | |
| self.dim_out = dim_out | |
| self.routing_iters = routing_iters | |
| weight = torch.randn(1, num_in_caps, num_out_caps, dim_out, dim_in) * 0.01 | |
| self.W = nn.Parameter(weight) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if x.ndim != 3: | |
| raise ValueError(f"RoutingCaps expects [B, N, D], got {tuple(x.shape)}") | |
| bsz = x.shape[0] | |
| x = x.unsqueeze(2).unsqueeze(-1) # [B, N, 1, D, 1] | |
| u_hat = torch.matmul(self.W, x).squeeze(-1) # [B, N, M, Dout] | |
| b = x.new_zeros(bsz, self.num_in_caps, self.num_out_caps) | |
| for idx in range(self.routing_iters): | |
| c = F.softmax(b, dim=-1) | |
| s = (c.unsqueeze(-1) * u_hat).sum(dim=1) | |
| v = squash(s, dim=-1) | |
| if idx < self.routing_iters - 1: | |
| b = b + (u_hat * v.unsqueeze(1)).sum(dim=-1) | |
| return v | |
| class DeformableCaps2d(nn.Module): | |
| """Deformable capsule layer with learned sampling offsets.""" | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| num_child_caps: int = 8, | |
| dim_child: int = 8, | |
| num_parent_caps: int = 8, | |
| dim_parent: int = 8, | |
| num_samples: int = 4, | |
| routing_iters: int = 3, | |
| offset_scale: float = 1.0, | |
| out_channels: int | None = None, | |
| ) -> None: | |
| super().__init__() | |
| self.num_child_caps = num_child_caps | |
| self.dim_child = dim_child | |
| self.num_parent_caps = num_parent_caps | |
| self.dim_parent = dim_parent | |
| self.num_samples = num_samples | |
| self.routing_iters = routing_iters | |
| self.offset_scale = offset_scale | |
| self.primary = PrimaryCaps2d(in_channels, num_child_caps, dim_child, kernel_size=1, stride=1, padding=0) | |
| self.offset = nn.Conv2d(in_channels, 2 * num_samples, kernel_size=3, stride=1, padding=1) | |
| nn.init.zeros_(self.offset.weight) | |
| nn.init.zeros_(self.offset.bias) | |
| self.routing = RoutingCaps( | |
| num_in_caps=num_samples * num_child_caps, | |
| dim_in=dim_child, | |
| num_out_caps=num_parent_caps, | |
| dim_out=dim_parent, | |
| routing_iters=routing_iters, | |
| ) | |
| caps_channels = num_parent_caps * dim_parent | |
| self.out_channels = out_channels or caps_channels | |
| self.project = None | |
| if self.out_channels != caps_channels: | |
| self.project = ConvBNAct(caps_channels, self.out_channels, kernel_size=1, stride=1, padding=0) | |
| def _base_grid(h: int, w: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: | |
| ys = torch.linspace(-1.0, 1.0, h, device=device, dtype=dtype) | |
| xs = torch.linspace(-1.0, 1.0, w, device=device, dtype=dtype) | |
| yy, xx = torch.meshgrid(ys, xs, indexing="ij") | |
| return torch.stack((xx, yy), dim=-1) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| child = self.primary(x) # [B, Nc, Dc, H, W] | |
| bsz, _, _, h, w = child.shape | |
| child_flat = child.view(bsz, self.num_child_caps * self.dim_child, h, w) | |
| offsets = self.offset(x).view(bsz, self.num_samples, 2, h, w) | |
| offsets = torch.tanh(offsets) * self.offset_scale | |
| scale_x = max(w - 1, 1) / 2.0 | |
| scale_y = max(h - 1, 1) / 2.0 | |
| scale = offsets.new_tensor([scale_x, scale_y]).view(1, 1, 2, 1, 1) | |
| offsets = offsets / scale | |
| base = self._base_grid(h, w, x.device, x.dtype).view(1, 1, h, w, 2) | |
| grids = base + offsets.permute(0, 1, 3, 4, 2) | |
| sampled = [] | |
| for idx in range(self.num_samples): | |
| grid = grids[:, idx] | |
| feat = F.grid_sample( | |
| child_flat, | |
| grid, | |
| mode="bilinear", | |
| padding_mode="zeros", | |
| align_corners=True, | |
| ) | |
| sampled.append(feat) | |
| sampled = torch.stack(sampled, dim=1) | |
| sampled = sampled.view(bsz, self.num_samples, self.num_child_caps, self.dim_child, h, w) | |
| sampled = sampled.permute(0, 4, 5, 1, 2, 3).contiguous() | |
| sampled = sampled.view(bsz * h * w, self.num_samples * self.num_child_caps, self.dim_child) | |
| routed = self.routing(sampled) | |
| routed = routed.view(bsz, h, w, self.num_parent_caps, self.dim_parent) | |
| routed = routed.permute(0, 3, 4, 1, 2).contiguous() | |
| out = routed.view(bsz, self.num_parent_caps * self.dim_parent, h, w) | |
| if self.project is not None: | |
| out = self.project(out) | |
| return out | |
| class DeformableCapsBlock(nn.Module): | |
| """Backbone block: Conv downsample + deformable capsule routing.""" | |
| def __init__( | |
| self, | |
| c1: int, | |
| c2: int, | |
| num_child_caps: int = 8, | |
| dim_child: int = 8, | |
| num_parent_caps: int = 8, | |
| dim_parent: int = 8, | |
| num_samples: int = 4, | |
| routing_iters: int = 3, | |
| stride: int = 1, | |
| ) -> None: | |
| super().__init__() | |
| self.down = ConvBNAct(c1, c2, kernel_size=3, stride=stride) | |
| self.caps = DeformableCaps2d( | |
| in_channels=c2, | |
| num_child_caps=num_child_caps, | |
| dim_child=dim_child, | |
| num_parent_caps=num_parent_caps, | |
| dim_parent=dim_parent, | |
| num_samples=num_samples, | |
| routing_iters=routing_iters, | |
| out_channels=c2, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.down(x) | |
| return self.caps(x) | |
| class CapsuleBackbone(nn.Module): | |
| """Simple capsule-based backbone that returns multi-scale features.""" | |
| def __init__( | |
| self, | |
| in_channels: int = 3, | |
| stem_channels: int = 64, | |
| stages: Sequence[int] = (128, 256, 512), | |
| capsule_cfgs: Iterable[dict] | None = None, | |
| ) -> None: | |
| super().__init__() | |
| self.stem = ConvBNAct(in_channels, stem_channels, kernel_size=3, stride=2) | |
| stage_cfgs = list(capsule_cfgs) if capsule_cfgs is not None else [{}] * len(stages) | |
| if len(stage_cfgs) != len(stages): | |
| raise ValueError("capsule_cfgs must match stages length") | |
| blocks = [] | |
| in_ch = stem_channels | |
| for out_ch, cfg in zip(stages, stage_cfgs): | |
| blocks.append( | |
| nn.Sequential( | |
| ConvBNAct(in_ch, out_ch, kernel_size=3, stride=2), | |
| DeformableCaps2d(out_ch, out_channels=out_ch, **cfg), | |
| ) | |
| ) | |
| in_ch = out_ch | |
| self.stages = nn.ModuleList(blocks) | |
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]: | |
| x = self.stem(x) | |
| outputs = [] | |
| for stage in self.stages: | |
| x = stage(x) | |
| outputs.append(x) | |
| return tuple(outputs) | |
| __all__ = [ | |
| "CapsuleBackbone", | |
| "ConvBNAct", | |
| "DeformableCaps2d", | |
| "DeformableCapsBlock", | |
| "PrimaryCaps2d", | |
| "RoutingCaps", | |
| "squash", | |
| ] | |