QuantiSpect-V1 / code /model /predecoder_fasthyper_rf13_v1.py
donghufeng
init
d57fabf
from __future__ import annotations
from typing import Any
import torch
import torch.nn as nn
def _choose_gn_groups(channels: int, max_groups: int = 8) -> int:
for g in range(min(max_groups, channels), 0, -1):
if channels % g == 0:
return g
return 1
class _ChannelGate(nn.Module):
def __init__(self, channels: int, reduction: int = 4) -> None:
super().__init__()
hidden = max(channels // reduction, 8)
self.pool = nn.AdaptiveAvgPool3d(1)
self.fc1 = nn.Conv3d(channels, hidden, kernel_size=1, bias=True)
self.act = nn.GELU()
self.fc2 = nn.Conv3d(hidden, channels, kernel_size=1, bias=True)
self.gate = nn.Sigmoid()
def forward(self, x: torch.Tensor) -> torch.Tensor:
s = self.pool(x)
s = self.fc1(s)
s = self.act(s)
s = self.fc2(s)
return x * self.gate(s)
class _FastHyperBlock(nn.Module):
"""
Efficient RF-expanding residual block.
Each block contributes one effective k=3 receptive-field expansion stage via
three parallel branches operating on the same expanded activation:
- spatial depthwise (1,3,3)
- temporal depthwise (3,1,1)
- grouped 3D mixing (3,3,3)
"""
def __init__(
self,
channels: int,
mid_dim: int,
mix_groups: int = 6,
dropout_p: float = 0.02,
gate_reduction: int = 4,
) -> None:
super().__init__()
gn1 = _choose_gn_groups(channels)
gn2 = _choose_gn_groups(mid_dim)
mix_groups = max(1, min(mix_groups, mid_dim))
while mid_dim % mix_groups != 0 and mix_groups > 1:
mix_groups -= 1
self.pre = nn.Sequential(
nn.GroupNorm(gn1, channels),
nn.Conv3d(channels, mid_dim, kernel_size=1, bias=True),
nn.GELU(),
)
self.spatial = nn.Sequential(
nn.Conv3d(
mid_dim,
mid_dim,
kernel_size=(1, 3, 3),
padding=(0, 1, 1),
groups=mid_dim,
bias=True,
),
nn.GELU(),
)
self.temporal = nn.Sequential(
nn.Conv3d(
mid_dim,
mid_dim,
kernel_size=(3, 1, 1),
padding=(1, 0, 0),
groups=mid_dim,
bias=True,
),
nn.GELU(),
)
self.mixed = nn.Sequential(
nn.GroupNorm(gn2, mid_dim),
nn.Conv3d(
mid_dim,
mid_dim,
kernel_size=3,
padding=1,
groups=mix_groups,
bias=True,
),
nn.GELU(),
)
self.fuse = nn.Sequential(
nn.Conv3d(mid_dim, channels, kernel_size=1, bias=True),
nn.GELU(),
)
self.gate = _ChannelGate(channels, reduction=gate_reduction)
self.dropout = nn.Dropout3d(dropout_p) if dropout_p > 0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.pre(x)
h = self.spatial(h) + self.temporal(h) + self.mixed(h)
h = self.fuse(h)
h = self.gate(h)
h = self.dropout(h)
return x + h
class PredecoderFastHyperRF13V1(nn.Module):
"""
Faster-stronger candidate for model 6 under the public Ising-Decoding API.
Input / output shape:
(B, 4, T, D, D) -> (B, 4, T, D, D)
"""
def __init__(
self,
input_channels: int = 4,
out_channels: int = 4,
hidden_dim: int = 96,
mid_dim: int = 144,
mix_groups: int = 6,
num_blocks: int = 5,
stem_kernel_size: int = 3,
dropout_p: float = 0.02,
gate_reduction: int = 4,
**_: Any,
) -> None:
super().__init__()
pad = stem_kernel_size // 2
gn = _choose_gn_groups(hidden_dim)
self.stem = nn.Sequential(
nn.Conv3d(
input_channels,
hidden_dim,
kernel_size=stem_kernel_size,
padding=pad,
bias=True,
),
nn.GroupNorm(gn, hidden_dim),
nn.GELU(),
)
self.blocks = nn.Sequential(*[
_FastHyperBlock(
channels=hidden_dim,
mid_dim=mid_dim,
mix_groups=mix_groups,
dropout_p=dropout_p,
gate_reduction=gate_reduction,
) for _ in range(num_blocks)
])
self.head = nn.Sequential(
nn.GroupNorm(gn, hidden_dim),
nn.Conv3d(hidden_dim, hidden_dim, kernel_size=1, bias=True),
nn.GELU(),
nn.Conv3d(hidden_dim, out_channels, kernel_size=1, bias=True),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.stem(x)
x = self.blocks(x)
x = self.head(x)
return x