Spaces:
Runtime error
Runtime error
| import yaml | |
| import logging | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| import matplotlib.pyplot as plt | |
| import sunpy.visualization.colormaps as sunpy_cm | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| from surya.datasets.helio import HelioNetCDFDataset, inverse_transform_single_channel | |
| from surya.models.helio_spectformer import HelioSpectFormer | |
| from surya.utils.data import build_scalers, custom_collate_fn | |
| logger = logging.getLogger(__name__) | |
| SDO_CHANNELS = [ | |
| "aia94", | |
| "aia131", | |
| "aia171", | |
| "aia193", | |
| "aia211", | |
| "aia304", | |
| "aia335", | |
| "aia1600", | |
| "hmi_m", | |
| "hmi_bx", | |
| "hmi_by", | |
| "hmi_bz", | |
| "hmi_v", | |
| ] | |
| class SDOImage: | |
| channel: str | |
| data: np.ndarray | |
| timestamp: str | |
| type: str | |
| def download_data(): | |
| snapshot_download( | |
| repo_id="nasa-ibm-ai4science/Surya-1.0", | |
| local_dir="data/Surya-1.0", | |
| allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"], | |
| token=None, | |
| ) | |
| snapshot_download( | |
| repo_id="nasa-ibm-ai4science/Surya-1.0_validation_data", | |
| repo_type="dataset", | |
| local_dir="data/Surya-1.0_validation_data", | |
| allow_patterns="20140107_1[5-9]??.nc", | |
| token=None, | |
| ) | |
| def get_dataset(config, scalers) -> HelioNetCDFDataset: | |
| dataset = HelioNetCDFDataset( | |
| index_path="tests/test_surya_index.csv", | |
| time_delta_input_minutes=config["data"]["time_delta_input_minutes"], | |
| time_delta_target_minutes=config["data"]["time_delta_target_minutes"], | |
| n_input_timestamps=len(config["data"]["time_delta_input_minutes"]), | |
| rollout_steps=0, | |
| channels=config["data"]["sdo_channels"], | |
| drop_hmi_probability=config["data"]["drop_hmi_probability"], | |
| num_mask_aia_channels=config["data"]["num_mask_aia_channels"], | |
| use_latitude_in_learned_flow=config["data"]["use_latitude_in_learned_flow"], | |
| scalers=scalers, | |
| phase="valid", | |
| pooling=config["data"]["pooling"], | |
| random_vert_flip=config["data"]["random_vert_flip"], | |
| ) | |
| logger.info(f"Initialized the dataset. {len(dataset)} samples.") | |
| return dataset | |
| def get_scalers() -> dict: | |
| scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r")) | |
| scalers = build_scalers(info=scalers_info) | |
| logger.info("Built the scalers.") | |
| return scalers | |
| def get_model_from_config(config) -> HelioSpectFormer: | |
| model = HelioSpectFormer( | |
| img_size=config["model"]["img_size"], | |
| patch_size=config["model"]["patch_size"], | |
| in_chans=len(config["data"]["sdo_channels"]), | |
| embed_dim=config["model"]["embed_dim"], | |
| time_embedding={ | |
| "type": "linear", | |
| "time_dim": len(config["data"]["time_delta_input_minutes"]), | |
| }, | |
| depth=config["model"]["depth"], | |
| n_spectral_blocks=config["model"]["n_spectral_blocks"], | |
| num_heads=config["model"]["num_heads"], | |
| mlp_ratio=config["model"]["mlp_ratio"], | |
| drop_rate=config["model"]["drop_rate"], | |
| dtype=torch.bfloat16, | |
| window_size=config["model"]["window_size"], | |
| dp_rank=config["model"]["dp_rank"], | |
| learned_flow=config["model"]["learned_flow"], | |
| use_latitude_in_learned_flow=config["model"]["learned_flow"], | |
| init_weights=False, | |
| checkpoint_layers=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], | |
| rpe=config["model"]["rpe"], | |
| ensemble=config["model"]["ensemble"], | |
| finetune=config["model"]["finetune"], | |
| ) | |
| logger.info("Initialized the model.") | |
| return model | |
| def get_config() -> dict: | |
| with open("data/Surya-1.0/config.yaml") as fp: | |
| config = yaml.safe_load(fp) | |
| return config | |
| def setup(): | |
| logger.info("Loading data ...") | |
| download_data() | |
| config = get_config() | |
| scalers = get_scalers() | |
| logger.info("Initializing dataset ...") | |
| dataset = get_dataset(config, scalers) | |
| logger.info("Initializing model ...") | |
| model = get_model_from_config(config) | |
| if torch.cuda.is_available(): | |
| device = torch.cuda.current_device() | |
| logger.info(f"GPU detected. Running the test on device {device}.") | |
| else: | |
| device = "cpu" | |
| logger.warning(f"No GPU detected. Running the test on CPU.") | |
| model.to(device) | |
| n_parameters = sum(p.numel() for p in model.parameters()) / 1e6 | |
| logger.info(f"Surya FM: {n_parameters:.2f} M total parameters.") | |
| path_weights = "data/Surya-1.0/surya.366m.v1.pt" | |
| weights = torch.load( | |
| path_weights, map_location=torch.device(device), weights_only=True | |
| ) | |
| model.load_state_dict(weights, strict=True) | |
| logger.info("Loaded weights.") | |
| return dataset, model, device | |
| def batch_step( | |
| model: HelioSpectFormer, | |
| sample_data: dict, | |
| sample_metadata: dict, | |
| device: int | str, | |
| hours_ahead: int = 1, | |
| ) -> np.ndarray: | |
| """ | |
| Perform a single batch step for the given model, batch data, metadata, and device. | |
| Args: | |
| model: The PyTorch model to use for prediction. | |
| sample_data: A dictionary containing input and target data for the batch. | |
| sample_metadata: A dictionary containing metadata for the batch, including timestamps. | |
| device: The device to use for computation ('cpu', 'cuda' or device number). | |
| hours_ahead: The number of steps to forecast ahead. Defaults to 1. | |
| Returns: | |
| np.ndarray: Output data. | |
| """ | |
| data_returned = [] | |
| forecast_hat = None # Initialize forecast_hat | |
| for step in range(1, hours_ahead + 1): | |
| if step == 1: | |
| curr_batch = { | |
| key: torch.from_numpy(sample_data[key]).unsqueeze(0).to(device) | |
| for key in ["ts", "time_delta_input"] | |
| } | |
| else: | |
| # Use the previous forecast_hat from the previous iteration | |
| if forecast_hat is not None: | |
| curr_batch["ts"] = torch.cat( | |
| (curr_batch["ts"][:, :, 1:, ...], forecast_hat[:, :, None, ...]), | |
| dim=2, | |
| ) | |
| forecast_hat = model(curr_batch) | |
| data_returned = forecast_hat.to(dtype=torch.float32).cpu().squeeze(0).numpy() | |
| return data_returned | |
| def run_inference(init_time_idx, plt_channel_idx, hours_ahead): | |
| plt_channel_str = SDO_CHANNELS[plt_channel_idx] | |
| input_timestamp_1 = dataset.valid_indices[init_time_idx] | |
| input_timestamp_0 = input_timestamp_1 - pd.Timedelta(1, "h") | |
| output_timestamp = input_timestamp_1 + pd.Timedelta(int(hours_ahead), "h") | |
| input_timestamp_0 = input_timestamp_0.strftime("%Y-%m-%d %H:%M") | |
| input_timestamp_1 = input_timestamp_1.strftime("%Y-%m-%d %H:%M") | |
| output_timestamp = output_timestamp.strftime("%Y-%m-%d %H:%M") | |
| sample_data, sample_metadata = dataset[init_time_idx] | |
| with torch.no_grad(): | |
| model_output = batch_step( | |
| model, | |
| sample_data, | |
| sample_metadata, | |
| device, | |
| hours_ahead | |
| ) | |
| means, stds, epsilons, sl_scale_factors = dataset.transformation_inputs() | |
| vmin = float("-inf") | |
| vmax = float("inf") | |
| input_image = [] | |
| for i in range(2): | |
| input_image.append( | |
| inverse_transform_single_channel( | |
| sample_data["ts"][plt_channel_idx, i], | |
| mean=means[plt_channel_idx], | |
| std=stds[plt_channel_idx], | |
| epsilon=epsilons[plt_channel_idx], | |
| sl_scale_factor=sl_scale_factors[plt_channel_idx], | |
| ) | |
| ) | |
| vmin = max(vmin, input_image[i].min()) | |
| vmax = min(vmax, np.quantile(input_image[i], 0.99)) | |
| if plt_channel_str.startswith("aia"): | |
| cm_name = "sdo" + plt_channel_str | |
| else: | |
| cm_name = "hmimag" | |
| input_image = [ | |
| sunpy_cm.cmlist[cm_name]( | |
| (img[::-1]-vmin) / (vmax-vmin), bytes=True | |
| ) | |
| for img in input_image | |
| ] | |
| output_image = inverse_transform_single_channel( | |
| model_output[plt_channel_idx], | |
| mean=means[plt_channel_idx], | |
| std=stds[plt_channel_idx], | |
| epsilon=epsilons[plt_channel_idx], | |
| sl_scale_factor=sl_scale_factors[plt_channel_idx], | |
| ) | |
| output_image = sunpy_cm.cmlist[cm_name]( | |
| (output_image[::-1]-vmin) / (vmax-vmin), bytes=True | |
| ) | |
| return input_timestamp_0, input_image[0], input_timestamp_1, input_image[1], output_timestamp, output_image | |
| logging.basicConfig(level=logging.INFO) | |
| dataset, model, device = setup() | |
| with gr.Blocks() as demo: | |
| gr.Markdown(value="# Surya 1.0 - Visual forecasting demo") | |
| #with gr.Row(): | |
| #with gr.Column(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| init_time = gr.Dropdown( | |
| [v.strftime("%Y-%m-%d %H:%M") for v in dataset.valid_indices], | |
| label="Initialization time", | |
| multiselect=False, | |
| type="index", | |
| ) | |
| with gr.Column(): | |
| plt_channel = gr.Dropdown( | |
| [c.upper() for c in SDO_CHANNELS], | |
| label="SDO Band", | |
| value="AIA94", | |
| multiselect=False, | |
| type="index" | |
| ) | |
| with gr.Row(): | |
| hours_ahead = gr.Slider(minimum=1.0, maximum=6.0, step=1.0, label="Forcast step [hours ahead]") | |
| with gr.Row(): | |
| btn = gr.Button("Run") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_timestamp_0 = gr.Textbox(label="Input 0") | |
| input_image_0 = gr.Image() | |
| with gr.Column(): | |
| input_timestamp_1 = gr.Textbox(label="Input 1") | |
| input_image_1 = gr.Image() | |
| with gr.Column(): | |
| output_timestamp = gr.Textbox(label="Prediction") | |
| output_image = gr.Image() | |
| btn.click( | |
| fn=run_inference, | |
| inputs=[init_time, plt_channel, hours_ahead], | |
| outputs=[input_timestamp_0, input_image_0, input_timestamp_1, input_image_1, output_timestamp, output_image] | |
| ) | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=[ | |
| ["2014-01-07 17:24", "AIA94", 2], | |
| ["2014-01-07 16:12", "AIA94", 6], | |
| ["2014-01-07 16:00", "AIA131", 1], | |
| ["2014-01-07 16:00", "HMI_M", 2], | |
| ], | |
| fn=run_inference, | |
| inputs=[init_time, plt_channel, hours_ahead], | |
| outputs=[input_timestamp_0, input_image_0, input_timestamp_1, input_image_1, output_timestamp, output_image], | |
| cache_examples=False, | |
| ) | |
| demo.launch() | |