File size: 2,608 Bytes
65bee5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn as nn
from monai.networks.nets import ViT
import os


class ViTBackboneNet(nn.Module):
    def __init__(self, simclr_ckpt_path: str):
        super().__init__()
        self.backbone = ViT(
            in_channels=1,
            img_size=(96, 96, 96),
            patch_size=(16, 16, 16),
            hidden_size=768,
            mlp_dim=3072,
            num_layers=12,
            num_heads=12,
            save_attn=True,
        )
        # Load pretrained weights from SimCLR checkpoint if provided
        if simclr_ckpt_path and os.path.exists(simclr_ckpt_path):
            ckpt = torch.load(simclr_ckpt_path, map_location="cpu", weights_only=False)
            state_dict = ckpt.get("state_dict", ckpt)
            backbone_state_dict = {}
            for key, value in state_dict.items():
                if key.startswith("backbone."):
                    new_key = key[len("backbone."):]
                    backbone_state_dict[new_key] = value
            missing, unexpected = self.backbone.load_state_dict(backbone_state_dict, strict=False)
            print(f"Loaded SimCLR backbone weights. Missing: {len(missing)}, Unexpected: {len(unexpected)}")
        else:
            print("Warning: SimCLR checkpoint not found or not provided. Using randomly initialized backbone.")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.backbone(x)
        cls_token = features[0][:, 0]
        return cls_token


class Classifier(nn.Module):
    def __init__(self, d_model: int = 768, num_classes: int = 1):
        super().__init__()
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc(x)


class SingleScanModelBP(nn.Module):
    def __init__(self, backbone: nn.Module, classifier: nn.Module):
        super().__init__()
        self.backbone = backbone
        self.classifier = classifier
        self.dropout = nn.Dropout(p=0.2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x shape: (batch_size, 2, C, D, H, W)
        scan_features_list = []
        for scan_tensor_with_extra_dim in x.split(1, dim=1):
            squeezed_scan_tensor = scan_tensor_with_extra_dim.squeeze(1)
            feature = self.backbone(squeezed_scan_tensor)
            scan_features_list.append(feature)
        stacked_features = torch.stack(scan_features_list, dim=1)
        merged_features = torch.mean(stacked_features, dim=1)
        merged_features = self.dropout(merged_features)
        output = self.classifier(merged_features)
        return output