Sadjad Alikhani commited on
Commit
164610c
Β·
0 Parent(s):

Initial commit

Browse files
.cursorignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python cache
2
+ __pycache__/
3
+ *.pyc
4
+
5
+ # Experiment artifacts
6
+ cache/
7
+ logs/
8
+ wandb/
9
+ figs/
10
+ checkpoints/*.pth
11
+ checkpoints/*.bin
12
+
13
+ # Data files
14
+ examples/data/*.p
15
+ *.pkl
16
+ *.pickle
17
+
.gitattributes ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Auto detect text files and perform LF normalization
2
+ * text=auto
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+ MANIFEST
23
+
24
+ # PyTorch
25
+ *.pth
26
+ *.pt
27
+ *.bin
28
+ *.ckpt
29
+ !checkpoints/**/*.pth
30
+ !checkpoints/**/*.bin
31
+ !checkpoints/**/*.json
32
+ !LWMTemporal/models/config.json
33
+
34
+ # Data
35
+ *.p
36
+ *.pkl
37
+ *.pickle
38
+ *.h5
39
+ *.hdf5
40
+ cache/
41
+ data/
42
+ !examples/data/
43
+ !examples/data/*.p
44
+ !examples/data/README.md
45
+
46
+ # Experiments
47
+ logs/
48
+ figs/
49
+ wandb/
50
+ outputs/
51
+ # checkpoints/
52
+ runs/
53
+
54
+ # IDE
55
+ .vscode/
56
+ .idea/
57
+ *.swp
58
+ *.swo
59
+ *~
60
+ .DS_Store
61
+
62
+ # Testing
63
+ .pytest_cache/
64
+ .coverage
65
+ htmlcov/
66
+ .tox/
67
+
68
+ # Environment
69
+ .env
70
+ .venv
71
+ env/
72
+ venv/
73
+ ENV/
74
+ env.bak/
75
+ venv.bak/
76
+
LICENSE ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Sadjad Alikhani
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
LWMTemporal/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """LWM Temporal Model package."""
2
+
3
+ import warnings
4
+
5
+ warnings.filterwarnings("ignore")
6
+
7
+ from .models.lwm import LWMConfig, LWMModel, LWMBackbone
8
+
9
+ __version__ = "0.1.0"
10
+ __all__ = ["LWMConfig", "LWMModel", "LWMBackbone", "__version__"]
LWMTemporal/cli/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Command line entrypoints for the LWM foundation package."""
LWMTemporal/cli/channel_prediction.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+ from typing import Optional, Sequence
6
+ import torch
7
+
8
+ from ..tasks.channel_prediction import (
9
+ ChannelPredictionArgs,
10
+ ChannelPredictionTrainer,
11
+ DatasetArgs,
12
+ ModelArgs,
13
+ PredictionArgs,
14
+ TrainingArgs,
15
+ )
16
+ from ..utils.logging import setup_logging
17
+
18
+
19
+ def parse_args(argv: Optional[Sequence[str]] = None) -> ChannelPredictionArgs:
20
+ parser = argparse.ArgumentParser(description="Channel prediction trainer")
21
+ parser.add_argument("--data_path", type=Path, required=True)
22
+ parser.add_argument("--keep_percentage", type=float, default=0.25)
23
+ parser.add_argument("--normalize", type=str, default="global_rms", choices=["global_rms", "per_sample_rms", "none"])
24
+ parser.add_argument("--cache_dir", type=Path, default=Path("cache"))
25
+ parser.add_argument("--no_cache", action="store_true")
26
+ parser.add_argument("--overwrite_cache", action="store_true")
27
+ parser.add_argument("--snr_db", type=float, default=None)
28
+ parser.add_argument("--noise_seed", type=int, default=None)
29
+ parser.add_argument("--max_time_steps", type=int, default=None)
30
+ parser.add_argument("--train_limit", type=int, default=500)
31
+ parser.add_argument("--val_limit", type=int, default=100)
32
+ parser.add_argument("--seed", type=int, default=42)
33
+
34
+ parser.add_argument("--patch_size", type=int, nargs=2, default=(1, 1))
35
+ parser.add_argument("--phase_mode", type=str, default="real_imag", choices=["real_imag", "mag_phase"])
36
+ parser.add_argument("--embed_dim", type=int, default=32)
37
+ parser.add_argument("--depth", type=int, default=12)
38
+ parser.add_argument("--num_heads", type=int, default=8)
39
+ parser.add_argument("--mlp_ratio", type=float, default=4.0)
40
+ parser.add_argument("--same_frame_window", type=int, default=2)
41
+ parser.add_argument("--temporal_offsets", type=int, nargs="*", default=[-1, -2, -3, -4, -5, -6, -7])
42
+ parser.add_argument("--temporal_spatial_window", type=int, default=2)
43
+ parser.add_argument("--temporal_drift_h", type=int, default=1)
44
+ parser.add_argument("--temporal_drift_w", type=int, default=1)
45
+ parser.add_argument("--routing_topk_enable", action="store_true", default=True)
46
+ parser.add_argument("--routing_topk_fraction", type=float, default=0.2)
47
+ parser.add_argument("--routing_topk_min", type=int, default=8)
48
+ parser.add_argument("--routing_topk_max", type=int, default=32)
49
+ parser.add_argument("--topk_per_head", action="store_true", default=True)
50
+ parser.add_argument("--posenc", type=str, default="learned", choices=["learned", "rope_sincos"])
51
+ parser.add_argument("--rope_base", type=float, default=10000.0)
52
+ parser.add_argument("--global_cls", action="store_true")
53
+ parser.add_argument("--pretrained", type=Path, default=None)
54
+ parser.add_argument("--finetune_last_n", type=int, default=0)
55
+ parser.add_argument("--train_head_only", action="store_true")
56
+
57
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
58
+ parser.add_argument("--epochs", type=int, default=3)
59
+ parser.add_argument("--batch_size", type=int, default=16)
60
+ parser.add_argument("--lr", type=float, default=1e-4)
61
+ parser.add_argument("--weight_decay", type=float, default=1e-4)
62
+ parser.add_argument("--warmup_ratio", type=float, default=0.1)
63
+ parser.add_argument("--loss", type=str, default="nmse", choices=["nmse", "mse"])
64
+ parser.add_argument("--use_dataparallel", action="store_true")
65
+ parser.add_argument("--grad_clip", type=float, default=1.0)
66
+ parser.add_argument("--log_interval", type=int, default=10)
67
+ parser.add_argument("--save_dir", type=Path, default=Path("models"))
68
+ parser.add_argument("--save_prefix", type=str, default="channel_prediction")
69
+ parser.add_argument("--inference_only", action="store_true")
70
+ parser.add_argument("--inference_split", type=str, default="val", choices=["train", "val", "all"])
71
+ parser.add_argument("--verbose_inference", action="store_true")
72
+ parser.add_argument("--log_dir", type=Path, default=Path("logs"))
73
+ parser.add_argument("--use_wandb", action="store_true")
74
+ parser.add_argument("--wandb_project", type=str, default=None)
75
+ parser.add_argument("--wandb_entity", type=str, default=None)
76
+ parser.add_argument("--wandb_run_name", type=str, default=None)
77
+
78
+ parser.add_argument("--Tpast", type=int, default=10)
79
+ parser.add_argument("--horizon", type=int, default=1)
80
+ parser.add_argument("--num_visual_samples", type=int, default=4)
81
+ parser.add_argument("--viz_dir", type=Path, default=Path("figs/predictions"))
82
+
83
+ ns = parser.parse_args(argv)
84
+
85
+ dataset = DatasetArgs(
86
+ data_path=ns.data_path,
87
+ keep_percentage=ns.keep_percentage,
88
+ normalize=ns.normalize,
89
+ cache_dir=ns.cache_dir,
90
+ use_cache=not ns.no_cache,
91
+ overwrite_cache=ns.overwrite_cache,
92
+ snr_db=ns.snr_db,
93
+ noise_seed=ns.noise_seed,
94
+ max_time_steps=ns.max_time_steps,
95
+ train_limit=ns.train_limit,
96
+ val_limit=ns.val_limit,
97
+ seed=ns.seed,
98
+ )
99
+
100
+ model = ModelArgs(
101
+ patch_size=tuple(ns.patch_size),
102
+ phase_mode=ns.phase_mode,
103
+ embed_dim=ns.embed_dim,
104
+ depth=ns.depth,
105
+ num_heads=ns.num_heads,
106
+ mlp_ratio=ns.mlp_ratio,
107
+ same_frame_window=ns.same_frame_window,
108
+ temporal_offsets=tuple(ns.temporal_offsets),
109
+ temporal_spatial_window=ns.temporal_spatial_window,
110
+ temporal_drift_h=ns.temporal_drift_h,
111
+ temporal_drift_w=ns.temporal_drift_w,
112
+ routing_topk_enable=ns.routing_topk_enable,
113
+ routing_topk_fraction=ns.routing_topk_fraction,
114
+ routing_topk_min=ns.routing_topk_min,
115
+ routing_topk_max=ns.routing_topk_max,
116
+ topk_per_head=ns.topk_per_head,
117
+ posenc=ns.posenc,
118
+ rope_base=ns.rope_base,
119
+ global_cls=ns.global_cls,
120
+ pretrained=ns.pretrained,
121
+ finetune_last_n=ns.finetune_last_n,
122
+ train_head_only=ns.train_head_only,
123
+ )
124
+
125
+ training = TrainingArgs(
126
+ device=ns.device,
127
+ epochs=ns.epochs,
128
+ batch_size=ns.batch_size,
129
+ lr=ns.lr,
130
+ weight_decay=ns.weight_decay,
131
+ warmup_ratio=ns.warmup_ratio,
132
+ loss=ns.loss,
133
+ use_dataparallel=ns.use_dataparallel,
134
+ grad_clip=ns.grad_clip,
135
+ log_interval=ns.log_interval,
136
+ save_dir=ns.save_dir,
137
+ save_prefix=ns.save_prefix,
138
+ inference_only=ns.inference_only,
139
+ inference_split=ns.inference_split,
140
+ verbose_inference=ns.verbose_inference,
141
+ log_dir=ns.log_dir,
142
+ use_wandb=ns.use_wandb,
143
+ wandb_project=ns.wandb_project,
144
+ wandb_entity=ns.wandb_entity,
145
+ wandb_run_name=ns.wandb_run_name,
146
+ )
147
+
148
+ prediction = PredictionArgs(
149
+ Tpast=ns.Tpast,
150
+ horizon=ns.horizon,
151
+ num_visual_samples=ns.num_visual_samples,
152
+ viz_dir=ns.viz_dir,
153
+ )
154
+
155
+ return ChannelPredictionArgs(dataset=dataset, model=model, training=training, prediction=prediction)
156
+
157
+
158
+ def main(argv: Optional[Sequence[str]] = None) -> None:
159
+ args = parse_args(argv)
160
+ logger = setup_logging("LWMTemporal.channel_prediction", args.training.log_dir)
161
+ logger.info(
162
+ "Starting channel prediction run | device=%s inference_only=%s use_wandb=%s",
163
+ args.training.device,
164
+ args.training.inference_only,
165
+ args.training.use_wandb,
166
+ )
167
+ trainer = ChannelPredictionTrainer(args, logger=logger)
168
+ trainer.train()
169
+
170
+
171
+ __all__ = ["parse_args", "main"]
172
+
173
+
174
+ if __name__ == "__main__":
175
+ main()
LWMTemporal/cli/pretrain.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Sequence
4
+
5
+ from ..tasks.pretraining import build_parser, build_pretraining_args, PretrainingTrainer
6
+ from ..utils.logging import setup_logging
7
+
8
+
9
+ def main(argv: Optional[Sequence[str]] = None) -> None:
10
+ parser = build_parser()
11
+ args_ns = parser.parse_args(args=list(argv) if argv is not None else None)
12
+ args = build_pretraining_args(args_ns)
13
+ logger = setup_logging("LWMTemporal.pretraining", args.logging.log_dir)
14
+ logger.info(
15
+ "Starting pretraining run | device=%s epochs=%d batch_size=%d use_wandb=%s",
16
+ args.optim.device,
17
+ args.optim.epochs,
18
+ args.optim.batch_size,
19
+ args.logging.use_wandb,
20
+ )
21
+ trainer = PretrainingTrainer(args, logger=logger)
22
+ trainer.train()
23
+
24
+
25
+ __all__ = ["main"]
LWMTemporal/models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .lwm import LWMConfig, LWMModel, LWMBackbone
2
+
3
+ __all__ = ["LWMConfig", "LWMModel", "LWMBackbone"]
LWMTemporal/models/config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "patch_size": [1, 1],
3
+ "phase_mode": "real_imag",
4
+ "embed_dim": 32,
5
+ "depth": 12,
6
+ "num_heads": 8,
7
+ "mlp_ratio": 4.0,
8
+ "same_frame_window": 2,
9
+ "same_frame_window_h": null,
10
+ "same_frame_window_w": null,
11
+ "same_frame_dilation_h": 1,
12
+ "same_frame_dilation_w": 1,
13
+ "temporal_offsets": [-4, -3, -2, -1, 1, 2, 3],
14
+ "temporal_spatial_window": 2,
15
+ "temporal_spatial_window_h": null,
16
+ "temporal_spatial_window_w": null,
17
+ "temporal_spatial_dilation_h": 1,
18
+ "temporal_spatial_dilation_w": 1,
19
+ "temporal_drift_h": 1,
20
+ "temporal_drift_w": 1,
21
+ "spatial_only": false,
22
+ "routing_topk_enable": true,
23
+ "routing_topk_fraction": 0.2,
24
+ "routing_topk_min": 8,
25
+ "routing_topk_max": 32,
26
+ "routing_topk_per_head": true,
27
+ "topk_neighbors": null,
28
+ "topk_per_head": true,
29
+ "global_cls": false,
30
+ "posenc": "learned",
31
+ "rope_base": 10000.0,
32
+ "rope_mode": "flat",
33
+ "rope_base_t": null,
34
+ "rope_base_h": null,
35
+ "rope_base_w": null,
36
+ "max_seq_len": null
37
+ }
LWMTemporal/models/lwm.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import math
5
+ from dataclasses import dataclass, asdict, fields
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch import Tensor
14
+ from torch.utils.data import Dataset
15
+
16
+ # -----------------------------------------------------------------------------
17
+ # Tokenization
18
+ # -----------------------------------------------------------------------------
19
+ class ComplexPatchTokenizer:
20
+ def __init__(self, phase_mode: str = "real_imag") -> None:
21
+ if phase_mode not in {"real_imag", "mag_phase"}:
22
+ raise ValueError("phase_mode must be 'real_imag' or 'mag_phase'")
23
+ self.phase_mode = phase_mode
24
+
25
+ def _split_channels(self, tensor: Tensor) -> Tensor:
26
+ if self.phase_mode == "real_imag":
27
+ real = tensor.real.unsqueeze(-1)
28
+ imag = tensor.imag.unsqueeze(-1)
29
+ return torch.cat([real, imag], dim=-1)
30
+ magnitude = tensor.abs().unsqueeze(-1)
31
+ phase = torch.angle(tensor).unsqueeze(-1)
32
+ return torch.cat([magnitude, phase], dim=-1)
33
+
34
+ def __call__(self, seq: Tensor, patch_size: Tuple[int, int]) -> Tuple[Tensor, Tensor]:
35
+ if not torch.is_complex(seq):
36
+ raise TypeError("expected complex tensor shaped (B, T, N, M)")
37
+ ph, pw = patch_size
38
+ if seq.size(2) % ph != 0 or seq.size(3) % pw != 0:
39
+ raise ValueError("patch_size must evenly divide channel dimensions")
40
+ channels = self._split_channels(seq)
41
+ b, t, n, m, c = channels.shape
42
+ h = n // ph
43
+ w = m // pw
44
+ channels = channels.view(b, t, h, ph, w, pw, c)
45
+ channels = channels.permute(0, 1, 2, 4, 3, 5, 6).contiguous()
46
+ tokens = channels.view(b, t * h * w, ph * pw * c)
47
+ mask = torch.zeros((b, tokens.size(1)), dtype=torch.bool, device=tokens.device)
48
+ return tokens, mask
49
+
50
+
51
+ # -----------------------------------------------------------------------------
52
+ # Sparse spatio-temporal attention
53
+ # -----------------------------------------------------------------------------
54
+ @dataclass(frozen=True)
55
+ class AttentionCacheKey:
56
+ temporal: int
57
+ height: int
58
+ width: int
59
+ same_frame_window: int
60
+ same_frame_window_h: Optional[int]
61
+ same_frame_window_w: Optional[int]
62
+ same_frame_dilation_h: int
63
+ same_frame_dilation_w: int
64
+ temporal_offsets: Tuple[int, ...]
65
+ temporal_spatial_window: int
66
+ temporal_spatial_window_h: Optional[int]
67
+ temporal_spatial_window_w: Optional[int]
68
+ temporal_spatial_dilation_h: int
69
+ temporal_spatial_dilation_w: int
70
+ temporal_drift_h: int
71
+ temporal_drift_w: int
72
+ include_cls: bool
73
+
74
+
75
+ class NeighborIndexer:
76
+ def __init__(self) -> None:
77
+ self._cache: Dict[Tuple[int, int, int, AttentionCacheKey], Tensor] = {}
78
+
79
+ def get(self, T: int, H: int, W: int, include_cls: bool, config: "LWMConfig", device: torch.device) -> Tensor:
80
+ key = (
81
+ T,
82
+ H,
83
+ W,
84
+ AttentionCacheKey(
85
+ temporal=T,
86
+ height=H,
87
+ width=W,
88
+ same_frame_window=config.same_frame_window,
89
+ same_frame_window_h=config.same_frame_window_h,
90
+ same_frame_window_w=config.same_frame_window_w,
91
+ same_frame_dilation_h=config.same_frame_dilation_h,
92
+ same_frame_dilation_w=config.same_frame_dilation_w,
93
+ temporal_offsets=config.temporal_offsets,
94
+ temporal_spatial_window=config.temporal_spatial_window,
95
+ temporal_spatial_window_h=config.temporal_spatial_window_h,
96
+ temporal_spatial_window_w=config.temporal_spatial_window_w,
97
+ temporal_spatial_dilation_h=config.temporal_spatial_dilation_h,
98
+ temporal_spatial_dilation_w=config.temporal_spatial_dilation_w,
99
+ temporal_drift_h=config.temporal_drift_h,
100
+ temporal_drift_w=config.temporal_drift_w,
101
+ include_cls=include_cls,
102
+ ),
103
+ )
104
+ if key in self._cache:
105
+ tensor = self._cache[key]
106
+ return tensor if tensor.device == device else tensor.to(device)
107
+ indices = self._build_indices(T, H, W, include_cls, config)
108
+ if indices:
109
+ max_len = max(len(neighbors) for neighbors in indices)
110
+ if any(len(neighbors) != max_len for neighbors in indices):
111
+ padded = []
112
+ for neighbors in indices:
113
+ if len(neighbors) < max_len:
114
+ neighbors = neighbors + [-1] * (max_len - len(neighbors))
115
+ padded.append(neighbors)
116
+ indices = padded
117
+ tensor = torch.as_tensor(indices, dtype=torch.long, device=device)
118
+ self._cache[key] = tensor
119
+ return tensor
120
+
121
+ def _build_indices(self, T: int, H: int, W: int, include_cls: bool, config: "LWMConfig") -> List[List[int]]:
122
+ neighbors: List[List[int]] = []
123
+ same_h = config.same_frame_window if config.same_frame_window_h is None else config.same_frame_window_h
124
+ same_w = config.same_frame_window if config.same_frame_window_w is None else config.same_frame_window_w
125
+
126
+ def frame_base(frame: int) -> int:
127
+ return frame * H * W
128
+
129
+ for t_idx in range(T):
130
+ base = frame_base(t_idx)
131
+ for h_idx in range(H):
132
+ for w_idx in range(W):
133
+ current = base + h_idx * W + w_idx
134
+ local: List[int] = []
135
+ if config.same_frame_window < 0:
136
+ local.extend(range(base, base + H * W))
137
+ else:
138
+ for dh in range(-same_h, same_h + 1, config.same_frame_dilation_h):
139
+ for dw in range(-same_w, same_w + 1, config.same_frame_dilation_w):
140
+ nh = h_idx + dh
141
+ nw = w_idx + dw
142
+ if 0 <= nh < H and 0 <= nw < W:
143
+ local.append(base + nh * W + nw)
144
+ if not config.spatial_only:
145
+ for dt in config.temporal_offsets:
146
+ other_t = t_idx + dt
147
+ if other_t < 0 or other_t >= T:
148
+ continue
149
+ other_base = frame_base(other_t)
150
+ drift_h = config.temporal_spatial_window if config.temporal_drift_h == 0 else min(config.temporal_spatial_window, abs(dt) * config.temporal_drift_h)
151
+ drift_w = config.temporal_spatial_window if config.temporal_drift_w == 0 else min(config.temporal_spatial_window, abs(dt) * config.temporal_drift_w)
152
+ window_h = config.temporal_spatial_window if config.temporal_spatial_window_h is None else config.temporal_spatial_window_h
153
+ window_w = config.temporal_spatial_window if config.temporal_spatial_window_w is None else config.temporal_spatial_window_w
154
+ for dh in range(-min(window_h, drift_h), min(window_h, drift_h) + 1, config.temporal_spatial_dilation_h):
155
+ for dw in range(-min(window_w, drift_w), min(window_w, drift_w) + 1, config.temporal_spatial_dilation_w):
156
+ nh = max(0, min(H - 1, h_idx + dh))
157
+ nw = max(0, min(W - 1, w_idx + dw))
158
+ local.append(other_base + nh * W + nw)
159
+ if include_cls:
160
+ local.append(T * H * W)
161
+ if not local:
162
+ local.append(current)
163
+ neighbors.append(sorted(set(local)))
164
+ if include_cls:
165
+ neighbors.append(list(range(T * H * W)))
166
+ return neighbors
167
+
168
+
169
+ class SparseSpatioTemporalAttention(nn.Module):
170
+ def __init__(self, config: "LWMConfig", embed_dim: int, num_heads: int) -> None:
171
+ super().__init__()
172
+ self.config = config
173
+ self.embed_dim = embed_dim
174
+ self.num_heads = num_heads
175
+ self.head_dim = embed_dim // num_heads
176
+ if self.head_dim * num_heads != embed_dim:
177
+ raise ValueError("embed_dim must be divisible by num_heads")
178
+ self.scale = self.head_dim ** -0.5
179
+ self.qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
180
+ self.proj = nn.Linear(embed_dim, embed_dim)
181
+ self.indexer = NeighborIndexer()
182
+
183
+ def _apply_rope(self, x: Tensor, cos: Tensor, sin: Tensor) -> Tensor:
184
+ x1 = x[..., ::2]
185
+ x2 = x[..., 1::2]
186
+ rotated_first = x1 * cos - x2 * sin
187
+ rotated_second = x1 * sin + x2 * cos
188
+ return torch.stack([rotated_first, rotated_second], dim=-1).flatten(-2)
189
+
190
+ def _rope_factors(self, S: int, device: torch.device) -> Tuple[Tensor, Tensor]:
191
+ half = self.head_dim // 2
192
+ inv_freq = 1.0 / (self.config.rope_base ** (torch.arange(0, half, dtype=torch.float32, device=device) / max(1, half)))
193
+ positions = torch.arange(S, dtype=torch.float32, device=device)
194
+ angles = positions[:, None] * inv_freq[None, :]
195
+ return torch.cos(angles)[None, None, :, :], torch.sin(angles)[None, None, :, :]
196
+
197
+ def forward(self, hidden_states: Tensor, T: int, H: int, W: int, include_cls: bool) -> Tensor:
198
+ bsz, seq_len, _ = hidden_states.shape
199
+ neighbors = self.indexer.get(T, H, W, include_cls, self.config, hidden_states.device)
200
+ qkv = self.qkv(hidden_states)
201
+ qkv = qkv.view(bsz, seq_len, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
202
+ q, k, v = qkv[0], qkv[1], qkv[2]
203
+
204
+ if self.config.posenc == "rope_sincos":
205
+ cos, sin = self._rope_factors(seq_len, hidden_states.device)
206
+ q = self._apply_rope(q, cos, sin)
207
+ k = self._apply_rope(k, cos, sin)
208
+
209
+ gather_idx = neighbors.clamp_min(0)
210
+ valid_mask = neighbors >= 0
211
+ k = k[:, :, gather_idx, :]
212
+ v = v[:, :, gather_idx, :]
213
+ scores = torch.einsum("bhqd,bhqkd->bhqk", q, k) * self.scale
214
+ scores = scores.masked_fill(~valid_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
215
+
216
+ if self.config.routing_topk_enable:
217
+ K = scores.size(-1)
218
+ keep = min(self.config.routing_topk_max, max(self.config.routing_topk_min, int(self.config.routing_topk_fraction * K)))
219
+ if self.config.routing_topk_per_head:
220
+ _, idx = torch.topk(scores, keep, dim=-1)
221
+ topk_mask = torch.zeros_like(scores, dtype=torch.bool)
222
+ topk_mask.scatter_(-1, idx, True)
223
+ else:
224
+ avg_scores = scores.mean(dim=1, keepdim=True)
225
+ _, idx = torch.topk(avg_scores, keep, dim=-1)
226
+ topk_mask = torch.zeros_like(scores, dtype=torch.bool)
227
+ topk_mask.scatter_(-1, idx.expand_as(scores), True)
228
+ scores = scores.masked_fill(~topk_mask, float("-inf"))
229
+ elif self.config.topk_neighbors is not None:
230
+ keep = min(self.config.topk_neighbors, scores.size(-1))
231
+ if self.config.topk_per_head:
232
+ _, idx = torch.topk(scores, keep, dim=-1)
233
+ topk_mask = torch.zeros_like(scores, dtype=torch.bool)
234
+ topk_mask.scatter_(-1, idx, True)
235
+ else:
236
+ avg_scores = scores.mean(dim=1, keepdim=True)
237
+ _, idx = torch.topk(avg_scores, keep, dim=-1)
238
+ topk_mask = torch.zeros_like(scores, dtype=torch.bool)
239
+ topk_mask.scatter_(-1, idx.expand_as(scores), True)
240
+ scores = scores.masked_fill(~topk_mask, float("-inf"))
241
+
242
+ attn = torch.softmax(scores, dim=-1)
243
+ attn = attn.masked_fill(~valid_mask.unsqueeze(0).unsqueeze(0), 0.0)
244
+ context = torch.einsum("bhqk,bhqkd->bhqd", attn, v)
245
+ context = context.transpose(1, 2).contiguous().view(bsz, seq_len, self.embed_dim)
246
+ return self.proj(context)
247
+
248
+
249
+ class LWMEncoderLayer(nn.Module):
250
+ def __init__(self, config: "LWMConfig") -> None:
251
+ super().__init__()
252
+ self.norm1 = nn.LayerNorm(config.embed_dim)
253
+ self.attn = SparseSpatioTemporalAttention(config, config.embed_dim, config.num_heads)
254
+ self.norm2 = nn.LayerNorm(config.embed_dim)
255
+ hidden_dim = int(config.embed_dim * config.mlp_ratio)
256
+ self.mlp = nn.Sequential(
257
+ nn.Linear(config.embed_dim, hidden_dim),
258
+ nn.GELU(),
259
+ nn.Linear(hidden_dim, config.embed_dim),
260
+ )
261
+
262
+ def forward(self, x: Tensor, T: int, H: int, W: int, include_cls: bool) -> Tensor:
263
+ x = x + self.attn(self.norm1(x), T, H, W, include_cls)
264
+ x = x + self.mlp(self.norm2(x))
265
+ return x
266
+
267
+
268
+ class LWMEncoder(nn.Module):
269
+ def __init__(self, config: "LWMConfig") -> None:
270
+ super().__init__()
271
+ self.layers = nn.ModuleList([LWMEncoderLayer(config) for _ in range(config.depth)])
272
+ self.norm = nn.LayerNorm(config.embed_dim)
273
+
274
+ def forward(self, x: Tensor, T: int, H: int, W: int, include_cls: bool) -> Tensor:
275
+ for layer in self.layers:
276
+ x = layer(x, T, H, W, include_cls)
277
+ return self.norm(x)
278
+
279
+
280
+ # -----------------------------------------------------------------------------
281
+ # Hugging Face configuration and model definitions
282
+ # -----------------------------------------------------------------------------
283
+ @dataclass
284
+ class LWMConfig:
285
+ patch_size: Tuple[int, int] = (1, 1)
286
+ phase_mode: str = "real_imag"
287
+ embed_dim: int = 32
288
+ depth: int = 12
289
+ num_heads: int = 8
290
+ mlp_ratio: float = 4.0
291
+ same_frame_window: int = 2
292
+ same_frame_window_h: Optional[int] = None
293
+ same_frame_window_w: Optional[int] = None
294
+ same_frame_dilation_h: int = 1
295
+ same_frame_dilation_w: int = 1
296
+ temporal_offsets: Tuple[int, ...] = (-4, -3, -2, -1, 1, 2, 3)
297
+ temporal_spatial_window: int = 2
298
+ temporal_spatial_window_h: Optional[int] = None
299
+ temporal_spatial_window_w: Optional[int] = None
300
+ temporal_spatial_dilation_h: int = 1
301
+ temporal_spatial_dilation_w: int = 1
302
+ temporal_drift_h: int = 1
303
+ temporal_drift_w: int = 1
304
+ spatial_only: bool = False
305
+ routing_topk_enable: bool = True
306
+ routing_topk_fraction: float = 0.2
307
+ routing_topk_min: int = 8
308
+ routing_topk_max: int = 32
309
+ routing_topk_per_head: bool = True
310
+ topk_neighbors: Optional[int] = None
311
+ topk_per_head: bool = True
312
+ global_cls: bool = False
313
+ posenc: str = "learned"
314
+ rope_base: float = 10000.0
315
+ rope_mode: str = "flat"
316
+ rope_base_t: Optional[float] = None
317
+ rope_base_h: Optional[float] = None
318
+ rope_base_w: Optional[float] = None
319
+ max_seq_len: Optional[int] = None
320
+
321
+ def __post_init__(self) -> None:
322
+ self.patch_size = (int(self.patch_size[0]), int(self.patch_size[1]))
323
+ self.temporal_offsets = tuple(int(o) for o in self.temporal_offsets)
324
+
325
+ def to_dict(self) -> Dict[str, Any]:
326
+ return asdict(self)
327
+
328
+ @classmethod
329
+ def from_dict(cls, data: Dict[str, Any]) -> "LWMConfig":
330
+ return cls(**data)
331
+
332
+
333
+ class LWMModel(nn.Module):
334
+ def __init__(self, config: LWMConfig) -> None:
335
+ super().__init__()
336
+ self.config = config
337
+ patch_dim = config.patch_size[0] * config.patch_size[1] * 2
338
+ self.tokenizer = ComplexPatchTokenizer(config.phase_mode)
339
+ self.patch_embed = nn.Linear(patch_dim, config.embed_dim)
340
+ self.global_cls = config.global_cls
341
+ pos_len = (config.max_seq_len or 0) + (1 if self.global_cls else 0)
342
+ if pos_len == 0:
343
+ pos_len = 1
344
+ if config.posenc == "learned":
345
+ self.pos_embed = nn.Parameter(torch.zeros(1, pos_len, config.embed_dim))
346
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
347
+ else:
348
+ self.register_buffer("pos_embed", torch.zeros(1, pos_len, config.embed_dim), persistent=False)
349
+ if self.global_cls:
350
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim))
351
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
352
+ self.encoder = LWMEncoder(config)
353
+ self.head = nn.Linear(config.embed_dim, patch_dim)
354
+ self._init_weights()
355
+
356
+ def _init_weights(self) -> None:
357
+ for module in self.modules():
358
+ if isinstance(module, nn.Linear):
359
+ nn.init.trunc_normal_(module.weight, std=0.02)
360
+ if module.bias is not None:
361
+ nn.init.zeros_(module.bias)
362
+ elif isinstance(module, nn.LayerNorm):
363
+ nn.init.ones_(module.weight)
364
+ nn.init.zeros_(module.bias)
365
+
366
+ def _add_positional(self, tokens: Tensor) -> Tensor:
367
+ if self.config.posenc == "learned":
368
+ return tokens + self.pos_embed[:, : tokens.size(1)]
369
+ return tokens
370
+
371
+ def forward_tokens(
372
+ self,
373
+ tokens: Tensor,
374
+ mask: Tensor,
375
+ T: int,
376
+ H: int,
377
+ W: int,
378
+ *,
379
+ return_cls: bool = False,
380
+ ) -> Dict[str, Optional[Tensor]]:
381
+ embeddings = self.patch_embed(tokens)
382
+ include_cls = self.global_cls
383
+ if include_cls:
384
+ cls_tokens = self.cls_token.expand(embeddings.size(0), -1, -1)
385
+ embeddings = torch.cat([embeddings, cls_tokens], dim=1)
386
+ cls_mask = torch.zeros((embeddings.size(0), 1), dtype=torch.bool, device=embeddings.device)
387
+ mask = torch.cat([mask, cls_mask], dim=1)
388
+ # Add positional embeddings BEFORE masking (matching original implementation)
389
+ embeddings = self._add_positional(embeddings)
390
+ # Then mask embeddings (zeros out both token embedding AND positional embedding)
391
+ embeddings = embeddings.masked_fill(mask.unsqueeze(-1), 0.0)
392
+ encoded = self.encoder(embeddings, T, H, W, include_cls)
393
+ if include_cls:
394
+ reconstruction = self.head(encoded[:, :-1, :])
395
+ cls = encoded[:, -1, :]
396
+ else:
397
+ reconstruction = self.head(encoded)
398
+ cls = None
399
+ return {"reconstruction": reconstruction, "cls": cls if return_cls else None}
400
+
401
+ def forward(self, seq: Tensor, mask: Optional[Tensor] = None, *, return_cls: bool = False) -> Dict[str, Optional[Tensor]]:
402
+ tokens, base_mask = self.tokenizer(seq, self.config.patch_size)
403
+ total_mask = base_mask if mask is None else mask
404
+ ph, pw = self.config.patch_size
405
+ T = seq.size(1)
406
+ H = seq.size(2) // ph
407
+ W = seq.size(3) // pw
408
+ return self.forward_tokens(tokens, total_mask, T, H, W, return_cls=return_cls)
409
+
410
+ @torch.no_grad()
411
+ def forward_features(self, seq: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
412
+ outputs = self.forward(seq, return_cls=True)
413
+ return outputs["reconstruction"], outputs["cls"]
414
+
415
+
416
+ class LWMBackbone(LWMModel):
417
+ """Minor alias kept for backwards compatibility with legacy scripts."""
418
+
419
+ @classmethod
420
+ def from_pretrained(
421
+ cls,
422
+ pretrained_model_name_or_path: str | Path,
423
+ *model_args: Any,
424
+ config: Optional[LWMConfig] = None,
425
+ map_location: str | torch.device = "cpu",
426
+ **kwargs: Any,
427
+ ) -> "LWMBackbone":
428
+ path = Path(pretrained_model_name_or_path)
429
+ state: Dict[str, Tensor]
430
+ checkpoint_config: Optional[Dict[str, Any]] = None
431
+
432
+ if path.is_dir():
433
+ directory = path
434
+ state_path = directory / "pytorch_model.bin"
435
+ if not state_path.exists():
436
+ raise FileNotFoundError(f"Pretrained weights not found at {state_path}")
437
+ raw = torch.load(state_path, map_location=map_location)
438
+ if isinstance(raw, dict) and any(isinstance(v, torch.Tensor) for v in raw.values()):
439
+ state = {k: v for k, v in raw.items() if isinstance(v, torch.Tensor)}
440
+ else:
441
+ raise ValueError(f"Unexpected checkpoint format at {state_path}")
442
+ # Always try to load checkpoint config first, then merge with provided config
443
+ checkpoint_config_dict = None
444
+ config_path = directory / "config.json"
445
+ if config_path.exists():
446
+ with config_path.open("r") as handle:
447
+ checkpoint_config_dict = json.load(handle)
448
+ checkpoint_config = LWMConfig.from_dict(checkpoint_config_dict)
449
+ if config is None:
450
+ config = checkpoint_config
451
+ else:
452
+ # Merge: use checkpoint config as base, override with provided config
453
+ checkpoint_dict = checkpoint_config.to_dict()
454
+ provided_dict = config.to_dict()
455
+ merged_dict = {**checkpoint_dict, **provided_dict}
456
+ config = LWMConfig.from_dict(merged_dict)
457
+ else:
458
+ if not path.exists():
459
+ raise FileNotFoundError(f"Pretrained weights not found at {path}")
460
+ raw = torch.load(path, map_location=map_location)
461
+ if isinstance(raw, dict) and "model_state_dict" in raw:
462
+ state = raw["model_state_dict"]
463
+ checkpoint_config = raw.get("config")
464
+ elif isinstance(raw, dict):
465
+ state = {k: v for k, v in raw.items() if isinstance(v, torch.Tensor)}
466
+ else:
467
+ raise ValueError("Unsupported checkpoint format; expected a state_dict or training checkpoint.")
468
+
469
+ if config is None and checkpoint_config is not None:
470
+ config = cls._config_from_checkpoint(checkpoint_config)
471
+ if config is None:
472
+ config = LWMConfig()
473
+
474
+ if config.max_seq_len is None and "pos_embed" in state:
475
+ pos_len = int(state["pos_embed"].shape[1])
476
+ cls_tokens = 1 if config.global_cls else 0
477
+ inferred = max(0, pos_len - cls_tokens)
478
+ if inferred > 0:
479
+ config.max_seq_len = inferred
480
+ remapped_state = cls._remap_state_dict(state)
481
+ model = cls(config, *model_args, **kwargs)
482
+ model.load_state_dict(remapped_state, strict=False)
483
+ return model
484
+
485
+ def save_pretrained(self, save_directory: str | Path, **kwargs: Any) -> None:
486
+ directory = Path(save_directory)
487
+ directory.mkdir(parents=True, exist_ok=True)
488
+ config_path = directory / "config.json"
489
+ with config_path.open("w") as handle:
490
+ json.dump(self.config.to_dict(), handle, indent=2)
491
+ state_path = directory / "pytorch_model.bin"
492
+ torch.save(self.state_dict(), state_path)
493
+
494
+ @staticmethod
495
+ def _config_from_checkpoint(data: Any) -> Optional[LWMConfig]:
496
+ if not isinstance(data, dict):
497
+ return None
498
+ model_cfg = data.get("model", data)
499
+ if not isinstance(model_cfg, dict):
500
+ return None
501
+ allowed = {field.name for field in fields(LWMConfig)}
502
+ kwargs: Dict[str, Any] = {}
503
+ for key, value in model_cfg.items():
504
+ if key not in allowed:
505
+ continue
506
+ if key == "patch_size" and isinstance(value, (list, tuple)):
507
+ value = tuple(int(v) for v in value)
508
+ if key == "temporal_offsets" and isinstance(value, (list, tuple)):
509
+ value = tuple(int(v) for v in value)
510
+ kwargs[key] = value
511
+ if not kwargs:
512
+ return None
513
+ return LWMConfig(**kwargs)
514
+
515
+ @staticmethod
516
+ def _remap_state_dict(state: Dict[str, Tensor]) -> Dict[str, Tensor]:
517
+ remapped: Dict[str, Tensor] = {}
518
+ for key, value in state.items():
519
+ new_key = key
520
+ if key.startswith("embed."):
521
+ new_key = key.replace("embed", "patch_embed", 1)
522
+ elif key.startswith("blocks."):
523
+ new_key = key.replace("blocks", "encoder.layers", 1)
524
+ elif key.startswith("norm."):
525
+ new_key = key.replace("norm", "encoder.norm", 1)
526
+ remapped[new_key] = value
527
+ return remapped
528
+
529
+
530
+ def compute_nmse(pred: Tensor, target: Tensor, mask: Tensor) -> float:
531
+ """
532
+ Compute NMSE per sample, then average across batch (matching original implementation).
533
+ For each sample: nmse_b = sum((pred-target)^2 [mask]) / sum(target^2 [mask])
534
+ """
535
+ B = pred.size(0)
536
+ nmse_vals = []
537
+ for b in range(B):
538
+ m = mask[b]
539
+ if m.sum() == 0:
540
+ continue
541
+ se = (pred[b][m] - target[b][m]).pow(2).sum()
542
+ sp = target[b][m].pow(2).sum().clamp_min(1e-12)
543
+ nmse_vals.append((se / sp).item())
544
+ if not nmse_vals:
545
+ return float('nan')
546
+ return sum(nmse_vals) / len(nmse_vals)
547
+
548
+
549
+ def masked_nmse_loss(pred: Tensor, target: Tensor, mask: Tensor) -> Tensor:
550
+ diff = (pred - target).abs() ** 2
551
+ power = target.abs() ** 2
552
+ mask_f = mask.float()
553
+ diff_sum = (diff.sum(-1) * mask_f).sum(-1)
554
+ power_sum = (power.sum(-1) * mask_f).sum(-1).clamp_min(1e-12)
555
+ nmse = diff_sum / power_sum
556
+ valid = mask.sum(-1) > 0
557
+ return nmse[valid].mean() if valid.any() else nmse.mean()
558
+
559
+
560
+ def masked_mse_loss(pred: Tensor, target: Tensor, mask: Tensor) -> Tensor:
561
+ diff = (pred - target).abs() ** 2
562
+ mask_f = mask.float()
563
+ num = (diff.sum(-1) * mask_f).sum()
564
+ denom = mask_f.sum().clamp_min(1.0)
565
+ return num / denom
566
+
567
+
568
+ __all__ = [
569
+ "ComplexPatchTokenizer",
570
+ "LWMConfig",
571
+ "LWMModel",
572
+ "LWMBackbone",
573
+ "compute_nmse",
574
+ "masked_nmse_loss",
575
+ "masked_mse_loss",
576
+ ]
LWMTemporal/tasks/channel_prediction.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import logging
5
+ import math
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional, Sequence, Tuple
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.cuda.amp import GradScaler, autocast
13
+ from torch.utils.data import DataLoader, Subset
14
+
15
+ from ..data import AngleDelayDatasetConfig, AngleDelaySequenceDataset
16
+ from ..models import LWMBackbone, LWMConfig
17
+ from ..models.lwm import masked_mse_loss, masked_nmse_loss, compute_nmse
18
+
19
+ try:
20
+ import wandb # type: ignore
21
+ except ImportError: # pragma: no cover
22
+ wandb = None # type: ignore
23
+
24
+
25
+ @dataclasses.dataclass
26
+ class DatasetArgs:
27
+ data_path: Path
28
+ keep_percentage: float = 0.25
29
+ normalize: str = "global_rms"
30
+ cache_dir: Path = Path("cache")
31
+ use_cache: bool = True
32
+ overwrite_cache: bool = False
33
+ snr_db: Optional[float] = None
34
+ noise_seed: Optional[int] = None
35
+ max_time_steps: Optional[int] = None
36
+ train_limit: int = 500
37
+ val_limit: int = 1000
38
+ seed: int = 42
39
+
40
+
41
+ @dataclasses.dataclass
42
+ class ModelArgs:
43
+ patch_size: Tuple[int, int] = (1, 1)
44
+ phase_mode: str = "real_imag"
45
+ embed_dim: int = 32
46
+ depth: int = 12
47
+ num_heads: int = 8
48
+ mlp_ratio: float = 4.0
49
+ same_frame_window: int = 2
50
+ temporal_offsets: Sequence[int] = dataclasses.field(default_factory=lambda: (-1, -2, -3, -4, -5, -6, -7))
51
+ temporal_spatial_window: int = 2
52
+ temporal_drift_h: int = 1
53
+ temporal_drift_w: int = 1
54
+ routing_topk_enable: bool = True
55
+ routing_topk_fraction: float = 0.2
56
+ routing_topk_min: int = 8
57
+ routing_topk_max: int = 32
58
+ topk_per_head: bool = True
59
+ posenc: str = "learned"
60
+ rope_base: float = 10000.0
61
+ global_cls: bool = False
62
+ pretrained: Optional[Path] = None
63
+ finetune_last_n: int = 0
64
+ train_head_only: bool = False
65
+
66
+
67
+ @dataclasses.dataclass
68
+ class TrainingArgs:
69
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
70
+ epochs: int = 3
71
+ batch_size: int = 16
72
+ lr: float = 1e-4
73
+ weight_decay: float = 1e-4
74
+ warmup_ratio: float = 0.1
75
+ loss: str = "nmse"
76
+ use_dataparallel: bool = False
77
+ grad_clip: float = 1.0
78
+ log_interval: int = 10
79
+ save_dir: Path = Path("models")
80
+ save_prefix: str = "channel_prediction"
81
+ inference_only: bool = False
82
+ inference_split: str = "val"
83
+ verbose_inference: bool = False
84
+ log_dir: Path = Path("logs")
85
+ use_wandb: bool = False
86
+ wandb_project: Optional[str] = None
87
+ wandb_entity: Optional[str] = None
88
+ wandb_run_name: Optional[str] = None
89
+
90
+
91
+ @dataclasses.dataclass
92
+ class PredictionArgs:
93
+ Tpast: int = 10
94
+ horizon: int = 1
95
+ num_visual_samples: int = 4
96
+ viz_dir: Path = Path("figs/predictions")
97
+
98
+
99
+ @dataclasses.dataclass
100
+ class ChannelPredictionArgs:
101
+ dataset: DatasetArgs
102
+ model: ModelArgs
103
+ training: TrainingArgs
104
+ prediction: PredictionArgs
105
+
106
+
107
+ class ChannelPredictionDataModule:
108
+ def __init__(self, args: DatasetArgs, patch_size: Tuple[int, int], phase_mode: str) -> None:
109
+ cfg = AngleDelayDatasetConfig(
110
+ raw_path=args.data_path,
111
+ keep_percentage=args.keep_percentage,
112
+ normalize=args.normalize,
113
+ cache_dir=args.cache_dir,
114
+ use_cache=args.use_cache,
115
+ overwrite_cache=args.overwrite_cache,
116
+ snr_db=args.snr_db,
117
+ noise_seed=args.noise_seed,
118
+ max_time_steps=args.max_time_steps,
119
+ patch_size=patch_size,
120
+ phase_mode=phase_mode,
121
+ )
122
+ self.dataset = AngleDelaySequenceDataset(cfg)
123
+ generator = torch.Generator().manual_seed(args.seed)
124
+ indices = torch.randperm(len(self.dataset), generator=generator).tolist()
125
+ train_len = min(args.train_limit, len(indices))
126
+ val_len = min(args.val_limit, max(0, len(indices) - train_len))
127
+ self.train_indices = indices[:train_len]
128
+ self.val_indices = indices[train_len:train_len + val_len]
129
+ self.patch_size = patch_size
130
+ self.phase_mode = phase_mode
131
+
132
+ def train_loader(self, batch_size: int, drop_last: bool = True) -> DataLoader:
133
+ subset = Subset(self.dataset, self.train_indices)
134
+ return DataLoader(subset, batch_size=batch_size, shuffle=True, drop_last=drop_last)
135
+
136
+ def val_loader(self, batch_size: int, drop_last: bool = False) -> Optional[DataLoader]:
137
+ if not self.val_indices:
138
+ return None
139
+ subset = Subset(self.dataset, self.val_indices)
140
+ return DataLoader(subset, batch_size=batch_size, shuffle=False, drop_last=drop_last)
141
+
142
+
143
+ class AutoregressiveEngine:
144
+ def __init__(self, patch_size: Tuple[int, int], phase_mode: str) -> None:
145
+ self.patch_size = patch_size
146
+ self.phase_mode = phase_mode
147
+
148
+ def detokenize(self, tokens: Tensor, T: int, H: int, W: int) -> Tensor:
149
+ B = tokens.size(0)
150
+ ph, pw = self.patch_size
151
+ patches = tokens.view(B, T, H, W, ph * pw * 2)
152
+ patches = patches.view(B, T, H, W, ph, pw, 2)
153
+ patches = patches.permute(0, 1, 2, 4, 3, 5, 6).contiguous()
154
+ recon = patches.view(B, T, H * ph, W * pw, 2)
155
+ if self.phase_mode == "real_imag":
156
+ real = recon[..., 0]
157
+ imag = recon[..., 1]
158
+ return torch.complex(real, imag)
159
+ magnitude = recon[..., 0]
160
+ phase = recon[..., 1]
161
+ real = magnitude * torch.cos(phase)
162
+ imag = magnitude * torch.sin(phase)
163
+ return torch.complex(real, imag)
164
+
165
+ def autoregressive_rollout(
166
+ self,
167
+ model: LWMBackbone,
168
+ tokens: Tensor,
169
+ Tpast: int,
170
+ horizon: int,
171
+ H: int,
172
+ W: int,
173
+ ) -> Tuple[Tensor, Tensor, Tensor]:
174
+ B, S_full, D = tokens.shape
175
+ S_per_time = H * W
176
+ if S_full % S_per_time != 0:
177
+ raise ValueError("Token sequence length incompatible with H and W")
178
+ T_total = S_full // S_per_time
179
+ S_per_time = H * W
180
+ window_tokens = Tpast + 1
181
+ if T_total < Tpast + horizon:
182
+ raise ValueError("sequence shorter than Tpast + horizon")
183
+ mask_window = torch.zeros((window_tokens, H, W), dtype=torch.bool, device=tokens.device)
184
+ mask_window[Tpast, :, :] = True
185
+ mask_window = mask_window.view(window_tokens * S_per_time)
186
+ mask_future = torch.zeros((T_total, H, W), dtype=torch.bool, device=tokens.device)
187
+ mask_future[Tpast:Tpast + horizon, :, :] = True
188
+ mask_flat = mask_future.view(1, T_total * S_per_time).expand(B, -1)
189
+
190
+ source_tokens = tokens.clone()
191
+ pred_tokens = torch.zeros_like(tokens)
192
+
193
+ for step in range(horizon):
194
+ start_time = step
195
+ end_time = step + window_tokens
196
+ abs_start = start_time * S_per_time
197
+ abs_end = end_time * S_per_time
198
+ window_slice = source_tokens[:, abs_start:abs_end, :].clone() # Clone to avoid in-place modification
199
+ mask_slice = mask_window.unsqueeze(0).expand(B, -1)
200
+ # Zero masked tokens before model forward (matching original implementation)
201
+ window_slice = window_slice.masked_fill(mask_slice.unsqueeze(-1), 0.0)
202
+ outputs = model.forward_tokens(window_slice, mask_slice, window_tokens, H, W, return_cls=False)
203
+ predicted_window = outputs["reconstruction"]
204
+ # Extract predictions for the last time position in the window using slicing
205
+ win_last_start = Tpast * S_per_time
206
+ win_last_end = (Tpast + 1) * S_per_time
207
+ step_pred_last = predicted_window[:, win_last_start:win_last_end, :]
208
+ # Write back into absolute position
209
+ target_range_start = (Tpast + step) * S_per_time
210
+ target_range_end = target_range_start + S_per_time
211
+ source_tokens[:, target_range_start:target_range_end, :] = step_pred_last
212
+ pred_tokens[:, target_range_start:target_range_end, :] = step_pred_last
213
+
214
+ target_tokens = tokens
215
+ return pred_tokens, target_tokens, mask_flat
216
+
217
+
218
+ class PredictionVisualizer:
219
+ def __init__(self, engine: AutoregressiveEngine, save_dir: Path, num_samples: int) -> None:
220
+ self.engine = engine
221
+ self.save_dir = save_dir
222
+ self.num_samples = num_samples
223
+ self.save_dir.mkdir(parents=True, exist_ok=True)
224
+
225
+ def save(self, model: LWMBackbone, tokens: Tensor, H: int, W: int, args: PredictionArgs) -> None:
226
+ model.eval()
227
+ with torch.no_grad():
228
+ preds, tgt, mask = self.engine.autoregressive_rollout(
229
+ model,
230
+ tokens,
231
+ args.Tpast,
232
+ args.horizon,
233
+ H,
234
+ W,
235
+ )
236
+ tokens_per_time = H * W
237
+ T_total = tokens.size(1) // tokens_per_time
238
+ B = tokens.size(0)
239
+ for idx in range(min(B, self.num_samples)):
240
+ pred_seq = preds[idx].view(T_total, tokens_per_time, -1)
241
+ tgt_seq = tgt[idx].view(T_total, tokens_per_time, -1)
242
+ pred_complex = self.engine.detokenize(pred_seq.unsqueeze(0), T_total, H, W)[0]
243
+ tgt_complex = self.engine.detokenize(tgt_seq.unsqueeze(0), T_total, H, W)[0]
244
+ self._plot_sample(pred_complex, tgt_complex, args, sample_idx=idx)
245
+
246
+ def _plot_sample(self, pred: Tensor, tgt: Tensor, args: PredictionArgs, sample_idx: int) -> None:
247
+ import matplotlib.pyplot as plt
248
+
249
+ fig, axes = plt.subplots(args.horizon, 2, figsize=(8, 3 * args.horizon), squeeze=False)
250
+ for step in range(args.horizon):
251
+ t_idx = args.Tpast + step
252
+ gt_mag = tgt[t_idx].abs().cpu().numpy()
253
+ pred_mag = pred[t_idx].abs().cpu().numpy()
254
+ ax_gt, ax_pred = axes[step]
255
+ im0 = ax_gt.imshow(gt_mag, cmap="viridis", aspect="auto")
256
+ im1 = ax_pred.imshow(pred_mag, cmap="viridis", aspect="auto")
257
+ ax_gt.set_title(f"GT t={t_idx}")
258
+ ax_pred.set_title(f"Pred t={t_idx}")
259
+ for ax in (ax_gt, ax_pred):
260
+ ax.set_xticks([])
261
+ ax.set_yticks([])
262
+ fig.colorbar(im0, ax=ax_gt, fraction=0.046, pad=0.04)
263
+ fig.colorbar(im1, ax=ax_pred, fraction=0.046, pad=0.04)
264
+ fig.tight_layout()
265
+ out_path = self.save_dir / f"sample_{sample_idx}.png"
266
+ fig.savefig(out_path)
267
+ plt.close(fig)
268
+
269
+
270
+ class ChannelPredictionTrainer:
271
+ def __init__(self, args: ChannelPredictionArgs, *, logger: Optional[logging.Logger] = None) -> None:
272
+ self.args = args
273
+ torch.manual_seed(args.dataset.seed)
274
+ np.random.seed(args.dataset.seed)
275
+ self.device = torch.device(args.training.device)
276
+ self.engine = AutoregressiveEngine(args.model.patch_size, args.model.phase_mode)
277
+ self.data = ChannelPredictionDataModule(args.dataset, args.model.patch_size, args.model.phase_mode)
278
+ self.model = self._build_model().to(self.device)
279
+ self.model.eval() # Set to eval mode immediately after loading
280
+ if args.training.use_dataparallel and torch.cuda.device_count() > 1:
281
+ self.model = nn.DataParallel(self.model)
282
+ if hasattr(self.model, 'module'):
283
+ self.model.module.eval()
284
+ self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=args.training.lr, weight_decay=args.training.weight_decay)
285
+ self.scheduler = self._build_scheduler()
286
+ self.scaler = GradScaler()
287
+ self.viz = PredictionVisualizer(self.engine, args.prediction.viz_dir, args.prediction.num_visual_samples)
288
+ self.logger = logger or logging.getLogger(__name__)
289
+ self.global_step = 0
290
+ self._wandb_run = self._maybe_init_wandb()
291
+
292
+ def _wandb_enabled(self) -> bool:
293
+ return self._wandb_run is not None
294
+
295
+ def _maybe_init_wandb(self) -> Optional["wandb.sdk.wandb_run.Run"]:
296
+ training = self.args.training
297
+ if not training.use_wandb:
298
+ return None
299
+ if wandb is None:
300
+ self.logger.warning("Weights & Biases not installed; disabling wandb logging.")
301
+ return None
302
+ config = {
303
+ "dataset": dataclasses.asdict(self.args.dataset),
304
+ "model": dataclasses.asdict(self.args.model),
305
+ "training": dataclasses.asdict(self.args.training),
306
+ "prediction": dataclasses.asdict(self.args.prediction),
307
+ }
308
+ run = wandb.init(
309
+ project=training.wandb_project,
310
+ entity=training.wandb_entity,
311
+ name=training.wandb_run_name,
312
+ config=config,
313
+ )
314
+ wandb.watch(self.model, log="all", log_freq=self.args.training.log_interval)
315
+ self.logger.info("Initialized Weights & Biases run: %s", run.name)
316
+ return run
317
+
318
+ def _wandb_log(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
319
+ if not self._wandb_enabled():
320
+ return
321
+ wandb.log(metrics, step=step)
322
+
323
+ def _finish_wandb(self) -> None:
324
+ if self._wandb_enabled():
325
+ wandb.finish()
326
+
327
+ def _build_model(self) -> LWMBackbone:
328
+ # Calculate max_seq_len based on window size (matching original implementation)
329
+ # This is critical for channel prediction with autoregressive rollout
330
+ sample_batch = next(iter(self.data.val_loader(1) or self.data.train_loader(1)))
331
+ _, _, H, W = self._prepare_batch(sample_batch)
332
+ max_seq_len = (self.args.prediction.Tpast + 1) * H * W
333
+
334
+ cfg = LWMConfig(
335
+ patch_size=self.args.model.patch_size,
336
+ phase_mode=self.args.model.phase_mode,
337
+ embed_dim=self.args.model.embed_dim,
338
+ depth=self.args.model.depth,
339
+ num_heads=self.args.model.num_heads,
340
+ mlp_ratio=self.args.model.mlp_ratio,
341
+ same_frame_window=self.args.model.same_frame_window,
342
+ temporal_offsets=self.args.model.temporal_offsets,
343
+ temporal_spatial_window=self.args.model.temporal_spatial_window,
344
+ temporal_drift_h=self.args.model.temporal_drift_h,
345
+ temporal_drift_w=self.args.model.temporal_drift_w,
346
+ routing_topk_enable=self.args.model.routing_topk_enable,
347
+ routing_topk_fraction=self.args.model.routing_topk_fraction,
348
+ routing_topk_min=self.args.model.routing_topk_min,
349
+ routing_topk_max=self.args.model.routing_topk_max,
350
+ topk_per_head=self.args.model.topk_per_head,
351
+ posenc=self.args.model.posenc,
352
+ rope_base=self.args.model.rope_base,
353
+ global_cls=self.args.model.global_cls,
354
+ max_seq_len=max_seq_len,
355
+ )
356
+ model = LWMBackbone(cfg)
357
+ if self.args.model.pretrained is not None and self.args.model.pretrained.exists():
358
+ model = LWMBackbone.from_pretrained(self.args.model.pretrained, config=cfg)
359
+ if self.args.model.train_head_only:
360
+ for param in model.parameters():
361
+ param.requires_grad = False
362
+ for param in model.head.parameters():
363
+ param.requires_grad = True
364
+ elif self.args.model.finetune_last_n > 0:
365
+ model.freeze_backbone()
366
+ if hasattr(model, "encoder"):
367
+ layers = model.encoder.layers
368
+ for layer in layers[-self.args.model.finetune_last_n:]:
369
+ for param in layer.parameters():
370
+ param.requires_grad = True
371
+ for param in model.head.parameters():
372
+ param.requires_grad = True
373
+ return model
374
+
375
+ def _build_scheduler(self) -> torch.optim.lr_scheduler.LambdaLR:
376
+ train_loader = self.data.train_loader(self.args.training.batch_size)
377
+ steps_per_epoch = max(1, len(train_loader))
378
+ total_steps = steps_per_epoch * max(1, self.args.training.epochs)
379
+ warmup_steps = int(self.args.training.warmup_ratio * total_steps)
380
+
381
+ def schedule(step: int) -> float:
382
+ if step < warmup_steps:
383
+ return float(step) / max(1, warmup_steps)
384
+ progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
385
+ return 0.5 * (1.0 + math.cos(math.pi * progress))
386
+
387
+ return torch.optim.lr_scheduler.LambdaLR(self.optimizer, schedule)
388
+
389
+ def _prepare_batch(self, batch: Dict[str, Tensor]) -> Tuple[Tensor, Tensor, int, int]:
390
+ tokens = batch["tokens"].to(self.device)
391
+ base_mask = batch["base_mask"].to(self.device)
392
+ shapes = batch["shape"]
393
+ if not isinstance(shapes, torch.Tensor):
394
+ shapes = torch.tensor(shapes)
395
+ if shapes.dim() == 1:
396
+ shapes = shapes.unsqueeze(0)
397
+ ref_shape = shapes[0]
398
+ if not torch.all(shapes.eq(ref_shape)):
399
+ raise ValueError("Mixed sequence shapes within the same batch are not supported")
400
+ T = int(ref_shape[0].item())
401
+ H = int(ref_shape[1].item())
402
+ W = int(ref_shape[2].item())
403
+ T_needed = self.args.prediction.Tpast + self.args.prediction.horizon
404
+ if T < T_needed:
405
+ raise ValueError("Sequence shorter than required Tpast+horizon frames")
406
+ S_per_time = H * W
407
+ tokens = tokens[:, : T_needed * S_per_time, :]
408
+ mask = base_mask[:, : T_needed * S_per_time]
409
+ return tokens, mask, H, W
410
+
411
+ def _compute_loss(self, pred: Tensor, tgt: Tensor, mask: Tensor) -> Tensor:
412
+ if self.args.training.loss == "mse":
413
+ return masked_mse_loss(pred, tgt, mask)
414
+ return masked_nmse_loss(pred, tgt, mask)
415
+
416
+ def train(self) -> None:
417
+ if self.args.training.inference_only:
418
+ # self.logger.info(
419
+ # "Running inference-only evaluation on split '%s'", self.args.training.inference_split
420
+ # )
421
+ self.evaluate(split=self.args.training.inference_split)
422
+ self._finish_wandb()
423
+ return
424
+ train_loader = self.data.train_loader(self.args.training.batch_size)
425
+ val_loader = self.data.val_loader(self.args.training.batch_size)
426
+ for epoch in range(1, self.args.training.epochs + 1):
427
+ self.model.train()
428
+ running_loss = 0.0
429
+ running_nmse: List[float] = []
430
+ loader_len = len(train_loader)
431
+ for step, batch in enumerate(train_loader, start=1):
432
+ tokens, _, H, W = self._prepare_batch(batch)
433
+ with autocast():
434
+ preds, target, mask = self.engine.autoregressive_rollout(
435
+ self.model, tokens, self.args.prediction.Tpast, self.args.prediction.horizon, H, W
436
+ )
437
+ loss = self._compute_loss(preds, target, mask)
438
+ self.scaler.scale(loss).backward()
439
+ self.scaler.unscale_(self.optimizer)
440
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.training.grad_clip)
441
+ self.scaler.step(self.optimizer)
442
+ self.scaler.update()
443
+ self.optimizer.zero_grad()
444
+ self.scheduler.step()
445
+ running_loss += loss.item()
446
+ running_nmse.append(compute_nmse(preds, target, mask))
447
+ self.global_step += 1
448
+ if step % self.args.training.log_interval == 0:
449
+ avg_loss = running_loss / step
450
+ avg_nmse = float(np.mean(running_nmse)) if running_nmse else float("nan")
451
+ lr_cur = self.optimizer.param_groups[0]["lr"]
452
+ self.logger.info(
453
+ "Train: [%d/%d][%d/%d] loss=%0.6f nmse=%0.6f lr=%0.2e",
454
+ epoch,
455
+ self.args.training.epochs,
456
+ step,
457
+ loader_len,
458
+ avg_loss,
459
+ avg_nmse,
460
+ lr_cur,
461
+ )
462
+ self._wandb_log(
463
+ {
464
+ "train/loss": avg_loss,
465
+ "train/nmse": avg_nmse,
466
+ "train/lr": lr_cur,
467
+ },
468
+ step=self.global_step,
469
+ )
470
+ avg_train_loss = running_loss / max(1, len(train_loader))
471
+ avg_train_nmse = float(np.mean(running_nmse)) if running_nmse else float("nan")
472
+ self.logger.info(
473
+ "Train Epoch %d/%d Summary: loss=%0.6f nmse=%0.6f",
474
+ epoch,
475
+ self.args.training.epochs,
476
+ avg_train_loss,
477
+ avg_train_nmse,
478
+ )
479
+ self._wandb_log(
480
+ {
481
+ "train/epoch_loss": avg_train_loss,
482
+ "train/epoch_nmse": avg_train_nmse,
483
+ },
484
+ step=self.global_step,
485
+ )
486
+ if val_loader is not None:
487
+ self.evaluate(loader=val_loader, split="val", epoch=epoch)
488
+ first_batch = next(iter(train_loader))
489
+ tokens_vis, _, H_vis, W_vis = self._prepare_batch(first_batch)
490
+ self.viz.save(self.model, tokens_vis, H_vis, W_vis, self.args.prediction)
491
+ self._finish_wandb()
492
+
493
+ def evaluate(
494
+ self,
495
+ loader: Optional[DataLoader] = None,
496
+ split: str = "val",
497
+ epoch: Optional[int] = None,
498
+ ) -> None:
499
+ if loader is None:
500
+ if split == "train":
501
+ subset = Subset(self.data.dataset, self.data.train_indices)
502
+ loader = DataLoader(subset, batch_size=self.args.training.batch_size, shuffle=False, drop_last=False)
503
+ elif split == "val":
504
+ subset = Subset(self.data.dataset, self.data.val_indices)
505
+ loader = DataLoader(subset, batch_size=self.args.training.batch_size, shuffle=False, drop_last=False)
506
+ else:
507
+ loader = DataLoader(self.data.dataset, batch_size=self.args.training.batch_size, shuffle=False)
508
+ if loader is None:
509
+ self.logger.warning("No %s loader available", split)
510
+ return
511
+ self.model.eval()
512
+ losses: List[float] = []
513
+ nmses: List[float] = []
514
+ per_step_nmses: List[List[float]] = [] # List of lists: [batch][step]
515
+ with torch.no_grad():
516
+ total_steps = len(loader)
517
+ for step, batch in enumerate(loader, start=1):
518
+ tokens, _, H, W = self._prepare_batch(batch)
519
+ preds, target, mask = self.engine.autoregressive_rollout(
520
+ self.model, tokens, self.args.prediction.Tpast, self.args.prediction.horizon, H, W
521
+ )
522
+ loss = self._compute_loss(preds, target, mask)
523
+ batch_loss = loss.item()
524
+ batch_nmse = compute_nmse(preds, target, mask)
525
+ losses.append(batch_loss)
526
+ nmses.append(batch_nmse)
527
+
528
+ # Compute per-step NMSE for this batch
529
+ S_per_time = H * W
530
+ Tpast = self.args.prediction.Tpast
531
+ horizon = self.args.prediction.horizon
532
+ step_nmses = []
533
+ for h in range(horizon):
534
+ t_idx = Tpast + h
535
+ step_start = t_idx * S_per_time
536
+ step_end = step_start + S_per_time
537
+ step_mask = mask[:, step_start:step_end]
538
+ if step_mask.sum() > 0:
539
+ step_pred = preds[:, step_start:step_end, :]
540
+ step_target = target[:, step_start:step_end, :]
541
+ step_nmse = compute_nmse(step_pred, step_target, step_mask)
542
+ step_nmses.append(step_nmse)
543
+ else:
544
+ step_nmses.append(float('nan'))
545
+ per_step_nmses.append(step_nmses)
546
+
547
+ # Report per-step NMSE for this batch (matching original package format)
548
+ per_step_strs = []
549
+ for h, step_nmse in enumerate(step_nmses):
550
+ if not math.isnan(step_nmse):
551
+ t = Tpast + h + 1 # t=11, 12, ... (1-indexed)
552
+ nmse_db = 10.0 * math.log10(max(step_nmse, 1e-12))
553
+ per_step_strs.append(f"t={t}: {nmse_db:.3f} dB")
554
+ if per_step_strs:
555
+ self.logger.info(
556
+ "[%s] per-step NMSE dB: %s",
557
+ split,
558
+ ", ".join(per_step_strs),
559
+ )
560
+
561
+ if self.args.training.verbose_inference:
562
+ tag = split.upper()
563
+ nmse_db = 10.0 * math.log10(max(batch_nmse, 1e-12))
564
+ self.logger.info(
565
+ "%s: [%d/%d] loss=%0.6f nmse=%0.6f (%0.2f dB)",
566
+ tag,
567
+ step,
568
+ total_steps,
569
+ batch_loss,
570
+ batch_nmse,
571
+ nmse_db,
572
+ )
573
+
574
+ avg_loss = float(np.mean(losses)) if losses else float("nan")
575
+ avg_nmse = float(np.mean(nmses)) if nmses else float("nan")
576
+ tag = f"[{split}]" if epoch is None else f"Epoch {epoch} [{split}]"
577
+ avg_nmse_db = 10.0 * math.log10(max(avg_nmse, 1e-12))
578
+ self.logger.info(
579
+ "Inference [%s] NMSE=%e (%0.3f dB) over %d batches",
580
+ split,
581
+ avg_nmse,
582
+ avg_nmse_db,
583
+ len(losses),
584
+ )
585
+
586
+ # Compute per-step average in dB scale (matching original implementation)
587
+ if per_step_nmses:
588
+ horizon = len(per_step_nmses[0])
589
+ per_step_avg_db = []
590
+ Tpast = self.args.prediction.Tpast
591
+ for h in range(horizon):
592
+ # Average dB values (not linear values!)
593
+ step_dbs = []
594
+ for batch_nmses in per_step_nmses:
595
+ if not math.isnan(batch_nmses[h]):
596
+ step_db = 10.0 * math.log10(max(batch_nmses[h], 1e-12))
597
+ step_dbs.append(step_db)
598
+ if step_dbs:
599
+ avg_db = float(np.mean(step_dbs))
600
+ per_step_avg_db.append(f"t={Tpast + h + 1}: {avg_db:.3f} dB")
601
+ if per_step_avg_db:
602
+ self.logger.info(
603
+ "Inference [%s] per-step average NMSE dB: %s",
604
+ split,
605
+ ", ".join(per_step_avg_db),
606
+ )
607
+
608
+ metrics = {
609
+ f"{split}/loss": avg_loss,
610
+ f"{split}/nmse": avg_nmse,
611
+ f"{split}/nmse_db": avg_nmse_db,
612
+ }
613
+ self._wandb_log(metrics, step=self.global_step)
614
+
615
+ def _save_checkpoint(self, epoch: int, metric: float) -> None:
616
+ self.args.training.save_dir.mkdir(parents=True, exist_ok=True)
617
+ filename = f"{self.args.training.save_prefix}_epoch{epoch:02d}.pth"
618
+ path = self.args.training.save_dir / filename
619
+ state = {
620
+ "epoch": epoch,
621
+ "metric": metric,
622
+ "model_state_dict": self.model.state_dict(),
623
+ "optimizer_state_dict": self.optimizer.state_dict(),
624
+ "scheduler_state_dict": self.scheduler.state_dict(),
625
+ "config": dataclasses.asdict(self.args),
626
+ }
627
+ torch.save(state, path)
628
+ print(f"Saved checkpoint to {path}")
629
+
630
+
631
+ __all__ = [
632
+ "DatasetArgs",
633
+ "ModelArgs",
634
+ "TrainingArgs",
635
+ "PredictionArgs",
636
+ "ChannelPredictionArgs",
637
+ "ChannelPredictionDataModule",
638
+ "AutoregressiveEngine",
639
+ "PredictionVisualizer",
640
+ "ChannelPredictionTrainer",
641
+ ]
LWMTemporal/tasks/pretraining.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import dataclasses
5
+ import logging
6
+ import math
7
+ import pickle
8
+ import random
9
+ from pathlib import Path
10
+ from typing import Dict, List, Optional, Sequence, Tuple
11
+
12
+ import numpy as np
13
+ import torch
14
+ from torch import Tensor
15
+ from torch.cuda.amp import GradScaler, autocast
16
+ from torch.utils.data import DataLoader, Dataset
17
+
18
+ from ..data.angle_delay import AngleDelayConfig, AngleDelayProcessor
19
+ from ..models import LWMBackbone, LWMConfig
20
+ from ..models.lwm import ComplexPatchTokenizer, masked_nmse_loss
21
+
22
+ try:
23
+ import wandb # type: ignore
24
+ except ImportError: # pragma: no cover
25
+ wandb = None # type: ignore
26
+
27
+ @dataclasses.dataclass
28
+ class DataArgs:
29
+ data_dir: Path
30
+ keep_percentage: float = 0.25
31
+ normalize: str = "global_rms"
32
+ max_time_steps: Optional[int] = None
33
+
34
+
35
+ @dataclasses.dataclass
36
+ class MaskArgs:
37
+ mask_ratio: float = 0.75
38
+ mask_mode: str = "auto"
39
+ random_fraction: float = 0.2
40
+
41
+
42
+ @dataclasses.dataclass
43
+ class CurriculumArgs:
44
+ strategy: str = "mask"
45
+ warmup_epochs: int = 4
46
+ min_mask_ratio: float = 0.3
47
+ max_mask_ratio: float = 0.75
48
+
49
+
50
+ @dataclasses.dataclass
51
+ class AugmentationArgs:
52
+ phase_p: float = 0.0
53
+ amp_p: float = 0.0
54
+ amp_min: float = 0.7
55
+ amp_max: float = 1.3
56
+ awgn_p: float = 0.0
57
+ awgn_snr_min: float = 20.0
58
+ awgn_snr_max: float = 30.0
59
+
60
+
61
+ @dataclasses.dataclass
62
+ class LoggingArgs:
63
+ log_dir: Path = Path("logs")
64
+ use_wandb: bool = False
65
+ wandb_project: Optional[str] = None
66
+ wandb_entity: Optional[str] = None
67
+ wandb_run_name: Optional[str] = None
68
+
69
+
70
+ @dataclasses.dataclass
71
+ class OptimizationArgs:
72
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
73
+ epochs: int = 20
74
+ batch_size: int = 32
75
+ lr: float = 2e-4
76
+ weight_decay: float = 1e-4
77
+ warmup_ratio: float = 0.1
78
+ grad_clip: float = 1.0
79
+ log_interval: int = 1
80
+ save_dir: Path = Path("models")
81
+ save_prefix: str = "lwm_pretrain"
82
+ resume_from: Optional[Path] = None
83
+
84
+
85
+ @dataclasses.dataclass
86
+ class ModelArgs:
87
+ patch_size: Tuple[int, int] = (1, 1)
88
+ phase_mode: str = "real_imag"
89
+ embed_dim: int = 32
90
+ depth: int = 12
91
+ num_heads: int = 8
92
+ mlp_ratio: float = 4.0
93
+ same_frame_window: int = 2
94
+ temporal_offsets: Sequence[int] = dataclasses.field(default_factory=lambda: (-4, -3, -2, -1, 1, 2, 3))
95
+ temporal_spatial_window: int = 2
96
+ temporal_drift_h: int = 1
97
+ temporal_drift_w: int = 1
98
+ routing_topk_enable: bool = True
99
+ routing_topk_fraction: float = 0.2
100
+ routing_topk_min: int = 8
101
+ routing_topk_max: int = 32
102
+ topk_per_head: bool = True
103
+ posenc: str = "learned"
104
+ rope_base: float = 10000.0
105
+ global_cls: bool = False
106
+
107
+
108
+ @dataclasses.dataclass
109
+ class PretrainingArgs:
110
+ data: DataArgs
111
+ mask: MaskArgs
112
+ curriculum: CurriculumArgs
113
+ augment: AugmentationArgs
114
+ optim: OptimizationArgs
115
+ model: ModelArgs
116
+ logging: LoggingArgs
117
+
118
+
119
+ class PretrainingDataset(Dataset):
120
+ def __init__(
121
+ self,
122
+ args: DataArgs,
123
+ tokenizer: ComplexPatchTokenizer,
124
+ augmenter: Augmenter,
125
+ masker: MaskGenerator,
126
+ patch_size: Tuple[int, int],
127
+ ) -> None:
128
+ self.args = args
129
+ self.tokenizer = tokenizer
130
+ self.augmenter = augmenter
131
+ self.masker = masker
132
+ self.patch_size = patch_size
133
+ self.samples = self._load_sequences()
134
+ if args.normalize != "none":
135
+ self.samples = [self._normalize(sample, args.normalize) for sample in self.samples]
136
+
137
+ def _load_sequences(self) -> List[Tensor]:
138
+ processor = AngleDelayProcessor(AngleDelayConfig(keep_percentage=self.args.keep_percentage))
139
+ samples: List[Tensor] = []
140
+ for path in sorted(self.args.data_dir.glob("*.p")):
141
+ with path.open("rb") as handle:
142
+ payload = pickle.load(handle)
143
+ if isinstance(payload, dict) and "channel" in payload:
144
+ tensor = torch.as_tensor(payload["channel"], dtype=torch.complex64)
145
+ else:
146
+ tensor = torch.as_tensor(payload, dtype=torch.complex64)
147
+ if tensor.ndim == 3:
148
+ tensor = tensor.unsqueeze(0)
149
+ for seq in tensor:
150
+ ad = processor.forward(seq)
151
+ truncated, _ = processor.truncate_delay_bins(ad)
152
+ if self.args.max_time_steps is not None and truncated.size(0) > self.args.max_time_steps:
153
+ truncated = truncated[: self.args.max_time_steps]
154
+ samples.append(truncated)
155
+ return samples
156
+
157
+ def _normalize(self, tensor: Tensor, mode: str) -> Tensor:
158
+ if mode == "global_rms":
159
+ rms = torch.sqrt((tensor.real.float() ** 2 + tensor.imag.float() ** 2).mean()).clamp_min(1e-8)
160
+ return tensor / rms.to(tensor.dtype)
161
+ if mode == "per_sample_rms":
162
+ rms = torch.sqrt((tensor.real.float() ** 2 + tensor.imag.float() ** 2).mean()).clamp_min(1e-8)
163
+ return tensor / rms.to(tensor.dtype)
164
+ return tensor
165
+
166
+ def __len__(self) -> int:
167
+ return len(self.samples)
168
+
169
+ def __getitem__(self, index: int) -> Dict[str, Tensor]:
170
+ sample = self.samples[index]
171
+ if self.augmenter is not None:
172
+ sample = self.augmenter(sample)
173
+ tokens, _ = self.tokenizer(sample.unsqueeze(0), self.patch_size)
174
+ tokens = tokens.squeeze(0)
175
+ ph, pw = self.patch_size
176
+ T, N, M = sample.shape
177
+ H = N // ph
178
+ W = M // pw
179
+ mask = self.masker(T, H, W, device=tokens.device).view(-1)
180
+ shape = torch.tensor([T, H, W], dtype=torch.long)
181
+ return {
182
+ "tokens": tokens,
183
+ "mask": mask,
184
+ "shape": shape,
185
+ }
186
+
187
+
188
+ class MaskGenerator:
189
+ def __init__(self, args: MaskArgs) -> None:
190
+ self.args = args
191
+
192
+ def __call__(self, T: int, H: int, W: int, device: torch.device) -> torch.BoolTensor:
193
+ if self.args.mask_mode == "random" or (self.args.mask_mode == "auto" and random.random() < self.args.random_fraction):
194
+ return self.random_mask(T, H, W, device)
195
+ if self.args.mask_mode in {"rect", "auto"} and random.random() < 0.33:
196
+ return self.rect_mask(T, H, W, device)
197
+ if self.args.mask_mode in {"tube", "auto"} and random.random() < 0.33:
198
+ return self.tube_mask(T, H, W, device)
199
+ return self.comb_mask(T, H, W, device)
200
+
201
+ def random_mask(self, T: int, H: int, W: int, device: torch.device) -> torch.BoolTensor:
202
+ total = T * H * W
203
+ num_mask = int(self.args.mask_ratio * total)
204
+ mask = torch.zeros(total, dtype=torch.bool, device=device)
205
+ idx = torch.randperm(total, device=device)[:num_mask]
206
+ mask[idx] = True
207
+ return mask.view(T, H, W)
208
+
209
+ def rect_mask(self, T: int, H: int, W: int, device: torch.device) -> torch.BoolTensor:
210
+ mask = torch.zeros((T, H, W), dtype=torch.bool, device=device)
211
+ blocks = max(1, int(self.args.mask_ratio * T))
212
+ for _ in range(blocks):
213
+ t = random.randrange(T)
214
+ h_size = random.randint(1, max(1, H // 2))
215
+ w_size = random.randint(1, max(1, W // 2))
216
+ h0 = random.randint(0, H - h_size)
217
+ w0 = random.randint(0, W - w_size)
218
+ mask[t, h0:h0 + h_size, w0:w0 + w_size] = True
219
+ return mask
220
+
221
+ def tube_mask(self, T: int, H: int, W: int, device: torch.device) -> torch.BoolTensor:
222
+ mask = torch.zeros((T, H, W), dtype=torch.bool, device=device)
223
+ start_t = random.randrange(T)
224
+ h = random.randrange(H)
225
+ w = random.randrange(W)
226
+ length = random.randint(max(1, T // 2), T)
227
+ for k in range(length):
228
+ t_idx = (start_t + k) % T
229
+ mask[t_idx, max(0, h - 1):min(H, h + 2), max(0, w - 1):min(W, w + 2)] = True
230
+ h = max(0, min(H - 1, h + random.randint(-1, 1)))
231
+ w = max(0, min(W - 1, w + random.randint(-1, 1)))
232
+ return mask
233
+
234
+ def comb_mask(self, T: int, H: int, W: int, device: torch.device) -> torch.BoolTensor:
235
+ mask = torch.zeros((T, H, W), dtype=torch.bool, device=device)
236
+ stride_t = random.choice([2, 3]) if T >= 2 else 1
237
+ stride_w = random.choice([3, 4, 6]) if W >= 3 else 1
238
+ offset_t = random.randrange(stride_t)
239
+ offset_w = random.randrange(stride_w)
240
+ for t in range(T):
241
+ for w in range(W):
242
+ visible = (t % stride_t == offset_t) and (w % stride_w == offset_w)
243
+ if not visible:
244
+ mask[t, :, w] = True
245
+ return mask
246
+
247
+
248
+ class Augmenter:
249
+ def __init__(self, args: AugmentationArgs) -> None:
250
+ self.args = args
251
+
252
+ def __call__(self, tensor: Tensor) -> Tensor:
253
+ x = tensor.clone()
254
+ if torch.rand(()) < self.args.phase_p:
255
+ theta = (torch.rand((), device=x.device) * 2 * math.pi) - math.pi
256
+ rotation = torch.cos(theta) + 1j * torch.sin(theta)
257
+ x = x * rotation
258
+ if torch.rand(()) < self.args.amp_p:
259
+ scale = self.args.amp_min + (self.args.amp_max - self.args.amp_min) * torch.rand((), device=x.device)
260
+ x = x * scale
261
+ if torch.rand(()) < self.args.awgn_p:
262
+ snr_db = torch.empty((), device=x.device).uniform_(self.args.awgn_snr_min, self.args.awgn_snr_max)
263
+ snr_lin = 10 ** (snr_db / 10.0)
264
+ power = (x.real.float().pow(2) + x.imag.float().pow(2)).mean().item()
265
+ if power > 0:
266
+ noise_var = power / snr_lin
267
+ std = math.sqrt(noise_var / 2.0)
268
+ noise_real = torch.randn_like(x.real.float()) * std
269
+ noise_imag = torch.randn_like(x.imag.float()) * std
270
+ noise = torch.complex(noise_real.to(x.dtype), noise_imag.to(x.dtype))
271
+ x = x + noise
272
+ return x
273
+
274
+
275
+ class PretrainingTrainer:
276
+ def __init__(self, args: PretrainingArgs, *, logger: Optional[logging.Logger] = None) -> None:
277
+ self.args = args
278
+ self.logger = logger or logging.getLogger(__name__)
279
+ self.device = torch.device(args.optim.device)
280
+ self.tokenizer = ComplexPatchTokenizer(args.model.phase_mode)
281
+ self.masker = MaskGenerator(args.mask)
282
+ self.augmenter = Augmenter(args.augment)
283
+ self.dataset = PretrainingDataset(args.data, self.tokenizer, self.augmenter, self.masker, args.model.patch_size)
284
+ self.dataloader = DataLoader(self.dataset, batch_size=args.optim.batch_size, shuffle=True, drop_last=True)
285
+ self.model = self._build_model().to(self.device)
286
+ self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=args.optim.lr, weight_decay=args.optim.weight_decay)
287
+ self.scheduler = self._build_scheduler()
288
+ self.scaler = GradScaler()
289
+ self.global_step = 0
290
+ self._wandb_run = self._maybe_init_wandb()
291
+ if self.args.optim.resume_from is not None:
292
+ self._load_checkpoint(self.args.optim.resume_from)
293
+
294
+ def _wandb_enabled(self) -> bool:
295
+ return self._wandb_run is not None
296
+
297
+ def _maybe_init_wandb(self) -> Optional["wandb.sdk.wandb_run.Run"]:
298
+ logging_args = self.args.logging
299
+ if not logging_args.use_wandb:
300
+ return None
301
+ if wandb is None:
302
+ self.logger.warning("Weights & Biases not installed; disabling wandb logging.")
303
+ return None
304
+ config = dataclasses.asdict(self.args)
305
+ run = wandb.init(
306
+ project=logging_args.wandb_project,
307
+ entity=logging_args.wandb_entity,
308
+ name=logging_args.wandb_run_name,
309
+ config=config,
310
+ )
311
+ wandb.watch(self.model, log="all", log_freq=self.args.optim.log_interval)
312
+ self.logger.info("Initialized Weights & Biases run: %s", run.name)
313
+ return run
314
+
315
+ def _wandb_log(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
316
+ if not self._wandb_enabled():
317
+ return
318
+ wandb.log(metrics, step=step)
319
+
320
+ def _finish_wandb(self) -> None:
321
+ if self._wandb_enabled():
322
+ wandb.finish()
323
+
324
+ def _build_model(self) -> LWMBackbone:
325
+ cfg = LWMConfig(
326
+ patch_size=self.args.model.patch_size,
327
+ phase_mode=self.args.model.phase_mode,
328
+ embed_dim=self.args.model.embed_dim,
329
+ depth=self.args.model.depth,
330
+ num_heads=self.args.model.num_heads,
331
+ mlp_ratio=self.args.model.mlp_ratio,
332
+ same_frame_window=self.args.model.same_frame_window,
333
+ temporal_offsets=self.args.model.temporal_offsets,
334
+ temporal_spatial_window=self.args.model.temporal_spatial_window,
335
+ temporal_drift_h=self.args.model.temporal_drift_h,
336
+ temporal_drift_w=self.args.model.temporal_drift_w,
337
+ routing_topk_enable=self.args.model.routing_topk_enable,
338
+ routing_topk_fraction=self.args.model.routing_topk_fraction,
339
+ routing_topk_min=self.args.model.routing_topk_min,
340
+ routing_topk_max=self.args.model.routing_topk_max,
341
+ topk_per_head=self.args.model.topk_per_head,
342
+ posenc=self.args.model.posenc,
343
+ rope_base=self.args.model.rope_base,
344
+ global_cls=self.args.model.global_cls,
345
+ )
346
+ return LWMBackbone(cfg)
347
+
348
+ def _build_scheduler(self) -> torch.optim.lr_scheduler.LambdaLR:
349
+ steps_per_epoch = max(1, len(self.dataloader))
350
+ total_steps = steps_per_epoch * max(1, self.args.optim.epochs)
351
+ warmup_steps = int(self.args.optim.warmup_ratio * total_steps)
352
+
353
+ def schedule(step: int) -> float:
354
+ if step < warmup_steps:
355
+ return float(step) / max(1, warmup_steps)
356
+ progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
357
+ return 0.5 * (1.0 + math.cos(math.pi * progress))
358
+
359
+ return torch.optim.lr_scheduler.LambdaLR(self.optimizer, schedule)
360
+
361
+ def _adjust_curriculum(self, epoch: int) -> None:
362
+ if self.args.curriculum.strategy == "mask" and epoch <= self.args.curriculum.warmup_epochs:
363
+ ratio = np.interp(
364
+ epoch,
365
+ [0, self.args.curriculum.warmup_epochs],
366
+ [self.args.curriculum.min_mask_ratio, self.args.curriculum.max_mask_ratio],
367
+ )
368
+ self.masker.args.mask_ratio = float(ratio)
369
+ self.logger.info(
370
+ "Curriculum update | epoch=%d/%d mask_ratio=%0.2f",
371
+ epoch,
372
+ self.args.optim.epochs,
373
+ self.masker.args.mask_ratio,
374
+ )
375
+ self._wandb_log(
376
+ {"train/curriculum_mask_ratio": self.masker.args.mask_ratio},
377
+ step=self.global_step,
378
+ )
379
+ self.logger.info(
380
+ "Curriculum update | epoch=%d/%d mask_ratio=%0.2f",
381
+ epoch,
382
+ self.args.optim.epochs,
383
+ self.masker.args.mask_ratio,
384
+ )
385
+
386
+ def train(self) -> None:
387
+ for epoch in range(1, self.args.optim.epochs + 1):
388
+ self._adjust_curriculum(epoch)
389
+ running_loss = 0.0
390
+ loader_len = len(self.dataloader)
391
+ for step, batch in enumerate(self.dataloader, start=1):
392
+ tokens = batch["tokens"].to(self.device)
393
+ mask_tokens = batch["mask"].to(self.device)
394
+ shapes = batch["shape"]
395
+ if not isinstance(shapes, torch.Tensor):
396
+ shapes = torch.tensor(shapes)
397
+ if shapes.dim() == 1:
398
+ shapes = shapes.unsqueeze(0)
399
+ ref_shape = shapes[0]
400
+ if not torch.all(shapes.eq(ref_shape)):
401
+ raise ValueError("Mixed sequence shapes within the same batch are not supported")
402
+ T = int(ref_shape[0].item())
403
+ H = int(ref_shape[1].item())
404
+ W = int(ref_shape[2].item())
405
+ with autocast():
406
+ outputs = self.model.forward_tokens(tokens, mask_tokens, T, H, W, return_cls=False)
407
+ preds = outputs["reconstruction"]
408
+ loss = masked_nmse_loss(preds, tokens, mask_tokens)
409
+ self.scaler.scale(loss).backward()
410
+ self.scaler.unscale_(self.optimizer)
411
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.optim.grad_clip)
412
+ self.scaler.step(self.optimizer)
413
+ self.scaler.update()
414
+ self.optimizer.zero_grad()
415
+ self.scheduler.step()
416
+ running_loss += loss.item()
417
+ self.global_step += 1
418
+ if step % self.args.optim.log_interval == 0:
419
+ avg_loss = running_loss / step
420
+ lr_cur = self.optimizer.param_groups[0]["lr"]
421
+ self.logger.info(
422
+ "Train: [%d/%d][%d/%d] loss=%0.6f mask=%0.2f lr=%0.2e",
423
+ epoch,
424
+ self.args.optim.epochs,
425
+ step,
426
+ loader_len,
427
+ avg_loss,
428
+ self.masker.args.mask_ratio,
429
+ lr_cur,
430
+ )
431
+ self._wandb_log(
432
+ {
433
+ "train/loss": avg_loss,
434
+ "train/mask_ratio": self.masker.args.mask_ratio,
435
+ "train/lr": lr_cur,
436
+ },
437
+ step=self.global_step,
438
+ )
439
+ avg_epoch_loss = running_loss / max(1, len(self.dataloader))
440
+ self.logger.info(
441
+ "Train Epoch %d/%d Summary: loss=%0.6f",
442
+ epoch,
443
+ self.args.optim.epochs,
444
+ avg_epoch_loss,
445
+ )
446
+ self._wandb_log(
447
+ {
448
+ "train/epoch_loss": avg_epoch_loss,
449
+ },
450
+ step=self.global_step,
451
+ )
452
+ self._save_checkpoint(epoch, avg_epoch_loss)
453
+ self._finish_wandb()
454
+
455
+ def _save_checkpoint(self, epoch: int, metric: float) -> None:
456
+ self.args.optim.save_dir.mkdir(parents=True, exist_ok=True)
457
+ save_prefix = Path(self.args.optim.save_prefix)
458
+ suffix = save_prefix.suffix or ".pth"
459
+ stem = save_prefix.stem if save_prefix.suffix else save_prefix.name
460
+ filename = f"{stem}_epoch{epoch:03d}{suffix}"
461
+ path = self.args.optim.save_dir / filename
462
+ suffix = path.suffix.lower()
463
+ if suffix == ".bin":
464
+ torch.save(self.model.state_dict(), path)
465
+ self.logger.info("Saved weights-only checkpoint to %s", path)
466
+ else:
467
+ torch.save(
468
+ {
469
+ "epoch": epoch,
470
+ "metric": metric,
471
+ "model_state_dict": self.model.state_dict(),
472
+ "optimizer_state_dict": self.optimizer.state_dict(),
473
+ "scheduler_state_dict": self.scheduler.state_dict(),
474
+ "config": dataclasses.asdict(self.args),
475
+ },
476
+ path,
477
+ )
478
+ self.logger.info("Saved checkpoint to %s", path)
479
+ if self._wandb_enabled():
480
+ wandb.save(str(path))
481
+
482
+ def _load_checkpoint(self, checkpoint_path: Path) -> None:
483
+ if not checkpoint_path.exists():
484
+ raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")
485
+ self.logger.info("Resuming from checkpoint %s", checkpoint_path)
486
+
487
+ payload = torch.load(checkpoint_path, map_location=self.device)
488
+ suffix = checkpoint_path.suffix.lower()
489
+
490
+ if suffix == ".bin":
491
+ model_state = payload
492
+ missing, unexpected = self.model.load_state_dict(model_state, strict=False)
493
+ if missing:
494
+ self.logger.warning("Missing keys when loading model: %s", missing)
495
+ if unexpected:
496
+ self.logger.warning("Unexpected keys when loading model: %s", unexpected)
497
+ self.logger.info("Loaded weights-only checkpoint.")
498
+ return
499
+
500
+ model_state = payload.get("model_state_dict")
501
+ if model_state is not None:
502
+ missing, unexpected = self.model.load_state_dict(model_state, strict=False)
503
+ if missing:
504
+ self.logger.warning("Missing keys when loading model: %s", missing)
505
+ if unexpected:
506
+ self.logger.warning("Unexpected keys when loading model: %s", unexpected)
507
+
508
+ opt_state = payload.get("optimizer_state_dict")
509
+ if opt_state is not None:
510
+ try:
511
+ self.optimizer.load_state_dict(opt_state)
512
+ except Exception as exc:
513
+ self.logger.warning("Failed to load optimizer state: %s", exc)
514
+
515
+ sched_state = payload.get("scheduler_state_dict")
516
+ if sched_state is not None:
517
+ try:
518
+ self.scheduler.load_state_dict(sched_state)
519
+ except Exception as exc:
520
+ self.logger.warning("Failed to load scheduler state: %s", exc)
521
+
522
+ epoch = payload.get("epoch", 0)
523
+ metric = payload.get("metric")
524
+ self.logger.info("Checkpoint contained epoch=%s metric=%s", epoch, metric)
525
+
526
+
527
+ def build_pretraining_args(ns: argparse.Namespace) -> PretrainingArgs:
528
+ data_args = DataArgs(
529
+ data_dir=ns.data_dir,
530
+ keep_percentage=ns.keep_percentage,
531
+ normalize=ns.normalize,
532
+ max_time_steps=ns.max_time_steps,
533
+ )
534
+ mask_args = MaskArgs(mask_ratio=ns.mask_ratio, mask_mode=ns.mask_mode, random_fraction=ns.mask_random_fraction)
535
+ curriculum_args = CurriculumArgs(
536
+ strategy=ns.curriculum_strategy,
537
+ warmup_epochs=ns.curriculum_warmup_epochs,
538
+ min_mask_ratio=ns.curriculum_min_mask,
539
+ max_mask_ratio=ns.curriculum_max_mask,
540
+ )
541
+ augment_args = AugmentationArgs(
542
+ phase_p=ns.aug_phase_p,
543
+ amp_p=ns.aug_amp_p,
544
+ amp_min=ns.aug_amp_min,
545
+ amp_max=ns.aug_amp_max,
546
+ awgn_p=ns.aug_awgn_p,
547
+ awgn_snr_min=ns.aug_awgn_snr_min,
548
+ awgn_snr_max=ns.aug_awgn_snr_max,
549
+ )
550
+ optim_args = OptimizationArgs(
551
+ device=ns.device,
552
+ epochs=ns.epochs,
553
+ batch_size=ns.batch_size,
554
+ lr=ns.lr,
555
+ weight_decay=ns.weight_decay,
556
+ warmup_ratio=ns.warmup_ratio,
557
+ grad_clip=ns.grad_clip,
558
+ log_interval=ns.log_interval,
559
+ save_dir=ns.save_dir,
560
+ save_prefix=ns.save_prefix,
561
+ resume_from=ns.resume_from,
562
+ )
563
+ logging_args = LoggingArgs(
564
+ log_dir=ns.log_dir,
565
+ use_wandb=ns.use_wandb,
566
+ wandb_project=ns.wandb_project,
567
+ wandb_entity=ns.wandb_entity,
568
+ wandb_run_name=ns.wandb_run_name,
569
+ )
570
+ model_args = ModelArgs(
571
+ patch_size=tuple(ns.patch_size),
572
+ phase_mode=ns.phase_mode,
573
+ embed_dim=ns.embed_dim,
574
+ depth=ns.depth,
575
+ num_heads=ns.num_heads,
576
+ mlp_ratio=ns.mlp_ratio,
577
+ same_frame_window=ns.same_frame_window,
578
+ temporal_offsets=tuple(ns.temporal_offsets),
579
+ temporal_spatial_window=ns.temporal_spatial_window,
580
+ temporal_drift_h=ns.temporal_drift_h,
581
+ temporal_drift_w=ns.temporal_drift_w,
582
+ routing_topk_enable=ns.routing_topk_enable,
583
+ routing_topk_fraction=ns.routing_topk_fraction,
584
+ routing_topk_min=ns.routing_topk_min,
585
+ routing_topk_max=ns.routing_topk_max,
586
+ topk_per_head=ns.topk_per_head,
587
+ posenc=ns.posenc,
588
+ rope_base=ns.rope_base,
589
+ )
590
+ return PretrainingArgs(
591
+ data=data_args,
592
+ mask=mask_args,
593
+ curriculum=curriculum_args,
594
+ augment=augment_args,
595
+ optim=optim_args,
596
+ model=model_args,
597
+ logging=logging_args,
598
+ )
599
+
600
+
601
+ def build_parser() -> argparse.ArgumentParser:
602
+ parser = argparse.ArgumentParser(description="Pretrain LWM foundation model")
603
+ parser.add_argument("--data_dir", type=Path, required=True)
604
+ parser.add_argument("--keep_percentage", type=float, default=0.25)
605
+ parser.add_argument("--normalize", type=str, default="global_rms", choices=["global_rms", "per_sample_rms", "none"])
606
+ parser.add_argument("--max_time_steps", type=int, default=None)
607
+
608
+ parser.add_argument("--mask_ratio", type=float, default=0.60)
609
+ parser.add_argument("--mask_mode", type=str, default="auto", choices=["auto", "random", "rect", "tube", "comb"])
610
+ parser.add_argument("--mask_random_fraction", type=float, default=0.2)
611
+
612
+ parser.add_argument("--curriculum_strategy", type=str, default="mask", choices=["none", "mask"])
613
+ parser.add_argument("--curriculum_warmup_epochs", type=int, default=4)
614
+ parser.add_argument("--curriculum_min_mask", type=float, default=0.3)
615
+ parser.add_argument("--curriculum_max_mask", type=float, default=0.75)
616
+ parser.add_argument("--log_dir", type=Path, default=Path("logs"))
617
+ parser.add_argument("--use_wandb", action="store_true")
618
+ parser.add_argument("--wandb_project", type=str, default=None)
619
+ parser.add_argument("--wandb_entity", type=str, default=None)
620
+ parser.add_argument("--wandb_run_name", type=str, default=None)
621
+
622
+ parser.add_argument("--phase_mode", type=str, default="real_imag", choices=["real_imag", "mag_phase"])
623
+ parser.add_argument("--patch_size", type=int, nargs=2, default=(1, 1))
624
+ parser.add_argument("--embed_dim", type=int, default=32)
625
+ parser.add_argument("--depth", type=int, default=12)
626
+ parser.add_argument("--num_heads", type=int, default=8)
627
+ parser.add_argument("--mlp_ratio", type=float, default=4.0)
628
+ parser.add_argument("--same_frame_window", type=int, default=2)
629
+ parser.add_argument("--temporal_offsets", type=int, nargs="*", default=[-4, -3, -2, -1, 1, 2, 3])
630
+ parser.add_argument("--temporal_spatial_window", type=int, default=2)
631
+ parser.add_argument("--temporal_drift_h", type=int, default=1)
632
+ parser.add_argument("--temporal_drift_w", type=int, default=1)
633
+ parser.add_argument("--routing_topk_enable", action="store_true", default=True)
634
+ parser.add_argument("--routing_topk_fraction", type=float, default=0.2)
635
+ parser.add_argument("--routing_topk_min", type=int, default=8)
636
+ parser.add_argument("--routing_topk_max", type=int, default=32)
637
+ parser.add_argument("--topk_per_head", action="store_true", default=True)
638
+ parser.add_argument("--posenc", type=str, default="learned", choices=["learned", "rope_sincos"])
639
+ parser.add_argument("--rope_base", type=float, default=10000.0)
640
+
641
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
642
+ parser.add_argument("--epochs", type=int, default=20)
643
+ parser.add_argument("--batch_size", type=int, default=32)
644
+ parser.add_argument("--lr", type=float, default=2e-4)
645
+ parser.add_argument("--weight_decay", type=float, default=1e-4)
646
+ parser.add_argument("--warmup_ratio", type=float, default=0.1)
647
+ parser.add_argument("--grad_clip", type=float, default=1.0)
648
+ parser.add_argument("--log_interval", type=int, default=1)
649
+ parser.add_argument("--save_dir", type=Path, default=Path("models"))
650
+ parser.add_argument("--save_prefix", type=str, default="lwm_pretrain")
651
+ parser.add_argument("--resume_from", type=Path, default=None, help="Path to checkpoint to resume from")
652
+
653
+ parser.add_argument("--aug_phase_p", type=float, default=0.0)
654
+ parser.add_argument("--aug_amp_p", type=float, default=0.0)
655
+ parser.add_argument("--aug_amp_min", type=float, default=0.7)
656
+ parser.add_argument("--aug_amp_max", type=float, default=1.3)
657
+ parser.add_argument("--aug_awgn_p", type=float, default=0.0)
658
+ parser.add_argument("--aug_awgn_snr_min", type=float, default=20.0)
659
+ parser.add_argument("--aug_awgn_snr_max", type=float, default=30.0)
660
+
661
+ return parser
662
+
663
+
664
+ def main(argv: Optional[Sequence[str]] = None) -> None:
665
+ ns = build_parser().parse_args(args=list(argv) if argv is not None else None)
666
+ args = build_pretraining_args(ns)
667
+ trainer = PretrainingTrainer(args)
668
+ trainer.train()
669
+
670
+
671
+ __all__ = [
672
+ "DataArgs",
673
+ "MaskArgs",
674
+ "CurriculumArgs",
675
+ "AugmentationArgs",
676
+ "OptimizationArgs",
677
+ "ModelArgs",
678
+ "PretrainingArgs",
679
+ "PretrainingDataset",
680
+ "PretrainingTrainer",
681
+ "build_pretraining_args",
682
+ "build_parser",
683
+ "main",
684
+ ]
LWMTemporal/training/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Training utilities for LWM foundation models."""
LWMTemporal/utils/logging.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ import logging
5
+
6
+ LOG_FORMAT = "[%(asctime)s,%(msecs)03d %(levelname)s %(name)s line %(lineno)d %(process)d] %(message)s"
7
+ DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
8
+
9
+
10
+ def setup_logging(
11
+ name: str = "LWMTemporal",
12
+ log_dir: Path | None = None,
13
+ level: int = logging.INFO,
14
+ ) -> logging.Logger:
15
+ """Configure and return a logger using the original package's style."""
16
+ logger = logging.getLogger(name)
17
+ logger.setLevel(level)
18
+
19
+ # Avoid duplicating handlers when called multiple times
20
+ if logger.hasHandlers():
21
+ logger.handlers.clear()
22
+
23
+ formatter = logging.Formatter(LOG_FORMAT, DATE_FORMAT)
24
+
25
+ if log_dir is not None:
26
+ log_dir = Path(log_dir)
27
+ log_dir.mkdir(parents=True, exist_ok=True)
28
+ file_handler = logging.FileHandler(log_dir / f"{name}.log")
29
+ file_handler.setFormatter(formatter)
30
+ file_handler.setLevel(level)
31
+ logger.addHandler(file_handler)
32
+
33
+ stream_handler = logging.StreamHandler()
34
+ stream_handler.setFormatter(formatter)
35
+ stream_handler.setLevel(level)
36
+ logger.addHandler(stream_handler)
37
+
38
+ return logger
39
+
MANIFEST.in ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ include README.md
2
+ include LICENSE
3
+ include requirements.txt
4
+ include LWMTemporal/models/config.json
5
+ recursive-exclude * __pycache__
6
+ recursive-exclude * *.py[co]
7
+ recursive-exclude * .DS_Store
8
+ exclude cache
9
+ exclude logs
10
+ exclude figs
11
+ exclude wandb
12
+ exclude checkpoints
13
+ exclude test.py
14
+
README.md ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LWMTemporal
2
+
3
+ Large Wireless Model (LWM) with sparse spatio-temporal attention for wireless channel prediction and forecasting.
4
+
5
+ This package provides a transformer-based model for spatio-temporal wireless channel prediction with support for both pretraining and fine-tuning tasks. It follows Hugging Face conventions for model checkpoints and configurations.
6
+
7
+ ---
8
+
9
+ ## Installation
10
+
11
+ ### From PyPI (Recommended)
12
+
13
+ ```bash
14
+ pip install lwm-temporal
15
+ ```
16
+
17
+ ### From Source
18
+
19
+ ```bash
20
+ git clone https://github.com/yourusername/lwm-temporal.git
21
+ cd lwm-temporal
22
+ pip install -e .
23
+ ```
24
+
25
+ ### Optional Dependencies
26
+
27
+ For Weights & Biases logging:
28
+ ```bash
29
+ pip install lwm-temporal[wandb]
30
+ ```
31
+
32
+ For development:
33
+ ```bash
34
+ pip install lwm-temporal[dev]
35
+ ```
36
+
37
+ ---
38
+
39
+ ## Quick Start
40
+
41
+ ### Python API
42
+
43
+ ```python
44
+ from pathlib import Path
45
+ from LWMTemporal import LWMBackbone, LWMConfig
46
+
47
+ # Load pretrained model
48
+ model = LWMBackbone.from_pretrained("checkpoints/m18_cp.pth")
49
+ model.eval()
50
+
51
+ # Use for inference - see examples/ for complete scripts
52
+ ```
53
+
54
+ ### Command Line Interface
55
+
56
+ ```bash
57
+ # Run channel prediction inference
58
+ python -m LWMTemporal.cli.channel_prediction \
59
+ --data_path examples/data/city_8_tempe_3p5_20_32_32.p \
60
+ --pretrained checkpoints/m18_cp.pth \
61
+ --inference_only \
62
+ --device cpu
63
+ ```
64
+
65
+ See `examples/` directory for more detailed usage examples.
66
+
67
+ ---
68
+
69
+ ## 1. Environment Setup
70
+
71
+ ### Requirements
72
+
73
+ - Python >= 3.9
74
+ - PyTorch >= 2.0.0
75
+ - NumPy >= 1.21.0
76
+ - Matplotlib >= 3.5.0
77
+
78
+ Verify that your PyTorch build matches your hardware (CPU vs CUDA). Mixed-precision (AMP) is optional; on CPU it will automatically disable itself.
79
+
80
+ ---
81
+
82
+ ## 2. Repository Layout
83
+
84
+ ```
85
+ LWMTemporal/ # Main package
86
+ cli/ # Command-line entry points
87
+ data/ # Dataset loaders & preprocessing utilities
88
+ models/ # LWM model + backbone + configs
89
+ tasks/ # High-level training/inference orchestration
90
+ utils/ # Logging and helper utilities
91
+ examples/ # Example scripts and sample data
92
+ data/ # Example datasets
93
+ checkpoints/ # Pretrained model checkpoints
94
+ cache/ # Optional dataset cache (auto-created)
95
+ figs/predictions/ # Visualization output (auto-created)
96
+ logs/ # Training logs (auto-created)
97
+ ```
98
+
99
+ ---
100
+
101
+ ## 3. Checkpoint Format (Hugging Face Compatible)
102
+
103
+ The code supports two checkpoint formats:
104
+
105
+ **Format 1: Directory (Hugging Face style)**
106
+ ```
107
+ checkpoints/my_model/
108
+ config.json # Model configuration
109
+ pytorch_model.bin # Model weights
110
+ ```
111
+
112
+ **Format 2: Single file**
113
+ ```
114
+ checkpoints/my_model.pth # Contains both weights and optional config
115
+ ```
116
+
117
+ The package automatically detects and loads either format.
118
+
119
+ > **Tip:** If you only have a single file (e.g. `model_best.pth`), move it to a directory and rename to `pytorch_model.bin`. Copy or recreate a matching `config.json`. The loader infers `max_seq_len` when it sees a longer positional embedding in the checkpoint, so older weights continue to work.
120
+
121
+ The directory can be uploaded to Hugging Face Hub as-is and loaded via `AutoModel.from_pretrained` if you create a thin wrapper.
122
+
123
+ ---
124
+
125
+ ## 4. Dataset Preparation
126
+
127
+ - The pipeline consumes pickle (`.p`) payloads with a `channel` key (complex tensor) and optional metadata (`pos`, `dt`).
128
+ - `AngleDelaySequenceDataset` normalizes, truncates, and caches angle-delay representations on demand.
129
+ - Configure preprocessing through `DatasetArgs`:
130
+ - `keep_percentage` – fraction of strongest taps to keep.
131
+ - `normalize` – `global_rms`, `per_sample_rms`, or `none`.
132
+ - `cache_dir`, `use_cache`, `overwrite_cache` – caching behavior.
133
+ - `snr_db`, `noise_seed` – synthetic AWGN injection.
134
+ - `max_time_steps` – optional temporal truncation.
135
+
136
+ Cached tensors are stored under `cache/adseq_<stem>_keepXX_<normalize>.pt`.
137
+
138
+ ---
139
+
140
+ ## 5. Command-Line Usage
141
+
142
+ The CLI mirrors the Hugging Face workflow (`python -m package.cli ...`).
143
+
144
+ ### 5.1 Inference / Evaluation
145
+
146
+ ```bash
147
+ python -m LWMTemporal.cli.channel_prediction \
148
+ --data_path examples/data/parow.p \
149
+ --pretrained checkpoints/m18_cp.pth \
150
+ --inference_only \
151
+ --inference_split val \
152
+ --Tpast 10 \
153
+ --horizon 1
154
+ ```
155
+
156
+ - `--Tpast` / `--horizon` define the autoregressive roll-out window.
157
+ - `--inference_split` selects which subset to score (`train`, `val`, `all`).
158
+ - Visualizations are written to `figs/predictions/`.
159
+
160
+ ### 5.2 Training / Fine-Tuning
161
+
162
+ Remove `--inference_only` to launch training:
163
+
164
+ ```bash
165
+ python -m LWMTemporal.cli.channel_prediction \
166
+ --data_path examples/data/parow.p \
167
+ --save_dir models/finetune_run \
168
+ --epochs 5 \
169
+ --batch_size 8 \
170
+ --lr 3e-4 \
171
+ --Tpast 10 \
172
+ --horizon 2
173
+ ```
174
+
175
+ Notable flags:
176
+
177
+ - `--pretrained` – resume from existing weights.
178
+ - `--train_head_only` – freeze encoder, train output head.
179
+ - `--finetune_last_n` – unfreeze last N transformer blocks.
180
+ - `--global_cls` – enable CLS token for global prediction heads.
181
+ - `--routing_topk_enable`, `--topk_per_head`, etc. – sparse attention controls.
182
+ - `--temporal_offsets` – defaults to `[-4, -3, -2, -1]` so the attention only reaches the previous four frames.
183
+ - `--use_wandb` (with `--wandb_project`, `--wandb_run_name`, `--wandb_entity`) – stream training/eval metrics to Weights & Biases.
184
+
185
+ Checkpoints are saved as `save_dir/<prefix>_epochXX.pth` together with optimizer state. Use `ChannelPredictionTrainer._save_checkpoint` for custom logic.
186
+
187
+ ---
188
+
189
+ ## 6. Python API Usage
190
+
191
+ Construct arguments programmatically and drive training/evaluation via `ChannelPredictionTrainer`:
192
+
193
+ ```python
194
+ from pathlib import Path
195
+ from LWMTemporal.tasks.channel_prediction import (
196
+ ChannelPredictionArgs,
197
+ DatasetArgs,
198
+ ModelArgs,
199
+ TrainingArgs,
200
+ PredictionArgs,
201
+ ChannelPredictionTrainer,
202
+ )
203
+
204
+ args = ChannelPredictionArgs(
205
+ dataset=DatasetArgs(
206
+ data_path=Path("examples/data/parow.p"),
207
+ keep_percentage=0.25,
208
+ train_limit=500,
209
+ val_limit=1000,
210
+ ),
211
+ model=ModelArgs(
212
+ patch_size=(1, 1),
213
+ phase_mode="real_imag",
214
+ pretrained=Path("checkpoints/m18_cp.pth"),
215
+ ),
216
+ training=TrainingArgs(
217
+ inference_only=True,
218
+ device="cpu",
219
+ batch_size=4,
220
+ ),
221
+ prediction=PredictionArgs(Tpast=10, horizon=1),
222
+ )
223
+
224
+ trainer = ChannelPredictionTrainer(args)
225
+ trainer.train() # runs evaluate() because inference_only=True
226
+ ```
227
+
228
+ From here you can:
229
+
230
+ - Access `trainer.model` (an `LWMBackbone`) for custom forward passes.
231
+ - Call `trainer.data.train_loader(...)` / `val_loader(...)` for raw dataloaders.
232
+ - Use `trainer.engine.autoregressive_rollout(...)` to obtain `(pred_tokens, target_tokens, mask)` tensors for downstream metrics.
233
+ - Generate visualizations with `trainer.viz.save(...)`.
234
+
235
+ ---
236
+
237
+ ## 7. Working With `LWMBackbone`
238
+
239
+ - Instantiate from scratch: `LWMBackbone(LWMConfig(...))`.
240
+ - Load checkpoints:
241
+ ```python
242
+ from LWMTemporal.models.lwm import LWMBackbone, LWMConfig
243
+
244
+ cfg = LWMConfig(patch_size=(1, 1), embed_dim=32, max_seq_len=2816)
245
+ model = LWMBackbone.from_pretrained("checkpoints/m18_cp.pth", config=cfg)
246
+ ```
247
+ - Save checkpoints: `model.save_pretrained("path/to/output")`.
248
+ - The loader automatically adjusts `config.max_seq_len` when the checkpoint’s positional embedding is longer than the provided config.
249
+
250
+ `LWMModel.forward(seq, mask=None, return_cls=False)` accepts complex tensors shaped `(B, T, N, M)` and returns reconstruction tokens along with an optional CLS embedding when enabled.
251
+
252
+ ---
253
+
254
+ ## 8. Visualization & Metrics
255
+
256
+ - `PredictionVisualizer` renders magnitude plots (`|H|`) for predicted vs. ground-truth angle-delay grids.
257
+ - Metrics:
258
+ - `masked_nmse_loss` / `compute_nmse` – Normalized MSE over valid tokens.
259
+ - `masked_mse_loss` – standard MSE with masking support.
260
+
261
+ Configure masking via token `mask` tensors (boolean) where `True` indicates dropped tokens.
262
+
263
+ ---
264
+
265
+ ## 9. Advanced Configuration
266
+
267
+ - **Sparse Attention Windows:** Control spatial/temporal neighborhoods via `same_frame_window`, `temporal_offsets`, `temporal_spatial_window`, `temporal_drift_*`, and dilation parameters. The default `temporal_offsets = (-4, -3, -2, -1)` limits attention to the previous four frames.
268
+ - **Routing & Top-k Pruning:** Enable dynamic neighbour pruning with `routing_topk_enable`, `routing_topk_fraction`, `routing_topk_min/max`, or fallback to static `topk_neighbors`.
269
+ - **Positional Encoding:** `posenc` supports `learned` or `rope_sincos`. Additional `rope_base_*` parameters adjust RoPE scaling.
270
+ - **CLS Token:** Toggle `global_cls`; the autoregressive rollout handles CLS automatically when present.
271
+ - **Detokenization:** `AutoregressiveEngine.detokenize` converts predicted tokens back to complex-valued channel coefficients.
272
+
273
+ ---
274
+
275
+ ## 10. Troubleshooting
276
+
277
+ - **Circular Imports:** The project avoids cross-imports by keeping tokenizers in `models`. Ensure you are on the latest code if you encounter import errors.
278
+ - **Checkpoint Shape Mismatch:** Confirm `patch_size`, `phase_mode`, and positional embedding lengths match between config and weights.
279
+ - **Neighbor Padding Errors:** Patched `NeighborIndexer` pads ragged neighbour lists with `-1`, so any older ValueError is resolved once you update to the current code.
280
+ - **AMP Warnings:** On CPU you may see `GradScaler` warnings; they are benign because AMP disables itself.
281
+ - **Data Shape Mismatch:** Sequences must have consistent `(T, H, W)` dimensions within each batch. The trainer raises a descriptive error otherwise.
282
+
283
+ ---
284
+
285
+ ## 11. Hugging Face Integration
286
+
287
+ - Because checkpoints follow the standard `config.json` + `pytorch_model.bin` scheme, you can do:
288
+ ```python
289
+ from transformers import AutoConfig, AutoModel
290
+
291
+ cfg = AutoConfig.from_pretrained("path/to/model_best")
292
+ model = AutoModel.from_pretrained("path/to/model_best", config=cfg)
293
+ ```
294
+ - Wrap `LWMBackbone` in a custom `transformers.PreTrainedModel` subclass if you need full pipeline compatibility.
295
+ - Use the same directory structure when publishing to the Hugging Face Hub.
296
+
297
+ ---
298
+
299
+ ## 12. Reproducibility Checklist
300
+
301
+ - Seed control: `DatasetArgs.seed` (for train/val splits); manual seeding via `torch.manual_seed` and `np.random.seed` happens inside the trainer.
302
+ - Log frequency: `TrainingArgs.log_interval`.
303
+ - Gradient clipping: `TrainingArgs.grad_clip` (defaults to 1.0).
304
+ - Warm-up / Scheduler: cosine decay after a configurable warm-up fraction (`TrainingArgs.warmup_ratio`).
305
+
306
+ ---
307
+
308
+ ## 13. Getting Help
309
+
310
+ - Issues related to data format, training instabilities, or new features can be logged on your preferred tracking system or discussed with collaborators.
311
+ - For general transformer best practices, refer to the Hugging Face BERT documentation and friends ([link](https://huggingface.co/docs/transformers/en/model_doc/bert?usage=Pipeline)). The workflow above mirrors that style for LWMTemporal.
312
+
313
+ Happy experimenting!
314
+
315
+ ## Citation
316
+
317
+ If you use LWMTemporal in your research, please cite:
318
+
319
+ ```bibtex
320
+ @article{lwmtemporal2025,
321
+ title={Large Wireless Model for Spatio-Temporal Channel Prediction},
322
+ author={Alikhani, Sadjad and others},
323
+ journal={arXiv preprint arXiv:XXXX.XXXXX},
324
+ year={2025}
325
+ }
326
+ ```
327
+
328
+ ## License
329
+
330
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
331
+
332
+ ## Acknowledgments
333
+
334
+ - Built with PyTorch
335
+ - Inspired by Vision Transformer architectures
336
+ - Supports Hugging Face model hub integration
337
+
338
+ ## Contact
339
+
340
+ For questions or issues, please:
341
+ - Open an issue on GitHub
342
+ - Contact: sadjad.alikhani@asu.edu
343
+
344
+ ## Contributing
345
+
346
+ Contributions are welcome! Please:
347
+ 1. Fork the repository
348
+ 2. Create a feature branch
349
+ 3. Make your changes
350
+ 4. Submit a pull request
351
+
352
+ For major changes, please open an issue first to discuss the proposed changes.
353
+
checkpoints/README.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Checkpoints
2
+
3
+ This directory contains pretrained model checkpoints.
4
+
5
+ ## Available Checkpoints
6
+
7
+ ### `m18_cp.pth`
8
+ - **Task**: Channel prediction (fine-tuned)
9
+ - **Architecture**: 12-layer transformer with 32-dim embeddings
10
+ - **Temporal Attention**: Causal (attends to past 7 frames)
11
+ - **Performance**: ~-20 dB NMSE on validation set
12
+
13
+ ### `pytorch_model.bin`
14
+ - **Task**: Pretrained backbone
15
+ - **Architecture**: Same as above
16
+ - **Temporal Attention**: Bidirectional
17
+
18
+ ## Loading Checkpoints
19
+
20
+ ### Python API
21
+
22
+ ```python
23
+ from pathlib import Path
24
+ from LWMTemporal import LWMBackbone, LWMConfig
25
+
26
+ # Load with default config from checkpoint
27
+ model = LWMBackbone.from_pretrained("checkpoints/m18_cp.pth")
28
+
29
+ # Or override config for fine-tuning
30
+ cfg = LWMConfig(
31
+ temporal_offsets=(-1, -2, -3, -4), # Override for different task
32
+ )
33
+ model = LWMBackbone.from_pretrained("checkpoints/m18_cp.pth", config=cfg)
34
+ ```
35
+
36
+ ### CLI
37
+
38
+ ```bash
39
+ python -m LWMTemporal.cli.channel_prediction \
40
+ --pretrained checkpoints/m18_cp.pth \
41
+ --data_path examples/data/parow.p \
42
+ --inference_only
43
+ ```
44
+
45
+ ## Hosting on Hugging Face Hub (Recommended)
46
+
47
+ For production use, upload checkpoints to Hugging Face Hub:
48
+
49
+ ```bash
50
+ huggingface-cli login
51
+ huggingface-cli upload your-username/lwm-temporal checkpoints/m18_cp.pth
52
+ ```
53
+
54
+ Then load directly from the hub:
55
+ ```python
56
+ model = LWMBackbone.from_pretrained("your-username/lwm-temporal")
57
+ ```
58
+
checkpoints/config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "patch_size": [1, 1],
3
+ "phase_mode": "real_imag",
4
+ "embed_dim": 32,
5
+ "depth": 12,
6
+ "num_heads": 8,
7
+ "mlp_ratio": 4.0,
8
+ "same_frame_window": 2,
9
+ "same_frame_window_h": null,
10
+ "same_frame_window_w": null,
11
+ "same_frame_dilation_h": 1,
12
+ "same_frame_dilation_w": 1,
13
+ "temporal_offsets": [-4, -3, -2, -1, 1, 2, 3],
14
+ "temporal_spatial_window": 2,
15
+ "temporal_spatial_window_h": null,
16
+ "temporal_spatial_window_w": null,
17
+ "temporal_spatial_dilation_h": 1,
18
+ "temporal_spatial_dilation_w": 1,
19
+ "temporal_drift_h": 1,
20
+ "temporal_drift_w": 1,
21
+ "spatial_only": false,
22
+ "routing_topk_enable": true,
23
+ "routing_topk_fraction": 0.2,
24
+ "routing_topk_min": 8,
25
+ "routing_topk_max": 32,
26
+ "routing_topk_per_head": true,
27
+ "topk_neighbors": null,
28
+ "topk_per_head": true,
29
+ "global_cls": false,
30
+ "posenc": "learned",
31
+ "rope_base": 10000.0,
32
+ "rope_mode": "flat",
33
+ "rope_base_t": null,
34
+ "rope_base_h": null,
35
+ "rope_base_w": null,
36
+ "max_seq_len": null
37
+ }
checkpoints/hist/config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "patch_size": [1, 1],
3
+ "phase_mode": "real_imag",
4
+ "embed_dim": 32,
5
+ "depth": 12,
6
+ "num_heads": 8,
7
+ "mlp_ratio": 4.0,
8
+ "same_frame_window": 2,
9
+ "same_frame_window_h": null,
10
+ "same_frame_window_w": null,
11
+ "same_frame_dilation_h": 1,
12
+ "same_frame_dilation_w": 1,
13
+ "temporal_offsets": [-4, -3, -2, -1, 1, 2, 3],
14
+ "temporal_spatial_window": 2,
15
+ "temporal_spatial_window_h": null,
16
+ "temporal_spatial_window_w": null,
17
+ "temporal_spatial_dilation_h": 1,
18
+ "temporal_spatial_dilation_w": 1,
19
+ "temporal_drift_h": 1,
20
+ "temporal_drift_w": 1,
21
+ "spatial_only": false,
22
+ "routing_topk_enable": true,
23
+ "routing_topk_fraction": 0.2,
24
+ "routing_topk_min": 8,
25
+ "routing_topk_max": 32,
26
+ "routing_topk_per_head": true,
27
+ "topk_neighbors": null,
28
+ "topk_per_head": true,
29
+ "global_cls": false,
30
+ "posenc": "learned",
31
+ "rope_base": 10000.0,
32
+ "rope_mode": "flat",
33
+ "rope_base_t": null,
34
+ "rope_base_h": null,
35
+ "rope_base_w": null,
36
+ "max_seq_len": null
37
+ }
examples/README.md ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LWMTemporal Examples
2
+
3
+ This directory contains example scripts demonstrating how to use the LWMTemporal package.
4
+
5
+ ## Quick Start Examples
6
+
7
+ ### 1. Masked Reconstruction (`example_reconstruction.py`)
8
+
9
+ Demonstrates how to:
10
+ - Load wireless channel data
11
+ - Tokenize complex channels
12
+ - Mask random positions
13
+ - Reconstruct using the pretrained model
14
+
15
+ ```bash
16
+ python examples/example_reconstruction.py
17
+ ```
18
+
19
+ ### 2. Channel Prediction Inference (`inference_channel_prediction.py`)
20
+
21
+ Run inference with a fine-tuned channel prediction model:
22
+
23
+ ```bash
24
+ python examples/inference_channel_prediction.py
25
+ ```
26
+
27
+ Expected output: Per-step NMSE around -20 dB
28
+
29
+ ### 3. Train Channel Prediction (`train_channel_prediction.py`)
30
+
31
+ Fine-tune the model for channel prediction:
32
+
33
+ ```bash
34
+ python examples/train_channel_prediction.py
35
+ ```
36
+
37
+ This will:
38
+ - Load pretrained weights
39
+ - Fine-tune on your dataset
40
+ - Save checkpoints to `models/`
41
+ - Generate visualizations in `figs/predictions/`
42
+
43
+ ## Using the CLI
44
+
45
+ The package also provides command-line interfaces:
46
+
47
+ ### Channel Prediction
48
+
49
+ ```bash
50
+ python -m LWMTemporal.cli.channel_prediction \
51
+ --data_path examples/data/city_8_tempe_3p5_20_32_32.p \
52
+ --pretrained checkpoints/m18_cp.pth \
53
+ --inference_only \
54
+ --val_limit 100 \
55
+ --device cpu
56
+ ```
57
+
58
+ ### Pretraining
59
+
60
+ ```bash
61
+ python -m LWMTemporal.cli.pretrain \
62
+ --data_dir examples/data/ \
63
+ --save_prefix models/pretrained \
64
+ --epochs 100 \
65
+ --batch_size 32 \
66
+ --device cuda
67
+ ```
68
+
69
+ ## Data Format
70
+
71
+ Example data files are in `examples/data/`. See `examples/data/README.md` for details on the expected format.
72
+
73
+ ## Checkpoints
74
+
75
+ Pretrained checkpoints are in `checkpoints/`. See `checkpoints/README.md` for available models and loading instructions.
76
+
examples/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """LWMTemporal usage examples."""
2
+
examples/example_reconstruction.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ sys.path.insert(0, str(Path(__file__).parent.parent))
4
+
5
+ import torch
6
+ import numpy as np
7
+
8
+ from LWMTemporal.data.datasets import AngleDelayDatasetConfig, AngleDelaySequenceDataset
9
+ from LWMTemporal.models.lwm import (
10
+ LWMBackbone,
11
+ LWMConfig,
12
+ ComplexPatchTokenizer,
13
+ masked_nmse_loss,
14
+ )
15
+
16
+ # ----- 1. Load one sequence (complex tensor) -----
17
+ data_cfg = AngleDelayDatasetConfig(raw_path=Path("examples/data/parow.p"))
18
+ dataset = AngleDelaySequenceDataset(data_cfg)
19
+
20
+ sequence = dataset[0]["sequence"].unsqueeze(0) # (1, T, N, M)
21
+ sequence = sequence[:, :11] # keep only the first 11 time steps
22
+ print("Sequence shape:", sequence.shape) # expect (1, 11, 32, 8)
23
+
24
+ # ----- 2. Tokenise and select tokens to mask -----
25
+ tokenizer = ComplexPatchTokenizer(phase_mode="real_imag")
26
+ tokens, base_mask = tokenizer(sequence, patch_size=(1, 1)) # tokens: (B, S, D)
27
+
28
+ B, S, D = tokens.shape
29
+ mask_ratio = 0.60 # choose the fraction to hide
30
+ mask = base_mask.clone()
31
+
32
+ # randomly choose the positions that will be hidden
33
+ for b in range(B):
34
+ num_mask = int(mask_ratio * S)
35
+ masked_positions = torch.randperm(S)[:num_mask]
36
+ mask[b, masked_positions] = True
37
+
38
+ # create the corrupted input by zeroing the masked tokens
39
+ corrupted_tokens = tokens.clone()
40
+ corrupted_tokens[mask] = 0.0
41
+
42
+ # ----- 3. Load the pretrained backbone -----
43
+ # Need max_seq_len >= S (here 11 * 32 * 8 = 2816)
44
+ cfg = LWMConfig(
45
+ patch_size=(1, 1),
46
+ phase_mode="real_imag",
47
+ embed_dim=32,
48
+ depth=12,
49
+ num_heads=8,
50
+ mlp_ratio=4.0,
51
+ same_frame_window=2,
52
+ temporal_offsets=(-4, -3, -2, -1, 1, 2, 3),
53
+ temporal_spatial_window=2,
54
+ temporal_drift_h=1,
55
+ temporal_drift_w=1,
56
+ routing_topk_enable=True,
57
+ topk_per_head=True,
58
+ max_seq_len=2816, # 2816
59
+ )
60
+
61
+ backbone = LWMBackbone.from_pretrained(Path("checkpoints/m18_cp.pth"), config=cfg)
62
+ backbone.eval()
63
+
64
+ # ---- 4. Run reconstruction and compute NMSE on the masked positions -----
65
+ with torch.no_grad():
66
+ # compute H, W from the sequence (N and M dimensions)
67
+ T = sequence.size(1)
68
+ H = sequence.size(2)
69
+ W = sequence.size(3)
70
+
71
+ outputs = backbone.forward_tokens(corrupted_tokens, mask, T, H, W, return_cls=False)
72
+ reconstructed = outputs["reconstruction"]
73
+
74
+ nmse = masked_nmse_loss(reconstructed, tokens, mask)
75
+ nmse_db = 10 * torch.log10(nmse)
76
+
77
+ print(f"Masked {mask_ratio*100:.1f}% of tokens ({mask.sum().item()} / {S})")
78
+ print(f"NMSE (linear): {nmse.item():.6f}")
79
+ print(f"NMSE (dB): {nmse_db.item():.2f} dB")
80
+
81
+
82
+
83
+
84
+
85
+ # import torch
86
+ # from pathlib import Path
87
+ # from LWMTemporal.data.datasets import AngleDelayDatasetConfig, AngleDelaySequenceDataset
88
+ # from LWMTemporal.models.lwm import (
89
+ # LWMBackbone,
90
+ # LWMConfig,
91
+ # ComplexPatchTokenizer,
92
+ # masked_nmse_loss,
93
+ # )
94
+
95
+ # # --- 1. Load one sample from the dataset and keep the first 11 frames ---
96
+ # cfg = AngleDelayDatasetConfig(raw_path=Path("LWMTemporal/data/parow.p"))
97
+ # dataset = AngleDelaySequenceDataset(cfg)
98
+ # sequence = dataset[0]["sequence"].unsqueeze(0)[:, :11] # (1, 11, 32, 8)
99
+
100
+ # # --- 2. Tokenise and randomly mask 40% of the tokens ---
101
+ # tokenizer = ComplexPatchTokenizer(phase_mode="real_imag")
102
+ # tokens, base_mask = tokenizer(sequence, patch_size=(1, 1))
103
+ # mask = base_mask.clone()
104
+
105
+ # B, S, _ = tokens.shape
106
+ # mask_fraction = 0.40
107
+
108
+ # for b in range(B):
109
+ # num_mask = int(mask_fraction * S)
110
+ # masked_positions = torch.randperm(S)[:num_mask]
111
+ # mask[b, masked_positions] = True
112
+
113
+ # corrupted_tokens = tokens.clone()
114
+ # corrupted_tokens[mask] = 0.0
115
+
116
+ # T = sequence.size(1)
117
+ # H = sequence.size(2)
118
+ # W = sequence.size(3)
119
+
120
+ # # --- 3. Helper to run a model and report NMSE ---
121
+ # def run_model(model: LWMBackbone, label: str) -> None:
122
+ # model.eval()
123
+ # with torch.no_grad():
124
+ # outputs = model.forward_tokens(corrupted_tokens, mask, T, H, W, return_cls=False)
125
+ # reconstructed = outputs["reconstruction"]
126
+ # nmse = masked_nmse_loss(reconstructed, tokens, mask)
127
+ # nmse_db = 10 * torch.log10(nmse)
128
+ # print(f"{label:>12}: NMSE = {nmse.item():.6f} ({nmse_db.item():.2f} dB)")
129
+
130
+ # # --- 4. Random-weights model ---
131
+ # cfg_random = LWMConfig(max_seq_len=11 * sequence.size(2) * sequence.size(3))
132
+ # model_random = LWMBackbone(cfg_random)
133
+ # run_model(model_random, "random init")
134
+
135
+ # # --- 5. Pretrained checkpoint ---
136
+ # cfg_pretrained = LWMConfig(max_seq_len=11 * sequence.size(2) * sequence.size(3))
137
+ # model_ckpt = LWMBackbone.from_pretrained(Path("LWMTemporal/models"), config=cfg_pretrained)
138
+ # run_model(model_ckpt, "checkpoint")
examples/inference_channel_prediction.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """Example: Run inference with a trained channel prediction model."""
3
+
4
+ import sys
5
+ from pathlib import Path
6
+ sys.path.insert(0, str(Path(__file__).parent.parent))
7
+
8
+ from LWMTemporal.tasks.channel_prediction import (
9
+ ChannelPredictionArgs,
10
+ ChannelPredictionTrainer,
11
+ DatasetArgs,
12
+ ModelArgs,
13
+ TrainingArgs,
14
+ PredictionArgs,
15
+ )
16
+ from LWMTemporal.utils.logging import setup_logging
17
+
18
+ # Setup logging
19
+ logger = setup_logging("channel_prediction_inference", log_dir=Path("logs"))
20
+
21
+ # Configure dataset
22
+ dataset_args = DatasetArgs(
23
+ data_path=Path("examples/data/city_8_tempe_3p5_20_32_32.p"),
24
+ keep_percentage=0.25,
25
+ normalize="global_rms",
26
+ seed=0,
27
+ val_limit=100,
28
+ )
29
+
30
+ # Configure model
31
+ model_args = ModelArgs(
32
+ patch_size=(1, 1),
33
+ phase_mode="real_imag",
34
+ embed_dim=32,
35
+ depth=12,
36
+ num_heads=8,
37
+ mlp_ratio=4.0,
38
+ same_frame_window=2,
39
+ temporal_offsets=(-1, -2, -3, -4, -5, -6, -7), # Causal attention
40
+ temporal_spatial_window=2,
41
+ temporal_drift_h=1,
42
+ temporal_drift_w=1,
43
+ routing_topk_enable=True,
44
+ routing_topk_fraction=0.2,
45
+ routing_topk_max=32,
46
+ pretrained=Path("checkpoints/m18_cp.pth"),
47
+ )
48
+
49
+ # Configure training (inference only)
50
+ training_args = TrainingArgs(
51
+ device="cpu",
52
+ batch_size=2,
53
+ inference_only=True,
54
+ inference_split="val",
55
+ )
56
+
57
+ # Configure prediction
58
+ prediction_args = PredictionArgs(
59
+ Tpast=10,
60
+ horizon=1,
61
+ )
62
+
63
+ # Build full config
64
+ args = ChannelPredictionArgs(
65
+ dataset=dataset_args,
66
+ model=model_args,
67
+ training=training_args,
68
+ prediction=prediction_args,
69
+ )
70
+
71
+ # Run inference
72
+ trainer = ChannelPredictionTrainer(args, logger=logger)
73
+ trainer.train() # train() handles inference_only mode
74
+
75
+ logger.info("Inference complete!")
76
+
examples/train_channel_prediction.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """Example: Train a channel prediction model."""
3
+
4
+ import sys
5
+ from pathlib import Path
6
+ sys.path.insert(0, str(Path(__file__).parent.parent))
7
+
8
+ import torch
9
+ from LWMTemporal.tasks.channel_prediction import (
10
+ ChannelPredictionArgs,
11
+ ChannelPredictionTrainer,
12
+ DatasetArgs,
13
+ ModelArgs,
14
+ TrainingArgs,
15
+ PredictionArgs,
16
+ )
17
+ from LWMTemporal.utils.logging import setup_logging
18
+
19
+ # Setup logging
20
+ logger = setup_logging("channel_prediction_example", log_dir=Path("logs"))
21
+
22
+ # Configure dataset
23
+ dataset_args = DatasetArgs(
24
+ data_path=Path("examples/data/city_8_tempe_3p5_20_32_32.p"),
25
+ keep_percentage=0.25,
26
+ normalize="global_rms",
27
+ seed=0,
28
+ train_limit=500,
29
+ val_limit=100,
30
+ )
31
+
32
+ # Configure model
33
+ model_args = ModelArgs(
34
+ patch_size=(1, 1),
35
+ phase_mode="real_imag",
36
+ embed_dim=32,
37
+ depth=12,
38
+ num_heads=8,
39
+ mlp_ratio=4.0,
40
+ same_frame_window=2,
41
+ temporal_offsets=(-1, -2, -3, -4, -5, -6, -7), # Causal attention
42
+ temporal_spatial_window=2,
43
+ temporal_drift_h=1,
44
+ temporal_drift_w=1,
45
+ routing_topk_enable=True,
46
+ routing_topk_fraction=0.2,
47
+ routing_topk_max=32,
48
+ pretrained=Path("checkpoints/m18_cp.pth"), # Load pretrained weights
49
+ )
50
+
51
+ # Configure training
52
+ training_args = TrainingArgs(
53
+ device="cuda" if torch.cuda.is_available() else "cpu",
54
+ epochs=10,
55
+ batch_size=16,
56
+ lr=1e-4,
57
+ weight_decay=1e-4,
58
+ warmup_ratio=0.1,
59
+ save_dir=Path("models"),
60
+ use_wandb=False, # Set to True to enable Weights & Biases logging
61
+ )
62
+
63
+ # Configure prediction
64
+ prediction_args = PredictionArgs(
65
+ Tpast=10,
66
+ horizon=1,
67
+ viz_dir=Path("figs/predictions"),
68
+ )
69
+
70
+ # Build full config
71
+ args = ChannelPredictionArgs(
72
+ dataset=dataset_args,
73
+ model=model_args,
74
+ training=training_args,
75
+ prediction=prediction_args,
76
+ )
77
+
78
+ # Train
79
+ trainer = ChannelPredictionTrainer(args, logger=logger)
80
+ trainer.train()
81
+
82
+ logger.info("Training complete!")
83
+
pyproject.toml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "lwm-temporal"
7
+ version = "0.1.0"
8
+ description = "Large Wireless Model (LWM) for spatio-temporal wireless channel representation learning"
9
+ readme = "README.md"
10
+ requires-python = ">=3.9"
11
+ license = {text = "MIT"}
12
+ authors = [
13
+ {name = "Sadjad Alikhani", email = "alikhani@asu.edu"}
14
+ ]
15
+ keywords = ["wireless", "sparse-spatiotemporal-attention", "transformer", "deep-learning", "pytorch"]
16
+ classifiers = [
17
+ "Development Status :: 4 - Beta",
18
+ "Intended Audience :: Science/Research",
19
+ "License :: OSI Approved :: MIT License",
20
+ "Programming Language :: Python :: 3",
21
+ "Programming Language :: Python :: 3.9",
22
+ "Programming Language :: Python :: 3.10",
23
+ "Programming Language :: Python :: 3.11",
24
+ "Programming Language :: Python :: 3.12",
25
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
26
+ ]
27
+
28
+ dependencies = [
29
+ "torch>=2.0.0",
30
+ "numpy>=1.21.0",
31
+ "matplotlib>=3.5.0",
32
+ ]
33
+
34
+ [project.optional-dependencies]
35
+ dev = [
36
+ "pytest>=7.0",
37
+ "black>=22.0",
38
+ "flake8>=4.0",
39
+ "mypy>=0.950",
40
+ ]
41
+ wandb = [
42
+ "wandb>=0.13.0",
43
+ ]
44
+
45
+ [project.urls]
46
+ Homepage = "https://github.com/yourusername/lwm-temporal"
47
+ Repository = "https://github.com/yourusername/lwm-temporal"
48
+ Documentation = "https://github.com/yourusername/lwm-temporal#readme"
49
+
50
+ [project.scripts]
51
+ lwm-pretrain = "LWMTemporal.cli.pretrain:main"
52
+ lwm-channel-prediction = "LWMTemporal.cli.channel_prediction:main"
53
+
54
+ [tool.setuptools.packages.find]
55
+ include = ["LWMTemporal*"]
56
+ exclude = ["tests*", "examples*", "checkpoints*", "cache*", "logs*", "figs*", "wandb*"]
57
+
58
+ [tool.setuptools.package-data]
59
+ LWMTemporal = ["models/config.json"]
60
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch>=2.0.0
2
+ numpy>=1.21.0
3
+ matplotlib>=3.5.0
4
+
setup.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Setup script for LWMTemporal package.
3
+ This is kept for backward compatibility; the package primarily uses pyproject.toml.
4
+ """
5
+
6
+ from setuptools import setup, find_packages
7
+ from pathlib import Path
8
+
9
+ # Read the README
10
+ this_directory = Path(__file__).parent
11
+ long_description = (this_directory / "README.md").read_text(encoding="utf-8")
12
+
13
+ setup(
14
+ name="lwm-temporal",
15
+ version="0.1.0",
16
+ author="Sadjad Alikhani",
17
+ author_email="alikhani@asu.edu",
18
+ description="Large Wireless Model (LWM) for spatio-temporal wireless channel prediction",
19
+ long_description=long_description,
20
+ long_description_content_type="text/markdown",
21
+ url="https://github.com/yourusername/lwm-temporal",
22
+ packages=find_packages(include=["LWMTemporal", "LWMTemporal.*"]),
23
+ package_data={
24
+ "LWMTemporal": ["models/config.json"],
25
+ },
26
+ install_requires=[
27
+ "torch>=2.0.0",
28
+ "numpy>=1.21.0",
29
+ "matplotlib>=3.5.0",
30
+ ],
31
+ extras_require={
32
+ "dev": [
33
+ "pytest>=7.0",
34
+ "black>=22.0",
35
+ "flake8>=4.0",
36
+ "mypy>=0.950",
37
+ ],
38
+ "wandb": ["wandb>=0.13.0"],
39
+ },
40
+ entry_points={
41
+ "console_scripts": [
42
+ "lwm-pretrain=LWMTemporal.cli.pretrain:main",
43
+ "lwm-channel-prediction=LWMTemporal.cli.channel_prediction:main",
44
+ ],
45
+ },
46
+ classifiers=[
47
+ "Development Status :: 4 - Beta",
48
+ "Intended Audience :: Science/Research",
49
+ "License :: OSI Approved :: MIT License",
50
+ "Programming Language :: Python :: 3",
51
+ "Programming Language :: Python :: 3.9",
52
+ "Programming Language :: Python :: 3.10",
53
+ "Programming Language :: Python :: 3.11",
54
+ "Programming Language :: Python :: 3.12",
55
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
56
+ ],
57
+ python_requires=">=3.9",
58
+ license="MIT",
59
+ keywords="wireless channel-prediction transformer deep-learning pytorch",
60
+ )
61
+
test_package.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Test script to verify the LWMTemporal package is properly structured and functional.
4
+ Run this before releasing to ensure everything works.
5
+ """
6
+
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ def test_imports():
11
+ """Test that all core components can be imported."""
12
+ print("Testing imports...")
13
+ try:
14
+ from LWMTemporal import LWMBackbone, LWMConfig, LWMModel, __version__
15
+ from LWMTemporal.data import AngleDelaySequenceDataset, AngleDelayDatasetConfig
16
+ from LWMTemporal.tasks.channel_prediction import ChannelPredictionTrainer
17
+ from LWMTemporal.tasks.pretraining import PretrainingTrainer
18
+ print(f" βœ“ All imports successful (version {__version__})")
19
+ return True
20
+ except ImportError as e:
21
+ print(f" βœ— Import failed: {e}")
22
+ return False
23
+
24
+ def test_file_structure():
25
+ """Test that required files exist."""
26
+ print("\nTesting file structure...")
27
+ required_files = [
28
+ "README.md",
29
+ "LICENSE",
30
+ "pyproject.toml",
31
+ "setup.py",
32
+ "requirements.txt",
33
+ "MANIFEST.in",
34
+ ".gitignore",
35
+ "CHANGELOG.md",
36
+ "LWMTemporal/__init__.py",
37
+ "LWMTemporal/models/lwm.py",
38
+ "LWMTemporal/models/config.json",
39
+ "examples/README.md",
40
+ "checkpoints/README.md",
41
+ ]
42
+
43
+ all_exist = True
44
+ for file in required_files:
45
+ path = Path(file)
46
+ if path.exists():
47
+ print(f" βœ“ {file}")
48
+ else:
49
+ print(f" βœ— {file} NOT FOUND")
50
+ all_exist = False
51
+
52
+ return all_exist
53
+
54
+ def test_checkpoints():
55
+ """Test that checkpoints are accessible."""
56
+ print("\nTesting checkpoints...")
57
+ checkpoint_dir = Path("checkpoints")
58
+
59
+ if not checkpoint_dir.exists():
60
+ print(f" βœ— Checkpoints directory not found")
61
+ return False
62
+
63
+ checkpoints = list(checkpoint_dir.glob("*.pth")) + list(checkpoint_dir.glob("*.bin"))
64
+ if checkpoints:
65
+ print(f" βœ“ Found {len(checkpoints)} checkpoint(s)")
66
+ for ckpt in checkpoints:
67
+ print(f" - {ckpt.name}")
68
+ return True
69
+ else:
70
+ print(f" βœ— No checkpoint files found")
71
+ return False
72
+
73
+ def test_examples():
74
+ """Test that example files exist."""
75
+ print("\nTesting examples...")
76
+ examples_dir = Path("examples")
77
+
78
+ if not examples_dir.exists():
79
+ print(f" βœ— Examples directory not found")
80
+ return False
81
+
82
+ py_files = list(examples_dir.glob("*.py"))
83
+ if py_files:
84
+ print(f" βœ“ Found {len(py_files)} example script(s)")
85
+ for script in py_files:
86
+ print(f" - {script.name}")
87
+ return True
88
+ else:
89
+ print(f" βœ— No example scripts found")
90
+ return False
91
+
92
+ def test_data():
93
+ """Test that example data exists."""
94
+ print("\nTesting example data...")
95
+ data_dir = Path("examples/data")
96
+
97
+ if not data_dir.exists():
98
+ print(f" βœ— Example data directory not found")
99
+ return False
100
+
101
+ data_files = list(data_dir.glob("*.p"))
102
+ if data_files:
103
+ print(f" βœ“ Found {len(data_files)} data file(s)")
104
+ for data_file in data_files:
105
+ size_mb = data_file.stat().st_size / (1024 * 1024)
106
+ print(f" - {data_file.name} ({size_mb:.1f} MB)")
107
+ return True
108
+ else:
109
+ print(f" ⚠ No example data files found (optional)")
110
+ return True # Not critical
111
+
112
+ def test_no_data_in_package():
113
+ """Test that data files are not in the main package."""
114
+ print("\nTesting package cleanliness...")
115
+ package_dir = Path("LWMTemporal")
116
+
117
+ data_files = list(package_dir.rglob("*.p"))
118
+ checkpoints = list(package_dir.rglob("*.pth")) + list(package_dir.rglob("*.bin"))
119
+
120
+ issues = []
121
+ if data_files:
122
+ issues.append(f"Found {len(data_files)} .p files in package (should be in examples/)")
123
+ if checkpoints:
124
+ # config.json is OK, but not checkpoints
125
+ checkpoint_files = [f for f in checkpoints if 'hist' not in str(f)]
126
+ if checkpoint_files:
127
+ issues.append(f"Found checkpoint files in package (should be in checkpoints/)")
128
+
129
+ if issues:
130
+ for issue in issues:
131
+ print(f" ⚠ {issue}")
132
+ return False
133
+ else:
134
+ print(f" βœ“ Package directory is clean")
135
+ return True
136
+
137
+ def main():
138
+ """Run all tests."""
139
+ print("=" * 60)
140
+ print("LWMTemporal Package Structure Test")
141
+ print("=" * 60)
142
+
143
+ results = []
144
+ results.append(("Imports", test_imports()))
145
+ results.append(("File Structure", test_file_structure()))
146
+ results.append(("Checkpoints", test_checkpoints()))
147
+ results.append(("Examples", test_examples()))
148
+ results.append(("Example Data", test_data()))
149
+ results.append(("Package Cleanliness", test_no_data_in_package()))
150
+
151
+ print("\n" + "=" * 60)
152
+ print("SUMMARY")
153
+ print("=" * 60)
154
+
155
+ passed = sum(1 for _, result in results if result)
156
+ total = len(results)
157
+
158
+ for name, result in results:
159
+ status = "βœ“ PASS" if result else "βœ— FAIL"
160
+ print(f"{status:8} | {name}")
161
+
162
+ print("=" * 60)
163
+ print(f"Result: {passed}/{total} tests passed")
164
+
165
+ if passed == total:
166
+ print("\nπŸŽ‰ Package is ready for release!")
167
+ return 0
168
+ else:
169
+ print("\n⚠️ Some tests failed. Please review and fix.")
170
+ return 1
171
+
172
+ if __name__ == "__main__":
173
+ sys.exit(main())
174
+