|
|
"""
|
|
|
FlareSense v2 - Simple Usage Example
|
|
|
|
|
|
This script demonstrates how to use the FlareSense model to predict solar radio bursts
|
|
|
on e-Callisto data. The model is automatically downloaded from HuggingFace and cached locally.
|
|
|
|
|
|
Usage:
|
|
|
python example_usage.py
|
|
|
|
|
|
The model will predict on a 15-minute window of data from a specific instrument.
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
from datetime import datetime
|
|
|
from huggingface_hub import hf_hub_download
|
|
|
from ecallisto_ng.data_download.downloader import get_ecallisto_data
|
|
|
from ecallisto_ng.data_processing.utils import subtract_constant_background
|
|
|
from ecallisto_ng.plotting.plotting import plot_spectrogram
|
|
|
from plotly.io import show
|
|
|
import torch.nn as nn
|
|
|
from torchvision import models
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GrayScaleResNet(nn.Module):
|
|
|
"""ResNet model adapted for grayscale images (single channel)."""
|
|
|
|
|
|
def __init__(self, n_classes=1, resnet_type="resnet34"):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
if resnet_type == "resnet34":
|
|
|
self.resnet = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
|
|
|
elif resnet_type == "resnet18":
|
|
|
self.resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
|
|
|
elif resnet_type == "resnet50":
|
|
|
self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
|
|
else:
|
|
|
raise ValueError(f"Unsupported resnet_type: {resnet_type}")
|
|
|
|
|
|
|
|
|
num_features = self.resnet.fc.in_features
|
|
|
self.resnet.fc = nn.Linear(num_features, n_classes)
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
if x.size(1) == 1:
|
|
|
x = x.expand(-1, 3, -1, -1)
|
|
|
return self.resnet(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def remove_background(df_spectrogram) -> torch.Tensor:
|
|
|
"""
|
|
|
Remove constant background from spectrogram DataFrame.
|
|
|
Uses the median of the first 300 timepoints as the background.
|
|
|
|
|
|
Args:
|
|
|
df_spectrogram: Pandas DataFrame with time as index and frequency as columns
|
|
|
|
|
|
Returns:
|
|
|
Torch tensor with background removed (frequency x time)
|
|
|
"""
|
|
|
|
|
|
df_processed = subtract_constant_background(df_spectrogram, n=300)
|
|
|
|
|
|
|
|
|
|
|
|
array_processed = df_processed.values.T
|
|
|
|
|
|
|
|
|
tensor = torch.from_numpy(array_processed).float()
|
|
|
|
|
|
return tensor
|
|
|
|
|
|
|
|
|
def remove_background_median(spectrogram_tensor: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Remove row-wise median background from spectrogram tensor.
|
|
|
This is applied AFTER the constant background subtraction.
|
|
|
|
|
|
Args:
|
|
|
spectrogram_tensor: Tensor of shape (frequency, time)
|
|
|
|
|
|
Returns:
|
|
|
Tensor with median background removed
|
|
|
"""
|
|
|
|
|
|
median_values = torch.median(spectrogram_tensor, dim=1).values
|
|
|
|
|
|
|
|
|
background_removed = spectrogram_tensor - median_values[:, None]
|
|
|
|
|
|
return background_removed
|
|
|
|
|
|
|
|
|
def resize_spectrogram(spectrogram_tensor: torch.Tensor, target_size=(128, 512)) -> torch.Tensor:
|
|
|
"""
|
|
|
Resize spectrogram to target size using bilinear interpolation.
|
|
|
|
|
|
Args:
|
|
|
spectrogram_tensor: Input tensor (frequency, time)
|
|
|
target_size: Target size (height, width)
|
|
|
|
|
|
Returns:
|
|
|
Resized tensor (1, height, width)
|
|
|
"""
|
|
|
|
|
|
x = spectrogram_tensor.unsqueeze(0).unsqueeze(0)
|
|
|
|
|
|
|
|
|
resized = torch.nn.functional.interpolate(
|
|
|
x, size=target_size, mode='bilinear', align_corners=False
|
|
|
)
|
|
|
|
|
|
|
|
|
return resized.squeeze(0)
|
|
|
|
|
|
|
|
|
def min_max_scale(tensor: torch.Tensor, feature_range=(0, 1)) -> torch.Tensor:
|
|
|
"""
|
|
|
Apply Min-Max scaling to a tensor.
|
|
|
|
|
|
Args:
|
|
|
tensor: Input tensor
|
|
|
feature_range: Desired range (default: (0, 1))
|
|
|
|
|
|
Returns:
|
|
|
Scaled tensor
|
|
|
"""
|
|
|
min_val, max_val = feature_range
|
|
|
tensor_min = tensor.min()
|
|
|
tensor_max = tensor.max()
|
|
|
|
|
|
|
|
|
if tensor_max - tensor_min == 0:
|
|
|
return torch.zeros_like(tensor)
|
|
|
|
|
|
scaled_tensor = (tensor - tensor_min) / (tensor_max - tensor_min)
|
|
|
scaled_tensor = scaled_tensor * (max_val - min_val) + min_val
|
|
|
|
|
|
return scaled_tensor
|
|
|
|
|
|
|
|
|
def preprocess_spectrogram(df_spectrogram) -> torch.Tensor:
|
|
|
"""
|
|
|
Complete preprocessing pipeline for a spectrogram DataFrame.
|
|
|
This follows the exact same pipeline as the training code.
|
|
|
|
|
|
Args:
|
|
|
df_spectrogram: Pandas DataFrame (time x frequency) from get_ecallisto_data
|
|
|
|
|
|
Returns:
|
|
|
Preprocessed tensor ready for model input (1, 128, 512)
|
|
|
"""
|
|
|
|
|
|
tensor = remove_background(df_spectrogram)
|
|
|
|
|
|
|
|
|
tensor = remove_background_median(tensor)
|
|
|
|
|
|
|
|
|
|
|
|
tensor = resize_spectrogram(tensor, target_size=(128, 512))
|
|
|
|
|
|
|
|
|
tensor = min_max_scale(tensor, feature_range=(0, 1))
|
|
|
|
|
|
return tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_flaresense_model(device="cpu"):
|
|
|
"""
|
|
|
Load the FlareSense model from HuggingFace Hub.
|
|
|
The model is automatically downloaded and cached locally.
|
|
|
|
|
|
Args:
|
|
|
device: Device to load model on ('cpu' or 'cuda')
|
|
|
|
|
|
Returns:
|
|
|
Loaded model in evaluation mode
|
|
|
"""
|
|
|
|
|
|
REPO_ID = "i4ds/flaresense-v2"
|
|
|
MODEL_FILENAME = "model.ckpt"
|
|
|
RESNET_TYPE = "resnet34"
|
|
|
|
|
|
print(f"Downloading model from {REPO_ID}...")
|
|
|
checkpoint_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
|
|
|
print(f"Model cached at: {checkpoint_path}")
|
|
|
|
|
|
|
|
|
model = GrayScaleResNet(n_classes=1, resnet_type=RESNET_TYPE)
|
|
|
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, weights_only=True, map_location=device)
|
|
|
if "state_dict" in checkpoint:
|
|
|
state_dict = checkpoint["state_dict"]
|
|
|
else:
|
|
|
state_dict = checkpoint
|
|
|
|
|
|
|
|
|
new_state_dict = {}
|
|
|
for key, value in state_dict.items():
|
|
|
new_key = key.replace("_orig_mod.", "")
|
|
|
new_state_dict[new_key] = value
|
|
|
|
|
|
model.load_state_dict(new_state_dict)
|
|
|
|
|
|
|
|
|
model.eval()
|
|
|
model.to(device)
|
|
|
|
|
|
print(f"Model loaded successfully on {device}")
|
|
|
return model
|
|
|
|
|
|
|
|
|
def sigmoid(x, temperature=0.4974):
|
|
|
"""
|
|
|
Convert logit to probability using temperature-scaled sigmoid.
|
|
|
|
|
|
Args:
|
|
|
x: Logit value
|
|
|
temperature: Temperature parameter for calibration
|
|
|
|
|
|
Returns:
|
|
|
Probability [0, 1]
|
|
|
"""
|
|
|
return 1 / (1 + np.exp(-x / temperature))
|
|
|
|
|
|
|
|
|
def predict_burst(model, df_spectrogram, device="cpu"):
|
|
|
"""
|
|
|
Predict solar radio burst on a single spectrogram DataFrame.
|
|
|
|
|
|
Args:
|
|
|
model: Loaded FlareSense model
|
|
|
df_spectrogram: Pandas DataFrame (time x frequency) from get_ecallisto_data
|
|
|
device: Device to run prediction on
|
|
|
|
|
|
Returns:
|
|
|
tuple: (logit, probability)
|
|
|
- logit: Raw model output
|
|
|
- probability: Calibrated probability [0, 1]
|
|
|
"""
|
|
|
|
|
|
input_tensor = preprocess_spectrogram(df_spectrogram)
|
|
|
|
|
|
|
|
|
input_batch = input_tensor.unsqueeze(0).to(device)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
logit = model(input_batch).squeeze().item()
|
|
|
|
|
|
|
|
|
probability = sigmoid(logit)
|
|
|
|
|
|
return logit, probability
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
"""Main example demonstrating how to use FlareSense for prediction."""
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
print(f"Using device: {device}\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start_time = datetime(2021, 5, 7, 3, 33, 0)
|
|
|
end_time = datetime(2021, 5, 7, 3, 48, 0)
|
|
|
|
|
|
instrument = "Australia-ASSA_01"
|
|
|
|
|
|
print(f"Example prediction on instrument: {instrument}")
|
|
|
|
|
|
|
|
|
model = load_flaresense_model(device=device)
|
|
|
|
|
|
|
|
|
print(f"Fetching data from e-Callisto...")
|
|
|
df_dict = get_ecallisto_data(start_time, end_time, instrument)
|
|
|
|
|
|
|
|
|
df_spectrogram = df_dict[instrument]
|
|
|
|
|
|
print(f"Data shape: {df_spectrogram.shape} (time x frequency)")
|
|
|
print(f"Time range: {df_spectrogram.index[0]} to {df_spectrogram.index[-1]}")
|
|
|
print(f"Frequency range: {df_spectrogram.columns[0]:.2f} - {df_spectrogram.columns[-1]:.2f} MHz\n")
|
|
|
|
|
|
|
|
|
print("Running prediction...")
|
|
|
logit, probability = predict_burst(model, df_spectrogram, device=device)
|
|
|
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
print("PREDICTION RESULTS")
|
|
|
print("="*60)
|
|
|
print(f"Logit: {logit:.4f}")
|
|
|
print(f"Probability: {probability:.4f} ({probability*100:.2f}%)")
|
|
|
burst_detected = probability > 0.5
|
|
|
print(f"Prediction: {'BURST DETECTED ☀️' if burst_detected else 'No burst'}")
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
print("\nGenerating spectrogram plot...")
|
|
|
df_processed = subtract_constant_background(df_dict[instrument])
|
|
|
|
|
|
|
|
|
|
|
|
fig = plot_spectrogram(df_processed)
|
|
|
show(fig)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|