| """Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" |
|
|
| import dataclasses |
| from typing import Generator |
|
|
| import torch |
|
|
| from aurora.batch import Batch |
| from aurora.model.aurora import Aurora |
|
|
| __all__ = ["rollout"] |
|
|
|
|
| def rollout(model: Aurora, batch: Batch, steps: int) -> Generator[Batch, None, None]: |
| """Perform a roll-out to make long-term predictions. |
| |
| Args: |
| model (:class:`aurora.model.aurora.Aurora`): The model to roll out. |
| batch (:class:`aurora.batch.Batch`): The batch to start the roll-out from. |
| steps (int): The number of roll-out steps. |
| |
| Yields: |
| :class:`aurora.batch.Batch`: The prediction after every step. |
| """ |
| |
| |
| p = next(model.parameters()) |
| batch = batch.type(p.dtype) |
| batch = batch.crop(model.patch_size) |
| batch = batch.to(p.device) |
|
|
| for _ in range(steps): |
| pred = model.forward(batch) |
|
|
| yield pred |
|
|
| |
| batch = dataclasses.replace( |
| pred, |
| surf_vars={ |
| k: torch.cat([batch.surf_vars[k][:, 1:], v], dim=1) |
| for k, v in pred.surf_vars.items() |
| }, |
| atmos_vars={ |
| k: torch.cat([batch.atmos_vars[k][:, 1:], v], dim=1) |
| for k, v in pred.atmos_vars.items() |
| }, |
| ) |
|
|