Spaces:
Sleeping
Sleeping
| """Model 1: Custom ResNet-style CNN built entirely from scratch. | |
| Design decisions: | |
| - Residual connections: Enable training of deeper networks by mitigating vanishing gradients. | |
| - Batch normalization: Stabilizes training and allows higher learning rates. | |
| - Global average pooling: Reduces parameters vs. fully-connected layers, less prone to overfitting. | |
| - SE (Squeeze-and-Excitation) attention: Channel-wise recalibration β the model learns WHICH | |
| feature maps matter most for each pathology. Critical for multi-label medical imaging where | |
| different disease channels compete. Based on Hu et al. 2018 (CVPR best paper). | |
| - Dual heads: Shared backbone extracts features once; separate heads specialize per task. | |
| - Sigmoid activation: Multi-label (not mutually exclusive) requires independent probabilities per class. | |
| - Depth [3,4,6,3]: ResNet-50 equivalent depth β 4x more capacity than the original ResNet-18 | |
| style [2,2,2,2] baseline, justified by the complexity of 14 simultaneous pathology signals. | |
| """ | |
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| class SEBlock(nn.Module): | |
| """Squeeze-and-Excitation block β channel-wise attention. | |
| Mechanism: | |
| 1. Squeeze: Global average pool β (B, C) descriptor of each channel's global response. | |
| 2. Excitation: Two FC layers learn per-channel importance weights (gating with sigmoid). | |
| 3. Scale: Multiply original feature map channels by learned weights. | |
| Why it helps for chest X-ray classification: | |
| Different pathologies activate different feature channels. SE teaches the network to | |
| amplify disease-relevant channels and suppress background texture channels β effectively | |
| a form of disease-specific feature selection. | |
| Reference: Hu et al., "Squeeze-and-Excitation Networks", CVPR 2018. | |
| """ | |
| def __init__(self, channels: int, reduction: int = 16) -> None: | |
| super().__init__() | |
| bottleneck = max(channels // reduction, 4) | |
| self.squeeze = nn.AdaptiveAvgPool2d(1) | |
| self.excitation = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Linear(channels, bottleneck, bias=False), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(bottleneck, channels, bias=False), | |
| nn.Sigmoid(), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| b, c, _, _ = x.shape | |
| # Squeeze: global spatial information into channel descriptor | |
| scale = self.squeeze(x).view(b, c) | |
| # Excitation: learn channel importance weights | |
| scale = self.excitation(scale).view(b, c, 1, 1) | |
| # Scale: re-calibrate channel-wise feature responses | |
| return x * scale | |
| class ResidualBlock(nn.Module): | |
| """Basic residual block with two 3Γ3 convolutions, a skip connection, and optional SE attention.""" | |
| def __init__( | |
| self, in_channels: int, out_channels: int, stride: int = 1, use_se: bool = False | |
| ) -> None: | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) | |
| self.bn1 = nn.BatchNorm2d(out_channels) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) | |
| self.bn2 = nn.BatchNorm2d(out_channels) | |
| # SE attention applied after second conv, before residual addition | |
| self.se = SEBlock(out_channels) if use_se else nn.Identity() | |
| # Skip connection with 1Γ1 conv if dimensions change | |
| self.shortcut: nn.Module = nn.Identity() | |
| if stride != 1 or in_channels != out_channels: | |
| self.shortcut = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| residual = self.shortcut(x) | |
| out = self.relu(self.bn1(self.conv1(x))) | |
| out = self.bn2(self.conv2(out)) | |
| out = self.se(out) # Channel attention before residual merge | |
| out += residual | |
| out = self.relu(out) | |
| return out | |
| class CheXVisionScratch(nn.Module): | |
| """Custom ResNet-style CNN for chest X-ray classification. | |
| Architecture (default, ResNet-50 equivalent depth): | |
| - Stem: 7Γ7 conv (stride 2) + max pool β 64 channels, ΒΌ spatial resolution | |
| - Stage 1: 3Γ ResidualBlock [64 β 64, stride 1] | |
| - Stage 2: 4Γ ResidualBlock [64 β 128, stride 2] | |
| - Stage 3: 6Γ ResidualBlock [128 β 256, stride 2] | |
| - Stage 4: 3Γ ResidualBlock [256 β 512, stride 2] | |
| - Global average pooling β Dropout | |
| - Multilabel head: Linear(512 β 14) β 14 pathology classes (sigmoid) | |
| - Binary head: Linear(512 β 1) β Normal vs Abnormal (sigmoid) | |
| Parameter count (default config): ~23M parameters | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int = 3, | |
| num_classes: int = 14, | |
| block_config: tuple[int, ...] = (3, 4, 6, 3), | |
| filter_sizes: tuple[int, ...] = (64, 128, 256, 512), | |
| dropout: float = 0.5, | |
| use_se: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| # Initial convolution: 7Γ7 conv captures low-level features (edges, textures) | |
| # stride=2 immediately reduces spatial resolution for efficiency | |
| self.stem = nn.Sequential( | |
| nn.Conv2d(in_channels, filter_sizes[0], kernel_size=7, stride=2, padding=3, bias=False), | |
| nn.BatchNorm2d(filter_sizes[0]), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(kernel_size=3, stride=2, padding=1), | |
| ) | |
| # Build residual stages β progressive downsampling at each stage | |
| self.stages = nn.ModuleList() | |
| current_channels = filter_sizes[0] | |
| for i, (num_blocks, out_channels) in enumerate(zip(block_config, filter_sizes)): | |
| stride = 1 if i == 0 else 2 # Downsample at each stage except the first | |
| blocks = [ResidualBlock(current_channels, out_channels, stride=stride, use_se=use_se)] | |
| for _ in range(1, num_blocks): | |
| blocks.append(ResidualBlock(out_channels, out_channels, stride=1, use_se=use_se)) | |
| self.stages.append(nn.Sequential(*blocks)) | |
| current_channels = out_channels | |
| # Global average pooling: reduces (B, C, H, W) β (B, C) β no spatial information lost | |
| self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.dropout = nn.Dropout(p=dropout) | |
| # Multi-label classification head (14 pathologies β independent sigmoid per class) | |
| self.multilabel_head = nn.Linear(filter_sizes[-1], num_classes) | |
| # Binary classification head (Normal vs Abnormal) | |
| self.binary_head = nn.Linear(filter_sizes[-1], 1) | |
| # Initialize weights (Kaiming for ReLU networks) | |
| self._init_weights() | |
| def _init_weights(self) -> None: | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |
| elif isinstance(m, nn.BatchNorm2d): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: | |
| """Forward pass returning both task predictions. | |
| Args: | |
| x: Input tensor of shape (B, C, H, W). | |
| Returns: | |
| Dict with 'multilabel_logits' (B, 14) and 'binary_logits' (B, 1). | |
| """ | |
| x = self.stem(x) | |
| for stage in self.stages: | |
| x = stage(x) | |
| x = self.global_pool(x) | |
| x = torch.flatten(x, 1) | |
| x = self.dropout(x) | |
| return { | |
| "multilabel_logits": self.multilabel_head(x), | |
| "binary_logits": self.binary_head(x), | |
| } | |