chexvision-demo / src /models /scratch_cnn.py
arudaev's picture
fix(types): resolve all 8 mypy errors across 5 files
91347e8
"""Model 1: Custom ResNet-style CNN built entirely from scratch.
Design decisions:
- Residual connections: Enable training of deeper networks by mitigating vanishing gradients.
- Batch normalization: Stabilizes training and allows higher learning rates.
- Global average pooling: Reduces parameters vs. fully-connected layers, less prone to overfitting.
- SE (Squeeze-and-Excitation) attention: Channel-wise recalibration β€” the model learns WHICH
feature maps matter most for each pathology. Critical for multi-label medical imaging where
different disease channels compete. Based on Hu et al. 2018 (CVPR best paper).
- Dual heads: Shared backbone extracts features once; separate heads specialize per task.
- Sigmoid activation: Multi-label (not mutually exclusive) requires independent probabilities per class.
- Depth [3,4,6,3]: ResNet-50 equivalent depth β€” 4x more capacity than the original ResNet-18
style [2,2,2,2] baseline, justified by the complexity of 14 simultaneous pathology signals.
"""
from __future__ import annotations
import torch
import torch.nn as nn
class SEBlock(nn.Module):
"""Squeeze-and-Excitation block β€” channel-wise attention.
Mechanism:
1. Squeeze: Global average pool β†’ (B, C) descriptor of each channel's global response.
2. Excitation: Two FC layers learn per-channel importance weights (gating with sigmoid).
3. Scale: Multiply original feature map channels by learned weights.
Why it helps for chest X-ray classification:
Different pathologies activate different feature channels. SE teaches the network to
amplify disease-relevant channels and suppress background texture channels β€” effectively
a form of disease-specific feature selection.
Reference: Hu et al., "Squeeze-and-Excitation Networks", CVPR 2018.
"""
def __init__(self, channels: int, reduction: int = 16) -> None:
super().__init__()
bottleneck = max(channels // reduction, 4)
self.squeeze = nn.AdaptiveAvgPool2d(1)
self.excitation = nn.Sequential(
nn.Flatten(),
nn.Linear(channels, bottleneck, bias=False),
nn.ReLU(inplace=True),
nn.Linear(bottleneck, channels, bias=False),
nn.Sigmoid(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, c, _, _ = x.shape
# Squeeze: global spatial information into channel descriptor
scale = self.squeeze(x).view(b, c)
# Excitation: learn channel importance weights
scale = self.excitation(scale).view(b, c, 1, 1)
# Scale: re-calibrate channel-wise feature responses
return x * scale
class ResidualBlock(nn.Module):
"""Basic residual block with two 3Γ—3 convolutions, a skip connection, and optional SE attention."""
def __init__(
self, in_channels: int, out_channels: int, stride: int = 1, use_se: bool = False
) -> None:
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
# SE attention applied after second conv, before residual addition
self.se = SEBlock(out_channels) if use_se else nn.Identity()
# Skip connection with 1Γ—1 conv if dimensions change
self.shortcut: nn.Module = nn.Identity()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = self.shortcut(x)
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = self.se(out) # Channel attention before residual merge
out += residual
out = self.relu(out)
return out
class CheXVisionScratch(nn.Module):
"""Custom ResNet-style CNN for chest X-ray classification.
Architecture (default, ResNet-50 equivalent depth):
- Stem: 7Γ—7 conv (stride 2) + max pool β†’ 64 channels, ΒΌ spatial resolution
- Stage 1: 3Γ— ResidualBlock [64 β†’ 64, stride 1]
- Stage 2: 4Γ— ResidualBlock [64 β†’ 128, stride 2]
- Stage 3: 6Γ— ResidualBlock [128 β†’ 256, stride 2]
- Stage 4: 3Γ— ResidualBlock [256 β†’ 512, stride 2]
- Global average pooling β†’ Dropout
- Multilabel head: Linear(512 β†’ 14) β€” 14 pathology classes (sigmoid)
- Binary head: Linear(512 β†’ 1) β€” Normal vs Abnormal (sigmoid)
Parameter count (default config): ~23M parameters
"""
def __init__(
self,
in_channels: int = 3,
num_classes: int = 14,
block_config: tuple[int, ...] = (3, 4, 6, 3),
filter_sizes: tuple[int, ...] = (64, 128, 256, 512),
dropout: float = 0.5,
use_se: bool = True,
) -> None:
super().__init__()
# Initial convolution: 7Γ—7 conv captures low-level features (edges, textures)
# stride=2 immediately reduces spatial resolution for efficiency
self.stem = nn.Sequential(
nn.Conv2d(in_channels, filter_sizes[0], kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(filter_sizes[0]),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
)
# Build residual stages β€” progressive downsampling at each stage
self.stages = nn.ModuleList()
current_channels = filter_sizes[0]
for i, (num_blocks, out_channels) in enumerate(zip(block_config, filter_sizes)):
stride = 1 if i == 0 else 2 # Downsample at each stage except the first
blocks = [ResidualBlock(current_channels, out_channels, stride=stride, use_se=use_se)]
for _ in range(1, num_blocks):
blocks.append(ResidualBlock(out_channels, out_channels, stride=1, use_se=use_se))
self.stages.append(nn.Sequential(*blocks))
current_channels = out_channels
# Global average pooling: reduces (B, C, H, W) β†’ (B, C) β€” no spatial information lost
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
self.dropout = nn.Dropout(p=dropout)
# Multi-label classification head (14 pathologies β€” independent sigmoid per class)
self.multilabel_head = nn.Linear(filter_sizes[-1], num_classes)
# Binary classification head (Normal vs Abnormal)
self.binary_head = nn.Linear(filter_sizes[-1], 1)
# Initialize weights (Kaiming for ReLU networks)
self._init_weights()
def _init_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
"""Forward pass returning both task predictions.
Args:
x: Input tensor of shape (B, C, H, W).
Returns:
Dict with 'multilabel_logits' (B, 14) and 'binary_logits' (B, 1).
"""
x = self.stem(x)
for stage in self.stages:
x = stage(x)
x = self.global_pool(x)
x = torch.flatten(x, 1)
x = self.dropout(x)
return {
"multilabel_logits": self.multilabel_head(x),
"binary_logits": self.binary_head(x),
}