from __future__ import annotations from typing import Any import torch import torch.nn as nn def _choose_gn_groups(channels: int, max_groups: int = 8) -> int: for g in range(min(max_groups, channels), 0, -1): if channels % g == 0: return g return 1 class _ChannelGate(nn.Module): def __init__(self, channels: int, reduction: int = 4) -> None: super().__init__() hidden = max(channels // reduction, 8) self.pool = nn.AdaptiveAvgPool3d(1) self.fc1 = nn.Conv3d(channels, hidden, kernel_size=1, bias=True) self.act = nn.GELU() self.fc2 = nn.Conv3d(hidden, channels, kernel_size=1, bias=True) self.gate = nn.Sigmoid() def forward(self, x: torch.Tensor) -> torch.Tensor: s = self.pool(x) s = self.fc1(s) s = self.act(s) s = self.fc2(s) return x * self.gate(s) class _FastHyperBlock(nn.Module): """ Efficient RF-expanding residual block. Each block contributes one effective k=3 receptive-field expansion stage via three parallel branches operating on the same expanded activation: - spatial depthwise (1,3,3) - temporal depthwise (3,1,1) - grouped 3D mixing (3,3,3) """ def __init__( self, channels: int, mid_dim: int, mix_groups: int = 6, dropout_p: float = 0.02, gate_reduction: int = 4, ) -> None: super().__init__() gn1 = _choose_gn_groups(channels) gn2 = _choose_gn_groups(mid_dim) mix_groups = max(1, min(mix_groups, mid_dim)) while mid_dim % mix_groups != 0 and mix_groups > 1: mix_groups -= 1 self.pre = nn.Sequential( nn.GroupNorm(gn1, channels), nn.Conv3d(channels, mid_dim, kernel_size=1, bias=True), nn.GELU(), ) self.spatial = nn.Sequential( nn.Conv3d( mid_dim, mid_dim, kernel_size=(1, 3, 3), padding=(0, 1, 1), groups=mid_dim, bias=True, ), nn.GELU(), ) self.temporal = nn.Sequential( nn.Conv3d( mid_dim, mid_dim, kernel_size=(3, 1, 1), padding=(1, 0, 0), groups=mid_dim, bias=True, ), nn.GELU(), ) self.mixed = nn.Sequential( nn.GroupNorm(gn2, mid_dim), nn.Conv3d( mid_dim, mid_dim, kernel_size=3, padding=1, groups=mix_groups, bias=True, ), nn.GELU(), ) self.fuse = nn.Sequential( nn.Conv3d(mid_dim, channels, kernel_size=1, bias=True), nn.GELU(), ) self.gate = _ChannelGate(channels, reduction=gate_reduction) self.dropout = nn.Dropout3d(dropout_p) if dropout_p > 0 else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: h = self.pre(x) h = self.spatial(h) + self.temporal(h) + self.mixed(h) h = self.fuse(h) h = self.gate(h) h = self.dropout(h) return x + h class PredecoderFastHyperRF13V1(nn.Module): """ Faster-stronger candidate for model 6 under the public Ising-Decoding API. Input / output shape: (B, 4, T, D, D) -> (B, 4, T, D, D) """ def __init__( self, input_channels: int = 4, out_channels: int = 4, hidden_dim: int = 96, mid_dim: int = 144, mix_groups: int = 6, num_blocks: int = 5, stem_kernel_size: int = 3, dropout_p: float = 0.02, gate_reduction: int = 4, **_: Any, ) -> None: super().__init__() pad = stem_kernel_size // 2 gn = _choose_gn_groups(hidden_dim) self.stem = nn.Sequential( nn.Conv3d( input_channels, hidden_dim, kernel_size=stem_kernel_size, padding=pad, bias=True, ), nn.GroupNorm(gn, hidden_dim), nn.GELU(), ) self.blocks = nn.Sequential(*[ _FastHyperBlock( channels=hidden_dim, mid_dim=mid_dim, mix_groups=mix_groups, dropout_p=dropout_p, gate_reduction=gate_reduction, ) for _ in range(num_blocks) ]) self.head = nn.Sequential( nn.GroupNorm(gn, hidden_dim), nn.Conv3d(hidden_dim, hidden_dim, kernel_size=1, bias=True), nn.GELU(), nn.Conv3d(hidden_dim, out_channels, kernel_size=1, bias=True), ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.stem(x) x = self.blocks(x) x = self.head(x) return x