| |
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torchxrayvision as xrv
|
| | from torchvision import models
|
| | import timm
|
| |
|
| | class DenseNetTB(nn.Module):
|
| | """DenseNet121 for TB detection"""
|
| |
|
| | def __init__(self, pretrained=True):
|
| | super().__init__()
|
| | if pretrained:
|
| | self.model = xrv.models.DenseNet(weights="densenet121-res224-all")
|
| | self.model.op_threshs = None
|
| | else:
|
| | self.model = xrv.models.DenseNet(weights=None)
|
| |
|
| |
|
| | self.model.classifier = nn.Linear(self.model.classifier.in_features, 1)
|
| |
|
| | def forward(self, x):
|
| | return self.model(x)
|
| |
|
| | class EfficientNetTB(nn.Module):
|
| | """EfficientNet-B3 for TB detection"""
|
| |
|
| | def __init__(self, pretrained=True):
|
| | super().__init__()
|
| | self.model = timm.create_model('efficientnet_b3', pretrained=pretrained, num_classes=1, in_chans=1)
|
| |
|
| | def forward(self, x):
|
| | return self.model(x).squeeze(-1)
|
| |
|
| | class ResNetTB(nn.Module):
|
| | """ResNet50 for TB detection"""
|
| |
|
| | def __init__(self, pretrained=True):
|
| | super().__init__()
|
| | self.model = models.resnet50(pretrained=pretrained)
|
| |
|
| |
|
| | self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
| |
|
| |
|
| | self.model.fc = nn.Linear(self.model.fc.in_features, 1)
|
| |
|
| | def forward(self, x):
|
| | return self.model(x).squeeze(-1)
|
| |
|
| | class TBEnsemble(nn.Module):
|
| | """Ensemble of multiple models with weighted voting"""
|
| |
|
| | def __init__(self, weights=None, use_mc_dropout=False, dropout_rate=0.3):
|
| | super().__init__()
|
| |
|
| | self.densenet = DenseNetTB(pretrained=True)
|
| | self.efficientnet = EfficientNetTB(pretrained=True)
|
| | self.resnet = ResNetTB(pretrained=True)
|
| |
|
| |
|
| | if weights is None:
|
| | self.weights = torch.tensor([0.4, 0.35, 0.25])
|
| | else:
|
| | self.weights = torch.tensor(weights)
|
| |
|
| | self.use_mc_dropout = use_mc_dropout
|
| | if use_mc_dropout:
|
| | self.dropout = nn.Dropout(dropout_rate)
|
| |
|
| | def forward(self, x):
|
| |
|
| | pred_densenet = torch.sigmoid(self.densenet(x)).squeeze()
|
| | pred_efficientnet = torch.sigmoid(self.efficientnet(x)).squeeze()
|
| | pred_resnet = torch.sigmoid(self.resnet(x)).squeeze()
|
| |
|
| |
|
| | ensemble_pred = (
|
| | self.weights[0] * pred_densenet +
|
| | self.weights[1] * pred_efficientnet +
|
| | self.weights[2] * pred_resnet
|
| | )
|
| |
|
| | return ensemble_pred
|
| |
|
| | def _forward_with_dropout(self, x):
|
| | """Forward pass with dropout on logits for MC uncertainty estimation"""
|
| |
|
| | logit_densenet = self.densenet(x).squeeze()
|
| | logit_efficientnet = self.efficientnet(x).squeeze()
|
| | logit_resnet = self.resnet(x).squeeze()
|
| |
|
| |
|
| | logit_densenet = self.dropout(logit_densenet)
|
| | logit_efficientnet = self.dropout(logit_efficientnet)
|
| | logit_resnet = self.dropout(logit_resnet)
|
| |
|
| |
|
| | pred_densenet = torch.sigmoid(logit_densenet)
|
| | pred_efficientnet = torch.sigmoid(logit_efficientnet)
|
| | pred_resnet = torch.sigmoid(logit_resnet)
|
| |
|
| |
|
| | ensemble_pred = (
|
| | self.weights[0] * pred_densenet +
|
| | self.weights[1] * pred_efficientnet +
|
| | self.weights[2] * pred_resnet
|
| | )
|
| |
|
| | return ensemble_pred
|
| |
|
| | def predict_with_uncertainty(self, x, n_samples=20):
|
| | """MC Dropout uncertainty estimation"""
|
| |
|
| |
|
| | self.eval()
|
| | self.dropout.train()
|
| |
|
| | predictions = []
|
| | with torch.no_grad():
|
| | for _ in range(n_samples):
|
| | pred = self._forward_with_dropout(x)
|
| | predictions.append(pred)
|
| |
|
| | predictions = torch.stack(predictions)
|
| |
|
| | mean_pred = predictions.mean(dim=0)
|
| | std_pred = predictions.std(dim=0)
|
| |
|
| | return mean_pred, std_pred
|
| |
|
| | def load_ensemble(checkpoint_path=None, device='cuda'):
|
| | """Load ensemble model"""
|
| | model = TBEnsemble(use_mc_dropout=True)
|
| |
|
| | if checkpoint_path:
|
| | model.load_state_dict(torch.load(checkpoint_path, map_location=device))
|
| |
|
| | model = model.to(device)
|
| | model.eval()
|
| | return model
|
| |
|
| | if __name__ == "__main__":
|
| |
|
| | model = TBEnsemble(use_mc_dropout=True)
|
| |
|
| |
|
| | x = torch.randn(2, 1, 224, 224)
|
| |
|
| |
|
| | output = model(x)
|
| | print(f"Output shape: {output.shape}")
|
| | print(f"Output: {output}")
|
| |
|
| |
|
| | mean, std = model.predict_with_uncertainty(x, n_samples=10)
|
| | print(f"\nMean prediction: {mean}")
|
| | print(f"Std prediction: {std}")
|
| |
|
| | print("\n✅ Ensemble model test passed")
|
| |
|