chexvision-demo / tests /test_models.py
arudaev's picture
Fix CI, unblock Kaggle training dispatch, upgrade demo app for HF Spaces
0e8457d
"""Unit tests for CheXVision model architectures."""
import pytest
import torch
from src.models.densenet_transfer import CheXVisionDenseNet
from src.models.scratch_cnn import CheXVisionScratch
@pytest.fixture
def dummy_input() -> torch.Tensor:
"""Batch of 2 fake 224x224 RGB images."""
return torch.randn(2, 3, 224, 224)
class TestCheXVisionScratch:
def test_output_shapes(self, dummy_input: torch.Tensor) -> None:
model = CheXVisionScratch(in_channels=3, num_classes=14)
outputs = model(dummy_input)
assert "multilabel_logits" in outputs
assert "binary_logits" in outputs
assert outputs["multilabel_logits"].shape == (2, 14)
assert outputs["binary_logits"].shape == (2, 1)
def test_custom_block_config(self, dummy_input: torch.Tensor) -> None:
model = CheXVisionScratch(
in_channels=3,
num_classes=14,
block_config=(1, 1, 1, 1),
filter_sizes=(32, 64, 128, 256),
)
outputs = model(dummy_input)
assert outputs["multilabel_logits"].shape == (2, 14)
def test_gradient_flow(self, dummy_input: torch.Tensor) -> None:
model = CheXVisionScratch()
outputs = model(dummy_input)
loss = outputs["multilabel_logits"].sum() + outputs["binary_logits"].sum()
loss.backward()
# Verify gradients exist for all parameters
for name, param in model.named_parameters():
if param.requires_grad:
assert param.grad is not None, f"No gradient for {name}"
class TestCheXVisionDenseNet:
def test_output_shapes(self, dummy_input: torch.Tensor) -> None:
model = CheXVisionDenseNet(num_classes=14, pretrained=False)
outputs = model(dummy_input)
assert outputs["multilabel_logits"].shape == (2, 14)
assert outputs["binary_logits"].shape == (2, 1)
def test_freeze_unfreeze(self) -> None:
model = CheXVisionDenseNet(pretrained=False, freeze_backbone=True)
# Check backbone is frozen
for param in model.backbone.parameters():
assert not param.requires_grad
# Check heads are trainable
for param in model.multilabel_head.parameters():
assert param.requires_grad
# Unfreeze
model.unfreeze_backbone()
for param in model.backbone.parameters():
assert param.requires_grad
def test_gradient_flow(self, dummy_input: torch.Tensor) -> None:
model = CheXVisionDenseNet(pretrained=False)
outputs = model(dummy_input)
loss = outputs["multilabel_logits"].sum() + outputs["binary_logits"].sum()
loss.backward()
for name, param in model.named_parameters():
if param.requires_grad:
assert param.grad is not None, f"No gradient for {name}"