griffingoodwin04 commited on
Commit
7bbc799
·
1 Parent(s): 7e470a6

Updated splitting function so now it can do SXR and AIA individually... it can also repartition an already partioned folder... updated the megs models to include warm restarts, weight decay, and adaptive weights

Browse files
data/split_data.py CHANGED
@@ -1,72 +1,183 @@
1
  import os
2
  import pandas as pd
3
  import shutil
 
4
  from datetime import datetime
 
5
 
6
- aia_data_dir = "/mnt/data2/AIA_processed/"
7
- sxr_data_dir = "/mnt/data2/ML-Ready_clean/GOES-18-SXR-B/"
8
- flares_event_dir = "/mnt/data2/ML-Ready_clean/flares_event_dir/"
9
- non_flares_event_dir = "/mnt/data2/ML-Ready_clean/non_flares_event_dir/"
10
- mixed_data_dir = "/mnt/data2/ML-Ready_clean/mixed_data/"
11
- flare_events_csv = "/mnt/data2/SDO-AIA-flaring/FlareEvents/flare_events_2012-01-01_2015-03-25.csv"
12
-
13
- # Create train, val, test subdirectories under flaring and non-flaring
14
- for base_dir in [flares_event_dir, non_flares_event_dir, mixed_data_dir]:
15
- os.makedirs(os.path.join(base_dir, "AIA"), exist_ok=True)
16
- os.makedirs(os.path.join(base_dir, "SXR"), exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
17
  for split in ["train", "val", "test"]:
18
- os.makedirs(os.path.join(base_dir, "AIA", split), exist_ok=True)
19
- os.makedirs(os.path.join(base_dir, "SXR", split), exist_ok=True)
20
-
21
- # Load flare events
22
- flare_event = pd.read_csv(flare_events_csv)
23
-
24
- # Create list of flare event time ranges
25
- flaring_eve_list = []
26
- for i, row in flare_event.iterrows():
27
- start_time = pd.to_datetime(row['event_starttime'])
28
- end_time = pd.to_datetime(row['event_endtime'])
29
- flaring_eve_list.append((start_time, end_time))
30
-
31
- # Get list of files in data_dir
32
- data_list = os.listdir(aia_data_dir)
33
-
34
- for file in data_list:
35
- try:
36
- aia_time = pd.to_datetime(file.split(".")[0])
37
- except ValueError:
38
- print(f"Skipping file {file}: Invalid timestamp format")
39
- continue
40
-
41
- # Determine if the file is during a flare event
42
- is_flaring = any(start <= aia_time <= end for start, end in flaring_eve_list)
43
- base_dir = flares_event_dir if is_flaring else non_flares_event_dir
44
-
45
- month = aia_time.month
46
-
47
- if month in [2, 3, 4, 5, 6, 7, 9, 10, 11, 12]:
48
- split_dir = "train"
49
- elif month == 1:
50
- split_dir = "val"
51
- elif month == 8:
52
- split_dir = "test"
53
- else:
54
- print(f"Skipping file {file}: Unexpected month {month}")
55
- continue
56
-
57
- # Copy file to appropriate directory
58
- src_aia = os.path.join(aia_data_dir, file)
59
- src_sxr = os.path.join(sxr_data_dir, file)
60
- dst_aia = os.path.join(base_dir, "AIA", split_dir, file)
61
- dst_sxr = os.path.join(base_dir, "SXR", split_dir, file)
62
-
63
- if not os.path.exists(dst_aia):
64
- shutil.copy(src_aia, dst_aia)
65
- print(f"Copied {file} to {dst_aia} and {dst_sxr}")
66
  else:
67
- print(f"File {dst_aia} already exists, skipping copy.")
68
- if not os.path.exists(dst_sxr):
69
- shutil.copy(src_sxr, dst_sxr)
 
 
 
 
 
 
 
 
 
 
 
 
70
  else:
71
- print(f"File {dst_sxr} already exists, skipping copy.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
 
 
 
1
  import os
2
  import pandas as pd
3
  import shutil
4
+ import argparse
5
  from datetime import datetime
6
+ from pathlib import Path
7
 
8
+ def split_data(input_folder, output_dir, data_type, flare_events_csv=None, repartition=False):
9
+ """
10
+ Split data from input folder into train/val/test based on month.
11
+ Optionally use flare events for additional classification.
12
+
13
+ Args:
14
+ input_folder (str): Path to the input folder containing data files
15
+ output_dir (str): Path to the output directory where split data will be saved
16
+ data_type (str): Type of data ('aia' or 'sxr')
17
+ flare_events_csv (str, optional): Path to the flare events CSV file
18
+ repartition (bool): If True, treat input_folder as already partitioned (has train/val/test subdirs)
19
+ """
20
+
21
+ # Validate input folder
22
+ if not os.path.exists(input_folder):
23
+ raise ValueError(f"Input folder does not exist: {input_folder}")
24
+
25
+ # Validate data type
26
+ if data_type.lower() not in ['aia', 'sxr']:
27
+ raise ValueError("data_type must be 'aia' or 'sxr'")
28
+
29
+ # Create output directory structure
30
+ os.makedirs(output_dir, exist_ok=True)
31
  for split in ["train", "val", "test"]:
32
+ os.makedirs(os.path.join(output_dir, split), exist_ok=True)
33
+
34
+ print(f"Processing {data_type.upper()} data from: {input_folder}")
35
+ print(f"Output directory: {output_dir}")
36
+
37
+ # Load flare events if provided
38
+ flaring_eve_list = []
39
+ if flare_events_csv and os.path.exists(flare_events_csv):
40
+ print(f"Loading flare events from: {flare_events_csv}")
41
+ flare_event = pd.read_csv(flare_events_csv)
42
+
43
+ # Create list of flare event time ranges
44
+ for i, row in flare_event.iterrows():
45
+ start_time = pd.to_datetime(row['event_starttime'])
46
+ end_time = pd.to_datetime(row['event_endtime'])
47
+ flaring_eve_list.append((start_time, end_time))
48
+ print(f"Loaded {len(flaring_eve_list)} flare events")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  else:
50
+ print("No flare events CSV provided or file not found. Skipping flare classification.")
51
+
52
+ # Get list of files in input folder
53
+ if repartition:
54
+ # For repartitioning, collect files from train/val/test subdirectories
55
+ data_list = []
56
+ for split in ["train", "val", "test"]:
57
+ split_dir = os.path.join(input_folder, split)
58
+ if os.path.exists(split_dir):
59
+ split_files = os.listdir(split_dir)
60
+ # Add split information to each file for tracking
61
+ for file in split_files:
62
+ data_list.append((file, split))
63
+ print(f"Found {len(split_files)} files in {split}/ directory")
64
+ print(f"Total files to repartition: {len(data_list)}")
65
  else:
66
+ # For normal splitting, get files directly from input folder
67
+ data_list = os.listdir(input_folder)
68
+ print(f"Found {len(data_list)} files to process")
69
+
70
+ moved_count = 0
71
+ skipped_count = 0
72
+
73
+ for file_info in data_list:
74
+ if repartition:
75
+ file, original_split = file_info
76
+ else:
77
+ file = file_info
78
+ original_split = None
79
+
80
+ try:
81
+ # Extract timestamp from filename (assuming format like "2012-01-01T00:00:00.npy")
82
+ file_time = pd.to_datetime(file.split(".")[0])
83
+ except ValueError:
84
+ print(f"Skipping file {file}: Invalid timestamp format")
85
+ skipped_count += 1
86
+ continue
87
+
88
+ # Determine if the file is during a flare event (if flare events are available)
89
+ is_flaring = False
90
+ if flaring_eve_list:
91
+ is_flaring = any(start <= file_time <= end for start, end in flaring_eve_list)
92
+
93
+ # Determine split based on month
94
+ month = file_time.month
95
+
96
+ if month in [2, 3, 4, 5, 6, 7, 9, 10, 11, 12]:
97
+ new_split_dir = "train"
98
+ elif month == 1:
99
+ new_split_dir = "val"
100
+ elif month == 8:
101
+ new_split_dir = "test"
102
+ else:
103
+ print(f"Skipping file {file}: Unexpected month {month}")
104
+ skipped_count += 1
105
+ continue
106
+
107
+ # Determine source and destination paths
108
+ if repartition:
109
+ src_path = os.path.join(input_folder, original_split, file)
110
+ else:
111
+ src_path = os.path.join(input_folder, file)
112
+
113
+ dst_path = os.path.join(output_dir, new_split_dir, file)
114
+
115
+ # Skip if file is already in the correct split and we're repartitioning
116
+ if repartition and original_split == new_split_dir and os.path.exists(dst_path):
117
+ print(f"File {file} already in correct split ({new_split_dir}), skipping.")
118
+ skipped_count += 1
119
+ continue
120
+
121
+ if not os.path.exists(dst_path):
122
+ try:
123
+ shutil.move(src_path, dst_path)
124
+ if repartition:
125
+ if flaring_eve_list:
126
+ print(f"Moved {file} from {original_split}/ to {new_split_dir}/ (flaring: {is_flaring})")
127
+ else:
128
+ print(f"Moved {file} from {original_split}/ to {new_split_dir}/")
129
+ else:
130
+ if flaring_eve_list:
131
+ print(f"Moved {file} to {new_split_dir}/ (flaring: {is_flaring})")
132
+ else:
133
+ print(f"Moved {file} to {new_split_dir}/")
134
+ moved_count += 1
135
+ except Exception as e:
136
+ print(f"Error moving {file}: {e}")
137
+ skipped_count += 1
138
+ else:
139
+ print(f"File {dst_path} already exists, skipping move.")
140
+ skipped_count += 1
141
+
142
+ print(f"\nProcessing complete!")
143
+ print(f"Files moved: {moved_count}")
144
+ print(f"Files skipped: {skipped_count}")
145
+ print(f"Total files processed: {moved_count + skipped_count}")
146
+
147
+ def main():
148
+ parser = argparse.ArgumentParser(description='Split AIA or SXR data into train/val/test sets based on month')
149
+ parser.add_argument('--input_folder', type=str, required=True,
150
+ help='Path to the input folder containing data files (or partitioned folder for repartition)')
151
+ parser.add_argument('--output_dir', type=str, required=True,
152
+ help='Path to the output directory where split data will be saved')
153
+ parser.add_argument('--data_type', type=str, choices=['aia', 'sxr'], required=True,
154
+ help='Type of data: "aia" or "sxr"')
155
+ parser.add_argument('--flare_events_csv', type=str, default=None,
156
+ help='Path to the flare events CSV file (optional)')
157
+ parser.add_argument('--repartition', action='store_true',
158
+ help='Repartition an already partitioned folder (input_folder should have train/val/test subdirs)')
159
+
160
+ args = parser.parse_args()
161
+
162
+ # Convert to absolute paths
163
+ input_folder = os.path.abspath(args.input_folder)
164
+ output_dir = os.path.abspath(args.output_dir)
165
+ flare_events_csv = os.path.abspath(args.flare_events_csv) if args.flare_events_csv else None
166
+
167
+ # Validate repartition mode
168
+ if args.repartition:
169
+ # Check if input folder has train/val/test subdirectories
170
+ expected_dirs = ['train', 'val', 'test']
171
+ missing_dirs = []
172
+ for dir_name in expected_dirs:
173
+ if not os.path.exists(os.path.join(input_folder, dir_name)):
174
+ missing_dirs.append(dir_name)
175
+
176
+ if missing_dirs:
177
+ print(f"Warning: Input folder is missing expected subdirectories: {missing_dirs}")
178
+ print("Continuing with available directories...")
179
+
180
+ split_data(input_folder, output_dir, args.data_type, flare_events_csv, args.repartition)
181
 
182
+ if __name__ == "__main__":
183
+ main()
forecasting/data_loaders/sxr_normalization.py CHANGED
@@ -51,7 +51,7 @@ def compute_sxr_norm(sxr_dir):
51
 
52
  if __name__ == "__main__":
53
  # Update this path to your real data SXR directory
54
- sxr_dir = "/mnt/data/ML-Ready-mixed/ML-Ready-mixed/SXR/train" # Replace with actual path
55
  sxr_norm = compute_sxr_norm(sxr_dir)
56
- np.save("/mnt/data/ML-Ready-mixed/ML-Ready-mixed/SXR/normalized_sxr.npy", sxr_norm)
57
  #print(f"Saved SXR normalization to /mnt/data/ML-Ready-Data-No-Intensity-Cut/normalized_sxr")
 
51
 
52
  if __name__ == "__main__":
53
  # Update this path to your real data SXR directory
54
+ sxr_dir = "/mnt/data/ML-READY/SXR/train" # Replace with actual path
55
  sxr_norm = compute_sxr_norm(sxr_dir)
56
+ np.save("/mnt/data/ML-READY/SXR/normalized_sxr.npy", sxr_norm)
57
  #print(f"Saved SXR normalization to /mnt/data/ML-Ready-Data-No-Intensity-Cut/normalized_sxr")
forecasting/inference/evaluation.py CHANGED
@@ -890,7 +890,7 @@ class SolarFlareEvaluator:
890
  # Sort frame paths by timestamp to ensure correct order
891
  frame_paths.sort(key=lambda x: os.path.basename(x))
892
 
893
- movie_path = os.path.join(self.output_dir, "AIA_video_with_uncertainties.mp4")
894
  with imageio.get_writer(movie_path, fps=30, codec='libx264', format='ffmpeg') as writer:
895
  for frame_path in frame_paths:
896
  if os.path.exists(frame_path):
 
890
  # Sort frame paths by timestamp to ensure correct order
891
  frame_paths.sort(key=lambda x: os.path.basename(x))
892
 
893
+ movie_path = os.path.join(self.output_dir, f"AIA_{timestamps[0].split('T')[0]}.mp4")
894
  with imageio.get_writer(movie_path, fps=30, codec='libx264', format='ffmpeg') as writer:
895
  for frame_path in frame_paths:
896
  if os.path.exists(frame_path):
forecasting/models/base_model.py CHANGED
@@ -1,30 +1,72 @@
1
  import torch
2
  import torch.nn as nn
 
3
  from pytorch_lightning import LightningModule
 
 
 
 
 
4
 
5
  class BaseModel(LightningModule):
6
- def __init__(self, model, loss_func, lr):
 
7
  super().__init__()
8
  self.model = model
9
  self.loss_func = loss_func
10
  self.lr = lr
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def forward(self, x):
13
  return self.model(x)
14
 
15
  def configure_optimizers(self):
16
- optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
17
- scheduler = {
18
- 'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3),
19
- 'monitor': 'val_loss', # name of the metric to monitor
20
- 'interval': 'epoch',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  }
22
- return {'optimizer': optimizer, 'lr_scheduler': scheduler}
23
 
24
  def training_step(self, batch, batch_idx):
25
  x, target = batch
26
  pred = self(x)
27
- loss = self.loss_func(torch.squeeze(pred), target)
 
 
 
 
 
 
 
 
 
 
 
28
  self.log('train_loss', loss)
29
  current_lr = self.trainer.optimizers[0].param_groups[0]['lr']
30
  self.log('learning_rate', current_lr, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
@@ -33,13 +75,35 @@ class BaseModel(LightningModule):
33
  def validation_step(self, batch, batch_idx):
34
  x, target = batch
35
  pred = self(x)
36
- loss = self.loss_func(torch.squeeze(pred), target)
 
 
 
 
 
 
 
 
 
 
 
37
  self.log('val_loss', loss)
38
  return loss
39
 
40
  def test_step(self, batch, batch_idx):
41
  x, target = batch
42
  pred = self(x)
43
- loss = self.loss_func(torch.squeeze(pred), target)
 
 
 
 
 
 
 
 
 
 
 
44
  self.log('test_loss', loss)
45
  return loss
 
1
  import torch
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
4
  from pytorch_lightning import LightningModule
5
+ from collections import deque
6
+ import numpy as np
7
+
8
+ # Import adaptive loss and normalization functions
9
+ from .vit_patch_model import SXRRegressionDynamicLoss, normalize_sxr, unnormalize_sxr
10
 
11
  class BaseModel(LightningModule):
12
+ def __init__(self, model, loss_func, lr, sxr_norm=None, weight_decay=1e-5,
13
+ cosine_restart_T0=50, cosine_restart_Tmult=2, cosine_eta_min=1e-7):
14
  super().__init__()
15
  self.model = model
16
  self.loss_func = loss_func
17
  self.lr = lr
18
+ self.sxr_norm = sxr_norm
19
+ self.weight_decay = weight_decay
20
+ self.cosine_restart_T0 = cosine_restart_T0
21
+ self.cosine_restart_Tmult = cosine_restart_Tmult
22
+ self.cosine_eta_min = cosine_eta_min
23
+
24
+ # Initialize adaptive loss if sxr_norm is provided
25
+ if sxr_norm is not None:
26
+ self.adaptive_loss = SXRRegressionDynamicLoss(window_size=1500)
27
+ else:
28
+ self.adaptive_loss = None
29
 
30
  def forward(self, x):
31
  return self.model(x)
32
 
33
  def configure_optimizers(self):
34
+ optimizer = torch.optim.AdamW(
35
+ self.parameters(),
36
+ lr=self.lr,
37
+ weight_decay=self.weight_decay,
38
+ )
39
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
40
+ optimizer,
41
+ T_0=self.cosine_restart_T0,
42
+ T_mult=self.cosine_restart_Tmult,
43
+ eta_min=self.cosine_eta_min,
44
+ )
45
+ return {
46
+ 'optimizer': optimizer,
47
+ 'lr_scheduler': {
48
+ 'scheduler': scheduler,
49
+ 'interval': 'epoch',
50
+ 'frequency': 1,
51
+ 'name': 'learning_rate'
52
+ }
53
  }
 
54
 
55
  def training_step(self, batch, batch_idx):
56
  x, target = batch
57
  pred = self(x)
58
+
59
+ # Use adaptive loss if available and sxr_norm is provided
60
+ if self.adaptive_loss is not None and self.sxr_norm is not None:
61
+ raw_preds_squeezed = torch.squeeze(pred)
62
+ target_un = unnormalize_sxr(target, self.sxr_norm)
63
+ norm_preds_squeezed = normalize_sxr(raw_preds_squeezed, self.sxr_norm)
64
+ loss, weights = self.adaptive_loss.calculate_loss(
65
+ norm_preds_squeezed, target, target_un, raw_preds_squeezed
66
+ )
67
+ else:
68
+ loss = self.loss_func(torch.squeeze(pred), target)
69
+
70
  self.log('train_loss', loss)
71
  current_lr = self.trainer.optimizers[0].param_groups[0]['lr']
72
  self.log('learning_rate', current_lr, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
 
75
  def validation_step(self, batch, batch_idx):
76
  x, target = batch
77
  pred = self(x)
78
+
79
+ # Use adaptive loss if available and sxr_norm is provided
80
+ if self.adaptive_loss is not None and self.sxr_norm is not None:
81
+ raw_preds_squeezed = torch.squeeze(pred)
82
+ target_un = unnormalize_sxr(target, self.sxr_norm)
83
+ norm_preds_squeezed = normalize_sxr(raw_preds_squeezed, self.sxr_norm)
84
+ loss, weights = self.adaptive_loss.calculate_loss(
85
+ norm_preds_squeezed, target, target_un, raw_preds_squeezed
86
+ )
87
+ else:
88
+ loss = self.loss_func(torch.squeeze(pred), target)
89
+
90
  self.log('val_loss', loss)
91
  return loss
92
 
93
  def test_step(self, batch, batch_idx):
94
  x, target = batch
95
  pred = self(x)
96
+
97
+ # Use adaptive loss if available and sxr_norm is provided
98
+ if self.adaptive_loss is not None and self.sxr_norm is not None:
99
+ raw_preds_squeezed = torch.squeeze(pred)
100
+ target_un = unnormalize_sxr(target, self.sxr_norm)
101
+ norm_preds_squeezed = normalize_sxr(raw_preds_squeezed, self.sxr_norm)
102
+ loss, weights = self.adaptive_loss.calculate_loss(
103
+ norm_preds_squeezed, target, target_un, raw_preds_squeezed
104
+ )
105
+ else:
106
+ loss = self.loss_func(torch.squeeze(pred), target)
107
+
108
  self.log('test_loss', loss)
109
  return loss
forecasting/models/linear_and_hybrid.py CHANGED
@@ -10,11 +10,14 @@ from forecasting.models.base_model import BaseModel
10
  from torchvision.models import resnet18
11
 
12
  class LinearIrradianceModel(BaseModel):
13
- def __init__(self, d_input, d_output, loss_func=HuberLoss(), lr=1e-4):
 
14
  self.n_channels = d_input
15
  self.outSize = d_output
16
  model = nn.Linear(2 * self.n_channels, self.outSize)
17
- super().__init__(model=model, loss_func=loss_func, lr=lr)
 
 
18
 
19
  def forward(self, x, **kwargs):
20
 
@@ -48,14 +51,21 @@ class LinearIrradianceModel(BaseModel):
48
  return self.model(input_features)
49
 
50
  class HybridIrradianceModel(BaseModel):
51
- def __init__(self, d_input, d_output, cnn_model='resnet', ln_model=True, ln_params=None, lr=1e-4, cnn_dp=0.75, loss_func=HuberLoss()):
52
- super().__init__(model=None, loss_func=loss_func, lr=lr)
 
 
 
53
  self.n_channels = d_input
54
  self.outSize = d_output
55
  self.ln_params = ln_params
56
  self.ln_model = None
57
  if ln_model:
58
- self.ln_model = LinearIrradianceModel(d_input, d_output, loss_func=loss_func, lr=lr)
 
 
 
 
59
  if self.ln_params is not None and self.ln_model is not None:
60
  self.ln_model.model.weight = nn.Parameter(self.ln_params['weight'])
61
  self.ln_model.model.bias = nn.Parameter(self.ln_params['bias'])
 
10
  from torchvision.models import resnet18
11
 
12
  class LinearIrradianceModel(BaseModel):
13
+ def __init__(self, d_input, d_output, loss_func=HuberLoss(), lr=1e-4, sxr_norm=None,
14
+ weight_decay=1e-5, cosine_restart_T0=50, cosine_restart_Tmult=2, cosine_eta_min=1e-7):
15
  self.n_channels = d_input
16
  self.outSize = d_output
17
  model = nn.Linear(2 * self.n_channels, self.outSize)
18
+ super().__init__(model=model, loss_func=loss_func, lr=lr, sxr_norm=sxr_norm,
19
+ weight_decay=weight_decay, cosine_restart_T0=cosine_restart_T0,
20
+ cosine_restart_Tmult=cosine_restart_Tmult, cosine_eta_min=cosine_eta_min)
21
 
22
  def forward(self, x, **kwargs):
23
 
 
51
  return self.model(input_features)
52
 
53
  class HybridIrradianceModel(BaseModel):
54
+ def __init__(self, d_input, d_output, cnn_model='resnet', ln_model=True, ln_params=None, lr=1e-4, cnn_dp=0.75, loss_func=HuberLoss(),
55
+ sxr_norm=None, weight_decay=1e-5, cosine_restart_T0=50, cosine_restart_Tmult=2, cosine_eta_min=1e-7):
56
+ super().__init__(model=None, loss_func=loss_func, lr=lr, sxr_norm=sxr_norm,
57
+ weight_decay=weight_decay, cosine_restart_T0=cosine_restart_T0,
58
+ cosine_restart_Tmult=cosine_restart_Tmult, cosine_eta_min=cosine_eta_min)
59
  self.n_channels = d_input
60
  self.outSize = d_output
61
  self.ln_params = ln_params
62
  self.ln_model = None
63
  if ln_model:
64
+ self.ln_model = LinearIrradianceModel(d_input, d_output, loss_func=loss_func, lr=lr,
65
+ sxr_norm=sxr_norm, weight_decay=weight_decay,
66
+ cosine_restart_T0=cosine_restart_T0,
67
+ cosine_restart_Tmult=cosine_restart_Tmult,
68
+ cosine_eta_min=cosine_eta_min)
69
  if self.ln_params is not None and self.ln_model is not None:
70
  self.ln_model.model.weight = nn.Parameter(self.ln_params['weight'])
71
  self.ln_model.model.bias = nn.Parameter(self.ln_params['bias'])
forecasting/training/config.yaml CHANGED
@@ -4,9 +4,9 @@ base_data_dir: "/mnt/data/ML-READY" # Change this line for different datasets
4
  base_checkpoint_dir: "/mnt/data/ML-READY" # Change this line for different datasets
5
  wavelengths: [94, 131, 171, 193, 211, 304] # AIA wavelengths in Angstroms
6
  # Model configuration
7
- selected_model: "FusionViTHybrid" # Options: "cnn", "vit",
8
- batch_size: 16
9
- epochs: 500
10
  oversample: false
11
  balance_strategy: "upsample_minority"
12
 
@@ -14,8 +14,12 @@ megsai:
14
  architecture: "cnn"
15
  seed: 42
16
  lr: 0.0001
17
- cnn_model: "original"
18
  cnn_dp: 0.2
 
 
 
 
19
 
20
  vit_custom:
21
  embed_dim: 512
@@ -57,11 +61,11 @@ data:
57
 
58
  wandb:
59
  entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
60
- project: ViT Loss Function Testing For Patches
61
  job_type: training
62
  tags:
63
  - aia
64
  - sxr
65
  - regression
66
- wb_name: vit-fused-model
67
  notes: Regression from AIA images (6 channels) to GOES SXR flux
 
4
  base_checkpoint_dir: "/mnt/data/ML-READY" # Change this line for different datasets
5
  wavelengths: [94, 131, 171, 193, 211, 304] # AIA wavelengths in Angstroms
6
  # Model configuration
7
+ selected_model: "cnn" # Options: "cnn", "vit",
8
+ batch_size: 256
9
+ epochs: 250
10
  oversample: false
11
  balance_strategy: "upsample_minority"
12
 
 
14
  architecture: "cnn"
15
  seed: 42
16
  lr: 0.0001
17
+ cnn_model: "updated"
18
  cnn_dp: 0.2
19
+ weight_decay: 1e-5
20
+ cosine_restart_T0: 50
21
+ cosine_restart_Tmult: 2
22
+ cosine_eta_min: 1e-7
23
 
24
  vit_custom:
25
  embed_dim: 512
 
61
 
62
  wandb:
63
  entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
64
+ project: ViT Proper Scale
65
  job_type: training
66
  tags:
67
  - aia
68
  - sxr
69
  - regression
70
+ wb_name: baseline-model
71
  notes: Regression from AIA images (6 channels) to GOES SXR flux
forecasting/training/train.py CHANGED
@@ -190,8 +190,13 @@ if config_data['selected_model'] == 'linear':
190
  model = LinearIrradianceModel(
191
  d_input= len(config_data['wavelengths']),
192
  d_output=1,
193
- lr= config_data['model']['lr'],
194
- loss_func=MSELoss()
 
 
 
 
 
195
  )
196
  elif config_data['selected_model'] == 'hybrid':
197
  model = HybridIrradianceModel(
@@ -201,6 +206,11 @@ elif config_data['selected_model'] == 'hybrid':
201
  ln_model=True,
202
  cnn_dp=config_data['megsai']['cnn_dp'],
203
  lr=config_data['megsai']['lr'],
 
 
 
 
 
204
  )
205
  elif config_data['selected_model'] == 'ViT':
206
  model = ViT(model_kwargs=config_data['vit_custom'], sxr_norm = sxr_norm)
 
190
  model = LinearIrradianceModel(
191
  d_input= len(config_data['wavelengths']),
192
  d_output=1,
193
+ lr= config_data['megsai']['lr'],
194
+ loss_func=MSELoss(),
195
+ sxr_norm=sxr_norm,
196
+ weight_decay=config_data['megsai']['weight_decay'],
197
+ cosine_restart_T0=config_data['megsai']['cosine_restart_T0'],
198
+ cosine_restart_Tmult=config_data['megsai']['cosine_restart_Tmult'],
199
+ cosine_eta_min=config_data['megsai']['cosine_eta_min']
200
  )
201
  elif config_data['selected_model'] == 'hybrid':
202
  model = HybridIrradianceModel(
 
206
  ln_model=True,
207
  cnn_dp=config_data['megsai']['cnn_dp'],
208
  lr=config_data['megsai']['lr'],
209
+ sxr_norm=sxr_norm,
210
+ weight_decay=config_data['megsai']['weight_decay'],
211
+ cosine_restart_T0=config_data['megsai']['cosine_restart_T0'],
212
+ cosine_restart_Tmult=config_data['megsai']['cosine_restart_Tmult'],
213
+ cosine_eta_min=config_data['megsai']['cosine_eta_min']
214
  )
215
  elif config_data['selected_model'] == 'ViT':
216
  model = ViT(model_kwargs=config_data['vit_custom'], sxr_norm = sxr_norm)