File size: 1,091 Bytes
8960670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#!/usr/bin/env python3
"""
Baseline 2D ResNet-18 (Slice-Level) β€” Standalone module.

This module provides the ResNet2D18SliceLevel class as a standalone import
per the ablation study specification. The actual implementation lives in
baselines.py; this file re-exports it for clean per-model imports.

Usage:
    from src.models.baseline_resnet2d import ResNet2D18SliceLevel
    model = ResNet2D18SliceLevel(num_classes=1)

Architecture:
    - Backbone: torchvision resnet18 (2D)
    - Input conv modified: 3-channel β†’ 1-channel (grayscale CT slices)
    - Each 3D volume (B, 1, D, H, W) is processed slice-by-slice:
        1. Reshape β†’ (B*D, 1, H, W)
        2. Forward through ResNet-18 backbone β†’ (B*D, 512)
        3. Reshape back β†’ (B, D, 512)
        4. Global Average Pool over depth β†’ (B, 512)
        5. Classification head β†’ (B, 1)
    - Forward signature: forward(nodule_patch, context_patch=None)
      (context_patch is ignored β€” 2D baseline uses nodule slices only)
"""

from src.models.baselines import ResNet2D18SliceLevel

__all__ = ['ResNet2D18SliceLevel']