Image Classification
kenfus commited on
Commit
fe8f292
·
verified ·
1 Parent(s): 43dca64

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +370 -0
README.md ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - i4ds/ecallisto_radio_sunburst
5
+ metrics:
6
+ - recall
7
+ - precision
8
+ pipeline_tag: image-classification
9
+ ---
10
+
11
+ ## FlareSense-v2
12
+ This model predicts on 15 minutes spectrograms if they contain a burst or not, see paper:
13
+
14
+
15
+ ## Usage
16
+
17
+ ```bash
18
+ pip install torch torchvision huggingface_hub ecallisto_ng
19
+ ```
20
+
21
+
22
+ ```python
23
+ """
24
+ FlareSense v2 - Simple Usage Example
25
+
26
+ This script demonstrates how to use the FlareSense model to predict solar radio bursts
27
+ on e-Callisto data. The model is automatically downloaded from HuggingFace and cached locally.
28
+
29
+ Usage:
30
+ python example_usage.py
31
+
32
+ The model will predict on a 15-minute window of data from a specific instrument.
33
+ """
34
+
35
+ import torch
36
+ import numpy as np
37
+ from datetime import datetime
38
+ from huggingface_hub import hf_hub_download
39
+ from ecallisto_ng.data_download.downloader import get_ecallisto_data
40
+ from ecallisto_ng.data_processing.utils import subtract_constant_background
41
+ from ecallisto_ng.plotting.plotting import plot_spectrogram
42
+ import torch.nn as nn
43
+ from torchvision import models
44
+ import os
45
+
46
+
47
+ # ============================================================================
48
+ # Model Definition
49
+ # ============================================================================
50
+
51
+ class GrayScaleResNet(nn.Module):
52
+ """ResNet model adapted for grayscale images (single channel)."""
53
+
54
+ def __init__(self, n_classes=1, resnet_type="resnet34"):
55
+ super().__init__()
56
+
57
+ # Load pretrained ResNet (without num_classes parameter)
58
+ if resnet_type == "resnet34":
59
+ self.resnet = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
60
+ elif resnet_type == "resnet18":
61
+ self.resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
62
+ elif resnet_type == "resnet50":
63
+ self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
64
+ else:
65
+ raise ValueError(f"Unsupported resnet_type: {resnet_type}")
66
+
67
+ # Replace the final fully connected layer for our number of classes
68
+ num_features = self.resnet.fc.in_features
69
+ self.resnet.fc = nn.Linear(num_features, n_classes)
70
+
71
+ def forward(self, x):
72
+ # Convert grayscale (1 channel) to 3 channels by expanding
73
+ if x.size(1) == 1:
74
+ x = x.expand(-1, 3, -1, -1)
75
+ return self.resnet(x)
76
+
77
+
78
+ # ============================================================================
79
+ # Data Processing Functions
80
+ # ============================================================================
81
+
82
+ def remove_background(df_spectrogram) -> torch.Tensor:
83
+ """
84
+ Remove constant background from spectrogram DataFrame.
85
+ Uses the median of the first 300 timepoints as the background.
86
+
87
+ Args:
88
+ df_spectrogram: Pandas DataFrame with time as index and frequency as columns
89
+
90
+ Returns:
91
+ Torch tensor with background removed (frequency x time)
92
+ """
93
+ # Subtract constant background using ecallisto_ng function
94
+ df_processed = subtract_constant_background(df_spectrogram, n=300)
95
+
96
+ # Convert to numpy and transpose to (frequency, time)
97
+ # DataFrame is (time, frequency), we need (frequency, time)
98
+ array_processed = df_processed.values.T
99
+
100
+ # Convert to torch tensor
101
+ tensor = torch.from_numpy(array_processed).float()
102
+
103
+ return tensor
104
+
105
+
106
+ def remove_background_median(spectrogram_tensor: torch.Tensor) -> torch.Tensor:
107
+ """
108
+ Remove row-wise median background from spectrogram tensor.
109
+ This is applied AFTER the constant background subtraction.
110
+
111
+ Args:
112
+ spectrogram_tensor: Tensor of shape (frequency, time)
113
+
114
+ Returns:
115
+ Tensor with median background removed
116
+ """
117
+ # Calculate the median of each row (frequency band)
118
+ median_values = torch.median(spectrogram_tensor, dim=1).values
119
+
120
+ # Subtract the median from each row
121
+ background_removed = spectrogram_tensor - median_values[:, None]
122
+
123
+ return background_removed
124
+
125
+
126
+ def resize_spectrogram(spectrogram_tensor: torch.Tensor, target_size=(128, 512)) -> torch.Tensor:
127
+ """
128
+ Resize spectrogram to target size using bilinear interpolation.
129
+
130
+ Args:
131
+ spectrogram_tensor: Input tensor (frequency, time)
132
+ target_size: Target size (height, width)
133
+
134
+ Returns:
135
+ Resized tensor (1, height, width)
136
+ """
137
+ # Add batch and channel dimensions for interpolation
138
+ x = spectrogram_tensor.unsqueeze(0).unsqueeze(0)
139
+
140
+ # Resize using bilinear interpolation
141
+ resized = torch.nn.functional.interpolate(
142
+ x, size=target_size, mode='bilinear', align_corners=False
143
+ )
144
+
145
+ # Remove batch dimension, keep channel dimension (1, H, W)
146
+ return resized.squeeze(0)
147
+
148
+
149
+ def min_max_scale(tensor: torch.Tensor, feature_range=(0, 1)) -> torch.Tensor:
150
+ """
151
+ Apply Min-Max scaling to a tensor.
152
+
153
+ Args:
154
+ tensor: Input tensor
155
+ feature_range: Desired range (default: (0, 1))
156
+
157
+ Returns:
158
+ Scaled tensor
159
+ """
160
+ min_val, max_val = feature_range
161
+ tensor_min = tensor.min()
162
+ tensor_max = tensor.max()
163
+
164
+ # Avoid division by zero
165
+ if tensor_max - tensor_min == 0:
166
+ return torch.zeros_like(tensor)
167
+
168
+ scaled_tensor = (tensor - tensor_min) / (tensor_max - tensor_min)
169
+ scaled_tensor = scaled_tensor * (max_val - min_val) + min_val
170
+
171
+ return scaled_tensor
172
+
173
+
174
+ def preprocess_spectrogram(df_spectrogram) -> torch.Tensor:
175
+ """
176
+ Complete preprocessing pipeline for a spectrogram DataFrame.
177
+ This follows the exact same pipeline as the training code.
178
+
179
+ Args:
180
+ df_spectrogram: Pandas DataFrame (time x frequency) from get_ecallisto_data
181
+
182
+ Returns:
183
+ Preprocessed tensor ready for model input (1, 128, 512)
184
+ """
185
+ # Step 1: Remove constant background and convert to tensor (frequency x time)
186
+ tensor = remove_background(df_spectrogram)
187
+
188
+ # Step 2: Remove row-wise median background
189
+ tensor = remove_background_median(tensor)
190
+
191
+ # Step 3: Resize to target size (128, 512)
192
+ # This uses normal_resize since custom_resize is False in config
193
+ tensor = resize_spectrogram(tensor, target_size=(128, 512))
194
+
195
+ # Step 4: Min-max scale to [0, 1]
196
+ tensor = min_max_scale(tensor, feature_range=(0, 1))
197
+
198
+ return tensor
199
+
200
+
201
+ # ============================================================================
202
+ # Model Loading and Prediction
203
+ # ============================================================================
204
+
205
+ def load_flaresense_model(device="cpu"):
206
+ """
207
+ Load the FlareSense model from HuggingFace Hub.
208
+ The model is automatically downloaded and cached locally.
209
+
210
+ Args:
211
+ device: Device to load model on ('cpu' or 'cuda')
212
+
213
+ Returns:
214
+ Loaded model in evaluation mode
215
+ """
216
+ # Model configuration (from best_v2.yml)
217
+ REPO_ID = "i4ds/flaresense-v2"
218
+ MODEL_FILENAME = "model.ckpt"
219
+ RESNET_TYPE = "resnet34"
220
+
221
+ print(f"Downloading model from {REPO_ID}...")
222
+ checkpoint_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
223
+ print(f"Model cached at: {checkpoint_path}")
224
+
225
+ # Initialize model
226
+ model = GrayScaleResNet(n_classes=1, resnet_type=RESNET_TYPE)
227
+
228
+ # Load checkpoint
229
+ checkpoint = torch.load(checkpoint_path, weights_only=True, map_location=device)
230
+ if "state_dict" in checkpoint:
231
+ state_dict = checkpoint["state_dict"]
232
+ else:
233
+ state_dict = checkpoint
234
+
235
+ # Remove '_orig_mod.' prefix from keys (added by torch.compile)
236
+ new_state_dict = {}
237
+ for key, value in state_dict.items():
238
+ new_key = key.replace("_orig_mod.", "")
239
+ new_state_dict[new_key] = value
240
+
241
+ model.load_state_dict(new_state_dict)
242
+
243
+ # Set to evaluation mode and move to device
244
+ model.eval()
245
+ model.to(device)
246
+
247
+ print(f"Model loaded successfully on {device}")
248
+ return model
249
+
250
+
251
+ def sigmoid(x, temperature=0.4974):
252
+ """
253
+ Convert logit to probability using temperature-scaled sigmoid.
254
+
255
+ Args:
256
+ x: Logit value
257
+ temperature: Temperature parameter for calibration
258
+
259
+ Returns:
260
+ Probability [0, 1]
261
+ """
262
+ return 1 / (1 + np.exp(-x / temperature))
263
+
264
+
265
+ def predict_burst(model, df_spectrogram, device="cpu"):
266
+ """
267
+ Predict solar radio burst on a single spectrogram DataFrame.
268
+
269
+ Args:
270
+ model: Loaded FlareSense model
271
+ df_spectrogram: Pandas DataFrame (time x frequency) from get_ecallisto_data
272
+ device: Device to run prediction on
273
+
274
+ Returns:
275
+ tuple: (logit, probability)
276
+ - logit: Raw model output
277
+ - probability: Calibrated probability [0, 1]
278
+ """
279
+ # Preprocess the DataFrame
280
+ input_tensor = preprocess_spectrogram(df_spectrogram)
281
+
282
+ # Add batch dimension and move to device
283
+ input_batch = input_tensor.unsqueeze(0).to(device)
284
+
285
+ # Predict
286
+ with torch.no_grad():
287
+ logit = model(input_batch).squeeze().item()
288
+
289
+ # Convert to probability
290
+ probability = sigmoid(logit)
291
+
292
+ return logit, probability
293
+
294
+
295
+ # ============================================================================
296
+ # Main Example
297
+ # ============================================================================
298
+
299
+ def main():
300
+ """Main example demonstrating how to use FlareSense for prediction."""
301
+
302
+ # Configuration
303
+ device = "cuda" if torch.cuda.is_available() else "cpu"
304
+ print(f"Using device: {device}\n")
305
+
306
+ # Example: Predict on data from May 7, 2021
307
+ # Create a 15-minute window centered around 03:40:30
308
+ # This gives us exactly 15 minutes: 03:33:00 to 03:48:00
309
+ start_time = datetime(2021, 5, 7, 3, 33, 0)
310
+ end_time = datetime(2021, 5, 7, 3, 48, 0)
311
+
312
+ instrument = "Australia-ASSA_01"
313
+
314
+ print(f"Example prediction on instrument: {instrument}")
315
+ duration_minutes = (end_time - start_time).total_seconds() / 60
316
+ print(f"Time window: {start_time} to {end_time} ({duration_minutes:.0f} minutes)\n")
317
+
318
+ # Load model (downloaded and cached automatically)
319
+ model = load_flaresense_model(device=device)
320
+
321
+ # Fetch data from e-Callisto
322
+ print(f"Fetching data from e-Callisto...")
323
+ df_dict = get_ecallisto_data(start_time, end_time, instrument)
324
+
325
+ if instrument not in df_dict:
326
+ print(f"Error: No data found for instrument {instrument}")
327
+ return
328
+
329
+ df_spectrogram = df_dict[instrument]
330
+ print(f"Data shape: {df_spectrogram.shape} (time x frequency)")
331
+ print(f"Time range: {df_spectrogram.index[0]} to {df_spectrogram.index[-1]}")
332
+ print(f"Frequency range: {df_spectrogram.columns[0]:.2f} - {df_spectrogram.columns[-1]:.2f} MHz\n")
333
+
334
+ # Predict (pass the DataFrame directly)
335
+ print("Running prediction...")
336
+ logit, probability = predict_burst(model, df_spectrogram, device=device)
337
+
338
+ # Display results
339
+ print("\n" + "="*60)
340
+ print("PREDICTION RESULTS")
341
+ print("="*60)
342
+ print(f"Logit: {logit:.4f}")
343
+ print(f"Probability: {probability:.4f} ({probability*100:.2f}%)")
344
+ burst_detected = probability > 0.5
345
+ print(f"Prediction: {'BURST DETECTED ☀️' if burst_detected else 'No burst'}")
346
+ print("="*60)
347
+
348
+ # Plot and save the spectrogram
349
+ print("\nGenerating spectrogram plot...")
350
+ df_processed = subtract_constant_background(df_spectrogram)
351
+ fig = plot_spectrogram(df_processed)
352
+
353
+ # Create output filename
354
+ burst_label = "burst" if burst_detected else "no_burst"
355
+ date_str = start_time.strftime("%Y-%m-%d_%H-%M-%S")
356
+ output_filename = f"{instrument}_{date_str}_{burst_label}.png"
357
+ output_path = os.path.join("ecallisto_ng", output_filename)
358
+
359
+ # Create directory if it doesn't exist
360
+ os.makedirs("ecallisto_ng", exist_ok=True)
361
+
362
+ # Save the plot
363
+ fig.write_image(output_path)
364
+ print(f"Spectrogram saved to: {output_path}")
365
+
366
+
367
+ if __name__ == "__main__":
368
+ main()
369
+
370
+ ```