Aakash-Tripathi commited on
Commit
88d9d81
Β·
verified Β·
1 Parent(s): 15f5ea3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +400 -1
README.md CHANGED
@@ -135,6 +135,405 @@ for i, score in enumerate(output.risk_scores.numpy()):
135
  print(f"Year {i+1}: {float(score)}")
136
  ```
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  ## πŸ“ˆ Performance Metrics
139
 
140
  | Dataset | 1-Year AUC | 6-Year AUC | Sample Size |
@@ -215,7 +614,7 @@ This Hugging Face implementation is based on the original work by:
215
  MIT License - See [LICENSE](LICENSE) file
216
 
217
  - Original Model Β© 2022 Peter Mikhael & Jeremy Wohlwend
218
- - HF Adaptation Β© 2025 Aakash Tripathi
219
 
220
  ## πŸ”§ Troubleshooting
221
 
 
135
  print(f"Year {i+1}: {float(score)}")
136
  ```
137
 
138
+ ## πŸ”¬ Advanced Usage: Embedding Extraction
139
+
140
+ ### Extract Embeddings Before Dropout Layer
141
+
142
+ You can extract 512-dimensional embedding vectors from the layer immediately before the dropout layer. This captures the learned risk features before the final prediction layer.
143
+
144
+ ```python
145
+ from huggingface_hub import snapshot_download
146
+ import sys
147
+ import os
148
+ import torch
149
+ import numpy as np
150
+
151
+ # Download and setup model
152
+ model_path = snapshot_download(repo_id="Lab-Rasool/sybil")
153
+ sys.path.append(model_path)
154
+
155
+ from modeling_sybil_hf import SybilHFWrapper
156
+ from configuration_sybil import SybilConfig
157
+
158
+ def extract_embeddings(dicom_paths):
159
+ """
160
+ Extract embeddings from the layer after ReLU, before Dropout.
161
+
162
+ Args:
163
+ dicom_paths: List of DICOM file paths
164
+
165
+ Returns:
166
+ numpy array of shape (512,) - averaged embeddings across ensemble
167
+ """
168
+ # Initialize model
169
+ config = SybilConfig()
170
+ model = SybilHFWrapper(config)
171
+
172
+ # Set each model in ensemble to eval mode
173
+ for m in model.models:
174
+ m.eval()
175
+
176
+ # Storage for embeddings from each model in ensemble
177
+ all_embeddings = []
178
+
179
+ # Register hooks on each model in the ensemble
180
+ for model_idx, ensemble_model in enumerate(model.models):
181
+ embeddings_buffer = []
182
+
183
+ def create_hook(buffer):
184
+ def hook(module, input, output):
185
+ # Capture the output of ReLU layer (before dropout)
186
+ buffer.append(output.detach().cpu())
187
+ return hook
188
+
189
+ # Register hook on the ReLU layer
190
+ hook_handle = ensemble_model.relu.register_forward_hook(create_hook(embeddings_buffer))
191
+
192
+ # Run forward pass
193
+ with torch.no_grad():
194
+ _ = model(dicom_paths=dicom_paths)
195
+
196
+ # Remove hook
197
+ hook_handle.remove()
198
+
199
+ # Get the embeddings (should be shape [1, 512])
200
+ if embeddings_buffer:
201
+ embedding = embeddings_buffer[0].numpy().squeeze()
202
+ all_embeddings.append(embedding)
203
+ print(f"Model {model_idx + 1}: Embedding shape = {embedding.shape}")
204
+
205
+ # Average embeddings across ensemble
206
+ averaged_embedding = np.mean(all_embeddings, axis=0)
207
+ return averaged_embedding
208
+
209
+ # Usage
210
+ dicom_dir = "path/to/volume"
211
+ dicom_paths = [os.path.join(dicom_dir, f) for f in os.listdir(dicom_dir) if f.endswith('.dcm')]
212
+
213
+ embeddings = extract_embeddings(dicom_paths)
214
+ print(f"\nEmbedding vector shape: {embeddings.shape}")
215
+ print(f"Embedding statistics:")
216
+ print(f" Mean: {np.mean(embeddings):.6f}")
217
+ print(f" Std: {np.std(embeddings):.6f}")
218
+ print(f" Min: {np.min(embeddings):.6f}")
219
+ print(f" Max: {np.max(embeddings):.6f}")
220
+ ```
221
+
222
+ ## 🎯 Extracting Embeddings at Other Layers
223
+
224
+ ### Available Extraction Points
225
+
226
+ The Sybil model has several key layers where you can extract intermediate representations:
227
+
228
+ ```python
229
+ import torch
230
+ from huggingface_hub import snapshot_download
231
+ import sys
232
+
233
+ model_path = snapshot_download(repo_id="Lab-Rasool/sybil")
234
+ sys.path.append(model_path)
235
+
236
+ from modeling_sybil_hf import SybilHFWrapper
237
+ from configuration_sybil import SybilConfig
238
+
239
+ config = SybilConfig()
240
+ model = SybilHFWrapper(config)
241
+
242
+ # Get first model from ensemble for demonstration
243
+ first_model = model.models[0]
244
+
245
+ # Model architecture flow:
246
+ # Input β†’ image_encoder β†’ pool β†’ relu β†’ dropout β†’ prob_of_failure_layer β†’ Output
247
+
248
+ def extract_layer_output(model, layer_name, dicom_paths):
249
+ """
250
+ Extract output from any layer in the model.
251
+
252
+ Args:
253
+ model: SybilHFWrapper model
254
+ layer_name: Name of the layer to extract from
255
+ dicom_paths: List of DICOM file paths
256
+
257
+ Returns:
258
+ Extracted features from the specified layer
259
+ """
260
+ features = []
261
+
262
+ def hook_fn(module, input, output):
263
+ features.append(output.detach().cpu())
264
+
265
+ # Register hook on the specified layer
266
+ for m in model.models:
267
+ layer = dict(m.named_modules())[layer_name]
268
+ hook_handle = layer.register_forward_hook(hook_fn)
269
+
270
+ # Run forward pass
271
+ with torch.no_grad():
272
+ _ = model(dicom_paths=dicom_paths)
273
+
274
+ # Remove hook
275
+ hook_handle.remove()
276
+
277
+ return features
278
+
279
+ # Example 1: Extract from image encoder (3D feature maps)
280
+ # Shape: (batch, 512, time, height, width)
281
+ encoder_features = extract_layer_output(model, 'image_encoder', dicom_paths)
282
+ print(f"Image encoder output shape: {encoder_features[0].shape}")
283
+
284
+ # Example 2: Extract from pooling layer (before ReLU)
285
+ # Shape: (batch, 512)
286
+ pool_features = extract_layer_output(model, 'pool', dicom_paths)
287
+ print(f"Pool layer output shape: {pool_features[0].shape}")
288
+
289
+ # Example 3: Extract from ReLU layer (before dropout) - RECOMMENDED
290
+ # Shape: (batch, 512)
291
+ relu_features = extract_layer_output(model, 'relu', dicom_paths)
292
+ print(f"ReLU layer output shape: {relu_features[0].shape}")
293
+
294
+ # Example 4: Extract from dropout layer (before final prediction)
295
+ # Shape: (batch, 512)
296
+ dropout_features = extract_layer_output(model, 'dropout', dicom_paths)
297
+ print(f"Dropout layer output shape: {dropout_features[0].shape}")
298
+ ```
299
+
300
+ ### Custom Layer Extraction Template
301
+
302
+ ```python
303
+ def extract_custom_layer(dicom_paths, target_layer_name):
304
+ """
305
+ Template for extracting features from any layer.
306
+
307
+ Args:
308
+ dicom_paths: List of DICOM file paths
309
+ target_layer_name: Name of target layer (e.g., 'relu', 'pool', 'image_encoder')
310
+
311
+ Returns:
312
+ Extracted features averaged across ensemble
313
+ """
314
+ from huggingface_hub import snapshot_download
315
+ import sys
316
+ import torch
317
+ import numpy as np
318
+
319
+ model_path = snapshot_download(repo_id="Lab-Rasool/sybil")
320
+ sys.path.append(model_path)
321
+
322
+ from modeling_sybil_hf import SybilHFWrapper
323
+ from configuration_sybil import SybilConfig
324
+
325
+ config = SybilConfig()
326
+ model = SybilHFWrapper(config)
327
+
328
+ all_features = []
329
+
330
+ for ensemble_model in model.models:
331
+ ensemble_model.eval()
332
+ features_buffer = []
333
+
334
+ # Get the target layer
335
+ target_layer = dict(ensemble_model.named_modules())[target_layer_name]
336
+
337
+ # Register hook
338
+ def hook(module, input, output):
339
+ features_buffer.append(output.detach().cpu())
340
+
341
+ hook_handle = target_layer.register_forward_hook(hook)
342
+
343
+ # Forward pass
344
+ with torch.no_grad():
345
+ _ = model(dicom_paths=dicom_paths)
346
+
347
+ hook_handle.remove()
348
+
349
+ if features_buffer:
350
+ all_features.append(features_buffer[0])
351
+
352
+ # Average across ensemble
353
+ averaged_features = torch.stack(all_features).mean(dim=0)
354
+ return averaged_features.numpy()
355
+ ```
356
+
357
+ ## πŸ” Model Architecture Inspection
358
+
359
+ ### Print Full Model Architecture
360
+
361
+ ```python
362
+ from huggingface_hub import snapshot_download
363
+ import sys
364
+
365
+ model_path = snapshot_download(repo_id="Lab-Rasool/sybil")
366
+ sys.path.append(model_path)
367
+
368
+ from modeling_sybil_hf import SybilHFWrapper
369
+ from configuration_sybil import SybilConfig
370
+
371
+ config = SybilConfig()
372
+ model = SybilHFWrapper(config)
373
+
374
+ # Print configuration
375
+ print("=" * 80)
376
+ print("MODEL CONFIGURATION:")
377
+ print("=" * 80)
378
+ print(config)
379
+
380
+ # Print ensemble information
381
+ print("\n" + "=" * 80)
382
+ print("ENSEMBLE INFORMATION:")
383
+ print("=" * 80)
384
+ print(f"Number of models in ensemble: {len(model.models)}")
385
+ print(f"Device: {model.device}")
386
+
387
+ # Print architecture of first model
388
+ print("\n" + "=" * 80)
389
+ print("MODEL ARCHITECTURE (First model in ensemble):")
390
+ print("=" * 80)
391
+ first_model = model.models[0]
392
+ print(first_model)
393
+ ```
394
+
395
+ ### Count Model Parameters
396
+
397
+ ```python
398
+ from huggingface_hub import snapshot_download
399
+ import sys
400
+
401
+ model_path = snapshot_download(repo_id="Lab-Rasool/sybil")
402
+ sys.path.append(model_path)
403
+
404
+ from modeling_sybil_hf import SybilHFWrapper
405
+ from configuration_sybil import SybilConfig
406
+
407
+ config = SybilConfig()
408
+ model = SybilHFWrapper(config)
409
+
410
+ print("=" * 80)
411
+ print("MODEL PARAMETERS:")
412
+ print("=" * 80)
413
+
414
+ # Parameters per model in ensemble
415
+ for i, ensemble_model in enumerate(model.models):
416
+ total_params = sum(p.numel() for p in ensemble_model.parameters())
417
+ trainable_params = sum(p.numel() for p in ensemble_model.parameters() if p.requires_grad)
418
+
419
+ print(f"\nModel {i+1}:")
420
+ print(f" Total parameters: {total_params:,}")
421
+ print(f" Trainable parameters: {trainable_params:,}")
422
+ print(f" Non-trainable parameters: {total_params - trainable_params:,}")
423
+
424
+ # Total ensemble parameters
425
+ total_ensemble = sum(
426
+ sum(p.numel() for p in m.parameters())
427
+ for m in model.models
428
+ )
429
+ print(f"\nTotal ensemble parameters: {total_ensemble:,}")
430
+ ```
431
+
432
+ ### List Model Components
433
+
434
+ ```python
435
+ from huggingface_hub import snapshot_download
436
+ import sys
437
+
438
+ model_path = snapshot_download(repo_id="Lab-Rasool/sybil")
439
+ sys.path.append(model_path)
440
+
441
+ from modeling_sybil_hf import SybilHFWrapper
442
+ from configuration_sybil import SybilConfig
443
+
444
+ config = SybilConfig()
445
+ model = SybilHFWrapper(config)
446
+ first_model = model.models[0]
447
+
448
+ print("=" * 80)
449
+ print("MODEL COMPONENTS:")
450
+ print("=" * 80)
451
+
452
+ # Print each component with parameter count
453
+ for name, module in first_model.named_children():
454
+ num_params = sum(p.numel() for p in module.parameters())
455
+ print(f"{name}: {module.__class__.__name__} ({num_params:,} parameters)")
456
+
457
+ print("\n" + "=" * 80)
458
+ print("DETAILED LAYER NAMES:")
459
+ print("=" * 80)
460
+
461
+ # Print all named modules (including nested layers)
462
+ for name, module in first_model.named_modules():
463
+ if name: # Skip the root module
464
+ print(f" {name}: {module.__class__.__name__}")
465
+ ```
466
+
467
+ ### Model Architecture Overview
468
+
469
+ The Sybil model consists of the following key components:
470
+
471
+ ```
472
+ Input (3D CT Volume)
473
+ ↓
474
+ image_encoder (R3D-18 backbone)
475
+ - 3D convolutional neural network
476
+ - Pretrained on Kinetics-400
477
+ - Output: (batch, 512, time, height, width)
478
+ ↓
479
+ pool (MultiAttentionPool)
480
+ - Attention-based pooling mechanisms
481
+ - Combines multiple pooling strategies
482
+ - Output: (batch, 512)
483
+ ↓
484
+ relu (ReLU activation)
485
+ - Non-linear activation
486
+ - Output: (batch, 512) ← EMBEDDING EXTRACTION POINT
487
+ ↓
488
+ dropout (Dropout layer)
489
+ - Regularization (p=0.0 in inference)
490
+ - Output: (batch, 512)
491
+ ↓
492
+ prob_of_failure_layer (CumulativeProbabilityLayer)
493
+ - Hazard function prediction
494
+ - Output: (batch, 6) - one score per year
495
+ ↓
496
+ sigmoid (applied post-forward)
497
+ ↓
498
+ Risk Scores (final output)
499
+ ```
500
+
501
+ ### Get Layer-by-Layer Summary
502
+
503
+ ```python
504
+ def print_model_summary(model):
505
+ """Print a detailed summary of the model architecture."""
506
+ from huggingface_hub import snapshot_download
507
+ import sys
508
+
509
+ model_path = snapshot_download(repo_id="Lab-Rasool/sybil")
510
+ sys.path.append(model_path)
511
+
512
+ from modeling_sybil_hf import SybilHFWrapper
513
+ from configuration_sybil import SybilConfig
514
+
515
+ config = SybilConfig()
516
+ model = SybilHFWrapper(config)
517
+ first_model = model.models[0]
518
+
519
+ print(f"{'Layer Name':<40} {'Type':<30} {'Parameters':>15}")
520
+ print("=" * 85)
521
+
522
+ total_params = 0
523
+ for name, module in first_model.named_modules():
524
+ if name: # Skip root
525
+ num_params = sum(p.numel() for p in module.parameters())
526
+ if num_params > 0:
527
+ print(f"{name:<40} {module.__class__.__name__:<30} {num_params:>15,}")
528
+ total_params += num_params
529
+
530
+ print("=" * 85)
531
+ print(f"{'TOTAL':<40} {'':<30} {total_params:>15,}")
532
+
533
+ # Usage
534
+ print_model_summary(model)
535
+ ```
536
+
537
  ## πŸ“ˆ Performance Metrics
538
 
539
  | Dataset | 1-Year AUC | 6-Year AUC | Sample Size |
 
614
  MIT License - See [LICENSE](LICENSE) file
615
 
616
  - Original Model Β© 2022 Peter Mikhael & Jeremy Wohlwend
617
+ - HF Adaptation with Embeddings Β© 2025 [Aakash Tripathi](https://github.com/Aakash-Tripathi)
618
 
619
  ## πŸ”§ Troubleshooting
620