File size: 9,348 Bytes
8e5954e
 
 
 
85377dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
---
license: mit
tags:
- braindecode
---

Shallow conversion from the original weight for braindecode.

```python

#!/usr/bin/env python3
"""
Complete LaBraM Weight Transfer Script

Combines explicit weight mapping with full backbone transfer.
Uses precise key renaming to transfer all compatible parameters.

Transfers weights from LaBraM checkpoint to Braindecode Labram model.
"""

import torch
import argparse
from braindecode.models import Labram


def create_weight_mapping():
    """
    Create comprehensive weight mapping from LaBraM to Braindecode.
    
    Includes:
    - Temporal convolution layers (patch_embed)
    - All transformer blocks
    - Position embeddings
    - Other backbone components
    """
    return {
        # Temporal Convolution Layers
        'student.patch_embed.conv1.weight': 'patch_embed.temporal_conv.conv1.weight',
        'student.patch_embed.conv1.bias': 'patch_embed.temporal_conv.conv1.bias',
        'student.patch_embed.norm1.weight': 'patch_embed.temporal_conv.norm1.weight',
        'student.patch_embed.norm1.bias': 'patch_embed.temporal_conv.norm1.bias',
        'student.patch_embed.conv2.weight': 'patch_embed.temporal_conv.conv2.weight',
        'student.patch_embed.conv2.bias': 'patch_embed.temporal_conv.conv2.bias',
        'student.patch_embed.norm2.weight': 'patch_embed.temporal_conv.norm2.weight',
        'student.patch_embed.norm2.bias': 'patch_embed.temporal_conv.norm2.bias',
        'student.patch_embed.conv3.weight': 'patch_embed.temporal_conv.conv3.weight',
        'student.patch_embed.conv3.bias': 'patch_embed.temporal_conv.conv3.bias',
        'student.patch_embed.norm3.weight': 'patch_embed.temporal_conv.norm3.weight',
        'student.patch_embed.norm3.bias': 'patch_embed.temporal_conv.norm3.bias',
        # Note: Other backbone layers (blocks, embeddings, norm, fc_norm) are handled
        # by removing 'student.' prefix in process_state_dict()
    }


def process_state_dict(state_dict, weight_mapping):
    """
    Process checkpoint state dict with explicit mapping.
    
    Parameters:
    -----------
    state_dict : dict
        Original checkpoint state dictionary
    weight_mapping : dict
        Explicit mapping for special layers (patch_embed)
        
    Returns:
    --------
    dict : Processed state dict ready for Braindecode model
    """
    new_state = {}
    mapped_keys = []
    skipped_keys = []
    
    for key, value in state_dict.items():
        # Skip classification head (task-specific)
        if 'head' in key:
            skipped_keys.append((key, 'head layer'))
            continue
        
        # Use explicit mapping for patch_embed temporal_conv
        if key in weight_mapping:
            new_key = weight_mapping[key]
            new_state[new_key] = value
            mapped_keys.append((key, new_key))
            continue
        
        # Skip original patch_embed if not in mapping (SegmentPatch)
        if 'patch_embed' in key and 'temporal_conv' not in key:
            skipped_keys.append((key, 'patch_embed (non-temporal)'))
            continue
        
        # For backbone layers, remove 'student.' prefix
        if key.startswith('student.'):
            new_key = key.replace('student.', '')
            new_state[new_key] = value
            mapped_keys.append((key, new_key))
            continue
        
        # Keep other keys as-is
        new_state[key] = value
        mapped_keys.append((key, key))
    
    return new_state, mapped_keys, skipped_keys


def transfer_labram_weights(
    checkpoint_path,
    n_times=1600,
    n_chans=64,
    n_outputs=4,
    output_path=None,
    verbose=True
):
    """
    Transfer LaBraM weights to Braindecode Labram using explicit mapping.
    
    Parameters:
    -----------
    checkpoint_path : str
        Path to LaBraM checkpoint
    n_times : int
        Number of time samples
    n_chans : int
        Number of channels
    n_outputs : int
        Number of output classes
    output_path : str
        Where to save the model
    verbose : bool
        Print transfer details
        
    Returns:
    --------
    model : Labram
        Model with transferred weights
    stats : dict
        Transfer statistics
    """
    
    print("\n" + "="*70)
    print("LaBraM → Braindecode Weight Transfer")
    print("="*70)
    
    # Load checkpoint
    print(f"\nLoading checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
    
    # Extract model state
    if isinstance(checkpoint, dict) and 'model' in checkpoint:
        state = checkpoint['model']
    else:
        state = checkpoint
    
    original_params = len(state)
    print(f"Original checkpoint: {original_params} parameters")
    
    # Create weight mapping
    weight_mapping = create_weight_mapping()
    
    # Process state dict
    print("\nProcessing checkpoint...")
    new_state, mapped_keys, skipped_keys = process_state_dict(state, weight_mapping)
    
    transferred_params = len(mapped_keys)
    print(f"Mapped keys: {transferred_params} ({transferred_params/original_params*100:.1f}%)")
    print(f"Skipped keys: {len(skipped_keys)}")
    
    if verbose and skipped_keys:
        print(f"\nSkipped layers:")
        for key, reason in skipped_keys[:5]:  # Show first 5
            print(f"  - {key:50s} ({reason})")
        if len(skipped_keys) > 5:
            print(f"  ... and {len(skipped_keys) - 5} more")
    
    # Create model
    print(f"\nCreating Labram model:")
    print(f"  n_times: {n_times}")
    print(f"  n_chans: {n_chans}")
    print(f"  n_outputs: {n_outputs}")
    model = Labram(
        n_times=n_times,
        n_chans=n_chans,
        n_outputs=n_outputs,
        neural_tokenizer=True,
    )
    
    # Load weights
    print("\nLoading weights into model...")
    incompatible = model.load_state_dict(new_state, strict=False)
    
    missing_count = len(incompatible.missing_keys) if incompatible.missing_keys else 0
    unexpected_count = len(incompatible.unexpected_keys) if incompatible.unexpected_keys else 0
    
    if missing_count > 0:
        print(f"  Missing keys: {missing_count} (expected - will be initialized)")
    if unexpected_count > 0:
        print(f"  Unexpected keys: {unexpected_count}")
    
    # Test forward pass
    if verbose:
        print("\nTesting forward pass...")
        x = torch.randn(2, n_chans, n_times)
        with torch.no_grad():
            output = model(x)
        print(f"  Input shape:  {x.shape}")
        print(f"  Output shape: {output.shape}")
        print("  ✅ Forward pass successful!")
    
    # Save model if output_path provided
    if output_path:
        print(f"\nSaving model to: {output_path}")
        torch.save(model.state_dict(), output_path)
        print(f"  ✅ Model saved")
    
    stats = {
        'original': original_params,
        'transferred': transferred_params,
        'skipped': len(skipped_keys),
        'transfer_rate': f"{transferred_params/original_params*100:.1f}%"
    }
    
    return model, stats


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Transfer LaBraM weights to Braindecode Labram',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Default transfer (backbone parameters)
  python labram_complete_transfer.py
  
  # Transfer and save model
  python labram_complete_transfer.py --output labram_weights.pt
  
  # Custom EEG parameters
  python labram_complete_transfer.py --n-times 2000 --n-chans 62 --n-outputs 2
  
  # Custom checkpoint path
  python labram_complete_transfer.py --checkpoint path/to/checkpoint.pth
        """
    )
    
    parser.add_argument(
        '--checkpoint',
        type=str,
        default='LaBraM/checkpoints/labram-base.pth',
        help='Path to LaBraM checkpoint (default: LaBraM/checkpoints/labram-base.pth)'
    )
    parser.add_argument(
        '--n-times',
        type=int,
        default=1600,
        help='Number of time samples (default: 1600)'
    )
    parser.add_argument(
        '--n-chans',
        type=int,
        default=64,
        help='Number of channels (default: 64)'
    )
    parser.add_argument(
        '--n-outputs',
        type=int,
        default=4,
        help='Number of output classes (default: 4)'
    )
    parser.add_argument(
        '--output',
        type=str,
        default=None,
        help='Output file path to save model weights'
    )
    parser.add_argument(
        '--device',
        type=str,
        default='cpu',
        help='Device to use (default: cpu)'
    )
    
    args = parser.parse_args()
    
    print("="*70)
    print("LaBraM → Braindecode Weight Transfer")
    print("="*70)
    
    # Transfer weights
    model, stats = transfer_labram_weights(
        checkpoint_path=args.checkpoint,
        n_times=args.n_times,
        n_chans=args.n_chans,
        n_outputs=args.n_outputs,
        output_path=args.output,
        verbose=True
    )
    
    print("\n" + "="*70)
    print("✅ TRANSFER COMPLETE")
    print("="*70)
    print(f"Original parameters:   {stats['original']}")
    print(f"Transferred:           {stats['transferred']} ({stats['transfer_rate']})")
    print(f"Skipped:               {stats['skipped']}")
    print("="*70)

```