WeatherPEFT / aurora /rollout.py
bidulki-99's picture
Add files using upload-large-folder tool
9d66a40 verified
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""
import dataclasses
from typing import Generator
import torch
from aurora.batch import Batch
from aurora.model.aurora import Aurora
__all__ = ["rollout"]
def rollout(model: Aurora, batch: Batch, steps: int) -> Generator[Batch, None, None]:
"""Perform a roll-out to make long-term predictions.
Args:
model (:class:`aurora.model.aurora.Aurora`): The model to roll out.
batch (:class:`aurora.batch.Batch`): The batch to start the roll-out from.
steps (int): The number of roll-out steps.
Yields:
:class:`aurora.batch.Batch`: The prediction after every step.
"""
# We will need to concatenate data, so ensure that everything is already of the right form.
# Use an arbitary parameter of the model to derive the data type and device.
p = next(model.parameters())
batch = batch.type(p.dtype)
batch = batch.crop(model.patch_size)
batch = batch.to(p.device)
for _ in range(steps):
pred = model.forward(batch)
yield pred
# Add the appropriate history so the model can be run on the prediction.
batch = dataclasses.replace(
pred,
surf_vars={
k: torch.cat([batch.surf_vars[k][:, 1:], v], dim=1)
for k, v in pred.surf_vars.items()
},
atmos_vars={
k: torch.cat([batch.atmos_vars[k][:, 1:], v], dim=1)
for k, v in pred.atmos_vars.items()
},
)