File size: 4,813 Bytes
853f08d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#!/usr/bin/env python3
"""
example_usage.py

Demonstrates how to use the BYOL Mammogram model for feature extraction
and classification tasks.
"""

import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import numpy as np
from pathlib import Path

# Import the BYOL model classes
from train_byol_mammo import MammogramBYOL


def load_byol_model(checkpoint_path: str, device: torch.device):
    """Load the pre-trained BYOL model for feature extraction."""

    print(f"πŸ“₯ Loading BYOL model from: {checkpoint_path}")

    # Create ResNet50 backbone (same as training)
    resnet = models.resnet50(weights=None)
    backbone = nn.Sequential(*list(resnet.children())[:-1])

    # Initialize BYOL model with same architecture
    model = MammogramBYOL(
        backbone=backbone,
        input_dim=2048,      # ResNet50 feature dimension
        hidden_dim=4096,     # BYOL projection head hidden dim
        proj_dim=256         # BYOL projection dimension
    ).to(device)

    # Load the trained weights
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    print(f"βœ… Model loaded successfully!")
    print(f"   Epoch: {checkpoint.get('epoch', 'Unknown')}")
    print(f"   Final loss: {checkpoint.get('loss', 'Unknown'):.4f}")

    return model


def create_inference_transform(tile_size: int = 512):
    """Create the preprocessing transform for inference."""
    return transforms.Compose([
        transforms.Resize((tile_size, tile_size), antialias=True),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])


def extract_features(model, image_tensor, device):
    """Extract 2048-dimensional features from mammogram tiles."""
    with torch.no_grad():
        image_tensor = image_tensor.to(device)
        features = model.get_features(image_tensor)
    return features.cpu().numpy()


def main():
    """Demonstrate model usage."""

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"πŸ–₯️  Using device: {device}")

    # Load the pre-trained BYOL model
    model = load_byol_model("mammogram_byol_best.pth", device)

    # Create preprocessing transform
    transform = create_inference_transform(tile_size=512)

    # Example 1: Feature extraction from a single image
    print("\nπŸ“Š Example 1: Feature Extraction")
    print("-" * 40)

    # Create a dummy mammogram tile (replace with actual image loading)
    dummy_image = Image.fromarray(np.random.randint(0, 255, (512, 512), dtype=np.uint8))
    dummy_image = dummy_image.convert('RGB')  # Convert to RGB as expected

    # Preprocess the image
    image_tensor = transform(dummy_image).unsqueeze(0)  # Add batch dimension

    # Extract features
    features = extract_features(model, image_tensor, device)

    print(f"βœ… Input shape: {image_tensor.shape}")
    print(f"βœ… Feature shape: {features.shape}")
    print(f"βœ… Feature vector (first 10 values): {features[0][:10]}")

    # Example 2: Batch processing multiple images
    print("\nπŸ“Š Example 2: Batch Feature Extraction")
    print("-" * 40)

    # Create a batch of dummy images
    batch_size = 4
    dummy_batch = torch.stack([
        transform(Image.fromarray(np.random.randint(0, 255, (512, 512), dtype=np.uint8)).convert('RGB'))
        for _ in range(batch_size)
    ])

    # Extract features for the entire batch
    batch_features = extract_features(model, dummy_batch, device)

    print(f"βœ… Batch input shape: {dummy_batch.shape}")
    print(f"βœ… Batch features shape: {batch_features.shape}")
    print(f"βœ… Features per image: {batch_features.shape[1]} dimensions")

    # Example 3: Similarity computation
    print("\nπŸ“Š Example 3: Feature Similarity")
    print("-" * 40)

    # Compute cosine similarity between first two images
    from sklearn.metrics.pairwise import cosine_similarity

    similarity = cosine_similarity(
        batch_features[0:1],
        batch_features[1:2]
    )[0][0]

    print(f"βœ… Cosine similarity between image 1 and 2: {similarity:.4f}")

    print("\n🎯 Next Steps:")
    print("- Use these 2048D features for downstream classification")
    print("- Train a classifier using train_classification.py")
    print("- Fine-tune the entire model for specific tasks")
    print("- Use for similarity search or clustering")

    print(f"\nπŸ“š Model Summary:")
    print(f"- Architecture: ResNet50 + BYOL")
    print(f"- Input: 512x512 RGB mammogram tiles")
    print(f"- Output: 2048-dimensional feature vectors")
    print(f"- Training: Self-supervised on breast tissue tiles")
    print(f"- Use case: Medical image analysis and classification")


if __name__ == "__main__":
    main()