Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class HelioFlowModel(nn.Module): | |
| def __init__(self, img_size=(4096, 4096), use_latitude_in_learned_flow=False): | |
| super().__init__() | |
| self.use_latitude_in_learned_flow = use_latitude_in_learned_flow | |
| u = torch.linspace(-1, 1, img_size[0]) | |
| v = torch.linspace(-1, 1, img_size[1]) | |
| u, v = torch.meshgrid(u, v, indexing="xy") | |
| self.register_buffer( | |
| "grid", torch.stack((u, v), dim=2).view(1, *img_size, 2) | |
| ) # B, H, W, 2 | |
| # Higher modes can be used for explicit feature engineering for flow features. | |
| if self.use_latitude_in_learned_flow: | |
| higher_modes = [u, v, torch.ones_like(u)] | |
| else: | |
| higher_modes = [ | |
| u, | |
| v, | |
| ] | |
| self.register_buffer( | |
| "higher_modes", torch.stack(higher_modes, dim=2).view(1, *img_size, -1) | |
| ) | |
| self.flow_generator = nn.Sequential( | |
| nn.Linear(self.higher_modes.shape[3], 128), | |
| nn.GELU(), | |
| nn.Linear(128, 2), | |
| ) | |
| def forward(self, batch): | |
| """ | |
| Args: | |
| batch: Dictionary containing keys `ts` and | |
| `forecast_latitude` (optionally). | |
| ts (torch.Tensor): B, C, T, H, W | |
| forecast_latitude (torch.Tensor): B, L | |
| B - Batch size, C - Channels, T - Input times, H - Image height, | |
| W - Image width, L - Lead time. | |
| """ | |
| x = batch["ts"] | |
| B, C, T, H, W = x.shape | |
| if T == 1: | |
| x = x[:, :, -1, :, :] | |
| else: | |
| # Taking the average of the last two time stamps | |
| x = (x[:, :, -1, :, :] + x[:, :, -2, :, :]) / 2 | |
| # Flow fields have the shape B, H_out, W_out, 2 | |
| if self.use_latitude_in_learned_flow: | |
| broadcast_lat = batch["forecast_latitude"] / 7 | |
| broadcast_lat = torch.concatenate( | |
| [ | |
| torch.ones_like(broadcast_lat), | |
| torch.ones_like(broadcast_lat), | |
| broadcast_lat, | |
| ], | |
| 1, | |
| )[:, None, None, :] | |
| higher_modes = self.higher_modes * broadcast_lat | |
| flow_field = self.grid + self.flow_generator(higher_modes) | |
| else: | |
| flow_field = self.grid + self.flow_generator(self.higher_modes) | |
| flow_field = flow_field.expand(B, H, W, 2) | |
| y_hat = F.grid_sample( | |
| x, | |
| flow_field, | |
| mode="bilinear", | |
| padding_mode="border", # Possible values: zeros, border, or reflection. | |
| align_corners=False, | |
| ) | |
| return y_hat | |