File size: 1,843 Bytes
8960670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#!/usr/bin/env python3
"""Malignancy classifier using a 3D ResNet backbone."""

from pathlib import Path

import torch
import torch.nn as nn


class MalignancyClassifier(nn.Module):
    """3D ResNet-based malignancy classifier."""

    def __init__(self, pretrained_path=None, use_torchvision_pretrained=True):
        super().__init__()

        try:
            from torchvision.models.video import R3D_18_Weights, r3d_18

            if use_torchvision_pretrained and pretrained_path is None:
                self.backbone = r3d_18(weights=R3D_18_Weights.DEFAULT)
            else:
                self.backbone = r3d_18(weights=None)
        except ImportError:
            from torchvision.models.video import r3d_18

            self.backbone = r3d_18(
                pretrained=(use_torchvision_pretrained and pretrained_path is None)
            )

        if pretrained_path and Path(pretrained_path).exists():
            try:
                state = torch.load(pretrained_path, map_location='cpu', weights_only=False)
                if 'state_dict' in state:
                    state = state['state_dict']
                clean_state = {key.replace('module.', ''): value for key, value in state.items()}
                self.backbone.load_state_dict(clean_state, strict=False)
                print(f"  Loaded MedicalNet weights from {pretrained_path}")
            except Exception as exc:
                print(f"  Warning: could not load MedicalNet weights: {exc}")

        self.backbone.fc = nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        """Forward pass on (B, 1, 32, 32, 32) patches."""
        x = x.repeat(1, 3, 1, 1, 1)
        return self.backbone(x)