|
|
--- |
|
|
license: mit |
|
|
datasets: |
|
|
- i4ds/ecallisto_radio_sunburst |
|
|
metrics: |
|
|
- recall |
|
|
- precision |
|
|
pipeline_tag: image-classification |
|
|
--- |
|
|
|
|
|
## FlareSense-v2 |
|
|
This model predicts on 15 minutes spectrograms if they contain a burst or not, see paper: |
|
|
|
|
|
|
|
|
# Usage |
|
|
|
|
|
```bash |
|
|
pip install torch torchvision huggingface_hub ecallisto_ng |
|
|
``` |
|
|
|
|
|
|
|
|
```python |
|
|
""" |
|
|
FlareSense v2 - Simple Usage Example |
|
|
|
|
|
This script demonstrates how to use the FlareSense model to predict solar radio bursts |
|
|
on e-Callisto data. The model is automatically downloaded from HuggingFace and cached locally. |
|
|
|
|
|
Usage: |
|
|
python example_usage.py |
|
|
|
|
|
The model will predict on a 15-minute window of data from a specific instrument. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
from datetime import datetime |
|
|
from huggingface_hub import hf_hub_download |
|
|
from ecallisto_ng.data_download.downloader import get_ecallisto_data |
|
|
from ecallisto_ng.data_processing.utils import subtract_constant_background |
|
|
from ecallisto_ng.plotting.plotting import plot_spectrogram |
|
|
from plotly.io import show |
|
|
import torch.nn as nn |
|
|
from torchvision import models |
|
|
import os |
|
|
|
|
|
# ============================================================================ |
|
|
# Model Definition |
|
|
# ============================================================================ |
|
|
|
|
|
class GrayScaleResNet(nn.Module): |
|
|
"""ResNet model adapted for grayscale images (single channel).""" |
|
|
|
|
|
def __init__(self, n_classes=1, resnet_type="resnet34"): |
|
|
super().__init__() |
|
|
|
|
|
# Load pretrained ResNet (without num_classes parameter) |
|
|
if resnet_type == "resnet34": |
|
|
self.resnet = models.resnet34(weights=models.ResNet34_Weights.DEFAULT) |
|
|
elif resnet_type == "resnet18": |
|
|
self.resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) |
|
|
elif resnet_type == "resnet50": |
|
|
self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) |
|
|
else: |
|
|
raise ValueError(f"Unsupported resnet_type: {resnet_type}") |
|
|
|
|
|
# Replace the final fully connected layer for our number of classes |
|
|
num_features = self.resnet.fc.in_features |
|
|
self.resnet.fc = nn.Linear(num_features, n_classes) |
|
|
|
|
|
def forward(self, x): |
|
|
# Convert grayscale (1 channel) to 3 channels by expanding |
|
|
if x.size(1) == 1: |
|
|
x = x.expand(-1, 3, -1, -1) |
|
|
return self.resnet(x) |
|
|
|
|
|
|
|
|
# ============================================================================ |
|
|
# Data Processing Functions |
|
|
# ============================================================================ |
|
|
|
|
|
def remove_background(df_spectrogram) -> torch.Tensor: |
|
|
""" |
|
|
Remove constant background from spectrogram DataFrame. |
|
|
Uses the median of the first 300 timepoints as the background. |
|
|
|
|
|
Args: |
|
|
df_spectrogram: Pandas DataFrame with time as index and frequency as columns |
|
|
|
|
|
Returns: |
|
|
Torch tensor with background removed (frequency x time) |
|
|
""" |
|
|
# Subtract constant background using ecallisto_ng function |
|
|
df_processed = subtract_constant_background(df_spectrogram, n=300) |
|
|
|
|
|
# Convert to numpy and transpose to (frequency, time) |
|
|
# DataFrame is (time, frequency), we need (frequency, time) |
|
|
array_processed = df_processed.values.T |
|
|
|
|
|
# Convert to torch tensor |
|
|
tensor = torch.from_numpy(array_processed).float() |
|
|
|
|
|
return tensor |
|
|
|
|
|
|
|
|
def remove_background_median(spectrogram_tensor: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Remove row-wise median background from spectrogram tensor. |
|
|
This is applied AFTER the constant background subtraction. |
|
|
|
|
|
Args: |
|
|
spectrogram_tensor: Tensor of shape (frequency, time) |
|
|
|
|
|
Returns: |
|
|
Tensor with median background removed |
|
|
""" |
|
|
# Calculate the median of each row (frequency band) |
|
|
median_values = torch.median(spectrogram_tensor, dim=1).values |
|
|
|
|
|
# Subtract the median from each row |
|
|
background_removed = spectrogram_tensor - median_values[:, None] |
|
|
|
|
|
return background_removed |
|
|
|
|
|
|
|
|
def resize_spectrogram(spectrogram_tensor: torch.Tensor, target_size=(128, 512)) -> torch.Tensor: |
|
|
""" |
|
|
Resize spectrogram to target size using bilinear interpolation. |
|
|
|
|
|
Args: |
|
|
spectrogram_tensor: Input tensor (frequency, time) |
|
|
target_size: Target size (height, width) |
|
|
|
|
|
Returns: |
|
|
Resized tensor (1, height, width) |
|
|
""" |
|
|
# Add batch and channel dimensions for interpolation |
|
|
x = spectrogram_tensor.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
# Resize using bilinear interpolation |
|
|
resized = torch.nn.functional.interpolate( |
|
|
x, size=target_size, mode='bilinear', align_corners=False |
|
|
) |
|
|
|
|
|
# Remove batch dimension, keep channel dimension (1, H, W) |
|
|
return resized.squeeze(0) |
|
|
|
|
|
|
|
|
def min_max_scale(tensor: torch.Tensor, feature_range=(0, 1)) -> torch.Tensor: |
|
|
""" |
|
|
Apply Min-Max scaling to a tensor. |
|
|
|
|
|
Args: |
|
|
tensor: Input tensor |
|
|
feature_range: Desired range (default: (0, 1)) |
|
|
|
|
|
Returns: |
|
|
Scaled tensor |
|
|
""" |
|
|
min_val, max_val = feature_range |
|
|
tensor_min = tensor.min() |
|
|
tensor_max = tensor.max() |
|
|
|
|
|
# Avoid division by zero |
|
|
if tensor_max - tensor_min == 0: |
|
|
return torch.zeros_like(tensor) |
|
|
|
|
|
scaled_tensor = (tensor - tensor_min) / (tensor_max - tensor_min) |
|
|
scaled_tensor = scaled_tensor * (max_val - min_val) + min_val |
|
|
|
|
|
return scaled_tensor |
|
|
|
|
|
|
|
|
def preprocess_spectrogram(df_spectrogram) -> torch.Tensor: |
|
|
""" |
|
|
Complete preprocessing pipeline for a spectrogram DataFrame. |
|
|
This follows the exact same pipeline as the training code. |
|
|
|
|
|
Args: |
|
|
df_spectrogram: Pandas DataFrame (time x frequency) from get_ecallisto_data |
|
|
|
|
|
Returns: |
|
|
Preprocessed tensor ready for model input (1, 128, 512) |
|
|
""" |
|
|
# Step 1: Remove constant background and convert to tensor (frequency x time) |
|
|
tensor = remove_background(df_spectrogram) |
|
|
|
|
|
# Step 2: Remove row-wise median background |
|
|
tensor = remove_background_median(tensor) |
|
|
|
|
|
# Step 3: Resize to target size (128, 512) |
|
|
# This uses normal_resize since custom_resize is False in config |
|
|
tensor = resize_spectrogram(tensor, target_size=(128, 512)) |
|
|
|
|
|
# Step 4: Min-max scale to [0, 1] |
|
|
tensor = min_max_scale(tensor, feature_range=(0, 1)) |
|
|
|
|
|
return tensor |
|
|
|
|
|
|
|
|
# ============================================================================ |
|
|
# Model Loading and Prediction |
|
|
# ============================================================================ |
|
|
|
|
|
def load_flaresense_model(device="cpu"): |
|
|
""" |
|
|
Load the FlareSense model from HuggingFace Hub. |
|
|
The model is automatically downloaded and cached locally. |
|
|
|
|
|
Args: |
|
|
device: Device to load model on ('cpu' or 'cuda') |
|
|
|
|
|
Returns: |
|
|
Loaded model in evaluation mode |
|
|
""" |
|
|
# Model configuration (from best_v2.yml) |
|
|
REPO_ID = "i4ds/flaresense-v2" |
|
|
MODEL_FILENAME = "model.ckpt" |
|
|
RESNET_TYPE = "resnet34" |
|
|
|
|
|
print(f"Downloading model from {REPO_ID}...") |
|
|
checkpoint_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME) |
|
|
print(f"Model cached at: {checkpoint_path}") |
|
|
|
|
|
# Initialize model |
|
|
model = GrayScaleResNet(n_classes=1, resnet_type=RESNET_TYPE) |
|
|
|
|
|
# Load checkpoint |
|
|
checkpoint = torch.load(checkpoint_path, weights_only=True, map_location=device) |
|
|
if "state_dict" in checkpoint: |
|
|
state_dict = checkpoint["state_dict"] |
|
|
else: |
|
|
state_dict = checkpoint |
|
|
|
|
|
# Remove '_orig_mod.' prefix from keys (added by torch.compile) |
|
|
new_state_dict = {} |
|
|
for key, value in state_dict.items(): |
|
|
new_key = key.replace("_orig_mod.", "") |
|
|
new_state_dict[new_key] = value |
|
|
|
|
|
model.load_state_dict(new_state_dict) |
|
|
|
|
|
# Set to evaluation mode and move to device |
|
|
model.eval() |
|
|
model.to(device) |
|
|
|
|
|
print(f"Model loaded successfully on {device}") |
|
|
return model |
|
|
|
|
|
|
|
|
def sigmoid(x, temperature=0.4974): |
|
|
""" |
|
|
Convert logit to probability using temperature-scaled sigmoid. |
|
|
|
|
|
Args: |
|
|
x: Logit value |
|
|
temperature: Temperature parameter for calibration |
|
|
|
|
|
Returns: |
|
|
Probability [0, 1] |
|
|
""" |
|
|
return 1 / (1 + np.exp(-x / temperature)) |
|
|
|
|
|
|
|
|
def predict_burst(model, df_spectrogram, device="cpu"): |
|
|
""" |
|
|
Predict solar radio burst on a single spectrogram DataFrame. |
|
|
|
|
|
Args: |
|
|
model: Loaded FlareSense model |
|
|
df_spectrogram: Pandas DataFrame (time x frequency) from get_ecallisto_data |
|
|
device: Device to run prediction on |
|
|
|
|
|
Returns: |
|
|
tuple: (logit, probability) |
|
|
- logit: Raw model output |
|
|
- probability: Calibrated probability [0, 1] |
|
|
""" |
|
|
# Preprocess the DataFrame |
|
|
input_tensor = preprocess_spectrogram(df_spectrogram) |
|
|
|
|
|
# Add batch dimension and move to device |
|
|
input_batch = input_tensor.unsqueeze(0).to(device) |
|
|
|
|
|
# Predict |
|
|
with torch.no_grad(): |
|
|
logit = model(input_batch).squeeze().item() |
|
|
|
|
|
# Convert to probability |
|
|
probability = sigmoid(logit) |
|
|
|
|
|
return logit, probability |
|
|
|
|
|
|
|
|
# ============================================================================ |
|
|
# Main Example |
|
|
# ============================================================================ |
|
|
|
|
|
def main(): |
|
|
"""Main example demonstrating how to use FlareSense for prediction.""" |
|
|
|
|
|
# Configuration |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {device}\n") |
|
|
|
|
|
# Example: Predict on data from May 7, 2021 |
|
|
# Create a 15-minute window centered around 03:40:30 |
|
|
# This gives us exactly 15 minutes: 03:33:00 to 03:48:00 |
|
|
start_time = datetime(2021, 5, 7, 3, 33, 0) |
|
|
end_time = datetime(2021, 5, 7, 3, 48, 0) |
|
|
|
|
|
instrument = "Australia-ASSA_01" |
|
|
|
|
|
print(f"Example prediction on instrument: {instrument}") |
|
|
|
|
|
# Load model (downloaded and cached automatically) |
|
|
model = load_flaresense_model(device=device) |
|
|
|
|
|
# Fetch data from e-Callisto |
|
|
print(f"Fetching data from e-Callisto...") |
|
|
df_dict = get_ecallisto_data(start_time, end_time, instrument) |
|
|
|
|
|
|
|
|
df_spectrogram = df_dict[instrument] |
|
|
|
|
|
print(f"Data shape: {df_spectrogram.shape} (time x frequency)") |
|
|
print(f"Time range: {df_spectrogram.index[0]} to {df_spectrogram.index[-1]}") |
|
|
print(f"Frequency range: {df_spectrogram.columns[0]:.2f} - {df_spectrogram.columns[-1]:.2f} MHz\n") |
|
|
|
|
|
# Predict (pass the DataFrame directly) |
|
|
print("Running prediction...") |
|
|
logit, probability = predict_burst(model, df_spectrogram, device=device) |
|
|
|
|
|
# Display results |
|
|
print("\n" + "="*60) |
|
|
print("PREDICTION RESULTS") |
|
|
print("="*60) |
|
|
print(f"Logit: {logit:.4f}") |
|
|
print(f"Probability: {probability:.4f} ({probability*100:.2f}%)") |
|
|
burst_detected = probability > 0.5 |
|
|
print(f"Prediction: {'BURST DETECTED ☀️' if burst_detected else 'No burst'}") |
|
|
print("="*60) |
|
|
|
|
|
# Plot and save the spectrogram |
|
|
print("\nGenerating spectrogram plot...") |
|
|
df_processed = subtract_constant_background(df_dict[instrument]) |
|
|
|
|
|
|
|
|
# Show the plot |
|
|
fig = plot_spectrogram(df_processed) |
|
|
show(fig) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
``` |