import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) import torch import numpy as np from LWMTemporal.data.datasets import AngleDelayDatasetConfig, AngleDelaySequenceDataset from LWMTemporal.models.lwm import ( LWMBackbone, LWMConfig, ComplexPatchTokenizer, masked_nmse_loss, ) # ----- 1. Load one sequence (complex tensor) ----- data_cfg = AngleDelayDatasetConfig(raw_path=Path("examples/data/parow.p")) dataset = AngleDelaySequenceDataset(data_cfg) sequence = dataset[0]["sequence"].unsqueeze(0) # (1, T, N, M) sequence = sequence[:, :11] # keep only the first 11 time steps print("Sequence shape:", sequence.shape) # expect (1, 11, 32, 8) # ----- 2. Tokenise and select tokens to mask ----- tokenizer = ComplexPatchTokenizer(phase_mode="real_imag") tokens, base_mask = tokenizer(sequence, patch_size=(1, 1)) # tokens: (B, S, D) B, S, D = tokens.shape mask_ratio = 0.60 # choose the fraction to hide mask = base_mask.clone() # randomly choose the positions that will be hidden for b in range(B): num_mask = int(mask_ratio * S) masked_positions = torch.randperm(S)[:num_mask] mask[b, masked_positions] = True # create the corrupted input by zeroing the masked tokens corrupted_tokens = tokens.clone() corrupted_tokens[mask] = 0.0 # ----- 3. Load the pretrained backbone ----- # Need max_seq_len >= S (here 11 * 32 * 8 = 2816) cfg = LWMConfig( patch_size=(1, 1), phase_mode="real_imag", embed_dim=32, depth=12, num_heads=8, mlp_ratio=4.0, same_frame_window=2, temporal_offsets=(-4, -3, -2, -1, 1, 2, 3), temporal_spatial_window=2, temporal_drift_h=1, temporal_drift_w=1, routing_topk_enable=True, topk_per_head=True, max_seq_len=2816, # 2816 ) backbone = LWMBackbone.from_pretrained(Path("checkpoints/pytorch_model.bin"), config=cfg) backbone.eval() # ---- 4. Run reconstruction and compute NMSE on the masked positions ----- with torch.no_grad(): # compute H, W from the sequence (N and M dimensions) T = sequence.size(1) H = sequence.size(2) W = sequence.size(3) outputs = backbone.forward_tokens(corrupted_tokens, mask, T, H, W, return_cls=False) reconstructed = outputs["reconstruction"] nmse = masked_nmse_loss(reconstructed, tokens, mask) nmse_db = 10 * torch.log10(nmse) print(f"Masked {mask_ratio*100:.1f}% of tokens ({mask.sum().item()} / {S})") print(f"NMSE (linear): {nmse.item():.6f}") print(f"NMSE (dB): {nmse_db.item():.2f} dB") # import torch # from pathlib import Path # from LWMTemporal.data.datasets import AngleDelayDatasetConfig, AngleDelaySequenceDataset # from LWMTemporal.models.lwm import ( # LWMBackbone, # LWMConfig, # ComplexPatchTokenizer, # masked_nmse_loss, # ) # # --- 1. Load one sample from the dataset and keep the first 11 frames --- # cfg = AngleDelayDatasetConfig(raw_path=Path("LWMTemporal/data/parow.p")) # dataset = AngleDelaySequenceDataset(cfg) # sequence = dataset[0]["sequence"].unsqueeze(0)[:, :11] # (1, 11, 32, 8) # # --- 2. Tokenise and randomly mask 40% of the tokens --- # tokenizer = ComplexPatchTokenizer(phase_mode="real_imag") # tokens, base_mask = tokenizer(sequence, patch_size=(1, 1)) # mask = base_mask.clone() # B, S, _ = tokens.shape # mask_fraction = 0.40 # for b in range(B): # num_mask = int(mask_fraction * S) # masked_positions = torch.randperm(S)[:num_mask] # mask[b, masked_positions] = True # corrupted_tokens = tokens.clone() # corrupted_tokens[mask] = 0.0 # T = sequence.size(1) # H = sequence.size(2) # W = sequence.size(3) # # --- 3. Helper to run a model and report NMSE --- # def run_model(model: LWMBackbone, label: str) -> None: # model.eval() # with torch.no_grad(): # outputs = model.forward_tokens(corrupted_tokens, mask, T, H, W, return_cls=False) # reconstructed = outputs["reconstruction"] # nmse = masked_nmse_loss(reconstructed, tokens, mask) # nmse_db = 10 * torch.log10(nmse) # print(f"{label:>12}: NMSE = {nmse.item():.6f} ({nmse_db.item():.2f} dB)") # # --- 4. Random-weights model --- # cfg_random = LWMConfig(max_seq_len=11 * sequence.size(2) * sequence.size(3)) # model_random = LWMBackbone(cfg_random) # run_model(model_random, "random init") # # --- 5. Pretrained checkpoint --- # cfg_pretrained = LWMConfig(max_seq_len=11 * sequence.size(2) * sequence.size(3)) # model_ckpt = LWMBackbone.from_pretrained(Path("LWMTemporal/models"), config=cfg_pretrained) # run_model(model_ckpt, "checkpoint")