import torch, torch.nn as nn, numpy as np, os, pickle, platform import torch.distributed as dist from typing import Optional, Dict, Any from numpy.random import Generator, default_rng try: from tqdm import tqdm # type: ignore except ImportError: # pragma: no cover - optional dependency def tqdm(iterable, *args, **kwargs): return iterable # Optional deps for MATLAB .mat (v7.3 HDF5) loading try: import h5py # type: ignore except Exception: h5py = None # Fallback handled below try: from scipy.io import loadmat # type: ignore except Exception: loadmat = None # Only used if available from collections import defaultdict from torch.utils.data import TensorDataset, DataLoader # Use tqdm for better progress display USE_TQDM = True def count_parameters(model, log: bool = True): total = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) if log: print(f"πŸ“Š Model: {total:,} total, {trainable:,} trainable") return total def generate_spectrograms_and_labels(scenario_name, spectrogram_path, cache_path): # TEMP FIX: Skip cache if cache_path is None if cache_path and os.path.exists(cache_path): with open(cache_path, 'rb') as f: cached_data = pickle.load(f) # Handle different cache formats if isinstance(cached_data, dict) and 'samples' in cached_data: spectrograms = cached_data['samples'] else: spectrograms = cached_data else: # Load data directly if cache doesn't exist or cache_path is None spectrograms = load_spectrogram_data(spectrogram_path) # Create cache file (only if cache_path is provided) if cache_path: os.makedirs(os.path.dirname(cache_path), exist_ok=True) with open(cache_path, 'wb') as f: pickle.dump(spectrograms, f) labels = torch.zeros(len(spectrograms), dtype=torch.long) # Convert list of tensors to single tensor if needed if isinstance(spectrograms, list): spectrograms = torch.stack(spectrograms) return spectrograms, labels def load_spectrogram_data(path): """Load spectrogram data from a .pkl, .mat file, or directory. Returns a numpy array with shape: - (N, rows, cols) for single-channel spectrograms - (N, C, rows, cols) for multi-channel spectrograms """ specs = [] def _load_from_pkl(file_path): with open(file_path, 'rb') as f: data = pickle.load(f) if isinstance(data, dict) and 'spectrograms' in data: arr = data['spectrograms'] if isinstance(arr, np.ndarray): return arr if isinstance(data, np.ndarray): return data return None def _load_from_mat(file_path): # Primary path: MATLAB v7.3 (HDF5) via h5py if h5py is not None: try: with h5py.File(file_path, 'r') as f: # Prefer 'spectrograms'; otherwise pick the largest numeric dataset if 'spectrograms' in f: ds = f['spectrograms'] else: cand = [] def _collect(name, obj): try: if isinstance(obj, h5py.Dataset) and obj.dtype.kind in ('f','i','u','c','V'): cand.append((name, obj)) except Exception: pass f.visititems(_collect) if not cand: return None # pick the dataset with the most elements name, ds = max(cand, key=lambda kv: np.prod(kv[1].shape) if hasattr(kv[1], 'shape') else 0) # Complex handling: structured dtype with fields 'real'/'imag' or native complex dtype if hasattr(ds.dtype, 'fields') and ds.dtype.fields and 'real' in ds.dtype.fields and 'imag' in ds.dtype.fields: real = ds['real'][...] imag = ds['imag'][...] arr = real + 1j * imag else: arr = ds[...] return np.array(arr) except Exception: # Fallback to scipy if available pass # Fallback path: older MATLAB formats via scipy.io.loadmat if loadmat is not None: try: data = loadmat(file_path) # Prefer exact key; else choose first suitable numeric array if 'spectrograms' in data: arr = data['spectrograms'] return np.array(arr) for k, v in data.items(): if k.startswith('__'): continue if isinstance(v, np.ndarray) and v.ndim >= 2 and v.size > 0 and np.issubdtype(v.dtype, np.number): return np.array(v) except Exception: pass return None def _normalize_shape(arr: np.ndarray) -> np.ndarray: """Normalize array to (N, rows, cols) or (N, C, rows, cols). Handles both MATLAB-saved HDF5 layouts and already-normalized tensors: - (rows, cols) -> (1, rows, cols) - (rows, cols, N) -> (N, rows, cols) - (N, rows, cols) -> (N, rows, cols) - (rows, cols, C, N) -> (N, C, rows, cols) - (N, C, rows, cols) -> (N, C, rows, cols) """ if arr.ndim == 2: return arr[None, ...] if arr.ndim == 3: # Heuristic: if last dim looks like N, transpose; else assume already (N, rows, cols) if arr.shape[2] > 4 and arr.shape[0] <= 512 and arr.shape[1] <= 512: return np.transpose(arr, (2, 0, 1)) else: return arr if arr.ndim == 4: # Two common patterns: (rows, cols, C, N) or (N, C, rows, cols) # Detect by which axis likely holds N (#samples) # If first axis is large and second is small (#channels), likely already (N, C, rows, cols) if arr.shape[0] > 4 and arr.shape[1] in (1, 2, 4, 8, 16, 32): return arr # Else if last axis is large (N) and third axis is small (C), transpose if arr.shape[3] > 4 and arr.shape[2] in (1, 2, 4, 8, 16, 32): return np.transpose(arr, (3, 2, 0, 1)) # Fallback to original assumption return np.transpose(arr, (3, 2, 0, 1)) return arr # File path if os.path.isfile(path): if path.endswith('.pkl'): arr = _load_from_pkl(path) if arr is not None: arr = _normalize_shape(arr) return arr if path.endswith('.mat'): arr = _load_from_mat(path) if arr is not None: arr = _normalize_shape(arr) return arr return np.array([]) # Directory path for root, _, files in os.walk(path): for f in files: file_path = os.path.join(root, f) if f.endswith('.pkl'): arr = _load_from_pkl(file_path) elif f.endswith('.mat'): arr = _load_from_mat(file_path) else: arr = None if isinstance(arr, np.ndarray): arr = _normalize_shape(arr) # Consolidate into list of samples if arr.ndim == 3: # (N, rows, cols) for i in range(arr.shape[0]): specs.append(arr[i]) elif arr.ndim == 4: # (N, C, rows, cols) for i in range(arr.shape[0]): specs.append(arr[i]) return np.array(specs) if specs else np.array([]) def tokenizer_train( spectrograms, max_len=None, masking_percent=0.4, mask=False, seed=None, metadata=None, dataset_stats=None, normalization="dataset", interleaved: bool = False, show_progress: bool = True, ): # Auto-calculate max_len if not provided if max_len is None and len(spectrograms) > 0: max_len = calculate_max_len_from_spectrogram(spectrograms[0]) print(f"Auto-calculated max_len: {max_len} (from spectrogram shape {spectrograms[0].shape})") elif max_len is None: max_len = 513 # fallback default print(f"Using default max_len: {max_len}") total_specs = len(spectrograms) if show_progress: print(f"Tokenizing {total_specs} samples...") rng: Generator = default_rng(seed) if seed is not None else default_rng() seq_groups = defaultdict(list) tensor_samples = [] skipped_empty = 0 if metadata is not None: meta_arrays = {k: np.asarray(v) for k, v in metadata.items()} else: meta_arrays = None normalization = normalization or "dataset" if normalization not in {"dataset", "per_sample"}: raise ValueError(f"Unsupported normalization mode: {normalization}") if dataset_stats is not None: ds_mean = float(dataset_stats.get('mean', 0.0)) ds_std = float(dataset_stats.get('std', 1.0)) if abs(ds_std) < 1e-6: ds_std = 1e-6 else: ds_mean = 0.0 ds_std = 1.0 eps = 1e-6 iterator = spectrograms if USE_TQDM and show_progress: iterator = tqdm(spectrograms, desc="Tokenizing", total=total_specs) for idx, spec in enumerate(iterator): spec_np = np.array(spec, dtype=np.float32, copy=False) mean_db = float(spec_np.mean()) std_db = float(spec_np.std()) if normalization == "per_sample": denom = std_db if abs(std_db) > eps else eps spec_proc = (spec_np - mean_db) / denom else: spec_proc = (spec_np - ds_mean) / ds_std patch = patch_maker(spec_proc, interleaved=interleaved) if patch.size == 0: skipped_empty += 1 continue n_patches = patch.shape[0] patch_size = patch.shape[1] if patch.ndim > 1 else 16 n_masks = int(masking_percent * n_patches) word2id = { '[CLS]': np.full(patch_size, 0.2, dtype=np.float32), '[MASK]': np.full(patch_size, 0.1, dtype=np.float32), } sample = make_sample(patch, word2id, n_masks, patch_size, mask=mask, rng=rng) sample_meta = {} if meta_arrays is not None: for key, values in meta_arrays.items(): sample_meta[key] = values[idx] sample_meta['power_stats'] = np.array([mean_db, std_db], dtype=np.float32) if mask: input_ids, masked_tokens, masked_pos = sample seq_len = len(input_ids) if seq_len <= 1: continue if masked_tokens: masked_tokens = np.stack(masked_tokens).astype(np.float32, copy=False) else: masked_tokens = np.empty((0, patch_size), dtype=np.float32) seq_groups[seq_len].append({ 'input_ids': input_ids, 'masked_pos': masked_pos, 'masked_tokens': masked_tokens, 'n_patches': seq_len - 1, **sample_meta, }) else: tensor_samples.append({ 'sample': sample, **sample_meta, }) if skipped_empty: print(f"⚠️ Skipped {skipped_empty} spectrograms with empty patches") if mask: filtered_data = {k: v for k, v in seq_groups.items() if k > 0 and v} total_samples = sum(len(v) for v in filtered_data.values()) if not filtered_data: print("Warning: No valid data after filtering!") return {} if show_progress: print(f"βœ… Tokenization completed: {total_samples} samples across {len(filtered_data)} sequence lengths") return {k: filtered_data[k] for k in sorted(filtered_data.keys())} if not tensor_samples: print("Warning: No validation data after processing!") return torch.empty(0) stacked = torch.stack([torch.tensor(item['sample'], dtype=torch.float32) if isinstance(item['sample'], np.ndarray) else item['sample'] for item in tensor_samples]) if show_progress: print(f"βœ… Tokenization completed: {len(tensor_samples)} validation samples") return stacked def calculate_max_len_from_spectrogram(spec, patch_rows=4, patch_cols=4): """ Calculate the maximum sequence length needed for a given spectrogram size. Args: spec: Spectrogram tensor/array patch_rows: Number of rows per patch patch_cols: Number of columns per patch Returns: int: Maximum sequence length (number of patches + 1 for CLS token) """ if hasattr(spec, 'shape'): shape = spec.shape else: shape = spec # Handle different shape formats if len(shape) == 3 and shape[0] == 1: # [1, height, width] n_rows, n_cols = shape[1], shape[2] elif len(shape) == 4 and shape[0] == 1 and shape[1] == 1: # [1, 1, height, width] n_rows, n_cols = shape[2], shape[3] elif len(shape) == 2: # [height, width] n_rows, n_cols = shape[0], shape[1] else: raise ValueError(f"Unexpected spec shape: {shape}") n_patches_r = n_rows // patch_rows n_patches_c = n_cols // patch_cols total_patches = n_patches_r * n_patches_c return total_patches + 1 # +1 for CLS token def patch_maker(spec, patch_rows=4, patch_cols=4, interleaved: bool = False): # Handle normalized spectrograms: [1, height, width] or [1, 1, height, width] if len(spec.shape) == 3 and spec.shape[0] == 1: # [1, height, width] spec = spec.squeeze(0) # Remove batch dimension: [height, width] elif len(spec.shape) == 4 and spec.shape[0] == 1 and spec.shape[1] == 1: # [1, 1, height, width] spec = spec.squeeze(0).squeeze(0) # Remove both dimensions: [height, width] elif len(spec.shape) == 2: # [height, width] - already processed pass else: raise ValueError(f"Unexpected spec shape: {spec.shape}") n_rows, n_cols = spec.shape if interleaved: # Treat last axis as interleaved [real, imag, real, imag, ...] # Compute patches across columns in pairs (2x per complex bin) n_patches_r = n_rows // patch_rows n_complex_cols = n_cols // 2 n_patches_c = n_complex_cols // patch_cols if n_patches_r == 0 or n_patches_c == 0: print(f"❌ PATCH CREATION FAILED (interleaved): {n_rows}x{n_cols} too small for {patch_rows}x{patch_cols}") return np.array([]) # Crop to full patches: rows and 2x columns for interleaving cropped = spec[:n_patches_r * patch_rows, :n_patches_c * patch_cols * 2] if cropped.size == 0: print(f"⚠️ No patches generated from {n_rows}x{n_cols} spectrogram (interleaved)") return np.array([]) # Reshape to (n_patches_r, patch_rows, n_patches_c, patch_cols*2) reshaped = cropped.reshape(n_patches_r, patch_rows, n_patches_c, patch_cols * 2) result = reshaped.transpose(0, 2, 1, 3).reshape(-1, patch_rows * patch_cols * 2) return result.astype(np.float32, copy=False) # Non-interleaved real-valued path (existing behavior) n_patches_r, n_patches_c = n_rows // patch_rows, n_cols // patch_cols if n_patches_r == 0 or n_patches_c == 0: print(f"❌ PATCH CREATION FAILED: spectrogram {n_rows}x{n_cols} too small for {patch_rows}x{patch_cols} patches") print(f" n_patches_r: {n_patches_r}, n_patches_c: {n_patches_c}") return np.array([]) cropped = spec[:n_patches_r * patch_rows, :n_patches_c * patch_cols] if cropped.size == 0: print(f"⚠️ No patches generated from {n_rows}x{n_cols} spectrogram") return np.array([]) reshaped = cropped.reshape(n_patches_r, patch_rows, n_patches_c, patch_cols) result = reshaped.transpose(0, 2, 1, 3).reshape(-1, patch_rows * patch_cols) return result.astype(np.float32, copy=False) def make_sample(tokens, word2id, n_masks, patch_size, mask=True, rng: Generator | None = None): rng = rng or default_rng() input_ids = np.vstack((word2id['[CLS]'], tokens)) if not mask: return torch.tensor(input_ids, dtype=torch.float32) n_patches = tokens.shape[0] if n_masks <= 0 or n_patches == 0: masked_pos = np.empty(0, dtype=np.int64) else: n_masks = min(n_masks, n_patches) mask_candidates = np.arange(1, n_patches + 1) masked_pos = rng.choice(mask_candidates, size=n_masks, replace=False) masked_tokens = [] for pos in masked_pos: masked_tokens.append(input_ids[pos].astype(np.float32, copy=True)) rnd = rng.random() if rnd < 0.1: input_ids[pos] = rng.random(patch_size, dtype=np.float32) elif rnd < 0.9: input_ids[pos] = word2id['[MASK]'] return [input_ids.astype(np.float32, copy=False), masked_tokens, masked_pos] def patch_reconstructor(patches, rows, cols, patch_rows=4, patch_cols=4): if isinstance(patches, torch.Tensor): patches = patches.detach().cpu().numpy() batch_size, num_patches, _ = patches.shape n_h, n_w = rows // patch_rows, cols // patch_cols patches = patches.reshape(batch_size, n_h, n_w, patch_rows, patch_cols) reconstructed = np.zeros((batch_size, rows, cols)) for i in range(n_h): for j in range(n_w): reconstructed[:, i*patch_rows:(i+1)*patch_rows, j*patch_cols:(j+1)*patch_cols] = patches[:, i, j] return reconstructed def plot_radar_chart(names, opt_scores, base_scores, save_path="results/chart.png"): try: import matplotlib.pyplot as plt from math import pi N = len(names) angles = [n/float(N)*2*pi for n in range(N)] + [0] fig, ax = plt.subplots(subplot_kw=dict(projection='polar')) ax.plot(angles, opt_scores + opt_scores[:1], 'o-', label='Optimized', color='#1f77b4') ax.fill(angles, opt_scores + opt_scores[:1], alpha=0.25, color='#1f77b4') ax.plot(angles, base_scores + base_scores[:1], 'o-', label='Baseline', color='#ff7f0e') ax.fill(angles, base_scores + base_scores[:1], alpha=0.25, color='#ff7f0e') ax.set_xticks(angles[:-1]); ax.set_xticklabels(names) ax.set_ylim(0, 1); ax.legend(); ax.grid(True, alpha=0.3) plt.savefig(save_path, dpi=300, bbox_inches='tight'); plt.close() print(f"πŸ“Š Chart saved: {save_path}") except: print("⚠️ Matplotlib unavailable") class MaskedSpectrogramDataset(torch.utils.data.Dataset): """Lazy dataset that materializes masked spectrogram samples per access.""" def __init__(self, samples): self.samples = samples def __len__(self): return len(self.samples) def __getitem__(self, idx): sample = self.samples[idx] input_ids = torch.from_numpy(sample['input_ids']).float() masked_tokens = torch.from_numpy(sample['masked_tokens']).float() masked_pos = torch.from_numpy(sample['masked_pos']).long() snr_db = torch.tensor(sample.get('snr_db', 0.0), dtype=torch.float32) doppler_id = torch.tensor(sample.get('doppler_id', 0), dtype=torch.long) power_stats = torch.tensor(sample.get('power_stats', np.zeros(2, dtype=np.float32)), dtype=torch.float32) snr_id = torch.tensor(sample.get('snr_id', -1), dtype=torch.long) modulation_id = torch.tensor(sample.get('modulation_id', -1), dtype=torch.long) return ( input_ids, masked_tokens, masked_pos, snr_db, doppler_id, power_stats, snr_id, modulation_id, ) def create_train_dataloader(data, batch_size, shuffle, num_workers=0): loaders = {} for seq_len, group in data.items(): print(f"Dataloader: Processing seq_len={seq_len} with {len(group)} samples") # Expect labels to be provided as group_labels in data if available group_labels = None if isinstance(group, tuple) and len(group) == 2: group, group_labels = group # Masked data with dict structure if isinstance(group[0], dict): print(" Processing as masked data (dict structure)") dataset = MaskedSpectrogramDataset(group) loaders[seq_len] = DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True, num_workers=num_workers, ) print(f" Created DataLoader with {len(dataset)} samples (lazy loading)") elif isinstance(group[0], list): print(" Processing as masked data (list structure)") ids, tokens, pos = zip(*group) # If labels are available, use them; else, use zeros if group_labels is not None: label_tensor = torch.tensor(group_labels, dtype=torch.long) else: label_tensor = torch.zeros(len(group), dtype=torch.long) dataset = TensorDataset(torch.tensor(ids, dtype=torch.float32), torch.tensor(tokens, dtype=torch.float32), torch.tensor(pos, dtype=torch.long), label_tensor) loaders[seq_len] = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True, num_workers=num_workers) print(f" Created DataLoader with {len(dataset)} samples (with labels)") else: print(" Processing as non-masked data") if isinstance(group[0], torch.Tensor): dataset = TensorDataset(*group) else: tensor_group = [torch.tensor(g, dtype=torch.float32) for g in group] dataset = TensorDataset(*tensor_group) loaders[seq_len] = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True, num_workers=num_workers) print(f" Created DataLoader with {len(dataset)} samples") return loaders def train_lwm( model, train_loaders, val_loaders, optimizer, scheduler, epochs, device, save_dir="models", log_file="training_log.csv", checkpoint_suffix: str = "", distributed_context: Optional[Dict[str, Any]] = None, ): distributed_context = distributed_context or {} is_distributed = distributed_context.get("is_distributed", False) rank = distributed_context.get("rank", 0) world_size = max(1, distributed_context.get("world_size", 1)) is_primary = distributed_context.get("is_primary", rank == 0) os.makedirs(save_dir, exist_ok=True) # Initialize logging log_file_path = f"{save_dir}/training_log.csv" use_tensorboard = False writer = None # Try to initialize TensorBoard writer if is_primary: try: from torch.utils.tensorboard import SummaryWriter tensorboard_dir = f"{save_dir}/tensorboard" writer = SummaryWriter(tensorboard_dir) print(f"πŸ“Š TensorBoard logs will be saved to: {tensorboard_dir}") use_tensorboard = True except (ImportError, AttributeError) as e: print(f"⚠️ TensorBoard not available ({e}), using CSV logging instead") # Initialize CSV logging as fallback with open(log_file_path, 'w') as f: f.write("epoch,train_loss,val_loss,val_nmse,lr\n") criterion = nn.MSELoss(reduction='sum') best_mse = float('inf') train_losses, val_losses, val_nmse_losses = [], [], [] # Early stopping parameters patience = 3 # Stop if no improvement for 3 epochs patience_counter = 0 def _sync_sum(value: float) -> float: if not is_distributed or not dist.is_available() or not dist.is_initialized(): return float(value) tensor = torch.tensor(value, dtype=torch.float64, device=device) dist.all_reduce(tensor, op=dist.ReduceOp.SUM) return float(tensor.item()) for epoch in range(epochs): # Training model.train() train_mse, train_samples = 0.0, 0 if is_primary: print(f"\nEpoch {epoch+1}/{epochs}") for loader in train_loaders.values(): pbar = tqdm( loader, desc="Train", postfix={"loss": 0.0, "avg_loss": 0.0}, disable=not is_primary, ) for batch in pbar: optimizer.zero_grad() if len(batch) >= 3: ids, tokens, pos = batch[0], batch[1], batch[2] else: raise ValueError(f"Unexpected batch length: {len(batch)}") ids = ids.to(device).float() tokens = tokens.to(device).float() pos = pos.to(device).long() logits = model(ids, pos)[0] loss = criterion(tokens, logits) loss.backward(); optimizer.step(); scheduler.step() train_mse += loss.item(); train_samples += ids.shape[0] # Update tqdm postfix with real-time metrics current_avg_loss = train_mse / max(train_samples, 1) batch_size = ids.shape[0] if is_primary: pbar.set_postfix({ "loss": f"{loss.item()/batch_size:.4f}", "avg_loss": f"{current_avg_loss:.4f}" }) total_train_mse = _sync_sum(train_mse) total_train_samples = _sync_sum(train_samples) train_mse = total_train_mse / max(total_train_samples, 1) train_losses.append(train_mse) # Log training metrics if use_tensorboard and writer: writer.add_scalar('Loss/train', train_mse, epoch + 1) writer.add_scalar('Learning_Rate', optimizer.param_groups[0]['lr'], epoch + 1) elif is_primary: # Log to CSV lr = optimizer.param_groups[0]['lr'] with open(log_file_path, 'a') as f: f.write(f"{epoch+1},{train_mse},,,{lr}\n") # Validation every epoch model.eval() val_mse, val_nmse, val_samples = 0.0, 0.0, 0 with torch.no_grad(): for loader in val_loaders.values(): progress_bar = tqdm( loader, desc="Val", postfix={"mse": 0.0, "nmse": 0.0}, disable=not is_primary, ) for batch in progress_bar: # Check if validation data has masking (3 or 4 elements) or not (1 element) if len(batch) >= 3: # Masked validation data (training format) ids, tokens, pos = batch[0], batch[1], batch[2] ids = ids.to(device).float() tokens = tokens.to(device).float() pos = pos.to(device).long() logits = model(ids, pos)[0] elif len(batch) == 1: # Non-masked validation data (tensor format) val_tensor = batch[0].to(device, dtype=torch.float32) if 'mps' in str(device) else batch[0].to(device) # For validation, call model without masked_pos (None) output = model(val_tensor) # Returns [batch_size, seq_len, d_model] # Apply decoder to get predictions in original dimension # Handle DataParallel wrapper model_module = model.module if hasattr(model, 'module') else model logits = model_module.decoder(output) + model_module.decoder_bias # [batch_size, seq_len, element_length] # For non-masked validation, tokens = input (no masking applied) tokens = val_tensor ids = val_tensor else: raise ValueError(f"Unexpected batch length: {len(batch)}") val_mse += criterion(tokens, logits).item() # Safe numpy conversion for MPS compatibility tokens_np = tokens.float().cpu().numpy().astype(np.float32) if 'mps' in str(device) else tokens.cpu().numpy() logits_np = logits.float().cpu().numpy().astype(np.float32) if 'mps' in str(device) else logits.cpu().numpy() nmse_val = nmse_loss(tokens_np, logits_np) val_nmse += nmse_val * ids.shape[0] val_samples += ids.shape[0] # Update progress bar with real-time metrics current_mse = val_mse / max(val_samples, 1) current_nmse = val_nmse / max(val_samples, 1) current_nmse_db = 10 * np.log10(max(current_nmse, 1e-8)) # Convert to dB scale batch_size = ids.shape[0] if is_primary: progress_bar.set_postfix({ "mse": f"{current_mse:.4f}", "nmse": f"{current_nmse_db:.2f}dB" }) total_val_mse = _sync_sum(val_mse) total_val_nmse = _sync_sum(val_nmse) total_val_samples = _sync_sum(val_samples) val_mse = total_val_mse / max(total_val_samples, 1) val_nmse = total_val_nmse / max(total_val_samples, 1) val_losses.append(val_mse) val_nmse_losses.append(val_nmse) # Log validation metrics if use_tensorboard and writer: writer.add_scalar('Loss/validation', val_mse, epoch + 1) writer.add_scalar('Loss/nmse', val_nmse, epoch + 1) elif is_primary: # Update CSV with validation metrics lr = optimizer.param_groups[0]['lr'] # Read the last line and update it with validation metrics with open(log_file_path, 'r') as f: lines = f.readlines() if lines: # Update the last line with validation metrics last_line = lines[-1].strip() parts = last_line.split(',') if len(parts) >= 5: parts[2] = f"{val_mse}" parts[3] = f"{val_nmse}" lines[-1] = ','.join(parts) + '\n' with open(log_file_path, 'w') as f: f.writelines(lines) if val_mse < best_mse: best_mse = val_mse patience_counter = 0 # Reset counter on improvement suffix = checkpoint_suffix or "" if is_primary: path = f"{save_dir}/lwm_epoch{epoch+1}_val{val_mse:.4f}{suffix}.pth" torch.save(model.state_dict(), path) print(f"βœ… Saved: {path}") else: patience_counter += 1 if is_primary: print(f"⏸️ No improvement for {patience_counter}/{patience} epochs") # Early stopping check if patience_counter >= patience: if is_primary: print(f"πŸ›‘ Early stopping triggered after {epoch+1} epochs") print(f" Best validation MSE: {best_mse:.4f}") break if is_primary: print(f"Train MSE: {train_mse:.4f}") val_nmse_db = 10 * np.log10(max(val_nmse, 1e-8)) print(f"Val MSE: {val_mse:.4f}, NMSE: {val_nmse_db:.2f}dB") # Ensure val_losses and val_nmse_losses have same length as train_losses # Fill missing validation data with None or last available value while len(val_losses) < len(train_losses): val_losses.append(None) while len(val_nmse_losses) < len(train_losses): val_nmse_losses.append(None) # Save training history # Convert numpy types to Python native types for JSON serialization def convert_numpy_types(obj): """Convert numpy types to Python native types for JSON serialization""" if isinstance(obj, np.floating): return float(obj) elif isinstance(obj, np.integer): return int(obj) elif isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, list): return [convert_numpy_types(item) for item in obj] elif isinstance(obj, dict): return {key: convert_numpy_types(value) for key, value in obj.items()} else: return obj training_history = { 'train_losses': convert_numpy_types(train_losses), 'val_losses': convert_numpy_types(val_losses), 'val_nmse_losses': convert_numpy_types(val_nmse_losses), 'epochs': list(range(1, epochs + 1)), 'best_val_mse': convert_numpy_types(best_mse) } if is_primary: import json history_file = f"{save_dir}/training_history.json" with open(history_file, 'w') as f: json.dump(training_history, f, indent=2) print(f"πŸ“Š Training history saved: {history_file}") # Close TensorBoard writer if use_tensorboard and writer: writer.close() print(f"πŸ“Š TensorBoard logs saved: {tensorboard_dir}") else: print(f"πŸ“Š Training logs saved: {log_file_path}") elif use_tensorboard and writer: writer.close() return model def nmse_loss(y_true, y_pred): if isinstance(y_true, torch.Tensor): mse = torch.mean((y_true - y_pred) ** 2) power = torch.mean(y_true ** 2) else: mse = np.mean((y_true - y_pred) ** 2) power = np.mean(y_true ** 2) return mse / (power + 1e-8)