File size: 4,813 Bytes
853f08d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
#!/usr/bin/env python3
"""
example_usage.py
Demonstrates how to use the BYOL Mammogram model for feature extraction
and classification tasks.
"""
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import numpy as np
from pathlib import Path
# Import the BYOL model classes
from train_byol_mammo import MammogramBYOL
def load_byol_model(checkpoint_path: str, device: torch.device):
"""Load the pre-trained BYOL model for feature extraction."""
print(f"π₯ Loading BYOL model from: {checkpoint_path}")
# Create ResNet50 backbone (same as training)
resnet = models.resnet50(weights=None)
backbone = nn.Sequential(*list(resnet.children())[:-1])
# Initialize BYOL model with same architecture
model = MammogramBYOL(
backbone=backbone,
input_dim=2048, # ResNet50 feature dimension
hidden_dim=4096, # BYOL projection head hidden dim
proj_dim=256 # BYOL projection dimension
).to(device)
# Load the trained weights
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"β
Model loaded successfully!")
print(f" Epoch: {checkpoint.get('epoch', 'Unknown')}")
print(f" Final loss: {checkpoint.get('loss', 'Unknown'):.4f}")
return model
def create_inference_transform(tile_size: int = 512):
"""Create the preprocessing transform for inference."""
return transforms.Compose([
transforms.Resize((tile_size, tile_size), antialias=True),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
def extract_features(model, image_tensor, device):
"""Extract 2048-dimensional features from mammogram tiles."""
with torch.no_grad():
image_tensor = image_tensor.to(device)
features = model.get_features(image_tensor)
return features.cpu().numpy()
def main():
"""Demonstrate model usage."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"π₯οΈ Using device: {device}")
# Load the pre-trained BYOL model
model = load_byol_model("mammogram_byol_best.pth", device)
# Create preprocessing transform
transform = create_inference_transform(tile_size=512)
# Example 1: Feature extraction from a single image
print("\nπ Example 1: Feature Extraction")
print("-" * 40)
# Create a dummy mammogram tile (replace with actual image loading)
dummy_image = Image.fromarray(np.random.randint(0, 255, (512, 512), dtype=np.uint8))
dummy_image = dummy_image.convert('RGB') # Convert to RGB as expected
# Preprocess the image
image_tensor = transform(dummy_image).unsqueeze(0) # Add batch dimension
# Extract features
features = extract_features(model, image_tensor, device)
print(f"β
Input shape: {image_tensor.shape}")
print(f"β
Feature shape: {features.shape}")
print(f"β
Feature vector (first 10 values): {features[0][:10]}")
# Example 2: Batch processing multiple images
print("\nπ Example 2: Batch Feature Extraction")
print("-" * 40)
# Create a batch of dummy images
batch_size = 4
dummy_batch = torch.stack([
transform(Image.fromarray(np.random.randint(0, 255, (512, 512), dtype=np.uint8)).convert('RGB'))
for _ in range(batch_size)
])
# Extract features for the entire batch
batch_features = extract_features(model, dummy_batch, device)
print(f"β
Batch input shape: {dummy_batch.shape}")
print(f"β
Batch features shape: {batch_features.shape}")
print(f"β
Features per image: {batch_features.shape[1]} dimensions")
# Example 3: Similarity computation
print("\nπ Example 3: Feature Similarity")
print("-" * 40)
# Compute cosine similarity between first two images
from sklearn.metrics.pairwise import cosine_similarity
similarity = cosine_similarity(
batch_features[0:1],
batch_features[1:2]
)[0][0]
print(f"β
Cosine similarity between image 1 and 2: {similarity:.4f}")
print("\nπ― Next Steps:")
print("- Use these 2048D features for downstream classification")
print("- Train a classifier using train_classification.py")
print("- Fine-tune the entire model for specific tasks")
print("- Use for similarity search or clustering")
print(f"\nπ Model Summary:")
print(f"- Architecture: ResNet50 + BYOL")
print(f"- Input: 512x512 RGB mammogram tiles")
print(f"- Output: 2048-dimensional feature vectors")
print(f"- Training: Self-supervised on breast tissue tiles")
print(f"- Use case: Medical image analysis and classification")
if __name__ == "__main__":
main() |