wi-lab commited on
Commit
9f9c8ef
Β·
verified Β·
1 Parent(s): 8df1cab

Upload pretraining/train_lwm_spectro.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pretraining/train_lwm_spectro.py +741 -0
pretraining/train_lwm_spectro.py ADDED
@@ -0,0 +1,741 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # =============================================================================
3
+ # train_lwm_spectro.py - LWM Pretraining with Complex-Valued Spectrogram Support
4
+ # Modified from train_lwm_spectro_no_contrast.py to handle complex spectrograms
5
+ # by separating real and imaginary parts and flattening them (similar to train_lwm.py)
6
+ # =============================================================================
7
+
8
+ # =============================================================================
9
+ # 1. IMPORTS AND WARNINGS SETUP
10
+ # - Load necessary PyTorch modules, utilities, and suppress UserWarnings
11
+ # =============================================================================
12
+ import sys
13
+ import os
14
+ import argparse
15
+ # Add project root to path (Windows compatible)
16
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
17
+ sys.path.insert(0, project_root)
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from torch.utils.data import DataLoader, random_split, TensorDataset
22
+ import torch.optim as optim
23
+ from utils import (generate_spectrograms_and_labels, tokenizer_train,
24
+ create_train_dataloader, count_parameters, train_lwm)
25
+ import numpy as np
26
+ import pretrained_model # Assuming this contains the LWM model definition
27
+ from torch.optim.lr_scheduler import LambdaLR
28
+ from torch.optim import AdamW
29
+ import warnings
30
+ import platform
31
+ import re
32
+ from tqdm import tqdm
33
+ from datetime import datetime
34
+ import concurrent.futures
35
+ import multiprocessing
36
+ from collections import Counter
37
+ from functools import lru_cache
38
+ import json
39
+
40
+ SNR_PATTERN = re.compile(r"SNR(-?\d+)dB")
41
+ DOPPLER_MAP = {"static": 0, "pedestrian": 1, "vehicular": 2}
42
+ DOPPLER_INV = {v: k for k, v in DOPPLER_MAP.items()}
43
+
44
+
45
+ def _parse_snr_and_doppler(path: str) -> tuple[float, int]:
46
+ snr_db = 0.0
47
+ doppler_id = 0
48
+
49
+ matches = SNR_PATTERN.findall(path)
50
+ if matches:
51
+ try:
52
+ snr_db = float(matches[-1])
53
+ except ValueError:
54
+ snr_db = 0.0
55
+
56
+ normalized_path = os.path.normpath(path)
57
+ parts = normalized_path.split(os.sep)
58
+ for part in parts:
59
+ if part in DOPPLER_MAP:
60
+ doppler_id = DOPPLER_MAP[part]
61
+ break
62
+
63
+ return snr_db, doppler_id
64
+
65
+ warnings.filterwarnings("ignore", category=UserWarning)
66
+
67
+ # Use simple progress display instead of tqdm on Windows
68
+ USE_TQDM = platform.system() != 'Windows'
69
+
70
+ # CPU μ½”μ–΄ 수 계산 (λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰ κ³ λ €ν•˜μ—¬ 보수적으둜 μ„€μ •)
71
+ total_cores = multiprocessing.cpu_count()
72
+ if total_cores >= 16:
73
+ MAX_WORKERS = min(8, total_cores // 2) # κ³ μ„±λŠ₯ μ„œλ²„μ˜ 경우 8μ½”μ–΄λ‘œ μ œν•œ
74
+ else:
75
+ MAX_WORKERS = max(2, total_cores // 2) # 일반 μ‹œμŠ€ν…œμ˜ 경우 절반 μ‚¬μš©
76
+ print(f"πŸš€ Using {MAX_WORKERS}/{total_cores} CPU cores for parallel processing")
77
+
78
+ PRINT_CONVERSION_STATS = os.environ.get("LWM_PRINT_CONVERSION_STATS", "").strip().lower() in {"1", "true", "yes"}
79
+
80
+
81
+ def convert_complex_to_interleaved(spectrograms):
82
+ """
83
+ Convert complex-valued spectrograms to real-imaginary interleaved format.
84
+
85
+ Similar to patch_maker() in train_lwm.py, this function:
86
+ 1. Extracts real and imaginary parts
87
+ 2. Interleaves them along the last dimension
88
+
89
+ Args:
90
+ spectrograms (np.ndarray): Complex-valued array of shape (n_samples, n_rows, n_cols)
91
+ or (n_samples, 1, n_rows, n_cols)
92
+
93
+ Returns:
94
+ np.ndarray: Real-valued array with interleaved real/imag parts
95
+ Shape: (n_samples, n_rows, n_cols * 2)
96
+ """
97
+ # Handle different input shapes
98
+ if spectrograms.ndim == 4:
99
+ # Remove channel dimension if present: (n_samples, 1, n_rows, n_cols) -> (n_samples, n_rows, n_cols)
100
+ spectrograms = spectrograms[:, 0, :, :]
101
+
102
+ # Check if data is complex
103
+ if np.iscomplexobj(spectrograms):
104
+ n_samples, n_rows, n_cols = spectrograms.shape
105
+
106
+ # Extract real and imaginary parts
107
+ flat_real = spectrograms.real
108
+ flat_imag = spectrograms.imag
109
+
110
+ # Interleave real and imaginary parts along the last axis
111
+ # Output shape: (n_samples, n_rows, n_cols * 2)
112
+ interleaved = np.empty((n_samples, n_rows, n_cols * 2), dtype=np.float32)
113
+ interleaved[:, :, 0::2] = flat_real # Even indices: real parts
114
+ interleaved[:, :, 1::2] = flat_imag # Odd indices: imaginary parts
115
+
116
+ if PRINT_CONVERSION_STATS:
117
+ print(f" ℹ️ Converted complex spectrograms: {spectrograms.shape} -> {interleaved.shape}")
118
+ print(f" Real part range: [{flat_real.min():.2e}, {flat_real.max():.2e}]")
119
+ print(f" Imag part range: [{flat_imag.min():.2e}, {flat_imag.max():.2e}]")
120
+
121
+ return interleaved
122
+ else:
123
+ # Already real-valued, just ensure correct shape
124
+ if spectrograms.ndim == 3:
125
+ if PRINT_CONVERSION_STATS:
126
+ print(f" ℹ️ Data is already real-valued: {spectrograms.shape}")
127
+ return spectrograms
128
+ else:
129
+ raise ValueError(f"Unexpected spectrogram shape: {spectrograms.shape}")
130
+
131
+
132
+ def process_single_scenario(scenario_info):
133
+ """단일 μ‹œλ‚˜λ¦¬μ˜€λ₯Ό μ²˜λ¦¬ν•˜λŠ” ν•¨μˆ˜ (λ©€ν‹°ν”„λ‘œμ„Έμ‹±μš©)"""
134
+ scenario_name, spectrogram_path = scenario_info
135
+
136
+ try:
137
+ # λ©”λͺ¨λ¦¬ νš¨μœ¨μ„±μ„ μœ„ν•΄ ν•„μš”ν•œ λ°μ΄ν„°λ§Œ λ‘œλ“œ
138
+ scenario_spectrograms, scenario_labels = generate_spectrograms_and_labels(
139
+ scenario_name=scenario_name,
140
+ spectrogram_path=spectrogram_path,
141
+ cache_path=None, # λ©”λͺ¨λ¦¬ 문제둜 μΊμ‹œ λΉ„ν™œμ„±ν™”
142
+ )
143
+
144
+ # Validate load
145
+ if scenario_spectrograms is None or (hasattr(scenario_spectrograms, 'size') and scenario_spectrograms.size == 0):
146
+ print(f" ⚠️ No data loaded from: {spectrogram_path}")
147
+ return None
148
+
149
+ # Convert complex spectrograms to interleaved real-imaginary format
150
+ scenario_spectrograms = convert_complex_to_interleaved(scenario_spectrograms)
151
+
152
+ snr_db, doppler_id = _parse_snr_and_doppler(spectrogram_path)
153
+
154
+ # 데이터 λΆ„ν•  (인덱슀만 계산)
155
+ total_samples = len(scenario_spectrograms)
156
+ train_size = int(0.8 * total_samples)
157
+ val_size = total_samples - train_size
158
+
159
+ # λ©”λͺ¨λ¦¬ μ ˆμ•½μ„ μœ„ν•΄ numpy array둜 μœ μ§€ (ν•„μš”ν•  λ•Œλ§Œ tensor둜 λ³€ν™˜)
160
+ train_data = np.array(scenario_spectrograms[:train_size], dtype=np.float32)
161
+ val_data = np.array(scenario_spectrograms[train_size:], dtype=np.float32)
162
+
163
+ snr_array = np.full(total_samples, snr_db, dtype=np.float32)
164
+ doppler_array = np.full(total_samples, doppler_id, dtype=np.int64)
165
+ train_meta = {
166
+ 'snr_db': snr_array[:train_size],
167
+ 'doppler_id': doppler_array[:train_size],
168
+ }
169
+ val_meta = {
170
+ 'snr_db': snr_array[train_size:],
171
+ 'doppler_id': doppler_array[train_size:],
172
+ }
173
+
174
+ # λΆˆν•„μš”ν•œ 데이터 μ¦‰μ‹œ μ‚­μ œ
175
+ del scenario_spectrograms
176
+
177
+ return {
178
+ 'scenario': scenario_name,
179
+ 'train_data': train_data,
180
+ 'val_data': val_data,
181
+ 'train_meta': train_meta,
182
+ 'val_meta': val_meta,
183
+ 'train_size': len(train_data),
184
+ 'val_size': len(val_data)
185
+ }
186
+ except Exception as e:
187
+ print(f"❌ Error processing scenario {scenario_name}: {e}")
188
+ import traceback
189
+ traceback.print_exc()
190
+ return None
191
+
192
+ # GPU Memory Monitor import (for Lambda) - Removed
193
+
194
+ # =============================================================================
195
+ # 2. SCENARIO LIST DEFINITION
196
+ # - Define the list of scenario names to iterate over for data generation
197
+ # =============================================================================
198
+
199
+ # Supported communications; can be limited via CLI
200
+ SUPPORTED_COMM_TYPES = {"LTE", "WiFi", "5G"}
201
+
202
+
203
+ def _parse_standard_args():
204
+ parser = argparse.ArgumentParser(add_help=False)
205
+ parser.add_argument('--standards', nargs='+', choices=SUPPORTED_COMM_TYPES,
206
+ help='Specify one or more communication types to include (default: all).')
207
+ for comm in SUPPORTED_COMM_TYPES:
208
+ parser.add_argument(f'--{comm}', dest=f'flag_{comm}', action='store_true',
209
+ help=f'Include only {comm} data (can be combined).')
210
+ parser.add_argument('--city', '--cities', dest='cities', nargs='+',
211
+ help='Limit scenarios to one or more city prefixes (e.g., "0" or "city_0").')
212
+ parser.add_argument(
213
+ '--normalization',
214
+ choices=('per_sample', 'dataset'),
215
+ default='per_sample',
216
+ help='Normalization mode applied during tokenization (default: %(default)s).'
217
+ )
218
+ parser.add_argument('--help', action='help')
219
+
220
+ args, remaining = parser.parse_known_args()
221
+
222
+ enabled = set(SUPPORTED_COMM_TYPES)
223
+ if args.standards:
224
+ enabled = set(args.standards)
225
+ else:
226
+ flagged = {comm for comm in SUPPORTED_COMM_TYPES if getattr(args, f'flag_{comm}', False)}
227
+ if flagged:
228
+ enabled = flagged
229
+
230
+ selected_cities: list[str] | None = None
231
+ if args.cities:
232
+ selected_cities = []
233
+ for city_token in args.cities:
234
+ token = str(city_token).strip()
235
+ if not token:
236
+ continue
237
+ if token.startswith('city_'):
238
+ selected_cities.append(token)
239
+ else:
240
+ selected_cities.append(f'city_{token}')
241
+ if not selected_cities:
242
+ selected_cities = None
243
+
244
+ # Return remaining args to allow downstream parsing if needed
245
+ sys.argv = [sys.argv[0]] + remaining
246
+ return enabled, selected_cities, args.normalization
247
+
248
+
249
+ ENABLED_COMM_TYPES, ENABLED_CITY_PREFIXES, NORMALIZATION_MODE = _parse_standard_args()
250
+ MAX_SCENARIOS = int(os.environ.get("LWM_MAX_SCENARIOS", "0")) or None
251
+
252
+
253
+ def _extract_scenario_token(file_path):
254
+ """Derive the base scenario token (without city) from the file path."""
255
+ normalized_path = os.path.normpath(file_path)
256
+ parts = normalized_path.split(os.sep)
257
+
258
+ scenario_parts = []
259
+ for i, part in enumerate(parts):
260
+ if part in SUPPORTED_COMM_TYPES:
261
+ trailing = parts[i:i + 5]
262
+ if trailing:
263
+ scenario_parts = trailing[:5]
264
+ break
265
+
266
+ if not scenario_parts:
267
+ # Fallback for datasets where the communication type is only captured in the filename
268
+ base_name = os.path.splitext(os.path.basename(file_path))[0]
269
+ if base_name.startswith('spectrogram_'):
270
+ tokens = base_name.split('_')[1:] # drop 'spectrogram'
271
+ if tokens and tokens[0] in SUPPORTED_COMM_TYPES:
272
+ scenario_parts = tokens[:5] if len(tokens) >= 5 else tokens
273
+
274
+ return '_'.join(scenario_parts) if scenario_parts else None
275
+
276
+
277
+ @lru_cache(maxsize=1)
278
+ def _collect_scenario_file_info():
279
+ import glob
280
+
281
+ scenario_entries = []
282
+
283
+ # New MATLAB receiver pipeline output
284
+ new_base = os.path.join('ls_data', 'MATLAB', 'receiver_pipeline')
285
+ if os.path.isdir(new_base):
286
+ patterns = [os.path.join(new_base, '*', '**', 'spectrogram_*.mat')]
287
+ for pattern in patterns:
288
+ for file_path in sorted(glob.glob(pattern, recursive=True)):
289
+ norm = os.path.normpath(file_path)
290
+ parts = norm.split(os.sep)
291
+ # Determine a grouping token similar to city_name; use the standard folder name
292
+ try:
293
+ idx = parts.index('receiver_pipeline')
294
+ city_name = parts[idx + 1] if idx + 1 < len(parts) else 'receiver_pipeline'
295
+ except ValueError:
296
+ city_name = 'receiver_pipeline'
297
+
298
+ base_token = _extract_scenario_token(file_path)
299
+ if not base_token:
300
+ continue
301
+ comm_type = base_token.split('_', 1)[0]
302
+ if comm_type not in ENABLED_COMM_TYPES:
303
+ continue
304
+ scenario_id = f"{city_name}::{base_token}"
305
+ scenario_entries.append((scenario_id, file_path, city_name, base_token))
306
+
307
+ # Legacy repo layouts under spectrograms/city_*
308
+ import glob as _glob
309
+ for city_dir in sorted(_glob.glob(os.path.join('spectrograms', 'city_*'))):
310
+ if not os.path.isdir(city_dir):
311
+ continue
312
+ city_name = os.path.basename(city_dir)
313
+ if ENABLED_CITY_PREFIXES:
314
+ if not any(city_name.startswith(prefix) for prefix in ENABLED_CITY_PREFIXES):
315
+ continue
316
+ # Look for complex spectrogram outputs; support both nested and flat layouts
317
+ candidate_patterns = [
318
+ os.path.join(city_dir, '**', 'complex_raw', '**', 'spectrogram_*.mat'),
319
+ os.path.join(city_dir, '**', 'spectrogram_*.mat'),
320
+ ]
321
+ city_files = []
322
+ seen_paths = set()
323
+ for pattern in candidate_patterns:
324
+ for file_path in sorted(_glob.glob(pattern, recursive=True)):
325
+ if not file_path.lower().endswith('.mat'):
326
+ continue
327
+ if file_path in seen_paths:
328
+ continue
329
+ seen_paths.add(file_path)
330
+ city_files.append(file_path)
331
+
332
+ # Fallback: 512FFT pattern (κΈ°μ‘΄ ν˜Έν™˜μ„±)
333
+ if not city_files:
334
+ pattern = os.path.join(city_dir, '**', '512FFT', '**', 'spectrograms', '*.pkl')
335
+ city_files = sorted(_glob.glob(pattern, recursive=True))
336
+
337
+ for file_path in city_files:
338
+ base_token = _extract_scenario_token(file_path)
339
+ if not base_token:
340
+ continue
341
+ comm_type = base_token.split('_', 1)[0]
342
+ if comm_type not in ENABLED_COMM_TYPES:
343
+ continue
344
+ scenario_id = f"{city_name}::{base_token}"
345
+ scenario_entries.append((scenario_id, file_path, city_name, base_token))
346
+
347
+ if MAX_SCENARIOS:
348
+ scenario_entries = scenario_entries[:MAX_SCENARIOS]
349
+
350
+ return scenario_entries
351
+
352
+
353
+ def scenarios_list():
354
+ scenario_entries = _collect_scenario_file_info()
355
+
356
+ if not scenario_entries:
357
+ print("⚠️ No spectrogram files found for pretraining.")
358
+ return np.array([])
359
+
360
+ print(f"Enabled communication types: {sorted(ENABLED_COMM_TYPES)}")
361
+ if ENABLED_CITY_PREFIXES:
362
+ print(f"Selected city prefixes: {sorted(ENABLED_CITY_PREFIXES)}")
363
+ city_counts = Counter(entry[2] for entry in scenario_entries)
364
+ print("Using scenarios from the following city datasets:")
365
+ for city_name, count in city_counts.items():
366
+ print(f" - {city_name}: {count} files")
367
+
368
+ print(f"Total scenarios selected: {len(scenario_entries)}")
369
+ return np.array([entry[0] for entry in scenario_entries])
370
+
371
+
372
+ # =============================================================================
373
+ # 3. SCENARIO PROPERTIES MAPPING
374
+ # - Map each scenario name to its corresponding properties
375
+ # =============================================================================
376
+
377
+ def scenario_prop():
378
+ scenario_entries = _collect_scenario_file_info()
379
+
380
+ row_column_users = {}
381
+ for scenario_id, file_path, city_name, _ in scenario_entries:
382
+ row_column_users[scenario_id] = {
383
+ 'spectrogram_path': file_path,
384
+ 'cache_path': os.path.join('spectrograms', city_name, 'spectrogram_cache_128x128.pkl')
385
+ }
386
+
387
+ return row_column_users
388
+
389
+ # =============================================================================
390
+ # 4. TRAINING PARAMETERS AND HYPERPARAMETERS
391
+ # - Set training epochs, batch sizes, learning rates, model dimensions, etc.
392
+ # =============================================================================
393
+
394
+ EPOCHS = 20 # Increased for better convergence
395
+ # Optimized batch size for A100 GPU (40GB)
396
+ BATCH_SIZE = 16
397
+ VAL_BATCH_SIZE = 16
398
+ WARMUP_EPOCHS = 5
399
+ BASE_LR = 5e-4
400
+ MIN_LR = 1e-8
401
+ # Updated for 128x128 complex spectrograms with real-imaginary interleaving
402
+ N_ROWS = 4
403
+ N_COLUMNS = 4
404
+ ELEMENT_LENGTH = N_ROWS * N_COLUMNS * 2 # Complex spectrograms: 2x for real+imaginary interleaving
405
+ D_MODEL = 128
406
+ MAX_LEN = 1025 # (128/4) * (128/4) + 1 = 32 * 32 + 1 = 1024 + 1 for [CLS] token
407
+ # Interleaving keeps the same number of spatial patches (32x32) while doubling patch width
408
+ # so each token covers 4x4 complex bins (real+imag) and sequence length stays at 1025.
409
+ N_LAYERS = 12
410
+ device_idx = 0
411
+ WEIGHT_DECAY = 0.05
412
+ BETA1 = 0.9
413
+ BETA2 = 0.999
414
+ MASK_PERCENT = 0.6
415
+ N_HEADS = 8
416
+ DROPOUT = 0.1
417
+
418
+ print(f"πŸ“Š Model configuration for complex spectrograms:")
419
+ print(f" Patch size: {N_ROWS}x{N_COLUMNS}")
420
+ print(f" Element length: {ELEMENT_LENGTH} (includes real+imag interleaving)")
421
+ print(f" Max sequence length: {MAX_LEN}")
422
+
423
+ # =============================================================================
424
+ # 5. DATA GENERATION LOOP
425
+ # - Iterate over scenarios to generate spectrogram samples and labels
426
+ # =============================================================================
427
+
428
+ scenarios = scenarios_list()
429
+ scenario_properties = scenario_prop()
430
+
431
+ # Collect all training and validation data separately
432
+ train_spectrogram_chunks = []
433
+ val_spectrogram_chunks = []
434
+ train_label_chunks = []
435
+ val_label_chunks = []
436
+ train_meta_chunks = []
437
+ val_meta_chunks = []
438
+
439
+ print(f"πŸ“‚ Loading {len(scenarios)} scenarios...")
440
+
441
+ # TEMP: Modified to not use cache
442
+ print("⚠️ TEMPORARY FIX: Skipping cache to avoid memory issues")
443
+ cache_path = None # Disable cache usage
444
+
445
+ # 단일 ν”„λ‘œμ„ΈμŠ€ μ‹œλ‚˜λ¦¬μ˜€ 처리 (λ©€ν‹°ν”„λ‘œμ„Έμ‹± λΉ„ν™œμ„±ν™”)
446
+ scenario_info_list = []
447
+ missing_props = []
448
+ for scenario in scenarios:
449
+ props = scenario_properties.get(scenario)
450
+ if props is None:
451
+ missing_props.append(scenario)
452
+ continue
453
+ scenario_info_list.append((scenario, props["spectrogram_path"]))
454
+
455
+ if missing_props:
456
+ print("⚠️ Missing metadata for the following scenarios; skipping:")
457
+ for scen in missing_props:
458
+ print(f" - {scen}")
459
+
460
+ print(f"πŸ“‚ Loading {len(scenario_info_list)} scenarios using single process...")
461
+
462
+ # 단일 ν”„λ‘œμ„ΈμŠ€λ‘œ 처리
463
+ successful_scenarios = 0
464
+ scenario_results = []
465
+
466
+ for scenario_info in tqdm(scenario_info_list, desc="Processing scenarios", unit="scenario"):
467
+ scenario_name = scenario_info[0]
468
+ try:
469
+ result = process_single_scenario(scenario_info)
470
+ if result is not None:
471
+ # 데이터 μˆ˜μ§‘ (μ‹œλ‚˜λ¦¬μ˜€ λ‹¨μœ„λ‘œ λˆ„μ )
472
+ train_spectrogram_chunks.append(result['train_data'])
473
+ val_spectrogram_chunks.append(result['val_data'])
474
+ train_label_chunks.append(np.zeros(result['train_size'], dtype=np.int64))
475
+ val_label_chunks.append(np.zeros(result['val_size'], dtype=np.int64))
476
+ train_meta_chunks.append(result['train_meta'])
477
+ val_meta_chunks.append(result['val_meta'])
478
+ successful_scenarios += 1
479
+ except Exception as e:
480
+ print(f"❌ Scenario {scenario_name} processing failed: {e}")
481
+
482
+ print(f"βœ… Processing completed! Successful scenarios: {successful_scenarios}/{len(scenario_info_list)}")
483
+
484
+ if not train_spectrogram_chunks or not val_spectrogram_chunks:
485
+ raise ValueError("No spectrogram data collected; check scenario configuration.")
486
+
487
+ print("πŸ”„ Collating spectrogram arrays...")
488
+ train_spectrograms = np.concatenate(train_spectrogram_chunks, axis=0).astype(np.float32, copy=False)
489
+ val_spectrograms = np.concatenate(val_spectrogram_chunks, axis=0).astype(np.float32, copy=False)
490
+ train_labels = np.concatenate(train_label_chunks, axis=0)
491
+ val_labels = np.concatenate(val_label_chunks, axis=0)
492
+
493
+ def _concat_metadata_dicts(dict_list):
494
+ if not dict_list:
495
+ return {}
496
+ keys = dict_list[0].keys()
497
+ return {k: np.concatenate([d[k] for d in dict_list], axis=0) for k in keys}
498
+
499
+ train_metadata = _concat_metadata_dicts(train_meta_chunks)
500
+ val_metadata = _concat_metadata_dicts(val_meta_chunks)
501
+
502
+ del train_spectrogram_chunks, val_spectrogram_chunks, train_label_chunks, val_label_chunks
503
+ del train_meta_chunks, val_meta_chunks
504
+
505
+ print(f"Training spectrograms shape: {train_spectrograms.shape}")
506
+ print(f"Validation spectrograms shape: {val_spectrograms.shape}")
507
+ print(f"Memory usage: {train_spectrograms.nbytes + val_spectrograms.nbytes + train_labels.nbytes + val_labels.nbytes:,} bytes")
508
+
509
+ train_mean = float(train_spectrograms.mean())
510
+ train_std = float(train_spectrograms.std())
511
+ if abs(train_std) < 1e-6:
512
+ print("⚠️ Training std near zero, using epsilon for stability")
513
+ train_std = 1e-6
514
+ dataset_normalization = {'mean': train_mean, 'std': train_std, 'normalization': NORMALIZATION_MODE}
515
+ print(f"Dataset normalization stats -> mean: {train_mean:.4f}, std: {train_std:.4f}")
516
+
517
+ # =============================================================================
518
+ # 6. DATA TOKENIZATION
519
+ # - Tokenize spectrogram matrices into input sequences with masking for pretraining
520
+ # =============================================================================
521
+
522
+ # Tokenize training data
523
+ print("πŸ”„ Starting tokenization of training data...")
524
+ preprocessed_train = tokenizer_train(
525
+ train_spectrograms,
526
+ max_len=MAX_LEN,
527
+ masking_percent=MASK_PERCENT,
528
+ mask=True,
529
+ seed=42,
530
+ metadata=train_metadata,
531
+ dataset_stats=dataset_normalization,
532
+ normalization=NORMALIZATION_MODE,
533
+ interleaved=True,
534
+ )
535
+ print("βœ… Training data tokenization completed!")
536
+
537
+ # Tokenize validation data (with masking for pretraining evaluation)
538
+ print("πŸ”„ Starting tokenization of validation data...")
539
+ preprocessed_val = tokenizer_train(
540
+ val_spectrograms,
541
+ max_len=MAX_LEN,
542
+ masking_percent=MASK_PERCENT,
543
+ mask=True, # Apply masking for pretraining evaluation
544
+ seed=42,
545
+ metadata=val_metadata,
546
+ dataset_stats=dataset_normalization,
547
+ normalization=NORMALIZATION_MODE,
548
+ interleaved=True,
549
+ )
550
+ print("βœ… Validation data tokenization completed!")
551
+
552
+ # =============================================================================
553
+ # 7. TRAIN/VALIDATION DATA SETUP
554
+ # - Use pre-split training and validation data
555
+ # =============================================================================
556
+
557
+ SEED = 42
558
+ torch.manual_seed(SEED)
559
+ np.random.seed(SEED)
560
+
561
+ # Use pre-split data
562
+ train_data = preprocessed_train
563
+ val_data = preprocessed_val
564
+
565
+ # =============================================================================
566
+ # 8. DATALOADER CREATION
567
+ # - Build PyTorch DataLoader objects for batched training and validation
568
+ # =============================================================================
569
+
570
+ # Handle different data formats
571
+ print("πŸ”§ Creating data loaders...")
572
+
573
+ if isinstance(train_data, dict):
574
+ print(f" Training data format: dict with {len(train_data)} sequence lengths")
575
+ # Training data with masking
576
+ train_loaders = create_train_dataloader(train_data, batch_size=BATCH_SIZE, shuffle=True)
577
+ else:
578
+ print(f" Training data format: tensor with shape {train_data.shape}")
579
+ # Training data without masking (fallback)
580
+ train_dataset = TensorDataset(train_data)
581
+ train_loaders = {'seq_0': DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)}
582
+
583
+ if isinstance(val_data, dict):
584
+ print(f" Validation data format: dict with {len(val_data)} sequence lengths")
585
+ # Validation data with masking
586
+ val_loaders = create_train_dataloader(val_data, batch_size=VAL_BATCH_SIZE, shuffle=False)
587
+ else:
588
+ print(f" Validation data format: tensor with shape {val_data.shape}")
589
+ # Validation data without masking
590
+ val_dataset = TensorDataset(val_data)
591
+ val_loaders = {'seq_0': DataLoader(val_dataset, batch_size=VAL_BATCH_SIZE, shuffle=False)}
592
+
593
+ print("βœ… Data loaders created successfully!")
594
+
595
+ # =============================================================================
596
+ # 9. MODEL INITIALIZATION
597
+ # - Instantiate the LWM transformer model and optionally load pre-trained weights
598
+ # - Wrap with DataParallel for multi-GPU support
599
+ # =============================================================================
600
+
601
+ # Device selection with MPS support for Mac
602
+ print("πŸ”§ Setting up device and GPU configuration...")
603
+
604
+ if torch.cuda.is_available():
605
+ device_count = torch.cuda.device_count()
606
+ print(f" CUDA available: {device_count} GPU(s) detected")
607
+
608
+ device = torch.device("cuda:0")
609
+
610
+ # On Windows, use only available GPUs
611
+ gpu_ids = list(range(device_count)) # 0, 1, 2... auto-detect
612
+ print(f" Using CUDA GPUs: {gpu_ids}")
613
+
614
+ # GPU memory status
615
+ for i in gpu_ids:
616
+ try:
617
+ mem_total = torch.cuda.get_device_properties(i).total_memory / 1024**3
618
+ mem_allocated = torch.cuda.memory_allocated(i) / 1024**3
619
+ print(f" GPU {i}: Total: {mem_total:.1f}GB, Allocated: {mem_allocated:.1f}GB")
620
+ except Exception as e:
621
+ print(f" GPU {i}: Error getting memory info - {e}")
622
+
623
+ elif torch.backends.mps.is_available():
624
+ device = torch.device("mps")
625
+ gpu_ids = [] # MPS doesn't support DataParallel
626
+ print(" Using MPS (Apple Silicon GPU)")
627
+ else:
628
+ device = torch.device("cpu")
629
+ gpu_ids = []
630
+ print(" Using CPU")
631
+
632
+ print(f" Final device: {device}")
633
+ print(f" GPU IDs for DataParallel: {gpu_ids}")
634
+
635
+ print("πŸ€– Initializing LWM model...")
636
+ print(f" Model parameters: element_length={ELEMENT_LENGTH}, d_model={D_MODEL}, n_layers={N_LAYERS}, max_len={MAX_LEN}, n_heads={N_HEADS}")
637
+
638
+ try:
639
+ model = pretrained_model.lwm(
640
+ element_length=ELEMENT_LENGTH, # Complex spectrograms with real-imag interleaving
641
+ d_model=D_MODEL,
642
+ n_layers=N_LAYERS,
643
+ max_len=MAX_LEN,
644
+ n_heads=N_HEADS,
645
+ dropout=DROPOUT
646
+ )
647
+ print(" βœ… Model created successfully")
648
+
649
+ print(f" Moving model to device: {device}")
650
+ # MPS only supports float32, so set dtype
651
+ if 'mps' in str(device):
652
+ model = model.to(device).float()
653
+ print(" βœ… Model moved to MPS device (float32)")
654
+ else:
655
+ model = model.to(device)
656
+ print(" βœ… Model moved to device successfully")
657
+
658
+ except Exception as e:
659
+ print(f" ❌ Model initialization failed: {e}")
660
+ import traceback
661
+ traceback.print_exc()
662
+ exit(1)
663
+
664
+ # Optional: Load pre-trained model
665
+ load_model = False
666
+ if load_model:
667
+ model.load_state_dict(torch.load("models/model_checkpoint.pth", map_location=device))
668
+ print("Pre-trained model loaded successfully.")
669
+
670
+ # Use DataParallel for multi-GPU support (skip for MPS)
671
+ if gpu_ids:
672
+ model = nn.DataParallel(model, device_ids=gpu_ids)
673
+ print(f"Model loaded successfully on GPU {device.index}")
674
+ else:
675
+ print(f"Model loaded successfully on {device}")
676
+ n_parameters = count_parameters(model)
677
+ print(f"Number of trainable parameters: {n_parameters:,}")
678
+
679
+ # =============================================================================
680
+ # 10. OPTIMIZER AND LEARNING RATE SCHEDULER
681
+ # - Configure AdamW optimizer and a cosine-with-warmup LR schedule based on total steps
682
+ # =============================================================================
683
+
684
+ TOTAL_STEPS = sum(len(loader) for loader in train_loaders.values()) * EPOCHS
685
+ WARMUP_STEPS = sum(len(loader) for loader in train_loaders.values()) * WARMUP_EPOCHS
686
+
687
+ optimizer = AdamW(
688
+ model.parameters(),
689
+ lr=BASE_LR,
690
+ betas=(BETA1, BETA2),
691
+ weight_decay=WEIGHT_DECAY
692
+ )
693
+
694
+ def lr_lambda(current_step):
695
+ if current_step < WARMUP_STEPS:
696
+ return current_step / WARMUP_STEPS
697
+ else:
698
+ scaled_progress = (current_step - WARMUP_STEPS) / (TOTAL_STEPS - WARMUP_STEPS)
699
+ cosine_decay = 0.5 * (1 + np.cos(np.pi * scaled_progress))
700
+ return cosine_decay * (BASE_LR - MIN_LR) / BASE_LR + MIN_LR / BASE_LR
701
+
702
+ scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
703
+
704
+ # =============================================================================
705
+ # 11. PRE-TRAINING LOOP
706
+ # - Call the train_lwm utility to run the pre-training epochs, logging metrics and saving models
707
+ # =============================================================================
708
+
709
+ # Create timestamp-based save directory
710
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
711
+ save_dir = f"models/{timestamp}_complex"
712
+ print(f"πŸ“ Models and logs will be saved to: {save_dir}")
713
+ os.makedirs(save_dir, exist_ok=True)
714
+
715
+ stats_path = os.path.join(save_dir, "dataset_stats.json")
716
+ with open(stats_path, 'w') as f:
717
+ json.dump(dataset_normalization, f, indent=2)
718
+ print(f"πŸ“ Saved dataset stats to {stats_path}")
719
+
720
+ comm_selection = sorted(ENABLED_COMM_TYPES) if ENABLED_COMM_TYPES else []
721
+ if comm_selection:
722
+ comm_suffix = "_" + "-".join(comm_selection)
723
+ else:
724
+ comm_suffix = ""
725
+ if comm_selection:
726
+ print(f"[INFO] Communication standards for this run: {', '.join(comm_selection)}")
727
+
728
+ if __name__ == "__main__":
729
+ pretrained_model_output = train_lwm(
730
+ model,
731
+ train_loaders,
732
+ val_loaders,
733
+ optimizer,
734
+ scheduler,
735
+ EPOCHS,
736
+ device=device,
737
+ save_dir=save_dir,
738
+ log_file="training_log.csv",
739
+ checkpoint_suffix=comm_suffix + "_complex",
740
+ )
741
+ print("πŸŽ‰ Training completed successfully!")