Divyanshu Tak
Add BrainIAC IDH Classification app with Vision Transformer model
65bee5d
raw
history blame
2.61 kB
import torch
import torch.nn as nn
from monai.networks.nets import ViT
import os
class ViTBackboneNet(nn.Module):
def __init__(self, simclr_ckpt_path: str):
super().__init__()
self.backbone = ViT(
in_channels=1,
img_size=(96, 96, 96),
patch_size=(16, 16, 16),
hidden_size=768,
mlp_dim=3072,
num_layers=12,
num_heads=12,
save_attn=True,
)
# Load pretrained weights from SimCLR checkpoint if provided
if simclr_ckpt_path and os.path.exists(simclr_ckpt_path):
ckpt = torch.load(simclr_ckpt_path, map_location="cpu", weights_only=False)
state_dict = ckpt.get("state_dict", ckpt)
backbone_state_dict = {}
for key, value in state_dict.items():
if key.startswith("backbone."):
new_key = key[len("backbone."):]
backbone_state_dict[new_key] = value
missing, unexpected = self.backbone.load_state_dict(backbone_state_dict, strict=False)
print(f"Loaded SimCLR backbone weights. Missing: {len(missing)}, Unexpected: {len(unexpected)}")
else:
print("Warning: SimCLR checkpoint not found or not provided. Using randomly initialized backbone.")
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.backbone(x)
cls_token = features[0][:, 0]
return cls_token
class Classifier(nn.Module):
def __init__(self, d_model: int = 768, num_classes: int = 1):
super().__init__()
self.fc = nn.Linear(d_model, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fc(x)
class SingleScanModelBP(nn.Module):
def __init__(self, backbone: nn.Module, classifier: nn.Module):
super().__init__()
self.backbone = backbone
self.classifier = classifier
self.dropout = nn.Dropout(p=0.2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x shape: (batch_size, 2, C, D, H, W)
scan_features_list = []
for scan_tensor_with_extra_dim in x.split(1, dim=1):
squeezed_scan_tensor = scan_tensor_with_extra_dim.squeeze(1)
feature = self.backbone(squeezed_scan_tensor)
scan_features_list.append(feature)
stacked_features = torch.stack(scan_features_list, dim=1)
merged_features = torch.mean(stacked_features, dim=1)
merged_features = self.dropout(merged_features)
output = self.classifier(merged_features)
return output