| |
| """Example: Run inference with a trained channel prediction model.""" |
|
|
| import sys |
| from pathlib import Path |
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| from LWMTemporal.tasks.channel_prediction import ( |
| ChannelPredictionArgs, |
| ChannelPredictionTrainer, |
| DatasetArgs, |
| ModelArgs, |
| TrainingArgs, |
| PredictionArgs, |
| ) |
| from LWMTemporal.utils.logging import setup_logging |
|
|
| |
| logger = setup_logging("channel_prediction_inference", log_dir=Path("logs")) |
|
|
| |
| dataset_args = DatasetArgs( |
| data_path=Path("examples/data/city_8_tempe_3p5_20_32_32.p"), |
| keep_percentage=0.25, |
| normalize="global_rms", |
| seed=0, |
| val_limit=100, |
| ) |
|
|
| |
| model_args = ModelArgs( |
| patch_size=(1, 1), |
| phase_mode="real_imag", |
| embed_dim=32, |
| depth=12, |
| num_heads=8, |
| mlp_ratio=4.0, |
| same_frame_window=2, |
| temporal_offsets=(-1, -2, -3, -4, -5, -6, -7), |
| temporal_spatial_window=2, |
| temporal_drift_h=1, |
| temporal_drift_w=1, |
| routing_topk_enable=True, |
| routing_topk_fraction=0.2, |
| routing_topk_max=32, |
| pretrained=Path("checkpoints/m18_cp.pth"), |
| ) |
|
|
| |
| training_args = TrainingArgs( |
| device="cpu", |
| batch_size=2, |
| inference_only=True, |
| inference_split="val", |
| ) |
|
|
| |
| prediction_args = PredictionArgs( |
| Tpast=10, |
| horizon=1, |
| ) |
|
|
| |
| args = ChannelPredictionArgs( |
| dataset=dataset_args, |
| model=model_args, |
| training=training_args, |
| prediction=prediction_args, |
| ) |
|
|
| |
| trainer = ChannelPredictionTrainer(args, logger=logger) |
| trainer.train() |
|
|
| logger.info("Inference complete!") |
|
|
|
|