Commit ·
7bbc799
1
Parent(s): 7e470a6
Updated splitting function so now it can do SXR and AIA individually... it can also repartition an already partioned folder... updated the megs models to include warm restarts, weight decay, and adaptive weights
Browse files
data/split_data.py
CHANGED
|
@@ -1,72 +1,183 @@
|
|
| 1 |
import os
|
| 2 |
import pandas as pd
|
| 3 |
import shutil
|
|
|
|
| 4 |
from datetime import datetime
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
for split in ["train", "val", "test"]:
|
| 18 |
-
os.makedirs(os.path.join(
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
try:
|
| 36 |
-
aia_time = pd.to_datetime(file.split(".")[0])
|
| 37 |
-
except ValueError:
|
| 38 |
-
print(f"Skipping file {file}: Invalid timestamp format")
|
| 39 |
-
continue
|
| 40 |
-
|
| 41 |
-
# Determine if the file is during a flare event
|
| 42 |
-
is_flaring = any(start <= aia_time <= end for start, end in flaring_eve_list)
|
| 43 |
-
base_dir = flares_event_dir if is_flaring else non_flares_event_dir
|
| 44 |
-
|
| 45 |
-
month = aia_time.month
|
| 46 |
-
|
| 47 |
-
if month in [2, 3, 4, 5, 6, 7, 9, 10, 11, 12]:
|
| 48 |
-
split_dir = "train"
|
| 49 |
-
elif month == 1:
|
| 50 |
-
split_dir = "val"
|
| 51 |
-
elif month == 8:
|
| 52 |
-
split_dir = "test"
|
| 53 |
-
else:
|
| 54 |
-
print(f"Skipping file {file}: Unexpected month {month}")
|
| 55 |
-
continue
|
| 56 |
-
|
| 57 |
-
# Copy file to appropriate directory
|
| 58 |
-
src_aia = os.path.join(aia_data_dir, file)
|
| 59 |
-
src_sxr = os.path.join(sxr_data_dir, file)
|
| 60 |
-
dst_aia = os.path.join(base_dir, "AIA", split_dir, file)
|
| 61 |
-
dst_sxr = os.path.join(base_dir, "SXR", split_dir, file)
|
| 62 |
-
|
| 63 |
-
if not os.path.exists(dst_aia):
|
| 64 |
-
shutil.copy(src_aia, dst_aia)
|
| 65 |
-
print(f"Copied {file} to {dst_aia} and {dst_sxr}")
|
| 66 |
else:
|
| 67 |
-
print(
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
else:
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import pandas as pd
|
| 3 |
import shutil
|
| 4 |
+
import argparse
|
| 5 |
from datetime import datetime
|
| 6 |
+
from pathlib import Path
|
| 7 |
|
| 8 |
+
def split_data(input_folder, output_dir, data_type, flare_events_csv=None, repartition=False):
|
| 9 |
+
"""
|
| 10 |
+
Split data from input folder into train/val/test based on month.
|
| 11 |
+
Optionally use flare events for additional classification.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
input_folder (str): Path to the input folder containing data files
|
| 15 |
+
output_dir (str): Path to the output directory where split data will be saved
|
| 16 |
+
data_type (str): Type of data ('aia' or 'sxr')
|
| 17 |
+
flare_events_csv (str, optional): Path to the flare events CSV file
|
| 18 |
+
repartition (bool): If True, treat input_folder as already partitioned (has train/val/test subdirs)
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
# Validate input folder
|
| 22 |
+
if not os.path.exists(input_folder):
|
| 23 |
+
raise ValueError(f"Input folder does not exist: {input_folder}")
|
| 24 |
+
|
| 25 |
+
# Validate data type
|
| 26 |
+
if data_type.lower() not in ['aia', 'sxr']:
|
| 27 |
+
raise ValueError("data_type must be 'aia' or 'sxr'")
|
| 28 |
+
|
| 29 |
+
# Create output directory structure
|
| 30 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 31 |
for split in ["train", "val", "test"]:
|
| 32 |
+
os.makedirs(os.path.join(output_dir, split), exist_ok=True)
|
| 33 |
+
|
| 34 |
+
print(f"Processing {data_type.upper()} data from: {input_folder}")
|
| 35 |
+
print(f"Output directory: {output_dir}")
|
| 36 |
+
|
| 37 |
+
# Load flare events if provided
|
| 38 |
+
flaring_eve_list = []
|
| 39 |
+
if flare_events_csv and os.path.exists(flare_events_csv):
|
| 40 |
+
print(f"Loading flare events from: {flare_events_csv}")
|
| 41 |
+
flare_event = pd.read_csv(flare_events_csv)
|
| 42 |
+
|
| 43 |
+
# Create list of flare event time ranges
|
| 44 |
+
for i, row in flare_event.iterrows():
|
| 45 |
+
start_time = pd.to_datetime(row['event_starttime'])
|
| 46 |
+
end_time = pd.to_datetime(row['event_endtime'])
|
| 47 |
+
flaring_eve_list.append((start_time, end_time))
|
| 48 |
+
print(f"Loaded {len(flaring_eve_list)} flare events")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
else:
|
| 50 |
+
print("No flare events CSV provided or file not found. Skipping flare classification.")
|
| 51 |
+
|
| 52 |
+
# Get list of files in input folder
|
| 53 |
+
if repartition:
|
| 54 |
+
# For repartitioning, collect files from train/val/test subdirectories
|
| 55 |
+
data_list = []
|
| 56 |
+
for split in ["train", "val", "test"]:
|
| 57 |
+
split_dir = os.path.join(input_folder, split)
|
| 58 |
+
if os.path.exists(split_dir):
|
| 59 |
+
split_files = os.listdir(split_dir)
|
| 60 |
+
# Add split information to each file for tracking
|
| 61 |
+
for file in split_files:
|
| 62 |
+
data_list.append((file, split))
|
| 63 |
+
print(f"Found {len(split_files)} files in {split}/ directory")
|
| 64 |
+
print(f"Total files to repartition: {len(data_list)}")
|
| 65 |
else:
|
| 66 |
+
# For normal splitting, get files directly from input folder
|
| 67 |
+
data_list = os.listdir(input_folder)
|
| 68 |
+
print(f"Found {len(data_list)} files to process")
|
| 69 |
+
|
| 70 |
+
moved_count = 0
|
| 71 |
+
skipped_count = 0
|
| 72 |
+
|
| 73 |
+
for file_info in data_list:
|
| 74 |
+
if repartition:
|
| 75 |
+
file, original_split = file_info
|
| 76 |
+
else:
|
| 77 |
+
file = file_info
|
| 78 |
+
original_split = None
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
# Extract timestamp from filename (assuming format like "2012-01-01T00:00:00.npy")
|
| 82 |
+
file_time = pd.to_datetime(file.split(".")[0])
|
| 83 |
+
except ValueError:
|
| 84 |
+
print(f"Skipping file {file}: Invalid timestamp format")
|
| 85 |
+
skipped_count += 1
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
# Determine if the file is during a flare event (if flare events are available)
|
| 89 |
+
is_flaring = False
|
| 90 |
+
if flaring_eve_list:
|
| 91 |
+
is_flaring = any(start <= file_time <= end for start, end in flaring_eve_list)
|
| 92 |
+
|
| 93 |
+
# Determine split based on month
|
| 94 |
+
month = file_time.month
|
| 95 |
+
|
| 96 |
+
if month in [2, 3, 4, 5, 6, 7, 9, 10, 11, 12]:
|
| 97 |
+
new_split_dir = "train"
|
| 98 |
+
elif month == 1:
|
| 99 |
+
new_split_dir = "val"
|
| 100 |
+
elif month == 8:
|
| 101 |
+
new_split_dir = "test"
|
| 102 |
+
else:
|
| 103 |
+
print(f"Skipping file {file}: Unexpected month {month}")
|
| 104 |
+
skipped_count += 1
|
| 105 |
+
continue
|
| 106 |
+
|
| 107 |
+
# Determine source and destination paths
|
| 108 |
+
if repartition:
|
| 109 |
+
src_path = os.path.join(input_folder, original_split, file)
|
| 110 |
+
else:
|
| 111 |
+
src_path = os.path.join(input_folder, file)
|
| 112 |
+
|
| 113 |
+
dst_path = os.path.join(output_dir, new_split_dir, file)
|
| 114 |
+
|
| 115 |
+
# Skip if file is already in the correct split and we're repartitioning
|
| 116 |
+
if repartition and original_split == new_split_dir and os.path.exists(dst_path):
|
| 117 |
+
print(f"File {file} already in correct split ({new_split_dir}), skipping.")
|
| 118 |
+
skipped_count += 1
|
| 119 |
+
continue
|
| 120 |
+
|
| 121 |
+
if not os.path.exists(dst_path):
|
| 122 |
+
try:
|
| 123 |
+
shutil.move(src_path, dst_path)
|
| 124 |
+
if repartition:
|
| 125 |
+
if flaring_eve_list:
|
| 126 |
+
print(f"Moved {file} from {original_split}/ to {new_split_dir}/ (flaring: {is_flaring})")
|
| 127 |
+
else:
|
| 128 |
+
print(f"Moved {file} from {original_split}/ to {new_split_dir}/")
|
| 129 |
+
else:
|
| 130 |
+
if flaring_eve_list:
|
| 131 |
+
print(f"Moved {file} to {new_split_dir}/ (flaring: {is_flaring})")
|
| 132 |
+
else:
|
| 133 |
+
print(f"Moved {file} to {new_split_dir}/")
|
| 134 |
+
moved_count += 1
|
| 135 |
+
except Exception as e:
|
| 136 |
+
print(f"Error moving {file}: {e}")
|
| 137 |
+
skipped_count += 1
|
| 138 |
+
else:
|
| 139 |
+
print(f"File {dst_path} already exists, skipping move.")
|
| 140 |
+
skipped_count += 1
|
| 141 |
+
|
| 142 |
+
print(f"\nProcessing complete!")
|
| 143 |
+
print(f"Files moved: {moved_count}")
|
| 144 |
+
print(f"Files skipped: {skipped_count}")
|
| 145 |
+
print(f"Total files processed: {moved_count + skipped_count}")
|
| 146 |
+
|
| 147 |
+
def main():
|
| 148 |
+
parser = argparse.ArgumentParser(description='Split AIA or SXR data into train/val/test sets based on month')
|
| 149 |
+
parser.add_argument('--input_folder', type=str, required=True,
|
| 150 |
+
help='Path to the input folder containing data files (or partitioned folder for repartition)')
|
| 151 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
| 152 |
+
help='Path to the output directory where split data will be saved')
|
| 153 |
+
parser.add_argument('--data_type', type=str, choices=['aia', 'sxr'], required=True,
|
| 154 |
+
help='Type of data: "aia" or "sxr"')
|
| 155 |
+
parser.add_argument('--flare_events_csv', type=str, default=None,
|
| 156 |
+
help='Path to the flare events CSV file (optional)')
|
| 157 |
+
parser.add_argument('--repartition', action='store_true',
|
| 158 |
+
help='Repartition an already partitioned folder (input_folder should have train/val/test subdirs)')
|
| 159 |
+
|
| 160 |
+
args = parser.parse_args()
|
| 161 |
+
|
| 162 |
+
# Convert to absolute paths
|
| 163 |
+
input_folder = os.path.abspath(args.input_folder)
|
| 164 |
+
output_dir = os.path.abspath(args.output_dir)
|
| 165 |
+
flare_events_csv = os.path.abspath(args.flare_events_csv) if args.flare_events_csv else None
|
| 166 |
+
|
| 167 |
+
# Validate repartition mode
|
| 168 |
+
if args.repartition:
|
| 169 |
+
# Check if input folder has train/val/test subdirectories
|
| 170 |
+
expected_dirs = ['train', 'val', 'test']
|
| 171 |
+
missing_dirs = []
|
| 172 |
+
for dir_name in expected_dirs:
|
| 173 |
+
if not os.path.exists(os.path.join(input_folder, dir_name)):
|
| 174 |
+
missing_dirs.append(dir_name)
|
| 175 |
+
|
| 176 |
+
if missing_dirs:
|
| 177 |
+
print(f"Warning: Input folder is missing expected subdirectories: {missing_dirs}")
|
| 178 |
+
print("Continuing with available directories...")
|
| 179 |
+
|
| 180 |
+
split_data(input_folder, output_dir, args.data_type, flare_events_csv, args.repartition)
|
| 181 |
|
| 182 |
+
if __name__ == "__main__":
|
| 183 |
+
main()
|
forecasting/data_loaders/sxr_normalization.py
CHANGED
|
@@ -51,7 +51,7 @@ def compute_sxr_norm(sxr_dir):
|
|
| 51 |
|
| 52 |
if __name__ == "__main__":
|
| 53 |
# Update this path to your real data SXR directory
|
| 54 |
-
sxr_dir = "/mnt/data/ML-
|
| 55 |
sxr_norm = compute_sxr_norm(sxr_dir)
|
| 56 |
-
np.save("/mnt/data/ML-
|
| 57 |
#print(f"Saved SXR normalization to /mnt/data/ML-Ready-Data-No-Intensity-Cut/normalized_sxr")
|
|
|
|
| 51 |
|
| 52 |
if __name__ == "__main__":
|
| 53 |
# Update this path to your real data SXR directory
|
| 54 |
+
sxr_dir = "/mnt/data/ML-READY/SXR/train" # Replace with actual path
|
| 55 |
sxr_norm = compute_sxr_norm(sxr_dir)
|
| 56 |
+
np.save("/mnt/data/ML-READY/SXR/normalized_sxr.npy", sxr_norm)
|
| 57 |
#print(f"Saved SXR normalization to /mnt/data/ML-Ready-Data-No-Intensity-Cut/normalized_sxr")
|
forecasting/inference/evaluation.py
CHANGED
|
@@ -890,7 +890,7 @@ class SolarFlareEvaluator:
|
|
| 890 |
# Sort frame paths by timestamp to ensure correct order
|
| 891 |
frame_paths.sort(key=lambda x: os.path.basename(x))
|
| 892 |
|
| 893 |
-
movie_path = os.path.join(self.output_dir, "
|
| 894 |
with imageio.get_writer(movie_path, fps=30, codec='libx264', format='ffmpeg') as writer:
|
| 895 |
for frame_path in frame_paths:
|
| 896 |
if os.path.exists(frame_path):
|
|
|
|
| 890 |
# Sort frame paths by timestamp to ensure correct order
|
| 891 |
frame_paths.sort(key=lambda x: os.path.basename(x))
|
| 892 |
|
| 893 |
+
movie_path = os.path.join(self.output_dir, f"AIA_{timestamps[0].split('T')[0]}.mp4")
|
| 894 |
with imageio.get_writer(movie_path, fps=30, codec='libx264', format='ffmpeg') as writer:
|
| 895 |
for frame_path in frame_paths:
|
| 896 |
if os.path.exists(frame_path):
|
forecasting/models/base_model.py
CHANGED
|
@@ -1,30 +1,72 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
|
|
|
| 3 |
from pytorch_lightning import LightningModule
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
class BaseModel(LightningModule):
|
| 6 |
-
def __init__(self, model, loss_func, lr
|
|
|
|
| 7 |
super().__init__()
|
| 8 |
self.model = model
|
| 9 |
self.loss_func = loss_func
|
| 10 |
self.lr = lr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
def forward(self, x):
|
| 13 |
return self.model(x)
|
| 14 |
|
| 15 |
def configure_optimizers(self):
|
| 16 |
-
optimizer = torch.optim.
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
}
|
| 22 |
-
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
|
| 23 |
|
| 24 |
def training_step(self, batch, batch_idx):
|
| 25 |
x, target = batch
|
| 26 |
pred = self(x)
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
self.log('train_loss', loss)
|
| 29 |
current_lr = self.trainer.optimizers[0].param_groups[0]['lr']
|
| 30 |
self.log('learning_rate', current_lr, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
|
@@ -33,13 +75,35 @@ class BaseModel(LightningModule):
|
|
| 33 |
def validation_step(self, batch, batch_idx):
|
| 34 |
x, target = batch
|
| 35 |
pred = self(x)
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
self.log('val_loss', loss)
|
| 38 |
return loss
|
| 39 |
|
| 40 |
def test_step(self, batch, batch_idx):
|
| 41 |
x, target = batch
|
| 42 |
pred = self(x)
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
self.log('test_loss', loss)
|
| 45 |
return loss
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
from pytorch_lightning import LightningModule
|
| 5 |
+
from collections import deque
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
# Import adaptive loss and normalization functions
|
| 9 |
+
from .vit_patch_model import SXRRegressionDynamicLoss, normalize_sxr, unnormalize_sxr
|
| 10 |
|
| 11 |
class BaseModel(LightningModule):
|
| 12 |
+
def __init__(self, model, loss_func, lr, sxr_norm=None, weight_decay=1e-5,
|
| 13 |
+
cosine_restart_T0=50, cosine_restart_Tmult=2, cosine_eta_min=1e-7):
|
| 14 |
super().__init__()
|
| 15 |
self.model = model
|
| 16 |
self.loss_func = loss_func
|
| 17 |
self.lr = lr
|
| 18 |
+
self.sxr_norm = sxr_norm
|
| 19 |
+
self.weight_decay = weight_decay
|
| 20 |
+
self.cosine_restart_T0 = cosine_restart_T0
|
| 21 |
+
self.cosine_restart_Tmult = cosine_restart_Tmult
|
| 22 |
+
self.cosine_eta_min = cosine_eta_min
|
| 23 |
+
|
| 24 |
+
# Initialize adaptive loss if sxr_norm is provided
|
| 25 |
+
if sxr_norm is not None:
|
| 26 |
+
self.adaptive_loss = SXRRegressionDynamicLoss(window_size=1500)
|
| 27 |
+
else:
|
| 28 |
+
self.adaptive_loss = None
|
| 29 |
|
| 30 |
def forward(self, x):
|
| 31 |
return self.model(x)
|
| 32 |
|
| 33 |
def configure_optimizers(self):
|
| 34 |
+
optimizer = torch.optim.AdamW(
|
| 35 |
+
self.parameters(),
|
| 36 |
+
lr=self.lr,
|
| 37 |
+
weight_decay=self.weight_decay,
|
| 38 |
+
)
|
| 39 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
| 40 |
+
optimizer,
|
| 41 |
+
T_0=self.cosine_restart_T0,
|
| 42 |
+
T_mult=self.cosine_restart_Tmult,
|
| 43 |
+
eta_min=self.cosine_eta_min,
|
| 44 |
+
)
|
| 45 |
+
return {
|
| 46 |
+
'optimizer': optimizer,
|
| 47 |
+
'lr_scheduler': {
|
| 48 |
+
'scheduler': scheduler,
|
| 49 |
+
'interval': 'epoch',
|
| 50 |
+
'frequency': 1,
|
| 51 |
+
'name': 'learning_rate'
|
| 52 |
+
}
|
| 53 |
}
|
|
|
|
| 54 |
|
| 55 |
def training_step(self, batch, batch_idx):
|
| 56 |
x, target = batch
|
| 57 |
pred = self(x)
|
| 58 |
+
|
| 59 |
+
# Use adaptive loss if available and sxr_norm is provided
|
| 60 |
+
if self.adaptive_loss is not None and self.sxr_norm is not None:
|
| 61 |
+
raw_preds_squeezed = torch.squeeze(pred)
|
| 62 |
+
target_un = unnormalize_sxr(target, self.sxr_norm)
|
| 63 |
+
norm_preds_squeezed = normalize_sxr(raw_preds_squeezed, self.sxr_norm)
|
| 64 |
+
loss, weights = self.adaptive_loss.calculate_loss(
|
| 65 |
+
norm_preds_squeezed, target, target_un, raw_preds_squeezed
|
| 66 |
+
)
|
| 67 |
+
else:
|
| 68 |
+
loss = self.loss_func(torch.squeeze(pred), target)
|
| 69 |
+
|
| 70 |
self.log('train_loss', loss)
|
| 71 |
current_lr = self.trainer.optimizers[0].param_groups[0]['lr']
|
| 72 |
self.log('learning_rate', current_lr, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
|
|
|
| 75 |
def validation_step(self, batch, batch_idx):
|
| 76 |
x, target = batch
|
| 77 |
pred = self(x)
|
| 78 |
+
|
| 79 |
+
# Use adaptive loss if available and sxr_norm is provided
|
| 80 |
+
if self.adaptive_loss is not None and self.sxr_norm is not None:
|
| 81 |
+
raw_preds_squeezed = torch.squeeze(pred)
|
| 82 |
+
target_un = unnormalize_sxr(target, self.sxr_norm)
|
| 83 |
+
norm_preds_squeezed = normalize_sxr(raw_preds_squeezed, self.sxr_norm)
|
| 84 |
+
loss, weights = self.adaptive_loss.calculate_loss(
|
| 85 |
+
norm_preds_squeezed, target, target_un, raw_preds_squeezed
|
| 86 |
+
)
|
| 87 |
+
else:
|
| 88 |
+
loss = self.loss_func(torch.squeeze(pred), target)
|
| 89 |
+
|
| 90 |
self.log('val_loss', loss)
|
| 91 |
return loss
|
| 92 |
|
| 93 |
def test_step(self, batch, batch_idx):
|
| 94 |
x, target = batch
|
| 95 |
pred = self(x)
|
| 96 |
+
|
| 97 |
+
# Use adaptive loss if available and sxr_norm is provided
|
| 98 |
+
if self.adaptive_loss is not None and self.sxr_norm is not None:
|
| 99 |
+
raw_preds_squeezed = torch.squeeze(pred)
|
| 100 |
+
target_un = unnormalize_sxr(target, self.sxr_norm)
|
| 101 |
+
norm_preds_squeezed = normalize_sxr(raw_preds_squeezed, self.sxr_norm)
|
| 102 |
+
loss, weights = self.adaptive_loss.calculate_loss(
|
| 103 |
+
norm_preds_squeezed, target, target_un, raw_preds_squeezed
|
| 104 |
+
)
|
| 105 |
+
else:
|
| 106 |
+
loss = self.loss_func(torch.squeeze(pred), target)
|
| 107 |
+
|
| 108 |
self.log('test_loss', loss)
|
| 109 |
return loss
|
forecasting/models/linear_and_hybrid.py
CHANGED
|
@@ -10,11 +10,14 @@ from forecasting.models.base_model import BaseModel
|
|
| 10 |
from torchvision.models import resnet18
|
| 11 |
|
| 12 |
class LinearIrradianceModel(BaseModel):
|
| 13 |
-
def __init__(self, d_input, d_output, loss_func=HuberLoss(), lr=1e-4
|
|
|
|
| 14 |
self.n_channels = d_input
|
| 15 |
self.outSize = d_output
|
| 16 |
model = nn.Linear(2 * self.n_channels, self.outSize)
|
| 17 |
-
super().__init__(model=model, loss_func=loss_func, lr=lr
|
|
|
|
|
|
|
| 18 |
|
| 19 |
def forward(self, x, **kwargs):
|
| 20 |
|
|
@@ -48,14 +51,21 @@ class LinearIrradianceModel(BaseModel):
|
|
| 48 |
return self.model(input_features)
|
| 49 |
|
| 50 |
class HybridIrradianceModel(BaseModel):
|
| 51 |
-
def __init__(self, d_input, d_output, cnn_model='resnet', ln_model=True, ln_params=None, lr=1e-4, cnn_dp=0.75, loss_func=HuberLoss()
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
| 53 |
self.n_channels = d_input
|
| 54 |
self.outSize = d_output
|
| 55 |
self.ln_params = ln_params
|
| 56 |
self.ln_model = None
|
| 57 |
if ln_model:
|
| 58 |
-
self.ln_model = LinearIrradianceModel(d_input, d_output, loss_func=loss_func, lr=lr
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
if self.ln_params is not None and self.ln_model is not None:
|
| 60 |
self.ln_model.model.weight = nn.Parameter(self.ln_params['weight'])
|
| 61 |
self.ln_model.model.bias = nn.Parameter(self.ln_params['bias'])
|
|
|
|
| 10 |
from torchvision.models import resnet18
|
| 11 |
|
| 12 |
class LinearIrradianceModel(BaseModel):
|
| 13 |
+
def __init__(self, d_input, d_output, loss_func=HuberLoss(), lr=1e-4, sxr_norm=None,
|
| 14 |
+
weight_decay=1e-5, cosine_restart_T0=50, cosine_restart_Tmult=2, cosine_eta_min=1e-7):
|
| 15 |
self.n_channels = d_input
|
| 16 |
self.outSize = d_output
|
| 17 |
model = nn.Linear(2 * self.n_channels, self.outSize)
|
| 18 |
+
super().__init__(model=model, loss_func=loss_func, lr=lr, sxr_norm=sxr_norm,
|
| 19 |
+
weight_decay=weight_decay, cosine_restart_T0=cosine_restart_T0,
|
| 20 |
+
cosine_restart_Tmult=cosine_restart_Tmult, cosine_eta_min=cosine_eta_min)
|
| 21 |
|
| 22 |
def forward(self, x, **kwargs):
|
| 23 |
|
|
|
|
| 51 |
return self.model(input_features)
|
| 52 |
|
| 53 |
class HybridIrradianceModel(BaseModel):
|
| 54 |
+
def __init__(self, d_input, d_output, cnn_model='resnet', ln_model=True, ln_params=None, lr=1e-4, cnn_dp=0.75, loss_func=HuberLoss(),
|
| 55 |
+
sxr_norm=None, weight_decay=1e-5, cosine_restart_T0=50, cosine_restart_Tmult=2, cosine_eta_min=1e-7):
|
| 56 |
+
super().__init__(model=None, loss_func=loss_func, lr=lr, sxr_norm=sxr_norm,
|
| 57 |
+
weight_decay=weight_decay, cosine_restart_T0=cosine_restart_T0,
|
| 58 |
+
cosine_restart_Tmult=cosine_restart_Tmult, cosine_eta_min=cosine_eta_min)
|
| 59 |
self.n_channels = d_input
|
| 60 |
self.outSize = d_output
|
| 61 |
self.ln_params = ln_params
|
| 62 |
self.ln_model = None
|
| 63 |
if ln_model:
|
| 64 |
+
self.ln_model = LinearIrradianceModel(d_input, d_output, loss_func=loss_func, lr=lr,
|
| 65 |
+
sxr_norm=sxr_norm, weight_decay=weight_decay,
|
| 66 |
+
cosine_restart_T0=cosine_restart_T0,
|
| 67 |
+
cosine_restart_Tmult=cosine_restart_Tmult,
|
| 68 |
+
cosine_eta_min=cosine_eta_min)
|
| 69 |
if self.ln_params is not None and self.ln_model is not None:
|
| 70 |
self.ln_model.model.weight = nn.Parameter(self.ln_params['weight'])
|
| 71 |
self.ln_model.model.bias = nn.Parameter(self.ln_params['bias'])
|
forecasting/training/config.yaml
CHANGED
|
@@ -4,9 +4,9 @@ base_data_dir: "/mnt/data/ML-READY" # Change this line for different datasets
|
|
| 4 |
base_checkpoint_dir: "/mnt/data/ML-READY" # Change this line for different datasets
|
| 5 |
wavelengths: [94, 131, 171, 193, 211, 304] # AIA wavelengths in Angstroms
|
| 6 |
# Model configuration
|
| 7 |
-
selected_model: "
|
| 8 |
-
batch_size:
|
| 9 |
-
epochs:
|
| 10 |
oversample: false
|
| 11 |
balance_strategy: "upsample_minority"
|
| 12 |
|
|
@@ -14,8 +14,12 @@ megsai:
|
|
| 14 |
architecture: "cnn"
|
| 15 |
seed: 42
|
| 16 |
lr: 0.0001
|
| 17 |
-
cnn_model: "
|
| 18 |
cnn_dp: 0.2
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
vit_custom:
|
| 21 |
embed_dim: 512
|
|
@@ -57,11 +61,11 @@ data:
|
|
| 57 |
|
| 58 |
wandb:
|
| 59 |
entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
|
| 60 |
-
project: ViT
|
| 61 |
job_type: training
|
| 62 |
tags:
|
| 63 |
- aia
|
| 64 |
- sxr
|
| 65 |
- regression
|
| 66 |
-
wb_name:
|
| 67 |
notes: Regression from AIA images (6 channels) to GOES SXR flux
|
|
|
|
| 4 |
base_checkpoint_dir: "/mnt/data/ML-READY" # Change this line for different datasets
|
| 5 |
wavelengths: [94, 131, 171, 193, 211, 304] # AIA wavelengths in Angstroms
|
| 6 |
# Model configuration
|
| 7 |
+
selected_model: "cnn" # Options: "cnn", "vit",
|
| 8 |
+
batch_size: 256
|
| 9 |
+
epochs: 250
|
| 10 |
oversample: false
|
| 11 |
balance_strategy: "upsample_minority"
|
| 12 |
|
|
|
|
| 14 |
architecture: "cnn"
|
| 15 |
seed: 42
|
| 16 |
lr: 0.0001
|
| 17 |
+
cnn_model: "updated"
|
| 18 |
cnn_dp: 0.2
|
| 19 |
+
weight_decay: 1e-5
|
| 20 |
+
cosine_restart_T0: 50
|
| 21 |
+
cosine_restart_Tmult: 2
|
| 22 |
+
cosine_eta_min: 1e-7
|
| 23 |
|
| 24 |
vit_custom:
|
| 25 |
embed_dim: 512
|
|
|
|
| 61 |
|
| 62 |
wandb:
|
| 63 |
entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
|
| 64 |
+
project: ViT Proper Scale
|
| 65 |
job_type: training
|
| 66 |
tags:
|
| 67 |
- aia
|
| 68 |
- sxr
|
| 69 |
- regression
|
| 70 |
+
wb_name: baseline-model
|
| 71 |
notes: Regression from AIA images (6 channels) to GOES SXR flux
|
forecasting/training/train.py
CHANGED
|
@@ -190,8 +190,13 @@ if config_data['selected_model'] == 'linear':
|
|
| 190 |
model = LinearIrradianceModel(
|
| 191 |
d_input= len(config_data['wavelengths']),
|
| 192 |
d_output=1,
|
| 193 |
-
lr= config_data['
|
| 194 |
-
loss_func=MSELoss()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
)
|
| 196 |
elif config_data['selected_model'] == 'hybrid':
|
| 197 |
model = HybridIrradianceModel(
|
|
@@ -201,6 +206,11 @@ elif config_data['selected_model'] == 'hybrid':
|
|
| 201 |
ln_model=True,
|
| 202 |
cnn_dp=config_data['megsai']['cnn_dp'],
|
| 203 |
lr=config_data['megsai']['lr'],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
)
|
| 205 |
elif config_data['selected_model'] == 'ViT':
|
| 206 |
model = ViT(model_kwargs=config_data['vit_custom'], sxr_norm = sxr_norm)
|
|
|
|
| 190 |
model = LinearIrradianceModel(
|
| 191 |
d_input= len(config_data['wavelengths']),
|
| 192 |
d_output=1,
|
| 193 |
+
lr= config_data['megsai']['lr'],
|
| 194 |
+
loss_func=MSELoss(),
|
| 195 |
+
sxr_norm=sxr_norm,
|
| 196 |
+
weight_decay=config_data['megsai']['weight_decay'],
|
| 197 |
+
cosine_restart_T0=config_data['megsai']['cosine_restart_T0'],
|
| 198 |
+
cosine_restart_Tmult=config_data['megsai']['cosine_restart_Tmult'],
|
| 199 |
+
cosine_eta_min=config_data['megsai']['cosine_eta_min']
|
| 200 |
)
|
| 201 |
elif config_data['selected_model'] == 'hybrid':
|
| 202 |
model = HybridIrradianceModel(
|
|
|
|
| 206 |
ln_model=True,
|
| 207 |
cnn_dp=config_data['megsai']['cnn_dp'],
|
| 208 |
lr=config_data['megsai']['lr'],
|
| 209 |
+
sxr_norm=sxr_norm,
|
| 210 |
+
weight_decay=config_data['megsai']['weight_decay'],
|
| 211 |
+
cosine_restart_T0=config_data['megsai']['cosine_restart_T0'],
|
| 212 |
+
cosine_restart_Tmult=config_data['megsai']['cosine_restart_Tmult'],
|
| 213 |
+
cosine_eta_min=config_data['megsai']['cosine_eta_min']
|
| 214 |
)
|
| 215 |
elif config_data['selected_model'] == 'ViT':
|
| 216 |
model = ViT(model_kwargs=config_data['vit_custom'], sxr_norm = sxr_norm)
|