lwm-temporal / examples /example_reconstruction.py
wi-lab's picture
Add HF export + temporal updates
da859b5
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import torch
import numpy as np
from LWMTemporal.data.datasets import AngleDelayDatasetConfig, AngleDelaySequenceDataset
from LWMTemporal.models.lwm import (
LWMBackbone,
LWMConfig,
ComplexPatchTokenizer,
masked_nmse_loss,
)
# ----- 1. Load one sequence (complex tensor) -----
data_cfg = AngleDelayDatasetConfig(raw_path=Path("examples/data/parow.p"))
dataset = AngleDelaySequenceDataset(data_cfg)
sequence = dataset[0]["sequence"].unsqueeze(0) # (1, T, N, M)
sequence = sequence[:, :11] # keep only the first 11 time steps
print("Sequence shape:", sequence.shape) # expect (1, 11, 32, 8)
# ----- 2. Tokenise and select tokens to mask -----
tokenizer = ComplexPatchTokenizer(phase_mode="real_imag")
tokens, base_mask = tokenizer(sequence, patch_size=(1, 1)) # tokens: (B, S, D)
B, S, D = tokens.shape
mask_ratio = 0.60 # choose the fraction to hide
mask = base_mask.clone()
# randomly choose the positions that will be hidden
for b in range(B):
num_mask = int(mask_ratio * S)
masked_positions = torch.randperm(S)[:num_mask]
mask[b, masked_positions] = True
# create the corrupted input by zeroing the masked tokens
corrupted_tokens = tokens.clone()
corrupted_tokens[mask] = 0.0
# ----- 3. Load the pretrained backbone -----
# Need max_seq_len >= S (here 11 * 32 * 8 = 2816)
cfg = LWMConfig(
patch_size=(1, 1),
phase_mode="real_imag",
embed_dim=32,
depth=12,
num_heads=8,
mlp_ratio=4.0,
same_frame_window=2,
temporal_offsets=(-4, -3, -2, -1, 1, 2, 3),
temporal_spatial_window=2,
temporal_drift_h=1,
temporal_drift_w=1,
routing_topk_enable=True,
topk_per_head=True,
max_seq_len=2816, # 2816
)
backbone = LWMBackbone.from_pretrained(Path("checkpoints/pytorch_model.bin"), config=cfg)
backbone.eval()
# ---- 4. Run reconstruction and compute NMSE on the masked positions -----
with torch.no_grad():
# compute H, W from the sequence (N and M dimensions)
T = sequence.size(1)
H = sequence.size(2)
W = sequence.size(3)
outputs = backbone.forward_tokens(corrupted_tokens, mask, T, H, W, return_cls=False)
reconstructed = outputs["reconstruction"]
nmse = masked_nmse_loss(reconstructed, tokens, mask)
nmse_db = 10 * torch.log10(nmse)
print(f"Masked {mask_ratio*100:.1f}% of tokens ({mask.sum().item()} / {S})")
print(f"NMSE (linear): {nmse.item():.6f}")
print(f"NMSE (dB): {nmse_db.item():.2f} dB")
# import torch
# from pathlib import Path
# from LWMTemporal.data.datasets import AngleDelayDatasetConfig, AngleDelaySequenceDataset
# from LWMTemporal.models.lwm import (
# LWMBackbone,
# LWMConfig,
# ComplexPatchTokenizer,
# masked_nmse_loss,
# )
# # --- 1. Load one sample from the dataset and keep the first 11 frames ---
# cfg = AngleDelayDatasetConfig(raw_path=Path("LWMTemporal/data/parow.p"))
# dataset = AngleDelaySequenceDataset(cfg)
# sequence = dataset[0]["sequence"].unsqueeze(0)[:, :11] # (1, 11, 32, 8)
# # --- 2. Tokenise and randomly mask 40% of the tokens ---
# tokenizer = ComplexPatchTokenizer(phase_mode="real_imag")
# tokens, base_mask = tokenizer(sequence, patch_size=(1, 1))
# mask = base_mask.clone()
# B, S, _ = tokens.shape
# mask_fraction = 0.40
# for b in range(B):
# num_mask = int(mask_fraction * S)
# masked_positions = torch.randperm(S)[:num_mask]
# mask[b, masked_positions] = True
# corrupted_tokens = tokens.clone()
# corrupted_tokens[mask] = 0.0
# T = sequence.size(1)
# H = sequence.size(2)
# W = sequence.size(3)
# # --- 3. Helper to run a model and report NMSE ---
# def run_model(model: LWMBackbone, label: str) -> None:
# model.eval()
# with torch.no_grad():
# outputs = model.forward_tokens(corrupted_tokens, mask, T, H, W, return_cls=False)
# reconstructed = outputs["reconstruction"]
# nmse = masked_nmse_loss(reconstructed, tokens, mask)
# nmse_db = 10 * torch.log10(nmse)
# print(f"{label:>12}: NMSE = {nmse.item():.6f} ({nmse_db.item():.2f} dB)")
# # --- 4. Random-weights model ---
# cfg_random = LWMConfig(max_seq_len=11 * sequence.size(2) * sequence.size(3))
# model_random = LWMBackbone(cfg_random)
# run_model(model_random, "random init")
# # --- 5. Pretrained checkpoint ---
# cfg_pretrained = LWMConfig(max_seq_len=11 * sequence.size(2) * sequence.size(3))
# model_ckpt = LWMBackbone.from_pretrained(Path("LWMTemporal/models"), config=cfg_pretrained)
# run_model(model_ckpt, "checkpoint")