|
|
--- |
|
|
license: mit |
|
|
tags: |
|
|
- braindecode |
|
|
--- |
|
|
|
|
|
Shallow conversion from the original weight for braindecode. |
|
|
|
|
|
```python |
|
|
|
|
|
#!/usr/bin/env python3 |
|
|
""" |
|
|
Complete LaBraM Weight Transfer Script |
|
|
|
|
|
Combines explicit weight mapping with full backbone transfer. |
|
|
Uses precise key renaming to transfer all compatible parameters. |
|
|
|
|
|
Transfers weights from LaBraM checkpoint to Braindecode Labram model. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import argparse |
|
|
from braindecode.models import Labram |
|
|
|
|
|
|
|
|
def create_weight_mapping(): |
|
|
""" |
|
|
Create comprehensive weight mapping from LaBraM to Braindecode. |
|
|
|
|
|
Includes: |
|
|
- Temporal convolution layers (patch_embed) |
|
|
- All transformer blocks |
|
|
- Position embeddings |
|
|
- Other backbone components |
|
|
""" |
|
|
return { |
|
|
# Temporal Convolution Layers |
|
|
'student.patch_embed.conv1.weight': 'patch_embed.temporal_conv.conv1.weight', |
|
|
'student.patch_embed.conv1.bias': 'patch_embed.temporal_conv.conv1.bias', |
|
|
'student.patch_embed.norm1.weight': 'patch_embed.temporal_conv.norm1.weight', |
|
|
'student.patch_embed.norm1.bias': 'patch_embed.temporal_conv.norm1.bias', |
|
|
'student.patch_embed.conv2.weight': 'patch_embed.temporal_conv.conv2.weight', |
|
|
'student.patch_embed.conv2.bias': 'patch_embed.temporal_conv.conv2.bias', |
|
|
'student.patch_embed.norm2.weight': 'patch_embed.temporal_conv.norm2.weight', |
|
|
'student.patch_embed.norm2.bias': 'patch_embed.temporal_conv.norm2.bias', |
|
|
'student.patch_embed.conv3.weight': 'patch_embed.temporal_conv.conv3.weight', |
|
|
'student.patch_embed.conv3.bias': 'patch_embed.temporal_conv.conv3.bias', |
|
|
'student.patch_embed.norm3.weight': 'patch_embed.temporal_conv.norm3.weight', |
|
|
'student.patch_embed.norm3.bias': 'patch_embed.temporal_conv.norm3.bias', |
|
|
# Note: Other backbone layers (blocks, embeddings, norm, fc_norm) are handled |
|
|
# by removing 'student.' prefix in process_state_dict() |
|
|
} |
|
|
|
|
|
|
|
|
def process_state_dict(state_dict, weight_mapping): |
|
|
""" |
|
|
Process checkpoint state dict with explicit mapping. |
|
|
|
|
|
Parameters: |
|
|
----------- |
|
|
state_dict : dict |
|
|
Original checkpoint state dictionary |
|
|
weight_mapping : dict |
|
|
Explicit mapping for special layers (patch_embed) |
|
|
|
|
|
Returns: |
|
|
-------- |
|
|
dict : Processed state dict ready for Braindecode model |
|
|
""" |
|
|
new_state = {} |
|
|
mapped_keys = [] |
|
|
skipped_keys = [] |
|
|
|
|
|
for key, value in state_dict.items(): |
|
|
# Skip classification head (task-specific) |
|
|
if 'head' in key: |
|
|
skipped_keys.append((key, 'head layer')) |
|
|
continue |
|
|
|
|
|
# Use explicit mapping for patch_embed temporal_conv |
|
|
if key in weight_mapping: |
|
|
new_key = weight_mapping[key] |
|
|
new_state[new_key] = value |
|
|
mapped_keys.append((key, new_key)) |
|
|
continue |
|
|
|
|
|
# Skip original patch_embed if not in mapping (SegmentPatch) |
|
|
if 'patch_embed' in key and 'temporal_conv' not in key: |
|
|
skipped_keys.append((key, 'patch_embed (non-temporal)')) |
|
|
continue |
|
|
|
|
|
# For backbone layers, remove 'student.' prefix |
|
|
if key.startswith('student.'): |
|
|
new_key = key.replace('student.', '') |
|
|
new_state[new_key] = value |
|
|
mapped_keys.append((key, new_key)) |
|
|
continue |
|
|
|
|
|
# Keep other keys as-is |
|
|
new_state[key] = value |
|
|
mapped_keys.append((key, key)) |
|
|
|
|
|
return new_state, mapped_keys, skipped_keys |
|
|
|
|
|
|
|
|
def transfer_labram_weights( |
|
|
checkpoint_path, |
|
|
n_times=1600, |
|
|
n_chans=64, |
|
|
n_outputs=4, |
|
|
output_path=None, |
|
|
verbose=True |
|
|
): |
|
|
""" |
|
|
Transfer LaBraM weights to Braindecode Labram using explicit mapping. |
|
|
|
|
|
Parameters: |
|
|
----------- |
|
|
checkpoint_path : str |
|
|
Path to LaBraM checkpoint |
|
|
n_times : int |
|
|
Number of time samples |
|
|
n_chans : int |
|
|
Number of channels |
|
|
n_outputs : int |
|
|
Number of output classes |
|
|
output_path : str |
|
|
Where to save the model |
|
|
verbose : bool |
|
|
Print transfer details |
|
|
|
|
|
Returns: |
|
|
-------- |
|
|
model : Labram |
|
|
Model with transferred weights |
|
|
stats : dict |
|
|
Transfer statistics |
|
|
""" |
|
|
|
|
|
print("\n" + "="*70) |
|
|
print("LaBraM → Braindecode Weight Transfer") |
|
|
print("="*70) |
|
|
|
|
|
# Load checkpoint |
|
|
print(f"\nLoading checkpoint: {checkpoint_path}") |
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) |
|
|
|
|
|
# Extract model state |
|
|
if isinstance(checkpoint, dict) and 'model' in checkpoint: |
|
|
state = checkpoint['model'] |
|
|
else: |
|
|
state = checkpoint |
|
|
|
|
|
original_params = len(state) |
|
|
print(f"Original checkpoint: {original_params} parameters") |
|
|
|
|
|
# Create weight mapping |
|
|
weight_mapping = create_weight_mapping() |
|
|
|
|
|
# Process state dict |
|
|
print("\nProcessing checkpoint...") |
|
|
new_state, mapped_keys, skipped_keys = process_state_dict(state, weight_mapping) |
|
|
|
|
|
transferred_params = len(mapped_keys) |
|
|
print(f"Mapped keys: {transferred_params} ({transferred_params/original_params*100:.1f}%)") |
|
|
print(f"Skipped keys: {len(skipped_keys)}") |
|
|
|
|
|
if verbose and skipped_keys: |
|
|
print(f"\nSkipped layers:") |
|
|
for key, reason in skipped_keys[:5]: # Show first 5 |
|
|
print(f" - {key:50s} ({reason})") |
|
|
if len(skipped_keys) > 5: |
|
|
print(f" ... and {len(skipped_keys) - 5} more") |
|
|
|
|
|
# Create model |
|
|
print(f"\nCreating Labram model:") |
|
|
print(f" n_times: {n_times}") |
|
|
print(f" n_chans: {n_chans}") |
|
|
print(f" n_outputs: {n_outputs}") |
|
|
model = Labram( |
|
|
n_times=n_times, |
|
|
n_chans=n_chans, |
|
|
n_outputs=n_outputs, |
|
|
neural_tokenizer=True, |
|
|
) |
|
|
|
|
|
# Load weights |
|
|
print("\nLoading weights into model...") |
|
|
incompatible = model.load_state_dict(new_state, strict=False) |
|
|
|
|
|
missing_count = len(incompatible.missing_keys) if incompatible.missing_keys else 0 |
|
|
unexpected_count = len(incompatible.unexpected_keys) if incompatible.unexpected_keys else 0 |
|
|
|
|
|
if missing_count > 0: |
|
|
print(f" Missing keys: {missing_count} (expected - will be initialized)") |
|
|
if unexpected_count > 0: |
|
|
print(f" Unexpected keys: {unexpected_count}") |
|
|
|
|
|
# Test forward pass |
|
|
if verbose: |
|
|
print("\nTesting forward pass...") |
|
|
x = torch.randn(2, n_chans, n_times) |
|
|
with torch.no_grad(): |
|
|
output = model(x) |
|
|
print(f" Input shape: {x.shape}") |
|
|
print(f" Output shape: {output.shape}") |
|
|
print(" ✅ Forward pass successful!") |
|
|
|
|
|
# Save model if output_path provided |
|
|
if output_path: |
|
|
print(f"\nSaving model to: {output_path}") |
|
|
torch.save(model.state_dict(), output_path) |
|
|
print(f" ✅ Model saved") |
|
|
|
|
|
stats = { |
|
|
'original': original_params, |
|
|
'transferred': transferred_params, |
|
|
'skipped': len(skipped_keys), |
|
|
'transfer_rate': f"{transferred_params/original_params*100:.1f}%" |
|
|
} |
|
|
|
|
|
return model, stats |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser( |
|
|
description='Transfer LaBraM weights to Braindecode Labram', |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
Examples: |
|
|
# Default transfer (backbone parameters) |
|
|
python labram_complete_transfer.py |
|
|
|
|
|
# Transfer and save model |
|
|
python labram_complete_transfer.py --output labram_weights.pt |
|
|
|
|
|
# Custom EEG parameters |
|
|
python labram_complete_transfer.py --n-times 2000 --n-chans 62 --n-outputs 2 |
|
|
|
|
|
# Custom checkpoint path |
|
|
python labram_complete_transfer.py --checkpoint path/to/checkpoint.pth |
|
|
""" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--checkpoint', |
|
|
type=str, |
|
|
default='LaBraM/checkpoints/labram-base.pth', |
|
|
help='Path to LaBraM checkpoint (default: LaBraM/checkpoints/labram-base.pth)' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--n-times', |
|
|
type=int, |
|
|
default=1600, |
|
|
help='Number of time samples (default: 1600)' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--n-chans', |
|
|
type=int, |
|
|
default=64, |
|
|
help='Number of channels (default: 64)' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--n-outputs', |
|
|
type=int, |
|
|
default=4, |
|
|
help='Number of output classes (default: 4)' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--output', |
|
|
type=str, |
|
|
default=None, |
|
|
help='Output file path to save model weights' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--device', |
|
|
type=str, |
|
|
default='cpu', |
|
|
help='Device to use (default: cpu)' |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
print("="*70) |
|
|
print("LaBraM → Braindecode Weight Transfer") |
|
|
print("="*70) |
|
|
|
|
|
# Transfer weights |
|
|
model, stats = transfer_labram_weights( |
|
|
checkpoint_path=args.checkpoint, |
|
|
n_times=args.n_times, |
|
|
n_chans=args.n_chans, |
|
|
n_outputs=args.n_outputs, |
|
|
output_path=args.output, |
|
|
verbose=True |
|
|
) |
|
|
|
|
|
print("\n" + "="*70) |
|
|
print("✅ TRANSFER COMPLETE") |
|
|
print("="*70) |
|
|
print(f"Original parameters: {stats['original']}") |
|
|
print(f"Transferred: {stats['transferred']} ({stats['transfer_rate']})") |
|
|
print(f"Skipped: {stats['skipped']}") |
|
|
print("="*70) |
|
|
|
|
|
``` |