lwm-temporal / examples /inference_channel_prediction.py
Sadjad Alikhani
Initial commit
164610c
#!/usr/bin/env python
"""Example: Run inference with a trained channel prediction model."""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from LWMTemporal.tasks.channel_prediction import (
ChannelPredictionArgs,
ChannelPredictionTrainer,
DatasetArgs,
ModelArgs,
TrainingArgs,
PredictionArgs,
)
from LWMTemporal.utils.logging import setup_logging
# Setup logging
logger = setup_logging("channel_prediction_inference", log_dir=Path("logs"))
# Configure dataset
dataset_args = DatasetArgs(
data_path=Path("examples/data/city_8_tempe_3p5_20_32_32.p"),
keep_percentage=0.25,
normalize="global_rms",
seed=0,
val_limit=100,
)
# Configure model
model_args = ModelArgs(
patch_size=(1, 1),
phase_mode="real_imag",
embed_dim=32,
depth=12,
num_heads=8,
mlp_ratio=4.0,
same_frame_window=2,
temporal_offsets=(-1, -2, -3, -4, -5, -6, -7), # Causal attention
temporal_spatial_window=2,
temporal_drift_h=1,
temporal_drift_w=1,
routing_topk_enable=True,
routing_topk_fraction=0.2,
routing_topk_max=32,
pretrained=Path("checkpoints/m18_cp.pth"),
)
# Configure training (inference only)
training_args = TrainingArgs(
device="cpu",
batch_size=2,
inference_only=True,
inference_split="val",
)
# Configure prediction
prediction_args = PredictionArgs(
Tpast=10,
horizon=1,
)
# Build full config
args = ChannelPredictionArgs(
dataset=dataset_args,
model=model_args,
training=training_args,
prediction=prediction_args,
)
# Run inference
trainer = ChannelPredictionTrainer(args, logger=logger)
trainer.train() # train() handles inference_only mode
logger.info("Inference complete!")