Spaces:
Running
Running
| """Unit tests for CheXVision model architectures.""" | |
| import pytest | |
| import torch | |
| from src.models.densenet_transfer import CheXVisionDenseNet | |
| from src.models.scratch_cnn import CheXVisionScratch | |
| def dummy_input() -> torch.Tensor: | |
| """Batch of 2 fake 224x224 RGB images.""" | |
| return torch.randn(2, 3, 224, 224) | |
| class TestCheXVisionScratch: | |
| def test_output_shapes(self, dummy_input: torch.Tensor) -> None: | |
| model = CheXVisionScratch(in_channels=3, num_classes=14) | |
| outputs = model(dummy_input) | |
| assert "multilabel_logits" in outputs | |
| assert "binary_logits" in outputs | |
| assert outputs["multilabel_logits"].shape == (2, 14) | |
| assert outputs["binary_logits"].shape == (2, 1) | |
| def test_custom_block_config(self, dummy_input: torch.Tensor) -> None: | |
| model = CheXVisionScratch( | |
| in_channels=3, | |
| num_classes=14, | |
| block_config=(1, 1, 1, 1), | |
| filter_sizes=(32, 64, 128, 256), | |
| ) | |
| outputs = model(dummy_input) | |
| assert outputs["multilabel_logits"].shape == (2, 14) | |
| def test_gradient_flow(self, dummy_input: torch.Tensor) -> None: | |
| model = CheXVisionScratch() | |
| outputs = model(dummy_input) | |
| loss = outputs["multilabel_logits"].sum() + outputs["binary_logits"].sum() | |
| loss.backward() | |
| # Verify gradients exist for all parameters | |
| for name, param in model.named_parameters(): | |
| if param.requires_grad: | |
| assert param.grad is not None, f"No gradient for {name}" | |
| class TestCheXVisionDenseNet: | |
| def test_output_shapes(self, dummy_input: torch.Tensor) -> None: | |
| model = CheXVisionDenseNet(num_classes=14, pretrained=False) | |
| outputs = model(dummy_input) | |
| assert outputs["multilabel_logits"].shape == (2, 14) | |
| assert outputs["binary_logits"].shape == (2, 1) | |
| def test_freeze_unfreeze(self) -> None: | |
| model = CheXVisionDenseNet(pretrained=False, freeze_backbone=True) | |
| # Check backbone is frozen | |
| for param in model.backbone.parameters(): | |
| assert not param.requires_grad | |
| # Check heads are trainable | |
| for param in model.multilabel_head.parameters(): | |
| assert param.requires_grad | |
| # Unfreeze | |
| model.unfreeze_backbone() | |
| for param in model.backbone.parameters(): | |
| assert param.requires_grad | |
| def test_gradient_flow(self, dummy_input: torch.Tensor) -> None: | |
| model = CheXVisionDenseNet(pretrained=False) | |
| outputs = model(dummy_input) | |
| loss = outputs["multilabel_logits"].sum() + outputs["binary_logits"].sum() | |
| loss.backward() | |
| for name, param in model.named_parameters(): | |
| if param.requires_grad: | |
| assert param.grad is not None, f"No gradient for {name}" | |