| import sys | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| import torch | |
| import numpy as np | |
| from LWMTemporal.data.datasets import AngleDelayDatasetConfig, AngleDelaySequenceDataset | |
| from LWMTemporal.models.lwm import ( | |
| LWMBackbone, | |
| LWMConfig, | |
| ComplexPatchTokenizer, | |
| masked_nmse_loss, | |
| ) | |
| # ----- 1. Load one sequence (complex tensor) ----- | |
| data_cfg = AngleDelayDatasetConfig(raw_path=Path("examples/data/parow.p")) | |
| dataset = AngleDelaySequenceDataset(data_cfg) | |
| sequence = dataset[0]["sequence"].unsqueeze(0) # (1, T, N, M) | |
| sequence = sequence[:, :11] # keep only the first 11 time steps | |
| print("Sequence shape:", sequence.shape) # expect (1, 11, 32, 8) | |
| # ----- 2. Tokenise and select tokens to mask ----- | |
| tokenizer = ComplexPatchTokenizer(phase_mode="real_imag") | |
| tokens, base_mask = tokenizer(sequence, patch_size=(1, 1)) # tokens: (B, S, D) | |
| B, S, D = tokens.shape | |
| mask_ratio = 0.60 # choose the fraction to hide | |
| mask = base_mask.clone() | |
| # randomly choose the positions that will be hidden | |
| for b in range(B): | |
| num_mask = int(mask_ratio * S) | |
| masked_positions = torch.randperm(S)[:num_mask] | |
| mask[b, masked_positions] = True | |
| # create the corrupted input by zeroing the masked tokens | |
| corrupted_tokens = tokens.clone() | |
| corrupted_tokens[mask] = 0.0 | |
| # ----- 3. Load the pretrained backbone ----- | |
| # Need max_seq_len >= S (here 11 * 32 * 8 = 2816) | |
| cfg = LWMConfig( | |
| patch_size=(1, 1), | |
| phase_mode="real_imag", | |
| embed_dim=32, | |
| depth=12, | |
| num_heads=8, | |
| mlp_ratio=4.0, | |
| same_frame_window=2, | |
| temporal_offsets=(-4, -3, -2, -1, 1, 2, 3), | |
| temporal_spatial_window=2, | |
| temporal_drift_h=1, | |
| temporal_drift_w=1, | |
| routing_topk_enable=True, | |
| topk_per_head=True, | |
| max_seq_len=2816, # 2816 | |
| ) | |
| backbone = LWMBackbone.from_pretrained(Path("checkpoints/pytorch_model.bin"), config=cfg) | |
| backbone.eval() | |
| # ---- 4. Run reconstruction and compute NMSE on the masked positions ----- | |
| with torch.no_grad(): | |
| # compute H, W from the sequence (N and M dimensions) | |
| T = sequence.size(1) | |
| H = sequence.size(2) | |
| W = sequence.size(3) | |
| outputs = backbone.forward_tokens(corrupted_tokens, mask, T, H, W, return_cls=False) | |
| reconstructed = outputs["reconstruction"] | |
| nmse = masked_nmse_loss(reconstructed, tokens, mask) | |
| nmse_db = 10 * torch.log10(nmse) | |
| print(f"Masked {mask_ratio*100:.1f}% of tokens ({mask.sum().item()} / {S})") | |
| print(f"NMSE (linear): {nmse.item():.6f}") | |
| print(f"NMSE (dB): {nmse_db.item():.2f} dB") | |
| # import torch | |
| # from pathlib import Path | |
| # from LWMTemporal.data.datasets import AngleDelayDatasetConfig, AngleDelaySequenceDataset | |
| # from LWMTemporal.models.lwm import ( | |
| # LWMBackbone, | |
| # LWMConfig, | |
| # ComplexPatchTokenizer, | |
| # masked_nmse_loss, | |
| # ) | |
| # # --- 1. Load one sample from the dataset and keep the first 11 frames --- | |
| # cfg = AngleDelayDatasetConfig(raw_path=Path("LWMTemporal/data/parow.p")) | |
| # dataset = AngleDelaySequenceDataset(cfg) | |
| # sequence = dataset[0]["sequence"].unsqueeze(0)[:, :11] # (1, 11, 32, 8) | |
| # # --- 2. Tokenise and randomly mask 40% of the tokens --- | |
| # tokenizer = ComplexPatchTokenizer(phase_mode="real_imag") | |
| # tokens, base_mask = tokenizer(sequence, patch_size=(1, 1)) | |
| # mask = base_mask.clone() | |
| # B, S, _ = tokens.shape | |
| # mask_fraction = 0.40 | |
| # for b in range(B): | |
| # num_mask = int(mask_fraction * S) | |
| # masked_positions = torch.randperm(S)[:num_mask] | |
| # mask[b, masked_positions] = True | |
| # corrupted_tokens = tokens.clone() | |
| # corrupted_tokens[mask] = 0.0 | |
| # T = sequence.size(1) | |
| # H = sequence.size(2) | |
| # W = sequence.size(3) | |
| # # --- 3. Helper to run a model and report NMSE --- | |
| # def run_model(model: LWMBackbone, label: str) -> None: | |
| # model.eval() | |
| # with torch.no_grad(): | |
| # outputs = model.forward_tokens(corrupted_tokens, mask, T, H, W, return_cls=False) | |
| # reconstructed = outputs["reconstruction"] | |
| # nmse = masked_nmse_loss(reconstructed, tokens, mask) | |
| # nmse_db = 10 * torch.log10(nmse) | |
| # print(f"{label:>12}: NMSE = {nmse.item():.6f} ({nmse_db.item():.2f} dB)") | |
| # # --- 4. Random-weights model --- | |
| # cfg_random = LWMConfig(max_seq_len=11 * sequence.size(2) * sequence.size(3)) | |
| # model_random = LWMBackbone(cfg_random) | |
| # run_model(model_random, "random init") | |
| # # --- 5. Pretrained checkpoint --- | |
| # cfg_pretrained = LWMConfig(max_seq_len=11 * sequence.size(2) * sequence.size(3)) | |
| # model_ckpt = LWMBackbone.from_pretrained(Path("LWMTemporal/models"), config=cfg_pretrained) | |
| # run_model(model_ckpt, "checkpoint") |