Image Classification
kenfus commited on
Commit
b2df42b
·
verified ·
1 Parent(s): 4a36246
Files changed (1) hide show
  1. main.py +333 -0
main.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FlareSense v2 - Simple Usage Example
3
+
4
+ This script demonstrates how to use the FlareSense model to predict solar radio bursts
5
+ on e-Callisto data. The model is automatically downloaded from HuggingFace and cached locally.
6
+
7
+ Usage:
8
+ python example_usage.py
9
+
10
+ The model will predict on a 15-minute window of data from a specific instrument.
11
+ """
12
+
13
+ import torch
14
+ import numpy as np
15
+ from datetime import datetime
16
+ from huggingface_hub import hf_hub_download
17
+ from ecallisto_ng.data_download.downloader import get_ecallisto_data
18
+ from ecallisto_ng.data_processing.utils import subtract_constant_background
19
+ from ecallisto_ng.plotting.plotting import plot_spectrogram
20
+ from plotly.io import show
21
+ import torch.nn as nn
22
+ from torchvision import models
23
+ import os
24
+
25
+ # ============================================================================
26
+ # Model Definition
27
+ # ============================================================================
28
+
29
+ class GrayScaleResNet(nn.Module):
30
+ """ResNet model adapted for grayscale images (single channel)."""
31
+
32
+ def __init__(self, n_classes=1, resnet_type="resnet34"):
33
+ super().__init__()
34
+
35
+ # Load pretrained ResNet (without num_classes parameter)
36
+ if resnet_type == "resnet34":
37
+ self.resnet = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
38
+ elif resnet_type == "resnet18":
39
+ self.resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
40
+ elif resnet_type == "resnet50":
41
+ self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
42
+ else:
43
+ raise ValueError(f"Unsupported resnet_type: {resnet_type}")
44
+
45
+ # Replace the final fully connected layer for our number of classes
46
+ num_features = self.resnet.fc.in_features
47
+ self.resnet.fc = nn.Linear(num_features, n_classes)
48
+
49
+ def forward(self, x):
50
+ # Convert grayscale (1 channel) to 3 channels by expanding
51
+ if x.size(1) == 1:
52
+ x = x.expand(-1, 3, -1, -1)
53
+ return self.resnet(x)
54
+
55
+
56
+ # ============================================================================
57
+ # Data Processing Functions
58
+ # ============================================================================
59
+
60
+ def remove_background(df_spectrogram) -> torch.Tensor:
61
+ """
62
+ Remove constant background from spectrogram DataFrame.
63
+ Uses the median of the first 300 timepoints as the background.
64
+
65
+ Args:
66
+ df_spectrogram: Pandas DataFrame with time as index and frequency as columns
67
+
68
+ Returns:
69
+ Torch tensor with background removed (frequency x time)
70
+ """
71
+ # Subtract constant background using ecallisto_ng function
72
+ df_processed = subtract_constant_background(df_spectrogram, n=300)
73
+
74
+ # Convert to numpy and transpose to (frequency, time)
75
+ # DataFrame is (time, frequency), we need (frequency, time)
76
+ array_processed = df_processed.values.T
77
+
78
+ # Convert to torch tensor
79
+ tensor = torch.from_numpy(array_processed).float()
80
+
81
+ return tensor
82
+
83
+
84
+ def remove_background_median(spectrogram_tensor: torch.Tensor) -> torch.Tensor:
85
+ """
86
+ Remove row-wise median background from spectrogram tensor.
87
+ This is applied AFTER the constant background subtraction.
88
+
89
+ Args:
90
+ spectrogram_tensor: Tensor of shape (frequency, time)
91
+
92
+ Returns:
93
+ Tensor with median background removed
94
+ """
95
+ # Calculate the median of each row (frequency band)
96
+ median_values = torch.median(spectrogram_tensor, dim=1).values
97
+
98
+ # Subtract the median from each row
99
+ background_removed = spectrogram_tensor - median_values[:, None]
100
+
101
+ return background_removed
102
+
103
+
104
+ def resize_spectrogram(spectrogram_tensor: torch.Tensor, target_size=(128, 512)) -> torch.Tensor:
105
+ """
106
+ Resize spectrogram to target size using bilinear interpolation.
107
+
108
+ Args:
109
+ spectrogram_tensor: Input tensor (frequency, time)
110
+ target_size: Target size (height, width)
111
+
112
+ Returns:
113
+ Resized tensor (1, height, width)
114
+ """
115
+ # Add batch and channel dimensions for interpolation
116
+ x = spectrogram_tensor.unsqueeze(0).unsqueeze(0)
117
+
118
+ # Resize using bilinear interpolation
119
+ resized = torch.nn.functional.interpolate(
120
+ x, size=target_size, mode='bilinear', align_corners=False
121
+ )
122
+
123
+ # Remove batch dimension, keep channel dimension (1, H, W)
124
+ return resized.squeeze(0)
125
+
126
+
127
+ def min_max_scale(tensor: torch.Tensor, feature_range=(0, 1)) -> torch.Tensor:
128
+ """
129
+ Apply Min-Max scaling to a tensor.
130
+
131
+ Args:
132
+ tensor: Input tensor
133
+ feature_range: Desired range (default: (0, 1))
134
+
135
+ Returns:
136
+ Scaled tensor
137
+ """
138
+ min_val, max_val = feature_range
139
+ tensor_min = tensor.min()
140
+ tensor_max = tensor.max()
141
+
142
+ # Avoid division by zero
143
+ if tensor_max - tensor_min == 0:
144
+ return torch.zeros_like(tensor)
145
+
146
+ scaled_tensor = (tensor - tensor_min) / (tensor_max - tensor_min)
147
+ scaled_tensor = scaled_tensor * (max_val - min_val) + min_val
148
+
149
+ return scaled_tensor
150
+
151
+
152
+ def preprocess_spectrogram(df_spectrogram) -> torch.Tensor:
153
+ """
154
+ Complete preprocessing pipeline for a spectrogram DataFrame.
155
+ This follows the exact same pipeline as the training code.
156
+
157
+ Args:
158
+ df_spectrogram: Pandas DataFrame (time x frequency) from get_ecallisto_data
159
+
160
+ Returns:
161
+ Preprocessed tensor ready for model input (1, 128, 512)
162
+ """
163
+ # Step 1: Remove constant background and convert to tensor (frequency x time)
164
+ tensor = remove_background(df_spectrogram)
165
+
166
+ # Step 2: Remove row-wise median background
167
+ tensor = remove_background_median(tensor)
168
+
169
+ # Step 3: Resize to target size (128, 512)
170
+ # This uses normal_resize since custom_resize is False in config
171
+ tensor = resize_spectrogram(tensor, target_size=(128, 512))
172
+
173
+ # Step 4: Min-max scale to [0, 1]
174
+ tensor = min_max_scale(tensor, feature_range=(0, 1))
175
+
176
+ return tensor
177
+
178
+
179
+ # ============================================================================
180
+ # Model Loading and Prediction
181
+ # ============================================================================
182
+
183
+ def load_flaresense_model(device="cpu"):
184
+ """
185
+ Load the FlareSense model from HuggingFace Hub.
186
+ The model is automatically downloaded and cached locally.
187
+
188
+ Args:
189
+ device: Device to load model on ('cpu' or 'cuda')
190
+
191
+ Returns:
192
+ Loaded model in evaluation mode
193
+ """
194
+ # Model configuration (from best_v2.yml)
195
+ REPO_ID = "i4ds/flaresense-v2"
196
+ MODEL_FILENAME = "model.ckpt"
197
+ RESNET_TYPE = "resnet34"
198
+
199
+ print(f"Downloading model from {REPO_ID}...")
200
+ checkpoint_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
201
+ print(f"Model cached at: {checkpoint_path}")
202
+
203
+ # Initialize model
204
+ model = GrayScaleResNet(n_classes=1, resnet_type=RESNET_TYPE)
205
+
206
+ # Load checkpoint
207
+ checkpoint = torch.load(checkpoint_path, weights_only=True, map_location=device)
208
+ if "state_dict" in checkpoint:
209
+ state_dict = checkpoint["state_dict"]
210
+ else:
211
+ state_dict = checkpoint
212
+
213
+ # Remove '_orig_mod.' prefix from keys (added by torch.compile)
214
+ new_state_dict = {}
215
+ for key, value in state_dict.items():
216
+ new_key = key.replace("_orig_mod.", "")
217
+ new_state_dict[new_key] = value
218
+
219
+ model.load_state_dict(new_state_dict)
220
+
221
+ # Set to evaluation mode and move to device
222
+ model.eval()
223
+ model.to(device)
224
+
225
+ print(f"Model loaded successfully on {device}")
226
+ return model
227
+
228
+
229
+ def sigmoid(x, temperature=0.4974):
230
+ """
231
+ Convert logit to probability using temperature-scaled sigmoid.
232
+
233
+ Args:
234
+ x: Logit value
235
+ temperature: Temperature parameter for calibration
236
+
237
+ Returns:
238
+ Probability [0, 1]
239
+ """
240
+ return 1 / (1 + np.exp(-x / temperature))
241
+
242
+
243
+ def predict_burst(model, df_spectrogram, device="cpu"):
244
+ """
245
+ Predict solar radio burst on a single spectrogram DataFrame.
246
+
247
+ Args:
248
+ model: Loaded FlareSense model
249
+ df_spectrogram: Pandas DataFrame (time x frequency) from get_ecallisto_data
250
+ device: Device to run prediction on
251
+
252
+ Returns:
253
+ tuple: (logit, probability)
254
+ - logit: Raw model output
255
+ - probability: Calibrated probability [0, 1]
256
+ """
257
+ # Preprocess the DataFrame
258
+ input_tensor = preprocess_spectrogram(df_spectrogram)
259
+
260
+ # Add batch dimension and move to device
261
+ input_batch = input_tensor.unsqueeze(0).to(device)
262
+
263
+ # Predict
264
+ with torch.no_grad():
265
+ logit = model(input_batch).squeeze().item()
266
+
267
+ # Convert to probability
268
+ probability = sigmoid(logit)
269
+
270
+ return logit, probability
271
+
272
+
273
+ # ============================================================================
274
+ # Main Example
275
+ # ============================================================================
276
+
277
+ def main():
278
+ """Main example demonstrating how to use FlareSense for prediction."""
279
+
280
+ # Configuration
281
+ device = "cuda" if torch.cuda.is_available() else "cpu"
282
+ print(f"Using device: {device}\n")
283
+
284
+ # Example: Predict on data from May 7, 2021
285
+ # Create a 15-minute window centered around 03:40:30
286
+ # This gives us exactly 15 minutes: 03:33:00 to 03:48:00
287
+ start_time = datetime(2021, 5, 7, 3, 33, 0)
288
+ end_time = datetime(2021, 5, 7, 3, 48, 0)
289
+
290
+ instrument = "Australia-ASSA_01"
291
+
292
+ print(f"Example prediction on instrument: {instrument}")
293
+
294
+ # Load model (downloaded and cached automatically)
295
+ model = load_flaresense_model(device=device)
296
+
297
+ # Fetch data from e-Callisto
298
+ print(f"Fetching data from e-Callisto...")
299
+ df_dict = get_ecallisto_data(start_time, end_time, instrument)
300
+
301
+
302
+ df_spectrogram = df_dict[instrument]
303
+
304
+ print(f"Data shape: {df_spectrogram.shape} (time x frequency)")
305
+ print(f"Time range: {df_spectrogram.index[0]} to {df_spectrogram.index[-1]}")
306
+ print(f"Frequency range: {df_spectrogram.columns[0]:.2f} - {df_spectrogram.columns[-1]:.2f} MHz\n")
307
+
308
+ # Predict (pass the DataFrame directly)
309
+ print("Running prediction...")
310
+ logit, probability = predict_burst(model, df_spectrogram, device=device)
311
+
312
+ # Display results
313
+ print("\n" + "="*60)
314
+ print("PREDICTION RESULTS")
315
+ print("="*60)
316
+ print(f"Logit: {logit:.4f}")
317
+ print(f"Probability: {probability:.4f} ({probability*100:.2f}%)")
318
+ burst_detected = probability > 0.5
319
+ print(f"Prediction: {'BURST DETECTED ☀️' if burst_detected else 'No burst'}")
320
+ print("="*60)
321
+
322
+ # Plot and save the spectrogram
323
+ print("\nGenerating spectrogram plot...")
324
+ df_processed = subtract_constant_background(df_dict[instrument])
325
+
326
+
327
+ # Show the plot
328
+ fig = plot_spectrogram(df_processed)
329
+ show(fig)
330
+
331
+
332
+ if __name__ == "__main__":
333
+ main()