File size: 4,630 Bytes
164610c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da859b5
164610c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))

import torch
import numpy as np

from LWMTemporal.data.datasets import AngleDelayDatasetConfig, AngleDelaySequenceDataset
from LWMTemporal.models.lwm import (
    LWMBackbone,
    LWMConfig,
    ComplexPatchTokenizer,
    masked_nmse_loss,
)

# ----- 1. Load one sequence (complex tensor) -----
data_cfg = AngleDelayDatasetConfig(raw_path=Path("examples/data/parow.p"))
dataset = AngleDelaySequenceDataset(data_cfg)

sequence = dataset[0]["sequence"].unsqueeze(0)        # (1, T, N, M)
sequence = sequence[:, :11]                          # keep only the first 11 time steps
print("Sequence shape:", sequence.shape)             # expect (1, 11, 32, 8)

# ----- 2. Tokenise and select tokens to mask -----
tokenizer = ComplexPatchTokenizer(phase_mode="real_imag")
tokens, base_mask = tokenizer(sequence, patch_size=(1, 1))  # tokens: (B, S, D)

B, S, D = tokens.shape
mask_ratio = 0.60                                        # choose the fraction to hide
mask = base_mask.clone()

# randomly choose the positions that will be hidden
for b in range(B):
    num_mask = int(mask_ratio * S)
    masked_positions = torch.randperm(S)[:num_mask]
    mask[b, masked_positions] = True

# create the corrupted input by zeroing the masked tokens
corrupted_tokens = tokens.clone()
corrupted_tokens[mask] = 0.0

# ----- 3. Load the pretrained backbone -----
# Need max_seq_len >= S (here 11 * 32 * 8 = 2816)
cfg = LWMConfig(
    patch_size=(1, 1),
    phase_mode="real_imag",
    embed_dim=32,
    depth=12,
    num_heads=8,
    mlp_ratio=4.0,
    same_frame_window=2,
    temporal_offsets=(-4, -3, -2, -1, 1, 2, 3),
    temporal_spatial_window=2,
    temporal_drift_h=1,
    temporal_drift_w=1,
    routing_topk_enable=True,
    topk_per_head=True,
    max_seq_len=2816, # 2816
)

backbone = LWMBackbone.from_pretrained(Path("checkpoints/pytorch_model.bin"), config=cfg)
backbone.eval()

# ---- 4. Run reconstruction and compute NMSE on the masked positions -----
with torch.no_grad():
    # compute H, W from the sequence (N and M dimensions)
    T = sequence.size(1)
    H = sequence.size(2)
    W = sequence.size(3)

    outputs = backbone.forward_tokens(corrupted_tokens, mask, T, H, W, return_cls=False)
    reconstructed = outputs["reconstruction"]

    nmse = masked_nmse_loss(reconstructed, tokens, mask)
    nmse_db = 10 * torch.log10(nmse)

print(f"Masked {mask_ratio*100:.1f}% of tokens ({mask.sum().item()} / {S})")
print(f"NMSE (linear): {nmse.item():.6f}")
print(f"NMSE (dB):     {nmse_db.item():.2f} dB")





# import torch
# from pathlib import Path
# from LWMTemporal.data.datasets import AngleDelayDatasetConfig, AngleDelaySequenceDataset
# from LWMTemporal.models.lwm import (
#     LWMBackbone,
#     LWMConfig,
#     ComplexPatchTokenizer,
#     masked_nmse_loss,
# )

# # --- 1. Load one sample from the dataset and keep the first 11 frames ---
# cfg = AngleDelayDatasetConfig(raw_path=Path("LWMTemporal/data/parow.p"))
# dataset = AngleDelaySequenceDataset(cfg)
# sequence = dataset[0]["sequence"].unsqueeze(0)[:, :11]  # (1, 11, 32, 8)

# # --- 2. Tokenise and randomly mask 40% of the tokens ---
# tokenizer = ComplexPatchTokenizer(phase_mode="real_imag")
# tokens, base_mask = tokenizer(sequence, patch_size=(1, 1))
# mask = base_mask.clone()

# B, S, _ = tokens.shape
# mask_fraction = 0.40

# for b in range(B):
#     num_mask = int(mask_fraction * S)
#     masked_positions = torch.randperm(S)[:num_mask]
#     mask[b, masked_positions] = True

# corrupted_tokens = tokens.clone()
# corrupted_tokens[mask] = 0.0

# T = sequence.size(1)
# H = sequence.size(2)
# W = sequence.size(3)

# # --- 3. Helper to run a model and report NMSE ---
# def run_model(model: LWMBackbone, label: str) -> None:
#     model.eval()
#     with torch.no_grad():
#         outputs = model.forward_tokens(corrupted_tokens, mask, T, H, W, return_cls=False)
#         reconstructed = outputs["reconstruction"]
#         nmse = masked_nmse_loss(reconstructed, tokens, mask)
#         nmse_db = 10 * torch.log10(nmse)
#     print(f"{label:>12}: NMSE = {nmse.item():.6f}   ({nmse_db.item():.2f} dB)")

# # --- 4. Random-weights model ---
# cfg_random = LWMConfig(max_seq_len=11 * sequence.size(2) * sequence.size(3))
# model_random = LWMBackbone(cfg_random)
# run_model(model_random, "random init")

# # --- 5. Pretrained checkpoint ---
# cfg_pretrained = LWMConfig(max_seq_len=11 * sequence.size(2) * sequence.size(3))
# model_ckpt = LWMBackbone.from_pretrained(Path("LWMTemporal/models"), config=cfg_pretrained)
# run_model(model_ckpt, "checkpoint")