bruAristimunha commited on
Commit
85377dd
·
verified ·
1 Parent(s): 8e5954e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +302 -1
README.md CHANGED
@@ -2,4 +2,305 @@
2
  license: mit
3
  tags:
4
  - braindecode
5
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  license: mit
3
  tags:
4
  - braindecode
5
+ ---
6
+
7
+ Shallow conversion from the original weight for braindecode.
8
+
9
+ ```python
10
+
11
+ #!/usr/bin/env python3
12
+ """
13
+ Complete LaBraM Weight Transfer Script
14
+
15
+ Combines explicit weight mapping with full backbone transfer.
16
+ Uses precise key renaming to transfer all compatible parameters.
17
+
18
+ Transfers weights from LaBraM checkpoint to Braindecode Labram model.
19
+ """
20
+
21
+ import torch
22
+ import argparse
23
+ from braindecode.models import Labram
24
+
25
+
26
+ def create_weight_mapping():
27
+ """
28
+ Create comprehensive weight mapping from LaBraM to Braindecode.
29
+
30
+ Includes:
31
+ - Temporal convolution layers (patch_embed)
32
+ - All transformer blocks
33
+ - Position embeddings
34
+ - Other backbone components
35
+ """
36
+ return {
37
+ # Temporal Convolution Layers
38
+ 'student.patch_embed.conv1.weight': 'patch_embed.temporal_conv.conv1.weight',
39
+ 'student.patch_embed.conv1.bias': 'patch_embed.temporal_conv.conv1.bias',
40
+ 'student.patch_embed.norm1.weight': 'patch_embed.temporal_conv.norm1.weight',
41
+ 'student.patch_embed.norm1.bias': 'patch_embed.temporal_conv.norm1.bias',
42
+ 'student.patch_embed.conv2.weight': 'patch_embed.temporal_conv.conv2.weight',
43
+ 'student.patch_embed.conv2.bias': 'patch_embed.temporal_conv.conv2.bias',
44
+ 'student.patch_embed.norm2.weight': 'patch_embed.temporal_conv.norm2.weight',
45
+ 'student.patch_embed.norm2.bias': 'patch_embed.temporal_conv.norm2.bias',
46
+ 'student.patch_embed.conv3.weight': 'patch_embed.temporal_conv.conv3.weight',
47
+ 'student.patch_embed.conv3.bias': 'patch_embed.temporal_conv.conv3.bias',
48
+ 'student.patch_embed.norm3.weight': 'patch_embed.temporal_conv.norm3.weight',
49
+ 'student.patch_embed.norm3.bias': 'patch_embed.temporal_conv.norm3.bias',
50
+ # Note: Other backbone layers (blocks, embeddings, norm, fc_norm) are handled
51
+ # by removing 'student.' prefix in process_state_dict()
52
+ }
53
+
54
+
55
+ def process_state_dict(state_dict, weight_mapping):
56
+ """
57
+ Process checkpoint state dict with explicit mapping.
58
+
59
+ Parameters:
60
+ -----------
61
+ state_dict : dict
62
+ Original checkpoint state dictionary
63
+ weight_mapping : dict
64
+ Explicit mapping for special layers (patch_embed)
65
+
66
+ Returns:
67
+ --------
68
+ dict : Processed state dict ready for Braindecode model
69
+ """
70
+ new_state = {}
71
+ mapped_keys = []
72
+ skipped_keys = []
73
+
74
+ for key, value in state_dict.items():
75
+ # Skip classification head (task-specific)
76
+ if 'head' in key:
77
+ skipped_keys.append((key, 'head layer'))
78
+ continue
79
+
80
+ # Use explicit mapping for patch_embed temporal_conv
81
+ if key in weight_mapping:
82
+ new_key = weight_mapping[key]
83
+ new_state[new_key] = value
84
+ mapped_keys.append((key, new_key))
85
+ continue
86
+
87
+ # Skip original patch_embed if not in mapping (SegmentPatch)
88
+ if 'patch_embed' in key and 'temporal_conv' not in key:
89
+ skipped_keys.append((key, 'patch_embed (non-temporal)'))
90
+ continue
91
+
92
+ # For backbone layers, remove 'student.' prefix
93
+ if key.startswith('student.'):
94
+ new_key = key.replace('student.', '')
95
+ new_state[new_key] = value
96
+ mapped_keys.append((key, new_key))
97
+ continue
98
+
99
+ # Keep other keys as-is
100
+ new_state[key] = value
101
+ mapped_keys.append((key, key))
102
+
103
+ return new_state, mapped_keys, skipped_keys
104
+
105
+
106
+ def transfer_labram_weights(
107
+ checkpoint_path,
108
+ n_times=1600,
109
+ n_chans=64,
110
+ n_outputs=4,
111
+ output_path=None,
112
+ verbose=True
113
+ ):
114
+ """
115
+ Transfer LaBraM weights to Braindecode Labram using explicit mapping.
116
+
117
+ Parameters:
118
+ -----------
119
+ checkpoint_path : str
120
+ Path to LaBraM checkpoint
121
+ n_times : int
122
+ Number of time samples
123
+ n_chans : int
124
+ Number of channels
125
+ n_outputs : int
126
+ Number of output classes
127
+ output_path : str
128
+ Where to save the model
129
+ verbose : bool
130
+ Print transfer details
131
+
132
+ Returns:
133
+ --------
134
+ model : Labram
135
+ Model with transferred weights
136
+ stats : dict
137
+ Transfer statistics
138
+ """
139
+
140
+ print("\n" + "="*70)
141
+ print("LaBraM → Braindecode Weight Transfer")
142
+ print("="*70)
143
+
144
+ # Load checkpoint
145
+ print(f"\nLoading checkpoint: {checkpoint_path}")
146
+ checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
147
+
148
+ # Extract model state
149
+ if isinstance(checkpoint, dict) and 'model' in checkpoint:
150
+ state = checkpoint['model']
151
+ else:
152
+ state = checkpoint
153
+
154
+ original_params = len(state)
155
+ print(f"Original checkpoint: {original_params} parameters")
156
+
157
+ # Create weight mapping
158
+ weight_mapping = create_weight_mapping()
159
+
160
+ # Process state dict
161
+ print("\nProcessing checkpoint...")
162
+ new_state, mapped_keys, skipped_keys = process_state_dict(state, weight_mapping)
163
+
164
+ transferred_params = len(mapped_keys)
165
+ print(f"Mapped keys: {transferred_params} ({transferred_params/original_params*100:.1f}%)")
166
+ print(f"Skipped keys: {len(skipped_keys)}")
167
+
168
+ if verbose and skipped_keys:
169
+ print(f"\nSkipped layers:")
170
+ for key, reason in skipped_keys[:5]: # Show first 5
171
+ print(f" - {key:50s} ({reason})")
172
+ if len(skipped_keys) > 5:
173
+ print(f" ... and {len(skipped_keys) - 5} more")
174
+
175
+ # Create model
176
+ print(f"\nCreating Labram model:")
177
+ print(f" n_times: {n_times}")
178
+ print(f" n_chans: {n_chans}")
179
+ print(f" n_outputs: {n_outputs}")
180
+ model = Labram(
181
+ n_times=n_times,
182
+ n_chans=n_chans,
183
+ n_outputs=n_outputs,
184
+ neural_tokenizer=True,
185
+ )
186
+
187
+ # Load weights
188
+ print("\nLoading weights into model...")
189
+ incompatible = model.load_state_dict(new_state, strict=False)
190
+
191
+ missing_count = len(incompatible.missing_keys) if incompatible.missing_keys else 0
192
+ unexpected_count = len(incompatible.unexpected_keys) if incompatible.unexpected_keys else 0
193
+
194
+ if missing_count > 0:
195
+ print(f" Missing keys: {missing_count} (expected - will be initialized)")
196
+ if unexpected_count > 0:
197
+ print(f" Unexpected keys: {unexpected_count}")
198
+
199
+ # Test forward pass
200
+ if verbose:
201
+ print("\nTesting forward pass...")
202
+ x = torch.randn(2, n_chans, n_times)
203
+ with torch.no_grad():
204
+ output = model(x)
205
+ print(f" Input shape: {x.shape}")
206
+ print(f" Output shape: {output.shape}")
207
+ print(" ✅ Forward pass successful!")
208
+
209
+ # Save model if output_path provided
210
+ if output_path:
211
+ print(f"\nSaving model to: {output_path}")
212
+ torch.save(model.state_dict(), output_path)
213
+ print(f" ✅ Model saved")
214
+
215
+ stats = {
216
+ 'original': original_params,
217
+ 'transferred': transferred_params,
218
+ 'skipped': len(skipped_keys),
219
+ 'transfer_rate': f"{transferred_params/original_params*100:.1f}%"
220
+ }
221
+
222
+ return model, stats
223
+
224
+
225
+ if __name__ == '__main__':
226
+ parser = argparse.ArgumentParser(
227
+ description='Transfer LaBraM weights to Braindecode Labram',
228
+ formatter_class=argparse.RawDescriptionHelpFormatter,
229
+ epilog="""
230
+ Examples:
231
+ # Default transfer (backbone parameters)
232
+ python labram_complete_transfer.py
233
+
234
+ # Transfer and save model
235
+ python labram_complete_transfer.py --output labram_weights.pt
236
+
237
+ # Custom EEG parameters
238
+ python labram_complete_transfer.py --n-times 2000 --n-chans 62 --n-outputs 2
239
+
240
+ # Custom checkpoint path
241
+ python labram_complete_transfer.py --checkpoint path/to/checkpoint.pth
242
+ """
243
+ )
244
+
245
+ parser.add_argument(
246
+ '--checkpoint',
247
+ type=str,
248
+ default='LaBraM/checkpoints/labram-base.pth',
249
+ help='Path to LaBraM checkpoint (default: LaBraM/checkpoints/labram-base.pth)'
250
+ )
251
+ parser.add_argument(
252
+ '--n-times',
253
+ type=int,
254
+ default=1600,
255
+ help='Number of time samples (default: 1600)'
256
+ )
257
+ parser.add_argument(
258
+ '--n-chans',
259
+ type=int,
260
+ default=64,
261
+ help='Number of channels (default: 64)'
262
+ )
263
+ parser.add_argument(
264
+ '--n-outputs',
265
+ type=int,
266
+ default=4,
267
+ help='Number of output classes (default: 4)'
268
+ )
269
+ parser.add_argument(
270
+ '--output',
271
+ type=str,
272
+ default=None,
273
+ help='Output file path to save model weights'
274
+ )
275
+ parser.add_argument(
276
+ '--device',
277
+ type=str,
278
+ default='cpu',
279
+ help='Device to use (default: cpu)'
280
+ )
281
+
282
+ args = parser.parse_args()
283
+
284
+ print("="*70)
285
+ print("LaBraM → Braindecode Weight Transfer")
286
+ print("="*70)
287
+
288
+ # Transfer weights
289
+ model, stats = transfer_labram_weights(
290
+ checkpoint_path=args.checkpoint,
291
+ n_times=args.n_times,
292
+ n_chans=args.n_chans,
293
+ n_outputs=args.n_outputs,
294
+ output_path=args.output,
295
+ verbose=True
296
+ )
297
+
298
+ print("\n" + "="*70)
299
+ print("✅ TRANSFER COMPLETE")
300
+ print("="*70)
301
+ print(f"Original parameters: {stats['original']}")
302
+ print(f"Transferred: {stats['transferred']} ({stats['transfer_rate']})")
303
+ print(f"Skipped: {stats['skipped']}")
304
+ print("="*70)
305
+
306
+ ```