File size: 1,573 Bytes
9916246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights

class CliniScanClassifier(nn.Module):
    def __init__(self, num_classes=15, freeze_features=True):
        """
        Original ResNet50 implementation for AI-CliniScan abnormality classification.
        """
        super(CliniScanClassifier, self).__init__()
        # Load pretrained ResNet50
        self.backbone = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        
        if freeze_features:
            for param in self.backbone.parameters():
                param.requires_grad = False
                
        # Unfreeze layer4 for fine-tuning
        for param in self.backbone.layer4.parameters():
            param.requires_grad = True
            
        in_features = self.backbone.fc.in_features
        # Replace the fully connected layer for multi-label classification
        self.backbone.fc = nn.Sequential(
            nn.Dropout(p=0.3),
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(512, num_classes)
        )
        
    def forward(self, x):
        return self.backbone(x)

    def extract_features(self, x):
        """Used for Grad-CAM Visualization"""
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        features = self.backbone.layer4(x)
        return features