File size: 4,630 Bytes
164610c da859b5 164610c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import torch
import numpy as np
from LWMTemporal.data.datasets import AngleDelayDatasetConfig, AngleDelaySequenceDataset
from LWMTemporal.models.lwm import (
LWMBackbone,
LWMConfig,
ComplexPatchTokenizer,
masked_nmse_loss,
)
# ----- 1. Load one sequence (complex tensor) -----
data_cfg = AngleDelayDatasetConfig(raw_path=Path("examples/data/parow.p"))
dataset = AngleDelaySequenceDataset(data_cfg)
sequence = dataset[0]["sequence"].unsqueeze(0) # (1, T, N, M)
sequence = sequence[:, :11] # keep only the first 11 time steps
print("Sequence shape:", sequence.shape) # expect (1, 11, 32, 8)
# ----- 2. Tokenise and select tokens to mask -----
tokenizer = ComplexPatchTokenizer(phase_mode="real_imag")
tokens, base_mask = tokenizer(sequence, patch_size=(1, 1)) # tokens: (B, S, D)
B, S, D = tokens.shape
mask_ratio = 0.60 # choose the fraction to hide
mask = base_mask.clone()
# randomly choose the positions that will be hidden
for b in range(B):
num_mask = int(mask_ratio * S)
masked_positions = torch.randperm(S)[:num_mask]
mask[b, masked_positions] = True
# create the corrupted input by zeroing the masked tokens
corrupted_tokens = tokens.clone()
corrupted_tokens[mask] = 0.0
# ----- 3. Load the pretrained backbone -----
# Need max_seq_len >= S (here 11 * 32 * 8 = 2816)
cfg = LWMConfig(
patch_size=(1, 1),
phase_mode="real_imag",
embed_dim=32,
depth=12,
num_heads=8,
mlp_ratio=4.0,
same_frame_window=2,
temporal_offsets=(-4, -3, -2, -1, 1, 2, 3),
temporal_spatial_window=2,
temporal_drift_h=1,
temporal_drift_w=1,
routing_topk_enable=True,
topk_per_head=True,
max_seq_len=2816, # 2816
)
backbone = LWMBackbone.from_pretrained(Path("checkpoints/pytorch_model.bin"), config=cfg)
backbone.eval()
# ---- 4. Run reconstruction and compute NMSE on the masked positions -----
with torch.no_grad():
# compute H, W from the sequence (N and M dimensions)
T = sequence.size(1)
H = sequence.size(2)
W = sequence.size(3)
outputs = backbone.forward_tokens(corrupted_tokens, mask, T, H, W, return_cls=False)
reconstructed = outputs["reconstruction"]
nmse = masked_nmse_loss(reconstructed, tokens, mask)
nmse_db = 10 * torch.log10(nmse)
print(f"Masked {mask_ratio*100:.1f}% of tokens ({mask.sum().item()} / {S})")
print(f"NMSE (linear): {nmse.item():.6f}")
print(f"NMSE (dB): {nmse_db.item():.2f} dB")
# import torch
# from pathlib import Path
# from LWMTemporal.data.datasets import AngleDelayDatasetConfig, AngleDelaySequenceDataset
# from LWMTemporal.models.lwm import (
# LWMBackbone,
# LWMConfig,
# ComplexPatchTokenizer,
# masked_nmse_loss,
# )
# # --- 1. Load one sample from the dataset and keep the first 11 frames ---
# cfg = AngleDelayDatasetConfig(raw_path=Path("LWMTemporal/data/parow.p"))
# dataset = AngleDelaySequenceDataset(cfg)
# sequence = dataset[0]["sequence"].unsqueeze(0)[:, :11] # (1, 11, 32, 8)
# # --- 2. Tokenise and randomly mask 40% of the tokens ---
# tokenizer = ComplexPatchTokenizer(phase_mode="real_imag")
# tokens, base_mask = tokenizer(sequence, patch_size=(1, 1))
# mask = base_mask.clone()
# B, S, _ = tokens.shape
# mask_fraction = 0.40
# for b in range(B):
# num_mask = int(mask_fraction * S)
# masked_positions = torch.randperm(S)[:num_mask]
# mask[b, masked_positions] = True
# corrupted_tokens = tokens.clone()
# corrupted_tokens[mask] = 0.0
# T = sequence.size(1)
# H = sequence.size(2)
# W = sequence.size(3)
# # --- 3. Helper to run a model and report NMSE ---
# def run_model(model: LWMBackbone, label: str) -> None:
# model.eval()
# with torch.no_grad():
# outputs = model.forward_tokens(corrupted_tokens, mask, T, H, W, return_cls=False)
# reconstructed = outputs["reconstruction"]
# nmse = masked_nmse_loss(reconstructed, tokens, mask)
# nmse_db = 10 * torch.log10(nmse)
# print(f"{label:>12}: NMSE = {nmse.item():.6f} ({nmse_db.item():.2f} dB)")
# # --- 4. Random-weights model ---
# cfg_random = LWMConfig(max_seq_len=11 * sequence.size(2) * sequence.size(3))
# model_random = LWMBackbone(cfg_random)
# run_model(model_random, "random init")
# # --- 5. Pretrained checkpoint ---
# cfg_pretrained = LWMConfig(max_seq_len=11 * sequence.size(2) * sequence.size(3))
# model_ckpt = LWMBackbone.from_pretrained(Path("LWMTemporal/models"), config=cfg_pretrained)
# run_model(model_ckpt, "checkpoint") |