PranayPalem commited on
Commit
853f08d
Β·
1 Parent(s): d921913

πŸ“– Add example usage script

Browse files

- Demonstrates BYOL model loading and feature extraction
- Shows preprocessing pipeline for inference
- Includes batch processing examples
- Feature similarity computation example
- Complete documentation for model usage

Ready-to-use code for:
βœ… Loading pre-trained BYOL model
βœ… Feature extraction from mammogram tiles
βœ… Batch processing capabilities
βœ… Downstream task preparation

Files changed (1) hide show
  1. example_usage.py +143 -0
example_usage.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ example_usage.py
4
+
5
+ Demonstrates how to use the BYOL Mammogram model for feature extraction
6
+ and classification tasks.
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torchvision import models, transforms
12
+ from PIL import Image
13
+ import numpy as np
14
+ from pathlib import Path
15
+
16
+ # Import the BYOL model classes
17
+ from train_byol_mammo import MammogramBYOL
18
+
19
+
20
+ def load_byol_model(checkpoint_path: str, device: torch.device):
21
+ """Load the pre-trained BYOL model for feature extraction."""
22
+
23
+ print(f"πŸ“₯ Loading BYOL model from: {checkpoint_path}")
24
+
25
+ # Create ResNet50 backbone (same as training)
26
+ resnet = models.resnet50(weights=None)
27
+ backbone = nn.Sequential(*list(resnet.children())[:-1])
28
+
29
+ # Initialize BYOL model with same architecture
30
+ model = MammogramBYOL(
31
+ backbone=backbone,
32
+ input_dim=2048, # ResNet50 feature dimension
33
+ hidden_dim=4096, # BYOL projection head hidden dim
34
+ proj_dim=256 # BYOL projection dimension
35
+ ).to(device)
36
+
37
+ # Load the trained weights
38
+ checkpoint = torch.load(checkpoint_path, map_location=device)
39
+ model.load_state_dict(checkpoint['model_state_dict'])
40
+ model.eval()
41
+
42
+ print(f"βœ… Model loaded successfully!")
43
+ print(f" Epoch: {checkpoint.get('epoch', 'Unknown')}")
44
+ print(f" Final loss: {checkpoint.get('loss', 'Unknown'):.4f}")
45
+
46
+ return model
47
+
48
+
49
+ def create_inference_transform(tile_size: int = 512):
50
+ """Create the preprocessing transform for inference."""
51
+ return transforms.Compose([
52
+ transforms.Resize((tile_size, tile_size), antialias=True),
53
+ transforms.ToTensor(),
54
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
55
+ ])
56
+
57
+
58
+ def extract_features(model, image_tensor, device):
59
+ """Extract 2048-dimensional features from mammogram tiles."""
60
+ with torch.no_grad():
61
+ image_tensor = image_tensor.to(device)
62
+ features = model.get_features(image_tensor)
63
+ return features.cpu().numpy()
64
+
65
+
66
+ def main():
67
+ """Demonstrate model usage."""
68
+
69
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
+ print(f"πŸ–₯️ Using device: {device}")
71
+
72
+ # Load the pre-trained BYOL model
73
+ model = load_byol_model("mammogram_byol_best.pth", device)
74
+
75
+ # Create preprocessing transform
76
+ transform = create_inference_transform(tile_size=512)
77
+
78
+ # Example 1: Feature extraction from a single image
79
+ print("\nπŸ“Š Example 1: Feature Extraction")
80
+ print("-" * 40)
81
+
82
+ # Create a dummy mammogram tile (replace with actual image loading)
83
+ dummy_image = Image.fromarray(np.random.randint(0, 255, (512, 512), dtype=np.uint8))
84
+ dummy_image = dummy_image.convert('RGB') # Convert to RGB as expected
85
+
86
+ # Preprocess the image
87
+ image_tensor = transform(dummy_image).unsqueeze(0) # Add batch dimension
88
+
89
+ # Extract features
90
+ features = extract_features(model, image_tensor, device)
91
+
92
+ print(f"βœ… Input shape: {image_tensor.shape}")
93
+ print(f"βœ… Feature shape: {features.shape}")
94
+ print(f"βœ… Feature vector (first 10 values): {features[0][:10]}")
95
+
96
+ # Example 2: Batch processing multiple images
97
+ print("\nπŸ“Š Example 2: Batch Feature Extraction")
98
+ print("-" * 40)
99
+
100
+ # Create a batch of dummy images
101
+ batch_size = 4
102
+ dummy_batch = torch.stack([
103
+ transform(Image.fromarray(np.random.randint(0, 255, (512, 512), dtype=np.uint8)).convert('RGB'))
104
+ for _ in range(batch_size)
105
+ ])
106
+
107
+ # Extract features for the entire batch
108
+ batch_features = extract_features(model, dummy_batch, device)
109
+
110
+ print(f"βœ… Batch input shape: {dummy_batch.shape}")
111
+ print(f"βœ… Batch features shape: {batch_features.shape}")
112
+ print(f"βœ… Features per image: {batch_features.shape[1]} dimensions")
113
+
114
+ # Example 3: Similarity computation
115
+ print("\nπŸ“Š Example 3: Feature Similarity")
116
+ print("-" * 40)
117
+
118
+ # Compute cosine similarity between first two images
119
+ from sklearn.metrics.pairwise import cosine_similarity
120
+
121
+ similarity = cosine_similarity(
122
+ batch_features[0:1],
123
+ batch_features[1:2]
124
+ )[0][0]
125
+
126
+ print(f"βœ… Cosine similarity between image 1 and 2: {similarity:.4f}")
127
+
128
+ print("\n🎯 Next Steps:")
129
+ print("- Use these 2048D features for downstream classification")
130
+ print("- Train a classifier using train_classification.py")
131
+ print("- Fine-tune the entire model for specific tasks")
132
+ print("- Use for similarity search or clustering")
133
+
134
+ print(f"\nπŸ“š Model Summary:")
135
+ print(f"- Architecture: ResNet50 + BYOL")
136
+ print(f"- Input: 512x512 RGB mammogram tiles")
137
+ print(f"- Output: 2048-dimensional feature vectors")
138
+ print(f"- Training: Self-supervised on breast tissue tiles")
139
+ print(f"- Use case: Medical image analysis and classification")
140
+
141
+
142
+ if __name__ == "__main__":
143
+ main()