Commit ·
afe9cc0
1
Parent(s): b9d1035
Enhance FlareDownloadProcessor to support time span mode for 1-minute cadence downloads, adding new arguments and methods for data existence checks. Update command-line arguments for improved usability. Refactor evaluation configurations and model training settings, including adjustments to checkpoint paths and model parameters for better performance.
Browse files- download/flare_download_processor.py +225 -48
- forecasting/inference/checkpoint_list.yaml +13 -16
- forecasting/inference/evaluation.py +15 -11
- forecasting/inference/evaluation_config.yaml +7 -2
- forecasting/inference/inference_stereo.yaml +81 -0
- forecasting/inference/inference_template.yaml +3 -3
- forecasting/models/vit_patch_model.py +3 -0
- forecasting/models/vit_patch_model_uncertainty.py +529 -0
- forecasting/training/callback.py +6 -1
- forecasting/training/config5.yaml +2 -2
- forecasting/training/config6.yaml +1 -1
- forecasting/training/train.py +5 -0
- forecasting/training/vituncertainty.yaml +75 -0
download/flare_download_processor.py
CHANGED
|
@@ -11,15 +11,42 @@ import flare_event_downloader as fed
|
|
| 11 |
import sxr_downloader as sxr
|
| 12 |
|
| 13 |
class FlareDownloadProcessor:
|
| 14 |
-
def __init__(self, FlareEventDownloader, SDODownloader, SXRDownloader, flaring_data=True):
|
| 15 |
"""
|
| 16 |
Initialize the FlareDownloadProcessor.
|
| 17 |
This class is responsible for processing AIA flare downloads.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
"""
|
| 19 |
self.FlareEventDownloader = FlareEventDownloader
|
| 20 |
self.SDODownloader = SDODownloader
|
| 21 |
self.SXRDownloader = SXRDownloader
|
| 22 |
self.flaring_data = flaring_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
def retry_download_with_backoff(self, download_func, *args, max_retries=5, base_delay=60, **kwargs):
|
| 25 |
"""
|
|
@@ -56,10 +83,18 @@ class FlareDownloadProcessor:
|
|
| 56 |
# Re-raise non-HTTP errors immediately
|
| 57 |
raise
|
| 58 |
|
| 59 |
-
def process_download(self, time_before_start=timedelta(minutes=
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
[os.makedirs(os.path.join(self.SDODownloader.ds_path, str(c)), exist_ok=True) for c in
|
| 64 |
[94, 131, 171, 193, 211, 304]]
|
| 65 |
|
|
@@ -71,6 +106,93 @@ class FlareDownloadProcessor:
|
|
| 71 |
completed_dates = set(line.strip() for line in f)
|
| 72 |
print(f"Resuming from {len(completed_dates)} previously completed downloads")
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
if self.flaring_data == True:
|
| 75 |
print("Processing flare events...")
|
| 76 |
if fl_events.empty:
|
|
@@ -88,8 +210,13 @@ class FlareDownloadProcessor:
|
|
| 88 |
range((end_time - start_time) // timedelta(minutes=1))]:
|
| 89 |
# Only download if we haven't processed this date yet
|
| 90 |
if d.isoformat() not in processed_dates:
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
logging.info(f"Processed flare event {i + 1}/{len(fl_events)}: {event['event_starttime']} to {event['event_endtime']}")
|
| 94 |
elif self.flaring_data == False:
|
| 95 |
print("Processing non-flare events...")
|
|
@@ -136,29 +263,39 @@ class FlareDownloadProcessor:
|
|
| 136 |
for j, d in enumerate(batch):
|
| 137 |
# Only download if we haven't processed this date yet
|
| 138 |
if d.isoformat() not in processed_dates and d.isoformat() not in completed_dates:
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
processed_dates.add(d.isoformat())
|
| 143 |
completed_dates.add(d.isoformat())
|
| 144 |
|
| 145 |
# Update progress file
|
| 146 |
with open(progress_file, 'a') as f:
|
| 147 |
f.write(f"{d.isoformat()}\n")
|
| 148 |
-
|
| 149 |
-
print(f" ✓ Successfully downloaded {d}")
|
| 150 |
-
|
| 151 |
-
# Add small delay between individual downloads
|
| 152 |
-
if j < len(batch) - 1:
|
| 153 |
-
time.sleep(.2)
|
| 154 |
-
|
| 155 |
-
except Exception as e:
|
| 156 |
-
print(f" ✗ Failed to download data for {d}: {e}")
|
| 157 |
-
# If it's a connection error, wait longer before retrying
|
| 158 |
-
if "Connection refused" in str(e) or "timeout" in str(e).lower():
|
| 159 |
-
print(f" Waiting 10 seconds before continuing...")
|
| 160 |
-
time.sleep(5)
|
| 161 |
-
continue
|
| 162 |
elif d.isoformat() in completed_dates:
|
| 163 |
print(f" ⏭ Skipping {d} (already completed)")
|
| 164 |
processed_dates.add(d.isoformat())
|
|
@@ -174,17 +311,25 @@ class FlareDownloadProcessor:
|
|
| 174 |
if __name__ == '__main__':
|
| 175 |
parser = argparse.ArgumentParser(description='Download flare events and associated SDO data.')
|
| 176 |
parser.add_argument('--start_date', type=str, default='2023-6-15',
|
| 177 |
-
help='Start date for downloading
|
| 178 |
parser.add_argument('--end_date', type=str, default='2023-07-15',
|
| 179 |
-
help='End date for downloading
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
parser.add_argument('--chunk_size', type=int, default=2000,
|
| 181 |
help='Number of days per chunk for processing (default: 180)')
|
| 182 |
parser.add_argument('--download_dir', type=str, default='/mnt/data',
|
| 183 |
help='Directory to save downloaded data (default: /mnt/data)')
|
|
|
|
|
|
|
| 184 |
parser.add_argument('--flaring_data', dest='flaring_data', action='store_true',
|
| 185 |
help='Download flaring data (default)')
|
| 186 |
parser.add_argument('--non_flaring_data', dest='flaring_data', action='store_false',
|
| 187 |
help='Download non-flaring data')
|
|
|
|
|
|
|
| 188 |
parser.set_defaults(flaring_data=True)
|
| 189 |
args = parser.parse_args()
|
| 190 |
|
|
@@ -193,29 +338,61 @@ if __name__ == '__main__':
|
|
| 193 |
end_date = args.end_date
|
| 194 |
chunk_size = args.chunk_size
|
| 195 |
flaring_data = args.flaring_data
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
|
| 221 |
-
|
|
|
|
| 11 |
import sxr_downloader as sxr
|
| 12 |
|
| 13 |
class FlareDownloadProcessor:
|
| 14 |
+
def __init__(self, FlareEventDownloader, SDODownloader, SXRDownloader, flaring_data=True, time_span_mode=False):
|
| 15 |
"""
|
| 16 |
Initialize the FlareDownloadProcessor.
|
| 17 |
This class is responsible for processing AIA flare downloads.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
FlareEventDownloader: Downloader for flare events
|
| 21 |
+
SDODownloader: Downloader for SDO data
|
| 22 |
+
SXRDownloader: Downloader for SXR data
|
| 23 |
+
flaring_data: Whether to download flaring data (legacy mode)
|
| 24 |
+
time_span_mode: Whether to use time span mode for 1-minute cadence downloads
|
| 25 |
"""
|
| 26 |
self.FlareEventDownloader = FlareEventDownloader
|
| 27 |
self.SDODownloader = SDODownloader
|
| 28 |
self.SXRDownloader = SXRDownloader
|
| 29 |
self.flaring_data = flaring_data
|
| 30 |
+
self.time_span_mode = time_span_mode
|
| 31 |
+
|
| 32 |
+
def check_existing_data(self, date):
|
| 33 |
+
"""
|
| 34 |
+
Check if data already exists for the given date in all required wavelengths.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
date (datetime): The date to check for existing data
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
bool: True if data exists for all wavelengths, False otherwise
|
| 41 |
+
"""
|
| 42 |
+
wavelengths = ['94', '131', '171', '193', '211', '304']
|
| 43 |
+
date_str = date.strftime('%Y-%m-%dT%H:%M:%S')
|
| 44 |
+
|
| 45 |
+
for wl in wavelengths:
|
| 46 |
+
file_path = os.path.join(self.SDODownloader.ds_path, wl, f"{date_str}.fits")
|
| 47 |
+
if not os.path.exists(file_path):
|
| 48 |
+
return False
|
| 49 |
+
return True
|
| 50 |
|
| 51 |
def retry_download_with_backoff(self, download_func, *args, max_retries=5, base_delay=60, **kwargs):
|
| 52 |
"""
|
|
|
|
| 83 |
# Re-raise non-HTTP errors immediately
|
| 84 |
raise
|
| 85 |
|
| 86 |
+
def process_download(self, time_before_start=timedelta(minutes=120), time_after_end=timedelta(minutes=120),
|
| 87 |
+
start_time=None, end_time=None):
|
| 88 |
+
"""
|
| 89 |
+
Process downloads either in flare mode or time span mode.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
time_before_start: Time before flare start to download (legacy mode)
|
| 93 |
+
time_after_end: Time after flare end to download (legacy mode)
|
| 94 |
+
start_time: Start time for time span mode (datetime object)
|
| 95 |
+
end_time: End time for time span mode (datetime object)
|
| 96 |
+
"""
|
| 97 |
+
# Create directories for SDO data
|
| 98 |
[os.makedirs(os.path.join(self.SDODownloader.ds_path, str(c)), exist_ok=True) for c in
|
| 99 |
[94, 131, 171, 193, 211, 304]]
|
| 100 |
|
|
|
|
| 106 |
completed_dates = set(line.strip() for line in f)
|
| 107 |
print(f"Resuming from {len(completed_dates)} previously completed downloads")
|
| 108 |
|
| 109 |
+
if self.time_span_mode:
|
| 110 |
+
# Time span mode - download 1-minute cadence data for specified time range
|
| 111 |
+
if start_time is None or end_time is None:
|
| 112 |
+
raise ValueError("start_time and end_time must be provided for time span mode")
|
| 113 |
+
|
| 114 |
+
print(f"Processing time span mode: {start_time} to {end_time}")
|
| 115 |
+
print("Downloading 1-minute cadence data...")
|
| 116 |
+
|
| 117 |
+
# Download SXR data for the entire time span
|
| 118 |
+
self.retry_download_with_backoff(self.SXRDownloader.download_and_save_goes_data,
|
| 119 |
+
start_time.strftime('%Y-%m-%d'),
|
| 120 |
+
end_time.strftime('%Y-%m-%d'), max_workers=os.cpu_count()-1)
|
| 121 |
+
|
| 122 |
+
# Generate 1-minute intervals for the time span
|
| 123 |
+
processed_dates = set()
|
| 124 |
+
current_time = start_time
|
| 125 |
+
total_minutes = int((end_time - start_time).total_seconds() / 60)
|
| 126 |
+
|
| 127 |
+
print(f"Total time span: {total_minutes} minutes")
|
| 128 |
+
|
| 129 |
+
# Process in batches to avoid overwhelming the server
|
| 130 |
+
batch_size = 100 # Process 100 minutes at a time
|
| 131 |
+
batch_count = 0
|
| 132 |
+
|
| 133 |
+
while current_time < end_time:
|
| 134 |
+
batch_count += 1
|
| 135 |
+
batch_end = min(current_time + timedelta(minutes=batch_size), end_time)
|
| 136 |
+
batch_dates = []
|
| 137 |
+
|
| 138 |
+
# Generate dates for this batch
|
| 139 |
+
temp_time = current_time
|
| 140 |
+
while temp_time < batch_end:
|
| 141 |
+
if temp_time.isoformat() not in completed_dates:
|
| 142 |
+
# Check if data already exists in the download directory
|
| 143 |
+
if not self.check_existing_data(temp_time):
|
| 144 |
+
batch_dates.append(temp_time)
|
| 145 |
+
else:
|
| 146 |
+
print(f" ⏭ Data already exists for {temp_time}, skipping download")
|
| 147 |
+
completed_dates.add(temp_time.isoformat())
|
| 148 |
+
# Update progress file
|
| 149 |
+
with open(progress_file, 'a') as f:
|
| 150 |
+
f.write(f"{temp_time.isoformat()}\n")
|
| 151 |
+
temp_time += timedelta(minutes=1)
|
| 152 |
+
|
| 153 |
+
if batch_dates:
|
| 154 |
+
print(f"Processing batch {batch_count}: {len(batch_dates)} minutes from {current_time} to {batch_end}")
|
| 155 |
+
|
| 156 |
+
for i, d in enumerate(batch_dates):
|
| 157 |
+
try:
|
| 158 |
+
print(f" Downloading data for {d} ({i+1}/{len(batch_dates)})")
|
| 159 |
+
self.retry_download_with_backoff(self.SDODownloader.downloadDate, d)
|
| 160 |
+
processed_dates.add(d.isoformat())
|
| 161 |
+
completed_dates.add(d.isoformat())
|
| 162 |
+
|
| 163 |
+
# Update progress file
|
| 164 |
+
with open(progress_file, 'a') as f:
|
| 165 |
+
f.write(f"{d.isoformat()}\n")
|
| 166 |
+
|
| 167 |
+
print(f" ✓ Successfully downloaded {d}")
|
| 168 |
+
|
| 169 |
+
# Small delay between downloads
|
| 170 |
+
if i < len(batch_dates) - 1:
|
| 171 |
+
time.sleep(0.01)
|
| 172 |
+
|
| 173 |
+
except Exception as e:
|
| 174 |
+
print(f" ✗ Failed to download data for {d}: {e}")
|
| 175 |
+
if "Connection refused" in str(e) or "timeout" in str(e).lower():
|
| 176 |
+
print(f" Waiting 5 seconds before continuing...")
|
| 177 |
+
time.sleep(5)
|
| 178 |
+
continue
|
| 179 |
+
else:
|
| 180 |
+
print(f" ⏭ Skipping batch {batch_count} (all dates already completed)")
|
| 181 |
+
|
| 182 |
+
# Delay between batches
|
| 183 |
+
if batch_end < end_time:
|
| 184 |
+
print("Waiting 5 seconds before next batch...")
|
| 185 |
+
time.sleep(5)
|
| 186 |
+
|
| 187 |
+
current_time = batch_end
|
| 188 |
+
|
| 189 |
+
print(f"Time span processing completed. Downloaded {len(processed_dates)} data points.")
|
| 190 |
+
return
|
| 191 |
+
|
| 192 |
+
# Legacy flare mode
|
| 193 |
+
fl_events = self.FlareEventDownloader.download_events()
|
| 194 |
+
print(fl_events)
|
| 195 |
+
|
| 196 |
if self.flaring_data == True:
|
| 197 |
print("Processing flare events...")
|
| 198 |
if fl_events.empty:
|
|
|
|
| 210 |
range((end_time - start_time) // timedelta(minutes=1))]:
|
| 211 |
# Only download if we haven't processed this date yet
|
| 212 |
if d.isoformat() not in processed_dates:
|
| 213 |
+
# Check if data already exists in the download directory
|
| 214 |
+
if not self.check_existing_data(d):
|
| 215 |
+
self.retry_download_with_backoff(self.SDODownloader.downloadDate, d)
|
| 216 |
+
processed_dates.add(d.isoformat())
|
| 217 |
+
else:
|
| 218 |
+
print(f" ⏭ Data already exists for {d}, skipping download")
|
| 219 |
+
processed_dates.add(d.isoformat())
|
| 220 |
logging.info(f"Processed flare event {i + 1}/{len(fl_events)}: {event['event_starttime']} to {event['event_endtime']}")
|
| 221 |
elif self.flaring_data == False:
|
| 222 |
print("Processing non-flare events...")
|
|
|
|
| 263 |
for j, d in enumerate(batch):
|
| 264 |
# Only download if we haven't processed this date yet
|
| 265 |
if d.isoformat() not in processed_dates and d.isoformat() not in completed_dates:
|
| 266 |
+
# Check if data already exists in the download directory
|
| 267 |
+
if not self.check_existing_data(d):
|
| 268 |
+
try:
|
| 269 |
+
print(f" Downloading data for {d} ({j+1}/{len(batch)})")
|
| 270 |
+
self.retry_download_with_backoff(self.SDODownloader.downloadDate, d)
|
| 271 |
+
processed_dates.add(d.isoformat())
|
| 272 |
+
completed_dates.add(d.isoformat())
|
| 273 |
+
|
| 274 |
+
# Update progress file
|
| 275 |
+
with open(progress_file, 'a') as f:
|
| 276 |
+
f.write(f"{d.isoformat()}\n")
|
| 277 |
+
|
| 278 |
+
print(f" ✓ Successfully downloaded {d}")
|
| 279 |
+
|
| 280 |
+
# Add small delay between individual downloads
|
| 281 |
+
if j < len(batch) - 1:
|
| 282 |
+
time.sleep(.2)
|
| 283 |
+
|
| 284 |
+
except Exception as e:
|
| 285 |
+
print(f" ✗ Failed to download data for {d}: {e}")
|
| 286 |
+
# If it's a connection error, wait longer before retrying
|
| 287 |
+
if "Connection refused" in str(e) or "timeout" in str(e).lower():
|
| 288 |
+
print(f" Waiting 10 seconds before continuing...")
|
| 289 |
+
time.sleep(5)
|
| 290 |
+
continue
|
| 291 |
+
else:
|
| 292 |
+
print(f" ⏭ Data already exists for {d}, skipping download")
|
| 293 |
processed_dates.add(d.isoformat())
|
| 294 |
completed_dates.add(d.isoformat())
|
| 295 |
|
| 296 |
# Update progress file
|
| 297 |
with open(progress_file, 'a') as f:
|
| 298 |
f.write(f"{d.isoformat()}\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
elif d.isoformat() in completed_dates:
|
| 300 |
print(f" ⏭ Skipping {d} (already completed)")
|
| 301 |
processed_dates.add(d.isoformat())
|
|
|
|
| 311 |
if __name__ == '__main__':
|
| 312 |
parser = argparse.ArgumentParser(description='Download flare events and associated SDO data.')
|
| 313 |
parser.add_argument('--start_date', type=str, default='2023-6-15',
|
| 314 |
+
help='Start date for downloading data (YYYY-MM-DD)')
|
| 315 |
parser.add_argument('--end_date', type=str, default='2023-07-15',
|
| 316 |
+
help='End date for downloading data (YYYY-MM-DD)')
|
| 317 |
+
parser.add_argument('--start_time', type=str, default=None,
|
| 318 |
+
help='Start time for time span mode (YYYY-MM-DD HH:MM:SS)')
|
| 319 |
+
parser.add_argument('--end_time', type=str, default=None,
|
| 320 |
+
help='End time for time span mode (YYYY-MM-DD HH:MM:SS)')
|
| 321 |
parser.add_argument('--chunk_size', type=int, default=2000,
|
| 322 |
help='Number of days per chunk for processing (default: 180)')
|
| 323 |
parser.add_argument('--download_dir', type=str, default='/mnt/data',
|
| 324 |
help='Directory to save downloaded data (default: /mnt/data)')
|
| 325 |
+
parser.add_argument('--time_span_mode', action='store_true',
|
| 326 |
+
help='Use time span mode for 1-minute cadence downloads')
|
| 327 |
parser.add_argument('--flaring_data', dest='flaring_data', action='store_true',
|
| 328 |
help='Download flaring data (default)')
|
| 329 |
parser.add_argument('--non_flaring_data', dest='flaring_data', action='store_false',
|
| 330 |
help='Download non-flaring data')
|
| 331 |
+
parser.add_argument('--email', type=str, default='ggoodwin5@gsu.edu',
|
| 332 |
+
help='Email for SDO data download')
|
| 333 |
parser.set_defaults(flaring_data=True)
|
| 334 |
args = parser.parse_args()
|
| 335 |
|
|
|
|
| 338 |
end_date = args.end_date
|
| 339 |
chunk_size = args.chunk_size
|
| 340 |
flaring_data = args.flaring_data
|
| 341 |
+
time_span_mode = args.time_span_mode
|
| 342 |
+
email = args.email
|
| 343 |
+
if time_span_mode:
|
| 344 |
+
# Time span mode - use precise start and end times
|
| 345 |
+
if args.start_time is None or args.end_time is None:
|
| 346 |
+
print("Error: --start_time and --end_time must be provided for time span mode")
|
| 347 |
+
print("Example: --start_time '2023-06-15 00:00:00' --end_time '2023-06-15 23:59:59'")
|
| 348 |
+
exit(1)
|
| 349 |
+
|
| 350 |
+
try:
|
| 351 |
+
start_time = datetime.strptime(args.start_time, "%Y-%m-%d %H:%M:%S")
|
| 352 |
+
end_time = datetime.strptime(args.end_time, "%Y-%m-%d %H:%M:%S")
|
| 353 |
+
except ValueError:
|
| 354 |
+
print("Error: Invalid time format. Use YYYY-MM-DD HH:MM:SS")
|
| 355 |
+
exit(1)
|
| 356 |
+
|
| 357 |
+
print(f"Time span mode: {start_time} to {end_time}")
|
| 358 |
+
|
| 359 |
+
# Initialize downloaders
|
| 360 |
+
sxr_downloader = sxr.SXRDownloader(f"{download_dir}/GOES-timespan",
|
| 361 |
+
f"{download_dir}/GOES-timespan/combined")
|
| 362 |
+
sdo_downloader = sdo.SDODownloader(f"{download_dir}/SDO-AIA-timespan", email)
|
| 363 |
+
|
| 364 |
+
# Create a dummy flare event downloader (not used in time span mode)
|
| 365 |
+
flare_event = None
|
| 366 |
+
|
| 367 |
+
processor = FlareDownloadProcessor(flare_event, sdo_downloader, sxr_downloader,
|
| 368 |
+
flaring_data=flaring_data, time_span_mode=True)
|
| 369 |
+
processor.process_download(start_time=start_time, end_time=end_time)
|
| 370 |
+
|
| 371 |
+
else:
|
| 372 |
+
# Legacy flare mode
|
| 373 |
+
# Parse start and end dates
|
| 374 |
+
start = datetime.strptime(start_date, "%Y-%m-%d")
|
| 375 |
+
end = datetime.strptime(end_date, "%Y-%m-%d")
|
| 376 |
|
| 377 |
+
# Process in chunks
|
| 378 |
+
current_start = start
|
| 379 |
+
while current_start < end:
|
| 380 |
+
current_end = min(current_start + timedelta(days=chunk_size), end)
|
| 381 |
+
print(f"Processing chunk: {current_start.strftime('%Y-%m-%d')} to {current_end.strftime('%Y-%m-%d')}")
|
| 382 |
|
| 383 |
+
sxr_downloader = sxr.SXRDownloader(f"{download_dir}/GOES-flaring",
|
| 384 |
+
f"{download_dir}/GOES-flaring/combined")
|
| 385 |
+
flare_event = fed.FlareEventDownloader(
|
| 386 |
+
current_start.strftime("%Y-%m-%d"),
|
| 387 |
+
current_end.strftime("%Y-%m-%d"),
|
| 388 |
+
event_type="FL",
|
| 389 |
+
GOESCls="M1.0",
|
| 390 |
+
directory=f"{download_dir}/SDO-AIA-flaring/FlareEvents"
|
| 391 |
+
)
|
| 392 |
+
sdo_downloader = sdo.SDODownloader(f"{download_dir}/SDO-AIA-flaring", email)
|
| 393 |
|
| 394 |
+
processor = FlareDownloadProcessor(flare_event, sdo_downloader, sxr_downloader,
|
| 395 |
+
flaring_data=flaring_data, time_span_mode=False)
|
| 396 |
+
processor.process_download()
|
| 397 |
|
| 398 |
+
current_start = current_end
|
forecasting/inference/checkpoint_list.yaml
CHANGED
|
@@ -2,22 +2,19 @@
|
|
| 2 |
# This file contains a list of model checkpoints to evaluate
|
| 3 |
|
| 4 |
checkpoints:
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
# - name: "
|
| 10 |
-
# checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-reduced-sensitivity-changed-base-weights-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
# - name: "
|
| 16 |
-
# checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-reduced-sensitivity-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
# - name: "rs-epoch50"
|
| 20 |
-
# checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-reduced-sensitivity-epoch=50-val_total_loss=0.0382.ckpt"
|
| 21 |
|
| 22 |
# Add more checkpoints as needed
|
| 23 |
# Each checkpoint should have:
|
|
|
|
| 2 |
# This file contains a list of model checkpoints to evaluate
|
| 3 |
|
| 4 |
checkpoints:
|
| 5 |
+
# - name: "claude-final"
|
| 6 |
+
# checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-claude-suggested-weights-final-20250921_225446.pth"
|
| 7 |
+
# - name: "rs-final"
|
| 8 |
+
# checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-reduced-sensitivity-final-20250921_185953.pth"
|
| 9 |
+
# - name: "baseweights-final"
|
| 10 |
+
# checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-reduced-sensitivity-changed-base-weights-final-20250921_223323.pth"
|
| 11 |
+
- name: "claude-mse"
|
| 12 |
+
checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-mse-claude-epoch=62-val_total_loss=0.1904.ckpt"
|
| 13 |
+
- name: "baseweights-mse"
|
| 14 |
+
checkpoint_path: /mnt/data/COMBINED/new-checkpoint/vit-mse-base-weights-epoch=62-val_total_loss=0.2893.ckpt"
|
| 15 |
+
# - name: "stereo-final"
|
| 16 |
+
# checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-reduced-sensitivity-STEREO-final-20250921_183739.pth"
|
| 17 |
+
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
# Add more checkpoints as needed
|
| 20 |
# Each checkpoint should have:
|
forecasting/inference/evaluation.py
CHANGED
|
@@ -344,7 +344,7 @@ class SolarFlareEvaluator:
|
|
| 344 |
'B': (1e-7, 1e-6, "#FFAAA5"),
|
| 345 |
'C': (1e-6, 1e-5, "#FFAAA5"),
|
| 346 |
'M': (1e-5, 1e-4, "#FFAAA5"),
|
| 347 |
-
'X': (1e-4, 1e-
|
| 348 |
}
|
| 349 |
|
| 350 |
for class_name, (min_flux, max_flux, color) in flare_classes_mae.items():
|
|
@@ -646,10 +646,10 @@ class SolarFlareEvaluator:
|
|
| 646 |
# Create figure with transparent background
|
| 647 |
fig = plt.figure(figsize=(10, 5))
|
| 648 |
fig.patch.set_alpha(0.0) # Transparent background
|
| 649 |
-
gs_left = fig.add_gridspec(1, 1, left=0.0, right=0.35, width_ratios=[1], hspace=0, wspace=0.
|
| 650 |
|
| 651 |
# Right gridspec for SXR plot (column 3) with more padding
|
| 652 |
-
gs_right = fig.add_gridspec(2, 1, left=0.45, right=1, hspace=0)
|
| 653 |
|
| 654 |
wavs = ['94', '131', '171', '193', '211', '304']
|
| 655 |
att_max = np.percentile(attention_data, 100)
|
|
@@ -675,9 +675,9 @@ class SolarFlareEvaluator:
|
|
| 675 |
# Plot SXR data with uncertainty bands
|
| 676 |
sxr_ax = fig.add_subplot(gs_right[:, 0])
|
| 677 |
|
| 678 |
-
# Set SXR plot background to
|
| 679 |
-
sxr_ax.set_facecolor('#FFEEE6') # Light background for SXR plot
|
| 680 |
-
sxr_ax.patch.set_alpha(1.0) # Make
|
| 681 |
|
| 682 |
if sxr_window is not None and not sxr_window.empty:
|
| 683 |
# Plot ground truth (no uncertainty)
|
|
@@ -787,8 +787,10 @@ class SolarFlareEvaluator:
|
|
| 787 |
text.set_fontfamily('Barlow')
|
| 788 |
|
| 789 |
sxr_ax.grid(True, alpha=0.3, color='black')
|
| 790 |
-
sxr_ax.tick_params(axis='x', rotation=15, labelsize=12, colors='white'
|
| 791 |
-
|
|
|
|
|
|
|
| 792 |
|
| 793 |
# Set tick labels to Barlow font and white color
|
| 794 |
for label in sxr_ax.get_xticklabels():
|
|
@@ -797,8 +799,10 @@ class SolarFlareEvaluator:
|
|
| 797 |
for label in sxr_ax.get_yticklabels():
|
| 798 |
label.set_fontfamily('Barlow')
|
| 799 |
label.set_color('white')
|
| 800 |
-
|
| 801 |
-
#
|
|
|
|
|
|
|
| 802 |
try:
|
| 803 |
sxr_ax.set_yscale('log')
|
| 804 |
except:
|
|
@@ -814,7 +818,7 @@ class SolarFlareEvaluator:
|
|
| 814 |
|
| 815 |
#plt.suptitle(f'Timestamp: {timestamp}', fontsize=14)
|
| 816 |
#plt.tight_layout()
|
| 817 |
-
plt.savefig(save_path, dpi=500, facecolor='none',
|
| 818 |
plt.close()
|
| 819 |
|
| 820 |
print(f"Worker {os.getpid()}: Completed {timestamp}")
|
|
|
|
| 344 |
'B': (1e-7, 1e-6, "#FFAAA5"),
|
| 345 |
'C': (1e-6, 1e-5, "#FFAAA5"),
|
| 346 |
'M': (1e-5, 1e-4, "#FFAAA5"),
|
| 347 |
+
'X': (1e-4, 1e-2, "#FFAAA5")
|
| 348 |
}
|
| 349 |
|
| 350 |
for class_name, (min_flux, max_flux, color) in flare_classes_mae.items():
|
|
|
|
| 646 |
# Create figure with transparent background
|
| 647 |
fig = plt.figure(figsize=(10, 5))
|
| 648 |
fig.patch.set_alpha(0.0) # Transparent background
|
| 649 |
+
gs_left = fig.add_gridspec(1, 1, left=0.0, right=0.35, width_ratios=[1], hspace=0, wspace=0.1)
|
| 650 |
|
| 651 |
# Right gridspec for SXR plot (column 3) with more padding
|
| 652 |
+
gs_right = fig.add_gridspec(2, 1, left=0.45, right=1, hspace=0.1)
|
| 653 |
|
| 654 |
wavs = ['94', '131', '171', '193', '211', '304']
|
| 655 |
att_max = np.percentile(attention_data, 100)
|
|
|
|
| 675 |
# Plot SXR data with uncertainty bands
|
| 676 |
sxr_ax = fig.add_subplot(gs_right[:, 0])
|
| 677 |
|
| 678 |
+
# Set SXR plot background to have light background inside plot area
|
| 679 |
+
sxr_ax.set_facecolor('#FFEEE6') # Light background for SXR plot area
|
| 680 |
+
sxr_ax.patch.set_alpha(1.0) # Make axes patch opaque
|
| 681 |
|
| 682 |
if sxr_window is not None and not sxr_window.empty:
|
| 683 |
# Plot ground truth (no uncertainty)
|
|
|
|
| 787 |
text.set_fontfamily('Barlow')
|
| 788 |
|
| 789 |
sxr_ax.grid(True, alpha=0.3, color='black')
|
| 790 |
+
sxr_ax.tick_params(axis='x', rotation=15, labelsize=12, colors='white',
|
| 791 |
+
)
|
| 792 |
+
sxr_ax.tick_params(axis='y', labelsize=12, colors='white',
|
| 793 |
+
)
|
| 794 |
|
| 795 |
# Set tick labels to Barlow font and white color
|
| 796 |
for label in sxr_ax.get_xticklabels():
|
|
|
|
| 799 |
for label in sxr_ax.get_yticklabels():
|
| 800 |
label.set_fontfamily('Barlow')
|
| 801 |
label.set_color('white')
|
| 802 |
+
|
| 803 |
+
# Set graph border (spines) to white
|
| 804 |
+
for spine in sxr_ax.spines.values():
|
| 805 |
+
spine.set_color('white')
|
| 806 |
try:
|
| 807 |
sxr_ax.set_yscale('log')
|
| 808 |
except:
|
|
|
|
| 818 |
|
| 819 |
#plt.suptitle(f'Timestamp: {timestamp}', fontsize=14)
|
| 820 |
#plt.tight_layout()
|
| 821 |
+
plt.savefig(save_path, dpi=500, facecolor='none',bbox_inches='tight')
|
| 822 |
plt.close()
|
| 823 |
|
| 824 |
print(f"Worker {os.getpid()}: Completed {timestamp}")
|
forecasting/inference/evaluation_config.yaml
CHANGED
|
@@ -21,9 +21,14 @@ evaluation:
|
|
| 21 |
# Examples: 1e-6 (C-class and above), 1e-5 (M-class and above), 1e-4 (X-class only)
|
| 22 |
|
| 23 |
# Time range for evaluation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
time_range:
|
| 25 |
-
start_time: "
|
| 26 |
-
end_time: "
|
| 27 |
interval_minutes: 1
|
| 28 |
|
| 29 |
# Plotting parameters
|
|
|
|
| 21 |
# Examples: 1e-6 (C-class and above), 1e-5 (M-class and above), 1e-4 (X-class only)
|
| 22 |
|
| 23 |
# Time range for evaluation
|
| 24 |
+
# time_range:
|
| 25 |
+
# start_time: "2023-08-05T20:30:00"
|
| 26 |
+
# end_time: "2023-08-05T23:56:00"
|
| 27 |
+
# interval_minutes: 1
|
| 28 |
+
|
| 29 |
time_range:
|
| 30 |
+
start_time: "2014-08-01T00:00:00"
|
| 31 |
+
end_time: "2014-08-31T23:59:00"
|
| 32 |
interval_minutes: 1
|
| 33 |
|
| 34 |
# Plotting parameters
|
forecasting/inference/inference_stereo.yaml
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Base inference configuration template
|
| 2 |
+
# This will be used as a template for each checkpoint evaluation
|
| 3 |
+
|
| 4 |
+
# Base directories
|
| 5 |
+
base_data_dir: "/mnt/data/COMBINED"
|
| 6 |
+
output_path: "PLACEHOLDER_OUTPUT_PATH" # Will be replaced by batch script
|
| 7 |
+
weight_path: "PLACEHOLDER_WEIGHT_PATH" # Will be replaced by batch script
|
| 8 |
+
|
| 9 |
+
# Dataset configuration
|
| 10 |
+
SolO: "false"
|
| 11 |
+
Stereo: "false"
|
| 12 |
+
|
| 13 |
+
# Model configuration
|
| 14 |
+
model: "vit" # Options: "vit", "hybrid", "vitpatch", "fusion"
|
| 15 |
+
wavelengths: [171, 193, 211, 304] # AIA wavelengths in Angstroms
|
| 16 |
+
|
| 17 |
+
# MC Dropout configuration
|
| 18 |
+
mc:
|
| 19 |
+
active: "false"
|
| 20 |
+
runs: 5
|
| 21 |
+
|
| 22 |
+
# SolO data configuration (if using SolO dataset)
|
| 23 |
+
SolO_data:
|
| 24 |
+
solo_img_dir: "/mnt/data/ML-Ready_clean/SolO/SolO/ML-Ready-SolO"
|
| 25 |
+
sxr_dir: "${base_data_dir}/SXR"
|
| 26 |
+
sxr_norm_path: "${base_data_dir}/SolO/SXR/normalized_sxr.npy"
|
| 27 |
+
|
| 28 |
+
# Stereo data configuration (if using Stereo dataset)
|
| 29 |
+
Stereo_data:
|
| 30 |
+
stereo_img_dir: "/mnt/data/ML-Ready-mixed/STEREO_processed"
|
| 31 |
+
sxr_dir: "/mnt/data/ML-Ready-mixed/ML-Ready-mixed/SXR"
|
| 32 |
+
sxr_norm_path: "/mnt/data/ML-READY/SXR/normalized_sxr.npy"
|
| 33 |
+
|
| 34 |
+
# Model parameters
|
| 35 |
+
model_params:
|
| 36 |
+
input_size: 512
|
| 37 |
+
patch_size: 16
|
| 38 |
+
batch_size: 16
|
| 39 |
+
no_weights: false # Set to true to skip saving attention weights
|
| 40 |
+
|
| 41 |
+
# Model architecture parameters (should match training config)
|
| 42 |
+
vit_custom:
|
| 43 |
+
embed_dim: 512
|
| 44 |
+
num_channels: 4
|
| 45 |
+
num_classes: 1
|
| 46 |
+
patch_size: 16
|
| 47 |
+
num_patches: 1024
|
| 48 |
+
hidden_dim: 512
|
| 49 |
+
num_heads: 8
|
| 50 |
+
num_layers: 6
|
| 51 |
+
dropout: 0.1
|
| 52 |
+
|
| 53 |
+
# Data paths
|
| 54 |
+
data:
|
| 55 |
+
aia_dir: "${base_data_dir}/AIA-SPLIT/"
|
| 56 |
+
sxr_dir: "${base_data_dir}/SXR-SPLIT/"
|
| 57 |
+
sxr_norm_path: "${base_data_dir}/SXR-SPLIT/normalized_sxr.npy"
|
| 58 |
+
checkpoint_path: "PLACEHOLDER_CHECKPOINT_PATH" # Will be replaced by batch script
|
| 59 |
+
|
| 60 |
+
# MEGSAI parameters (should match training config)
|
| 61 |
+
megsai:
|
| 62 |
+
cnn_model: "updated"
|
| 63 |
+
cnn_dp: 0.2
|
| 64 |
+
weight_decay: 1e-5
|
| 65 |
+
cosine_restart_T0: 50
|
| 66 |
+
cosine_restart_Tmult: 2
|
| 67 |
+
cosine_eta_min: 1e-7
|
| 68 |
+
|
| 69 |
+
# Fusion parameters (if using fusion model)
|
| 70 |
+
fusion:
|
| 71 |
+
scalar_branch: "hybrid"
|
| 72 |
+
lr: 0.0001
|
| 73 |
+
lambda_vit_to_target: 0.3
|
| 74 |
+
lambda_scalar_to_target: 0.1
|
| 75 |
+
learnable_gate: true
|
| 76 |
+
gate_init_bias: 5.0
|
| 77 |
+
scalar_kwargs:
|
| 78 |
+
d_input: 6
|
| 79 |
+
d_output: 1
|
| 80 |
+
cnn_model: "updated"
|
| 81 |
+
cnn_dp: 0.75
|
forecasting/inference/inference_template.yaml
CHANGED
|
@@ -44,10 +44,10 @@ vit_custom:
|
|
| 44 |
num_channels: 6
|
| 45 |
num_classes: 1
|
| 46 |
patch_size: 16
|
| 47 |
-
num_patches:
|
| 48 |
hidden_dim: 512
|
| 49 |
-
num_heads:
|
| 50 |
-
num_layers:
|
| 51 |
dropout: 0.1
|
| 52 |
|
| 53 |
# Data paths
|
|
|
|
| 44 |
num_channels: 6
|
| 45 |
num_classes: 1
|
| 46 |
patch_size: 16
|
| 47 |
+
num_patches: 1024
|
| 48 |
hidden_dim: 512
|
| 49 |
+
num_heads: 8
|
| 50 |
+
num_layers: 6
|
| 51 |
dropout: 0.1
|
| 52 |
|
| 53 |
# Data paths
|
forecasting/models/vit_patch_model.py
CHANGED
|
@@ -84,6 +84,7 @@ class ViT(pl.LightningModule):
|
|
| 84 |
|
| 85 |
#Also calculate huber loss for logging
|
| 86 |
huber_loss = F.huber_loss(norm_preds_squeezed, sxr, delta=.3)
|
|
|
|
| 87 |
|
| 88 |
|
| 89 |
# Log adaptation info
|
|
@@ -349,6 +350,7 @@ class SXRRegressionDynamicLoss:
|
|
| 349 |
|
| 350 |
def calculate_loss(self, preds_norm, sxr_norm, sxr_un):
|
| 351 |
base_loss = F.huber_loss(preds_norm, sxr_norm, delta=.3, reduction='none')
|
|
|
|
| 352 |
weights = self._get_adaptive_weights(sxr_un)
|
| 353 |
self._update_tracking(sxr_un, sxr_norm, preds_norm)
|
| 354 |
weighted_loss = base_loss * weights
|
|
@@ -462,6 +464,7 @@ class SXRRegressionDynamicLoss:
|
|
| 462 |
|
| 463 |
#Huber loss
|
| 464 |
error = F.huber_loss(preds_norm, sxr_norm, delta=.3, reduction='none')
|
|
|
|
| 465 |
error = error.detach().cpu().numpy()
|
| 466 |
|
| 467 |
|
|
|
|
| 84 |
|
| 85 |
#Also calculate huber loss for logging
|
| 86 |
huber_loss = F.huber_loss(norm_preds_squeezed, sxr, delta=.3)
|
| 87 |
+
#huber_loss = F.mse_loss(norm_preds_squeezed, sxr)
|
| 88 |
|
| 89 |
|
| 90 |
# Log adaptation info
|
|
|
|
| 350 |
|
| 351 |
def calculate_loss(self, preds_norm, sxr_norm, sxr_un):
|
| 352 |
base_loss = F.huber_loss(preds_norm, sxr_norm, delta=.3, reduction='none')
|
| 353 |
+
#base_loss = F.mse_loss(preds_norm, sxr_norm, reduction='none')
|
| 354 |
weights = self._get_adaptive_weights(sxr_un)
|
| 355 |
self._update_tracking(sxr_un, sxr_norm, preds_norm)
|
| 356 |
weighted_loss = base_loss * weights
|
|
|
|
| 464 |
|
| 465 |
#Huber loss
|
| 466 |
error = F.huber_loss(preds_norm, sxr_norm, delta=.3, reduction='none')
|
| 467 |
+
#error = F.mse_loss(preds_norm, sxr_norm, reduction='none')
|
| 468 |
error = error.detach().cpu().numpy()
|
| 469 |
|
| 470 |
|
forecasting/models/vit_patch_model_uncertainty.py
ADDED
|
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import deque
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torch.optim as optim
|
| 9 |
+
import torch.utils.data as data
|
| 10 |
+
import torchvision
|
| 11 |
+
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
| 12 |
+
from torchvision import transforms
|
| 13 |
+
import pytorch_lightning as pl
|
| 14 |
+
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
| 15 |
+
|
| 16 |
+
#norm = np.load("/mnt/data/ML-Ready_clean/mixed_data/SXR/normalized_sxr.npy")
|
| 17 |
+
|
| 18 |
+
def normalize_sxr(unnormalized_values, sxr_norm):
|
| 19 |
+
"""Convert from unnormalized to normalized space"""
|
| 20 |
+
log_values = torch.log10(unnormalized_values + 1e-8)
|
| 21 |
+
normalized = (log_values - float(sxr_norm[0].item())) / float(sxr_norm[1].item())
|
| 22 |
+
return normalized
|
| 23 |
+
|
| 24 |
+
def unnormalize_sxr(normalized_values, sxr_norm):
|
| 25 |
+
return 10 ** (normalized_values * float(sxr_norm[1].item()) + float(sxr_norm[0].item())) - 1e-8
|
| 26 |
+
|
| 27 |
+
class ViTUncertainty(pl.LightningModule):
|
| 28 |
+
def __init__(self, model_kwargs, sxr_norm, base_weights=None):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.model_kwargs = model_kwargs
|
| 31 |
+
self.lr = model_kwargs['lr']
|
| 32 |
+
self.save_hyperparameters()
|
| 33 |
+
filtered_kwargs = dict(model_kwargs)
|
| 34 |
+
filtered_kwargs.pop('lr', None)
|
| 35 |
+
filtered_kwargs.pop('num_classes', None)
|
| 36 |
+
self.model = VisionTransformer(**filtered_kwargs)
|
| 37 |
+
#Set the base weights based on the number of samples in each class within training data
|
| 38 |
+
self.base_weights = base_weights
|
| 39 |
+
self.adaptive_loss = SXRRegressionDynamicLoss(window_size=15000, base_weights=self.base_weights)
|
| 40 |
+
self.sxr_norm = sxr_norm
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def forward(self, x, return_attention=True):
|
| 44 |
+
return self.model(x, self.sxr_norm, return_attention=return_attention)
|
| 45 |
+
|
| 46 |
+
def forward_for_callback(self, x, return_attention=True):
|
| 47 |
+
"""Forward method compatible with AttentionMapCallback"""
|
| 48 |
+
global_flux_raw, attention_weights, patch_flux_raw, patch_error = self.forward(x, return_attention=return_attention)
|
| 49 |
+
return global_flux_raw, attention_weights
|
| 50 |
+
|
| 51 |
+
def configure_optimizers(self):
|
| 52 |
+
# Use AdamW with weight decay for better regularization
|
| 53 |
+
optimizer = torch.optim.AdamW(
|
| 54 |
+
self.parameters(),
|
| 55 |
+
lr=self.lr,
|
| 56 |
+
weight_decay=0.00001,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
scheduler = CosineAnnealingWarmRestarts(
|
| 60 |
+
optimizer,
|
| 61 |
+
T_0=50, # Restart every 20 epochs
|
| 62 |
+
T_mult=2, # Double the cycle length after each restart
|
| 63 |
+
eta_min=1e-7 # Minimum learning rate
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
return {
|
| 67 |
+
'optimizer': optimizer,
|
| 68 |
+
'lr_scheduler': {
|
| 69 |
+
'scheduler': scheduler,
|
| 70 |
+
'interval': 'epoch',
|
| 71 |
+
'frequency': 1,
|
| 72 |
+
'name': 'learning_rate'
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
# M/X Class Flare Detection Optimized Weights
|
| 77 |
+
|
| 78 |
+
def _calculate_loss(self, batch, mode="train"):
|
| 79 |
+
imgs, sxr = batch
|
| 80 |
+
raw_preds, raw_patch_contributions, raw_error = self.model(imgs,self.sxr_norm)
|
| 81 |
+
raw_preds_squeezed = torch.squeeze(raw_preds)
|
| 82 |
+
sxr_un = unnormalize_sxr(sxr, self.sxr_norm)
|
| 83 |
+
|
| 84 |
+
norm_preds_squeezed = normalize_sxr(raw_preds_squeezed, self.sxr_norm)
|
| 85 |
+
raw_error_squeezed = torch.squeeze(raw_error)
|
| 86 |
+
# Use adaptive rare event loss
|
| 87 |
+
loss, error_loss, weights = self.adaptive_loss.calculate_loss(
|
| 88 |
+
norm_preds_squeezed, sxr, sxr_un, raw_error_squeezed
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
#Also calculate huber loss for logging
|
| 92 |
+
huber_loss = F.huber_loss(norm_preds_squeezed, sxr, delta=.3)
|
| 93 |
+
#huber_loss = F.mse_loss(norm_preds_squeezed, sxr)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# Log adaptation info
|
| 98 |
+
if mode == "train":
|
| 99 |
+
# Always log learning rate (every step)
|
| 100 |
+
current_lr = self.trainer.optimizers[0].param_groups[0]['lr']
|
| 101 |
+
self.log('learning_rate', current_lr, on_step=True, on_epoch=False,
|
| 102 |
+
prog_bar=True, logger=True, sync_dist=True)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
#self.log("sparsity_entropy_loss", sparsity_or_entropy, on_step=True, on_epoch=True, )
|
| 106 |
+
self.log("train_total_loss", loss, on_step=True, on_epoch=True,
|
| 107 |
+
prog_bar=True, logger=True, sync_dist=True)
|
| 108 |
+
self.log("train_huber_loss", huber_loss, on_step=True, on_epoch=True,
|
| 109 |
+
prog_bar=True, logger=True, sync_dist=True)
|
| 110 |
+
self.log("train_error_loss", error_loss, on_step=True, on_epoch=True,
|
| 111 |
+
prog_bar=True, logger=True, sync_dist=True)
|
| 112 |
+
# Detailed diagnostics only every 200 steps
|
| 113 |
+
if self.global_step % 200 == 0:
|
| 114 |
+
multipliers = self.adaptive_loss.get_current_multipliers()
|
| 115 |
+
for key, value in multipliers.items():
|
| 116 |
+
self.log(f"adaptive/{key}", value, on_step=True, on_epoch=False)
|
| 117 |
+
|
| 118 |
+
self.log("adaptive/avg_weight", weights.mean(), on_step=True, on_epoch=False)
|
| 119 |
+
self.log("adaptive/max_weight", weights.max(), on_step=True, on_epoch=False)
|
| 120 |
+
|
| 121 |
+
if mode == "val":
|
| 122 |
+
# Validation: typically only log epoch aggregates
|
| 123 |
+
multipliers = self.adaptive_loss.get_current_multipliers()
|
| 124 |
+
for key, value in multipliers.items():
|
| 125 |
+
self.log(f"val/adaptive/{key}", value, on_step=False, on_epoch=True)
|
| 126 |
+
self.log("val_total_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
| 127 |
+
self.log("val_huber_loss", huber_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
| 128 |
+
self.log("val_error_loss", error_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
| 129 |
+
|
| 130 |
+
return loss
|
| 131 |
+
|
| 132 |
+
def training_step(self, batch, batch_idx):
|
| 133 |
+
return self._calculate_loss(batch, mode="train")
|
| 134 |
+
|
| 135 |
+
def validation_step(self, batch, batch_idx):
|
| 136 |
+
self._calculate_loss(batch, mode="val")
|
| 137 |
+
|
| 138 |
+
def test_step(self, batch, batch_idx):
|
| 139 |
+
self._calculate_loss(batch, mode="test")
|
| 140 |
+
|
| 141 |
+
def apply_wavelength_dropout(self, x, dropout_prob=0.3):
|
| 142 |
+
"""Randomly zero out some wavelengths during training"""
|
| 143 |
+
if self.training and torch.rand(1).item() < dropout_prob:
|
| 144 |
+
# x shape: [B, H, W, num_channels]
|
| 145 |
+
num_keep = torch.randint(1, self.model_kwargs['num_channels'], (1,)).item()
|
| 146 |
+
keep_indices = torch.randperm(self.model_kwargs['num_channels'])[:num_keep]
|
| 147 |
+
|
| 148 |
+
mask = torch.zeros(self.model_kwargs['num_channels'], device=x.device)
|
| 149 |
+
mask[keep_indices] = 1.0
|
| 150 |
+
|
| 151 |
+
x = x * mask.view(1, 1, 1, -1)
|
| 152 |
+
return x
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class VisionTransformer(nn.Module):
|
| 156 |
+
def __init__(
|
| 157 |
+
self,
|
| 158 |
+
embed_dim,
|
| 159 |
+
hidden_dim,
|
| 160 |
+
num_channels,
|
| 161 |
+
num_heads,
|
| 162 |
+
num_layers,
|
| 163 |
+
patch_size,
|
| 164 |
+
num_patches,
|
| 165 |
+
dropout
|
| 166 |
+
|
| 167 |
+
):
|
| 168 |
+
"""Vision Transformer that outputs flux contributions per patch.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
embed_dim: Dimensionality of the input feature vectors to the Transformer
|
| 172 |
+
hidden_dim: Dimensionality of the hidden layer in the feed-forward networks
|
| 173 |
+
within the Transformer
|
| 174 |
+
num_channels: Number of channels of the input (3 for RGB)
|
| 175 |
+
num_heads: Number of heads to use in the Multi-Head Attention block
|
| 176 |
+
num_layers: Number of layers to use in the Transformer
|
| 177 |
+
patch_size: Number of pixels that the patches have per dimension
|
| 178 |
+
num_patches: Maximum number of patches an image can have
|
| 179 |
+
dropout: Amount of dropout to apply in the feed-forward network and
|
| 180 |
+
on the input encoding
|
| 181 |
+
|
| 182 |
+
"""
|
| 183 |
+
super().__init__()
|
| 184 |
+
|
| 185 |
+
self.patch_size = patch_size
|
| 186 |
+
|
| 187 |
+
# Layers/Networks
|
| 188 |
+
self.input_layer = nn.Linear(num_channels * (patch_size ** 2), embed_dim)
|
| 189 |
+
|
| 190 |
+
self.transformer_blocks = nn.ModuleList([
|
| 191 |
+
AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout)
|
| 192 |
+
for _ in range(num_layers)
|
| 193 |
+
])
|
| 194 |
+
|
| 195 |
+
self.mlp_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1))
|
| 196 |
+
self.error_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1))
|
| 197 |
+
self.dropout = nn.Dropout(dropout)
|
| 198 |
+
|
| 199 |
+
# Parameters/Embeddings
|
| 200 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
|
| 201 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, 1 + num_patches, embed_dim))
|
| 202 |
+
self.grid_h = int(math.sqrt(num_patches))
|
| 203 |
+
self.grid_w = int(math.sqrt(num_patches))
|
| 204 |
+
self.pos_embedding_2d = nn.Parameter(torch.randn(1, self.grid_h, self.grid_w, embed_dim))
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def forward(self, x, sxr_norm, return_attention=False):
|
| 208 |
+
# Preprocess input
|
| 209 |
+
x = img_to_patch(x, self.patch_size)
|
| 210 |
+
B, T, _ = x.shape
|
| 211 |
+
x = self.input_layer(x)
|
| 212 |
+
|
| 213 |
+
# Add CLS token and positional encoding
|
| 214 |
+
#cls_token = self.cls_token.repeat(B, 1, 1)
|
| 215 |
+
#x = torch.cat([cls_token, x], dim=1)
|
| 216 |
+
#x = x + self.pos_embedding[:, : T + 1]
|
| 217 |
+
x = self._add_2d_positional_encoding(x)
|
| 218 |
+
|
| 219 |
+
# Apply Transformer blocks
|
| 220 |
+
x = self.dropout(x)
|
| 221 |
+
x = x.transpose(0, 1) # [T, B, embed_dim]
|
| 222 |
+
|
| 223 |
+
attention_weights = []
|
| 224 |
+
for block in self.transformer_blocks:
|
| 225 |
+
if return_attention:
|
| 226 |
+
x, attn_weights = block(x, return_attention=True)
|
| 227 |
+
attention_weights.append(attn_weights)
|
| 228 |
+
else:
|
| 229 |
+
x = block(x)
|
| 230 |
+
#Extract patch logits and total error
|
| 231 |
+
patch_embeddings = x.transpose(0, 1) # [B, num_patches, embed_dim]
|
| 232 |
+
patch_logits = self.mlp_head(patch_embeddings).squeeze(-1) # normalized log predictions [B, num_patches]
|
| 233 |
+
patch_error = self.error_head(patch_embeddings).squeeze(-1) # [B, num_patches]
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# --- Convert to raw SXR ---
|
| 237 |
+
mean, std = sxr_norm # in log10 space
|
| 238 |
+
patch_flux_raw = torch.clamp(10 ** (patch_logits * std + mean)- 1e-8, min=1e-15, max=1)
|
| 239 |
+
patch_error_raw = torch.clamp(10 ** (patch_error * std + mean)- 1e-8, min=1e-30, max=1)
|
| 240 |
+
|
| 241 |
+
# Sum over patches for raw global flux
|
| 242 |
+
global_flux_raw = patch_flux_raw.sum(dim=1, keepdim=True)
|
| 243 |
+
#Calculate total error as sqrt of sum of squares of patch errors
|
| 244 |
+
total_error = torch.sqrt(patch_error_raw.pow(2).sum(dim=1, keepdim=True))
|
| 245 |
+
|
| 246 |
+
# Ensure global flux is never zero (add small epsilon if needed)
|
| 247 |
+
global_flux_raw = torch.clamp(global_flux_raw, min=1e-15)
|
| 248 |
+
|
| 249 |
+
if return_attention:
|
| 250 |
+
return global_flux_raw, attention_weights, patch_flux_raw, total_error
|
| 251 |
+
else:
|
| 252 |
+
return global_flux_raw, patch_flux_raw, total_error
|
| 253 |
+
|
| 254 |
+
def _add_2d_positional_encoding(self, x):
|
| 255 |
+
"""Add learned 2D positional encoding to patch embeddings"""
|
| 256 |
+
B, T, embed_dim = x.shape
|
| 257 |
+
num_patches = T # Exclude CLS token
|
| 258 |
+
|
| 259 |
+
# Reshape patches to 2D grid: [B, grid_h, grid_w, embed_dim]
|
| 260 |
+
patch_embeddings = x.reshape(B, self.grid_h, self.grid_w, embed_dim)
|
| 261 |
+
|
| 262 |
+
# Add learned 2D positional encoding
|
| 263 |
+
# Broadcasting: [B, grid_h, grid_w, embed_dim] + [1, grid_h, grid_w, embed_dim]
|
| 264 |
+
patch_embeddings = patch_embeddings + self.pos_embedding_2d
|
| 265 |
+
|
| 266 |
+
# Reshape back to sequence format: [B, num_patches, embed_dim]
|
| 267 |
+
patch_embeddings = patch_embeddings.reshape(B, num_patches, embed_dim)
|
| 268 |
+
|
| 269 |
+
return patch_embeddings
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class AttentionBlock(nn.Module):
|
| 275 |
+
def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
|
| 276 |
+
"""Attention Block.
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
embed_dim: Dimensionality of input and attention feature vectors
|
| 280 |
+
hidden_dim: Dimensionality of hidden layer in feed-forward network
|
| 281 |
+
(usually 2-4x larger than embed_dim)
|
| 282 |
+
num_heads: Number of heads to use in the Multi-Head Attention block
|
| 283 |
+
dropout: Amount of dropout to apply in the feed-forward network
|
| 284 |
+
|
| 285 |
+
"""
|
| 286 |
+
super().__init__()
|
| 287 |
+
|
| 288 |
+
self.layer_norm_1 = nn.LayerNorm(embed_dim)
|
| 289 |
+
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=False)
|
| 290 |
+
self.layer_norm_2 = nn.LayerNorm(embed_dim)
|
| 291 |
+
self.linear = nn.Sequential(
|
| 292 |
+
nn.Linear(embed_dim, hidden_dim),
|
| 293 |
+
nn.GELU(),
|
| 294 |
+
nn.Dropout(dropout),
|
| 295 |
+
nn.Linear(hidden_dim, embed_dim),
|
| 296 |
+
nn.Dropout(dropout),
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
def forward(self, x, return_attention=False):
|
| 300 |
+
inp_x = self.layer_norm_1(x)
|
| 301 |
+
|
| 302 |
+
if return_attention:
|
| 303 |
+
attn_output, attn_weights = self.attn(inp_x, inp_x, inp_x, average_attn_weights=False)
|
| 304 |
+
x = x + attn_output
|
| 305 |
+
x = x + self.linear(self.layer_norm_2(x))
|
| 306 |
+
return x, attn_weights
|
| 307 |
+
else:
|
| 308 |
+
attn_output = self.attn(inp_x, inp_x, inp_x)[0]
|
| 309 |
+
x = x + attn_output
|
| 310 |
+
x = x + self.linear(self.layer_norm_2(x))
|
| 311 |
+
return x
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def img_to_patch(x, patch_size, flatten_channels=True):
|
| 315 |
+
"""
|
| 316 |
+
Args:
|
| 317 |
+
x: Tensor representing the image of shape [B, H, W, C]
|
| 318 |
+
patch_size: Number of pixels per dimension of the patches (integer)
|
| 319 |
+
flatten_channels: If True, the patches will be returned in a flattened format
|
| 320 |
+
as a feature vector instead of a image grid.
|
| 321 |
+
"""
|
| 322 |
+
x = x.permute(0, 3, 1, 2)
|
| 323 |
+
B, C, H, W = x.shape
|
| 324 |
+
x = x.reshape(B, C, H // patch_size, patch_size, W // patch_size, patch_size)
|
| 325 |
+
x = x.permute(0, 2, 4, 1, 3, 5) # [B, H', W', C, p_H, p_W]
|
| 326 |
+
x = x.flatten(1, 2) # [B, H'*W', C, p_H, p_W]
|
| 327 |
+
if flatten_channels:
|
| 328 |
+
x = x.flatten(2, 4) # [B, H'*W', C*p_H*p_W]
|
| 329 |
+
return x
|
| 330 |
+
|
| 331 |
+
class SXRRegressionDynamicLoss:
|
| 332 |
+
def __init__(self, window_size=15000, base_weights=None):
|
| 333 |
+
self.c_threshold = 1e-6
|
| 334 |
+
self.m_threshold = 1e-5
|
| 335 |
+
self.x_threshold = 1e-4
|
| 336 |
+
|
| 337 |
+
self.window_size = window_size
|
| 338 |
+
self.quiet_errors = deque(maxlen=window_size)
|
| 339 |
+
self.c_errors = deque(maxlen=window_size)
|
| 340 |
+
self.m_errors = deque(maxlen=window_size)
|
| 341 |
+
self.x_errors = deque(maxlen=window_size)
|
| 342 |
+
|
| 343 |
+
#Calculate the base weights based on the number of samples in each class within training data
|
| 344 |
+
if base_weights is None:
|
| 345 |
+
self.base_weights = self._get_base_weights()
|
| 346 |
+
else:
|
| 347 |
+
self.base_weights = base_weights
|
| 348 |
+
|
| 349 |
+
def _get_base_weights(self):
|
| 350 |
+
#Calculate the base weights based on the number of samples in each class within training data
|
| 351 |
+
return {
|
| 352 |
+
'quiet': 1.5, # Increase from current value
|
| 353 |
+
'c_class': 1.0, # Keep as baseline
|
| 354 |
+
'm_class': 8.0, # Maintain M-class focus
|
| 355 |
+
'x_class': 20.0 # Maintain X-class focus
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
def calculate_loss(self, preds_norm, sxr_norm, sxr_un, raw_error):
|
| 359 |
+
base_loss = F.huber_loss(preds_norm, sxr_norm, delta=.3, reduction='none')
|
| 360 |
+
#Calculate loss between error and raw error
|
| 361 |
+
error = abs(sxr_norm - preds_norm)
|
| 362 |
+
error_loss = F.huber_loss(raw_error, error, reduction='none')
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
#base_loss = F.mse_loss(preds_norm, sxr_norm, reduction='none')
|
| 366 |
+
weights = self._get_adaptive_weights(sxr_un)
|
| 367 |
+
self._update_tracking(sxr_un, sxr_norm, preds_norm)
|
| 368 |
+
weighted_loss = base_loss * weights
|
| 369 |
+
error_weight = .2
|
| 370 |
+
error_loss = error_weight * error_loss.mean()
|
| 371 |
+
loss = weighted_loss.mean() + error_loss
|
| 372 |
+
return loss, error_loss, weights
|
| 373 |
+
|
| 374 |
+
def _get_adaptive_weights(self, sxr_un):
|
| 375 |
+
device = sxr_un.device
|
| 376 |
+
|
| 377 |
+
# Get continuous multipliers per class with custom params
|
| 378 |
+
quiet_mult = self._get_performance_multiplier(
|
| 379 |
+
self.quiet_errors, max_multiplier=1.5, min_multiplier=0.6, sensitivity=0.05, sxrclass='quiet' # Was 0.2
|
| 380 |
+
)
|
| 381 |
+
c_mult = self._get_performance_multiplier(
|
| 382 |
+
self.c_errors, max_multiplier=2, min_multiplier=0.7, sensitivity=0.08, sxrclass='c_class' # Was 0.3
|
| 383 |
+
)
|
| 384 |
+
m_mult = self._get_performance_multiplier(
|
| 385 |
+
self.m_errors, max_multiplier=5.0, min_multiplier=0.8, sensitivity=0.1, sxrclass='m_class' # Was 0.4
|
| 386 |
+
)
|
| 387 |
+
x_mult = self._get_performance_multiplier(
|
| 388 |
+
self.x_errors, max_multiplier=8.0, min_multiplier=0.8, sensitivity=0.12, sxrclass='x_class' # Was 0.5
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
quiet_weight = self.base_weights['quiet'] * quiet_mult
|
| 392 |
+
c_weight = self.base_weights['c_class'] * c_mult
|
| 393 |
+
m_weight = self.base_weights['m_class'] * m_mult
|
| 394 |
+
x_weight = self.base_weights['x_class'] * x_mult
|
| 395 |
+
|
| 396 |
+
weights = torch.ones_like(sxr_un, device=device)
|
| 397 |
+
weights = torch.where(sxr_un < self.c_threshold, quiet_weight, weights)
|
| 398 |
+
weights = torch.where((sxr_un >= self.c_threshold) & (sxr_un < self.m_threshold), c_weight, weights)
|
| 399 |
+
weights = torch.where((sxr_un >= self.m_threshold) & (sxr_un < self.x_threshold), m_weight, weights)
|
| 400 |
+
weights = torch.where(sxr_un >= self.x_threshold, x_weight, weights)
|
| 401 |
+
|
| 402 |
+
# Normalize so mean weight ~1.0 (optional, helps stability)
|
| 403 |
+
mean_weight = torch.mean(weights)
|
| 404 |
+
weights = weights / (mean_weight)
|
| 405 |
+
|
| 406 |
+
# Clamp extreme weights
|
| 407 |
+
#weights = torch.clamp(weights, min=0.01, max=40.0)
|
| 408 |
+
|
| 409 |
+
# Save for logging
|
| 410 |
+
self.current_multipliers = {
|
| 411 |
+
'quiet_mult': quiet_mult,
|
| 412 |
+
'c_mult': c_mult,
|
| 413 |
+
'm_mult': m_mult,
|
| 414 |
+
'x_mult': x_mult,
|
| 415 |
+
'quiet_weight': quiet_weight,
|
| 416 |
+
'c_weight': c_weight,
|
| 417 |
+
'm_weight': m_weight,
|
| 418 |
+
'x_weight': x_weight
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
return weights
|
| 422 |
+
|
| 423 |
+
def _get_performance_multiplier(self, error_history, max_multiplier=10.0, min_multiplier=0.5, sensitivity=3.0, sxrclass='quiet'):
|
| 424 |
+
"""Class-dependent performance multiplier"""
|
| 425 |
+
|
| 426 |
+
class_params = {
|
| 427 |
+
'quiet': {'min_samples': 2500, 'recent_window': 800},
|
| 428 |
+
'c_class': {'min_samples': 2500, 'recent_window': 800},
|
| 429 |
+
'm_class': {'min_samples': 1500, 'recent_window': 500},
|
| 430 |
+
'x_class': {'min_samples': 1000, 'recent_window': 300}
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
# target_errors = {
|
| 434 |
+
# 'quiet': 0.15,
|
| 435 |
+
# 'c_class': 0.08,
|
| 436 |
+
# 'm_class': 0.05,
|
| 437 |
+
# 'x_class': 0.05
|
| 438 |
+
# }
|
| 439 |
+
|
| 440 |
+
#target = target_errors[sxrclass]
|
| 441 |
+
|
| 442 |
+
if len(error_history) < class_params[sxrclass]['min_samples']:
|
| 443 |
+
return 1.0
|
| 444 |
+
|
| 445 |
+
recent_window = class_params[sxrclass]['recent_window']
|
| 446 |
+
recent = np.mean(list(error_history)[-recent_window:])
|
| 447 |
+
overall = np.mean(list(error_history))
|
| 448 |
+
|
| 449 |
+
# if overall < 1e-10:
|
| 450 |
+
# return 1.0
|
| 451 |
+
|
| 452 |
+
ratio = recent / overall
|
| 453 |
+
multiplier = np.exp(sensitivity * (ratio - 1))
|
| 454 |
+
return np.clip(multiplier, min_multiplier, max_multiplier)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
# if len(error_history) < class_params[sxrclass]['min_samples']:
|
| 459 |
+
# return 1.0
|
| 460 |
+
|
| 461 |
+
# recent = np.mean(list(error_history)[-class_params[sxrclass]['recent_window']:])
|
| 462 |
+
|
| 463 |
+
# if recent > target: # Not meeting target - increase weight
|
| 464 |
+
# excess_error = (recent - target) / target
|
| 465 |
+
# multiplier = 1.0 + sensitivity * excess_error
|
| 466 |
+
# else: # Meeting/exceeding target
|
| 467 |
+
# if sxrclass == 'quiet':
|
| 468 |
+
# # Can reduce quiet weight significantly
|
| 469 |
+
# multiplier = max(0.5, 1.0 - 0.5 * (target - recent) / target)
|
| 470 |
+
# else:
|
| 471 |
+
# # Keep important classes weighted well even when performing good
|
| 472 |
+
# multiplier = max(0.8, 1.0 - 0.2 * (target - recent) / target)
|
| 473 |
+
|
| 474 |
+
# return np.clip(multiplier, min_multiplier, max_multiplier)
|
| 475 |
+
|
| 476 |
+
def _update_tracking(self, sxr_un, sxr_norm, preds_norm):
|
| 477 |
+
sxr_un_np = sxr_un.detach().cpu().numpy()
|
| 478 |
+
|
| 479 |
+
#Huber loss
|
| 480 |
+
error = F.huber_loss(preds_norm, sxr_norm, delta=.3, reduction='none')
|
| 481 |
+
#error = F.mse_loss(preds_norm, sxr_norm, reduction='none')
|
| 482 |
+
error = error.detach().cpu().numpy()
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
quiet_mask = sxr_un_np < self.c_threshold
|
| 486 |
+
if quiet_mask.sum() > 0:
|
| 487 |
+
self.quiet_errors.append(float(np.mean(error[quiet_mask])))
|
| 488 |
+
|
| 489 |
+
c_mask = (sxr_un_np >= self.c_threshold) & (sxr_un_np < self.m_threshold)
|
| 490 |
+
if c_mask.sum() > 0:
|
| 491 |
+
self.c_errors.append(float(np.mean(error[c_mask])))
|
| 492 |
+
|
| 493 |
+
m_mask = (sxr_un_np >= self.m_threshold) & (sxr_un_np < self.x_threshold)
|
| 494 |
+
if m_mask.sum() > 0:
|
| 495 |
+
self.m_errors.append(float(np.mean(error[m_mask])))
|
| 496 |
+
|
| 497 |
+
x_mask = sxr_un_np >= self.x_threshold
|
| 498 |
+
if x_mask.sum() > 0:
|
| 499 |
+
self.x_errors.append(float(np.mean(error[x_mask])))
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def get_current_multipliers(self):
|
| 503 |
+
"""Get current performance multipliers for logging"""
|
| 504 |
+
return {
|
| 505 |
+
'quiet_mult': self._get_performance_multiplier(
|
| 506 |
+
self.quiet_errors, max_multiplier=1.5, min_multiplier=0.6, sensitivity=0.2, sxrclass='quiet'
|
| 507 |
+
),
|
| 508 |
+
'c_mult': self._get_performance_multiplier(
|
| 509 |
+
self.c_errors, max_multiplier=2, min_multiplier=0.7, sensitivity=0.3, sxrclass='c_class'
|
| 510 |
+
),
|
| 511 |
+
'm_mult': self._get_performance_multiplier(
|
| 512 |
+
self.m_errors, max_multiplier=5.0, min_multiplier=0.8, sensitivity=0.8, sxrclass='m_class'
|
| 513 |
+
),
|
| 514 |
+
'x_mult': self._get_performance_multiplier(
|
| 515 |
+
self.x_errors, max_multiplier=8.0, min_multiplier=0.8, sensitivity=1.0, sxrclass='x_class'
|
| 516 |
+
),
|
| 517 |
+
'quiet_count': len(self.quiet_errors),
|
| 518 |
+
'c_count': len(self.c_errors),
|
| 519 |
+
'm_count': len(self.m_errors),
|
| 520 |
+
'x_count': len(self.x_errors),
|
| 521 |
+
'quiet_error': np.mean(self.quiet_errors) if self.quiet_errors else 0.0,
|
| 522 |
+
'c_error': np.mean(self.c_errors) if self.c_errors else 0.0,
|
| 523 |
+
'm_error': np.mean(self.m_errors) if self.m_errors else 0.0,
|
| 524 |
+
'x_error': np.mean(self.x_errors) if self.x_errors else 0.0,
|
| 525 |
+
'quiet_weight': getattr(self, 'current_multipliers', {}).get('quiet_weight', 0.0),
|
| 526 |
+
'c_weight': getattr(self, 'current_multipliers', {}).get('c_weight', 0.0),
|
| 527 |
+
'm_weight': getattr(self, 'current_multipliers', {}).get('m_weight', 0.0),
|
| 528 |
+
'x_weight': getattr(self, 'current_multipliers', {}).get('x_weight', 0.0)
|
| 529 |
+
}
|
forecasting/training/callback.py
CHANGED
|
@@ -132,7 +132,12 @@ class AttentionMapCallback(Callback):
|
|
| 132 |
except:
|
| 133 |
# For ViT patch model, we need to call the model's forward method directly
|
| 134 |
if hasattr(pl_module, 'model') and hasattr(pl_module.model, 'forward'):
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
else:
|
| 137 |
outputs, attention_weights = pl_module.forward_for_callback(imgs, return_attention=True)
|
| 138 |
|
|
|
|
| 132 |
except:
|
| 133 |
# For ViT patch model, we need to call the model's forward method directly
|
| 134 |
if hasattr(pl_module, 'model') and hasattr(pl_module.model, 'forward'):
|
| 135 |
+
try:
|
| 136 |
+
print("Using model's forward method")
|
| 137 |
+
outputs, attention_weights, _ = pl_module.model(imgs, pl_module.sxr_norm, return_attention=True)
|
| 138 |
+
except:
|
| 139 |
+
print("Using model's forward method failed")
|
| 140 |
+
outputs, attention_weights = pl_module.forward_for_callback(imgs, return_attention=True)
|
| 141 |
else:
|
| 142 |
outputs, attention_weights = pl_module.forward_for_callback(imgs, return_attention=True)
|
| 143 |
|
forecasting/training/config5.yaml
CHANGED
|
@@ -12,7 +12,7 @@ batch_size: 64
|
|
| 12 |
epochs: 250
|
| 13 |
oversample: false
|
| 14 |
balance_strategy: "upsample_minority"
|
| 15 |
-
calculate_base_weights:
|
| 16 |
|
| 17 |
megsai:
|
| 18 |
architecture: "cnn"
|
|
@@ -71,5 +71,5 @@ wandb:
|
|
| 71 |
- aia
|
| 72 |
- sxr
|
| 73 |
- regression
|
| 74 |
-
wb_name: vit-
|
| 75 |
notes: Regression from AIA images (6 channels) to GOES SXR flux
|
|
|
|
| 12 |
epochs: 250
|
| 13 |
oversample: false
|
| 14 |
balance_strategy: "upsample_minority"
|
| 15 |
+
calculate_base_weights: true # Whether to calculate class-based weights for loss function
|
| 16 |
|
| 17 |
megsai:
|
| 18 |
architecture: "cnn"
|
|
|
|
| 71 |
- aia
|
| 72 |
- sxr
|
| 73 |
- regression
|
| 74 |
+
wb_name: vit-mse-base-weights
|
| 75 |
notes: Regression from AIA images (6 channels) to GOES SXR flux
|
forecasting/training/config6.yaml
CHANGED
|
@@ -71,5 +71,5 @@ wandb:
|
|
| 71 |
- aia
|
| 72 |
- sxr
|
| 73 |
- regression
|
| 74 |
-
wb_name: vit-
|
| 75 |
notes: Regression from AIA images (6 channels) to GOES SXR flux
|
|
|
|
| 71 |
- aia
|
| 72 |
- sxr
|
| 73 |
- regression
|
| 74 |
+
wb_name: vit-mse-claude
|
| 75 |
notes: Regression from AIA images (6 channels) to GOES SXR flux
|
forecasting/training/train.py
CHANGED
|
@@ -24,6 +24,7 @@ from forecasting.data_loaders.SDOAIA_dataloader import AIA_GOESDataModule
|
|
| 24 |
from forecasting.models.vision_transformer_custom import ViT
|
| 25 |
from forecasting.models.linear_and_hybrid import LinearIrradianceModel, HybridIrradianceModel
|
| 26 |
from forecasting.models.vit_patch_model import ViT as ViTPatch
|
|
|
|
| 27 |
from forecasting.models import FusionViTHybrid
|
| 28 |
from callback import ImagePredictionLogger_SXR, AttentionMapCallback
|
| 29 |
from pytorch_lightning.callbacks import Callback
|
|
@@ -344,6 +345,10 @@ elif config_data['selected_model'] == 'ViTPatch':
|
|
| 344 |
base_weights = get_base_weights(data_loader, sxr_norm) if config_data.get('calculate_base_weights', True) else None
|
| 345 |
model = ViTPatch(model_kwargs=config_data['vit_custom'], sxr_norm = sxr_norm, base_weights=base_weights)
|
| 346 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
elif config_data['selected_model'] == 'FusionViTHybrid':
|
| 348 |
# Expect a 'fusion' section in YAML
|
| 349 |
fusion_cfg = config_data.get('fusion', {})
|
|
|
|
| 24 |
from forecasting.models.vision_transformer_custom import ViT
|
| 25 |
from forecasting.models.linear_and_hybrid import LinearIrradianceModel, HybridIrradianceModel
|
| 26 |
from forecasting.models.vit_patch_model import ViT as ViTPatch
|
| 27 |
+
from forecasting.models.vit_patch_model_uncertainty import ViTUncertainty
|
| 28 |
from forecasting.models import FusionViTHybrid
|
| 29 |
from callback import ImagePredictionLogger_SXR, AttentionMapCallback
|
| 30 |
from pytorch_lightning.callbacks import Callback
|
|
|
|
| 345 |
base_weights = get_base_weights(data_loader, sxr_norm) if config_data.get('calculate_base_weights', True) else None
|
| 346 |
model = ViTPatch(model_kwargs=config_data['vit_custom'], sxr_norm = sxr_norm, base_weights=base_weights)
|
| 347 |
|
| 348 |
+
elif config_data['selected_model'] == 'ViTUncertainty':
|
| 349 |
+
base_weights = get_base_weights(data_loader, sxr_norm) if config_data.get('calculate_base_weights', True) else None
|
| 350 |
+
model = ViTUncertainty(model_kwargs=config_data['vit_custom'], sxr_norm = sxr_norm, base_weights=base_weights)
|
| 351 |
+
|
| 352 |
elif config_data['selected_model'] == 'FusionViTHybrid':
|
| 353 |
# Expect a 'fusion' section in YAML
|
| 354 |
fusion_cfg = config_data.get('fusion', {})
|
forecasting/training/vituncertainty.yaml
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
#Base directories - change these to switch datasets
|
| 3 |
+
base_data_dir: "/mnt/data/COMBINED" # Change this line for different datasets
|
| 4 |
+
base_checkpoint_dir: "/mnt/data/COMBINED" # Change this line for different datasets
|
| 5 |
+
wavelengths: [94, 131, 171, 193, 211, 304] # AIA wavelengths in Angstroms
|
| 6 |
+
|
| 7 |
+
# GPU configuration
|
| 8 |
+
gpu_id: 0 # GPU device ID to use (0, 1, 2, etc.) or -1 for CPU only
|
| 9 |
+
# Model configuration
|
| 10 |
+
selected_model: "ViTUncertainty" # Options: "hybrid", "vit", "fusion", "vitpatch"
|
| 11 |
+
batch_size: 64
|
| 12 |
+
epochs: 250
|
| 13 |
+
oversample: false
|
| 14 |
+
balance_strategy: "upsample_minority"
|
| 15 |
+
calculate_base_weights: false # Whether to calculate class-based weights for loss function
|
| 16 |
+
|
| 17 |
+
megsai:
|
| 18 |
+
architecture: "cnn"
|
| 19 |
+
seed: 42
|
| 20 |
+
lr: 0.0001
|
| 21 |
+
cnn_model: "updated"
|
| 22 |
+
cnn_dp: 0.2
|
| 23 |
+
weight_decay: 1e-5
|
| 24 |
+
cosine_restart_T0: 50
|
| 25 |
+
cosine_restart_Tmult: 2
|
| 26 |
+
cosine_eta_min: 1e-7
|
| 27 |
+
|
| 28 |
+
vit_custom:
|
| 29 |
+
embed_dim: 512
|
| 30 |
+
num_channels: 6
|
| 31 |
+
num_classes: 2
|
| 32 |
+
patch_size: 16
|
| 33 |
+
num_patches: 1024
|
| 34 |
+
hidden_dim: 512
|
| 35 |
+
num_heads: 8
|
| 36 |
+
num_layers: 6
|
| 37 |
+
dropout: 0.1
|
| 38 |
+
lr: 0.0001
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
fusion:
|
| 42 |
+
scalar_branch: "hybrid" # or "linear"
|
| 43 |
+
lr: 0.0001
|
| 44 |
+
lambda_vit_to_target: 0.3
|
| 45 |
+
lambda_scalar_to_target: 0.1
|
| 46 |
+
learnable_gate: true
|
| 47 |
+
gate_init_bias: 5.0
|
| 48 |
+
scalar_kwargs:
|
| 49 |
+
d_input: 6
|
| 50 |
+
d_output: 1
|
| 51 |
+
cnn_model: "updated"
|
| 52 |
+
cnn_dp: 0.75
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Data paths (automatically constructed from base directories)
|
| 56 |
+
data:
|
| 57 |
+
aia_dir:
|
| 58 |
+
"${base_data_dir}/AIA-SPLIT"
|
| 59 |
+
sxr_dir:
|
| 60 |
+
"${base_data_dir}/SXR-SPLIT"
|
| 61 |
+
sxr_norm_path:
|
| 62 |
+
"${base_data_dir}/SXR-SPLIT/normalized_sxr.npy"
|
| 63 |
+
checkpoints_dir:
|
| 64 |
+
"${base_checkpoint_dir}/new-checkpoint/"
|
| 65 |
+
|
| 66 |
+
wandb:
|
| 67 |
+
entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
|
| 68 |
+
project: Model Testing
|
| 69 |
+
job_type: training
|
| 70 |
+
tags:
|
| 71 |
+
- aia
|
| 72 |
+
- sxr
|
| 73 |
+
- regression
|
| 74 |
+
wb_name: vit-uncertainty-claude
|
| 75 |
+
notes: Regression from AIA images (6 channels) to GOES SXR flux
|