Sadjad Alikhani commited on
Commit Β·
164610c
0
Parent(s):
Initial commit
Browse files- .cursorignore +17 -0
- .gitattributes +3 -0
- .gitignore +76 -0
- LICENSE +22 -0
- LWMTemporal/__init__.py +10 -0
- LWMTemporal/cli/__init__.py +1 -0
- LWMTemporal/cli/channel_prediction.py +175 -0
- LWMTemporal/cli/pretrain.py +25 -0
- LWMTemporal/models/__init__.py +3 -0
- LWMTemporal/models/config.json +37 -0
- LWMTemporal/models/lwm.py +576 -0
- LWMTemporal/tasks/channel_prediction.py +641 -0
- LWMTemporal/tasks/pretraining.py +684 -0
- LWMTemporal/training/__init__.py +1 -0
- LWMTemporal/utils/logging.py +39 -0
- MANIFEST.in +14 -0
- README.md +353 -0
- checkpoints/README.md +58 -0
- checkpoints/config.json +37 -0
- checkpoints/hist/config.json +37 -0
- examples/README.md +76 -0
- examples/__init__.py +2 -0
- examples/example_reconstruction.py +138 -0
- examples/inference_channel_prediction.py +76 -0
- examples/train_channel_prediction.py +83 -0
- pyproject.toml +60 -0
- requirements.txt +4 -0
- setup.py +61 -0
- test_package.py +174 -0
.cursorignore
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python cache
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
|
| 5 |
+
# Experiment artifacts
|
| 6 |
+
cache/
|
| 7 |
+
logs/
|
| 8 |
+
wandb/
|
| 9 |
+
figs/
|
| 10 |
+
checkpoints/*.pth
|
| 11 |
+
checkpoints/*.bin
|
| 12 |
+
|
| 13 |
+
# Data files
|
| 14 |
+
examples/data/*.p
|
| 15 |
+
*.pkl
|
| 16 |
+
*.pickle
|
| 17 |
+
|
.gitattributes
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Auto detect text files and perform LF normalization
|
| 2 |
+
* text=auto
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
MANIFEST
|
| 23 |
+
|
| 24 |
+
# PyTorch
|
| 25 |
+
*.pth
|
| 26 |
+
*.pt
|
| 27 |
+
*.bin
|
| 28 |
+
*.ckpt
|
| 29 |
+
!checkpoints/**/*.pth
|
| 30 |
+
!checkpoints/**/*.bin
|
| 31 |
+
!checkpoints/**/*.json
|
| 32 |
+
!LWMTemporal/models/config.json
|
| 33 |
+
|
| 34 |
+
# Data
|
| 35 |
+
*.p
|
| 36 |
+
*.pkl
|
| 37 |
+
*.pickle
|
| 38 |
+
*.h5
|
| 39 |
+
*.hdf5
|
| 40 |
+
cache/
|
| 41 |
+
data/
|
| 42 |
+
!examples/data/
|
| 43 |
+
!examples/data/*.p
|
| 44 |
+
!examples/data/README.md
|
| 45 |
+
|
| 46 |
+
# Experiments
|
| 47 |
+
logs/
|
| 48 |
+
figs/
|
| 49 |
+
wandb/
|
| 50 |
+
outputs/
|
| 51 |
+
# checkpoints/
|
| 52 |
+
runs/
|
| 53 |
+
|
| 54 |
+
# IDE
|
| 55 |
+
.vscode/
|
| 56 |
+
.idea/
|
| 57 |
+
*.swp
|
| 58 |
+
*.swo
|
| 59 |
+
*~
|
| 60 |
+
.DS_Store
|
| 61 |
+
|
| 62 |
+
# Testing
|
| 63 |
+
.pytest_cache/
|
| 64 |
+
.coverage
|
| 65 |
+
htmlcov/
|
| 66 |
+
.tox/
|
| 67 |
+
|
| 68 |
+
# Environment
|
| 69 |
+
.env
|
| 70 |
+
.venv
|
| 71 |
+
env/
|
| 72 |
+
venv/
|
| 73 |
+
ENV/
|
| 74 |
+
env.bak/
|
| 75 |
+
venv.bak/
|
| 76 |
+
|
LICENSE
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Sadjad Alikhani
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
| 22 |
+
|
LWMTemporal/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LWM Temporal Model package."""
|
| 2 |
+
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
warnings.filterwarnings("ignore")
|
| 6 |
+
|
| 7 |
+
from .models.lwm import LWMConfig, LWMModel, LWMBackbone
|
| 8 |
+
|
| 9 |
+
__version__ = "0.1.0"
|
| 10 |
+
__all__ = ["LWMConfig", "LWMModel", "LWMBackbone", "__version__"]
|
LWMTemporal/cli/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Command line entrypoints for the LWM foundation package."""
|
LWMTemporal/cli/channel_prediction.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Optional, Sequence
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from ..tasks.channel_prediction import (
|
| 9 |
+
ChannelPredictionArgs,
|
| 10 |
+
ChannelPredictionTrainer,
|
| 11 |
+
DatasetArgs,
|
| 12 |
+
ModelArgs,
|
| 13 |
+
PredictionArgs,
|
| 14 |
+
TrainingArgs,
|
| 15 |
+
)
|
| 16 |
+
from ..utils.logging import setup_logging
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def parse_args(argv: Optional[Sequence[str]] = None) -> ChannelPredictionArgs:
|
| 20 |
+
parser = argparse.ArgumentParser(description="Channel prediction trainer")
|
| 21 |
+
parser.add_argument("--data_path", type=Path, required=True)
|
| 22 |
+
parser.add_argument("--keep_percentage", type=float, default=0.25)
|
| 23 |
+
parser.add_argument("--normalize", type=str, default="global_rms", choices=["global_rms", "per_sample_rms", "none"])
|
| 24 |
+
parser.add_argument("--cache_dir", type=Path, default=Path("cache"))
|
| 25 |
+
parser.add_argument("--no_cache", action="store_true")
|
| 26 |
+
parser.add_argument("--overwrite_cache", action="store_true")
|
| 27 |
+
parser.add_argument("--snr_db", type=float, default=None)
|
| 28 |
+
parser.add_argument("--noise_seed", type=int, default=None)
|
| 29 |
+
parser.add_argument("--max_time_steps", type=int, default=None)
|
| 30 |
+
parser.add_argument("--train_limit", type=int, default=500)
|
| 31 |
+
parser.add_argument("--val_limit", type=int, default=100)
|
| 32 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 33 |
+
|
| 34 |
+
parser.add_argument("--patch_size", type=int, nargs=2, default=(1, 1))
|
| 35 |
+
parser.add_argument("--phase_mode", type=str, default="real_imag", choices=["real_imag", "mag_phase"])
|
| 36 |
+
parser.add_argument("--embed_dim", type=int, default=32)
|
| 37 |
+
parser.add_argument("--depth", type=int, default=12)
|
| 38 |
+
parser.add_argument("--num_heads", type=int, default=8)
|
| 39 |
+
parser.add_argument("--mlp_ratio", type=float, default=4.0)
|
| 40 |
+
parser.add_argument("--same_frame_window", type=int, default=2)
|
| 41 |
+
parser.add_argument("--temporal_offsets", type=int, nargs="*", default=[-1, -2, -3, -4, -5, -6, -7])
|
| 42 |
+
parser.add_argument("--temporal_spatial_window", type=int, default=2)
|
| 43 |
+
parser.add_argument("--temporal_drift_h", type=int, default=1)
|
| 44 |
+
parser.add_argument("--temporal_drift_w", type=int, default=1)
|
| 45 |
+
parser.add_argument("--routing_topk_enable", action="store_true", default=True)
|
| 46 |
+
parser.add_argument("--routing_topk_fraction", type=float, default=0.2)
|
| 47 |
+
parser.add_argument("--routing_topk_min", type=int, default=8)
|
| 48 |
+
parser.add_argument("--routing_topk_max", type=int, default=32)
|
| 49 |
+
parser.add_argument("--topk_per_head", action="store_true", default=True)
|
| 50 |
+
parser.add_argument("--posenc", type=str, default="learned", choices=["learned", "rope_sincos"])
|
| 51 |
+
parser.add_argument("--rope_base", type=float, default=10000.0)
|
| 52 |
+
parser.add_argument("--global_cls", action="store_true")
|
| 53 |
+
parser.add_argument("--pretrained", type=Path, default=None)
|
| 54 |
+
parser.add_argument("--finetune_last_n", type=int, default=0)
|
| 55 |
+
parser.add_argument("--train_head_only", action="store_true")
|
| 56 |
+
|
| 57 |
+
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
| 58 |
+
parser.add_argument("--epochs", type=int, default=3)
|
| 59 |
+
parser.add_argument("--batch_size", type=int, default=16)
|
| 60 |
+
parser.add_argument("--lr", type=float, default=1e-4)
|
| 61 |
+
parser.add_argument("--weight_decay", type=float, default=1e-4)
|
| 62 |
+
parser.add_argument("--warmup_ratio", type=float, default=0.1)
|
| 63 |
+
parser.add_argument("--loss", type=str, default="nmse", choices=["nmse", "mse"])
|
| 64 |
+
parser.add_argument("--use_dataparallel", action="store_true")
|
| 65 |
+
parser.add_argument("--grad_clip", type=float, default=1.0)
|
| 66 |
+
parser.add_argument("--log_interval", type=int, default=10)
|
| 67 |
+
parser.add_argument("--save_dir", type=Path, default=Path("models"))
|
| 68 |
+
parser.add_argument("--save_prefix", type=str, default="channel_prediction")
|
| 69 |
+
parser.add_argument("--inference_only", action="store_true")
|
| 70 |
+
parser.add_argument("--inference_split", type=str, default="val", choices=["train", "val", "all"])
|
| 71 |
+
parser.add_argument("--verbose_inference", action="store_true")
|
| 72 |
+
parser.add_argument("--log_dir", type=Path, default=Path("logs"))
|
| 73 |
+
parser.add_argument("--use_wandb", action="store_true")
|
| 74 |
+
parser.add_argument("--wandb_project", type=str, default=None)
|
| 75 |
+
parser.add_argument("--wandb_entity", type=str, default=None)
|
| 76 |
+
parser.add_argument("--wandb_run_name", type=str, default=None)
|
| 77 |
+
|
| 78 |
+
parser.add_argument("--Tpast", type=int, default=10)
|
| 79 |
+
parser.add_argument("--horizon", type=int, default=1)
|
| 80 |
+
parser.add_argument("--num_visual_samples", type=int, default=4)
|
| 81 |
+
parser.add_argument("--viz_dir", type=Path, default=Path("figs/predictions"))
|
| 82 |
+
|
| 83 |
+
ns = parser.parse_args(argv)
|
| 84 |
+
|
| 85 |
+
dataset = DatasetArgs(
|
| 86 |
+
data_path=ns.data_path,
|
| 87 |
+
keep_percentage=ns.keep_percentage,
|
| 88 |
+
normalize=ns.normalize,
|
| 89 |
+
cache_dir=ns.cache_dir,
|
| 90 |
+
use_cache=not ns.no_cache,
|
| 91 |
+
overwrite_cache=ns.overwrite_cache,
|
| 92 |
+
snr_db=ns.snr_db,
|
| 93 |
+
noise_seed=ns.noise_seed,
|
| 94 |
+
max_time_steps=ns.max_time_steps,
|
| 95 |
+
train_limit=ns.train_limit,
|
| 96 |
+
val_limit=ns.val_limit,
|
| 97 |
+
seed=ns.seed,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
model = ModelArgs(
|
| 101 |
+
patch_size=tuple(ns.patch_size),
|
| 102 |
+
phase_mode=ns.phase_mode,
|
| 103 |
+
embed_dim=ns.embed_dim,
|
| 104 |
+
depth=ns.depth,
|
| 105 |
+
num_heads=ns.num_heads,
|
| 106 |
+
mlp_ratio=ns.mlp_ratio,
|
| 107 |
+
same_frame_window=ns.same_frame_window,
|
| 108 |
+
temporal_offsets=tuple(ns.temporal_offsets),
|
| 109 |
+
temporal_spatial_window=ns.temporal_spatial_window,
|
| 110 |
+
temporal_drift_h=ns.temporal_drift_h,
|
| 111 |
+
temporal_drift_w=ns.temporal_drift_w,
|
| 112 |
+
routing_topk_enable=ns.routing_topk_enable,
|
| 113 |
+
routing_topk_fraction=ns.routing_topk_fraction,
|
| 114 |
+
routing_topk_min=ns.routing_topk_min,
|
| 115 |
+
routing_topk_max=ns.routing_topk_max,
|
| 116 |
+
topk_per_head=ns.topk_per_head,
|
| 117 |
+
posenc=ns.posenc,
|
| 118 |
+
rope_base=ns.rope_base,
|
| 119 |
+
global_cls=ns.global_cls,
|
| 120 |
+
pretrained=ns.pretrained,
|
| 121 |
+
finetune_last_n=ns.finetune_last_n,
|
| 122 |
+
train_head_only=ns.train_head_only,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
training = TrainingArgs(
|
| 126 |
+
device=ns.device,
|
| 127 |
+
epochs=ns.epochs,
|
| 128 |
+
batch_size=ns.batch_size,
|
| 129 |
+
lr=ns.lr,
|
| 130 |
+
weight_decay=ns.weight_decay,
|
| 131 |
+
warmup_ratio=ns.warmup_ratio,
|
| 132 |
+
loss=ns.loss,
|
| 133 |
+
use_dataparallel=ns.use_dataparallel,
|
| 134 |
+
grad_clip=ns.grad_clip,
|
| 135 |
+
log_interval=ns.log_interval,
|
| 136 |
+
save_dir=ns.save_dir,
|
| 137 |
+
save_prefix=ns.save_prefix,
|
| 138 |
+
inference_only=ns.inference_only,
|
| 139 |
+
inference_split=ns.inference_split,
|
| 140 |
+
verbose_inference=ns.verbose_inference,
|
| 141 |
+
log_dir=ns.log_dir,
|
| 142 |
+
use_wandb=ns.use_wandb,
|
| 143 |
+
wandb_project=ns.wandb_project,
|
| 144 |
+
wandb_entity=ns.wandb_entity,
|
| 145 |
+
wandb_run_name=ns.wandb_run_name,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
prediction = PredictionArgs(
|
| 149 |
+
Tpast=ns.Tpast,
|
| 150 |
+
horizon=ns.horizon,
|
| 151 |
+
num_visual_samples=ns.num_visual_samples,
|
| 152 |
+
viz_dir=ns.viz_dir,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
return ChannelPredictionArgs(dataset=dataset, model=model, training=training, prediction=prediction)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def main(argv: Optional[Sequence[str]] = None) -> None:
|
| 159 |
+
args = parse_args(argv)
|
| 160 |
+
logger = setup_logging("LWMTemporal.channel_prediction", args.training.log_dir)
|
| 161 |
+
logger.info(
|
| 162 |
+
"Starting channel prediction run | device=%s inference_only=%s use_wandb=%s",
|
| 163 |
+
args.training.device,
|
| 164 |
+
args.training.inference_only,
|
| 165 |
+
args.training.use_wandb,
|
| 166 |
+
)
|
| 167 |
+
trainer = ChannelPredictionTrainer(args, logger=logger)
|
| 168 |
+
trainer.train()
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
__all__ = ["parse_args", "main"]
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
if __name__ == "__main__":
|
| 175 |
+
main()
|
LWMTemporal/cli/pretrain.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Sequence
|
| 4 |
+
|
| 5 |
+
from ..tasks.pretraining import build_parser, build_pretraining_args, PretrainingTrainer
|
| 6 |
+
from ..utils.logging import setup_logging
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def main(argv: Optional[Sequence[str]] = None) -> None:
|
| 10 |
+
parser = build_parser()
|
| 11 |
+
args_ns = parser.parse_args(args=list(argv) if argv is not None else None)
|
| 12 |
+
args = build_pretraining_args(args_ns)
|
| 13 |
+
logger = setup_logging("LWMTemporal.pretraining", args.logging.log_dir)
|
| 14 |
+
logger.info(
|
| 15 |
+
"Starting pretraining run | device=%s epochs=%d batch_size=%d use_wandb=%s",
|
| 16 |
+
args.optim.device,
|
| 17 |
+
args.optim.epochs,
|
| 18 |
+
args.optim.batch_size,
|
| 19 |
+
args.logging.use_wandb,
|
| 20 |
+
)
|
| 21 |
+
trainer = PretrainingTrainer(args, logger=logger)
|
| 22 |
+
trainer.train()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
__all__ = ["main"]
|
LWMTemporal/models/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .lwm import LWMConfig, LWMModel, LWMBackbone
|
| 2 |
+
|
| 3 |
+
__all__ = ["LWMConfig", "LWMModel", "LWMBackbone"]
|
LWMTemporal/models/config.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"patch_size": [1, 1],
|
| 3 |
+
"phase_mode": "real_imag",
|
| 4 |
+
"embed_dim": 32,
|
| 5 |
+
"depth": 12,
|
| 6 |
+
"num_heads": 8,
|
| 7 |
+
"mlp_ratio": 4.0,
|
| 8 |
+
"same_frame_window": 2,
|
| 9 |
+
"same_frame_window_h": null,
|
| 10 |
+
"same_frame_window_w": null,
|
| 11 |
+
"same_frame_dilation_h": 1,
|
| 12 |
+
"same_frame_dilation_w": 1,
|
| 13 |
+
"temporal_offsets": [-4, -3, -2, -1, 1, 2, 3],
|
| 14 |
+
"temporal_spatial_window": 2,
|
| 15 |
+
"temporal_spatial_window_h": null,
|
| 16 |
+
"temporal_spatial_window_w": null,
|
| 17 |
+
"temporal_spatial_dilation_h": 1,
|
| 18 |
+
"temporal_spatial_dilation_w": 1,
|
| 19 |
+
"temporal_drift_h": 1,
|
| 20 |
+
"temporal_drift_w": 1,
|
| 21 |
+
"spatial_only": false,
|
| 22 |
+
"routing_topk_enable": true,
|
| 23 |
+
"routing_topk_fraction": 0.2,
|
| 24 |
+
"routing_topk_min": 8,
|
| 25 |
+
"routing_topk_max": 32,
|
| 26 |
+
"routing_topk_per_head": true,
|
| 27 |
+
"topk_neighbors": null,
|
| 28 |
+
"topk_per_head": true,
|
| 29 |
+
"global_cls": false,
|
| 30 |
+
"posenc": "learned",
|
| 31 |
+
"rope_base": 10000.0,
|
| 32 |
+
"rope_mode": "flat",
|
| 33 |
+
"rope_base_t": null,
|
| 34 |
+
"rope_base_h": null,
|
| 35 |
+
"rope_base_w": null,
|
| 36 |
+
"max_seq_len": null
|
| 37 |
+
}
|
LWMTemporal/models/lwm.py
ADDED
|
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
from dataclasses import dataclass, asdict, fields
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
from torch.utils.data import Dataset
|
| 15 |
+
|
| 16 |
+
# -----------------------------------------------------------------------------
|
| 17 |
+
# Tokenization
|
| 18 |
+
# -----------------------------------------------------------------------------
|
| 19 |
+
class ComplexPatchTokenizer:
|
| 20 |
+
def __init__(self, phase_mode: str = "real_imag") -> None:
|
| 21 |
+
if phase_mode not in {"real_imag", "mag_phase"}:
|
| 22 |
+
raise ValueError("phase_mode must be 'real_imag' or 'mag_phase'")
|
| 23 |
+
self.phase_mode = phase_mode
|
| 24 |
+
|
| 25 |
+
def _split_channels(self, tensor: Tensor) -> Tensor:
|
| 26 |
+
if self.phase_mode == "real_imag":
|
| 27 |
+
real = tensor.real.unsqueeze(-1)
|
| 28 |
+
imag = tensor.imag.unsqueeze(-1)
|
| 29 |
+
return torch.cat([real, imag], dim=-1)
|
| 30 |
+
magnitude = tensor.abs().unsqueeze(-1)
|
| 31 |
+
phase = torch.angle(tensor).unsqueeze(-1)
|
| 32 |
+
return torch.cat([magnitude, phase], dim=-1)
|
| 33 |
+
|
| 34 |
+
def __call__(self, seq: Tensor, patch_size: Tuple[int, int]) -> Tuple[Tensor, Tensor]:
|
| 35 |
+
if not torch.is_complex(seq):
|
| 36 |
+
raise TypeError("expected complex tensor shaped (B, T, N, M)")
|
| 37 |
+
ph, pw = patch_size
|
| 38 |
+
if seq.size(2) % ph != 0 or seq.size(3) % pw != 0:
|
| 39 |
+
raise ValueError("patch_size must evenly divide channel dimensions")
|
| 40 |
+
channels = self._split_channels(seq)
|
| 41 |
+
b, t, n, m, c = channels.shape
|
| 42 |
+
h = n // ph
|
| 43 |
+
w = m // pw
|
| 44 |
+
channels = channels.view(b, t, h, ph, w, pw, c)
|
| 45 |
+
channels = channels.permute(0, 1, 2, 4, 3, 5, 6).contiguous()
|
| 46 |
+
tokens = channels.view(b, t * h * w, ph * pw * c)
|
| 47 |
+
mask = torch.zeros((b, tokens.size(1)), dtype=torch.bool, device=tokens.device)
|
| 48 |
+
return tokens, mask
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# -----------------------------------------------------------------------------
|
| 52 |
+
# Sparse spatio-temporal attention
|
| 53 |
+
# -----------------------------------------------------------------------------
|
| 54 |
+
@dataclass(frozen=True)
|
| 55 |
+
class AttentionCacheKey:
|
| 56 |
+
temporal: int
|
| 57 |
+
height: int
|
| 58 |
+
width: int
|
| 59 |
+
same_frame_window: int
|
| 60 |
+
same_frame_window_h: Optional[int]
|
| 61 |
+
same_frame_window_w: Optional[int]
|
| 62 |
+
same_frame_dilation_h: int
|
| 63 |
+
same_frame_dilation_w: int
|
| 64 |
+
temporal_offsets: Tuple[int, ...]
|
| 65 |
+
temporal_spatial_window: int
|
| 66 |
+
temporal_spatial_window_h: Optional[int]
|
| 67 |
+
temporal_spatial_window_w: Optional[int]
|
| 68 |
+
temporal_spatial_dilation_h: int
|
| 69 |
+
temporal_spatial_dilation_w: int
|
| 70 |
+
temporal_drift_h: int
|
| 71 |
+
temporal_drift_w: int
|
| 72 |
+
include_cls: bool
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class NeighborIndexer:
|
| 76 |
+
def __init__(self) -> None:
|
| 77 |
+
self._cache: Dict[Tuple[int, int, int, AttentionCacheKey], Tensor] = {}
|
| 78 |
+
|
| 79 |
+
def get(self, T: int, H: int, W: int, include_cls: bool, config: "LWMConfig", device: torch.device) -> Tensor:
|
| 80 |
+
key = (
|
| 81 |
+
T,
|
| 82 |
+
H,
|
| 83 |
+
W,
|
| 84 |
+
AttentionCacheKey(
|
| 85 |
+
temporal=T,
|
| 86 |
+
height=H,
|
| 87 |
+
width=W,
|
| 88 |
+
same_frame_window=config.same_frame_window,
|
| 89 |
+
same_frame_window_h=config.same_frame_window_h,
|
| 90 |
+
same_frame_window_w=config.same_frame_window_w,
|
| 91 |
+
same_frame_dilation_h=config.same_frame_dilation_h,
|
| 92 |
+
same_frame_dilation_w=config.same_frame_dilation_w,
|
| 93 |
+
temporal_offsets=config.temporal_offsets,
|
| 94 |
+
temporal_spatial_window=config.temporal_spatial_window,
|
| 95 |
+
temporal_spatial_window_h=config.temporal_spatial_window_h,
|
| 96 |
+
temporal_spatial_window_w=config.temporal_spatial_window_w,
|
| 97 |
+
temporal_spatial_dilation_h=config.temporal_spatial_dilation_h,
|
| 98 |
+
temporal_spatial_dilation_w=config.temporal_spatial_dilation_w,
|
| 99 |
+
temporal_drift_h=config.temporal_drift_h,
|
| 100 |
+
temporal_drift_w=config.temporal_drift_w,
|
| 101 |
+
include_cls=include_cls,
|
| 102 |
+
),
|
| 103 |
+
)
|
| 104 |
+
if key in self._cache:
|
| 105 |
+
tensor = self._cache[key]
|
| 106 |
+
return tensor if tensor.device == device else tensor.to(device)
|
| 107 |
+
indices = self._build_indices(T, H, W, include_cls, config)
|
| 108 |
+
if indices:
|
| 109 |
+
max_len = max(len(neighbors) for neighbors in indices)
|
| 110 |
+
if any(len(neighbors) != max_len for neighbors in indices):
|
| 111 |
+
padded = []
|
| 112 |
+
for neighbors in indices:
|
| 113 |
+
if len(neighbors) < max_len:
|
| 114 |
+
neighbors = neighbors + [-1] * (max_len - len(neighbors))
|
| 115 |
+
padded.append(neighbors)
|
| 116 |
+
indices = padded
|
| 117 |
+
tensor = torch.as_tensor(indices, dtype=torch.long, device=device)
|
| 118 |
+
self._cache[key] = tensor
|
| 119 |
+
return tensor
|
| 120 |
+
|
| 121 |
+
def _build_indices(self, T: int, H: int, W: int, include_cls: bool, config: "LWMConfig") -> List[List[int]]:
|
| 122 |
+
neighbors: List[List[int]] = []
|
| 123 |
+
same_h = config.same_frame_window if config.same_frame_window_h is None else config.same_frame_window_h
|
| 124 |
+
same_w = config.same_frame_window if config.same_frame_window_w is None else config.same_frame_window_w
|
| 125 |
+
|
| 126 |
+
def frame_base(frame: int) -> int:
|
| 127 |
+
return frame * H * W
|
| 128 |
+
|
| 129 |
+
for t_idx in range(T):
|
| 130 |
+
base = frame_base(t_idx)
|
| 131 |
+
for h_idx in range(H):
|
| 132 |
+
for w_idx in range(W):
|
| 133 |
+
current = base + h_idx * W + w_idx
|
| 134 |
+
local: List[int] = []
|
| 135 |
+
if config.same_frame_window < 0:
|
| 136 |
+
local.extend(range(base, base + H * W))
|
| 137 |
+
else:
|
| 138 |
+
for dh in range(-same_h, same_h + 1, config.same_frame_dilation_h):
|
| 139 |
+
for dw in range(-same_w, same_w + 1, config.same_frame_dilation_w):
|
| 140 |
+
nh = h_idx + dh
|
| 141 |
+
nw = w_idx + dw
|
| 142 |
+
if 0 <= nh < H and 0 <= nw < W:
|
| 143 |
+
local.append(base + nh * W + nw)
|
| 144 |
+
if not config.spatial_only:
|
| 145 |
+
for dt in config.temporal_offsets:
|
| 146 |
+
other_t = t_idx + dt
|
| 147 |
+
if other_t < 0 or other_t >= T:
|
| 148 |
+
continue
|
| 149 |
+
other_base = frame_base(other_t)
|
| 150 |
+
drift_h = config.temporal_spatial_window if config.temporal_drift_h == 0 else min(config.temporal_spatial_window, abs(dt) * config.temporal_drift_h)
|
| 151 |
+
drift_w = config.temporal_spatial_window if config.temporal_drift_w == 0 else min(config.temporal_spatial_window, abs(dt) * config.temporal_drift_w)
|
| 152 |
+
window_h = config.temporal_spatial_window if config.temporal_spatial_window_h is None else config.temporal_spatial_window_h
|
| 153 |
+
window_w = config.temporal_spatial_window if config.temporal_spatial_window_w is None else config.temporal_spatial_window_w
|
| 154 |
+
for dh in range(-min(window_h, drift_h), min(window_h, drift_h) + 1, config.temporal_spatial_dilation_h):
|
| 155 |
+
for dw in range(-min(window_w, drift_w), min(window_w, drift_w) + 1, config.temporal_spatial_dilation_w):
|
| 156 |
+
nh = max(0, min(H - 1, h_idx + dh))
|
| 157 |
+
nw = max(0, min(W - 1, w_idx + dw))
|
| 158 |
+
local.append(other_base + nh * W + nw)
|
| 159 |
+
if include_cls:
|
| 160 |
+
local.append(T * H * W)
|
| 161 |
+
if not local:
|
| 162 |
+
local.append(current)
|
| 163 |
+
neighbors.append(sorted(set(local)))
|
| 164 |
+
if include_cls:
|
| 165 |
+
neighbors.append(list(range(T * H * W)))
|
| 166 |
+
return neighbors
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class SparseSpatioTemporalAttention(nn.Module):
|
| 170 |
+
def __init__(self, config: "LWMConfig", embed_dim: int, num_heads: int) -> None:
|
| 171 |
+
super().__init__()
|
| 172 |
+
self.config = config
|
| 173 |
+
self.embed_dim = embed_dim
|
| 174 |
+
self.num_heads = num_heads
|
| 175 |
+
self.head_dim = embed_dim // num_heads
|
| 176 |
+
if self.head_dim * num_heads != embed_dim:
|
| 177 |
+
raise ValueError("embed_dim must be divisible by num_heads")
|
| 178 |
+
self.scale = self.head_dim ** -0.5
|
| 179 |
+
self.qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
|
| 180 |
+
self.proj = nn.Linear(embed_dim, embed_dim)
|
| 181 |
+
self.indexer = NeighborIndexer()
|
| 182 |
+
|
| 183 |
+
def _apply_rope(self, x: Tensor, cos: Tensor, sin: Tensor) -> Tensor:
|
| 184 |
+
x1 = x[..., ::2]
|
| 185 |
+
x2 = x[..., 1::2]
|
| 186 |
+
rotated_first = x1 * cos - x2 * sin
|
| 187 |
+
rotated_second = x1 * sin + x2 * cos
|
| 188 |
+
return torch.stack([rotated_first, rotated_second], dim=-1).flatten(-2)
|
| 189 |
+
|
| 190 |
+
def _rope_factors(self, S: int, device: torch.device) -> Tuple[Tensor, Tensor]:
|
| 191 |
+
half = self.head_dim // 2
|
| 192 |
+
inv_freq = 1.0 / (self.config.rope_base ** (torch.arange(0, half, dtype=torch.float32, device=device) / max(1, half)))
|
| 193 |
+
positions = torch.arange(S, dtype=torch.float32, device=device)
|
| 194 |
+
angles = positions[:, None] * inv_freq[None, :]
|
| 195 |
+
return torch.cos(angles)[None, None, :, :], torch.sin(angles)[None, None, :, :]
|
| 196 |
+
|
| 197 |
+
def forward(self, hidden_states: Tensor, T: int, H: int, W: int, include_cls: bool) -> Tensor:
|
| 198 |
+
bsz, seq_len, _ = hidden_states.shape
|
| 199 |
+
neighbors = self.indexer.get(T, H, W, include_cls, self.config, hidden_states.device)
|
| 200 |
+
qkv = self.qkv(hidden_states)
|
| 201 |
+
qkv = qkv.view(bsz, seq_len, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 202 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 203 |
+
|
| 204 |
+
if self.config.posenc == "rope_sincos":
|
| 205 |
+
cos, sin = self._rope_factors(seq_len, hidden_states.device)
|
| 206 |
+
q = self._apply_rope(q, cos, sin)
|
| 207 |
+
k = self._apply_rope(k, cos, sin)
|
| 208 |
+
|
| 209 |
+
gather_idx = neighbors.clamp_min(0)
|
| 210 |
+
valid_mask = neighbors >= 0
|
| 211 |
+
k = k[:, :, gather_idx, :]
|
| 212 |
+
v = v[:, :, gather_idx, :]
|
| 213 |
+
scores = torch.einsum("bhqd,bhqkd->bhqk", q, k) * self.scale
|
| 214 |
+
scores = scores.masked_fill(~valid_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
|
| 215 |
+
|
| 216 |
+
if self.config.routing_topk_enable:
|
| 217 |
+
K = scores.size(-1)
|
| 218 |
+
keep = min(self.config.routing_topk_max, max(self.config.routing_topk_min, int(self.config.routing_topk_fraction * K)))
|
| 219 |
+
if self.config.routing_topk_per_head:
|
| 220 |
+
_, idx = torch.topk(scores, keep, dim=-1)
|
| 221 |
+
topk_mask = torch.zeros_like(scores, dtype=torch.bool)
|
| 222 |
+
topk_mask.scatter_(-1, idx, True)
|
| 223 |
+
else:
|
| 224 |
+
avg_scores = scores.mean(dim=1, keepdim=True)
|
| 225 |
+
_, idx = torch.topk(avg_scores, keep, dim=-1)
|
| 226 |
+
topk_mask = torch.zeros_like(scores, dtype=torch.bool)
|
| 227 |
+
topk_mask.scatter_(-1, idx.expand_as(scores), True)
|
| 228 |
+
scores = scores.masked_fill(~topk_mask, float("-inf"))
|
| 229 |
+
elif self.config.topk_neighbors is not None:
|
| 230 |
+
keep = min(self.config.topk_neighbors, scores.size(-1))
|
| 231 |
+
if self.config.topk_per_head:
|
| 232 |
+
_, idx = torch.topk(scores, keep, dim=-1)
|
| 233 |
+
topk_mask = torch.zeros_like(scores, dtype=torch.bool)
|
| 234 |
+
topk_mask.scatter_(-1, idx, True)
|
| 235 |
+
else:
|
| 236 |
+
avg_scores = scores.mean(dim=1, keepdim=True)
|
| 237 |
+
_, idx = torch.topk(avg_scores, keep, dim=-1)
|
| 238 |
+
topk_mask = torch.zeros_like(scores, dtype=torch.bool)
|
| 239 |
+
topk_mask.scatter_(-1, idx.expand_as(scores), True)
|
| 240 |
+
scores = scores.masked_fill(~topk_mask, float("-inf"))
|
| 241 |
+
|
| 242 |
+
attn = torch.softmax(scores, dim=-1)
|
| 243 |
+
attn = attn.masked_fill(~valid_mask.unsqueeze(0).unsqueeze(0), 0.0)
|
| 244 |
+
context = torch.einsum("bhqk,bhqkd->bhqd", attn, v)
|
| 245 |
+
context = context.transpose(1, 2).contiguous().view(bsz, seq_len, self.embed_dim)
|
| 246 |
+
return self.proj(context)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class LWMEncoderLayer(nn.Module):
|
| 250 |
+
def __init__(self, config: "LWMConfig") -> None:
|
| 251 |
+
super().__init__()
|
| 252 |
+
self.norm1 = nn.LayerNorm(config.embed_dim)
|
| 253 |
+
self.attn = SparseSpatioTemporalAttention(config, config.embed_dim, config.num_heads)
|
| 254 |
+
self.norm2 = nn.LayerNorm(config.embed_dim)
|
| 255 |
+
hidden_dim = int(config.embed_dim * config.mlp_ratio)
|
| 256 |
+
self.mlp = nn.Sequential(
|
| 257 |
+
nn.Linear(config.embed_dim, hidden_dim),
|
| 258 |
+
nn.GELU(),
|
| 259 |
+
nn.Linear(hidden_dim, config.embed_dim),
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
def forward(self, x: Tensor, T: int, H: int, W: int, include_cls: bool) -> Tensor:
|
| 263 |
+
x = x + self.attn(self.norm1(x), T, H, W, include_cls)
|
| 264 |
+
x = x + self.mlp(self.norm2(x))
|
| 265 |
+
return x
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class LWMEncoder(nn.Module):
|
| 269 |
+
def __init__(self, config: "LWMConfig") -> None:
|
| 270 |
+
super().__init__()
|
| 271 |
+
self.layers = nn.ModuleList([LWMEncoderLayer(config) for _ in range(config.depth)])
|
| 272 |
+
self.norm = nn.LayerNorm(config.embed_dim)
|
| 273 |
+
|
| 274 |
+
def forward(self, x: Tensor, T: int, H: int, W: int, include_cls: bool) -> Tensor:
|
| 275 |
+
for layer in self.layers:
|
| 276 |
+
x = layer(x, T, H, W, include_cls)
|
| 277 |
+
return self.norm(x)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
# -----------------------------------------------------------------------------
|
| 281 |
+
# Hugging Face configuration and model definitions
|
| 282 |
+
# -----------------------------------------------------------------------------
|
| 283 |
+
@dataclass
|
| 284 |
+
class LWMConfig:
|
| 285 |
+
patch_size: Tuple[int, int] = (1, 1)
|
| 286 |
+
phase_mode: str = "real_imag"
|
| 287 |
+
embed_dim: int = 32
|
| 288 |
+
depth: int = 12
|
| 289 |
+
num_heads: int = 8
|
| 290 |
+
mlp_ratio: float = 4.0
|
| 291 |
+
same_frame_window: int = 2
|
| 292 |
+
same_frame_window_h: Optional[int] = None
|
| 293 |
+
same_frame_window_w: Optional[int] = None
|
| 294 |
+
same_frame_dilation_h: int = 1
|
| 295 |
+
same_frame_dilation_w: int = 1
|
| 296 |
+
temporal_offsets: Tuple[int, ...] = (-4, -3, -2, -1, 1, 2, 3)
|
| 297 |
+
temporal_spatial_window: int = 2
|
| 298 |
+
temporal_spatial_window_h: Optional[int] = None
|
| 299 |
+
temporal_spatial_window_w: Optional[int] = None
|
| 300 |
+
temporal_spatial_dilation_h: int = 1
|
| 301 |
+
temporal_spatial_dilation_w: int = 1
|
| 302 |
+
temporal_drift_h: int = 1
|
| 303 |
+
temporal_drift_w: int = 1
|
| 304 |
+
spatial_only: bool = False
|
| 305 |
+
routing_topk_enable: bool = True
|
| 306 |
+
routing_topk_fraction: float = 0.2
|
| 307 |
+
routing_topk_min: int = 8
|
| 308 |
+
routing_topk_max: int = 32
|
| 309 |
+
routing_topk_per_head: bool = True
|
| 310 |
+
topk_neighbors: Optional[int] = None
|
| 311 |
+
topk_per_head: bool = True
|
| 312 |
+
global_cls: bool = False
|
| 313 |
+
posenc: str = "learned"
|
| 314 |
+
rope_base: float = 10000.0
|
| 315 |
+
rope_mode: str = "flat"
|
| 316 |
+
rope_base_t: Optional[float] = None
|
| 317 |
+
rope_base_h: Optional[float] = None
|
| 318 |
+
rope_base_w: Optional[float] = None
|
| 319 |
+
max_seq_len: Optional[int] = None
|
| 320 |
+
|
| 321 |
+
def __post_init__(self) -> None:
|
| 322 |
+
self.patch_size = (int(self.patch_size[0]), int(self.patch_size[1]))
|
| 323 |
+
self.temporal_offsets = tuple(int(o) for o in self.temporal_offsets)
|
| 324 |
+
|
| 325 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 326 |
+
return asdict(self)
|
| 327 |
+
|
| 328 |
+
@classmethod
|
| 329 |
+
def from_dict(cls, data: Dict[str, Any]) -> "LWMConfig":
|
| 330 |
+
return cls(**data)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
class LWMModel(nn.Module):
|
| 334 |
+
def __init__(self, config: LWMConfig) -> None:
|
| 335 |
+
super().__init__()
|
| 336 |
+
self.config = config
|
| 337 |
+
patch_dim = config.patch_size[0] * config.patch_size[1] * 2
|
| 338 |
+
self.tokenizer = ComplexPatchTokenizer(config.phase_mode)
|
| 339 |
+
self.patch_embed = nn.Linear(patch_dim, config.embed_dim)
|
| 340 |
+
self.global_cls = config.global_cls
|
| 341 |
+
pos_len = (config.max_seq_len or 0) + (1 if self.global_cls else 0)
|
| 342 |
+
if pos_len == 0:
|
| 343 |
+
pos_len = 1
|
| 344 |
+
if config.posenc == "learned":
|
| 345 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, pos_len, config.embed_dim))
|
| 346 |
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
| 347 |
+
else:
|
| 348 |
+
self.register_buffer("pos_embed", torch.zeros(1, pos_len, config.embed_dim), persistent=False)
|
| 349 |
+
if self.global_cls:
|
| 350 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim))
|
| 351 |
+
nn.init.trunc_normal_(self.cls_token, std=0.02)
|
| 352 |
+
self.encoder = LWMEncoder(config)
|
| 353 |
+
self.head = nn.Linear(config.embed_dim, patch_dim)
|
| 354 |
+
self._init_weights()
|
| 355 |
+
|
| 356 |
+
def _init_weights(self) -> None:
|
| 357 |
+
for module in self.modules():
|
| 358 |
+
if isinstance(module, nn.Linear):
|
| 359 |
+
nn.init.trunc_normal_(module.weight, std=0.02)
|
| 360 |
+
if module.bias is not None:
|
| 361 |
+
nn.init.zeros_(module.bias)
|
| 362 |
+
elif isinstance(module, nn.LayerNorm):
|
| 363 |
+
nn.init.ones_(module.weight)
|
| 364 |
+
nn.init.zeros_(module.bias)
|
| 365 |
+
|
| 366 |
+
def _add_positional(self, tokens: Tensor) -> Tensor:
|
| 367 |
+
if self.config.posenc == "learned":
|
| 368 |
+
return tokens + self.pos_embed[:, : tokens.size(1)]
|
| 369 |
+
return tokens
|
| 370 |
+
|
| 371 |
+
def forward_tokens(
|
| 372 |
+
self,
|
| 373 |
+
tokens: Tensor,
|
| 374 |
+
mask: Tensor,
|
| 375 |
+
T: int,
|
| 376 |
+
H: int,
|
| 377 |
+
W: int,
|
| 378 |
+
*,
|
| 379 |
+
return_cls: bool = False,
|
| 380 |
+
) -> Dict[str, Optional[Tensor]]:
|
| 381 |
+
embeddings = self.patch_embed(tokens)
|
| 382 |
+
include_cls = self.global_cls
|
| 383 |
+
if include_cls:
|
| 384 |
+
cls_tokens = self.cls_token.expand(embeddings.size(0), -1, -1)
|
| 385 |
+
embeddings = torch.cat([embeddings, cls_tokens], dim=1)
|
| 386 |
+
cls_mask = torch.zeros((embeddings.size(0), 1), dtype=torch.bool, device=embeddings.device)
|
| 387 |
+
mask = torch.cat([mask, cls_mask], dim=1)
|
| 388 |
+
# Add positional embeddings BEFORE masking (matching original implementation)
|
| 389 |
+
embeddings = self._add_positional(embeddings)
|
| 390 |
+
# Then mask embeddings (zeros out both token embedding AND positional embedding)
|
| 391 |
+
embeddings = embeddings.masked_fill(mask.unsqueeze(-1), 0.0)
|
| 392 |
+
encoded = self.encoder(embeddings, T, H, W, include_cls)
|
| 393 |
+
if include_cls:
|
| 394 |
+
reconstruction = self.head(encoded[:, :-1, :])
|
| 395 |
+
cls = encoded[:, -1, :]
|
| 396 |
+
else:
|
| 397 |
+
reconstruction = self.head(encoded)
|
| 398 |
+
cls = None
|
| 399 |
+
return {"reconstruction": reconstruction, "cls": cls if return_cls else None}
|
| 400 |
+
|
| 401 |
+
def forward(self, seq: Tensor, mask: Optional[Tensor] = None, *, return_cls: bool = False) -> Dict[str, Optional[Tensor]]:
|
| 402 |
+
tokens, base_mask = self.tokenizer(seq, self.config.patch_size)
|
| 403 |
+
total_mask = base_mask if mask is None else mask
|
| 404 |
+
ph, pw = self.config.patch_size
|
| 405 |
+
T = seq.size(1)
|
| 406 |
+
H = seq.size(2) // ph
|
| 407 |
+
W = seq.size(3) // pw
|
| 408 |
+
return self.forward_tokens(tokens, total_mask, T, H, W, return_cls=return_cls)
|
| 409 |
+
|
| 410 |
+
@torch.no_grad()
|
| 411 |
+
def forward_features(self, seq: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
|
| 412 |
+
outputs = self.forward(seq, return_cls=True)
|
| 413 |
+
return outputs["reconstruction"], outputs["cls"]
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
class LWMBackbone(LWMModel):
|
| 417 |
+
"""Minor alias kept for backwards compatibility with legacy scripts."""
|
| 418 |
+
|
| 419 |
+
@classmethod
|
| 420 |
+
def from_pretrained(
|
| 421 |
+
cls,
|
| 422 |
+
pretrained_model_name_or_path: str | Path,
|
| 423 |
+
*model_args: Any,
|
| 424 |
+
config: Optional[LWMConfig] = None,
|
| 425 |
+
map_location: str | torch.device = "cpu",
|
| 426 |
+
**kwargs: Any,
|
| 427 |
+
) -> "LWMBackbone":
|
| 428 |
+
path = Path(pretrained_model_name_or_path)
|
| 429 |
+
state: Dict[str, Tensor]
|
| 430 |
+
checkpoint_config: Optional[Dict[str, Any]] = None
|
| 431 |
+
|
| 432 |
+
if path.is_dir():
|
| 433 |
+
directory = path
|
| 434 |
+
state_path = directory / "pytorch_model.bin"
|
| 435 |
+
if not state_path.exists():
|
| 436 |
+
raise FileNotFoundError(f"Pretrained weights not found at {state_path}")
|
| 437 |
+
raw = torch.load(state_path, map_location=map_location)
|
| 438 |
+
if isinstance(raw, dict) and any(isinstance(v, torch.Tensor) for v in raw.values()):
|
| 439 |
+
state = {k: v for k, v in raw.items() if isinstance(v, torch.Tensor)}
|
| 440 |
+
else:
|
| 441 |
+
raise ValueError(f"Unexpected checkpoint format at {state_path}")
|
| 442 |
+
# Always try to load checkpoint config first, then merge with provided config
|
| 443 |
+
checkpoint_config_dict = None
|
| 444 |
+
config_path = directory / "config.json"
|
| 445 |
+
if config_path.exists():
|
| 446 |
+
with config_path.open("r") as handle:
|
| 447 |
+
checkpoint_config_dict = json.load(handle)
|
| 448 |
+
checkpoint_config = LWMConfig.from_dict(checkpoint_config_dict)
|
| 449 |
+
if config is None:
|
| 450 |
+
config = checkpoint_config
|
| 451 |
+
else:
|
| 452 |
+
# Merge: use checkpoint config as base, override with provided config
|
| 453 |
+
checkpoint_dict = checkpoint_config.to_dict()
|
| 454 |
+
provided_dict = config.to_dict()
|
| 455 |
+
merged_dict = {**checkpoint_dict, **provided_dict}
|
| 456 |
+
config = LWMConfig.from_dict(merged_dict)
|
| 457 |
+
else:
|
| 458 |
+
if not path.exists():
|
| 459 |
+
raise FileNotFoundError(f"Pretrained weights not found at {path}")
|
| 460 |
+
raw = torch.load(path, map_location=map_location)
|
| 461 |
+
if isinstance(raw, dict) and "model_state_dict" in raw:
|
| 462 |
+
state = raw["model_state_dict"]
|
| 463 |
+
checkpoint_config = raw.get("config")
|
| 464 |
+
elif isinstance(raw, dict):
|
| 465 |
+
state = {k: v for k, v in raw.items() if isinstance(v, torch.Tensor)}
|
| 466 |
+
else:
|
| 467 |
+
raise ValueError("Unsupported checkpoint format; expected a state_dict or training checkpoint.")
|
| 468 |
+
|
| 469 |
+
if config is None and checkpoint_config is not None:
|
| 470 |
+
config = cls._config_from_checkpoint(checkpoint_config)
|
| 471 |
+
if config is None:
|
| 472 |
+
config = LWMConfig()
|
| 473 |
+
|
| 474 |
+
if config.max_seq_len is None and "pos_embed" in state:
|
| 475 |
+
pos_len = int(state["pos_embed"].shape[1])
|
| 476 |
+
cls_tokens = 1 if config.global_cls else 0
|
| 477 |
+
inferred = max(0, pos_len - cls_tokens)
|
| 478 |
+
if inferred > 0:
|
| 479 |
+
config.max_seq_len = inferred
|
| 480 |
+
remapped_state = cls._remap_state_dict(state)
|
| 481 |
+
model = cls(config, *model_args, **kwargs)
|
| 482 |
+
model.load_state_dict(remapped_state, strict=False)
|
| 483 |
+
return model
|
| 484 |
+
|
| 485 |
+
def save_pretrained(self, save_directory: str | Path, **kwargs: Any) -> None:
|
| 486 |
+
directory = Path(save_directory)
|
| 487 |
+
directory.mkdir(parents=True, exist_ok=True)
|
| 488 |
+
config_path = directory / "config.json"
|
| 489 |
+
with config_path.open("w") as handle:
|
| 490 |
+
json.dump(self.config.to_dict(), handle, indent=2)
|
| 491 |
+
state_path = directory / "pytorch_model.bin"
|
| 492 |
+
torch.save(self.state_dict(), state_path)
|
| 493 |
+
|
| 494 |
+
@staticmethod
|
| 495 |
+
def _config_from_checkpoint(data: Any) -> Optional[LWMConfig]:
|
| 496 |
+
if not isinstance(data, dict):
|
| 497 |
+
return None
|
| 498 |
+
model_cfg = data.get("model", data)
|
| 499 |
+
if not isinstance(model_cfg, dict):
|
| 500 |
+
return None
|
| 501 |
+
allowed = {field.name for field in fields(LWMConfig)}
|
| 502 |
+
kwargs: Dict[str, Any] = {}
|
| 503 |
+
for key, value in model_cfg.items():
|
| 504 |
+
if key not in allowed:
|
| 505 |
+
continue
|
| 506 |
+
if key == "patch_size" and isinstance(value, (list, tuple)):
|
| 507 |
+
value = tuple(int(v) for v in value)
|
| 508 |
+
if key == "temporal_offsets" and isinstance(value, (list, tuple)):
|
| 509 |
+
value = tuple(int(v) for v in value)
|
| 510 |
+
kwargs[key] = value
|
| 511 |
+
if not kwargs:
|
| 512 |
+
return None
|
| 513 |
+
return LWMConfig(**kwargs)
|
| 514 |
+
|
| 515 |
+
@staticmethod
|
| 516 |
+
def _remap_state_dict(state: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
| 517 |
+
remapped: Dict[str, Tensor] = {}
|
| 518 |
+
for key, value in state.items():
|
| 519 |
+
new_key = key
|
| 520 |
+
if key.startswith("embed."):
|
| 521 |
+
new_key = key.replace("embed", "patch_embed", 1)
|
| 522 |
+
elif key.startswith("blocks."):
|
| 523 |
+
new_key = key.replace("blocks", "encoder.layers", 1)
|
| 524 |
+
elif key.startswith("norm."):
|
| 525 |
+
new_key = key.replace("norm", "encoder.norm", 1)
|
| 526 |
+
remapped[new_key] = value
|
| 527 |
+
return remapped
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def compute_nmse(pred: Tensor, target: Tensor, mask: Tensor) -> float:
|
| 531 |
+
"""
|
| 532 |
+
Compute NMSE per sample, then average across batch (matching original implementation).
|
| 533 |
+
For each sample: nmse_b = sum((pred-target)^2 [mask]) / sum(target^2 [mask])
|
| 534 |
+
"""
|
| 535 |
+
B = pred.size(0)
|
| 536 |
+
nmse_vals = []
|
| 537 |
+
for b in range(B):
|
| 538 |
+
m = mask[b]
|
| 539 |
+
if m.sum() == 0:
|
| 540 |
+
continue
|
| 541 |
+
se = (pred[b][m] - target[b][m]).pow(2).sum()
|
| 542 |
+
sp = target[b][m].pow(2).sum().clamp_min(1e-12)
|
| 543 |
+
nmse_vals.append((se / sp).item())
|
| 544 |
+
if not nmse_vals:
|
| 545 |
+
return float('nan')
|
| 546 |
+
return sum(nmse_vals) / len(nmse_vals)
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
def masked_nmse_loss(pred: Tensor, target: Tensor, mask: Tensor) -> Tensor:
|
| 550 |
+
diff = (pred - target).abs() ** 2
|
| 551 |
+
power = target.abs() ** 2
|
| 552 |
+
mask_f = mask.float()
|
| 553 |
+
diff_sum = (diff.sum(-1) * mask_f).sum(-1)
|
| 554 |
+
power_sum = (power.sum(-1) * mask_f).sum(-1).clamp_min(1e-12)
|
| 555 |
+
nmse = diff_sum / power_sum
|
| 556 |
+
valid = mask.sum(-1) > 0
|
| 557 |
+
return nmse[valid].mean() if valid.any() else nmse.mean()
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def masked_mse_loss(pred: Tensor, target: Tensor, mask: Tensor) -> Tensor:
|
| 561 |
+
diff = (pred - target).abs() ** 2
|
| 562 |
+
mask_f = mask.float()
|
| 563 |
+
num = (diff.sum(-1) * mask_f).sum()
|
| 564 |
+
denom = mask_f.sum().clamp_min(1.0)
|
| 565 |
+
return num / denom
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
__all__ = [
|
| 569 |
+
"ComplexPatchTokenizer",
|
| 570 |
+
"LWMConfig",
|
| 571 |
+
"LWMModel",
|
| 572 |
+
"LWMBackbone",
|
| 573 |
+
"compute_nmse",
|
| 574 |
+
"masked_nmse_loss",
|
| 575 |
+
"masked_mse_loss",
|
| 576 |
+
]
|
LWMTemporal/tasks/channel_prediction.py
ADDED
|
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import dataclasses
|
| 4 |
+
import logging
|
| 5 |
+
import math
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Dict, List, Optional, Sequence, Tuple
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from torch.cuda.amp import GradScaler, autocast
|
| 13 |
+
from torch.utils.data import DataLoader, Subset
|
| 14 |
+
|
| 15 |
+
from ..data import AngleDelayDatasetConfig, AngleDelaySequenceDataset
|
| 16 |
+
from ..models import LWMBackbone, LWMConfig
|
| 17 |
+
from ..models.lwm import masked_mse_loss, masked_nmse_loss, compute_nmse
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
import wandb # type: ignore
|
| 21 |
+
except ImportError: # pragma: no cover
|
| 22 |
+
wandb = None # type: ignore
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclasses.dataclass
|
| 26 |
+
class DatasetArgs:
|
| 27 |
+
data_path: Path
|
| 28 |
+
keep_percentage: float = 0.25
|
| 29 |
+
normalize: str = "global_rms"
|
| 30 |
+
cache_dir: Path = Path("cache")
|
| 31 |
+
use_cache: bool = True
|
| 32 |
+
overwrite_cache: bool = False
|
| 33 |
+
snr_db: Optional[float] = None
|
| 34 |
+
noise_seed: Optional[int] = None
|
| 35 |
+
max_time_steps: Optional[int] = None
|
| 36 |
+
train_limit: int = 500
|
| 37 |
+
val_limit: int = 1000
|
| 38 |
+
seed: int = 42
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclasses.dataclass
|
| 42 |
+
class ModelArgs:
|
| 43 |
+
patch_size: Tuple[int, int] = (1, 1)
|
| 44 |
+
phase_mode: str = "real_imag"
|
| 45 |
+
embed_dim: int = 32
|
| 46 |
+
depth: int = 12
|
| 47 |
+
num_heads: int = 8
|
| 48 |
+
mlp_ratio: float = 4.0
|
| 49 |
+
same_frame_window: int = 2
|
| 50 |
+
temporal_offsets: Sequence[int] = dataclasses.field(default_factory=lambda: (-1, -2, -3, -4, -5, -6, -7))
|
| 51 |
+
temporal_spatial_window: int = 2
|
| 52 |
+
temporal_drift_h: int = 1
|
| 53 |
+
temporal_drift_w: int = 1
|
| 54 |
+
routing_topk_enable: bool = True
|
| 55 |
+
routing_topk_fraction: float = 0.2
|
| 56 |
+
routing_topk_min: int = 8
|
| 57 |
+
routing_topk_max: int = 32
|
| 58 |
+
topk_per_head: bool = True
|
| 59 |
+
posenc: str = "learned"
|
| 60 |
+
rope_base: float = 10000.0
|
| 61 |
+
global_cls: bool = False
|
| 62 |
+
pretrained: Optional[Path] = None
|
| 63 |
+
finetune_last_n: int = 0
|
| 64 |
+
train_head_only: bool = False
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@dataclasses.dataclass
|
| 68 |
+
class TrainingArgs:
|
| 69 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
| 70 |
+
epochs: int = 3
|
| 71 |
+
batch_size: int = 16
|
| 72 |
+
lr: float = 1e-4
|
| 73 |
+
weight_decay: float = 1e-4
|
| 74 |
+
warmup_ratio: float = 0.1
|
| 75 |
+
loss: str = "nmse"
|
| 76 |
+
use_dataparallel: bool = False
|
| 77 |
+
grad_clip: float = 1.0
|
| 78 |
+
log_interval: int = 10
|
| 79 |
+
save_dir: Path = Path("models")
|
| 80 |
+
save_prefix: str = "channel_prediction"
|
| 81 |
+
inference_only: bool = False
|
| 82 |
+
inference_split: str = "val"
|
| 83 |
+
verbose_inference: bool = False
|
| 84 |
+
log_dir: Path = Path("logs")
|
| 85 |
+
use_wandb: bool = False
|
| 86 |
+
wandb_project: Optional[str] = None
|
| 87 |
+
wandb_entity: Optional[str] = None
|
| 88 |
+
wandb_run_name: Optional[str] = None
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@dataclasses.dataclass
|
| 92 |
+
class PredictionArgs:
|
| 93 |
+
Tpast: int = 10
|
| 94 |
+
horizon: int = 1
|
| 95 |
+
num_visual_samples: int = 4
|
| 96 |
+
viz_dir: Path = Path("figs/predictions")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@dataclasses.dataclass
|
| 100 |
+
class ChannelPredictionArgs:
|
| 101 |
+
dataset: DatasetArgs
|
| 102 |
+
model: ModelArgs
|
| 103 |
+
training: TrainingArgs
|
| 104 |
+
prediction: PredictionArgs
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class ChannelPredictionDataModule:
|
| 108 |
+
def __init__(self, args: DatasetArgs, patch_size: Tuple[int, int], phase_mode: str) -> None:
|
| 109 |
+
cfg = AngleDelayDatasetConfig(
|
| 110 |
+
raw_path=args.data_path,
|
| 111 |
+
keep_percentage=args.keep_percentage,
|
| 112 |
+
normalize=args.normalize,
|
| 113 |
+
cache_dir=args.cache_dir,
|
| 114 |
+
use_cache=args.use_cache,
|
| 115 |
+
overwrite_cache=args.overwrite_cache,
|
| 116 |
+
snr_db=args.snr_db,
|
| 117 |
+
noise_seed=args.noise_seed,
|
| 118 |
+
max_time_steps=args.max_time_steps,
|
| 119 |
+
patch_size=patch_size,
|
| 120 |
+
phase_mode=phase_mode,
|
| 121 |
+
)
|
| 122 |
+
self.dataset = AngleDelaySequenceDataset(cfg)
|
| 123 |
+
generator = torch.Generator().manual_seed(args.seed)
|
| 124 |
+
indices = torch.randperm(len(self.dataset), generator=generator).tolist()
|
| 125 |
+
train_len = min(args.train_limit, len(indices))
|
| 126 |
+
val_len = min(args.val_limit, max(0, len(indices) - train_len))
|
| 127 |
+
self.train_indices = indices[:train_len]
|
| 128 |
+
self.val_indices = indices[train_len:train_len + val_len]
|
| 129 |
+
self.patch_size = patch_size
|
| 130 |
+
self.phase_mode = phase_mode
|
| 131 |
+
|
| 132 |
+
def train_loader(self, batch_size: int, drop_last: bool = True) -> DataLoader:
|
| 133 |
+
subset = Subset(self.dataset, self.train_indices)
|
| 134 |
+
return DataLoader(subset, batch_size=batch_size, shuffle=True, drop_last=drop_last)
|
| 135 |
+
|
| 136 |
+
def val_loader(self, batch_size: int, drop_last: bool = False) -> Optional[DataLoader]:
|
| 137 |
+
if not self.val_indices:
|
| 138 |
+
return None
|
| 139 |
+
subset = Subset(self.dataset, self.val_indices)
|
| 140 |
+
return DataLoader(subset, batch_size=batch_size, shuffle=False, drop_last=drop_last)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class AutoregressiveEngine:
|
| 144 |
+
def __init__(self, patch_size: Tuple[int, int], phase_mode: str) -> None:
|
| 145 |
+
self.patch_size = patch_size
|
| 146 |
+
self.phase_mode = phase_mode
|
| 147 |
+
|
| 148 |
+
def detokenize(self, tokens: Tensor, T: int, H: int, W: int) -> Tensor:
|
| 149 |
+
B = tokens.size(0)
|
| 150 |
+
ph, pw = self.patch_size
|
| 151 |
+
patches = tokens.view(B, T, H, W, ph * pw * 2)
|
| 152 |
+
patches = patches.view(B, T, H, W, ph, pw, 2)
|
| 153 |
+
patches = patches.permute(0, 1, 2, 4, 3, 5, 6).contiguous()
|
| 154 |
+
recon = patches.view(B, T, H * ph, W * pw, 2)
|
| 155 |
+
if self.phase_mode == "real_imag":
|
| 156 |
+
real = recon[..., 0]
|
| 157 |
+
imag = recon[..., 1]
|
| 158 |
+
return torch.complex(real, imag)
|
| 159 |
+
magnitude = recon[..., 0]
|
| 160 |
+
phase = recon[..., 1]
|
| 161 |
+
real = magnitude * torch.cos(phase)
|
| 162 |
+
imag = magnitude * torch.sin(phase)
|
| 163 |
+
return torch.complex(real, imag)
|
| 164 |
+
|
| 165 |
+
def autoregressive_rollout(
|
| 166 |
+
self,
|
| 167 |
+
model: LWMBackbone,
|
| 168 |
+
tokens: Tensor,
|
| 169 |
+
Tpast: int,
|
| 170 |
+
horizon: int,
|
| 171 |
+
H: int,
|
| 172 |
+
W: int,
|
| 173 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
| 174 |
+
B, S_full, D = tokens.shape
|
| 175 |
+
S_per_time = H * W
|
| 176 |
+
if S_full % S_per_time != 0:
|
| 177 |
+
raise ValueError("Token sequence length incompatible with H and W")
|
| 178 |
+
T_total = S_full // S_per_time
|
| 179 |
+
S_per_time = H * W
|
| 180 |
+
window_tokens = Tpast + 1
|
| 181 |
+
if T_total < Tpast + horizon:
|
| 182 |
+
raise ValueError("sequence shorter than Tpast + horizon")
|
| 183 |
+
mask_window = torch.zeros((window_tokens, H, W), dtype=torch.bool, device=tokens.device)
|
| 184 |
+
mask_window[Tpast, :, :] = True
|
| 185 |
+
mask_window = mask_window.view(window_tokens * S_per_time)
|
| 186 |
+
mask_future = torch.zeros((T_total, H, W), dtype=torch.bool, device=tokens.device)
|
| 187 |
+
mask_future[Tpast:Tpast + horizon, :, :] = True
|
| 188 |
+
mask_flat = mask_future.view(1, T_total * S_per_time).expand(B, -1)
|
| 189 |
+
|
| 190 |
+
source_tokens = tokens.clone()
|
| 191 |
+
pred_tokens = torch.zeros_like(tokens)
|
| 192 |
+
|
| 193 |
+
for step in range(horizon):
|
| 194 |
+
start_time = step
|
| 195 |
+
end_time = step + window_tokens
|
| 196 |
+
abs_start = start_time * S_per_time
|
| 197 |
+
abs_end = end_time * S_per_time
|
| 198 |
+
window_slice = source_tokens[:, abs_start:abs_end, :].clone() # Clone to avoid in-place modification
|
| 199 |
+
mask_slice = mask_window.unsqueeze(0).expand(B, -1)
|
| 200 |
+
# Zero masked tokens before model forward (matching original implementation)
|
| 201 |
+
window_slice = window_slice.masked_fill(mask_slice.unsqueeze(-1), 0.0)
|
| 202 |
+
outputs = model.forward_tokens(window_slice, mask_slice, window_tokens, H, W, return_cls=False)
|
| 203 |
+
predicted_window = outputs["reconstruction"]
|
| 204 |
+
# Extract predictions for the last time position in the window using slicing
|
| 205 |
+
win_last_start = Tpast * S_per_time
|
| 206 |
+
win_last_end = (Tpast + 1) * S_per_time
|
| 207 |
+
step_pred_last = predicted_window[:, win_last_start:win_last_end, :]
|
| 208 |
+
# Write back into absolute position
|
| 209 |
+
target_range_start = (Tpast + step) * S_per_time
|
| 210 |
+
target_range_end = target_range_start + S_per_time
|
| 211 |
+
source_tokens[:, target_range_start:target_range_end, :] = step_pred_last
|
| 212 |
+
pred_tokens[:, target_range_start:target_range_end, :] = step_pred_last
|
| 213 |
+
|
| 214 |
+
target_tokens = tokens
|
| 215 |
+
return pred_tokens, target_tokens, mask_flat
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class PredictionVisualizer:
|
| 219 |
+
def __init__(self, engine: AutoregressiveEngine, save_dir: Path, num_samples: int) -> None:
|
| 220 |
+
self.engine = engine
|
| 221 |
+
self.save_dir = save_dir
|
| 222 |
+
self.num_samples = num_samples
|
| 223 |
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
| 224 |
+
|
| 225 |
+
def save(self, model: LWMBackbone, tokens: Tensor, H: int, W: int, args: PredictionArgs) -> None:
|
| 226 |
+
model.eval()
|
| 227 |
+
with torch.no_grad():
|
| 228 |
+
preds, tgt, mask = self.engine.autoregressive_rollout(
|
| 229 |
+
model,
|
| 230 |
+
tokens,
|
| 231 |
+
args.Tpast,
|
| 232 |
+
args.horizon,
|
| 233 |
+
H,
|
| 234 |
+
W,
|
| 235 |
+
)
|
| 236 |
+
tokens_per_time = H * W
|
| 237 |
+
T_total = tokens.size(1) // tokens_per_time
|
| 238 |
+
B = tokens.size(0)
|
| 239 |
+
for idx in range(min(B, self.num_samples)):
|
| 240 |
+
pred_seq = preds[idx].view(T_total, tokens_per_time, -1)
|
| 241 |
+
tgt_seq = tgt[idx].view(T_total, tokens_per_time, -1)
|
| 242 |
+
pred_complex = self.engine.detokenize(pred_seq.unsqueeze(0), T_total, H, W)[0]
|
| 243 |
+
tgt_complex = self.engine.detokenize(tgt_seq.unsqueeze(0), T_total, H, W)[0]
|
| 244 |
+
self._plot_sample(pred_complex, tgt_complex, args, sample_idx=idx)
|
| 245 |
+
|
| 246 |
+
def _plot_sample(self, pred: Tensor, tgt: Tensor, args: PredictionArgs, sample_idx: int) -> None:
|
| 247 |
+
import matplotlib.pyplot as plt
|
| 248 |
+
|
| 249 |
+
fig, axes = plt.subplots(args.horizon, 2, figsize=(8, 3 * args.horizon), squeeze=False)
|
| 250 |
+
for step in range(args.horizon):
|
| 251 |
+
t_idx = args.Tpast + step
|
| 252 |
+
gt_mag = tgt[t_idx].abs().cpu().numpy()
|
| 253 |
+
pred_mag = pred[t_idx].abs().cpu().numpy()
|
| 254 |
+
ax_gt, ax_pred = axes[step]
|
| 255 |
+
im0 = ax_gt.imshow(gt_mag, cmap="viridis", aspect="auto")
|
| 256 |
+
im1 = ax_pred.imshow(pred_mag, cmap="viridis", aspect="auto")
|
| 257 |
+
ax_gt.set_title(f"GT t={t_idx}")
|
| 258 |
+
ax_pred.set_title(f"Pred t={t_idx}")
|
| 259 |
+
for ax in (ax_gt, ax_pred):
|
| 260 |
+
ax.set_xticks([])
|
| 261 |
+
ax.set_yticks([])
|
| 262 |
+
fig.colorbar(im0, ax=ax_gt, fraction=0.046, pad=0.04)
|
| 263 |
+
fig.colorbar(im1, ax=ax_pred, fraction=0.046, pad=0.04)
|
| 264 |
+
fig.tight_layout()
|
| 265 |
+
out_path = self.save_dir / f"sample_{sample_idx}.png"
|
| 266 |
+
fig.savefig(out_path)
|
| 267 |
+
plt.close(fig)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class ChannelPredictionTrainer:
|
| 271 |
+
def __init__(self, args: ChannelPredictionArgs, *, logger: Optional[logging.Logger] = None) -> None:
|
| 272 |
+
self.args = args
|
| 273 |
+
torch.manual_seed(args.dataset.seed)
|
| 274 |
+
np.random.seed(args.dataset.seed)
|
| 275 |
+
self.device = torch.device(args.training.device)
|
| 276 |
+
self.engine = AutoregressiveEngine(args.model.patch_size, args.model.phase_mode)
|
| 277 |
+
self.data = ChannelPredictionDataModule(args.dataset, args.model.patch_size, args.model.phase_mode)
|
| 278 |
+
self.model = self._build_model().to(self.device)
|
| 279 |
+
self.model.eval() # Set to eval mode immediately after loading
|
| 280 |
+
if args.training.use_dataparallel and torch.cuda.device_count() > 1:
|
| 281 |
+
self.model = nn.DataParallel(self.model)
|
| 282 |
+
if hasattr(self.model, 'module'):
|
| 283 |
+
self.model.module.eval()
|
| 284 |
+
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=args.training.lr, weight_decay=args.training.weight_decay)
|
| 285 |
+
self.scheduler = self._build_scheduler()
|
| 286 |
+
self.scaler = GradScaler()
|
| 287 |
+
self.viz = PredictionVisualizer(self.engine, args.prediction.viz_dir, args.prediction.num_visual_samples)
|
| 288 |
+
self.logger = logger or logging.getLogger(__name__)
|
| 289 |
+
self.global_step = 0
|
| 290 |
+
self._wandb_run = self._maybe_init_wandb()
|
| 291 |
+
|
| 292 |
+
def _wandb_enabled(self) -> bool:
|
| 293 |
+
return self._wandb_run is not None
|
| 294 |
+
|
| 295 |
+
def _maybe_init_wandb(self) -> Optional["wandb.sdk.wandb_run.Run"]:
|
| 296 |
+
training = self.args.training
|
| 297 |
+
if not training.use_wandb:
|
| 298 |
+
return None
|
| 299 |
+
if wandb is None:
|
| 300 |
+
self.logger.warning("Weights & Biases not installed; disabling wandb logging.")
|
| 301 |
+
return None
|
| 302 |
+
config = {
|
| 303 |
+
"dataset": dataclasses.asdict(self.args.dataset),
|
| 304 |
+
"model": dataclasses.asdict(self.args.model),
|
| 305 |
+
"training": dataclasses.asdict(self.args.training),
|
| 306 |
+
"prediction": dataclasses.asdict(self.args.prediction),
|
| 307 |
+
}
|
| 308 |
+
run = wandb.init(
|
| 309 |
+
project=training.wandb_project,
|
| 310 |
+
entity=training.wandb_entity,
|
| 311 |
+
name=training.wandb_run_name,
|
| 312 |
+
config=config,
|
| 313 |
+
)
|
| 314 |
+
wandb.watch(self.model, log="all", log_freq=self.args.training.log_interval)
|
| 315 |
+
self.logger.info("Initialized Weights & Biases run: %s", run.name)
|
| 316 |
+
return run
|
| 317 |
+
|
| 318 |
+
def _wandb_log(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
|
| 319 |
+
if not self._wandb_enabled():
|
| 320 |
+
return
|
| 321 |
+
wandb.log(metrics, step=step)
|
| 322 |
+
|
| 323 |
+
def _finish_wandb(self) -> None:
|
| 324 |
+
if self._wandb_enabled():
|
| 325 |
+
wandb.finish()
|
| 326 |
+
|
| 327 |
+
def _build_model(self) -> LWMBackbone:
|
| 328 |
+
# Calculate max_seq_len based on window size (matching original implementation)
|
| 329 |
+
# This is critical for channel prediction with autoregressive rollout
|
| 330 |
+
sample_batch = next(iter(self.data.val_loader(1) or self.data.train_loader(1)))
|
| 331 |
+
_, _, H, W = self._prepare_batch(sample_batch)
|
| 332 |
+
max_seq_len = (self.args.prediction.Tpast + 1) * H * W
|
| 333 |
+
|
| 334 |
+
cfg = LWMConfig(
|
| 335 |
+
patch_size=self.args.model.patch_size,
|
| 336 |
+
phase_mode=self.args.model.phase_mode,
|
| 337 |
+
embed_dim=self.args.model.embed_dim,
|
| 338 |
+
depth=self.args.model.depth,
|
| 339 |
+
num_heads=self.args.model.num_heads,
|
| 340 |
+
mlp_ratio=self.args.model.mlp_ratio,
|
| 341 |
+
same_frame_window=self.args.model.same_frame_window,
|
| 342 |
+
temporal_offsets=self.args.model.temporal_offsets,
|
| 343 |
+
temporal_spatial_window=self.args.model.temporal_spatial_window,
|
| 344 |
+
temporal_drift_h=self.args.model.temporal_drift_h,
|
| 345 |
+
temporal_drift_w=self.args.model.temporal_drift_w,
|
| 346 |
+
routing_topk_enable=self.args.model.routing_topk_enable,
|
| 347 |
+
routing_topk_fraction=self.args.model.routing_topk_fraction,
|
| 348 |
+
routing_topk_min=self.args.model.routing_topk_min,
|
| 349 |
+
routing_topk_max=self.args.model.routing_topk_max,
|
| 350 |
+
topk_per_head=self.args.model.topk_per_head,
|
| 351 |
+
posenc=self.args.model.posenc,
|
| 352 |
+
rope_base=self.args.model.rope_base,
|
| 353 |
+
global_cls=self.args.model.global_cls,
|
| 354 |
+
max_seq_len=max_seq_len,
|
| 355 |
+
)
|
| 356 |
+
model = LWMBackbone(cfg)
|
| 357 |
+
if self.args.model.pretrained is not None and self.args.model.pretrained.exists():
|
| 358 |
+
model = LWMBackbone.from_pretrained(self.args.model.pretrained, config=cfg)
|
| 359 |
+
if self.args.model.train_head_only:
|
| 360 |
+
for param in model.parameters():
|
| 361 |
+
param.requires_grad = False
|
| 362 |
+
for param in model.head.parameters():
|
| 363 |
+
param.requires_grad = True
|
| 364 |
+
elif self.args.model.finetune_last_n > 0:
|
| 365 |
+
model.freeze_backbone()
|
| 366 |
+
if hasattr(model, "encoder"):
|
| 367 |
+
layers = model.encoder.layers
|
| 368 |
+
for layer in layers[-self.args.model.finetune_last_n:]:
|
| 369 |
+
for param in layer.parameters():
|
| 370 |
+
param.requires_grad = True
|
| 371 |
+
for param in model.head.parameters():
|
| 372 |
+
param.requires_grad = True
|
| 373 |
+
return model
|
| 374 |
+
|
| 375 |
+
def _build_scheduler(self) -> torch.optim.lr_scheduler.LambdaLR:
|
| 376 |
+
train_loader = self.data.train_loader(self.args.training.batch_size)
|
| 377 |
+
steps_per_epoch = max(1, len(train_loader))
|
| 378 |
+
total_steps = steps_per_epoch * max(1, self.args.training.epochs)
|
| 379 |
+
warmup_steps = int(self.args.training.warmup_ratio * total_steps)
|
| 380 |
+
|
| 381 |
+
def schedule(step: int) -> float:
|
| 382 |
+
if step < warmup_steps:
|
| 383 |
+
return float(step) / max(1, warmup_steps)
|
| 384 |
+
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
|
| 385 |
+
return 0.5 * (1.0 + math.cos(math.pi * progress))
|
| 386 |
+
|
| 387 |
+
return torch.optim.lr_scheduler.LambdaLR(self.optimizer, schedule)
|
| 388 |
+
|
| 389 |
+
def _prepare_batch(self, batch: Dict[str, Tensor]) -> Tuple[Tensor, Tensor, int, int]:
|
| 390 |
+
tokens = batch["tokens"].to(self.device)
|
| 391 |
+
base_mask = batch["base_mask"].to(self.device)
|
| 392 |
+
shapes = batch["shape"]
|
| 393 |
+
if not isinstance(shapes, torch.Tensor):
|
| 394 |
+
shapes = torch.tensor(shapes)
|
| 395 |
+
if shapes.dim() == 1:
|
| 396 |
+
shapes = shapes.unsqueeze(0)
|
| 397 |
+
ref_shape = shapes[0]
|
| 398 |
+
if not torch.all(shapes.eq(ref_shape)):
|
| 399 |
+
raise ValueError("Mixed sequence shapes within the same batch are not supported")
|
| 400 |
+
T = int(ref_shape[0].item())
|
| 401 |
+
H = int(ref_shape[1].item())
|
| 402 |
+
W = int(ref_shape[2].item())
|
| 403 |
+
T_needed = self.args.prediction.Tpast + self.args.prediction.horizon
|
| 404 |
+
if T < T_needed:
|
| 405 |
+
raise ValueError("Sequence shorter than required Tpast+horizon frames")
|
| 406 |
+
S_per_time = H * W
|
| 407 |
+
tokens = tokens[:, : T_needed * S_per_time, :]
|
| 408 |
+
mask = base_mask[:, : T_needed * S_per_time]
|
| 409 |
+
return tokens, mask, H, W
|
| 410 |
+
|
| 411 |
+
def _compute_loss(self, pred: Tensor, tgt: Tensor, mask: Tensor) -> Tensor:
|
| 412 |
+
if self.args.training.loss == "mse":
|
| 413 |
+
return masked_mse_loss(pred, tgt, mask)
|
| 414 |
+
return masked_nmse_loss(pred, tgt, mask)
|
| 415 |
+
|
| 416 |
+
def train(self) -> None:
|
| 417 |
+
if self.args.training.inference_only:
|
| 418 |
+
# self.logger.info(
|
| 419 |
+
# "Running inference-only evaluation on split '%s'", self.args.training.inference_split
|
| 420 |
+
# )
|
| 421 |
+
self.evaluate(split=self.args.training.inference_split)
|
| 422 |
+
self._finish_wandb()
|
| 423 |
+
return
|
| 424 |
+
train_loader = self.data.train_loader(self.args.training.batch_size)
|
| 425 |
+
val_loader = self.data.val_loader(self.args.training.batch_size)
|
| 426 |
+
for epoch in range(1, self.args.training.epochs + 1):
|
| 427 |
+
self.model.train()
|
| 428 |
+
running_loss = 0.0
|
| 429 |
+
running_nmse: List[float] = []
|
| 430 |
+
loader_len = len(train_loader)
|
| 431 |
+
for step, batch in enumerate(train_loader, start=1):
|
| 432 |
+
tokens, _, H, W = self._prepare_batch(batch)
|
| 433 |
+
with autocast():
|
| 434 |
+
preds, target, mask = self.engine.autoregressive_rollout(
|
| 435 |
+
self.model, tokens, self.args.prediction.Tpast, self.args.prediction.horizon, H, W
|
| 436 |
+
)
|
| 437 |
+
loss = self._compute_loss(preds, target, mask)
|
| 438 |
+
self.scaler.scale(loss).backward()
|
| 439 |
+
self.scaler.unscale_(self.optimizer)
|
| 440 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.training.grad_clip)
|
| 441 |
+
self.scaler.step(self.optimizer)
|
| 442 |
+
self.scaler.update()
|
| 443 |
+
self.optimizer.zero_grad()
|
| 444 |
+
self.scheduler.step()
|
| 445 |
+
running_loss += loss.item()
|
| 446 |
+
running_nmse.append(compute_nmse(preds, target, mask))
|
| 447 |
+
self.global_step += 1
|
| 448 |
+
if step % self.args.training.log_interval == 0:
|
| 449 |
+
avg_loss = running_loss / step
|
| 450 |
+
avg_nmse = float(np.mean(running_nmse)) if running_nmse else float("nan")
|
| 451 |
+
lr_cur = self.optimizer.param_groups[0]["lr"]
|
| 452 |
+
self.logger.info(
|
| 453 |
+
"Train: [%d/%d][%d/%d] loss=%0.6f nmse=%0.6f lr=%0.2e",
|
| 454 |
+
epoch,
|
| 455 |
+
self.args.training.epochs,
|
| 456 |
+
step,
|
| 457 |
+
loader_len,
|
| 458 |
+
avg_loss,
|
| 459 |
+
avg_nmse,
|
| 460 |
+
lr_cur,
|
| 461 |
+
)
|
| 462 |
+
self._wandb_log(
|
| 463 |
+
{
|
| 464 |
+
"train/loss": avg_loss,
|
| 465 |
+
"train/nmse": avg_nmse,
|
| 466 |
+
"train/lr": lr_cur,
|
| 467 |
+
},
|
| 468 |
+
step=self.global_step,
|
| 469 |
+
)
|
| 470 |
+
avg_train_loss = running_loss / max(1, len(train_loader))
|
| 471 |
+
avg_train_nmse = float(np.mean(running_nmse)) if running_nmse else float("nan")
|
| 472 |
+
self.logger.info(
|
| 473 |
+
"Train Epoch %d/%d Summary: loss=%0.6f nmse=%0.6f",
|
| 474 |
+
epoch,
|
| 475 |
+
self.args.training.epochs,
|
| 476 |
+
avg_train_loss,
|
| 477 |
+
avg_train_nmse,
|
| 478 |
+
)
|
| 479 |
+
self._wandb_log(
|
| 480 |
+
{
|
| 481 |
+
"train/epoch_loss": avg_train_loss,
|
| 482 |
+
"train/epoch_nmse": avg_train_nmse,
|
| 483 |
+
},
|
| 484 |
+
step=self.global_step,
|
| 485 |
+
)
|
| 486 |
+
if val_loader is not None:
|
| 487 |
+
self.evaluate(loader=val_loader, split="val", epoch=epoch)
|
| 488 |
+
first_batch = next(iter(train_loader))
|
| 489 |
+
tokens_vis, _, H_vis, W_vis = self._prepare_batch(first_batch)
|
| 490 |
+
self.viz.save(self.model, tokens_vis, H_vis, W_vis, self.args.prediction)
|
| 491 |
+
self._finish_wandb()
|
| 492 |
+
|
| 493 |
+
def evaluate(
|
| 494 |
+
self,
|
| 495 |
+
loader: Optional[DataLoader] = None,
|
| 496 |
+
split: str = "val",
|
| 497 |
+
epoch: Optional[int] = None,
|
| 498 |
+
) -> None:
|
| 499 |
+
if loader is None:
|
| 500 |
+
if split == "train":
|
| 501 |
+
subset = Subset(self.data.dataset, self.data.train_indices)
|
| 502 |
+
loader = DataLoader(subset, batch_size=self.args.training.batch_size, shuffle=False, drop_last=False)
|
| 503 |
+
elif split == "val":
|
| 504 |
+
subset = Subset(self.data.dataset, self.data.val_indices)
|
| 505 |
+
loader = DataLoader(subset, batch_size=self.args.training.batch_size, shuffle=False, drop_last=False)
|
| 506 |
+
else:
|
| 507 |
+
loader = DataLoader(self.data.dataset, batch_size=self.args.training.batch_size, shuffle=False)
|
| 508 |
+
if loader is None:
|
| 509 |
+
self.logger.warning("No %s loader available", split)
|
| 510 |
+
return
|
| 511 |
+
self.model.eval()
|
| 512 |
+
losses: List[float] = []
|
| 513 |
+
nmses: List[float] = []
|
| 514 |
+
per_step_nmses: List[List[float]] = [] # List of lists: [batch][step]
|
| 515 |
+
with torch.no_grad():
|
| 516 |
+
total_steps = len(loader)
|
| 517 |
+
for step, batch in enumerate(loader, start=1):
|
| 518 |
+
tokens, _, H, W = self._prepare_batch(batch)
|
| 519 |
+
preds, target, mask = self.engine.autoregressive_rollout(
|
| 520 |
+
self.model, tokens, self.args.prediction.Tpast, self.args.prediction.horizon, H, W
|
| 521 |
+
)
|
| 522 |
+
loss = self._compute_loss(preds, target, mask)
|
| 523 |
+
batch_loss = loss.item()
|
| 524 |
+
batch_nmse = compute_nmse(preds, target, mask)
|
| 525 |
+
losses.append(batch_loss)
|
| 526 |
+
nmses.append(batch_nmse)
|
| 527 |
+
|
| 528 |
+
# Compute per-step NMSE for this batch
|
| 529 |
+
S_per_time = H * W
|
| 530 |
+
Tpast = self.args.prediction.Tpast
|
| 531 |
+
horizon = self.args.prediction.horizon
|
| 532 |
+
step_nmses = []
|
| 533 |
+
for h in range(horizon):
|
| 534 |
+
t_idx = Tpast + h
|
| 535 |
+
step_start = t_idx * S_per_time
|
| 536 |
+
step_end = step_start + S_per_time
|
| 537 |
+
step_mask = mask[:, step_start:step_end]
|
| 538 |
+
if step_mask.sum() > 0:
|
| 539 |
+
step_pred = preds[:, step_start:step_end, :]
|
| 540 |
+
step_target = target[:, step_start:step_end, :]
|
| 541 |
+
step_nmse = compute_nmse(step_pred, step_target, step_mask)
|
| 542 |
+
step_nmses.append(step_nmse)
|
| 543 |
+
else:
|
| 544 |
+
step_nmses.append(float('nan'))
|
| 545 |
+
per_step_nmses.append(step_nmses)
|
| 546 |
+
|
| 547 |
+
# Report per-step NMSE for this batch (matching original package format)
|
| 548 |
+
per_step_strs = []
|
| 549 |
+
for h, step_nmse in enumerate(step_nmses):
|
| 550 |
+
if not math.isnan(step_nmse):
|
| 551 |
+
t = Tpast + h + 1 # t=11, 12, ... (1-indexed)
|
| 552 |
+
nmse_db = 10.0 * math.log10(max(step_nmse, 1e-12))
|
| 553 |
+
per_step_strs.append(f"t={t}: {nmse_db:.3f} dB")
|
| 554 |
+
if per_step_strs:
|
| 555 |
+
self.logger.info(
|
| 556 |
+
"[%s] per-step NMSE dB: %s",
|
| 557 |
+
split,
|
| 558 |
+
", ".join(per_step_strs),
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
if self.args.training.verbose_inference:
|
| 562 |
+
tag = split.upper()
|
| 563 |
+
nmse_db = 10.0 * math.log10(max(batch_nmse, 1e-12))
|
| 564 |
+
self.logger.info(
|
| 565 |
+
"%s: [%d/%d] loss=%0.6f nmse=%0.6f (%0.2f dB)",
|
| 566 |
+
tag,
|
| 567 |
+
step,
|
| 568 |
+
total_steps,
|
| 569 |
+
batch_loss,
|
| 570 |
+
batch_nmse,
|
| 571 |
+
nmse_db,
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
avg_loss = float(np.mean(losses)) if losses else float("nan")
|
| 575 |
+
avg_nmse = float(np.mean(nmses)) if nmses else float("nan")
|
| 576 |
+
tag = f"[{split}]" if epoch is None else f"Epoch {epoch} [{split}]"
|
| 577 |
+
avg_nmse_db = 10.0 * math.log10(max(avg_nmse, 1e-12))
|
| 578 |
+
self.logger.info(
|
| 579 |
+
"Inference [%s] NMSE=%e (%0.3f dB) over %d batches",
|
| 580 |
+
split,
|
| 581 |
+
avg_nmse,
|
| 582 |
+
avg_nmse_db,
|
| 583 |
+
len(losses),
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
# Compute per-step average in dB scale (matching original implementation)
|
| 587 |
+
if per_step_nmses:
|
| 588 |
+
horizon = len(per_step_nmses[0])
|
| 589 |
+
per_step_avg_db = []
|
| 590 |
+
Tpast = self.args.prediction.Tpast
|
| 591 |
+
for h in range(horizon):
|
| 592 |
+
# Average dB values (not linear values!)
|
| 593 |
+
step_dbs = []
|
| 594 |
+
for batch_nmses in per_step_nmses:
|
| 595 |
+
if not math.isnan(batch_nmses[h]):
|
| 596 |
+
step_db = 10.0 * math.log10(max(batch_nmses[h], 1e-12))
|
| 597 |
+
step_dbs.append(step_db)
|
| 598 |
+
if step_dbs:
|
| 599 |
+
avg_db = float(np.mean(step_dbs))
|
| 600 |
+
per_step_avg_db.append(f"t={Tpast + h + 1}: {avg_db:.3f} dB")
|
| 601 |
+
if per_step_avg_db:
|
| 602 |
+
self.logger.info(
|
| 603 |
+
"Inference [%s] per-step average NMSE dB: %s",
|
| 604 |
+
split,
|
| 605 |
+
", ".join(per_step_avg_db),
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
metrics = {
|
| 609 |
+
f"{split}/loss": avg_loss,
|
| 610 |
+
f"{split}/nmse": avg_nmse,
|
| 611 |
+
f"{split}/nmse_db": avg_nmse_db,
|
| 612 |
+
}
|
| 613 |
+
self._wandb_log(metrics, step=self.global_step)
|
| 614 |
+
|
| 615 |
+
def _save_checkpoint(self, epoch: int, metric: float) -> None:
|
| 616 |
+
self.args.training.save_dir.mkdir(parents=True, exist_ok=True)
|
| 617 |
+
filename = f"{self.args.training.save_prefix}_epoch{epoch:02d}.pth"
|
| 618 |
+
path = self.args.training.save_dir / filename
|
| 619 |
+
state = {
|
| 620 |
+
"epoch": epoch,
|
| 621 |
+
"metric": metric,
|
| 622 |
+
"model_state_dict": self.model.state_dict(),
|
| 623 |
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
| 624 |
+
"scheduler_state_dict": self.scheduler.state_dict(),
|
| 625 |
+
"config": dataclasses.asdict(self.args),
|
| 626 |
+
}
|
| 627 |
+
torch.save(state, path)
|
| 628 |
+
print(f"Saved checkpoint to {path}")
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
__all__ = [
|
| 632 |
+
"DatasetArgs",
|
| 633 |
+
"ModelArgs",
|
| 634 |
+
"TrainingArgs",
|
| 635 |
+
"PredictionArgs",
|
| 636 |
+
"ChannelPredictionArgs",
|
| 637 |
+
"ChannelPredictionDataModule",
|
| 638 |
+
"AutoregressiveEngine",
|
| 639 |
+
"PredictionVisualizer",
|
| 640 |
+
"ChannelPredictionTrainer",
|
| 641 |
+
]
|
LWMTemporal/tasks/pretraining.py
ADDED
|
@@ -0,0 +1,684 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import dataclasses
|
| 5 |
+
import logging
|
| 6 |
+
import math
|
| 7 |
+
import pickle
|
| 8 |
+
import random
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Dict, List, Optional, Sequence, Tuple
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
from torch.cuda.amp import GradScaler, autocast
|
| 16 |
+
from torch.utils.data import DataLoader, Dataset
|
| 17 |
+
|
| 18 |
+
from ..data.angle_delay import AngleDelayConfig, AngleDelayProcessor
|
| 19 |
+
from ..models import LWMBackbone, LWMConfig
|
| 20 |
+
from ..models.lwm import ComplexPatchTokenizer, masked_nmse_loss
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
import wandb # type: ignore
|
| 24 |
+
except ImportError: # pragma: no cover
|
| 25 |
+
wandb = None # type: ignore
|
| 26 |
+
|
| 27 |
+
@dataclasses.dataclass
|
| 28 |
+
class DataArgs:
|
| 29 |
+
data_dir: Path
|
| 30 |
+
keep_percentage: float = 0.25
|
| 31 |
+
normalize: str = "global_rms"
|
| 32 |
+
max_time_steps: Optional[int] = None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclasses.dataclass
|
| 36 |
+
class MaskArgs:
|
| 37 |
+
mask_ratio: float = 0.75
|
| 38 |
+
mask_mode: str = "auto"
|
| 39 |
+
random_fraction: float = 0.2
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclasses.dataclass
|
| 43 |
+
class CurriculumArgs:
|
| 44 |
+
strategy: str = "mask"
|
| 45 |
+
warmup_epochs: int = 4
|
| 46 |
+
min_mask_ratio: float = 0.3
|
| 47 |
+
max_mask_ratio: float = 0.75
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclasses.dataclass
|
| 51 |
+
class AugmentationArgs:
|
| 52 |
+
phase_p: float = 0.0
|
| 53 |
+
amp_p: float = 0.0
|
| 54 |
+
amp_min: float = 0.7
|
| 55 |
+
amp_max: float = 1.3
|
| 56 |
+
awgn_p: float = 0.0
|
| 57 |
+
awgn_snr_min: float = 20.0
|
| 58 |
+
awgn_snr_max: float = 30.0
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@dataclasses.dataclass
|
| 62 |
+
class LoggingArgs:
|
| 63 |
+
log_dir: Path = Path("logs")
|
| 64 |
+
use_wandb: bool = False
|
| 65 |
+
wandb_project: Optional[str] = None
|
| 66 |
+
wandb_entity: Optional[str] = None
|
| 67 |
+
wandb_run_name: Optional[str] = None
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@dataclasses.dataclass
|
| 71 |
+
class OptimizationArgs:
|
| 72 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
| 73 |
+
epochs: int = 20
|
| 74 |
+
batch_size: int = 32
|
| 75 |
+
lr: float = 2e-4
|
| 76 |
+
weight_decay: float = 1e-4
|
| 77 |
+
warmup_ratio: float = 0.1
|
| 78 |
+
grad_clip: float = 1.0
|
| 79 |
+
log_interval: int = 1
|
| 80 |
+
save_dir: Path = Path("models")
|
| 81 |
+
save_prefix: str = "lwm_pretrain"
|
| 82 |
+
resume_from: Optional[Path] = None
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@dataclasses.dataclass
|
| 86 |
+
class ModelArgs:
|
| 87 |
+
patch_size: Tuple[int, int] = (1, 1)
|
| 88 |
+
phase_mode: str = "real_imag"
|
| 89 |
+
embed_dim: int = 32
|
| 90 |
+
depth: int = 12
|
| 91 |
+
num_heads: int = 8
|
| 92 |
+
mlp_ratio: float = 4.0
|
| 93 |
+
same_frame_window: int = 2
|
| 94 |
+
temporal_offsets: Sequence[int] = dataclasses.field(default_factory=lambda: (-4, -3, -2, -1, 1, 2, 3))
|
| 95 |
+
temporal_spatial_window: int = 2
|
| 96 |
+
temporal_drift_h: int = 1
|
| 97 |
+
temporal_drift_w: int = 1
|
| 98 |
+
routing_topk_enable: bool = True
|
| 99 |
+
routing_topk_fraction: float = 0.2
|
| 100 |
+
routing_topk_min: int = 8
|
| 101 |
+
routing_topk_max: int = 32
|
| 102 |
+
topk_per_head: bool = True
|
| 103 |
+
posenc: str = "learned"
|
| 104 |
+
rope_base: float = 10000.0
|
| 105 |
+
global_cls: bool = False
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@dataclasses.dataclass
|
| 109 |
+
class PretrainingArgs:
|
| 110 |
+
data: DataArgs
|
| 111 |
+
mask: MaskArgs
|
| 112 |
+
curriculum: CurriculumArgs
|
| 113 |
+
augment: AugmentationArgs
|
| 114 |
+
optim: OptimizationArgs
|
| 115 |
+
model: ModelArgs
|
| 116 |
+
logging: LoggingArgs
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class PretrainingDataset(Dataset):
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
args: DataArgs,
|
| 123 |
+
tokenizer: ComplexPatchTokenizer,
|
| 124 |
+
augmenter: Augmenter,
|
| 125 |
+
masker: MaskGenerator,
|
| 126 |
+
patch_size: Tuple[int, int],
|
| 127 |
+
) -> None:
|
| 128 |
+
self.args = args
|
| 129 |
+
self.tokenizer = tokenizer
|
| 130 |
+
self.augmenter = augmenter
|
| 131 |
+
self.masker = masker
|
| 132 |
+
self.patch_size = patch_size
|
| 133 |
+
self.samples = self._load_sequences()
|
| 134 |
+
if args.normalize != "none":
|
| 135 |
+
self.samples = [self._normalize(sample, args.normalize) for sample in self.samples]
|
| 136 |
+
|
| 137 |
+
def _load_sequences(self) -> List[Tensor]:
|
| 138 |
+
processor = AngleDelayProcessor(AngleDelayConfig(keep_percentage=self.args.keep_percentage))
|
| 139 |
+
samples: List[Tensor] = []
|
| 140 |
+
for path in sorted(self.args.data_dir.glob("*.p")):
|
| 141 |
+
with path.open("rb") as handle:
|
| 142 |
+
payload = pickle.load(handle)
|
| 143 |
+
if isinstance(payload, dict) and "channel" in payload:
|
| 144 |
+
tensor = torch.as_tensor(payload["channel"], dtype=torch.complex64)
|
| 145 |
+
else:
|
| 146 |
+
tensor = torch.as_tensor(payload, dtype=torch.complex64)
|
| 147 |
+
if tensor.ndim == 3:
|
| 148 |
+
tensor = tensor.unsqueeze(0)
|
| 149 |
+
for seq in tensor:
|
| 150 |
+
ad = processor.forward(seq)
|
| 151 |
+
truncated, _ = processor.truncate_delay_bins(ad)
|
| 152 |
+
if self.args.max_time_steps is not None and truncated.size(0) > self.args.max_time_steps:
|
| 153 |
+
truncated = truncated[: self.args.max_time_steps]
|
| 154 |
+
samples.append(truncated)
|
| 155 |
+
return samples
|
| 156 |
+
|
| 157 |
+
def _normalize(self, tensor: Tensor, mode: str) -> Tensor:
|
| 158 |
+
if mode == "global_rms":
|
| 159 |
+
rms = torch.sqrt((tensor.real.float() ** 2 + tensor.imag.float() ** 2).mean()).clamp_min(1e-8)
|
| 160 |
+
return tensor / rms.to(tensor.dtype)
|
| 161 |
+
if mode == "per_sample_rms":
|
| 162 |
+
rms = torch.sqrt((tensor.real.float() ** 2 + tensor.imag.float() ** 2).mean()).clamp_min(1e-8)
|
| 163 |
+
return tensor / rms.to(tensor.dtype)
|
| 164 |
+
return tensor
|
| 165 |
+
|
| 166 |
+
def __len__(self) -> int:
|
| 167 |
+
return len(self.samples)
|
| 168 |
+
|
| 169 |
+
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
| 170 |
+
sample = self.samples[index]
|
| 171 |
+
if self.augmenter is not None:
|
| 172 |
+
sample = self.augmenter(sample)
|
| 173 |
+
tokens, _ = self.tokenizer(sample.unsqueeze(0), self.patch_size)
|
| 174 |
+
tokens = tokens.squeeze(0)
|
| 175 |
+
ph, pw = self.patch_size
|
| 176 |
+
T, N, M = sample.shape
|
| 177 |
+
H = N // ph
|
| 178 |
+
W = M // pw
|
| 179 |
+
mask = self.masker(T, H, W, device=tokens.device).view(-1)
|
| 180 |
+
shape = torch.tensor([T, H, W], dtype=torch.long)
|
| 181 |
+
return {
|
| 182 |
+
"tokens": tokens,
|
| 183 |
+
"mask": mask,
|
| 184 |
+
"shape": shape,
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class MaskGenerator:
|
| 189 |
+
def __init__(self, args: MaskArgs) -> None:
|
| 190 |
+
self.args = args
|
| 191 |
+
|
| 192 |
+
def __call__(self, T: int, H: int, W: int, device: torch.device) -> torch.BoolTensor:
|
| 193 |
+
if self.args.mask_mode == "random" or (self.args.mask_mode == "auto" and random.random() < self.args.random_fraction):
|
| 194 |
+
return self.random_mask(T, H, W, device)
|
| 195 |
+
if self.args.mask_mode in {"rect", "auto"} and random.random() < 0.33:
|
| 196 |
+
return self.rect_mask(T, H, W, device)
|
| 197 |
+
if self.args.mask_mode in {"tube", "auto"} and random.random() < 0.33:
|
| 198 |
+
return self.tube_mask(T, H, W, device)
|
| 199 |
+
return self.comb_mask(T, H, W, device)
|
| 200 |
+
|
| 201 |
+
def random_mask(self, T: int, H: int, W: int, device: torch.device) -> torch.BoolTensor:
|
| 202 |
+
total = T * H * W
|
| 203 |
+
num_mask = int(self.args.mask_ratio * total)
|
| 204 |
+
mask = torch.zeros(total, dtype=torch.bool, device=device)
|
| 205 |
+
idx = torch.randperm(total, device=device)[:num_mask]
|
| 206 |
+
mask[idx] = True
|
| 207 |
+
return mask.view(T, H, W)
|
| 208 |
+
|
| 209 |
+
def rect_mask(self, T: int, H: int, W: int, device: torch.device) -> torch.BoolTensor:
|
| 210 |
+
mask = torch.zeros((T, H, W), dtype=torch.bool, device=device)
|
| 211 |
+
blocks = max(1, int(self.args.mask_ratio * T))
|
| 212 |
+
for _ in range(blocks):
|
| 213 |
+
t = random.randrange(T)
|
| 214 |
+
h_size = random.randint(1, max(1, H // 2))
|
| 215 |
+
w_size = random.randint(1, max(1, W // 2))
|
| 216 |
+
h0 = random.randint(0, H - h_size)
|
| 217 |
+
w0 = random.randint(0, W - w_size)
|
| 218 |
+
mask[t, h0:h0 + h_size, w0:w0 + w_size] = True
|
| 219 |
+
return mask
|
| 220 |
+
|
| 221 |
+
def tube_mask(self, T: int, H: int, W: int, device: torch.device) -> torch.BoolTensor:
|
| 222 |
+
mask = torch.zeros((T, H, W), dtype=torch.bool, device=device)
|
| 223 |
+
start_t = random.randrange(T)
|
| 224 |
+
h = random.randrange(H)
|
| 225 |
+
w = random.randrange(W)
|
| 226 |
+
length = random.randint(max(1, T // 2), T)
|
| 227 |
+
for k in range(length):
|
| 228 |
+
t_idx = (start_t + k) % T
|
| 229 |
+
mask[t_idx, max(0, h - 1):min(H, h + 2), max(0, w - 1):min(W, w + 2)] = True
|
| 230 |
+
h = max(0, min(H - 1, h + random.randint(-1, 1)))
|
| 231 |
+
w = max(0, min(W - 1, w + random.randint(-1, 1)))
|
| 232 |
+
return mask
|
| 233 |
+
|
| 234 |
+
def comb_mask(self, T: int, H: int, W: int, device: torch.device) -> torch.BoolTensor:
|
| 235 |
+
mask = torch.zeros((T, H, W), dtype=torch.bool, device=device)
|
| 236 |
+
stride_t = random.choice([2, 3]) if T >= 2 else 1
|
| 237 |
+
stride_w = random.choice([3, 4, 6]) if W >= 3 else 1
|
| 238 |
+
offset_t = random.randrange(stride_t)
|
| 239 |
+
offset_w = random.randrange(stride_w)
|
| 240 |
+
for t in range(T):
|
| 241 |
+
for w in range(W):
|
| 242 |
+
visible = (t % stride_t == offset_t) and (w % stride_w == offset_w)
|
| 243 |
+
if not visible:
|
| 244 |
+
mask[t, :, w] = True
|
| 245 |
+
return mask
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class Augmenter:
|
| 249 |
+
def __init__(self, args: AugmentationArgs) -> None:
|
| 250 |
+
self.args = args
|
| 251 |
+
|
| 252 |
+
def __call__(self, tensor: Tensor) -> Tensor:
|
| 253 |
+
x = tensor.clone()
|
| 254 |
+
if torch.rand(()) < self.args.phase_p:
|
| 255 |
+
theta = (torch.rand((), device=x.device) * 2 * math.pi) - math.pi
|
| 256 |
+
rotation = torch.cos(theta) + 1j * torch.sin(theta)
|
| 257 |
+
x = x * rotation
|
| 258 |
+
if torch.rand(()) < self.args.amp_p:
|
| 259 |
+
scale = self.args.amp_min + (self.args.amp_max - self.args.amp_min) * torch.rand((), device=x.device)
|
| 260 |
+
x = x * scale
|
| 261 |
+
if torch.rand(()) < self.args.awgn_p:
|
| 262 |
+
snr_db = torch.empty((), device=x.device).uniform_(self.args.awgn_snr_min, self.args.awgn_snr_max)
|
| 263 |
+
snr_lin = 10 ** (snr_db / 10.0)
|
| 264 |
+
power = (x.real.float().pow(2) + x.imag.float().pow(2)).mean().item()
|
| 265 |
+
if power > 0:
|
| 266 |
+
noise_var = power / snr_lin
|
| 267 |
+
std = math.sqrt(noise_var / 2.0)
|
| 268 |
+
noise_real = torch.randn_like(x.real.float()) * std
|
| 269 |
+
noise_imag = torch.randn_like(x.imag.float()) * std
|
| 270 |
+
noise = torch.complex(noise_real.to(x.dtype), noise_imag.to(x.dtype))
|
| 271 |
+
x = x + noise
|
| 272 |
+
return x
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class PretrainingTrainer:
|
| 276 |
+
def __init__(self, args: PretrainingArgs, *, logger: Optional[logging.Logger] = None) -> None:
|
| 277 |
+
self.args = args
|
| 278 |
+
self.logger = logger or logging.getLogger(__name__)
|
| 279 |
+
self.device = torch.device(args.optim.device)
|
| 280 |
+
self.tokenizer = ComplexPatchTokenizer(args.model.phase_mode)
|
| 281 |
+
self.masker = MaskGenerator(args.mask)
|
| 282 |
+
self.augmenter = Augmenter(args.augment)
|
| 283 |
+
self.dataset = PretrainingDataset(args.data, self.tokenizer, self.augmenter, self.masker, args.model.patch_size)
|
| 284 |
+
self.dataloader = DataLoader(self.dataset, batch_size=args.optim.batch_size, shuffle=True, drop_last=True)
|
| 285 |
+
self.model = self._build_model().to(self.device)
|
| 286 |
+
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=args.optim.lr, weight_decay=args.optim.weight_decay)
|
| 287 |
+
self.scheduler = self._build_scheduler()
|
| 288 |
+
self.scaler = GradScaler()
|
| 289 |
+
self.global_step = 0
|
| 290 |
+
self._wandb_run = self._maybe_init_wandb()
|
| 291 |
+
if self.args.optim.resume_from is not None:
|
| 292 |
+
self._load_checkpoint(self.args.optim.resume_from)
|
| 293 |
+
|
| 294 |
+
def _wandb_enabled(self) -> bool:
|
| 295 |
+
return self._wandb_run is not None
|
| 296 |
+
|
| 297 |
+
def _maybe_init_wandb(self) -> Optional["wandb.sdk.wandb_run.Run"]:
|
| 298 |
+
logging_args = self.args.logging
|
| 299 |
+
if not logging_args.use_wandb:
|
| 300 |
+
return None
|
| 301 |
+
if wandb is None:
|
| 302 |
+
self.logger.warning("Weights & Biases not installed; disabling wandb logging.")
|
| 303 |
+
return None
|
| 304 |
+
config = dataclasses.asdict(self.args)
|
| 305 |
+
run = wandb.init(
|
| 306 |
+
project=logging_args.wandb_project,
|
| 307 |
+
entity=logging_args.wandb_entity,
|
| 308 |
+
name=logging_args.wandb_run_name,
|
| 309 |
+
config=config,
|
| 310 |
+
)
|
| 311 |
+
wandb.watch(self.model, log="all", log_freq=self.args.optim.log_interval)
|
| 312 |
+
self.logger.info("Initialized Weights & Biases run: %s", run.name)
|
| 313 |
+
return run
|
| 314 |
+
|
| 315 |
+
def _wandb_log(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
|
| 316 |
+
if not self._wandb_enabled():
|
| 317 |
+
return
|
| 318 |
+
wandb.log(metrics, step=step)
|
| 319 |
+
|
| 320 |
+
def _finish_wandb(self) -> None:
|
| 321 |
+
if self._wandb_enabled():
|
| 322 |
+
wandb.finish()
|
| 323 |
+
|
| 324 |
+
def _build_model(self) -> LWMBackbone:
|
| 325 |
+
cfg = LWMConfig(
|
| 326 |
+
patch_size=self.args.model.patch_size,
|
| 327 |
+
phase_mode=self.args.model.phase_mode,
|
| 328 |
+
embed_dim=self.args.model.embed_dim,
|
| 329 |
+
depth=self.args.model.depth,
|
| 330 |
+
num_heads=self.args.model.num_heads,
|
| 331 |
+
mlp_ratio=self.args.model.mlp_ratio,
|
| 332 |
+
same_frame_window=self.args.model.same_frame_window,
|
| 333 |
+
temporal_offsets=self.args.model.temporal_offsets,
|
| 334 |
+
temporal_spatial_window=self.args.model.temporal_spatial_window,
|
| 335 |
+
temporal_drift_h=self.args.model.temporal_drift_h,
|
| 336 |
+
temporal_drift_w=self.args.model.temporal_drift_w,
|
| 337 |
+
routing_topk_enable=self.args.model.routing_topk_enable,
|
| 338 |
+
routing_topk_fraction=self.args.model.routing_topk_fraction,
|
| 339 |
+
routing_topk_min=self.args.model.routing_topk_min,
|
| 340 |
+
routing_topk_max=self.args.model.routing_topk_max,
|
| 341 |
+
topk_per_head=self.args.model.topk_per_head,
|
| 342 |
+
posenc=self.args.model.posenc,
|
| 343 |
+
rope_base=self.args.model.rope_base,
|
| 344 |
+
global_cls=self.args.model.global_cls,
|
| 345 |
+
)
|
| 346 |
+
return LWMBackbone(cfg)
|
| 347 |
+
|
| 348 |
+
def _build_scheduler(self) -> torch.optim.lr_scheduler.LambdaLR:
|
| 349 |
+
steps_per_epoch = max(1, len(self.dataloader))
|
| 350 |
+
total_steps = steps_per_epoch * max(1, self.args.optim.epochs)
|
| 351 |
+
warmup_steps = int(self.args.optim.warmup_ratio * total_steps)
|
| 352 |
+
|
| 353 |
+
def schedule(step: int) -> float:
|
| 354 |
+
if step < warmup_steps:
|
| 355 |
+
return float(step) / max(1, warmup_steps)
|
| 356 |
+
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
|
| 357 |
+
return 0.5 * (1.0 + math.cos(math.pi * progress))
|
| 358 |
+
|
| 359 |
+
return torch.optim.lr_scheduler.LambdaLR(self.optimizer, schedule)
|
| 360 |
+
|
| 361 |
+
def _adjust_curriculum(self, epoch: int) -> None:
|
| 362 |
+
if self.args.curriculum.strategy == "mask" and epoch <= self.args.curriculum.warmup_epochs:
|
| 363 |
+
ratio = np.interp(
|
| 364 |
+
epoch,
|
| 365 |
+
[0, self.args.curriculum.warmup_epochs],
|
| 366 |
+
[self.args.curriculum.min_mask_ratio, self.args.curriculum.max_mask_ratio],
|
| 367 |
+
)
|
| 368 |
+
self.masker.args.mask_ratio = float(ratio)
|
| 369 |
+
self.logger.info(
|
| 370 |
+
"Curriculum update | epoch=%d/%d mask_ratio=%0.2f",
|
| 371 |
+
epoch,
|
| 372 |
+
self.args.optim.epochs,
|
| 373 |
+
self.masker.args.mask_ratio,
|
| 374 |
+
)
|
| 375 |
+
self._wandb_log(
|
| 376 |
+
{"train/curriculum_mask_ratio": self.masker.args.mask_ratio},
|
| 377 |
+
step=self.global_step,
|
| 378 |
+
)
|
| 379 |
+
self.logger.info(
|
| 380 |
+
"Curriculum update | epoch=%d/%d mask_ratio=%0.2f",
|
| 381 |
+
epoch,
|
| 382 |
+
self.args.optim.epochs,
|
| 383 |
+
self.masker.args.mask_ratio,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
def train(self) -> None:
|
| 387 |
+
for epoch in range(1, self.args.optim.epochs + 1):
|
| 388 |
+
self._adjust_curriculum(epoch)
|
| 389 |
+
running_loss = 0.0
|
| 390 |
+
loader_len = len(self.dataloader)
|
| 391 |
+
for step, batch in enumerate(self.dataloader, start=1):
|
| 392 |
+
tokens = batch["tokens"].to(self.device)
|
| 393 |
+
mask_tokens = batch["mask"].to(self.device)
|
| 394 |
+
shapes = batch["shape"]
|
| 395 |
+
if not isinstance(shapes, torch.Tensor):
|
| 396 |
+
shapes = torch.tensor(shapes)
|
| 397 |
+
if shapes.dim() == 1:
|
| 398 |
+
shapes = shapes.unsqueeze(0)
|
| 399 |
+
ref_shape = shapes[0]
|
| 400 |
+
if not torch.all(shapes.eq(ref_shape)):
|
| 401 |
+
raise ValueError("Mixed sequence shapes within the same batch are not supported")
|
| 402 |
+
T = int(ref_shape[0].item())
|
| 403 |
+
H = int(ref_shape[1].item())
|
| 404 |
+
W = int(ref_shape[2].item())
|
| 405 |
+
with autocast():
|
| 406 |
+
outputs = self.model.forward_tokens(tokens, mask_tokens, T, H, W, return_cls=False)
|
| 407 |
+
preds = outputs["reconstruction"]
|
| 408 |
+
loss = masked_nmse_loss(preds, tokens, mask_tokens)
|
| 409 |
+
self.scaler.scale(loss).backward()
|
| 410 |
+
self.scaler.unscale_(self.optimizer)
|
| 411 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.optim.grad_clip)
|
| 412 |
+
self.scaler.step(self.optimizer)
|
| 413 |
+
self.scaler.update()
|
| 414 |
+
self.optimizer.zero_grad()
|
| 415 |
+
self.scheduler.step()
|
| 416 |
+
running_loss += loss.item()
|
| 417 |
+
self.global_step += 1
|
| 418 |
+
if step % self.args.optim.log_interval == 0:
|
| 419 |
+
avg_loss = running_loss / step
|
| 420 |
+
lr_cur = self.optimizer.param_groups[0]["lr"]
|
| 421 |
+
self.logger.info(
|
| 422 |
+
"Train: [%d/%d][%d/%d] loss=%0.6f mask=%0.2f lr=%0.2e",
|
| 423 |
+
epoch,
|
| 424 |
+
self.args.optim.epochs,
|
| 425 |
+
step,
|
| 426 |
+
loader_len,
|
| 427 |
+
avg_loss,
|
| 428 |
+
self.masker.args.mask_ratio,
|
| 429 |
+
lr_cur,
|
| 430 |
+
)
|
| 431 |
+
self._wandb_log(
|
| 432 |
+
{
|
| 433 |
+
"train/loss": avg_loss,
|
| 434 |
+
"train/mask_ratio": self.masker.args.mask_ratio,
|
| 435 |
+
"train/lr": lr_cur,
|
| 436 |
+
},
|
| 437 |
+
step=self.global_step,
|
| 438 |
+
)
|
| 439 |
+
avg_epoch_loss = running_loss / max(1, len(self.dataloader))
|
| 440 |
+
self.logger.info(
|
| 441 |
+
"Train Epoch %d/%d Summary: loss=%0.6f",
|
| 442 |
+
epoch,
|
| 443 |
+
self.args.optim.epochs,
|
| 444 |
+
avg_epoch_loss,
|
| 445 |
+
)
|
| 446 |
+
self._wandb_log(
|
| 447 |
+
{
|
| 448 |
+
"train/epoch_loss": avg_epoch_loss,
|
| 449 |
+
},
|
| 450 |
+
step=self.global_step,
|
| 451 |
+
)
|
| 452 |
+
self._save_checkpoint(epoch, avg_epoch_loss)
|
| 453 |
+
self._finish_wandb()
|
| 454 |
+
|
| 455 |
+
def _save_checkpoint(self, epoch: int, metric: float) -> None:
|
| 456 |
+
self.args.optim.save_dir.mkdir(parents=True, exist_ok=True)
|
| 457 |
+
save_prefix = Path(self.args.optim.save_prefix)
|
| 458 |
+
suffix = save_prefix.suffix or ".pth"
|
| 459 |
+
stem = save_prefix.stem if save_prefix.suffix else save_prefix.name
|
| 460 |
+
filename = f"{stem}_epoch{epoch:03d}{suffix}"
|
| 461 |
+
path = self.args.optim.save_dir / filename
|
| 462 |
+
suffix = path.suffix.lower()
|
| 463 |
+
if suffix == ".bin":
|
| 464 |
+
torch.save(self.model.state_dict(), path)
|
| 465 |
+
self.logger.info("Saved weights-only checkpoint to %s", path)
|
| 466 |
+
else:
|
| 467 |
+
torch.save(
|
| 468 |
+
{
|
| 469 |
+
"epoch": epoch,
|
| 470 |
+
"metric": metric,
|
| 471 |
+
"model_state_dict": self.model.state_dict(),
|
| 472 |
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
| 473 |
+
"scheduler_state_dict": self.scheduler.state_dict(),
|
| 474 |
+
"config": dataclasses.asdict(self.args),
|
| 475 |
+
},
|
| 476 |
+
path,
|
| 477 |
+
)
|
| 478 |
+
self.logger.info("Saved checkpoint to %s", path)
|
| 479 |
+
if self._wandb_enabled():
|
| 480 |
+
wandb.save(str(path))
|
| 481 |
+
|
| 482 |
+
def _load_checkpoint(self, checkpoint_path: Path) -> None:
|
| 483 |
+
if not checkpoint_path.exists():
|
| 484 |
+
raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")
|
| 485 |
+
self.logger.info("Resuming from checkpoint %s", checkpoint_path)
|
| 486 |
+
|
| 487 |
+
payload = torch.load(checkpoint_path, map_location=self.device)
|
| 488 |
+
suffix = checkpoint_path.suffix.lower()
|
| 489 |
+
|
| 490 |
+
if suffix == ".bin":
|
| 491 |
+
model_state = payload
|
| 492 |
+
missing, unexpected = self.model.load_state_dict(model_state, strict=False)
|
| 493 |
+
if missing:
|
| 494 |
+
self.logger.warning("Missing keys when loading model: %s", missing)
|
| 495 |
+
if unexpected:
|
| 496 |
+
self.logger.warning("Unexpected keys when loading model: %s", unexpected)
|
| 497 |
+
self.logger.info("Loaded weights-only checkpoint.")
|
| 498 |
+
return
|
| 499 |
+
|
| 500 |
+
model_state = payload.get("model_state_dict")
|
| 501 |
+
if model_state is not None:
|
| 502 |
+
missing, unexpected = self.model.load_state_dict(model_state, strict=False)
|
| 503 |
+
if missing:
|
| 504 |
+
self.logger.warning("Missing keys when loading model: %s", missing)
|
| 505 |
+
if unexpected:
|
| 506 |
+
self.logger.warning("Unexpected keys when loading model: %s", unexpected)
|
| 507 |
+
|
| 508 |
+
opt_state = payload.get("optimizer_state_dict")
|
| 509 |
+
if opt_state is not None:
|
| 510 |
+
try:
|
| 511 |
+
self.optimizer.load_state_dict(opt_state)
|
| 512 |
+
except Exception as exc:
|
| 513 |
+
self.logger.warning("Failed to load optimizer state: %s", exc)
|
| 514 |
+
|
| 515 |
+
sched_state = payload.get("scheduler_state_dict")
|
| 516 |
+
if sched_state is not None:
|
| 517 |
+
try:
|
| 518 |
+
self.scheduler.load_state_dict(sched_state)
|
| 519 |
+
except Exception as exc:
|
| 520 |
+
self.logger.warning("Failed to load scheduler state: %s", exc)
|
| 521 |
+
|
| 522 |
+
epoch = payload.get("epoch", 0)
|
| 523 |
+
metric = payload.get("metric")
|
| 524 |
+
self.logger.info("Checkpoint contained epoch=%s metric=%s", epoch, metric)
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def build_pretraining_args(ns: argparse.Namespace) -> PretrainingArgs:
|
| 528 |
+
data_args = DataArgs(
|
| 529 |
+
data_dir=ns.data_dir,
|
| 530 |
+
keep_percentage=ns.keep_percentage,
|
| 531 |
+
normalize=ns.normalize,
|
| 532 |
+
max_time_steps=ns.max_time_steps,
|
| 533 |
+
)
|
| 534 |
+
mask_args = MaskArgs(mask_ratio=ns.mask_ratio, mask_mode=ns.mask_mode, random_fraction=ns.mask_random_fraction)
|
| 535 |
+
curriculum_args = CurriculumArgs(
|
| 536 |
+
strategy=ns.curriculum_strategy,
|
| 537 |
+
warmup_epochs=ns.curriculum_warmup_epochs,
|
| 538 |
+
min_mask_ratio=ns.curriculum_min_mask,
|
| 539 |
+
max_mask_ratio=ns.curriculum_max_mask,
|
| 540 |
+
)
|
| 541 |
+
augment_args = AugmentationArgs(
|
| 542 |
+
phase_p=ns.aug_phase_p,
|
| 543 |
+
amp_p=ns.aug_amp_p,
|
| 544 |
+
amp_min=ns.aug_amp_min,
|
| 545 |
+
amp_max=ns.aug_amp_max,
|
| 546 |
+
awgn_p=ns.aug_awgn_p,
|
| 547 |
+
awgn_snr_min=ns.aug_awgn_snr_min,
|
| 548 |
+
awgn_snr_max=ns.aug_awgn_snr_max,
|
| 549 |
+
)
|
| 550 |
+
optim_args = OptimizationArgs(
|
| 551 |
+
device=ns.device,
|
| 552 |
+
epochs=ns.epochs,
|
| 553 |
+
batch_size=ns.batch_size,
|
| 554 |
+
lr=ns.lr,
|
| 555 |
+
weight_decay=ns.weight_decay,
|
| 556 |
+
warmup_ratio=ns.warmup_ratio,
|
| 557 |
+
grad_clip=ns.grad_clip,
|
| 558 |
+
log_interval=ns.log_interval,
|
| 559 |
+
save_dir=ns.save_dir,
|
| 560 |
+
save_prefix=ns.save_prefix,
|
| 561 |
+
resume_from=ns.resume_from,
|
| 562 |
+
)
|
| 563 |
+
logging_args = LoggingArgs(
|
| 564 |
+
log_dir=ns.log_dir,
|
| 565 |
+
use_wandb=ns.use_wandb,
|
| 566 |
+
wandb_project=ns.wandb_project,
|
| 567 |
+
wandb_entity=ns.wandb_entity,
|
| 568 |
+
wandb_run_name=ns.wandb_run_name,
|
| 569 |
+
)
|
| 570 |
+
model_args = ModelArgs(
|
| 571 |
+
patch_size=tuple(ns.patch_size),
|
| 572 |
+
phase_mode=ns.phase_mode,
|
| 573 |
+
embed_dim=ns.embed_dim,
|
| 574 |
+
depth=ns.depth,
|
| 575 |
+
num_heads=ns.num_heads,
|
| 576 |
+
mlp_ratio=ns.mlp_ratio,
|
| 577 |
+
same_frame_window=ns.same_frame_window,
|
| 578 |
+
temporal_offsets=tuple(ns.temporal_offsets),
|
| 579 |
+
temporal_spatial_window=ns.temporal_spatial_window,
|
| 580 |
+
temporal_drift_h=ns.temporal_drift_h,
|
| 581 |
+
temporal_drift_w=ns.temporal_drift_w,
|
| 582 |
+
routing_topk_enable=ns.routing_topk_enable,
|
| 583 |
+
routing_topk_fraction=ns.routing_topk_fraction,
|
| 584 |
+
routing_topk_min=ns.routing_topk_min,
|
| 585 |
+
routing_topk_max=ns.routing_topk_max,
|
| 586 |
+
topk_per_head=ns.topk_per_head,
|
| 587 |
+
posenc=ns.posenc,
|
| 588 |
+
rope_base=ns.rope_base,
|
| 589 |
+
)
|
| 590 |
+
return PretrainingArgs(
|
| 591 |
+
data=data_args,
|
| 592 |
+
mask=mask_args,
|
| 593 |
+
curriculum=curriculum_args,
|
| 594 |
+
augment=augment_args,
|
| 595 |
+
optim=optim_args,
|
| 596 |
+
model=model_args,
|
| 597 |
+
logging=logging_args,
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 602 |
+
parser = argparse.ArgumentParser(description="Pretrain LWM foundation model")
|
| 603 |
+
parser.add_argument("--data_dir", type=Path, required=True)
|
| 604 |
+
parser.add_argument("--keep_percentage", type=float, default=0.25)
|
| 605 |
+
parser.add_argument("--normalize", type=str, default="global_rms", choices=["global_rms", "per_sample_rms", "none"])
|
| 606 |
+
parser.add_argument("--max_time_steps", type=int, default=None)
|
| 607 |
+
|
| 608 |
+
parser.add_argument("--mask_ratio", type=float, default=0.60)
|
| 609 |
+
parser.add_argument("--mask_mode", type=str, default="auto", choices=["auto", "random", "rect", "tube", "comb"])
|
| 610 |
+
parser.add_argument("--mask_random_fraction", type=float, default=0.2)
|
| 611 |
+
|
| 612 |
+
parser.add_argument("--curriculum_strategy", type=str, default="mask", choices=["none", "mask"])
|
| 613 |
+
parser.add_argument("--curriculum_warmup_epochs", type=int, default=4)
|
| 614 |
+
parser.add_argument("--curriculum_min_mask", type=float, default=0.3)
|
| 615 |
+
parser.add_argument("--curriculum_max_mask", type=float, default=0.75)
|
| 616 |
+
parser.add_argument("--log_dir", type=Path, default=Path("logs"))
|
| 617 |
+
parser.add_argument("--use_wandb", action="store_true")
|
| 618 |
+
parser.add_argument("--wandb_project", type=str, default=None)
|
| 619 |
+
parser.add_argument("--wandb_entity", type=str, default=None)
|
| 620 |
+
parser.add_argument("--wandb_run_name", type=str, default=None)
|
| 621 |
+
|
| 622 |
+
parser.add_argument("--phase_mode", type=str, default="real_imag", choices=["real_imag", "mag_phase"])
|
| 623 |
+
parser.add_argument("--patch_size", type=int, nargs=2, default=(1, 1))
|
| 624 |
+
parser.add_argument("--embed_dim", type=int, default=32)
|
| 625 |
+
parser.add_argument("--depth", type=int, default=12)
|
| 626 |
+
parser.add_argument("--num_heads", type=int, default=8)
|
| 627 |
+
parser.add_argument("--mlp_ratio", type=float, default=4.0)
|
| 628 |
+
parser.add_argument("--same_frame_window", type=int, default=2)
|
| 629 |
+
parser.add_argument("--temporal_offsets", type=int, nargs="*", default=[-4, -3, -2, -1, 1, 2, 3])
|
| 630 |
+
parser.add_argument("--temporal_spatial_window", type=int, default=2)
|
| 631 |
+
parser.add_argument("--temporal_drift_h", type=int, default=1)
|
| 632 |
+
parser.add_argument("--temporal_drift_w", type=int, default=1)
|
| 633 |
+
parser.add_argument("--routing_topk_enable", action="store_true", default=True)
|
| 634 |
+
parser.add_argument("--routing_topk_fraction", type=float, default=0.2)
|
| 635 |
+
parser.add_argument("--routing_topk_min", type=int, default=8)
|
| 636 |
+
parser.add_argument("--routing_topk_max", type=int, default=32)
|
| 637 |
+
parser.add_argument("--topk_per_head", action="store_true", default=True)
|
| 638 |
+
parser.add_argument("--posenc", type=str, default="learned", choices=["learned", "rope_sincos"])
|
| 639 |
+
parser.add_argument("--rope_base", type=float, default=10000.0)
|
| 640 |
+
|
| 641 |
+
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
| 642 |
+
parser.add_argument("--epochs", type=int, default=20)
|
| 643 |
+
parser.add_argument("--batch_size", type=int, default=32)
|
| 644 |
+
parser.add_argument("--lr", type=float, default=2e-4)
|
| 645 |
+
parser.add_argument("--weight_decay", type=float, default=1e-4)
|
| 646 |
+
parser.add_argument("--warmup_ratio", type=float, default=0.1)
|
| 647 |
+
parser.add_argument("--grad_clip", type=float, default=1.0)
|
| 648 |
+
parser.add_argument("--log_interval", type=int, default=1)
|
| 649 |
+
parser.add_argument("--save_dir", type=Path, default=Path("models"))
|
| 650 |
+
parser.add_argument("--save_prefix", type=str, default="lwm_pretrain")
|
| 651 |
+
parser.add_argument("--resume_from", type=Path, default=None, help="Path to checkpoint to resume from")
|
| 652 |
+
|
| 653 |
+
parser.add_argument("--aug_phase_p", type=float, default=0.0)
|
| 654 |
+
parser.add_argument("--aug_amp_p", type=float, default=0.0)
|
| 655 |
+
parser.add_argument("--aug_amp_min", type=float, default=0.7)
|
| 656 |
+
parser.add_argument("--aug_amp_max", type=float, default=1.3)
|
| 657 |
+
parser.add_argument("--aug_awgn_p", type=float, default=0.0)
|
| 658 |
+
parser.add_argument("--aug_awgn_snr_min", type=float, default=20.0)
|
| 659 |
+
parser.add_argument("--aug_awgn_snr_max", type=float, default=30.0)
|
| 660 |
+
|
| 661 |
+
return parser
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
def main(argv: Optional[Sequence[str]] = None) -> None:
|
| 665 |
+
ns = build_parser().parse_args(args=list(argv) if argv is not None else None)
|
| 666 |
+
args = build_pretraining_args(ns)
|
| 667 |
+
trainer = PretrainingTrainer(args)
|
| 668 |
+
trainer.train()
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
__all__ = [
|
| 672 |
+
"DataArgs",
|
| 673 |
+
"MaskArgs",
|
| 674 |
+
"CurriculumArgs",
|
| 675 |
+
"AugmentationArgs",
|
| 676 |
+
"OptimizationArgs",
|
| 677 |
+
"ModelArgs",
|
| 678 |
+
"PretrainingArgs",
|
| 679 |
+
"PretrainingDataset",
|
| 680 |
+
"PretrainingTrainer",
|
| 681 |
+
"build_pretraining_args",
|
| 682 |
+
"build_parser",
|
| 683 |
+
"main",
|
| 684 |
+
]
|
LWMTemporal/training/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Training utilities for LWM foundation models."""
|
LWMTemporal/utils/logging.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
LOG_FORMAT = "[%(asctime)s,%(msecs)03d %(levelname)s %(name)s line %(lineno)d %(process)d] %(message)s"
|
| 7 |
+
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def setup_logging(
|
| 11 |
+
name: str = "LWMTemporal",
|
| 12 |
+
log_dir: Path | None = None,
|
| 13 |
+
level: int = logging.INFO,
|
| 14 |
+
) -> logging.Logger:
|
| 15 |
+
"""Configure and return a logger using the original package's style."""
|
| 16 |
+
logger = logging.getLogger(name)
|
| 17 |
+
logger.setLevel(level)
|
| 18 |
+
|
| 19 |
+
# Avoid duplicating handlers when called multiple times
|
| 20 |
+
if logger.hasHandlers():
|
| 21 |
+
logger.handlers.clear()
|
| 22 |
+
|
| 23 |
+
formatter = logging.Formatter(LOG_FORMAT, DATE_FORMAT)
|
| 24 |
+
|
| 25 |
+
if log_dir is not None:
|
| 26 |
+
log_dir = Path(log_dir)
|
| 27 |
+
log_dir.mkdir(parents=True, exist_ok=True)
|
| 28 |
+
file_handler = logging.FileHandler(log_dir / f"{name}.log")
|
| 29 |
+
file_handler.setFormatter(formatter)
|
| 30 |
+
file_handler.setLevel(level)
|
| 31 |
+
logger.addHandler(file_handler)
|
| 32 |
+
|
| 33 |
+
stream_handler = logging.StreamHandler()
|
| 34 |
+
stream_handler.setFormatter(formatter)
|
| 35 |
+
stream_handler.setLevel(level)
|
| 36 |
+
logger.addHandler(stream_handler)
|
| 37 |
+
|
| 38 |
+
return logger
|
| 39 |
+
|
MANIFEST.in
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
include README.md
|
| 2 |
+
include LICENSE
|
| 3 |
+
include requirements.txt
|
| 4 |
+
include LWMTemporal/models/config.json
|
| 5 |
+
recursive-exclude * __pycache__
|
| 6 |
+
recursive-exclude * *.py[co]
|
| 7 |
+
recursive-exclude * .DS_Store
|
| 8 |
+
exclude cache
|
| 9 |
+
exclude logs
|
| 10 |
+
exclude figs
|
| 11 |
+
exclude wandb
|
| 12 |
+
exclude checkpoints
|
| 13 |
+
exclude test.py
|
| 14 |
+
|
README.md
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LWMTemporal
|
| 2 |
+
|
| 3 |
+
Large Wireless Model (LWM) with sparse spatio-temporal attention for wireless channel prediction and forecasting.
|
| 4 |
+
|
| 5 |
+
This package provides a transformer-based model for spatio-temporal wireless channel prediction with support for both pretraining and fine-tuning tasks. It follows Hugging Face conventions for model checkpoints and configurations.
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## Installation
|
| 10 |
+
|
| 11 |
+
### From PyPI (Recommended)
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
pip install lwm-temporal
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
### From Source
|
| 18 |
+
|
| 19 |
+
```bash
|
| 20 |
+
git clone https://github.com/yourusername/lwm-temporal.git
|
| 21 |
+
cd lwm-temporal
|
| 22 |
+
pip install -e .
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
### Optional Dependencies
|
| 26 |
+
|
| 27 |
+
For Weights & Biases logging:
|
| 28 |
+
```bash
|
| 29 |
+
pip install lwm-temporal[wandb]
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
For development:
|
| 33 |
+
```bash
|
| 34 |
+
pip install lwm-temporal[dev]
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
---
|
| 38 |
+
|
| 39 |
+
## Quick Start
|
| 40 |
+
|
| 41 |
+
### Python API
|
| 42 |
+
|
| 43 |
+
```python
|
| 44 |
+
from pathlib import Path
|
| 45 |
+
from LWMTemporal import LWMBackbone, LWMConfig
|
| 46 |
+
|
| 47 |
+
# Load pretrained model
|
| 48 |
+
model = LWMBackbone.from_pretrained("checkpoints/m18_cp.pth")
|
| 49 |
+
model.eval()
|
| 50 |
+
|
| 51 |
+
# Use for inference - see examples/ for complete scripts
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
### Command Line Interface
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
# Run channel prediction inference
|
| 58 |
+
python -m LWMTemporal.cli.channel_prediction \
|
| 59 |
+
--data_path examples/data/city_8_tempe_3p5_20_32_32.p \
|
| 60 |
+
--pretrained checkpoints/m18_cp.pth \
|
| 61 |
+
--inference_only \
|
| 62 |
+
--device cpu
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
See `examples/` directory for more detailed usage examples.
|
| 66 |
+
|
| 67 |
+
---
|
| 68 |
+
|
| 69 |
+
## 1. Environment Setup
|
| 70 |
+
|
| 71 |
+
### Requirements
|
| 72 |
+
|
| 73 |
+
- Python >= 3.9
|
| 74 |
+
- PyTorch >= 2.0.0
|
| 75 |
+
- NumPy >= 1.21.0
|
| 76 |
+
- Matplotlib >= 3.5.0
|
| 77 |
+
|
| 78 |
+
Verify that your PyTorch build matches your hardware (CPU vs CUDA). Mixed-precision (AMP) is optional; on CPU it will automatically disable itself.
|
| 79 |
+
|
| 80 |
+
---
|
| 81 |
+
|
| 82 |
+
## 2. Repository Layout
|
| 83 |
+
|
| 84 |
+
```
|
| 85 |
+
LWMTemporal/ # Main package
|
| 86 |
+
cli/ # Command-line entry points
|
| 87 |
+
data/ # Dataset loaders & preprocessing utilities
|
| 88 |
+
models/ # LWM model + backbone + configs
|
| 89 |
+
tasks/ # High-level training/inference orchestration
|
| 90 |
+
utils/ # Logging and helper utilities
|
| 91 |
+
examples/ # Example scripts and sample data
|
| 92 |
+
data/ # Example datasets
|
| 93 |
+
checkpoints/ # Pretrained model checkpoints
|
| 94 |
+
cache/ # Optional dataset cache (auto-created)
|
| 95 |
+
figs/predictions/ # Visualization output (auto-created)
|
| 96 |
+
logs/ # Training logs (auto-created)
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
---
|
| 100 |
+
|
| 101 |
+
## 3. Checkpoint Format (Hugging Face Compatible)
|
| 102 |
+
|
| 103 |
+
The code supports two checkpoint formats:
|
| 104 |
+
|
| 105 |
+
**Format 1: Directory (Hugging Face style)**
|
| 106 |
+
```
|
| 107 |
+
checkpoints/my_model/
|
| 108 |
+
config.json # Model configuration
|
| 109 |
+
pytorch_model.bin # Model weights
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
**Format 2: Single file**
|
| 113 |
+
```
|
| 114 |
+
checkpoints/my_model.pth # Contains both weights and optional config
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
The package automatically detects and loads either format.
|
| 118 |
+
|
| 119 |
+
> **Tip:** If you only have a single file (e.g. `model_best.pth`), move it to a directory and rename to `pytorch_model.bin`. Copy or recreate a matching `config.json`. The loader infers `max_seq_len` when it sees a longer positional embedding in the checkpoint, so older weights continue to work.
|
| 120 |
+
|
| 121 |
+
The directory can be uploaded to Hugging Face Hub as-is and loaded via `AutoModel.from_pretrained` if you create a thin wrapper.
|
| 122 |
+
|
| 123 |
+
---
|
| 124 |
+
|
| 125 |
+
## 4. Dataset Preparation
|
| 126 |
+
|
| 127 |
+
- The pipeline consumes pickle (`.p`) payloads with a `channel` key (complex tensor) and optional metadata (`pos`, `dt`).
|
| 128 |
+
- `AngleDelaySequenceDataset` normalizes, truncates, and caches angle-delay representations on demand.
|
| 129 |
+
- Configure preprocessing through `DatasetArgs`:
|
| 130 |
+
- `keep_percentage` β fraction of strongest taps to keep.
|
| 131 |
+
- `normalize` β `global_rms`, `per_sample_rms`, or `none`.
|
| 132 |
+
- `cache_dir`, `use_cache`, `overwrite_cache` β caching behavior.
|
| 133 |
+
- `snr_db`, `noise_seed` β synthetic AWGN injection.
|
| 134 |
+
- `max_time_steps` β optional temporal truncation.
|
| 135 |
+
|
| 136 |
+
Cached tensors are stored under `cache/adseq_<stem>_keepXX_<normalize>.pt`.
|
| 137 |
+
|
| 138 |
+
---
|
| 139 |
+
|
| 140 |
+
## 5. Command-Line Usage
|
| 141 |
+
|
| 142 |
+
The CLI mirrors the Hugging Face workflow (`python -m package.cli ...`).
|
| 143 |
+
|
| 144 |
+
### 5.1 Inference / Evaluation
|
| 145 |
+
|
| 146 |
+
```bash
|
| 147 |
+
python -m LWMTemporal.cli.channel_prediction \
|
| 148 |
+
--data_path examples/data/parow.p \
|
| 149 |
+
--pretrained checkpoints/m18_cp.pth \
|
| 150 |
+
--inference_only \
|
| 151 |
+
--inference_split val \
|
| 152 |
+
--Tpast 10 \
|
| 153 |
+
--horizon 1
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
- `--Tpast` / `--horizon` define the autoregressive roll-out window.
|
| 157 |
+
- `--inference_split` selects which subset to score (`train`, `val`, `all`).
|
| 158 |
+
- Visualizations are written to `figs/predictions/`.
|
| 159 |
+
|
| 160 |
+
### 5.2 Training / Fine-Tuning
|
| 161 |
+
|
| 162 |
+
Remove `--inference_only` to launch training:
|
| 163 |
+
|
| 164 |
+
```bash
|
| 165 |
+
python -m LWMTemporal.cli.channel_prediction \
|
| 166 |
+
--data_path examples/data/parow.p \
|
| 167 |
+
--save_dir models/finetune_run \
|
| 168 |
+
--epochs 5 \
|
| 169 |
+
--batch_size 8 \
|
| 170 |
+
--lr 3e-4 \
|
| 171 |
+
--Tpast 10 \
|
| 172 |
+
--horizon 2
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
Notable flags:
|
| 176 |
+
|
| 177 |
+
- `--pretrained` β resume from existing weights.
|
| 178 |
+
- `--train_head_only` β freeze encoder, train output head.
|
| 179 |
+
- `--finetune_last_n` β unfreeze last N transformer blocks.
|
| 180 |
+
- `--global_cls` β enable CLS token for global prediction heads.
|
| 181 |
+
- `--routing_topk_enable`, `--topk_per_head`, etc. β sparse attention controls.
|
| 182 |
+
- `--temporal_offsets` β defaults to `[-4, -3, -2, -1]` so the attention only reaches the previous four frames.
|
| 183 |
+
- `--use_wandb` (with `--wandb_project`, `--wandb_run_name`, `--wandb_entity`) β stream training/eval metrics to Weights & Biases.
|
| 184 |
+
|
| 185 |
+
Checkpoints are saved as `save_dir/<prefix>_epochXX.pth` together with optimizer state. Use `ChannelPredictionTrainer._save_checkpoint` for custom logic.
|
| 186 |
+
|
| 187 |
+
---
|
| 188 |
+
|
| 189 |
+
## 6. Python API Usage
|
| 190 |
+
|
| 191 |
+
Construct arguments programmatically and drive training/evaluation via `ChannelPredictionTrainer`:
|
| 192 |
+
|
| 193 |
+
```python
|
| 194 |
+
from pathlib import Path
|
| 195 |
+
from LWMTemporal.tasks.channel_prediction import (
|
| 196 |
+
ChannelPredictionArgs,
|
| 197 |
+
DatasetArgs,
|
| 198 |
+
ModelArgs,
|
| 199 |
+
TrainingArgs,
|
| 200 |
+
PredictionArgs,
|
| 201 |
+
ChannelPredictionTrainer,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
args = ChannelPredictionArgs(
|
| 205 |
+
dataset=DatasetArgs(
|
| 206 |
+
data_path=Path("examples/data/parow.p"),
|
| 207 |
+
keep_percentage=0.25,
|
| 208 |
+
train_limit=500,
|
| 209 |
+
val_limit=1000,
|
| 210 |
+
),
|
| 211 |
+
model=ModelArgs(
|
| 212 |
+
patch_size=(1, 1),
|
| 213 |
+
phase_mode="real_imag",
|
| 214 |
+
pretrained=Path("checkpoints/m18_cp.pth"),
|
| 215 |
+
),
|
| 216 |
+
training=TrainingArgs(
|
| 217 |
+
inference_only=True,
|
| 218 |
+
device="cpu",
|
| 219 |
+
batch_size=4,
|
| 220 |
+
),
|
| 221 |
+
prediction=PredictionArgs(Tpast=10, horizon=1),
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
trainer = ChannelPredictionTrainer(args)
|
| 225 |
+
trainer.train() # runs evaluate() because inference_only=True
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
From here you can:
|
| 229 |
+
|
| 230 |
+
- Access `trainer.model` (an `LWMBackbone`) for custom forward passes.
|
| 231 |
+
- Call `trainer.data.train_loader(...)` / `val_loader(...)` for raw dataloaders.
|
| 232 |
+
- Use `trainer.engine.autoregressive_rollout(...)` to obtain `(pred_tokens, target_tokens, mask)` tensors for downstream metrics.
|
| 233 |
+
- Generate visualizations with `trainer.viz.save(...)`.
|
| 234 |
+
|
| 235 |
+
---
|
| 236 |
+
|
| 237 |
+
## 7. Working With `LWMBackbone`
|
| 238 |
+
|
| 239 |
+
- Instantiate from scratch: `LWMBackbone(LWMConfig(...))`.
|
| 240 |
+
- Load checkpoints:
|
| 241 |
+
```python
|
| 242 |
+
from LWMTemporal.models.lwm import LWMBackbone, LWMConfig
|
| 243 |
+
|
| 244 |
+
cfg = LWMConfig(patch_size=(1, 1), embed_dim=32, max_seq_len=2816)
|
| 245 |
+
model = LWMBackbone.from_pretrained("checkpoints/m18_cp.pth", config=cfg)
|
| 246 |
+
```
|
| 247 |
+
- Save checkpoints: `model.save_pretrained("path/to/output")`.
|
| 248 |
+
- The loader automatically adjusts `config.max_seq_len` when the checkpointβs positional embedding is longer than the provided config.
|
| 249 |
+
|
| 250 |
+
`LWMModel.forward(seq, mask=None, return_cls=False)` accepts complex tensors shaped `(B, T, N, M)` and returns reconstruction tokens along with an optional CLS embedding when enabled.
|
| 251 |
+
|
| 252 |
+
---
|
| 253 |
+
|
| 254 |
+
## 8. Visualization & Metrics
|
| 255 |
+
|
| 256 |
+
- `PredictionVisualizer` renders magnitude plots (`|H|`) for predicted vs. ground-truth angle-delay grids.
|
| 257 |
+
- Metrics:
|
| 258 |
+
- `masked_nmse_loss` / `compute_nmse` β Normalized MSE over valid tokens.
|
| 259 |
+
- `masked_mse_loss` β standard MSE with masking support.
|
| 260 |
+
|
| 261 |
+
Configure masking via token `mask` tensors (boolean) where `True` indicates dropped tokens.
|
| 262 |
+
|
| 263 |
+
---
|
| 264 |
+
|
| 265 |
+
## 9. Advanced Configuration
|
| 266 |
+
|
| 267 |
+
- **Sparse Attention Windows:** Control spatial/temporal neighborhoods via `same_frame_window`, `temporal_offsets`, `temporal_spatial_window`, `temporal_drift_*`, and dilation parameters. The default `temporal_offsets = (-4, -3, -2, -1)` limits attention to the previous four frames.
|
| 268 |
+
- **Routing & Top-k Pruning:** Enable dynamic neighbour pruning with `routing_topk_enable`, `routing_topk_fraction`, `routing_topk_min/max`, or fallback to static `topk_neighbors`.
|
| 269 |
+
- **Positional Encoding:** `posenc` supports `learned` or `rope_sincos`. Additional `rope_base_*` parameters adjust RoPE scaling.
|
| 270 |
+
- **CLS Token:** Toggle `global_cls`; the autoregressive rollout handles CLS automatically when present.
|
| 271 |
+
- **Detokenization:** `AutoregressiveEngine.detokenize` converts predicted tokens back to complex-valued channel coefficients.
|
| 272 |
+
|
| 273 |
+
---
|
| 274 |
+
|
| 275 |
+
## 10. Troubleshooting
|
| 276 |
+
|
| 277 |
+
- **Circular Imports:** The project avoids cross-imports by keeping tokenizers in `models`. Ensure you are on the latest code if you encounter import errors.
|
| 278 |
+
- **Checkpoint Shape Mismatch:** Confirm `patch_size`, `phase_mode`, and positional embedding lengths match between config and weights.
|
| 279 |
+
- **Neighbor Padding Errors:** Patched `NeighborIndexer` pads ragged neighbour lists with `-1`, so any older ValueError is resolved once you update to the current code.
|
| 280 |
+
- **AMP Warnings:** On CPU you may see `GradScaler` warnings; they are benign because AMP disables itself.
|
| 281 |
+
- **Data Shape Mismatch:** Sequences must have consistent `(T, H, W)` dimensions within each batch. The trainer raises a descriptive error otherwise.
|
| 282 |
+
|
| 283 |
+
---
|
| 284 |
+
|
| 285 |
+
## 11. Hugging Face Integration
|
| 286 |
+
|
| 287 |
+
- Because checkpoints follow the standard `config.json` + `pytorch_model.bin` scheme, you can do:
|
| 288 |
+
```python
|
| 289 |
+
from transformers import AutoConfig, AutoModel
|
| 290 |
+
|
| 291 |
+
cfg = AutoConfig.from_pretrained("path/to/model_best")
|
| 292 |
+
model = AutoModel.from_pretrained("path/to/model_best", config=cfg)
|
| 293 |
+
```
|
| 294 |
+
- Wrap `LWMBackbone` in a custom `transformers.PreTrainedModel` subclass if you need full pipeline compatibility.
|
| 295 |
+
- Use the same directory structure when publishing to the Hugging Face Hub.
|
| 296 |
+
|
| 297 |
+
---
|
| 298 |
+
|
| 299 |
+
## 12. Reproducibility Checklist
|
| 300 |
+
|
| 301 |
+
- Seed control: `DatasetArgs.seed` (for train/val splits); manual seeding via `torch.manual_seed` and `np.random.seed` happens inside the trainer.
|
| 302 |
+
- Log frequency: `TrainingArgs.log_interval`.
|
| 303 |
+
- Gradient clipping: `TrainingArgs.grad_clip` (defaults to 1.0).
|
| 304 |
+
- Warm-up / Scheduler: cosine decay after a configurable warm-up fraction (`TrainingArgs.warmup_ratio`).
|
| 305 |
+
|
| 306 |
+
---
|
| 307 |
+
|
| 308 |
+
## 13. Getting Help
|
| 309 |
+
|
| 310 |
+
- Issues related to data format, training instabilities, or new features can be logged on your preferred tracking system or discussed with collaborators.
|
| 311 |
+
- For general transformer best practices, refer to the Hugging Face BERT documentation and friends ([link](https://huggingface.co/docs/transformers/en/model_doc/bert?usage=Pipeline)). The workflow above mirrors that style for LWMTemporal.
|
| 312 |
+
|
| 313 |
+
Happy experimenting!
|
| 314 |
+
|
| 315 |
+
## Citation
|
| 316 |
+
|
| 317 |
+
If you use LWMTemporal in your research, please cite:
|
| 318 |
+
|
| 319 |
+
```bibtex
|
| 320 |
+
@article{lwmtemporal2025,
|
| 321 |
+
title={Large Wireless Model for Spatio-Temporal Channel Prediction},
|
| 322 |
+
author={Alikhani, Sadjad and others},
|
| 323 |
+
journal={arXiv preprint arXiv:XXXX.XXXXX},
|
| 324 |
+
year={2025}
|
| 325 |
+
}
|
| 326 |
+
```
|
| 327 |
+
|
| 328 |
+
## License
|
| 329 |
+
|
| 330 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
| 331 |
+
|
| 332 |
+
## Acknowledgments
|
| 333 |
+
|
| 334 |
+
- Built with PyTorch
|
| 335 |
+
- Inspired by Vision Transformer architectures
|
| 336 |
+
- Supports Hugging Face model hub integration
|
| 337 |
+
|
| 338 |
+
## Contact
|
| 339 |
+
|
| 340 |
+
For questions or issues, please:
|
| 341 |
+
- Open an issue on GitHub
|
| 342 |
+
- Contact: sadjad.alikhani@asu.edu
|
| 343 |
+
|
| 344 |
+
## Contributing
|
| 345 |
+
|
| 346 |
+
Contributions are welcome! Please:
|
| 347 |
+
1. Fork the repository
|
| 348 |
+
2. Create a feature branch
|
| 349 |
+
3. Make your changes
|
| 350 |
+
4. Submit a pull request
|
| 351 |
+
|
| 352 |
+
For major changes, please open an issue first to discuss the proposed changes.
|
| 353 |
+
|
checkpoints/README.md
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model Checkpoints
|
| 2 |
+
|
| 3 |
+
This directory contains pretrained model checkpoints.
|
| 4 |
+
|
| 5 |
+
## Available Checkpoints
|
| 6 |
+
|
| 7 |
+
### `m18_cp.pth`
|
| 8 |
+
- **Task**: Channel prediction (fine-tuned)
|
| 9 |
+
- **Architecture**: 12-layer transformer with 32-dim embeddings
|
| 10 |
+
- **Temporal Attention**: Causal (attends to past 7 frames)
|
| 11 |
+
- **Performance**: ~-20 dB NMSE on validation set
|
| 12 |
+
|
| 13 |
+
### `pytorch_model.bin`
|
| 14 |
+
- **Task**: Pretrained backbone
|
| 15 |
+
- **Architecture**: Same as above
|
| 16 |
+
- **Temporal Attention**: Bidirectional
|
| 17 |
+
|
| 18 |
+
## Loading Checkpoints
|
| 19 |
+
|
| 20 |
+
### Python API
|
| 21 |
+
|
| 22 |
+
```python
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from LWMTemporal import LWMBackbone, LWMConfig
|
| 25 |
+
|
| 26 |
+
# Load with default config from checkpoint
|
| 27 |
+
model = LWMBackbone.from_pretrained("checkpoints/m18_cp.pth")
|
| 28 |
+
|
| 29 |
+
# Or override config for fine-tuning
|
| 30 |
+
cfg = LWMConfig(
|
| 31 |
+
temporal_offsets=(-1, -2, -3, -4), # Override for different task
|
| 32 |
+
)
|
| 33 |
+
model = LWMBackbone.from_pretrained("checkpoints/m18_cp.pth", config=cfg)
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
### CLI
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
python -m LWMTemporal.cli.channel_prediction \
|
| 40 |
+
--pretrained checkpoints/m18_cp.pth \
|
| 41 |
+
--data_path examples/data/parow.p \
|
| 42 |
+
--inference_only
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
## Hosting on Hugging Face Hub (Recommended)
|
| 46 |
+
|
| 47 |
+
For production use, upload checkpoints to Hugging Face Hub:
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
huggingface-cli login
|
| 51 |
+
huggingface-cli upload your-username/lwm-temporal checkpoints/m18_cp.pth
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
Then load directly from the hub:
|
| 55 |
+
```python
|
| 56 |
+
model = LWMBackbone.from_pretrained("your-username/lwm-temporal")
|
| 57 |
+
```
|
| 58 |
+
|
checkpoints/config.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"patch_size": [1, 1],
|
| 3 |
+
"phase_mode": "real_imag",
|
| 4 |
+
"embed_dim": 32,
|
| 5 |
+
"depth": 12,
|
| 6 |
+
"num_heads": 8,
|
| 7 |
+
"mlp_ratio": 4.0,
|
| 8 |
+
"same_frame_window": 2,
|
| 9 |
+
"same_frame_window_h": null,
|
| 10 |
+
"same_frame_window_w": null,
|
| 11 |
+
"same_frame_dilation_h": 1,
|
| 12 |
+
"same_frame_dilation_w": 1,
|
| 13 |
+
"temporal_offsets": [-4, -3, -2, -1, 1, 2, 3],
|
| 14 |
+
"temporal_spatial_window": 2,
|
| 15 |
+
"temporal_spatial_window_h": null,
|
| 16 |
+
"temporal_spatial_window_w": null,
|
| 17 |
+
"temporal_spatial_dilation_h": 1,
|
| 18 |
+
"temporal_spatial_dilation_w": 1,
|
| 19 |
+
"temporal_drift_h": 1,
|
| 20 |
+
"temporal_drift_w": 1,
|
| 21 |
+
"spatial_only": false,
|
| 22 |
+
"routing_topk_enable": true,
|
| 23 |
+
"routing_topk_fraction": 0.2,
|
| 24 |
+
"routing_topk_min": 8,
|
| 25 |
+
"routing_topk_max": 32,
|
| 26 |
+
"routing_topk_per_head": true,
|
| 27 |
+
"topk_neighbors": null,
|
| 28 |
+
"topk_per_head": true,
|
| 29 |
+
"global_cls": false,
|
| 30 |
+
"posenc": "learned",
|
| 31 |
+
"rope_base": 10000.0,
|
| 32 |
+
"rope_mode": "flat",
|
| 33 |
+
"rope_base_t": null,
|
| 34 |
+
"rope_base_h": null,
|
| 35 |
+
"rope_base_w": null,
|
| 36 |
+
"max_seq_len": null
|
| 37 |
+
}
|
checkpoints/hist/config.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"patch_size": [1, 1],
|
| 3 |
+
"phase_mode": "real_imag",
|
| 4 |
+
"embed_dim": 32,
|
| 5 |
+
"depth": 12,
|
| 6 |
+
"num_heads": 8,
|
| 7 |
+
"mlp_ratio": 4.0,
|
| 8 |
+
"same_frame_window": 2,
|
| 9 |
+
"same_frame_window_h": null,
|
| 10 |
+
"same_frame_window_w": null,
|
| 11 |
+
"same_frame_dilation_h": 1,
|
| 12 |
+
"same_frame_dilation_w": 1,
|
| 13 |
+
"temporal_offsets": [-4, -3, -2, -1, 1, 2, 3],
|
| 14 |
+
"temporal_spatial_window": 2,
|
| 15 |
+
"temporal_spatial_window_h": null,
|
| 16 |
+
"temporal_spatial_window_w": null,
|
| 17 |
+
"temporal_spatial_dilation_h": 1,
|
| 18 |
+
"temporal_spatial_dilation_w": 1,
|
| 19 |
+
"temporal_drift_h": 1,
|
| 20 |
+
"temporal_drift_w": 1,
|
| 21 |
+
"spatial_only": false,
|
| 22 |
+
"routing_topk_enable": true,
|
| 23 |
+
"routing_topk_fraction": 0.2,
|
| 24 |
+
"routing_topk_min": 8,
|
| 25 |
+
"routing_topk_max": 32,
|
| 26 |
+
"routing_topk_per_head": true,
|
| 27 |
+
"topk_neighbors": null,
|
| 28 |
+
"topk_per_head": true,
|
| 29 |
+
"global_cls": false,
|
| 30 |
+
"posenc": "learned",
|
| 31 |
+
"rope_base": 10000.0,
|
| 32 |
+
"rope_mode": "flat",
|
| 33 |
+
"rope_base_t": null,
|
| 34 |
+
"rope_base_h": null,
|
| 35 |
+
"rope_base_w": null,
|
| 36 |
+
"max_seq_len": null
|
| 37 |
+
}
|
examples/README.md
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LWMTemporal Examples
|
| 2 |
+
|
| 3 |
+
This directory contains example scripts demonstrating how to use the LWMTemporal package.
|
| 4 |
+
|
| 5 |
+
## Quick Start Examples
|
| 6 |
+
|
| 7 |
+
### 1. Masked Reconstruction (`example_reconstruction.py`)
|
| 8 |
+
|
| 9 |
+
Demonstrates how to:
|
| 10 |
+
- Load wireless channel data
|
| 11 |
+
- Tokenize complex channels
|
| 12 |
+
- Mask random positions
|
| 13 |
+
- Reconstruct using the pretrained model
|
| 14 |
+
|
| 15 |
+
```bash
|
| 16 |
+
python examples/example_reconstruction.py
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
### 2. Channel Prediction Inference (`inference_channel_prediction.py`)
|
| 20 |
+
|
| 21 |
+
Run inference with a fine-tuned channel prediction model:
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
python examples/inference_channel_prediction.py
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
Expected output: Per-step NMSE around -20 dB
|
| 28 |
+
|
| 29 |
+
### 3. Train Channel Prediction (`train_channel_prediction.py`)
|
| 30 |
+
|
| 31 |
+
Fine-tune the model for channel prediction:
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
python examples/train_channel_prediction.py
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
This will:
|
| 38 |
+
- Load pretrained weights
|
| 39 |
+
- Fine-tune on your dataset
|
| 40 |
+
- Save checkpoints to `models/`
|
| 41 |
+
- Generate visualizations in `figs/predictions/`
|
| 42 |
+
|
| 43 |
+
## Using the CLI
|
| 44 |
+
|
| 45 |
+
The package also provides command-line interfaces:
|
| 46 |
+
|
| 47 |
+
### Channel Prediction
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
python -m LWMTemporal.cli.channel_prediction \
|
| 51 |
+
--data_path examples/data/city_8_tempe_3p5_20_32_32.p \
|
| 52 |
+
--pretrained checkpoints/m18_cp.pth \
|
| 53 |
+
--inference_only \
|
| 54 |
+
--val_limit 100 \
|
| 55 |
+
--device cpu
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
### Pretraining
|
| 59 |
+
|
| 60 |
+
```bash
|
| 61 |
+
python -m LWMTemporal.cli.pretrain \
|
| 62 |
+
--data_dir examples/data/ \
|
| 63 |
+
--save_prefix models/pretrained \
|
| 64 |
+
--epochs 100 \
|
| 65 |
+
--batch_size 32 \
|
| 66 |
+
--device cuda
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
## Data Format
|
| 70 |
+
|
| 71 |
+
Example data files are in `examples/data/`. See `examples/data/README.md` for details on the expected format.
|
| 72 |
+
|
| 73 |
+
## Checkpoints
|
| 74 |
+
|
| 75 |
+
Pretrained checkpoints are in `checkpoints/`. See `checkpoints/README.md` for available models and loading instructions.
|
| 76 |
+
|
examples/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LWMTemporal usage examples."""
|
| 2 |
+
|
examples/example_reconstruction.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from LWMTemporal.data.datasets import AngleDelayDatasetConfig, AngleDelaySequenceDataset
|
| 9 |
+
from LWMTemporal.models.lwm import (
|
| 10 |
+
LWMBackbone,
|
| 11 |
+
LWMConfig,
|
| 12 |
+
ComplexPatchTokenizer,
|
| 13 |
+
masked_nmse_loss,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
# ----- 1. Load one sequence (complex tensor) -----
|
| 17 |
+
data_cfg = AngleDelayDatasetConfig(raw_path=Path("examples/data/parow.p"))
|
| 18 |
+
dataset = AngleDelaySequenceDataset(data_cfg)
|
| 19 |
+
|
| 20 |
+
sequence = dataset[0]["sequence"].unsqueeze(0) # (1, T, N, M)
|
| 21 |
+
sequence = sequence[:, :11] # keep only the first 11 time steps
|
| 22 |
+
print("Sequence shape:", sequence.shape) # expect (1, 11, 32, 8)
|
| 23 |
+
|
| 24 |
+
# ----- 2. Tokenise and select tokens to mask -----
|
| 25 |
+
tokenizer = ComplexPatchTokenizer(phase_mode="real_imag")
|
| 26 |
+
tokens, base_mask = tokenizer(sequence, patch_size=(1, 1)) # tokens: (B, S, D)
|
| 27 |
+
|
| 28 |
+
B, S, D = tokens.shape
|
| 29 |
+
mask_ratio = 0.60 # choose the fraction to hide
|
| 30 |
+
mask = base_mask.clone()
|
| 31 |
+
|
| 32 |
+
# randomly choose the positions that will be hidden
|
| 33 |
+
for b in range(B):
|
| 34 |
+
num_mask = int(mask_ratio * S)
|
| 35 |
+
masked_positions = torch.randperm(S)[:num_mask]
|
| 36 |
+
mask[b, masked_positions] = True
|
| 37 |
+
|
| 38 |
+
# create the corrupted input by zeroing the masked tokens
|
| 39 |
+
corrupted_tokens = tokens.clone()
|
| 40 |
+
corrupted_tokens[mask] = 0.0
|
| 41 |
+
|
| 42 |
+
# ----- 3. Load the pretrained backbone -----
|
| 43 |
+
# Need max_seq_len >= S (here 11 * 32 * 8 = 2816)
|
| 44 |
+
cfg = LWMConfig(
|
| 45 |
+
patch_size=(1, 1),
|
| 46 |
+
phase_mode="real_imag",
|
| 47 |
+
embed_dim=32,
|
| 48 |
+
depth=12,
|
| 49 |
+
num_heads=8,
|
| 50 |
+
mlp_ratio=4.0,
|
| 51 |
+
same_frame_window=2,
|
| 52 |
+
temporal_offsets=(-4, -3, -2, -1, 1, 2, 3),
|
| 53 |
+
temporal_spatial_window=2,
|
| 54 |
+
temporal_drift_h=1,
|
| 55 |
+
temporal_drift_w=1,
|
| 56 |
+
routing_topk_enable=True,
|
| 57 |
+
topk_per_head=True,
|
| 58 |
+
max_seq_len=2816, # 2816
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
backbone = LWMBackbone.from_pretrained(Path("checkpoints/m18_cp.pth"), config=cfg)
|
| 62 |
+
backbone.eval()
|
| 63 |
+
|
| 64 |
+
# ---- 4. Run reconstruction and compute NMSE on the masked positions -----
|
| 65 |
+
with torch.no_grad():
|
| 66 |
+
# compute H, W from the sequence (N and M dimensions)
|
| 67 |
+
T = sequence.size(1)
|
| 68 |
+
H = sequence.size(2)
|
| 69 |
+
W = sequence.size(3)
|
| 70 |
+
|
| 71 |
+
outputs = backbone.forward_tokens(corrupted_tokens, mask, T, H, W, return_cls=False)
|
| 72 |
+
reconstructed = outputs["reconstruction"]
|
| 73 |
+
|
| 74 |
+
nmse = masked_nmse_loss(reconstructed, tokens, mask)
|
| 75 |
+
nmse_db = 10 * torch.log10(nmse)
|
| 76 |
+
|
| 77 |
+
print(f"Masked {mask_ratio*100:.1f}% of tokens ({mask.sum().item()} / {S})")
|
| 78 |
+
print(f"NMSE (linear): {nmse.item():.6f}")
|
| 79 |
+
print(f"NMSE (dB): {nmse_db.item():.2f} dB")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# import torch
|
| 86 |
+
# from pathlib import Path
|
| 87 |
+
# from LWMTemporal.data.datasets import AngleDelayDatasetConfig, AngleDelaySequenceDataset
|
| 88 |
+
# from LWMTemporal.models.lwm import (
|
| 89 |
+
# LWMBackbone,
|
| 90 |
+
# LWMConfig,
|
| 91 |
+
# ComplexPatchTokenizer,
|
| 92 |
+
# masked_nmse_loss,
|
| 93 |
+
# )
|
| 94 |
+
|
| 95 |
+
# # --- 1. Load one sample from the dataset and keep the first 11 frames ---
|
| 96 |
+
# cfg = AngleDelayDatasetConfig(raw_path=Path("LWMTemporal/data/parow.p"))
|
| 97 |
+
# dataset = AngleDelaySequenceDataset(cfg)
|
| 98 |
+
# sequence = dataset[0]["sequence"].unsqueeze(0)[:, :11] # (1, 11, 32, 8)
|
| 99 |
+
|
| 100 |
+
# # --- 2. Tokenise and randomly mask 40% of the tokens ---
|
| 101 |
+
# tokenizer = ComplexPatchTokenizer(phase_mode="real_imag")
|
| 102 |
+
# tokens, base_mask = tokenizer(sequence, patch_size=(1, 1))
|
| 103 |
+
# mask = base_mask.clone()
|
| 104 |
+
|
| 105 |
+
# B, S, _ = tokens.shape
|
| 106 |
+
# mask_fraction = 0.40
|
| 107 |
+
|
| 108 |
+
# for b in range(B):
|
| 109 |
+
# num_mask = int(mask_fraction * S)
|
| 110 |
+
# masked_positions = torch.randperm(S)[:num_mask]
|
| 111 |
+
# mask[b, masked_positions] = True
|
| 112 |
+
|
| 113 |
+
# corrupted_tokens = tokens.clone()
|
| 114 |
+
# corrupted_tokens[mask] = 0.0
|
| 115 |
+
|
| 116 |
+
# T = sequence.size(1)
|
| 117 |
+
# H = sequence.size(2)
|
| 118 |
+
# W = sequence.size(3)
|
| 119 |
+
|
| 120 |
+
# # --- 3. Helper to run a model and report NMSE ---
|
| 121 |
+
# def run_model(model: LWMBackbone, label: str) -> None:
|
| 122 |
+
# model.eval()
|
| 123 |
+
# with torch.no_grad():
|
| 124 |
+
# outputs = model.forward_tokens(corrupted_tokens, mask, T, H, W, return_cls=False)
|
| 125 |
+
# reconstructed = outputs["reconstruction"]
|
| 126 |
+
# nmse = masked_nmse_loss(reconstructed, tokens, mask)
|
| 127 |
+
# nmse_db = 10 * torch.log10(nmse)
|
| 128 |
+
# print(f"{label:>12}: NMSE = {nmse.item():.6f} ({nmse_db.item():.2f} dB)")
|
| 129 |
+
|
| 130 |
+
# # --- 4. Random-weights model ---
|
| 131 |
+
# cfg_random = LWMConfig(max_seq_len=11 * sequence.size(2) * sequence.size(3))
|
| 132 |
+
# model_random = LWMBackbone(cfg_random)
|
| 133 |
+
# run_model(model_random, "random init")
|
| 134 |
+
|
| 135 |
+
# # --- 5. Pretrained checkpoint ---
|
| 136 |
+
# cfg_pretrained = LWMConfig(max_seq_len=11 * sequence.size(2) * sequence.size(3))
|
| 137 |
+
# model_ckpt = LWMBackbone.from_pretrained(Path("LWMTemporal/models"), config=cfg_pretrained)
|
| 138 |
+
# run_model(model_ckpt, "checkpoint")
|
examples/inference_channel_prediction.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""Example: Run inference with a trained channel prediction model."""
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 7 |
+
|
| 8 |
+
from LWMTemporal.tasks.channel_prediction import (
|
| 9 |
+
ChannelPredictionArgs,
|
| 10 |
+
ChannelPredictionTrainer,
|
| 11 |
+
DatasetArgs,
|
| 12 |
+
ModelArgs,
|
| 13 |
+
TrainingArgs,
|
| 14 |
+
PredictionArgs,
|
| 15 |
+
)
|
| 16 |
+
from LWMTemporal.utils.logging import setup_logging
|
| 17 |
+
|
| 18 |
+
# Setup logging
|
| 19 |
+
logger = setup_logging("channel_prediction_inference", log_dir=Path("logs"))
|
| 20 |
+
|
| 21 |
+
# Configure dataset
|
| 22 |
+
dataset_args = DatasetArgs(
|
| 23 |
+
data_path=Path("examples/data/city_8_tempe_3p5_20_32_32.p"),
|
| 24 |
+
keep_percentage=0.25,
|
| 25 |
+
normalize="global_rms",
|
| 26 |
+
seed=0,
|
| 27 |
+
val_limit=100,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# Configure model
|
| 31 |
+
model_args = ModelArgs(
|
| 32 |
+
patch_size=(1, 1),
|
| 33 |
+
phase_mode="real_imag",
|
| 34 |
+
embed_dim=32,
|
| 35 |
+
depth=12,
|
| 36 |
+
num_heads=8,
|
| 37 |
+
mlp_ratio=4.0,
|
| 38 |
+
same_frame_window=2,
|
| 39 |
+
temporal_offsets=(-1, -2, -3, -4, -5, -6, -7), # Causal attention
|
| 40 |
+
temporal_spatial_window=2,
|
| 41 |
+
temporal_drift_h=1,
|
| 42 |
+
temporal_drift_w=1,
|
| 43 |
+
routing_topk_enable=True,
|
| 44 |
+
routing_topk_fraction=0.2,
|
| 45 |
+
routing_topk_max=32,
|
| 46 |
+
pretrained=Path("checkpoints/m18_cp.pth"),
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Configure training (inference only)
|
| 50 |
+
training_args = TrainingArgs(
|
| 51 |
+
device="cpu",
|
| 52 |
+
batch_size=2,
|
| 53 |
+
inference_only=True,
|
| 54 |
+
inference_split="val",
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Configure prediction
|
| 58 |
+
prediction_args = PredictionArgs(
|
| 59 |
+
Tpast=10,
|
| 60 |
+
horizon=1,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Build full config
|
| 64 |
+
args = ChannelPredictionArgs(
|
| 65 |
+
dataset=dataset_args,
|
| 66 |
+
model=model_args,
|
| 67 |
+
training=training_args,
|
| 68 |
+
prediction=prediction_args,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Run inference
|
| 72 |
+
trainer = ChannelPredictionTrainer(args, logger=logger)
|
| 73 |
+
trainer.train() # train() handles inference_only mode
|
| 74 |
+
|
| 75 |
+
logger.info("Inference complete!")
|
| 76 |
+
|
examples/train_channel_prediction.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""Example: Train a channel prediction model."""
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from LWMTemporal.tasks.channel_prediction import (
|
| 10 |
+
ChannelPredictionArgs,
|
| 11 |
+
ChannelPredictionTrainer,
|
| 12 |
+
DatasetArgs,
|
| 13 |
+
ModelArgs,
|
| 14 |
+
TrainingArgs,
|
| 15 |
+
PredictionArgs,
|
| 16 |
+
)
|
| 17 |
+
from LWMTemporal.utils.logging import setup_logging
|
| 18 |
+
|
| 19 |
+
# Setup logging
|
| 20 |
+
logger = setup_logging("channel_prediction_example", log_dir=Path("logs"))
|
| 21 |
+
|
| 22 |
+
# Configure dataset
|
| 23 |
+
dataset_args = DatasetArgs(
|
| 24 |
+
data_path=Path("examples/data/city_8_tempe_3p5_20_32_32.p"),
|
| 25 |
+
keep_percentage=0.25,
|
| 26 |
+
normalize="global_rms",
|
| 27 |
+
seed=0,
|
| 28 |
+
train_limit=500,
|
| 29 |
+
val_limit=100,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# Configure model
|
| 33 |
+
model_args = ModelArgs(
|
| 34 |
+
patch_size=(1, 1),
|
| 35 |
+
phase_mode="real_imag",
|
| 36 |
+
embed_dim=32,
|
| 37 |
+
depth=12,
|
| 38 |
+
num_heads=8,
|
| 39 |
+
mlp_ratio=4.0,
|
| 40 |
+
same_frame_window=2,
|
| 41 |
+
temporal_offsets=(-1, -2, -3, -4, -5, -6, -7), # Causal attention
|
| 42 |
+
temporal_spatial_window=2,
|
| 43 |
+
temporal_drift_h=1,
|
| 44 |
+
temporal_drift_w=1,
|
| 45 |
+
routing_topk_enable=True,
|
| 46 |
+
routing_topk_fraction=0.2,
|
| 47 |
+
routing_topk_max=32,
|
| 48 |
+
pretrained=Path("checkpoints/m18_cp.pth"), # Load pretrained weights
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Configure training
|
| 52 |
+
training_args = TrainingArgs(
|
| 53 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
| 54 |
+
epochs=10,
|
| 55 |
+
batch_size=16,
|
| 56 |
+
lr=1e-4,
|
| 57 |
+
weight_decay=1e-4,
|
| 58 |
+
warmup_ratio=0.1,
|
| 59 |
+
save_dir=Path("models"),
|
| 60 |
+
use_wandb=False, # Set to True to enable Weights & Biases logging
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Configure prediction
|
| 64 |
+
prediction_args = PredictionArgs(
|
| 65 |
+
Tpast=10,
|
| 66 |
+
horizon=1,
|
| 67 |
+
viz_dir=Path("figs/predictions"),
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Build full config
|
| 71 |
+
args = ChannelPredictionArgs(
|
| 72 |
+
dataset=dataset_args,
|
| 73 |
+
model=model_args,
|
| 74 |
+
training=training_args,
|
| 75 |
+
prediction=prediction_args,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Train
|
| 79 |
+
trainer = ChannelPredictionTrainer(args, logger=logger)
|
| 80 |
+
trainer.train()
|
| 81 |
+
|
| 82 |
+
logger.info("Training complete!")
|
| 83 |
+
|
pyproject.toml
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "lwm-temporal"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "Large Wireless Model (LWM) for spatio-temporal wireless channel representation learning"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.9"
|
| 11 |
+
license = {text = "MIT"}
|
| 12 |
+
authors = [
|
| 13 |
+
{name = "Sadjad Alikhani", email = "alikhani@asu.edu"}
|
| 14 |
+
]
|
| 15 |
+
keywords = ["wireless", "sparse-spatiotemporal-attention", "transformer", "deep-learning", "pytorch"]
|
| 16 |
+
classifiers = [
|
| 17 |
+
"Development Status :: 4 - Beta",
|
| 18 |
+
"Intended Audience :: Science/Research",
|
| 19 |
+
"License :: OSI Approved :: MIT License",
|
| 20 |
+
"Programming Language :: Python :: 3",
|
| 21 |
+
"Programming Language :: Python :: 3.9",
|
| 22 |
+
"Programming Language :: Python :: 3.10",
|
| 23 |
+
"Programming Language :: Python :: 3.11",
|
| 24 |
+
"Programming Language :: Python :: 3.12",
|
| 25 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
dependencies = [
|
| 29 |
+
"torch>=2.0.0",
|
| 30 |
+
"numpy>=1.21.0",
|
| 31 |
+
"matplotlib>=3.5.0",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
[project.optional-dependencies]
|
| 35 |
+
dev = [
|
| 36 |
+
"pytest>=7.0",
|
| 37 |
+
"black>=22.0",
|
| 38 |
+
"flake8>=4.0",
|
| 39 |
+
"mypy>=0.950",
|
| 40 |
+
]
|
| 41 |
+
wandb = [
|
| 42 |
+
"wandb>=0.13.0",
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
[project.urls]
|
| 46 |
+
Homepage = "https://github.com/yourusername/lwm-temporal"
|
| 47 |
+
Repository = "https://github.com/yourusername/lwm-temporal"
|
| 48 |
+
Documentation = "https://github.com/yourusername/lwm-temporal#readme"
|
| 49 |
+
|
| 50 |
+
[project.scripts]
|
| 51 |
+
lwm-pretrain = "LWMTemporal.cli.pretrain:main"
|
| 52 |
+
lwm-channel-prediction = "LWMTemporal.cli.channel_prediction:main"
|
| 53 |
+
|
| 54 |
+
[tool.setuptools.packages.find]
|
| 55 |
+
include = ["LWMTemporal*"]
|
| 56 |
+
exclude = ["tests*", "examples*", "checkpoints*", "cache*", "logs*", "figs*", "wandb*"]
|
| 57 |
+
|
| 58 |
+
[tool.setuptools.package-data]
|
| 59 |
+
LWMTemporal = ["models/config.json"]
|
| 60 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
numpy>=1.21.0
|
| 3 |
+
matplotlib>=3.5.0
|
| 4 |
+
|
setup.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Setup script for LWMTemporal package.
|
| 3 |
+
This is kept for backward compatibility; the package primarily uses pyproject.toml.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from setuptools import setup, find_packages
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
# Read the README
|
| 10 |
+
this_directory = Path(__file__).parent
|
| 11 |
+
long_description = (this_directory / "README.md").read_text(encoding="utf-8")
|
| 12 |
+
|
| 13 |
+
setup(
|
| 14 |
+
name="lwm-temporal",
|
| 15 |
+
version="0.1.0",
|
| 16 |
+
author="Sadjad Alikhani",
|
| 17 |
+
author_email="alikhani@asu.edu",
|
| 18 |
+
description="Large Wireless Model (LWM) for spatio-temporal wireless channel prediction",
|
| 19 |
+
long_description=long_description,
|
| 20 |
+
long_description_content_type="text/markdown",
|
| 21 |
+
url="https://github.com/yourusername/lwm-temporal",
|
| 22 |
+
packages=find_packages(include=["LWMTemporal", "LWMTemporal.*"]),
|
| 23 |
+
package_data={
|
| 24 |
+
"LWMTemporal": ["models/config.json"],
|
| 25 |
+
},
|
| 26 |
+
install_requires=[
|
| 27 |
+
"torch>=2.0.0",
|
| 28 |
+
"numpy>=1.21.0",
|
| 29 |
+
"matplotlib>=3.5.0",
|
| 30 |
+
],
|
| 31 |
+
extras_require={
|
| 32 |
+
"dev": [
|
| 33 |
+
"pytest>=7.0",
|
| 34 |
+
"black>=22.0",
|
| 35 |
+
"flake8>=4.0",
|
| 36 |
+
"mypy>=0.950",
|
| 37 |
+
],
|
| 38 |
+
"wandb": ["wandb>=0.13.0"],
|
| 39 |
+
},
|
| 40 |
+
entry_points={
|
| 41 |
+
"console_scripts": [
|
| 42 |
+
"lwm-pretrain=LWMTemporal.cli.pretrain:main",
|
| 43 |
+
"lwm-channel-prediction=LWMTemporal.cli.channel_prediction:main",
|
| 44 |
+
],
|
| 45 |
+
},
|
| 46 |
+
classifiers=[
|
| 47 |
+
"Development Status :: 4 - Beta",
|
| 48 |
+
"Intended Audience :: Science/Research",
|
| 49 |
+
"License :: OSI Approved :: MIT License",
|
| 50 |
+
"Programming Language :: Python :: 3",
|
| 51 |
+
"Programming Language :: Python :: 3.9",
|
| 52 |
+
"Programming Language :: Python :: 3.10",
|
| 53 |
+
"Programming Language :: Python :: 3.11",
|
| 54 |
+
"Programming Language :: Python :: 3.12",
|
| 55 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 56 |
+
],
|
| 57 |
+
python_requires=">=3.9",
|
| 58 |
+
license="MIT",
|
| 59 |
+
keywords="wireless channel-prediction transformer deep-learning pytorch",
|
| 60 |
+
)
|
| 61 |
+
|
test_package.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""
|
| 3 |
+
Test script to verify the LWMTemporal package is properly structured and functional.
|
| 4 |
+
Run this before releasing to ensure everything works.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
def test_imports():
|
| 11 |
+
"""Test that all core components can be imported."""
|
| 12 |
+
print("Testing imports...")
|
| 13 |
+
try:
|
| 14 |
+
from LWMTemporal import LWMBackbone, LWMConfig, LWMModel, __version__
|
| 15 |
+
from LWMTemporal.data import AngleDelaySequenceDataset, AngleDelayDatasetConfig
|
| 16 |
+
from LWMTemporal.tasks.channel_prediction import ChannelPredictionTrainer
|
| 17 |
+
from LWMTemporal.tasks.pretraining import PretrainingTrainer
|
| 18 |
+
print(f" β All imports successful (version {__version__})")
|
| 19 |
+
return True
|
| 20 |
+
except ImportError as e:
|
| 21 |
+
print(f" β Import failed: {e}")
|
| 22 |
+
return False
|
| 23 |
+
|
| 24 |
+
def test_file_structure():
|
| 25 |
+
"""Test that required files exist."""
|
| 26 |
+
print("\nTesting file structure...")
|
| 27 |
+
required_files = [
|
| 28 |
+
"README.md",
|
| 29 |
+
"LICENSE",
|
| 30 |
+
"pyproject.toml",
|
| 31 |
+
"setup.py",
|
| 32 |
+
"requirements.txt",
|
| 33 |
+
"MANIFEST.in",
|
| 34 |
+
".gitignore",
|
| 35 |
+
"CHANGELOG.md",
|
| 36 |
+
"LWMTemporal/__init__.py",
|
| 37 |
+
"LWMTemporal/models/lwm.py",
|
| 38 |
+
"LWMTemporal/models/config.json",
|
| 39 |
+
"examples/README.md",
|
| 40 |
+
"checkpoints/README.md",
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
all_exist = True
|
| 44 |
+
for file in required_files:
|
| 45 |
+
path = Path(file)
|
| 46 |
+
if path.exists():
|
| 47 |
+
print(f" β {file}")
|
| 48 |
+
else:
|
| 49 |
+
print(f" β {file} NOT FOUND")
|
| 50 |
+
all_exist = False
|
| 51 |
+
|
| 52 |
+
return all_exist
|
| 53 |
+
|
| 54 |
+
def test_checkpoints():
|
| 55 |
+
"""Test that checkpoints are accessible."""
|
| 56 |
+
print("\nTesting checkpoints...")
|
| 57 |
+
checkpoint_dir = Path("checkpoints")
|
| 58 |
+
|
| 59 |
+
if not checkpoint_dir.exists():
|
| 60 |
+
print(f" β Checkpoints directory not found")
|
| 61 |
+
return False
|
| 62 |
+
|
| 63 |
+
checkpoints = list(checkpoint_dir.glob("*.pth")) + list(checkpoint_dir.glob("*.bin"))
|
| 64 |
+
if checkpoints:
|
| 65 |
+
print(f" β Found {len(checkpoints)} checkpoint(s)")
|
| 66 |
+
for ckpt in checkpoints:
|
| 67 |
+
print(f" - {ckpt.name}")
|
| 68 |
+
return True
|
| 69 |
+
else:
|
| 70 |
+
print(f" β No checkpoint files found")
|
| 71 |
+
return False
|
| 72 |
+
|
| 73 |
+
def test_examples():
|
| 74 |
+
"""Test that example files exist."""
|
| 75 |
+
print("\nTesting examples...")
|
| 76 |
+
examples_dir = Path("examples")
|
| 77 |
+
|
| 78 |
+
if not examples_dir.exists():
|
| 79 |
+
print(f" β Examples directory not found")
|
| 80 |
+
return False
|
| 81 |
+
|
| 82 |
+
py_files = list(examples_dir.glob("*.py"))
|
| 83 |
+
if py_files:
|
| 84 |
+
print(f" β Found {len(py_files)} example script(s)")
|
| 85 |
+
for script in py_files:
|
| 86 |
+
print(f" - {script.name}")
|
| 87 |
+
return True
|
| 88 |
+
else:
|
| 89 |
+
print(f" β No example scripts found")
|
| 90 |
+
return False
|
| 91 |
+
|
| 92 |
+
def test_data():
|
| 93 |
+
"""Test that example data exists."""
|
| 94 |
+
print("\nTesting example data...")
|
| 95 |
+
data_dir = Path("examples/data")
|
| 96 |
+
|
| 97 |
+
if not data_dir.exists():
|
| 98 |
+
print(f" β Example data directory not found")
|
| 99 |
+
return False
|
| 100 |
+
|
| 101 |
+
data_files = list(data_dir.glob("*.p"))
|
| 102 |
+
if data_files:
|
| 103 |
+
print(f" β Found {len(data_files)} data file(s)")
|
| 104 |
+
for data_file in data_files:
|
| 105 |
+
size_mb = data_file.stat().st_size / (1024 * 1024)
|
| 106 |
+
print(f" - {data_file.name} ({size_mb:.1f} MB)")
|
| 107 |
+
return True
|
| 108 |
+
else:
|
| 109 |
+
print(f" β No example data files found (optional)")
|
| 110 |
+
return True # Not critical
|
| 111 |
+
|
| 112 |
+
def test_no_data_in_package():
|
| 113 |
+
"""Test that data files are not in the main package."""
|
| 114 |
+
print("\nTesting package cleanliness...")
|
| 115 |
+
package_dir = Path("LWMTemporal")
|
| 116 |
+
|
| 117 |
+
data_files = list(package_dir.rglob("*.p"))
|
| 118 |
+
checkpoints = list(package_dir.rglob("*.pth")) + list(package_dir.rglob("*.bin"))
|
| 119 |
+
|
| 120 |
+
issues = []
|
| 121 |
+
if data_files:
|
| 122 |
+
issues.append(f"Found {len(data_files)} .p files in package (should be in examples/)")
|
| 123 |
+
if checkpoints:
|
| 124 |
+
# config.json is OK, but not checkpoints
|
| 125 |
+
checkpoint_files = [f for f in checkpoints if 'hist' not in str(f)]
|
| 126 |
+
if checkpoint_files:
|
| 127 |
+
issues.append(f"Found checkpoint files in package (should be in checkpoints/)")
|
| 128 |
+
|
| 129 |
+
if issues:
|
| 130 |
+
for issue in issues:
|
| 131 |
+
print(f" β {issue}")
|
| 132 |
+
return False
|
| 133 |
+
else:
|
| 134 |
+
print(f" β Package directory is clean")
|
| 135 |
+
return True
|
| 136 |
+
|
| 137 |
+
def main():
|
| 138 |
+
"""Run all tests."""
|
| 139 |
+
print("=" * 60)
|
| 140 |
+
print("LWMTemporal Package Structure Test")
|
| 141 |
+
print("=" * 60)
|
| 142 |
+
|
| 143 |
+
results = []
|
| 144 |
+
results.append(("Imports", test_imports()))
|
| 145 |
+
results.append(("File Structure", test_file_structure()))
|
| 146 |
+
results.append(("Checkpoints", test_checkpoints()))
|
| 147 |
+
results.append(("Examples", test_examples()))
|
| 148 |
+
results.append(("Example Data", test_data()))
|
| 149 |
+
results.append(("Package Cleanliness", test_no_data_in_package()))
|
| 150 |
+
|
| 151 |
+
print("\n" + "=" * 60)
|
| 152 |
+
print("SUMMARY")
|
| 153 |
+
print("=" * 60)
|
| 154 |
+
|
| 155 |
+
passed = sum(1 for _, result in results if result)
|
| 156 |
+
total = len(results)
|
| 157 |
+
|
| 158 |
+
for name, result in results:
|
| 159 |
+
status = "β PASS" if result else "β FAIL"
|
| 160 |
+
print(f"{status:8} | {name}")
|
| 161 |
+
|
| 162 |
+
print("=" * 60)
|
| 163 |
+
print(f"Result: {passed}/{total} tests passed")
|
| 164 |
+
|
| 165 |
+
if passed == total:
|
| 166 |
+
print("\nπ Package is ready for release!")
|
| 167 |
+
return 0
|
| 168 |
+
else:
|
| 169 |
+
print("\nβ οΈ Some tests failed. Please review and fix.")
|
| 170 |
+
return 1
|
| 171 |
+
|
| 172 |
+
if __name__ == "__main__":
|
| 173 |
+
sys.exit(main())
|
| 174 |
+
|