|
|
import torch, torch.nn as nn, numpy as np, os, pickle, platform |
|
|
import torch.distributed as dist |
|
|
from typing import Optional, Dict, Any |
|
|
from numpy.random import Generator, default_rng |
|
|
|
|
|
try: |
|
|
from tqdm import tqdm |
|
|
except ImportError: |
|
|
def tqdm(iterable, *args, **kwargs): |
|
|
return iterable |
|
|
|
|
|
|
|
|
try: |
|
|
import h5py |
|
|
except Exception: |
|
|
h5py = None |
|
|
try: |
|
|
from scipy.io import loadmat |
|
|
except Exception: |
|
|
loadmat = None |
|
|
from collections import defaultdict |
|
|
from torch.utils.data import TensorDataset, DataLoader |
|
|
|
|
|
|
|
|
USE_TQDM = True |
|
|
|
|
|
def count_parameters(model, log: bool = True): |
|
|
total = sum(p.numel() for p in model.parameters()) |
|
|
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
if log: |
|
|
print(f"π Model: {total:,} total, {trainable:,} trainable") |
|
|
return total |
|
|
|
|
|
|
|
|
def generate_spectrograms_and_labels(scenario_name, spectrogram_path, cache_path): |
|
|
|
|
|
if cache_path and os.path.exists(cache_path): |
|
|
with open(cache_path, 'rb') as f: |
|
|
cached_data = pickle.load(f) |
|
|
|
|
|
if isinstance(cached_data, dict) and 'samples' in cached_data: |
|
|
spectrograms = cached_data['samples'] |
|
|
else: |
|
|
spectrograms = cached_data |
|
|
else: |
|
|
|
|
|
spectrograms = load_spectrogram_data(spectrogram_path) |
|
|
|
|
|
|
|
|
if cache_path: |
|
|
os.makedirs(os.path.dirname(cache_path), exist_ok=True) |
|
|
with open(cache_path, 'wb') as f: |
|
|
pickle.dump(spectrograms, f) |
|
|
|
|
|
labels = torch.zeros(len(spectrograms), dtype=torch.long) |
|
|
|
|
|
if isinstance(spectrograms, list): |
|
|
spectrograms = torch.stack(spectrograms) |
|
|
|
|
|
return spectrograms, labels |
|
|
|
|
|
def load_spectrogram_data(path): |
|
|
"""Load spectrogram data from a .pkl, .mat file, or directory. |
|
|
|
|
|
Returns a numpy array with shape: |
|
|
- (N, rows, cols) for single-channel spectrograms |
|
|
- (N, C, rows, cols) for multi-channel spectrograms |
|
|
""" |
|
|
specs = [] |
|
|
|
|
|
def _load_from_pkl(file_path): |
|
|
with open(file_path, 'rb') as f: |
|
|
data = pickle.load(f) |
|
|
if isinstance(data, dict) and 'spectrograms' in data: |
|
|
arr = data['spectrograms'] |
|
|
if isinstance(arr, np.ndarray): |
|
|
return arr |
|
|
if isinstance(data, np.ndarray): |
|
|
return data |
|
|
return None |
|
|
|
|
|
def _load_from_mat(file_path): |
|
|
|
|
|
if h5py is not None: |
|
|
try: |
|
|
with h5py.File(file_path, 'r') as f: |
|
|
|
|
|
if 'spectrograms' in f: |
|
|
ds = f['spectrograms'] |
|
|
else: |
|
|
cand = [] |
|
|
def _collect(name, obj): |
|
|
try: |
|
|
if isinstance(obj, h5py.Dataset) and obj.dtype.kind in ('f','i','u','c','V'): |
|
|
cand.append((name, obj)) |
|
|
except Exception: |
|
|
pass |
|
|
f.visititems(_collect) |
|
|
if not cand: |
|
|
return None |
|
|
|
|
|
name, ds = max(cand, key=lambda kv: np.prod(kv[1].shape) if hasattr(kv[1], 'shape') else 0) |
|
|
|
|
|
if hasattr(ds.dtype, 'fields') and ds.dtype.fields and 'real' in ds.dtype.fields and 'imag' in ds.dtype.fields: |
|
|
real = ds['real'][...] |
|
|
imag = ds['imag'][...] |
|
|
arr = real + 1j * imag |
|
|
else: |
|
|
arr = ds[...] |
|
|
return np.array(arr) |
|
|
except Exception: |
|
|
|
|
|
pass |
|
|
|
|
|
if loadmat is not None: |
|
|
try: |
|
|
data = loadmat(file_path) |
|
|
|
|
|
if 'spectrograms' in data: |
|
|
arr = data['spectrograms'] |
|
|
return np.array(arr) |
|
|
for k, v in data.items(): |
|
|
if k.startswith('__'): |
|
|
continue |
|
|
if isinstance(v, np.ndarray) and v.ndim >= 2 and v.size > 0 and np.issubdtype(v.dtype, np.number): |
|
|
return np.array(v) |
|
|
except Exception: |
|
|
pass |
|
|
return None |
|
|
|
|
|
def _normalize_shape(arr: np.ndarray) -> np.ndarray: |
|
|
"""Normalize array to (N, rows, cols) or (N, C, rows, cols). |
|
|
|
|
|
Handles both MATLAB-saved HDF5 layouts and already-normalized tensors: |
|
|
- (rows, cols) -> (1, rows, cols) |
|
|
- (rows, cols, N) -> (N, rows, cols) |
|
|
- (N, rows, cols) -> (N, rows, cols) |
|
|
- (rows, cols, C, N) -> (N, C, rows, cols) |
|
|
- (N, C, rows, cols) -> (N, C, rows, cols) |
|
|
""" |
|
|
if arr.ndim == 2: |
|
|
return arr[None, ...] |
|
|
if arr.ndim == 3: |
|
|
|
|
|
if arr.shape[2] > 4 and arr.shape[0] <= 512 and arr.shape[1] <= 512: |
|
|
return np.transpose(arr, (2, 0, 1)) |
|
|
else: |
|
|
return arr |
|
|
if arr.ndim == 4: |
|
|
|
|
|
|
|
|
|
|
|
if arr.shape[0] > 4 and arr.shape[1] in (1, 2, 4, 8, 16, 32): |
|
|
return arr |
|
|
|
|
|
if arr.shape[3] > 4 and arr.shape[2] in (1, 2, 4, 8, 16, 32): |
|
|
return np.transpose(arr, (3, 2, 0, 1)) |
|
|
|
|
|
return np.transpose(arr, (3, 2, 0, 1)) |
|
|
return arr |
|
|
|
|
|
|
|
|
if os.path.isfile(path): |
|
|
if path.endswith('.pkl'): |
|
|
arr = _load_from_pkl(path) |
|
|
if arr is not None: |
|
|
arr = _normalize_shape(arr) |
|
|
return arr |
|
|
if path.endswith('.mat'): |
|
|
arr = _load_from_mat(path) |
|
|
if arr is not None: |
|
|
arr = _normalize_shape(arr) |
|
|
return arr |
|
|
return np.array([]) |
|
|
|
|
|
|
|
|
for root, _, files in os.walk(path): |
|
|
for f in files: |
|
|
file_path = os.path.join(root, f) |
|
|
if f.endswith('.pkl'): |
|
|
arr = _load_from_pkl(file_path) |
|
|
elif f.endswith('.mat'): |
|
|
arr = _load_from_mat(file_path) |
|
|
else: |
|
|
arr = None |
|
|
if isinstance(arr, np.ndarray): |
|
|
arr = _normalize_shape(arr) |
|
|
|
|
|
if arr.ndim == 3: |
|
|
|
|
|
for i in range(arr.shape[0]): |
|
|
specs.append(arr[i]) |
|
|
elif arr.ndim == 4: |
|
|
|
|
|
for i in range(arr.shape[0]): |
|
|
specs.append(arr[i]) |
|
|
|
|
|
return np.array(specs) if specs else np.array([]) |
|
|
|
|
|
|
|
|
def tokenizer_train( |
|
|
spectrograms, |
|
|
max_len=None, |
|
|
masking_percent=0.4, |
|
|
mask=False, |
|
|
seed=None, |
|
|
metadata=None, |
|
|
dataset_stats=None, |
|
|
normalization="dataset", |
|
|
interleaved: bool = False, |
|
|
show_progress: bool = True, |
|
|
): |
|
|
|
|
|
if max_len is None and len(spectrograms) > 0: |
|
|
max_len = calculate_max_len_from_spectrogram(spectrograms[0]) |
|
|
print(f"Auto-calculated max_len: {max_len} (from spectrogram shape {spectrograms[0].shape})") |
|
|
elif max_len is None: |
|
|
max_len = 513 |
|
|
print(f"Using default max_len: {max_len}") |
|
|
|
|
|
total_specs = len(spectrograms) |
|
|
if show_progress: |
|
|
print(f"Tokenizing {total_specs} samples...") |
|
|
|
|
|
rng: Generator = default_rng(seed) if seed is not None else default_rng() |
|
|
seq_groups = defaultdict(list) |
|
|
tensor_samples = [] |
|
|
skipped_empty = 0 |
|
|
|
|
|
if metadata is not None: |
|
|
meta_arrays = {k: np.asarray(v) for k, v in metadata.items()} |
|
|
else: |
|
|
meta_arrays = None |
|
|
|
|
|
normalization = normalization or "dataset" |
|
|
if normalization not in {"dataset", "per_sample"}: |
|
|
raise ValueError(f"Unsupported normalization mode: {normalization}") |
|
|
|
|
|
if dataset_stats is not None: |
|
|
ds_mean = float(dataset_stats.get('mean', 0.0)) |
|
|
ds_std = float(dataset_stats.get('std', 1.0)) |
|
|
if abs(ds_std) < 1e-6: |
|
|
ds_std = 1e-6 |
|
|
else: |
|
|
ds_mean = 0.0 |
|
|
ds_std = 1.0 |
|
|
|
|
|
eps = 1e-6 |
|
|
|
|
|
iterator = spectrograms |
|
|
if USE_TQDM and show_progress: |
|
|
iterator = tqdm(spectrograms, desc="Tokenizing", total=total_specs) |
|
|
|
|
|
for idx, spec in enumerate(iterator): |
|
|
spec_np = np.array(spec, dtype=np.float32, copy=False) |
|
|
mean_db = float(spec_np.mean()) |
|
|
std_db = float(spec_np.std()) |
|
|
if normalization == "per_sample": |
|
|
denom = std_db if abs(std_db) > eps else eps |
|
|
spec_proc = (spec_np - mean_db) / denom |
|
|
else: |
|
|
spec_proc = (spec_np - ds_mean) / ds_std |
|
|
|
|
|
patch = patch_maker(spec_proc, interleaved=interleaved) |
|
|
if patch.size == 0: |
|
|
skipped_empty += 1 |
|
|
continue |
|
|
|
|
|
n_patches = patch.shape[0] |
|
|
patch_size = patch.shape[1] if patch.ndim > 1 else 16 |
|
|
n_masks = int(masking_percent * n_patches) |
|
|
|
|
|
word2id = { |
|
|
'[CLS]': np.full(patch_size, 0.2, dtype=np.float32), |
|
|
'[MASK]': np.full(patch_size, 0.1, dtype=np.float32), |
|
|
} |
|
|
|
|
|
sample = make_sample(patch, word2id, n_masks, patch_size, mask=mask, rng=rng) |
|
|
|
|
|
sample_meta = {} |
|
|
if meta_arrays is not None: |
|
|
for key, values in meta_arrays.items(): |
|
|
sample_meta[key] = values[idx] |
|
|
sample_meta['power_stats'] = np.array([mean_db, std_db], dtype=np.float32) |
|
|
|
|
|
if mask: |
|
|
input_ids, masked_tokens, masked_pos = sample |
|
|
seq_len = len(input_ids) |
|
|
|
|
|
if seq_len <= 1: |
|
|
continue |
|
|
|
|
|
if masked_tokens: |
|
|
masked_tokens = np.stack(masked_tokens).astype(np.float32, copy=False) |
|
|
else: |
|
|
masked_tokens = np.empty((0, patch_size), dtype=np.float32) |
|
|
|
|
|
seq_groups[seq_len].append({ |
|
|
'input_ids': input_ids, |
|
|
'masked_pos': masked_pos, |
|
|
'masked_tokens': masked_tokens, |
|
|
'n_patches': seq_len - 1, |
|
|
**sample_meta, |
|
|
}) |
|
|
else: |
|
|
tensor_samples.append({ |
|
|
'sample': sample, |
|
|
**sample_meta, |
|
|
}) |
|
|
|
|
|
if skipped_empty: |
|
|
print(f"β οΈ Skipped {skipped_empty} spectrograms with empty patches") |
|
|
|
|
|
if mask: |
|
|
filtered_data = {k: v for k, v in seq_groups.items() if k > 0 and v} |
|
|
total_samples = sum(len(v) for v in filtered_data.values()) |
|
|
if not filtered_data: |
|
|
print("Warning: No valid data after filtering!") |
|
|
return {} |
|
|
|
|
|
if show_progress: |
|
|
print(f"β
Tokenization completed: {total_samples} samples across {len(filtered_data)} sequence lengths") |
|
|
return {k: filtered_data[k] for k in sorted(filtered_data.keys())} |
|
|
|
|
|
if not tensor_samples: |
|
|
print("Warning: No validation data after processing!") |
|
|
return torch.empty(0) |
|
|
|
|
|
stacked = torch.stack([torch.tensor(item['sample'], dtype=torch.float32) if isinstance(item['sample'], np.ndarray) |
|
|
else item['sample'] for item in tensor_samples]) |
|
|
if show_progress: |
|
|
print(f"β
Tokenization completed: {len(tensor_samples)} validation samples") |
|
|
return stacked |
|
|
|
|
|
|
|
|
def calculate_max_len_from_spectrogram(spec, patch_rows=4, patch_cols=4): |
|
|
""" |
|
|
Calculate the maximum sequence length needed for a given spectrogram size. |
|
|
|
|
|
Args: |
|
|
spec: Spectrogram tensor/array |
|
|
patch_rows: Number of rows per patch |
|
|
patch_cols: Number of columns per patch |
|
|
|
|
|
Returns: |
|
|
int: Maximum sequence length (number of patches + 1 for CLS token) |
|
|
""" |
|
|
if hasattr(spec, 'shape'): |
|
|
shape = spec.shape |
|
|
else: |
|
|
shape = spec |
|
|
|
|
|
|
|
|
if len(shape) == 3 and shape[0] == 1: |
|
|
n_rows, n_cols = shape[1], shape[2] |
|
|
elif len(shape) == 4 and shape[0] == 1 and shape[1] == 1: |
|
|
n_rows, n_cols = shape[2], shape[3] |
|
|
elif len(shape) == 2: |
|
|
n_rows, n_cols = shape[0], shape[1] |
|
|
else: |
|
|
raise ValueError(f"Unexpected spec shape: {shape}") |
|
|
|
|
|
n_patches_r = n_rows // patch_rows |
|
|
n_patches_c = n_cols // patch_cols |
|
|
total_patches = n_patches_r * n_patches_c |
|
|
|
|
|
return total_patches + 1 |
|
|
|
|
|
|
|
|
def patch_maker(spec, patch_rows=4, patch_cols=4, interleaved: bool = False): |
|
|
|
|
|
if len(spec.shape) == 3 and spec.shape[0] == 1: |
|
|
spec = spec.squeeze(0) |
|
|
elif len(spec.shape) == 4 and spec.shape[0] == 1 and spec.shape[1] == 1: |
|
|
spec = spec.squeeze(0).squeeze(0) |
|
|
elif len(spec.shape) == 2: |
|
|
pass |
|
|
else: |
|
|
raise ValueError(f"Unexpected spec shape: {spec.shape}") |
|
|
|
|
|
n_rows, n_cols = spec.shape |
|
|
|
|
|
if interleaved: |
|
|
|
|
|
|
|
|
n_patches_r = n_rows // patch_rows |
|
|
n_complex_cols = n_cols // 2 |
|
|
n_patches_c = n_complex_cols // patch_cols |
|
|
|
|
|
if n_patches_r == 0 or n_patches_c == 0: |
|
|
print(f"β PATCH CREATION FAILED (interleaved): {n_rows}x{n_cols} too small for {patch_rows}x{patch_cols}") |
|
|
return np.array([]) |
|
|
|
|
|
|
|
|
cropped = spec[:n_patches_r * patch_rows, :n_patches_c * patch_cols * 2] |
|
|
if cropped.size == 0: |
|
|
print(f"β οΈ No patches generated from {n_rows}x{n_cols} spectrogram (interleaved)") |
|
|
return np.array([]) |
|
|
|
|
|
|
|
|
reshaped = cropped.reshape(n_patches_r, patch_rows, n_patches_c, patch_cols * 2) |
|
|
result = reshaped.transpose(0, 2, 1, 3).reshape(-1, patch_rows * patch_cols * 2) |
|
|
return result.astype(np.float32, copy=False) |
|
|
|
|
|
|
|
|
n_patches_r, n_patches_c = n_rows // patch_rows, n_cols // patch_cols |
|
|
|
|
|
if n_patches_r == 0 or n_patches_c == 0: |
|
|
print(f"β PATCH CREATION FAILED: spectrogram {n_rows}x{n_cols} too small for {patch_rows}x{patch_cols} patches") |
|
|
print(f" n_patches_r: {n_patches_r}, n_patches_c: {n_patches_c}") |
|
|
return np.array([]) |
|
|
|
|
|
cropped = spec[:n_patches_r * patch_rows, :n_patches_c * patch_cols] |
|
|
if cropped.size == 0: |
|
|
print(f"β οΈ No patches generated from {n_rows}x{n_cols} spectrogram") |
|
|
return np.array([]) |
|
|
|
|
|
reshaped = cropped.reshape(n_patches_r, patch_rows, n_patches_c, patch_cols) |
|
|
result = reshaped.transpose(0, 2, 1, 3).reshape(-1, patch_rows * patch_cols) |
|
|
return result.astype(np.float32, copy=False) |
|
|
|
|
|
|
|
|
def make_sample(tokens, word2id, n_masks, patch_size, mask=True, rng: Generator | None = None): |
|
|
rng = rng or default_rng() |
|
|
input_ids = np.vstack((word2id['[CLS]'], tokens)) |
|
|
|
|
|
if not mask: |
|
|
return torch.tensor(input_ids, dtype=torch.float32) |
|
|
|
|
|
n_patches = tokens.shape[0] |
|
|
if n_masks <= 0 or n_patches == 0: |
|
|
masked_pos = np.empty(0, dtype=np.int64) |
|
|
else: |
|
|
n_masks = min(n_masks, n_patches) |
|
|
mask_candidates = np.arange(1, n_patches + 1) |
|
|
masked_pos = rng.choice(mask_candidates, size=n_masks, replace=False) |
|
|
|
|
|
masked_tokens = [] |
|
|
|
|
|
for pos in masked_pos: |
|
|
masked_tokens.append(input_ids[pos].astype(np.float32, copy=True)) |
|
|
rnd = rng.random() |
|
|
if rnd < 0.1: |
|
|
input_ids[pos] = rng.random(patch_size, dtype=np.float32) |
|
|
elif rnd < 0.9: |
|
|
input_ids[pos] = word2id['[MASK]'] |
|
|
|
|
|
return [input_ids.astype(np.float32, copy=False), masked_tokens, masked_pos] |
|
|
|
|
|
|
|
|
def patch_reconstructor(patches, rows, cols, patch_rows=4, patch_cols=4): |
|
|
if isinstance(patches, torch.Tensor): patches = patches.detach().cpu().numpy() |
|
|
batch_size, num_patches, _ = patches.shape |
|
|
n_h, n_w = rows // patch_rows, cols // patch_cols |
|
|
patches = patches.reshape(batch_size, n_h, n_w, patch_rows, patch_cols) |
|
|
reconstructed = np.zeros((batch_size, rows, cols)) |
|
|
for i in range(n_h): |
|
|
for j in range(n_w): |
|
|
reconstructed[:, i*patch_rows:(i+1)*patch_rows, j*patch_cols:(j+1)*patch_cols] = patches[:, i, j] |
|
|
return reconstructed |
|
|
|
|
|
|
|
|
def plot_radar_chart(names, opt_scores, base_scores, save_path="results/chart.png"): |
|
|
try: |
|
|
import matplotlib.pyplot as plt |
|
|
from math import pi |
|
|
N = len(names) |
|
|
angles = [n/float(N)*2*pi for n in range(N)] + [0] |
|
|
fig, ax = plt.subplots(subplot_kw=dict(projection='polar')) |
|
|
ax.plot(angles, opt_scores + opt_scores[:1], 'o-', label='Optimized', color='#1f77b4') |
|
|
ax.fill(angles, opt_scores + opt_scores[:1], alpha=0.25, color='#1f77b4') |
|
|
ax.plot(angles, base_scores + base_scores[:1], 'o-', label='Baseline', color='#ff7f0e') |
|
|
ax.fill(angles, base_scores + base_scores[:1], alpha=0.25, color='#ff7f0e') |
|
|
ax.set_xticks(angles[:-1]); ax.set_xticklabels(names) |
|
|
ax.set_ylim(0, 1); ax.legend(); ax.grid(True, alpha=0.3) |
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight'); plt.close() |
|
|
print(f"π Chart saved: {save_path}") |
|
|
except: print("β οΈ Matplotlib unavailable") |
|
|
|
|
|
|
|
|
class MaskedSpectrogramDataset(torch.utils.data.Dataset): |
|
|
"""Lazy dataset that materializes masked spectrogram samples per access.""" |
|
|
|
|
|
def __init__(self, samples): |
|
|
self.samples = samples |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.samples) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
sample = self.samples[idx] |
|
|
input_ids = torch.from_numpy(sample['input_ids']).float() |
|
|
masked_tokens = torch.from_numpy(sample['masked_tokens']).float() |
|
|
masked_pos = torch.from_numpy(sample['masked_pos']).long() |
|
|
snr_db = torch.tensor(sample.get('snr_db', 0.0), dtype=torch.float32) |
|
|
doppler_id = torch.tensor(sample.get('doppler_id', 0), dtype=torch.long) |
|
|
power_stats = torch.tensor(sample.get('power_stats', np.zeros(2, dtype=np.float32)), dtype=torch.float32) |
|
|
snr_id = torch.tensor(sample.get('snr_id', -1), dtype=torch.long) |
|
|
modulation_id = torch.tensor(sample.get('modulation_id', -1), dtype=torch.long) |
|
|
return ( |
|
|
input_ids, |
|
|
masked_tokens, |
|
|
masked_pos, |
|
|
snr_db, |
|
|
doppler_id, |
|
|
power_stats, |
|
|
snr_id, |
|
|
modulation_id, |
|
|
) |
|
|
|
|
|
|
|
|
def create_train_dataloader(data, batch_size, shuffle, num_workers=0): |
|
|
loaders = {} |
|
|
for seq_len, group in data.items(): |
|
|
print(f"Dataloader: Processing seq_len={seq_len} with {len(group)} samples") |
|
|
|
|
|
group_labels = None |
|
|
if isinstance(group, tuple) and len(group) == 2: |
|
|
group, group_labels = group |
|
|
|
|
|
if isinstance(group[0], dict): |
|
|
print(" Processing as masked data (dict structure)") |
|
|
dataset = MaskedSpectrogramDataset(group) |
|
|
loaders[seq_len] = DataLoader( |
|
|
dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=shuffle, |
|
|
pin_memory=True, |
|
|
num_workers=num_workers, |
|
|
) |
|
|
print(f" Created DataLoader with {len(dataset)} samples (lazy loading)") |
|
|
elif isinstance(group[0], list): |
|
|
print(" Processing as masked data (list structure)") |
|
|
ids, tokens, pos = zip(*group) |
|
|
|
|
|
if group_labels is not None: |
|
|
label_tensor = torch.tensor(group_labels, dtype=torch.long) |
|
|
else: |
|
|
label_tensor = torch.zeros(len(group), dtype=torch.long) |
|
|
dataset = TensorDataset(torch.tensor(ids, dtype=torch.float32), |
|
|
torch.tensor(tokens, dtype=torch.float32), |
|
|
torch.tensor(pos, dtype=torch.long), |
|
|
label_tensor) |
|
|
loaders[seq_len] = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True, num_workers=num_workers) |
|
|
print(f" Created DataLoader with {len(dataset)} samples (with labels)") |
|
|
else: |
|
|
print(" Processing as non-masked data") |
|
|
if isinstance(group[0], torch.Tensor): |
|
|
dataset = TensorDataset(*group) |
|
|
else: |
|
|
tensor_group = [torch.tensor(g, dtype=torch.float32) for g in group] |
|
|
dataset = TensorDataset(*tensor_group) |
|
|
loaders[seq_len] = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True, num_workers=num_workers) |
|
|
print(f" Created DataLoader with {len(dataset)} samples") |
|
|
return loaders |
|
|
|
|
|
|
|
|
def train_lwm( |
|
|
model, |
|
|
train_loaders, |
|
|
val_loaders, |
|
|
optimizer, |
|
|
scheduler, |
|
|
epochs, |
|
|
device, |
|
|
save_dir="models", |
|
|
log_file="training_log.csv", |
|
|
checkpoint_suffix: str = "", |
|
|
distributed_context: Optional[Dict[str, Any]] = None, |
|
|
): |
|
|
distributed_context = distributed_context or {} |
|
|
is_distributed = distributed_context.get("is_distributed", False) |
|
|
rank = distributed_context.get("rank", 0) |
|
|
world_size = max(1, distributed_context.get("world_size", 1)) |
|
|
is_primary = distributed_context.get("is_primary", rank == 0) |
|
|
|
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
log_file_path = f"{save_dir}/training_log.csv" |
|
|
use_tensorboard = False |
|
|
writer = None |
|
|
|
|
|
|
|
|
if is_primary: |
|
|
try: |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
tensorboard_dir = f"{save_dir}/tensorboard" |
|
|
writer = SummaryWriter(tensorboard_dir) |
|
|
print(f"π TensorBoard logs will be saved to: {tensorboard_dir}") |
|
|
use_tensorboard = True |
|
|
except (ImportError, AttributeError) as e: |
|
|
print(f"β οΈ TensorBoard not available ({e}), using CSV logging instead") |
|
|
|
|
|
with open(log_file_path, 'w') as f: |
|
|
f.write("epoch,train_loss,val_loss,val_nmse,lr\n") |
|
|
|
|
|
criterion = nn.MSELoss(reduction='sum') |
|
|
best_mse = float('inf') |
|
|
train_losses, val_losses, val_nmse_losses = [], [], [] |
|
|
|
|
|
|
|
|
patience = 3 |
|
|
patience_counter = 0 |
|
|
|
|
|
def _sync_sum(value: float) -> float: |
|
|
if not is_distributed or not dist.is_available() or not dist.is_initialized(): |
|
|
return float(value) |
|
|
tensor = torch.tensor(value, dtype=torch.float64, device=device) |
|
|
dist.all_reduce(tensor, op=dist.ReduceOp.SUM) |
|
|
return float(tensor.item()) |
|
|
|
|
|
for epoch in range(epochs): |
|
|
|
|
|
model.train() |
|
|
train_mse, train_samples = 0.0, 0 |
|
|
if is_primary: |
|
|
print(f"\nEpoch {epoch+1}/{epochs}") |
|
|
for loader in train_loaders.values(): |
|
|
pbar = tqdm( |
|
|
loader, |
|
|
desc="Train", |
|
|
postfix={"loss": 0.0, "avg_loss": 0.0}, |
|
|
disable=not is_primary, |
|
|
) |
|
|
for batch in pbar: |
|
|
optimizer.zero_grad() |
|
|
|
|
|
if len(batch) >= 3: |
|
|
ids, tokens, pos = batch[0], batch[1], batch[2] |
|
|
else: |
|
|
raise ValueError(f"Unexpected batch length: {len(batch)}") |
|
|
|
|
|
ids = ids.to(device).float() |
|
|
tokens = tokens.to(device).float() |
|
|
pos = pos.to(device).long() |
|
|
|
|
|
logits = model(ids, pos)[0] |
|
|
loss = criterion(tokens, logits) |
|
|
loss.backward(); optimizer.step(); scheduler.step() |
|
|
train_mse += loss.item(); train_samples += ids.shape[0] |
|
|
|
|
|
|
|
|
current_avg_loss = train_mse / max(train_samples, 1) |
|
|
batch_size = ids.shape[0] |
|
|
if is_primary: |
|
|
pbar.set_postfix({ |
|
|
"loss": f"{loss.item()/batch_size:.4f}", |
|
|
"avg_loss": f"{current_avg_loss:.4f}" |
|
|
}) |
|
|
|
|
|
total_train_mse = _sync_sum(train_mse) |
|
|
total_train_samples = _sync_sum(train_samples) |
|
|
train_mse = total_train_mse / max(total_train_samples, 1) |
|
|
train_losses.append(train_mse) |
|
|
|
|
|
|
|
|
if use_tensorboard and writer: |
|
|
writer.add_scalar('Loss/train', train_mse, epoch + 1) |
|
|
writer.add_scalar('Learning_Rate', optimizer.param_groups[0]['lr'], epoch + 1) |
|
|
elif is_primary: |
|
|
|
|
|
lr = optimizer.param_groups[0]['lr'] |
|
|
with open(log_file_path, 'a') as f: |
|
|
f.write(f"{epoch+1},{train_mse},,,{lr}\n") |
|
|
|
|
|
|
|
|
model.eval() |
|
|
val_mse, val_nmse, val_samples = 0.0, 0.0, 0 |
|
|
with torch.no_grad(): |
|
|
for loader in val_loaders.values(): |
|
|
progress_bar = tqdm( |
|
|
loader, |
|
|
desc="Val", |
|
|
postfix={"mse": 0.0, "nmse": 0.0}, |
|
|
disable=not is_primary, |
|
|
) |
|
|
for batch in progress_bar: |
|
|
|
|
|
if len(batch) >= 3: |
|
|
|
|
|
ids, tokens, pos = batch[0], batch[1], batch[2] |
|
|
|
|
|
ids = ids.to(device).float() |
|
|
tokens = tokens.to(device).float() |
|
|
pos = pos.to(device).long() |
|
|
|
|
|
logits = model(ids, pos)[0] |
|
|
elif len(batch) == 1: |
|
|
|
|
|
val_tensor = batch[0].to(device, dtype=torch.float32) if 'mps' in str(device) else batch[0].to(device) |
|
|
|
|
|
output = model(val_tensor) |
|
|
|
|
|
|
|
|
model_module = model.module if hasattr(model, 'module') else model |
|
|
logits = model_module.decoder(output) + model_module.decoder_bias |
|
|
|
|
|
tokens = val_tensor |
|
|
ids = val_tensor |
|
|
else: |
|
|
raise ValueError(f"Unexpected batch length: {len(batch)}") |
|
|
|
|
|
val_mse += criterion(tokens, logits).item() |
|
|
|
|
|
tokens_np = tokens.float().cpu().numpy().astype(np.float32) if 'mps' in str(device) else tokens.cpu().numpy() |
|
|
logits_np = logits.float().cpu().numpy().astype(np.float32) if 'mps' in str(device) else logits.cpu().numpy() |
|
|
nmse_val = nmse_loss(tokens_np, logits_np) |
|
|
val_nmse += nmse_val * ids.shape[0] |
|
|
val_samples += ids.shape[0] |
|
|
|
|
|
|
|
|
current_mse = val_mse / max(val_samples, 1) |
|
|
current_nmse = val_nmse / max(val_samples, 1) |
|
|
current_nmse_db = 10 * np.log10(max(current_nmse, 1e-8)) |
|
|
batch_size = ids.shape[0] |
|
|
if is_primary: |
|
|
progress_bar.set_postfix({ |
|
|
"mse": f"{current_mse:.4f}", |
|
|
"nmse": f"{current_nmse_db:.2f}dB" |
|
|
}) |
|
|
|
|
|
total_val_mse = _sync_sum(val_mse) |
|
|
total_val_nmse = _sync_sum(val_nmse) |
|
|
total_val_samples = _sync_sum(val_samples) |
|
|
val_mse = total_val_mse / max(total_val_samples, 1) |
|
|
val_nmse = total_val_nmse / max(total_val_samples, 1) |
|
|
val_losses.append(val_mse) |
|
|
val_nmse_losses.append(val_nmse) |
|
|
|
|
|
|
|
|
if use_tensorboard and writer: |
|
|
writer.add_scalar('Loss/validation', val_mse, epoch + 1) |
|
|
writer.add_scalar('Loss/nmse', val_nmse, epoch + 1) |
|
|
elif is_primary: |
|
|
|
|
|
lr = optimizer.param_groups[0]['lr'] |
|
|
|
|
|
with open(log_file_path, 'r') as f: |
|
|
lines = f.readlines() |
|
|
if lines: |
|
|
|
|
|
last_line = lines[-1].strip() |
|
|
parts = last_line.split(',') |
|
|
if len(parts) >= 5: |
|
|
parts[2] = f"{val_mse}" |
|
|
parts[3] = f"{val_nmse}" |
|
|
lines[-1] = ','.join(parts) + '\n' |
|
|
with open(log_file_path, 'w') as f: |
|
|
f.writelines(lines) |
|
|
|
|
|
if val_mse < best_mse: |
|
|
best_mse = val_mse |
|
|
patience_counter = 0 |
|
|
suffix = checkpoint_suffix or "" |
|
|
if is_primary: |
|
|
path = f"{save_dir}/lwm_epoch{epoch+1}_val{val_mse:.4f}{suffix}.pth" |
|
|
torch.save(model.state_dict(), path) |
|
|
print(f"β
Saved: {path}") |
|
|
else: |
|
|
patience_counter += 1 |
|
|
if is_primary: |
|
|
print(f"βΈοΈ No improvement for {patience_counter}/{patience} epochs") |
|
|
|
|
|
|
|
|
if patience_counter >= patience: |
|
|
if is_primary: |
|
|
print(f"π Early stopping triggered after {epoch+1} epochs") |
|
|
print(f" Best validation MSE: {best_mse:.4f}") |
|
|
break |
|
|
|
|
|
if is_primary: |
|
|
print(f"Train MSE: {train_mse:.4f}") |
|
|
val_nmse_db = 10 * np.log10(max(val_nmse, 1e-8)) |
|
|
print(f"Val MSE: {val_mse:.4f}, NMSE: {val_nmse_db:.2f}dB") |
|
|
|
|
|
|
|
|
|
|
|
while len(val_losses) < len(train_losses): |
|
|
val_losses.append(None) |
|
|
while len(val_nmse_losses) < len(train_losses): |
|
|
val_nmse_losses.append(None) |
|
|
|
|
|
|
|
|
|
|
|
def convert_numpy_types(obj): |
|
|
"""Convert numpy types to Python native types for JSON serialization""" |
|
|
if isinstance(obj, np.floating): |
|
|
return float(obj) |
|
|
elif isinstance(obj, np.integer): |
|
|
return int(obj) |
|
|
elif isinstance(obj, np.ndarray): |
|
|
return obj.tolist() |
|
|
elif isinstance(obj, list): |
|
|
return [convert_numpy_types(item) for item in obj] |
|
|
elif isinstance(obj, dict): |
|
|
return {key: convert_numpy_types(value) for key, value in obj.items()} |
|
|
else: |
|
|
return obj |
|
|
|
|
|
training_history = { |
|
|
'train_losses': convert_numpy_types(train_losses), |
|
|
'val_losses': convert_numpy_types(val_losses), |
|
|
'val_nmse_losses': convert_numpy_types(val_nmse_losses), |
|
|
'epochs': list(range(1, epochs + 1)), |
|
|
'best_val_mse': convert_numpy_types(best_mse) |
|
|
} |
|
|
|
|
|
if is_primary: |
|
|
import json |
|
|
history_file = f"{save_dir}/training_history.json" |
|
|
with open(history_file, 'w') as f: |
|
|
json.dump(training_history, f, indent=2) |
|
|
print(f"π Training history saved: {history_file}") |
|
|
|
|
|
|
|
|
if use_tensorboard and writer: |
|
|
writer.close() |
|
|
print(f"π TensorBoard logs saved: {tensorboard_dir}") |
|
|
else: |
|
|
print(f"π Training logs saved: {log_file_path}") |
|
|
elif use_tensorboard and writer: |
|
|
writer.close() |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def nmse_loss(y_true, y_pred): |
|
|
if isinstance(y_true, torch.Tensor): |
|
|
mse = torch.mean((y_true - y_pred) ** 2) |
|
|
power = torch.mean(y_true ** 2) |
|
|
else: |
|
|
mse = np.mean((y_true - y_pred) ** 2) |
|
|
power = np.mean(y_true ** 2) |
|
|
return mse / (power + 1e-8) |
|
|
|