griffingoodwin04 commited on
Commit
afe9cc0
·
1 Parent(s): b9d1035

Enhance FlareDownloadProcessor to support time span mode for 1-minute cadence downloads, adding new arguments and methods for data existence checks. Update command-line arguments for improved usability. Refactor evaluation configurations and model training settings, including adjustments to checkpoint paths and model parameters for better performance.

Browse files
download/flare_download_processor.py CHANGED
@@ -11,15 +11,42 @@ import flare_event_downloader as fed
11
  import sxr_downloader as sxr
12
 
13
  class FlareDownloadProcessor:
14
- def __init__(self, FlareEventDownloader, SDODownloader, SXRDownloader, flaring_data=True):
15
  """
16
  Initialize the FlareDownloadProcessor.
17
  This class is responsible for processing AIA flare downloads.
 
 
 
 
 
 
 
18
  """
19
  self.FlareEventDownloader = FlareEventDownloader
20
  self.SDODownloader = SDODownloader
21
  self.SXRDownloader = SXRDownloader
22
  self.flaring_data = flaring_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  def retry_download_with_backoff(self, download_func, *args, max_retries=5, base_delay=60, **kwargs):
25
  """
@@ -56,10 +83,18 @@ class FlareDownloadProcessor:
56
  # Re-raise non-HTTP errors immediately
57
  raise
58
 
59
- def process_download(self, time_before_start=timedelta(minutes=5), time_after_end=timedelta(minutes=0)):
60
-
61
- fl_events = self.FlareEventDownloader.download_events()
62
- print(fl_events)
 
 
 
 
 
 
 
 
63
  [os.makedirs(os.path.join(self.SDODownloader.ds_path, str(c)), exist_ok=True) for c in
64
  [94, 131, 171, 193, 211, 304]]
65
 
@@ -71,6 +106,93 @@ class FlareDownloadProcessor:
71
  completed_dates = set(line.strip() for line in f)
72
  print(f"Resuming from {len(completed_dates)} previously completed downloads")
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  if self.flaring_data == True:
75
  print("Processing flare events...")
76
  if fl_events.empty:
@@ -88,8 +210,13 @@ class FlareDownloadProcessor:
88
  range((end_time - start_time) // timedelta(minutes=1))]:
89
  # Only download if we haven't processed this date yet
90
  if d.isoformat() not in processed_dates:
91
- self.retry_download_with_backoff(self.SDODownloader.downloadDate, d)
92
- processed_dates.add(d.isoformat())
 
 
 
 
 
93
  logging.info(f"Processed flare event {i + 1}/{len(fl_events)}: {event['event_starttime']} to {event['event_endtime']}")
94
  elif self.flaring_data == False:
95
  print("Processing non-flare events...")
@@ -136,29 +263,39 @@ class FlareDownloadProcessor:
136
  for j, d in enumerate(batch):
137
  # Only download if we haven't processed this date yet
138
  if d.isoformat() not in processed_dates and d.isoformat() not in completed_dates:
139
- try:
140
- print(f" Downloading data for {d} ({j+1}/{len(batch)})")
141
- self.retry_download_with_backoff(self.SDODownloader.downloadDate, d)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  processed_dates.add(d.isoformat())
143
  completed_dates.add(d.isoformat())
144
 
145
  # Update progress file
146
  with open(progress_file, 'a') as f:
147
  f.write(f"{d.isoformat()}\n")
148
-
149
- print(f" ✓ Successfully downloaded {d}")
150
-
151
- # Add small delay between individual downloads
152
- if j < len(batch) - 1:
153
- time.sleep(.2)
154
-
155
- except Exception as e:
156
- print(f" ✗ Failed to download data for {d}: {e}")
157
- # If it's a connection error, wait longer before retrying
158
- if "Connection refused" in str(e) or "timeout" in str(e).lower():
159
- print(f" Waiting 10 seconds before continuing...")
160
- time.sleep(5)
161
- continue
162
  elif d.isoformat() in completed_dates:
163
  print(f" ⏭ Skipping {d} (already completed)")
164
  processed_dates.add(d.isoformat())
@@ -174,17 +311,25 @@ class FlareDownloadProcessor:
174
  if __name__ == '__main__':
175
  parser = argparse.ArgumentParser(description='Download flare events and associated SDO data.')
176
  parser.add_argument('--start_date', type=str, default='2023-6-15',
177
- help='Start date for downloading flare events (YYYY-MM-DD)')
178
  parser.add_argument('--end_date', type=str, default='2023-07-15',
179
- help='End date for downloading flare events (YYYY-MM-DD)')
 
 
 
 
180
  parser.add_argument('--chunk_size', type=int, default=2000,
181
  help='Number of days per chunk for processing (default: 180)')
182
  parser.add_argument('--download_dir', type=str, default='/mnt/data',
183
  help='Directory to save downloaded data (default: /mnt/data)')
 
 
184
  parser.add_argument('--flaring_data', dest='flaring_data', action='store_true',
185
  help='Download flaring data (default)')
186
  parser.add_argument('--non_flaring_data', dest='flaring_data', action='store_false',
187
  help='Download non-flaring data')
 
 
188
  parser.set_defaults(flaring_data=True)
189
  args = parser.parse_args()
190
 
@@ -193,29 +338,61 @@ if __name__ == '__main__':
193
  end_date = args.end_date
194
  chunk_size = args.chunk_size
195
  flaring_data = args.flaring_data
196
- # Parse start and end dates
197
- start = datetime.strptime(start_date, "%Y-%m-%d")
198
- end = datetime.strptime(end_date, "%Y-%m-%d")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
- # Process in chunks
201
- current_start = start
202
- while current_start < end:
203
- current_end = min(current_start + timedelta(days=chunk_size), end)
204
- print(f"Processing chunk: {current_start.strftime('%Y-%m-%d')} to {current_end.strftime('%Y-%m-%d')}")
205
 
206
- sxr_downloader = sxr.SXRDownloader(f"{download_dir}/GOES-flaring",
207
- f"{download_dir}/GOES-flaring/combined")
208
- flare_event = fed.FlareEventDownloader(
209
- current_start.strftime("%Y-%m-%d"),
210
- current_end.strftime("%Y-%m-%d"),
211
- event_type="FL",
212
- GOESCls="M1.0",
213
- directory=f"{download_dir}/SDO-AIA-flaring/FlareEvents"
214
- )
215
- sdo_downloader = sdo.SDODownloader(f"{download_dir}/SDO-AIA-flaring", "ggoodwin5@gsu.edu")
216
 
217
- processor = FlareDownloadProcessor(flare_event, sdo_downloader, sxr_downloader,
218
- flaring_data=flaring_data)
219
- processor.process_download()
220
 
221
- current_start = current_end
 
11
  import sxr_downloader as sxr
12
 
13
  class FlareDownloadProcessor:
14
+ def __init__(self, FlareEventDownloader, SDODownloader, SXRDownloader, flaring_data=True, time_span_mode=False):
15
  """
16
  Initialize the FlareDownloadProcessor.
17
  This class is responsible for processing AIA flare downloads.
18
+
19
+ Args:
20
+ FlareEventDownloader: Downloader for flare events
21
+ SDODownloader: Downloader for SDO data
22
+ SXRDownloader: Downloader for SXR data
23
+ flaring_data: Whether to download flaring data (legacy mode)
24
+ time_span_mode: Whether to use time span mode for 1-minute cadence downloads
25
  """
26
  self.FlareEventDownloader = FlareEventDownloader
27
  self.SDODownloader = SDODownloader
28
  self.SXRDownloader = SXRDownloader
29
  self.flaring_data = flaring_data
30
+ self.time_span_mode = time_span_mode
31
+
32
+ def check_existing_data(self, date):
33
+ """
34
+ Check if data already exists for the given date in all required wavelengths.
35
+
36
+ Args:
37
+ date (datetime): The date to check for existing data
38
+
39
+ Returns:
40
+ bool: True if data exists for all wavelengths, False otherwise
41
+ """
42
+ wavelengths = ['94', '131', '171', '193', '211', '304']
43
+ date_str = date.strftime('%Y-%m-%dT%H:%M:%S')
44
+
45
+ for wl in wavelengths:
46
+ file_path = os.path.join(self.SDODownloader.ds_path, wl, f"{date_str}.fits")
47
+ if not os.path.exists(file_path):
48
+ return False
49
+ return True
50
 
51
  def retry_download_with_backoff(self, download_func, *args, max_retries=5, base_delay=60, **kwargs):
52
  """
 
83
  # Re-raise non-HTTP errors immediately
84
  raise
85
 
86
+ def process_download(self, time_before_start=timedelta(minutes=120), time_after_end=timedelta(minutes=120),
87
+ start_time=None, end_time=None):
88
+ """
89
+ Process downloads either in flare mode or time span mode.
90
+
91
+ Args:
92
+ time_before_start: Time before flare start to download (legacy mode)
93
+ time_after_end: Time after flare end to download (legacy mode)
94
+ start_time: Start time for time span mode (datetime object)
95
+ end_time: End time for time span mode (datetime object)
96
+ """
97
+ # Create directories for SDO data
98
  [os.makedirs(os.path.join(self.SDODownloader.ds_path, str(c)), exist_ok=True) for c in
99
  [94, 131, 171, 193, 211, 304]]
100
 
 
106
  completed_dates = set(line.strip() for line in f)
107
  print(f"Resuming from {len(completed_dates)} previously completed downloads")
108
 
109
+ if self.time_span_mode:
110
+ # Time span mode - download 1-minute cadence data for specified time range
111
+ if start_time is None or end_time is None:
112
+ raise ValueError("start_time and end_time must be provided for time span mode")
113
+
114
+ print(f"Processing time span mode: {start_time} to {end_time}")
115
+ print("Downloading 1-minute cadence data...")
116
+
117
+ # Download SXR data for the entire time span
118
+ self.retry_download_with_backoff(self.SXRDownloader.download_and_save_goes_data,
119
+ start_time.strftime('%Y-%m-%d'),
120
+ end_time.strftime('%Y-%m-%d'), max_workers=os.cpu_count()-1)
121
+
122
+ # Generate 1-minute intervals for the time span
123
+ processed_dates = set()
124
+ current_time = start_time
125
+ total_minutes = int((end_time - start_time).total_seconds() / 60)
126
+
127
+ print(f"Total time span: {total_minutes} minutes")
128
+
129
+ # Process in batches to avoid overwhelming the server
130
+ batch_size = 100 # Process 100 minutes at a time
131
+ batch_count = 0
132
+
133
+ while current_time < end_time:
134
+ batch_count += 1
135
+ batch_end = min(current_time + timedelta(minutes=batch_size), end_time)
136
+ batch_dates = []
137
+
138
+ # Generate dates for this batch
139
+ temp_time = current_time
140
+ while temp_time < batch_end:
141
+ if temp_time.isoformat() not in completed_dates:
142
+ # Check if data already exists in the download directory
143
+ if not self.check_existing_data(temp_time):
144
+ batch_dates.append(temp_time)
145
+ else:
146
+ print(f" ⏭ Data already exists for {temp_time}, skipping download")
147
+ completed_dates.add(temp_time.isoformat())
148
+ # Update progress file
149
+ with open(progress_file, 'a') as f:
150
+ f.write(f"{temp_time.isoformat()}\n")
151
+ temp_time += timedelta(minutes=1)
152
+
153
+ if batch_dates:
154
+ print(f"Processing batch {batch_count}: {len(batch_dates)} minutes from {current_time} to {batch_end}")
155
+
156
+ for i, d in enumerate(batch_dates):
157
+ try:
158
+ print(f" Downloading data for {d} ({i+1}/{len(batch_dates)})")
159
+ self.retry_download_with_backoff(self.SDODownloader.downloadDate, d)
160
+ processed_dates.add(d.isoformat())
161
+ completed_dates.add(d.isoformat())
162
+
163
+ # Update progress file
164
+ with open(progress_file, 'a') as f:
165
+ f.write(f"{d.isoformat()}\n")
166
+
167
+ print(f" ✓ Successfully downloaded {d}")
168
+
169
+ # Small delay between downloads
170
+ if i < len(batch_dates) - 1:
171
+ time.sleep(0.01)
172
+
173
+ except Exception as e:
174
+ print(f" ✗ Failed to download data for {d}: {e}")
175
+ if "Connection refused" in str(e) or "timeout" in str(e).lower():
176
+ print(f" Waiting 5 seconds before continuing...")
177
+ time.sleep(5)
178
+ continue
179
+ else:
180
+ print(f" ⏭ Skipping batch {batch_count} (all dates already completed)")
181
+
182
+ # Delay between batches
183
+ if batch_end < end_time:
184
+ print("Waiting 5 seconds before next batch...")
185
+ time.sleep(5)
186
+
187
+ current_time = batch_end
188
+
189
+ print(f"Time span processing completed. Downloaded {len(processed_dates)} data points.")
190
+ return
191
+
192
+ # Legacy flare mode
193
+ fl_events = self.FlareEventDownloader.download_events()
194
+ print(fl_events)
195
+
196
  if self.flaring_data == True:
197
  print("Processing flare events...")
198
  if fl_events.empty:
 
210
  range((end_time - start_time) // timedelta(minutes=1))]:
211
  # Only download if we haven't processed this date yet
212
  if d.isoformat() not in processed_dates:
213
+ # Check if data already exists in the download directory
214
+ if not self.check_existing_data(d):
215
+ self.retry_download_with_backoff(self.SDODownloader.downloadDate, d)
216
+ processed_dates.add(d.isoformat())
217
+ else:
218
+ print(f" ⏭ Data already exists for {d}, skipping download")
219
+ processed_dates.add(d.isoformat())
220
  logging.info(f"Processed flare event {i + 1}/{len(fl_events)}: {event['event_starttime']} to {event['event_endtime']}")
221
  elif self.flaring_data == False:
222
  print("Processing non-flare events...")
 
263
  for j, d in enumerate(batch):
264
  # Only download if we haven't processed this date yet
265
  if d.isoformat() not in processed_dates and d.isoformat() not in completed_dates:
266
+ # Check if data already exists in the download directory
267
+ if not self.check_existing_data(d):
268
+ try:
269
+ print(f" Downloading data for {d} ({j+1}/{len(batch)})")
270
+ self.retry_download_with_backoff(self.SDODownloader.downloadDate, d)
271
+ processed_dates.add(d.isoformat())
272
+ completed_dates.add(d.isoformat())
273
+
274
+ # Update progress file
275
+ with open(progress_file, 'a') as f:
276
+ f.write(f"{d.isoformat()}\n")
277
+
278
+ print(f" ✓ Successfully downloaded {d}")
279
+
280
+ # Add small delay between individual downloads
281
+ if j < len(batch) - 1:
282
+ time.sleep(.2)
283
+
284
+ except Exception as e:
285
+ print(f" ✗ Failed to download data for {d}: {e}")
286
+ # If it's a connection error, wait longer before retrying
287
+ if "Connection refused" in str(e) or "timeout" in str(e).lower():
288
+ print(f" Waiting 10 seconds before continuing...")
289
+ time.sleep(5)
290
+ continue
291
+ else:
292
+ print(f" ⏭ Data already exists for {d}, skipping download")
293
  processed_dates.add(d.isoformat())
294
  completed_dates.add(d.isoformat())
295
 
296
  # Update progress file
297
  with open(progress_file, 'a') as f:
298
  f.write(f"{d.isoformat()}\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  elif d.isoformat() in completed_dates:
300
  print(f" ⏭ Skipping {d} (already completed)")
301
  processed_dates.add(d.isoformat())
 
311
  if __name__ == '__main__':
312
  parser = argparse.ArgumentParser(description='Download flare events and associated SDO data.')
313
  parser.add_argument('--start_date', type=str, default='2023-6-15',
314
+ help='Start date for downloading data (YYYY-MM-DD)')
315
  parser.add_argument('--end_date', type=str, default='2023-07-15',
316
+ help='End date for downloading data (YYYY-MM-DD)')
317
+ parser.add_argument('--start_time', type=str, default=None,
318
+ help='Start time for time span mode (YYYY-MM-DD HH:MM:SS)')
319
+ parser.add_argument('--end_time', type=str, default=None,
320
+ help='End time for time span mode (YYYY-MM-DD HH:MM:SS)')
321
  parser.add_argument('--chunk_size', type=int, default=2000,
322
  help='Number of days per chunk for processing (default: 180)')
323
  parser.add_argument('--download_dir', type=str, default='/mnt/data',
324
  help='Directory to save downloaded data (default: /mnt/data)')
325
+ parser.add_argument('--time_span_mode', action='store_true',
326
+ help='Use time span mode for 1-minute cadence downloads')
327
  parser.add_argument('--flaring_data', dest='flaring_data', action='store_true',
328
  help='Download flaring data (default)')
329
  parser.add_argument('--non_flaring_data', dest='flaring_data', action='store_false',
330
  help='Download non-flaring data')
331
+ parser.add_argument('--email', type=str, default='ggoodwin5@gsu.edu',
332
+ help='Email for SDO data download')
333
  parser.set_defaults(flaring_data=True)
334
  args = parser.parse_args()
335
 
 
338
  end_date = args.end_date
339
  chunk_size = args.chunk_size
340
  flaring_data = args.flaring_data
341
+ time_span_mode = args.time_span_mode
342
+ email = args.email
343
+ if time_span_mode:
344
+ # Time span mode - use precise start and end times
345
+ if args.start_time is None or args.end_time is None:
346
+ print("Error: --start_time and --end_time must be provided for time span mode")
347
+ print("Example: --start_time '2023-06-15 00:00:00' --end_time '2023-06-15 23:59:59'")
348
+ exit(1)
349
+
350
+ try:
351
+ start_time = datetime.strptime(args.start_time, "%Y-%m-%d %H:%M:%S")
352
+ end_time = datetime.strptime(args.end_time, "%Y-%m-%d %H:%M:%S")
353
+ except ValueError:
354
+ print("Error: Invalid time format. Use YYYY-MM-DD HH:MM:SS")
355
+ exit(1)
356
+
357
+ print(f"Time span mode: {start_time} to {end_time}")
358
+
359
+ # Initialize downloaders
360
+ sxr_downloader = sxr.SXRDownloader(f"{download_dir}/GOES-timespan",
361
+ f"{download_dir}/GOES-timespan/combined")
362
+ sdo_downloader = sdo.SDODownloader(f"{download_dir}/SDO-AIA-timespan", email)
363
+
364
+ # Create a dummy flare event downloader (not used in time span mode)
365
+ flare_event = None
366
+
367
+ processor = FlareDownloadProcessor(flare_event, sdo_downloader, sxr_downloader,
368
+ flaring_data=flaring_data, time_span_mode=True)
369
+ processor.process_download(start_time=start_time, end_time=end_time)
370
+
371
+ else:
372
+ # Legacy flare mode
373
+ # Parse start and end dates
374
+ start = datetime.strptime(start_date, "%Y-%m-%d")
375
+ end = datetime.strptime(end_date, "%Y-%m-%d")
376
 
377
+ # Process in chunks
378
+ current_start = start
379
+ while current_start < end:
380
+ current_end = min(current_start + timedelta(days=chunk_size), end)
381
+ print(f"Processing chunk: {current_start.strftime('%Y-%m-%d')} to {current_end.strftime('%Y-%m-%d')}")
382
 
383
+ sxr_downloader = sxr.SXRDownloader(f"{download_dir}/GOES-flaring",
384
+ f"{download_dir}/GOES-flaring/combined")
385
+ flare_event = fed.FlareEventDownloader(
386
+ current_start.strftime("%Y-%m-%d"),
387
+ current_end.strftime("%Y-%m-%d"),
388
+ event_type="FL",
389
+ GOESCls="M1.0",
390
+ directory=f"{download_dir}/SDO-AIA-flaring/FlareEvents"
391
+ )
392
+ sdo_downloader = sdo.SDODownloader(f"{download_dir}/SDO-AIA-flaring", email)
393
 
394
+ processor = FlareDownloadProcessor(flare_event, sdo_downloader, sxr_downloader,
395
+ flaring_data=flaring_data, time_span_mode=False)
396
+ processor.process_download()
397
 
398
+ current_start = current_end
forecasting/inference/checkpoint_list.yaml CHANGED
@@ -2,22 +2,19 @@
2
  # This file contains a list of model checkpoints to evaluate
3
 
4
  checkpoints:
5
- <<<<<<< HEAD
6
- - name: "rs-epoch66"
7
- checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-reduced-sensitivity-epoch=66-val_total_loss=0.0342.ckpt"
8
-
9
- # - name: "rs-base-epoch43"
10
- # checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-reduced-sensitivity-changed-base-weights-epoch=43-val_total_loss=0.0470.ckpt"
11
- =======
12
- - name: "2d-embed"
13
- checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-epoch=70-val_total_loss=0.0360.ckpt"
14
-
15
- # - name: "rs-base-epoch39"
16
- # checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-reduced-sensitivity-changed-base-weights-epoch=39-val_total_loss=0.0467.ckpt"
17
- >>>>>>> 3d436f6d4c6b15827a7e8e923a105d7ba89b2c7c
18
-
19
- # - name: "rs-epoch50"
20
- # checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-reduced-sensitivity-epoch=50-val_total_loss=0.0382.ckpt"
21
 
22
  # Add more checkpoints as needed
23
  # Each checkpoint should have:
 
2
  # This file contains a list of model checkpoints to evaluate
3
 
4
  checkpoints:
5
+ # - name: "claude-final"
6
+ # checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-claude-suggested-weights-final-20250921_225446.pth"
7
+ # - name: "rs-final"
8
+ # checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-reduced-sensitivity-final-20250921_185953.pth"
9
+ # - name: "baseweights-final"
10
+ # checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-reduced-sensitivity-changed-base-weights-final-20250921_223323.pth"
11
+ - name: "claude-mse"
12
+ checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-mse-claude-epoch=62-val_total_loss=0.1904.ckpt"
13
+ - name: "baseweights-mse"
14
+ checkpoint_path: /mnt/data/COMBINED/new-checkpoint/vit-mse-base-weights-epoch=62-val_total_loss=0.2893.ckpt"
15
+ # - name: "stereo-final"
16
+ # checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-reduced-sensitivity-STEREO-final-20250921_183739.pth"
17
+
 
 
 
18
 
19
  # Add more checkpoints as needed
20
  # Each checkpoint should have:
forecasting/inference/evaluation.py CHANGED
@@ -344,7 +344,7 @@ class SolarFlareEvaluator:
344
  'B': (1e-7, 1e-6, "#FFAAA5"),
345
  'C': (1e-6, 1e-5, "#FFAAA5"),
346
  'M': (1e-5, 1e-4, "#FFAAA5"),
347
- 'X': (1e-4, 1e-3, "#FFAAA5")
348
  }
349
 
350
  for class_name, (min_flux, max_flux, color) in flare_classes_mae.items():
@@ -646,10 +646,10 @@ class SolarFlareEvaluator:
646
  # Create figure with transparent background
647
  fig = plt.figure(figsize=(10, 5))
648
  fig.patch.set_alpha(0.0) # Transparent background
649
- gs_left = fig.add_gridspec(1, 1, left=0.0, right=0.35, width_ratios=[1], hspace=0, wspace=0.0)
650
 
651
  # Right gridspec for SXR plot (column 3) with more padding
652
- gs_right = fig.add_gridspec(2, 1, left=0.45, right=1, hspace=0)
653
 
654
  wavs = ['94', '131', '171', '193', '211', '304']
655
  att_max = np.percentile(attention_data, 100)
@@ -675,9 +675,9 @@ class SolarFlareEvaluator:
675
  # Plot SXR data with uncertainty bands
676
  sxr_ax = fig.add_subplot(gs_right[:, 0])
677
 
678
- # Set SXR plot background to match regression plot
679
- sxr_ax.set_facecolor('#FFEEE6') # Light background for SXR plot
680
- sxr_ax.patch.set_alpha(1.0) # Make sure axes patch is opaque
681
 
682
  if sxr_window is not None and not sxr_window.empty:
683
  # Plot ground truth (no uncertainty)
@@ -787,8 +787,10 @@ class SolarFlareEvaluator:
787
  text.set_fontfamily('Barlow')
788
 
789
  sxr_ax.grid(True, alpha=0.3, color='black')
790
- sxr_ax.tick_params(axis='x', rotation=15, labelsize=12, colors='white')
791
- sxr_ax.tick_params(axis='y', labelsize=12, colors='white')
 
 
792
 
793
  # Set tick labels to Barlow font and white color
794
  for label in sxr_ax.get_xticklabels():
@@ -797,8 +799,10 @@ class SolarFlareEvaluator:
797
  for label in sxr_ax.get_yticklabels():
798
  label.set_fontfamily('Barlow')
799
  label.set_color('white')
800
- # for spine in sxr_ax.spines.values():
801
- # spine.set_color('white')
 
 
802
  try:
803
  sxr_ax.set_yscale('log')
804
  except:
@@ -814,7 +818,7 @@ class SolarFlareEvaluator:
814
 
815
  #plt.suptitle(f'Timestamp: {timestamp}', fontsize=14)
816
  #plt.tight_layout()
817
- plt.savefig(save_path, dpi=500, facecolor='none', transparent=True)
818
  plt.close()
819
 
820
  print(f"Worker {os.getpid()}: Completed {timestamp}")
 
344
  'B': (1e-7, 1e-6, "#FFAAA5"),
345
  'C': (1e-6, 1e-5, "#FFAAA5"),
346
  'M': (1e-5, 1e-4, "#FFAAA5"),
347
+ 'X': (1e-4, 1e-2, "#FFAAA5")
348
  }
349
 
350
  for class_name, (min_flux, max_flux, color) in flare_classes_mae.items():
 
646
  # Create figure with transparent background
647
  fig = plt.figure(figsize=(10, 5))
648
  fig.patch.set_alpha(0.0) # Transparent background
649
+ gs_left = fig.add_gridspec(1, 1, left=0.0, right=0.35, width_ratios=[1], hspace=0, wspace=0.1)
650
 
651
  # Right gridspec for SXR plot (column 3) with more padding
652
+ gs_right = fig.add_gridspec(2, 1, left=0.45, right=1, hspace=0.1)
653
 
654
  wavs = ['94', '131', '171', '193', '211', '304']
655
  att_max = np.percentile(attention_data, 100)
 
675
  # Plot SXR data with uncertainty bands
676
  sxr_ax = fig.add_subplot(gs_right[:, 0])
677
 
678
+ # Set SXR plot background to have light background inside plot area
679
+ sxr_ax.set_facecolor('#FFEEE6') # Light background for SXR plot area
680
+ sxr_ax.patch.set_alpha(1.0) # Make axes patch opaque
681
 
682
  if sxr_window is not None and not sxr_window.empty:
683
  # Plot ground truth (no uncertainty)
 
787
  text.set_fontfamily('Barlow')
788
 
789
  sxr_ax.grid(True, alpha=0.3, color='black')
790
+ sxr_ax.tick_params(axis='x', rotation=15, labelsize=12, colors='white',
791
+ )
792
+ sxr_ax.tick_params(axis='y', labelsize=12, colors='white',
793
+ )
794
 
795
  # Set tick labels to Barlow font and white color
796
  for label in sxr_ax.get_xticklabels():
 
799
  for label in sxr_ax.get_yticklabels():
800
  label.set_fontfamily('Barlow')
801
  label.set_color('white')
802
+
803
+ # Set graph border (spines) to white
804
+ for spine in sxr_ax.spines.values():
805
+ spine.set_color('white')
806
  try:
807
  sxr_ax.set_yscale('log')
808
  except:
 
818
 
819
  #plt.suptitle(f'Timestamp: {timestamp}', fontsize=14)
820
  #plt.tight_layout()
821
+ plt.savefig(save_path, dpi=500, facecolor='none',bbox_inches='tight')
822
  plt.close()
823
 
824
  print(f"Worker {os.getpid()}: Completed {timestamp}")
forecasting/inference/evaluation_config.yaml CHANGED
@@ -21,9 +21,14 @@ evaluation:
21
  # Examples: 1e-6 (C-class and above), 1e-5 (M-class and above), 1e-4 (X-class only)
22
 
23
  # Time range for evaluation
 
 
 
 
 
24
  time_range:
25
- start_time: "2023-08-05T20:30:00"
26
- end_time: "2023-08-05T23:56:00"
27
  interval_minutes: 1
28
 
29
  # Plotting parameters
 
21
  # Examples: 1e-6 (C-class and above), 1e-5 (M-class and above), 1e-4 (X-class only)
22
 
23
  # Time range for evaluation
24
+ # time_range:
25
+ # start_time: "2023-08-05T20:30:00"
26
+ # end_time: "2023-08-05T23:56:00"
27
+ # interval_minutes: 1
28
+
29
  time_range:
30
+ start_time: "2014-08-01T00:00:00"
31
+ end_time: "2014-08-31T23:59:00"
32
  interval_minutes: 1
33
 
34
  # Plotting parameters
forecasting/inference/inference_stereo.yaml ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base inference configuration template
2
+ # This will be used as a template for each checkpoint evaluation
3
+
4
+ # Base directories
5
+ base_data_dir: "/mnt/data/COMBINED"
6
+ output_path: "PLACEHOLDER_OUTPUT_PATH" # Will be replaced by batch script
7
+ weight_path: "PLACEHOLDER_WEIGHT_PATH" # Will be replaced by batch script
8
+
9
+ # Dataset configuration
10
+ SolO: "false"
11
+ Stereo: "false"
12
+
13
+ # Model configuration
14
+ model: "vit" # Options: "vit", "hybrid", "vitpatch", "fusion"
15
+ wavelengths: [171, 193, 211, 304] # AIA wavelengths in Angstroms
16
+
17
+ # MC Dropout configuration
18
+ mc:
19
+ active: "false"
20
+ runs: 5
21
+
22
+ # SolO data configuration (if using SolO dataset)
23
+ SolO_data:
24
+ solo_img_dir: "/mnt/data/ML-Ready_clean/SolO/SolO/ML-Ready-SolO"
25
+ sxr_dir: "${base_data_dir}/SXR"
26
+ sxr_norm_path: "${base_data_dir}/SolO/SXR/normalized_sxr.npy"
27
+
28
+ # Stereo data configuration (if using Stereo dataset)
29
+ Stereo_data:
30
+ stereo_img_dir: "/mnt/data/ML-Ready-mixed/STEREO_processed"
31
+ sxr_dir: "/mnt/data/ML-Ready-mixed/ML-Ready-mixed/SXR"
32
+ sxr_norm_path: "/mnt/data/ML-READY/SXR/normalized_sxr.npy"
33
+
34
+ # Model parameters
35
+ model_params:
36
+ input_size: 512
37
+ patch_size: 16
38
+ batch_size: 16
39
+ no_weights: false # Set to true to skip saving attention weights
40
+
41
+ # Model architecture parameters (should match training config)
42
+ vit_custom:
43
+ embed_dim: 512
44
+ num_channels: 4
45
+ num_classes: 1
46
+ patch_size: 16
47
+ num_patches: 1024
48
+ hidden_dim: 512
49
+ num_heads: 8
50
+ num_layers: 6
51
+ dropout: 0.1
52
+
53
+ # Data paths
54
+ data:
55
+ aia_dir: "${base_data_dir}/AIA-SPLIT/"
56
+ sxr_dir: "${base_data_dir}/SXR-SPLIT/"
57
+ sxr_norm_path: "${base_data_dir}/SXR-SPLIT/normalized_sxr.npy"
58
+ checkpoint_path: "PLACEHOLDER_CHECKPOINT_PATH" # Will be replaced by batch script
59
+
60
+ # MEGSAI parameters (should match training config)
61
+ megsai:
62
+ cnn_model: "updated"
63
+ cnn_dp: 0.2
64
+ weight_decay: 1e-5
65
+ cosine_restart_T0: 50
66
+ cosine_restart_Tmult: 2
67
+ cosine_eta_min: 1e-7
68
+
69
+ # Fusion parameters (if using fusion model)
70
+ fusion:
71
+ scalar_branch: "hybrid"
72
+ lr: 0.0001
73
+ lambda_vit_to_target: 0.3
74
+ lambda_scalar_to_target: 0.1
75
+ learnable_gate: true
76
+ gate_init_bias: 5.0
77
+ scalar_kwargs:
78
+ d_input: 6
79
+ d_output: 1
80
+ cnn_model: "updated"
81
+ cnn_dp: 0.75
forecasting/inference/inference_template.yaml CHANGED
@@ -44,10 +44,10 @@ vit_custom:
44
  num_channels: 6
45
  num_classes: 1
46
  patch_size: 16
47
- num_patches: 4096
48
  hidden_dim: 512
49
- num_heads: 16
50
- num_layers: 3
51
  dropout: 0.1
52
 
53
  # Data paths
 
44
  num_channels: 6
45
  num_classes: 1
46
  patch_size: 16
47
+ num_patches: 1024
48
  hidden_dim: 512
49
+ num_heads: 8
50
+ num_layers: 6
51
  dropout: 0.1
52
 
53
  # Data paths
forecasting/models/vit_patch_model.py CHANGED
@@ -84,6 +84,7 @@ class ViT(pl.LightningModule):
84
 
85
  #Also calculate huber loss for logging
86
  huber_loss = F.huber_loss(norm_preds_squeezed, sxr, delta=.3)
 
87
 
88
 
89
  # Log adaptation info
@@ -349,6 +350,7 @@ class SXRRegressionDynamicLoss:
349
 
350
  def calculate_loss(self, preds_norm, sxr_norm, sxr_un):
351
  base_loss = F.huber_loss(preds_norm, sxr_norm, delta=.3, reduction='none')
 
352
  weights = self._get_adaptive_weights(sxr_un)
353
  self._update_tracking(sxr_un, sxr_norm, preds_norm)
354
  weighted_loss = base_loss * weights
@@ -462,6 +464,7 @@ class SXRRegressionDynamicLoss:
462
 
463
  #Huber loss
464
  error = F.huber_loss(preds_norm, sxr_norm, delta=.3, reduction='none')
 
465
  error = error.detach().cpu().numpy()
466
 
467
 
 
84
 
85
  #Also calculate huber loss for logging
86
  huber_loss = F.huber_loss(norm_preds_squeezed, sxr, delta=.3)
87
+ #huber_loss = F.mse_loss(norm_preds_squeezed, sxr)
88
 
89
 
90
  # Log adaptation info
 
350
 
351
  def calculate_loss(self, preds_norm, sxr_norm, sxr_un):
352
  base_loss = F.huber_loss(preds_norm, sxr_norm, delta=.3, reduction='none')
353
+ #base_loss = F.mse_loss(preds_norm, sxr_norm, reduction='none')
354
  weights = self._get_adaptive_weights(sxr_un)
355
  self._update_tracking(sxr_un, sxr_norm, preds_norm)
356
  weighted_loss = base_loss * weights
 
464
 
465
  #Huber loss
466
  error = F.huber_loss(preds_norm, sxr_norm, delta=.3, reduction='none')
467
+ #error = F.mse_loss(preds_norm, sxr_norm, reduction='none')
468
  error = error.detach().cpu().numpy()
469
 
470
 
forecasting/models/vit_patch_model_uncertainty.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+
3
+ import math
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.optim as optim
9
+ import torch.utils.data as data
10
+ import torchvision
11
+ from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
12
+ from torchvision import transforms
13
+ import pytorch_lightning as pl
14
+ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
15
+
16
+ #norm = np.load("/mnt/data/ML-Ready_clean/mixed_data/SXR/normalized_sxr.npy")
17
+
18
+ def normalize_sxr(unnormalized_values, sxr_norm):
19
+ """Convert from unnormalized to normalized space"""
20
+ log_values = torch.log10(unnormalized_values + 1e-8)
21
+ normalized = (log_values - float(sxr_norm[0].item())) / float(sxr_norm[1].item())
22
+ return normalized
23
+
24
+ def unnormalize_sxr(normalized_values, sxr_norm):
25
+ return 10 ** (normalized_values * float(sxr_norm[1].item()) + float(sxr_norm[0].item())) - 1e-8
26
+
27
+ class ViTUncertainty(pl.LightningModule):
28
+ def __init__(self, model_kwargs, sxr_norm, base_weights=None):
29
+ super().__init__()
30
+ self.model_kwargs = model_kwargs
31
+ self.lr = model_kwargs['lr']
32
+ self.save_hyperparameters()
33
+ filtered_kwargs = dict(model_kwargs)
34
+ filtered_kwargs.pop('lr', None)
35
+ filtered_kwargs.pop('num_classes', None)
36
+ self.model = VisionTransformer(**filtered_kwargs)
37
+ #Set the base weights based on the number of samples in each class within training data
38
+ self.base_weights = base_weights
39
+ self.adaptive_loss = SXRRegressionDynamicLoss(window_size=15000, base_weights=self.base_weights)
40
+ self.sxr_norm = sxr_norm
41
+
42
+
43
+ def forward(self, x, return_attention=True):
44
+ return self.model(x, self.sxr_norm, return_attention=return_attention)
45
+
46
+ def forward_for_callback(self, x, return_attention=True):
47
+ """Forward method compatible with AttentionMapCallback"""
48
+ global_flux_raw, attention_weights, patch_flux_raw, patch_error = self.forward(x, return_attention=return_attention)
49
+ return global_flux_raw, attention_weights
50
+
51
+ def configure_optimizers(self):
52
+ # Use AdamW with weight decay for better regularization
53
+ optimizer = torch.optim.AdamW(
54
+ self.parameters(),
55
+ lr=self.lr,
56
+ weight_decay=0.00001,
57
+ )
58
+
59
+ scheduler = CosineAnnealingWarmRestarts(
60
+ optimizer,
61
+ T_0=50, # Restart every 20 epochs
62
+ T_mult=2, # Double the cycle length after each restart
63
+ eta_min=1e-7 # Minimum learning rate
64
+ )
65
+
66
+ return {
67
+ 'optimizer': optimizer,
68
+ 'lr_scheduler': {
69
+ 'scheduler': scheduler,
70
+ 'interval': 'epoch',
71
+ 'frequency': 1,
72
+ 'name': 'learning_rate'
73
+ }
74
+ }
75
+
76
+ # M/X Class Flare Detection Optimized Weights
77
+
78
+ def _calculate_loss(self, batch, mode="train"):
79
+ imgs, sxr = batch
80
+ raw_preds, raw_patch_contributions, raw_error = self.model(imgs,self.sxr_norm)
81
+ raw_preds_squeezed = torch.squeeze(raw_preds)
82
+ sxr_un = unnormalize_sxr(sxr, self.sxr_norm)
83
+
84
+ norm_preds_squeezed = normalize_sxr(raw_preds_squeezed, self.sxr_norm)
85
+ raw_error_squeezed = torch.squeeze(raw_error)
86
+ # Use adaptive rare event loss
87
+ loss, error_loss, weights = self.adaptive_loss.calculate_loss(
88
+ norm_preds_squeezed, sxr, sxr_un, raw_error_squeezed
89
+ )
90
+
91
+ #Also calculate huber loss for logging
92
+ huber_loss = F.huber_loss(norm_preds_squeezed, sxr, delta=.3)
93
+ #huber_loss = F.mse_loss(norm_preds_squeezed, sxr)
94
+
95
+
96
+
97
+ # Log adaptation info
98
+ if mode == "train":
99
+ # Always log learning rate (every step)
100
+ current_lr = self.trainer.optimizers[0].param_groups[0]['lr']
101
+ self.log('learning_rate', current_lr, on_step=True, on_epoch=False,
102
+ prog_bar=True, logger=True, sync_dist=True)
103
+
104
+
105
+ #self.log("sparsity_entropy_loss", sparsity_or_entropy, on_step=True, on_epoch=True, )
106
+ self.log("train_total_loss", loss, on_step=True, on_epoch=True,
107
+ prog_bar=True, logger=True, sync_dist=True)
108
+ self.log("train_huber_loss", huber_loss, on_step=True, on_epoch=True,
109
+ prog_bar=True, logger=True, sync_dist=True)
110
+ self.log("train_error_loss", error_loss, on_step=True, on_epoch=True,
111
+ prog_bar=True, logger=True, sync_dist=True)
112
+ # Detailed diagnostics only every 200 steps
113
+ if self.global_step % 200 == 0:
114
+ multipliers = self.adaptive_loss.get_current_multipliers()
115
+ for key, value in multipliers.items():
116
+ self.log(f"adaptive/{key}", value, on_step=True, on_epoch=False)
117
+
118
+ self.log("adaptive/avg_weight", weights.mean(), on_step=True, on_epoch=False)
119
+ self.log("adaptive/max_weight", weights.max(), on_step=True, on_epoch=False)
120
+
121
+ if mode == "val":
122
+ # Validation: typically only log epoch aggregates
123
+ multipliers = self.adaptive_loss.get_current_multipliers()
124
+ for key, value in multipliers.items():
125
+ self.log(f"val/adaptive/{key}", value, on_step=False, on_epoch=True)
126
+ self.log("val_total_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
127
+ self.log("val_huber_loss", huber_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
128
+ self.log("val_error_loss", error_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
129
+
130
+ return loss
131
+
132
+ def training_step(self, batch, batch_idx):
133
+ return self._calculate_loss(batch, mode="train")
134
+
135
+ def validation_step(self, batch, batch_idx):
136
+ self._calculate_loss(batch, mode="val")
137
+
138
+ def test_step(self, batch, batch_idx):
139
+ self._calculate_loss(batch, mode="test")
140
+
141
+ def apply_wavelength_dropout(self, x, dropout_prob=0.3):
142
+ """Randomly zero out some wavelengths during training"""
143
+ if self.training and torch.rand(1).item() < dropout_prob:
144
+ # x shape: [B, H, W, num_channels]
145
+ num_keep = torch.randint(1, self.model_kwargs['num_channels'], (1,)).item()
146
+ keep_indices = torch.randperm(self.model_kwargs['num_channels'])[:num_keep]
147
+
148
+ mask = torch.zeros(self.model_kwargs['num_channels'], device=x.device)
149
+ mask[keep_indices] = 1.0
150
+
151
+ x = x * mask.view(1, 1, 1, -1)
152
+ return x
153
+
154
+
155
+ class VisionTransformer(nn.Module):
156
+ def __init__(
157
+ self,
158
+ embed_dim,
159
+ hidden_dim,
160
+ num_channels,
161
+ num_heads,
162
+ num_layers,
163
+ patch_size,
164
+ num_patches,
165
+ dropout
166
+
167
+ ):
168
+ """Vision Transformer that outputs flux contributions per patch.
169
+
170
+ Args:
171
+ embed_dim: Dimensionality of the input feature vectors to the Transformer
172
+ hidden_dim: Dimensionality of the hidden layer in the feed-forward networks
173
+ within the Transformer
174
+ num_channels: Number of channels of the input (3 for RGB)
175
+ num_heads: Number of heads to use in the Multi-Head Attention block
176
+ num_layers: Number of layers to use in the Transformer
177
+ patch_size: Number of pixels that the patches have per dimension
178
+ num_patches: Maximum number of patches an image can have
179
+ dropout: Amount of dropout to apply in the feed-forward network and
180
+ on the input encoding
181
+
182
+ """
183
+ super().__init__()
184
+
185
+ self.patch_size = patch_size
186
+
187
+ # Layers/Networks
188
+ self.input_layer = nn.Linear(num_channels * (patch_size ** 2), embed_dim)
189
+
190
+ self.transformer_blocks = nn.ModuleList([
191
+ AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout)
192
+ for _ in range(num_layers)
193
+ ])
194
+
195
+ self.mlp_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1))
196
+ self.error_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1))
197
+ self.dropout = nn.Dropout(dropout)
198
+
199
+ # Parameters/Embeddings
200
+ self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
201
+ self.pos_embedding = nn.Parameter(torch.randn(1, 1 + num_patches, embed_dim))
202
+ self.grid_h = int(math.sqrt(num_patches))
203
+ self.grid_w = int(math.sqrt(num_patches))
204
+ self.pos_embedding_2d = nn.Parameter(torch.randn(1, self.grid_h, self.grid_w, embed_dim))
205
+
206
+
207
+ def forward(self, x, sxr_norm, return_attention=False):
208
+ # Preprocess input
209
+ x = img_to_patch(x, self.patch_size)
210
+ B, T, _ = x.shape
211
+ x = self.input_layer(x)
212
+
213
+ # Add CLS token and positional encoding
214
+ #cls_token = self.cls_token.repeat(B, 1, 1)
215
+ #x = torch.cat([cls_token, x], dim=1)
216
+ #x = x + self.pos_embedding[:, : T + 1]
217
+ x = self._add_2d_positional_encoding(x)
218
+
219
+ # Apply Transformer blocks
220
+ x = self.dropout(x)
221
+ x = x.transpose(0, 1) # [T, B, embed_dim]
222
+
223
+ attention_weights = []
224
+ for block in self.transformer_blocks:
225
+ if return_attention:
226
+ x, attn_weights = block(x, return_attention=True)
227
+ attention_weights.append(attn_weights)
228
+ else:
229
+ x = block(x)
230
+ #Extract patch logits and total error
231
+ patch_embeddings = x.transpose(0, 1) # [B, num_patches, embed_dim]
232
+ patch_logits = self.mlp_head(patch_embeddings).squeeze(-1) # normalized log predictions [B, num_patches]
233
+ patch_error = self.error_head(patch_embeddings).squeeze(-1) # [B, num_patches]
234
+
235
+
236
+ # --- Convert to raw SXR ---
237
+ mean, std = sxr_norm # in log10 space
238
+ patch_flux_raw = torch.clamp(10 ** (patch_logits * std + mean)- 1e-8, min=1e-15, max=1)
239
+ patch_error_raw = torch.clamp(10 ** (patch_error * std + mean)- 1e-8, min=1e-30, max=1)
240
+
241
+ # Sum over patches for raw global flux
242
+ global_flux_raw = patch_flux_raw.sum(dim=1, keepdim=True)
243
+ #Calculate total error as sqrt of sum of squares of patch errors
244
+ total_error = torch.sqrt(patch_error_raw.pow(2).sum(dim=1, keepdim=True))
245
+
246
+ # Ensure global flux is never zero (add small epsilon if needed)
247
+ global_flux_raw = torch.clamp(global_flux_raw, min=1e-15)
248
+
249
+ if return_attention:
250
+ return global_flux_raw, attention_weights, patch_flux_raw, total_error
251
+ else:
252
+ return global_flux_raw, patch_flux_raw, total_error
253
+
254
+ def _add_2d_positional_encoding(self, x):
255
+ """Add learned 2D positional encoding to patch embeddings"""
256
+ B, T, embed_dim = x.shape
257
+ num_patches = T # Exclude CLS token
258
+
259
+ # Reshape patches to 2D grid: [B, grid_h, grid_w, embed_dim]
260
+ patch_embeddings = x.reshape(B, self.grid_h, self.grid_w, embed_dim)
261
+
262
+ # Add learned 2D positional encoding
263
+ # Broadcasting: [B, grid_h, grid_w, embed_dim] + [1, grid_h, grid_w, embed_dim]
264
+ patch_embeddings = patch_embeddings + self.pos_embedding_2d
265
+
266
+ # Reshape back to sequence format: [B, num_patches, embed_dim]
267
+ patch_embeddings = patch_embeddings.reshape(B, num_patches, embed_dim)
268
+
269
+ return patch_embeddings
270
+
271
+
272
+
273
+
274
+ class AttentionBlock(nn.Module):
275
+ def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
276
+ """Attention Block.
277
+
278
+ Args:
279
+ embed_dim: Dimensionality of input and attention feature vectors
280
+ hidden_dim: Dimensionality of hidden layer in feed-forward network
281
+ (usually 2-4x larger than embed_dim)
282
+ num_heads: Number of heads to use in the Multi-Head Attention block
283
+ dropout: Amount of dropout to apply in the feed-forward network
284
+
285
+ """
286
+ super().__init__()
287
+
288
+ self.layer_norm_1 = nn.LayerNorm(embed_dim)
289
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=False)
290
+ self.layer_norm_2 = nn.LayerNorm(embed_dim)
291
+ self.linear = nn.Sequential(
292
+ nn.Linear(embed_dim, hidden_dim),
293
+ nn.GELU(),
294
+ nn.Dropout(dropout),
295
+ nn.Linear(hidden_dim, embed_dim),
296
+ nn.Dropout(dropout),
297
+ )
298
+
299
+ def forward(self, x, return_attention=False):
300
+ inp_x = self.layer_norm_1(x)
301
+
302
+ if return_attention:
303
+ attn_output, attn_weights = self.attn(inp_x, inp_x, inp_x, average_attn_weights=False)
304
+ x = x + attn_output
305
+ x = x + self.linear(self.layer_norm_2(x))
306
+ return x, attn_weights
307
+ else:
308
+ attn_output = self.attn(inp_x, inp_x, inp_x)[0]
309
+ x = x + attn_output
310
+ x = x + self.linear(self.layer_norm_2(x))
311
+ return x
312
+
313
+
314
+ def img_to_patch(x, patch_size, flatten_channels=True):
315
+ """
316
+ Args:
317
+ x: Tensor representing the image of shape [B, H, W, C]
318
+ patch_size: Number of pixels per dimension of the patches (integer)
319
+ flatten_channels: If True, the patches will be returned in a flattened format
320
+ as a feature vector instead of a image grid.
321
+ """
322
+ x = x.permute(0, 3, 1, 2)
323
+ B, C, H, W = x.shape
324
+ x = x.reshape(B, C, H // patch_size, patch_size, W // patch_size, patch_size)
325
+ x = x.permute(0, 2, 4, 1, 3, 5) # [B, H', W', C, p_H, p_W]
326
+ x = x.flatten(1, 2) # [B, H'*W', C, p_H, p_W]
327
+ if flatten_channels:
328
+ x = x.flatten(2, 4) # [B, H'*W', C*p_H*p_W]
329
+ return x
330
+
331
+ class SXRRegressionDynamicLoss:
332
+ def __init__(self, window_size=15000, base_weights=None):
333
+ self.c_threshold = 1e-6
334
+ self.m_threshold = 1e-5
335
+ self.x_threshold = 1e-4
336
+
337
+ self.window_size = window_size
338
+ self.quiet_errors = deque(maxlen=window_size)
339
+ self.c_errors = deque(maxlen=window_size)
340
+ self.m_errors = deque(maxlen=window_size)
341
+ self.x_errors = deque(maxlen=window_size)
342
+
343
+ #Calculate the base weights based on the number of samples in each class within training data
344
+ if base_weights is None:
345
+ self.base_weights = self._get_base_weights()
346
+ else:
347
+ self.base_weights = base_weights
348
+
349
+ def _get_base_weights(self):
350
+ #Calculate the base weights based on the number of samples in each class within training data
351
+ return {
352
+ 'quiet': 1.5, # Increase from current value
353
+ 'c_class': 1.0, # Keep as baseline
354
+ 'm_class': 8.0, # Maintain M-class focus
355
+ 'x_class': 20.0 # Maintain X-class focus
356
+ }
357
+
358
+ def calculate_loss(self, preds_norm, sxr_norm, sxr_un, raw_error):
359
+ base_loss = F.huber_loss(preds_norm, sxr_norm, delta=.3, reduction='none')
360
+ #Calculate loss between error and raw error
361
+ error = abs(sxr_norm - preds_norm)
362
+ error_loss = F.huber_loss(raw_error, error, reduction='none')
363
+
364
+
365
+ #base_loss = F.mse_loss(preds_norm, sxr_norm, reduction='none')
366
+ weights = self._get_adaptive_weights(sxr_un)
367
+ self._update_tracking(sxr_un, sxr_norm, preds_norm)
368
+ weighted_loss = base_loss * weights
369
+ error_weight = .2
370
+ error_loss = error_weight * error_loss.mean()
371
+ loss = weighted_loss.mean() + error_loss
372
+ return loss, error_loss, weights
373
+
374
+ def _get_adaptive_weights(self, sxr_un):
375
+ device = sxr_un.device
376
+
377
+ # Get continuous multipliers per class with custom params
378
+ quiet_mult = self._get_performance_multiplier(
379
+ self.quiet_errors, max_multiplier=1.5, min_multiplier=0.6, sensitivity=0.05, sxrclass='quiet' # Was 0.2
380
+ )
381
+ c_mult = self._get_performance_multiplier(
382
+ self.c_errors, max_multiplier=2, min_multiplier=0.7, sensitivity=0.08, sxrclass='c_class' # Was 0.3
383
+ )
384
+ m_mult = self._get_performance_multiplier(
385
+ self.m_errors, max_multiplier=5.0, min_multiplier=0.8, sensitivity=0.1, sxrclass='m_class' # Was 0.4
386
+ )
387
+ x_mult = self._get_performance_multiplier(
388
+ self.x_errors, max_multiplier=8.0, min_multiplier=0.8, sensitivity=0.12, sxrclass='x_class' # Was 0.5
389
+ )
390
+
391
+ quiet_weight = self.base_weights['quiet'] * quiet_mult
392
+ c_weight = self.base_weights['c_class'] * c_mult
393
+ m_weight = self.base_weights['m_class'] * m_mult
394
+ x_weight = self.base_weights['x_class'] * x_mult
395
+
396
+ weights = torch.ones_like(sxr_un, device=device)
397
+ weights = torch.where(sxr_un < self.c_threshold, quiet_weight, weights)
398
+ weights = torch.where((sxr_un >= self.c_threshold) & (sxr_un < self.m_threshold), c_weight, weights)
399
+ weights = torch.where((sxr_un >= self.m_threshold) & (sxr_un < self.x_threshold), m_weight, weights)
400
+ weights = torch.where(sxr_un >= self.x_threshold, x_weight, weights)
401
+
402
+ # Normalize so mean weight ~1.0 (optional, helps stability)
403
+ mean_weight = torch.mean(weights)
404
+ weights = weights / (mean_weight)
405
+
406
+ # Clamp extreme weights
407
+ #weights = torch.clamp(weights, min=0.01, max=40.0)
408
+
409
+ # Save for logging
410
+ self.current_multipliers = {
411
+ 'quiet_mult': quiet_mult,
412
+ 'c_mult': c_mult,
413
+ 'm_mult': m_mult,
414
+ 'x_mult': x_mult,
415
+ 'quiet_weight': quiet_weight,
416
+ 'c_weight': c_weight,
417
+ 'm_weight': m_weight,
418
+ 'x_weight': x_weight
419
+ }
420
+
421
+ return weights
422
+
423
+ def _get_performance_multiplier(self, error_history, max_multiplier=10.0, min_multiplier=0.5, sensitivity=3.0, sxrclass='quiet'):
424
+ """Class-dependent performance multiplier"""
425
+
426
+ class_params = {
427
+ 'quiet': {'min_samples': 2500, 'recent_window': 800},
428
+ 'c_class': {'min_samples': 2500, 'recent_window': 800},
429
+ 'm_class': {'min_samples': 1500, 'recent_window': 500},
430
+ 'x_class': {'min_samples': 1000, 'recent_window': 300}
431
+ }
432
+
433
+ # target_errors = {
434
+ # 'quiet': 0.15,
435
+ # 'c_class': 0.08,
436
+ # 'm_class': 0.05,
437
+ # 'x_class': 0.05
438
+ # }
439
+
440
+ #target = target_errors[sxrclass]
441
+
442
+ if len(error_history) < class_params[sxrclass]['min_samples']:
443
+ return 1.0
444
+
445
+ recent_window = class_params[sxrclass]['recent_window']
446
+ recent = np.mean(list(error_history)[-recent_window:])
447
+ overall = np.mean(list(error_history))
448
+
449
+ # if overall < 1e-10:
450
+ # return 1.0
451
+
452
+ ratio = recent / overall
453
+ multiplier = np.exp(sensitivity * (ratio - 1))
454
+ return np.clip(multiplier, min_multiplier, max_multiplier)
455
+
456
+
457
+
458
+ # if len(error_history) < class_params[sxrclass]['min_samples']:
459
+ # return 1.0
460
+
461
+ # recent = np.mean(list(error_history)[-class_params[sxrclass]['recent_window']:])
462
+
463
+ # if recent > target: # Not meeting target - increase weight
464
+ # excess_error = (recent - target) / target
465
+ # multiplier = 1.0 + sensitivity * excess_error
466
+ # else: # Meeting/exceeding target
467
+ # if sxrclass == 'quiet':
468
+ # # Can reduce quiet weight significantly
469
+ # multiplier = max(0.5, 1.0 - 0.5 * (target - recent) / target)
470
+ # else:
471
+ # # Keep important classes weighted well even when performing good
472
+ # multiplier = max(0.8, 1.0 - 0.2 * (target - recent) / target)
473
+
474
+ # return np.clip(multiplier, min_multiplier, max_multiplier)
475
+
476
+ def _update_tracking(self, sxr_un, sxr_norm, preds_norm):
477
+ sxr_un_np = sxr_un.detach().cpu().numpy()
478
+
479
+ #Huber loss
480
+ error = F.huber_loss(preds_norm, sxr_norm, delta=.3, reduction='none')
481
+ #error = F.mse_loss(preds_norm, sxr_norm, reduction='none')
482
+ error = error.detach().cpu().numpy()
483
+
484
+
485
+ quiet_mask = sxr_un_np < self.c_threshold
486
+ if quiet_mask.sum() > 0:
487
+ self.quiet_errors.append(float(np.mean(error[quiet_mask])))
488
+
489
+ c_mask = (sxr_un_np >= self.c_threshold) & (sxr_un_np < self.m_threshold)
490
+ if c_mask.sum() > 0:
491
+ self.c_errors.append(float(np.mean(error[c_mask])))
492
+
493
+ m_mask = (sxr_un_np >= self.m_threshold) & (sxr_un_np < self.x_threshold)
494
+ if m_mask.sum() > 0:
495
+ self.m_errors.append(float(np.mean(error[m_mask])))
496
+
497
+ x_mask = sxr_un_np >= self.x_threshold
498
+ if x_mask.sum() > 0:
499
+ self.x_errors.append(float(np.mean(error[x_mask])))
500
+
501
+
502
+ def get_current_multipliers(self):
503
+ """Get current performance multipliers for logging"""
504
+ return {
505
+ 'quiet_mult': self._get_performance_multiplier(
506
+ self.quiet_errors, max_multiplier=1.5, min_multiplier=0.6, sensitivity=0.2, sxrclass='quiet'
507
+ ),
508
+ 'c_mult': self._get_performance_multiplier(
509
+ self.c_errors, max_multiplier=2, min_multiplier=0.7, sensitivity=0.3, sxrclass='c_class'
510
+ ),
511
+ 'm_mult': self._get_performance_multiplier(
512
+ self.m_errors, max_multiplier=5.0, min_multiplier=0.8, sensitivity=0.8, sxrclass='m_class'
513
+ ),
514
+ 'x_mult': self._get_performance_multiplier(
515
+ self.x_errors, max_multiplier=8.0, min_multiplier=0.8, sensitivity=1.0, sxrclass='x_class'
516
+ ),
517
+ 'quiet_count': len(self.quiet_errors),
518
+ 'c_count': len(self.c_errors),
519
+ 'm_count': len(self.m_errors),
520
+ 'x_count': len(self.x_errors),
521
+ 'quiet_error': np.mean(self.quiet_errors) if self.quiet_errors else 0.0,
522
+ 'c_error': np.mean(self.c_errors) if self.c_errors else 0.0,
523
+ 'm_error': np.mean(self.m_errors) if self.m_errors else 0.0,
524
+ 'x_error': np.mean(self.x_errors) if self.x_errors else 0.0,
525
+ 'quiet_weight': getattr(self, 'current_multipliers', {}).get('quiet_weight', 0.0),
526
+ 'c_weight': getattr(self, 'current_multipliers', {}).get('c_weight', 0.0),
527
+ 'm_weight': getattr(self, 'current_multipliers', {}).get('m_weight', 0.0),
528
+ 'x_weight': getattr(self, 'current_multipliers', {}).get('x_weight', 0.0)
529
+ }
forecasting/training/callback.py CHANGED
@@ -132,7 +132,12 @@ class AttentionMapCallback(Callback):
132
  except:
133
  # For ViT patch model, we need to call the model's forward method directly
134
  if hasattr(pl_module, 'model') and hasattr(pl_module.model, 'forward'):
135
- outputs, attention_weights, _ = pl_module.model(imgs, pl_module.sxr_norm, return_attention=True)
 
 
 
 
 
136
  else:
137
  outputs, attention_weights = pl_module.forward_for_callback(imgs, return_attention=True)
138
 
 
132
  except:
133
  # For ViT patch model, we need to call the model's forward method directly
134
  if hasattr(pl_module, 'model') and hasattr(pl_module.model, 'forward'):
135
+ try:
136
+ print("Using model's forward method")
137
+ outputs, attention_weights, _ = pl_module.model(imgs, pl_module.sxr_norm, return_attention=True)
138
+ except:
139
+ print("Using model's forward method failed")
140
+ outputs, attention_weights = pl_module.forward_for_callback(imgs, return_attention=True)
141
  else:
142
  outputs, attention_weights = pl_module.forward_for_callback(imgs, return_attention=True)
143
 
forecasting/training/config5.yaml CHANGED
@@ -12,7 +12,7 @@ batch_size: 64
12
  epochs: 250
13
  oversample: false
14
  balance_strategy: "upsample_minority"
15
- calculate_base_weights: false # Whether to calculate class-based weights for loss function
16
 
17
  megsai:
18
  architecture: "cnn"
@@ -71,5 +71,5 @@ wandb:
71
  - aia
72
  - sxr
73
  - regression
74
- wb_name: vit-patch-model-2d-embeddings-reduced-sensitivity-changed-base-weights
75
  notes: Regression from AIA images (6 channels) to GOES SXR flux
 
12
  epochs: 250
13
  oversample: false
14
  balance_strategy: "upsample_minority"
15
+ calculate_base_weights: true # Whether to calculate class-based weights for loss function
16
 
17
  megsai:
18
  architecture: "cnn"
 
71
  - aia
72
  - sxr
73
  - regression
74
+ wb_name: vit-mse-base-weights
75
  notes: Regression from AIA images (6 channels) to GOES SXR flux
forecasting/training/config6.yaml CHANGED
@@ -71,5 +71,5 @@ wandb:
71
  - aia
72
  - sxr
73
  - regression
74
- wb_name: vit-patch-model-2d-embeddings-claude-suggested-weights
75
  notes: Regression from AIA images (6 channels) to GOES SXR flux
 
71
  - aia
72
  - sxr
73
  - regression
74
+ wb_name: vit-mse-claude
75
  notes: Regression from AIA images (6 channels) to GOES SXR flux
forecasting/training/train.py CHANGED
@@ -24,6 +24,7 @@ from forecasting.data_loaders.SDOAIA_dataloader import AIA_GOESDataModule
24
  from forecasting.models.vision_transformer_custom import ViT
25
  from forecasting.models.linear_and_hybrid import LinearIrradianceModel, HybridIrradianceModel
26
  from forecasting.models.vit_patch_model import ViT as ViTPatch
 
27
  from forecasting.models import FusionViTHybrid
28
  from callback import ImagePredictionLogger_SXR, AttentionMapCallback
29
  from pytorch_lightning.callbacks import Callback
@@ -344,6 +345,10 @@ elif config_data['selected_model'] == 'ViTPatch':
344
  base_weights = get_base_weights(data_loader, sxr_norm) if config_data.get('calculate_base_weights', True) else None
345
  model = ViTPatch(model_kwargs=config_data['vit_custom'], sxr_norm = sxr_norm, base_weights=base_weights)
346
 
 
 
 
 
347
  elif config_data['selected_model'] == 'FusionViTHybrid':
348
  # Expect a 'fusion' section in YAML
349
  fusion_cfg = config_data.get('fusion', {})
 
24
  from forecasting.models.vision_transformer_custom import ViT
25
  from forecasting.models.linear_and_hybrid import LinearIrradianceModel, HybridIrradianceModel
26
  from forecasting.models.vit_patch_model import ViT as ViTPatch
27
+ from forecasting.models.vit_patch_model_uncertainty import ViTUncertainty
28
  from forecasting.models import FusionViTHybrid
29
  from callback import ImagePredictionLogger_SXR, AttentionMapCallback
30
  from pytorch_lightning.callbacks import Callback
 
345
  base_weights = get_base_weights(data_loader, sxr_norm) if config_data.get('calculate_base_weights', True) else None
346
  model = ViTPatch(model_kwargs=config_data['vit_custom'], sxr_norm = sxr_norm, base_weights=base_weights)
347
 
348
+ elif config_data['selected_model'] == 'ViTUncertainty':
349
+ base_weights = get_base_weights(data_loader, sxr_norm) if config_data.get('calculate_base_weights', True) else None
350
+ model = ViTUncertainty(model_kwargs=config_data['vit_custom'], sxr_norm = sxr_norm, base_weights=base_weights)
351
+
352
  elif config_data['selected_model'] == 'FusionViTHybrid':
353
  # Expect a 'fusion' section in YAML
354
  fusion_cfg = config_data.get('fusion', {})
forecasting/training/vituncertainty.yaml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #Base directories - change these to switch datasets
3
+ base_data_dir: "/mnt/data/COMBINED" # Change this line for different datasets
4
+ base_checkpoint_dir: "/mnt/data/COMBINED" # Change this line for different datasets
5
+ wavelengths: [94, 131, 171, 193, 211, 304] # AIA wavelengths in Angstroms
6
+
7
+ # GPU configuration
8
+ gpu_id: 0 # GPU device ID to use (0, 1, 2, etc.) or -1 for CPU only
9
+ # Model configuration
10
+ selected_model: "ViTUncertainty" # Options: "hybrid", "vit", "fusion", "vitpatch"
11
+ batch_size: 64
12
+ epochs: 250
13
+ oversample: false
14
+ balance_strategy: "upsample_minority"
15
+ calculate_base_weights: false # Whether to calculate class-based weights for loss function
16
+
17
+ megsai:
18
+ architecture: "cnn"
19
+ seed: 42
20
+ lr: 0.0001
21
+ cnn_model: "updated"
22
+ cnn_dp: 0.2
23
+ weight_decay: 1e-5
24
+ cosine_restart_T0: 50
25
+ cosine_restart_Tmult: 2
26
+ cosine_eta_min: 1e-7
27
+
28
+ vit_custom:
29
+ embed_dim: 512
30
+ num_channels: 6
31
+ num_classes: 2
32
+ patch_size: 16
33
+ num_patches: 1024
34
+ hidden_dim: 512
35
+ num_heads: 8
36
+ num_layers: 6
37
+ dropout: 0.1
38
+ lr: 0.0001
39
+
40
+
41
+ fusion:
42
+ scalar_branch: "hybrid" # or "linear"
43
+ lr: 0.0001
44
+ lambda_vit_to_target: 0.3
45
+ lambda_scalar_to_target: 0.1
46
+ learnable_gate: true
47
+ gate_init_bias: 5.0
48
+ scalar_kwargs:
49
+ d_input: 6
50
+ d_output: 1
51
+ cnn_model: "updated"
52
+ cnn_dp: 0.75
53
+
54
+
55
+ # Data paths (automatically constructed from base directories)
56
+ data:
57
+ aia_dir:
58
+ "${base_data_dir}/AIA-SPLIT"
59
+ sxr_dir:
60
+ "${base_data_dir}/SXR-SPLIT"
61
+ sxr_norm_path:
62
+ "${base_data_dir}/SXR-SPLIT/normalized_sxr.npy"
63
+ checkpoints_dir:
64
+ "${base_checkpoint_dir}/new-checkpoint/"
65
+
66
+ wandb:
67
+ entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
68
+ project: Model Testing
69
+ job_type: training
70
+ tags:
71
+ - aia
72
+ - sxr
73
+ - regression
74
+ wb_name: vit-uncertainty-claude
75
+ notes: Regression from AIA images (6 channels) to GOES SXR flux