""" FlareSense v2 - Simple Usage Example This script demonstrates how to use the FlareSense model to predict solar radio bursts on e-Callisto data. The model is automatically downloaded from HuggingFace and cached locally. Usage: python example_usage.py The model will predict on a 15-minute window of data from a specific instrument. """ import torch import numpy as np from datetime import datetime from huggingface_hub import hf_hub_download from ecallisto_ng.data_download.downloader import get_ecallisto_data from ecallisto_ng.data_processing.utils import subtract_constant_background from ecallisto_ng.plotting.plotting import plot_spectrogram from plotly.io import show import torch.nn as nn from torchvision import models import os # ============================================================================ # Model Definition # ============================================================================ class GrayScaleResNet(nn.Module): """ResNet model adapted for grayscale images (single channel).""" def __init__(self, n_classes=1, resnet_type="resnet34"): super().__init__() # Load pretrained ResNet (without num_classes parameter) if resnet_type == "resnet34": self.resnet = models.resnet34(weights=models.ResNet34_Weights.DEFAULT) elif resnet_type == "resnet18": self.resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) elif resnet_type == "resnet50": self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) else: raise ValueError(f"Unsupported resnet_type: {resnet_type}") # Replace the final fully connected layer for our number of classes num_features = self.resnet.fc.in_features self.resnet.fc = nn.Linear(num_features, n_classes) def forward(self, x): # Convert grayscale (1 channel) to 3 channels by expanding if x.size(1) == 1: x = x.expand(-1, 3, -1, -1) return self.resnet(x) # ============================================================================ # Data Processing Functions # ============================================================================ def remove_background(df_spectrogram) -> torch.Tensor: """ Remove constant background from spectrogram DataFrame. Uses the median of the first 300 timepoints as the background. Args: df_spectrogram: Pandas DataFrame with time as index and frequency as columns Returns: Torch tensor with background removed (frequency x time) """ # Subtract constant background using ecallisto_ng function df_processed = subtract_constant_background(df_spectrogram, n=300) # Convert to numpy and transpose to (frequency, time) # DataFrame is (time, frequency), we need (frequency, time) array_processed = df_processed.values.T # Convert to torch tensor tensor = torch.from_numpy(array_processed).float() return tensor def remove_background_median(spectrogram_tensor: torch.Tensor) -> torch.Tensor: """ Remove row-wise median background from spectrogram tensor. This is applied AFTER the constant background subtraction. Args: spectrogram_tensor: Tensor of shape (frequency, time) Returns: Tensor with median background removed """ # Calculate the median of each row (frequency band) median_values = torch.median(spectrogram_tensor, dim=1).values # Subtract the median from each row background_removed = spectrogram_tensor - median_values[:, None] return background_removed def resize_spectrogram(spectrogram_tensor: torch.Tensor, target_size=(128, 512)) -> torch.Tensor: """ Resize spectrogram to target size using bilinear interpolation. Args: spectrogram_tensor: Input tensor (frequency, time) target_size: Target size (height, width) Returns: Resized tensor (1, height, width) """ # Add batch and channel dimensions for interpolation x = spectrogram_tensor.unsqueeze(0).unsqueeze(0) # Resize using bilinear interpolation resized = torch.nn.functional.interpolate( x, size=target_size, mode='bilinear', align_corners=False ) # Remove batch dimension, keep channel dimension (1, H, W) return resized.squeeze(0) def min_max_scale(tensor: torch.Tensor, feature_range=(0, 1)) -> torch.Tensor: """ Apply Min-Max scaling to a tensor. Args: tensor: Input tensor feature_range: Desired range (default: (0, 1)) Returns: Scaled tensor """ min_val, max_val = feature_range tensor_min = tensor.min() tensor_max = tensor.max() # Avoid division by zero if tensor_max - tensor_min == 0: return torch.zeros_like(tensor) scaled_tensor = (tensor - tensor_min) / (tensor_max - tensor_min) scaled_tensor = scaled_tensor * (max_val - min_val) + min_val return scaled_tensor def preprocess_spectrogram(df_spectrogram) -> torch.Tensor: """ Complete preprocessing pipeline for a spectrogram DataFrame. This follows the exact same pipeline as the training code. Args: df_spectrogram: Pandas DataFrame (time x frequency) from get_ecallisto_data Returns: Preprocessed tensor ready for model input (1, 128, 512) """ # Step 1: Remove constant background and convert to tensor (frequency x time) tensor = remove_background(df_spectrogram) # Step 2: Remove row-wise median background tensor = remove_background_median(tensor) # Step 3: Resize to target size (128, 512) # This uses normal_resize since custom_resize is False in config tensor = resize_spectrogram(tensor, target_size=(128, 512)) # Step 4: Min-max scale to [0, 1] tensor = min_max_scale(tensor, feature_range=(0, 1)) return tensor # ============================================================================ # Model Loading and Prediction # ============================================================================ def load_flaresense_model(device="cpu"): """ Load the FlareSense model from HuggingFace Hub. The model is automatically downloaded and cached locally. Args: device: Device to load model on ('cpu' or 'cuda') Returns: Loaded model in evaluation mode """ # Model configuration (from best_v2.yml) REPO_ID = "i4ds/flaresense-v2" MODEL_FILENAME = "model.ckpt" RESNET_TYPE = "resnet34" print(f"Downloading model from {REPO_ID}...") checkpoint_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME) print(f"Model cached at: {checkpoint_path}") # Initialize model model = GrayScaleResNet(n_classes=1, resnet_type=RESNET_TYPE) # Load checkpoint checkpoint = torch.load(checkpoint_path, weights_only=True, map_location=device) if "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] else: state_dict = checkpoint # Remove '_orig_mod.' prefix from keys (added by torch.compile) new_state_dict = {} for key, value in state_dict.items(): new_key = key.replace("_orig_mod.", "") new_state_dict[new_key] = value model.load_state_dict(new_state_dict) # Set to evaluation mode and move to device model.eval() model.to(device) print(f"Model loaded successfully on {device}") return model def sigmoid(x, temperature=0.4974): """ Convert logit to probability using temperature-scaled sigmoid. Args: x: Logit value temperature: Temperature parameter for calibration Returns: Probability [0, 1] """ return 1 / (1 + np.exp(-x / temperature)) def predict_burst(model, df_spectrogram, device="cpu"): """ Predict solar radio burst on a single spectrogram DataFrame. Args: model: Loaded FlareSense model df_spectrogram: Pandas DataFrame (time x frequency) from get_ecallisto_data device: Device to run prediction on Returns: tuple: (logit, probability) - logit: Raw model output - probability: Calibrated probability [0, 1] """ # Preprocess the DataFrame input_tensor = preprocess_spectrogram(df_spectrogram) # Add batch dimension and move to device input_batch = input_tensor.unsqueeze(0).to(device) # Predict with torch.no_grad(): logit = model(input_batch).squeeze().item() # Convert to probability probability = sigmoid(logit) return logit, probability # ============================================================================ # Main Example # ============================================================================ def main(): """Main example demonstrating how to use FlareSense for prediction.""" # Configuration device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}\n") # Example: Predict on data from May 7, 2021 # Create a 15-minute window centered around 03:40:30 # This gives us exactly 15 minutes: 03:33:00 to 03:48:00 start_time = datetime(2021, 5, 7, 3, 33, 0) end_time = datetime(2021, 5, 7, 3, 48, 0) instrument = "Australia-ASSA_01" print(f"Example prediction on instrument: {instrument}") # Load model (downloaded and cached automatically) model = load_flaresense_model(device=device) # Fetch data from e-Callisto print(f"Fetching data from e-Callisto...") df_dict = get_ecallisto_data(start_time, end_time, instrument) df_spectrogram = df_dict[instrument] print(f"Data shape: {df_spectrogram.shape} (time x frequency)") print(f"Time range: {df_spectrogram.index[0]} to {df_spectrogram.index[-1]}") print(f"Frequency range: {df_spectrogram.columns[0]:.2f} - {df_spectrogram.columns[-1]:.2f} MHz\n") # Predict (pass the DataFrame directly) print("Running prediction...") logit, probability = predict_burst(model, df_spectrogram, device=device) # Display results print("\n" + "="*60) print("PREDICTION RESULTS") print("="*60) print(f"Logit: {logit:.4f}") print(f"Probability: {probability:.4f} ({probability*100:.2f}%)") burst_detected = probability > 0.5 print(f"Prediction: {'BURST DETECTED ☀️' if burst_detected else 'No burst'}") print("="*60) # Plot and save the spectrogram print("\nGenerating spectrogram plot...") df_processed = subtract_constant_background(df_dict[instrument]) # Show the plot fig = plot_spectrogram(df_processed) show(fig) if __name__ == "__main__": main()