d3v-26 commited on
Commit
8224e41
·
verified ·
1 Parent(s): b93e0b2

Upload test_models.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. test_models.py +354 -0
test_models.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script to validate all BrainSegFounder model weights.
4
+ This script checks if models can be loaded and perform inference.
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import torch
10
+ from argparse import Namespace
11
+ from pathlib import Path
12
+
13
+ try:
14
+ from huggingface_hub import hf_hub_download
15
+ HF_AVAILABLE = True
16
+ except ImportError:
17
+ HF_AVAILABLE = False
18
+ print("Warning: huggingface_hub not available. Will only test local files.")
19
+
20
+ try:
21
+ from monai.networks.nets import SwinUNETR
22
+ MONAI_AVAILABLE = True
23
+ except ImportError:
24
+ MONAI_AVAILABLE = False
25
+ print("Warning: MONAI not available. Install with: pip install git+https://github.com/Project-MONAI/MONAI.git@a23c7f54")
26
+
27
+ # Try to import SSL_Head
28
+ try:
29
+ from SSL_Head import SSLHead
30
+ SSL_HEAD_AVAILABLE = True
31
+ except ImportError:
32
+ SSL_HEAD_AVAILABLE = False
33
+ print("Warning: SSL_Head.py not found in current directory or Python path.")
34
+
35
+
36
+ class ModelTester:
37
+ """Test suite for BrainSegFounder models."""
38
+
39
+ def __init__(self, use_local=True, use_hf=False):
40
+ """
41
+ Initialize the model tester.
42
+
43
+ Args:
44
+ use_local: Test local model files
45
+ use_hf: Download and test from Hugging Face
46
+ """
47
+ self.use_local = use_local
48
+ self.use_hf = use_hf
49
+ self.repo_id = "smilelab/BrainSegFounder"
50
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
+ print(f"Using device: {self.device}")
52
+ print("=" * 70)
53
+
54
+ def test_ssl_pretrain_model(self, model_path, model_name):
55
+ """Test SSL pretraining models (SSLHead)."""
56
+ if not SSL_HEAD_AVAILABLE:
57
+ print(f"⚠ Skipping {model_name}: SSL_Head not available")
58
+ return False
59
+
60
+ print(f"\n[Testing] {model_name}")
61
+ print("-" * 70)
62
+
63
+ try:
64
+ # Configure model
65
+ args = Namespace(
66
+ in_channels=2,
67
+ spatial_dims=3,
68
+ bottleneck_depth=768,
69
+ feature_size=48,
70
+ num_swin_blocks_per_stage=[2, 2, 2, 2],
71
+ num_heads_per_stage=[3, 6, 12, 24],
72
+ dropout_path_rate=0.0,
73
+ use_checkpoint=False
74
+ )
75
+
76
+ # Load model
77
+ print(f" Loading model from: {model_path}")
78
+ model = SSLHead(args)
79
+ checkpoint = torch.load(model_path, map_location="cpu")
80
+
81
+ # Handle checkpoint format (may be nested under 'state_dict' key)
82
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
83
+ print(f" Detected checkpoint format with nested state_dict")
84
+ state_dict = checkpoint['state_dict']
85
+ else:
86
+ state_dict = checkpoint
87
+
88
+ model.load_state_dict(state_dict)
89
+ model.to(self.device)
90
+ model.eval()
91
+
92
+ # Test forward pass with dummy input
93
+ print(f" Creating dummy input (2 channels, 96x96x96)...")
94
+ dummy_input = torch.randn(1, 2, 96, 96, 96).to(self.device)
95
+
96
+ print(f" Running forward pass...")
97
+ with torch.no_grad():
98
+ x_rot, x_contrastive, x_rec = model(dummy_input)
99
+
100
+ # Validate outputs
101
+ assert x_rot.shape == (1, 4), f"Expected rotation output shape (1, 4), got {x_rot.shape}"
102
+ assert x_contrastive.shape == (1, 512), f"Expected contrastive output shape (1, 512), got {x_contrastive.shape}"
103
+ assert x_rec.shape == (1, 2, 96, 96, 96), f"Expected reconstruction output shape (1, 2, 96, 96, 96), got {x_rec.shape}"
104
+
105
+ print(f" ✓ Rotation output shape: {x_rot.shape}")
106
+ print(f" ✓ Contrastive output shape: {x_contrastive.shape}")
107
+ print(f" ✓ Reconstruction output shape: {x_rec.shape}")
108
+ print(f" ✓ Model parameters: {sum(p.numel() for p in model.parameters()):,}")
109
+ print(f"✓ {model_name} passed all tests!")
110
+
111
+ # Clean up
112
+ del model, state_dict, dummy_input, x_rot, x_contrastive, x_rec
113
+ if torch.cuda.is_available():
114
+ torch.cuda.empty_cache()
115
+
116
+ return True
117
+
118
+ except Exception as e:
119
+ print(f"✗ {model_name} failed: {str(e)}")
120
+ return False
121
+
122
+ def test_swinunetr_model(self, model_path, model_name):
123
+ """Test finetuned segmentation models (SwinUNETR)."""
124
+ if not MONAI_AVAILABLE:
125
+ print(f"⚠ Skipping {model_name}: MONAI not available")
126
+ return False
127
+
128
+ print(f"\n[Testing] {model_name}")
129
+ print("-" * 70)
130
+
131
+ try:
132
+ # Configure model
133
+ depths = [2, 2, 2, 2]
134
+ num_heads = [3, 6, 12, 24]
135
+
136
+ print(f" Loading model from: {model_path}")
137
+
138
+ # Try with img_size first (older MONAI versions)
139
+ try:
140
+ model = SwinUNETR(
141
+ img_size=(96, 96, 96),
142
+ in_channels=4,
143
+ out_channels=3,
144
+ feature_size=48,
145
+ use_checkpoint=False,
146
+ depths=depths,
147
+ num_heads=num_heads
148
+ )
149
+ except TypeError:
150
+ # Newer MONAI versions use spatial_size instead of img_size
151
+ print(f" Using spatial_size parameter (newer MONAI)")
152
+ model = SwinUNETR(
153
+ spatial_size=(96, 96, 96),
154
+ in_channels=4,
155
+ out_channels=3,
156
+ feature_size=48,
157
+ use_checkpoint=False,
158
+ depths=depths,
159
+ num_heads=num_heads
160
+ )
161
+
162
+ checkpoint = torch.load(model_path, map_location="cpu")
163
+
164
+ # Handle checkpoint format (may be nested under 'state_dict' key)
165
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
166
+ print(f" Detected checkpoint format with nested state_dict")
167
+ state_dict = checkpoint['state_dict']
168
+ else:
169
+ state_dict = checkpoint
170
+
171
+ model.load_state_dict(state_dict)
172
+ model.to(self.device)
173
+ model.eval()
174
+
175
+ # Test forward pass with dummy input
176
+ print(f" Creating dummy input (4 channels, 96x96x96)...")
177
+ dummy_input = torch.randn(1, 4, 96, 96, 96).to(self.device)
178
+
179
+ print(f" Running forward pass...")
180
+ with torch.no_grad():
181
+ output = model(dummy_input)
182
+
183
+ # Validate output
184
+ assert output.shape == (1, 3, 96, 96, 96), f"Expected output shape (1, 3, 96, 96, 96), got {output.shape}"
185
+
186
+ print(f" ✓ Output shape: {output.shape}")
187
+ print(f" ✓ Model parameters: {sum(p.numel() for p in model.parameters()):,}")
188
+ print(f"✓ {model_name} passed all tests!")
189
+
190
+ # Clean up
191
+ del model, state_dict, dummy_input, output
192
+ if torch.cuda.is_available():
193
+ torch.cuda.empty_cache()
194
+
195
+ return True
196
+
197
+ except Exception as e:
198
+ print(f"✗ {model_name} failed: {str(e)}")
199
+ return False
200
+
201
+ def test_local_models(self):
202
+ """Test all local model files."""
203
+ print("\n" + "=" * 70)
204
+ print("TESTING LOCAL MODEL FILES")
205
+ print("=" * 70)
206
+
207
+ models = {
208
+ "model_weights_UKB-pretrain.pt": ("ssl", "UKB Pretrain"),
209
+ "model_weights_BRATS-pretrain.pt": ("ssl", "BRATS Pretrain"),
210
+ "model_weights_ATLAS-pretrain.pt": ("ssl", "ATLAS Pretrain"),
211
+ "model_weights_BRATS-finetune.pt": ("swinunetr", "BRATS Finetune"),
212
+ "model_weights_ATLAS-finetune.pt": ("swinunetr", "ATLAS Finetune"),
213
+ }
214
+
215
+ results = {}
216
+ for filename, (model_type, display_name) in models.items():
217
+ if not os.path.exists(filename):
218
+ print(f"\n⚠ Skipping {display_name}: {filename} not found locally")
219
+ results[display_name] = "not_found"
220
+ continue
221
+
222
+ if model_type == "ssl":
223
+ success = self.test_ssl_pretrain_model(filename, display_name)
224
+ else:
225
+ success = self.test_swinunetr_model(filename, display_name)
226
+
227
+ results[display_name] = "passed" if success else "failed"
228
+
229
+ return results
230
+
231
+ def test_huggingface_models(self):
232
+ """Test models downloaded from Hugging Face."""
233
+ if not HF_AVAILABLE:
234
+ print("\n⚠ Skipping Hugging Face tests: huggingface_hub not installed")
235
+ return {}
236
+
237
+ print("\n" + "=" * 70)
238
+ print("TESTING HUGGING FACE MODELS")
239
+ print("=" * 70)
240
+
241
+ models = {
242
+ "model_weights_UKB-pretrain.pt": ("ssl", "UKB Pretrain (HF)"),
243
+ "model_weights_BRATS-pretrain.pt": ("ssl", "BRATS Pretrain (HF)"),
244
+ "model_weights_ATLAS-pretrain.pt": ("ssl", "ATLAS Pretrain (HF)"),
245
+ "model_weights_BRATS-finetune.pt": ("swinunetr", "BRATS Finetune (HF)"),
246
+ "model_weights_ATLAS-finetune.pt": ("swinunetr", "ATLAS Finetune (HF)"),
247
+ }
248
+
249
+ results = {}
250
+ for filename, (model_type, display_name) in models.items():
251
+ try:
252
+ print(f"\n[Downloading] {display_name} from Hugging Face...")
253
+ model_path = hf_hub_download(
254
+ repo_id=self.repo_id,
255
+ filename=filename,
256
+ cache_dir=".hf_cache"
257
+ )
258
+
259
+ if model_type == "ssl":
260
+ success = self.test_ssl_pretrain_model(model_path, display_name)
261
+ else:
262
+ success = self.test_swinunetr_model(model_path, display_name)
263
+
264
+ results[display_name] = "passed" if success else "failed"
265
+
266
+ except Exception as e:
267
+ print(f"✗ Failed to download/test {display_name}: {str(e)}")
268
+ results[display_name] = "failed"
269
+
270
+ return results
271
+
272
+ def print_summary(self, local_results, hf_results):
273
+ """Print test summary."""
274
+ print("\n" + "=" * 70)
275
+ print("TEST SUMMARY")
276
+ print("=" * 70)
277
+
278
+ all_results = {}
279
+ if local_results:
280
+ print("\nLocal Models:")
281
+ for name, status in local_results.items():
282
+ symbol = "✓" if status == "passed" else "✗" if status == "failed" else "⚠"
283
+ print(f" {symbol} {name}: {status}")
284
+ all_results[name] = status
285
+
286
+ if hf_results:
287
+ print("\nHugging Face Models:")
288
+ for name, status in hf_results.items():
289
+ symbol = "✓" if status == "passed" else "✗"
290
+ print(f" {symbol} {name}: {status}")
291
+ all_results[name] = status
292
+
293
+ # Overall statistics
294
+ passed = sum(1 for s in all_results.values() if s == "passed")
295
+ failed = sum(1 for s in all_results.values() if s == "failed")
296
+ skipped = sum(1 for s in all_results.values() if s == "not_found")
297
+ total = len(all_results)
298
+
299
+ print(f"\nOverall: {passed}/{total} passed, {failed} failed, {skipped} skipped")
300
+
301
+ return failed == 0 and passed > 0
302
+
303
+
304
+ def main():
305
+ """Main test function."""
306
+ import argparse
307
+
308
+ parser = argparse.ArgumentParser(
309
+ description="Test BrainSegFounder model weights"
310
+ )
311
+ parser.add_argument(
312
+ "--local",
313
+ action="store_true",
314
+ default=True,
315
+ help="Test local model files (default: True)"
316
+ )
317
+ parser.add_argument(
318
+ "--hf",
319
+ action="store_true",
320
+ help="Download and test models from Hugging Face"
321
+ )
322
+ parser.add_argument(
323
+ "--no-local",
324
+ action="store_true",
325
+ help="Skip testing local files"
326
+ )
327
+
328
+ args = parser.parse_args()
329
+
330
+ use_local = not args.no_local
331
+ use_hf = args.hf
332
+
333
+ if not use_local and not use_hf:
334
+ print("Error: Must test either local or Hugging Face models (or both)")
335
+ sys.exit(1)
336
+
337
+ tester = ModelTester(use_local=use_local, use_hf=use_hf)
338
+
339
+ local_results = {}
340
+ hf_results = {}
341
+
342
+ if use_local:
343
+ local_results = tester.test_local_models()
344
+
345
+ if use_hf:
346
+ hf_results = tester.test_huggingface_models()
347
+
348
+ success = tester.print_summary(local_results, hf_results)
349
+
350
+ sys.exit(0 if success else 1)
351
+
352
+
353
+ if __name__ == "__main__":
354
+ main()