Chromaniquej1 commited on
Commit
f6e753d
·
2 Parent(s): 614e50075ed062

changed structure

Browse files
flaring/MEGS_AI_baseline/SDOAIA_dataloader.py CHANGED
@@ -57,19 +57,19 @@ class AIA_GOESDataset(torch.utils.data.Dataset):
57
  self.target_size[0]/aia_img.shape[1],
58
  self.target_size[1]/aia_img.shape[2]))
59
 
60
- #Apply cut and normalize:
61
- cuts_dict = {
62
- 0: np.float32(16.560747),
63
- 1: np.float32(75.84181),
64
- 2: np.float32(1536.1443),
65
- 3: np.float32(2288.1),
66
- 4: np.float32(1163.9178),
67
- 5: np.float32(401.82352)
68
- }
69
-
70
- for channel in range(6):
71
- aia_img[channel] = np.clip(aia_img[channel], 0, cuts_dict[channel])
72
- aia_img[channel] = aia_img[channel] / cuts_dict[channel] # Normalize each channel to [0, 1]
73
 
74
  # Convert to torch for transforms
75
  aia_img = torch.tensor(aia_img, dtype=torch.float32) # (6, H, W)
 
57
  self.target_size[0]/aia_img.shape[1],
58
  self.target_size[1]/aia_img.shape[2]))
59
 
60
+ # #Apply cut and normalize:
61
+ # cuts_dict = {
62
+ # 0: np.float32(16.560747),
63
+ # 1: np.float32(75.84181),
64
+ # 2: np.float32(1536.1443),
65
+ # 3: np.float32(2288.1),
66
+ # 4: np.float32(1163.9178),
67
+ # 5: np.float32(401.82352)
68
+ # }
69
+ #
70
+ # for channel in range(6):
71
+ # aia_img[channel] = np.clip(aia_img[channel], 0, cuts_dict[channel])
72
+ # aia_img[channel] = aia_img[channel] / cuts_dict[channel] # Normalize each channel to [0, 1]
73
 
74
  # Convert to torch for transforms
75
  aia_img = torch.tensor(aia_img, dtype=torch.float32) # (6, H, W)
flaring/MEGS_AI_baseline/callback.py CHANGED
@@ -10,12 +10,12 @@ import astropy.units as u
10
  # Custom Callback
11
  sdoaia94 = matplotlib.colormaps['sdoaia94']
12
 
13
- def unnormalize(y, eve_norm):
14
- eve_norm = torch.tensor(eve_norm).float()
15
- norm_mean = eve_norm[0]
16
- norm_stdev = eve_norm[1]
17
- y = y * norm_stdev[None].to(y) + norm_mean[None].to(y)
18
- return y
19
 
20
 
21
  class ImagePredictionLogger_SXR(Callback):
@@ -27,12 +27,6 @@ class ImagePredictionLogger_SXR(Callback):
27
  self.val_sxr = data_samples[1]
28
  self.sxr_norm = sxr_norm
29
 
30
- def unnormalize_sxr(self, normalized_values):
31
- if isinstance(normalized_values, torch.Tensor):
32
- normalized_values = normalized_values.cpu().numpy()
33
- normalized_values = np.array(normalized_values, dtype=np.float32)
34
- return 10 ** (normalized_values * float(self.sxr_norm[1].item()) + float(self.sxr_norm[0].item())) - 1e-8
35
-
36
  def on_validation_epoch_end(self, trainer, pl_module):
37
 
38
  aia_images = []
@@ -50,8 +44,8 @@ class ImagePredictionLogger_SXR(Callback):
50
  aia_images.append(aia.squeeze(0).cpu().numpy())
51
  true_sxr.append(target.item())
52
 
53
- true_unorm = self.unnormalize_sxr(true_sxr)
54
- pred_unnorm = self.unnormalize_sxr(pred_sxr)
55
  fig1 = self.plot_aia_sxr(aia_images,true_unorm, pred_unnorm)
56
  trainer.logger.experiment.log({"Soft X-ray flux plots": wandb.Image(fig1)})
57
  plt.close(fig1)
@@ -85,178 +79,3 @@ class ImagePredictionLogger_SXR(Callback):
85
 
86
  fig.tight_layout()
87
  return fig
88
-
89
-
90
- class ImagePredictionLogger(Callback):
91
- def __init__(self, val_imgs, val_eve, names, aia_wavelengths):
92
- super().__init__()
93
- self.val_imgs, self.val_eve = val_imgs, val_eve
94
- self.names = names
95
- self.aia_wavelengths = aia_wavelengths
96
-
97
- def on_validation_epoch_end(self, trainer, pl_module):
98
- # Bring the tensors to CPU
99
- val_imgs = self.val_imgs.to(device=pl_module.device)
100
- # Get model prediction
101
- # pred_eve = pl_module.forward(val_imgs).cpu().numpy()
102
- pred_eve = pl_module.forward_unnormalize(val_imgs).cpu().numpy()
103
- val_eve = unnormalize(self.val_eve, pl_module.eve_norm).numpy()
104
- val_imgs = val_imgs.cpu().numpy()
105
-
106
- # create matplotlib figure
107
- fig = self.plot_aia_eve(val_imgs, val_eve, pred_eve)
108
- # Log the images to wandb
109
- trainer.logger.experiment.log({"AIA Images and EVE bar plots": wandb.Image(fig)})
110
- plt.close(fig)
111
-
112
- def plot_aia_eve(self, val_imgs, val_eve, pred_eve):
113
- """
114
- Function to plot a 4 channel AIA stack and the EVE barplots
115
-
116
- Arguments:
117
- ----------
118
- val_imgs: numpy array
119
- Stack with 4 image channels
120
- val_eve: numpy array
121
- Stack of ground-truth eve channels
122
- pred_eve: numpy array
123
- Stack of predicted eve channels
124
- Returns:
125
- --------
126
- fig: matplotlib figure
127
- figure with plots
128
- """
129
- samples = pred_eve.shape[0]
130
- n_aia_wavelengths = len(self.aia_wavelengths)
131
- wspace = 0.2
132
- hspace = 0.125
133
- dpi = 100
134
-
135
- if n_aia_wavelengths < 3:
136
- nrows = 1
137
- ncols = n_aia_wavelengths
138
- fig = plt.figure(figsize=(9 + 9 / 4 * n_aia_wavelengths, 3 * samples), dpi=dpi)
139
- gs = fig.add_gridspec(samples, n_aia_wavelengths + 3, wspace=wspace, hspace=hspace)
140
- elif n_aia_wavelengths < 5:
141
- nrows = 2
142
- ncols = 2
143
- fig = plt.figure(figsize=(9 + 9 / 4 * 2, 6 * samples), dpi=dpi)
144
- gs = fig.add_gridspec(2 * samples, 5, wspace=wspace, hspace=hspace)
145
- elif n_aia_wavelengths < 7:
146
- nrows = 2
147
- ncols = 3
148
- fig = plt.figure(figsize=(9 + 9 / 4 * 3, 6 * samples), dpi=dpi)
149
- gs = fig.add_gridspec(2 * samples, 6, wspace=wspace, hspace=hspace)
150
- else:
151
- nrows = 2
152
- ncols = 4
153
- fig = plt.figure(figsize=(15, 5 * samples), dpi=dpi)
154
- gs = fig.add_gridspec(2 * samples, 7, wspace=wspace, hspace=hspace)
155
-
156
- cmaps_all = ['sdoaia94', 'sdoaia131', 'sdoaia171', 'sdoaia193', 'sdoaia211',
157
- 'sdoaia304', 'sdoaia335', 'sdoaia1600', 'sdoaia1700']
158
- cmaps = [cmaps_all[i] for i in self.aia_wavelengths]
159
- n_plots = 0
160
-
161
- for s in range(samples):
162
- for i in range(nrows):
163
- for j in range(ncols):
164
- if n_plots < n_aia_wavelengths:
165
- ax = fig.add_subplot(gs[s * nrows + i, j])
166
- ax.imshow(val_imgs[s, i * ncols + j], cmap=plt.get_cmap(cmaps[i * ncols + j]), origin='lower')
167
- ax.text(0.01, 0.99, cmaps[i * ncols + j], horizontalalignment='left', verticalalignment='top',
168
- color='w', transform=ax.transAxes)
169
- ax.set_axis_off()
170
- n_plots += 1
171
- n_plots = 0
172
- # eve data
173
- ax5 = fig.add_subplot(gs[s * nrows, ncols:])
174
- if self.names is not None:
175
- ax5.bar(np.arange(0, len(val_eve[s, :])), val_eve[s, :], label='ground truth')
176
- ax5.bar(np.arange(0, len(pred_eve[s, :])), pred_eve[s, :], width=0.5, label='prediction', alpha=0.5)
177
- ax5.set_xticks(np.arange(0, len(val_eve[s, :])))
178
- ax5.set_xticklabels(self.names, rotation=45)
179
- else:
180
- ax5.plot(np.arange(0, len(val_eve[s, :])), val_eve[s, :], label='ground truth', alpha=0.5,
181
- drawstyle='steps-mid')
182
- ax5.plot(np.arange(0, len(pred_eve[s, :])), pred_eve[s, :], label='prediction', alpha=0.5,
183
- drawstyle='steps-mid')
184
- ax5.set_yscale('log')
185
- ax5.legend()
186
-
187
- ax6 = fig.add_subplot(gs[s * nrows + 1, ncols:])
188
- if self.names is not None:
189
- ax6.bar(np.arange(0, len(val_eve[s, :])), np.abs(pred_eve[s, :] - val_eve[s, :]) / val_eve[s, :] * 100,
190
- label='relative error (%)')
191
- ax6.set_xticks(np.arange(0, len(val_eve[s, :])))
192
- ax6.set_xticklabels(self.names, rotation=45)
193
- else:
194
- ax6.plot(np.arange(0, len(val_eve[s, :])), np.abs(pred_eve[s, :] - val_eve[s, :]) / val_eve[s, :] * 100,
195
- label='relative error (%)', alpha=0.5, drawstyle='steps-mid')
196
- ax6.set_yscale('log')
197
- ax6.legend()
198
-
199
- fig.tight_layout()
200
- return fig
201
-
202
-
203
- class SpectrumPredictionLogger(ImagePredictionLogger):
204
- def __init__(self, val_imgs, val_eve, names, aia_wavelengths):
205
- super().__init__(val_imgs, val_eve, names, aia_wavelengths)
206
-
207
- def plot_aia_eve(self, val_imgs, val_eve, pred_eve):
208
- """
209
- Function to plot a 4 channel AIA stack and the EVE barplots
210
-
211
- Arguments:
212
- ----------
213
- val_imgs: numpy array
214
- Stack with 4 image channels
215
- val_eve: numpy array
216
- Stack of ground-truth eve channels
217
- pred_eve: numpy array
218
- Stack of predicted eve channels
219
- Returns:
220
- --------
221
- fig: matplotlib figure
222
- figure with plots
223
- """
224
- samples = pred_eve.shape[0]
225
- n_aia_wavelengths = len(self.aia_wavelengths)
226
- wspace = 0.2
227
- hspace = 0.125
228
- dpi = 200
229
-
230
- fig = plt.figure(figsize=(5, 5), dpi=dpi)
231
- gs = fig.add_gridspec(2, 1, wspace=wspace, hspace=hspace)
232
-
233
- # eve data
234
- s = 0
235
- ax5 = fig.add_subplot(gs[0, 0])
236
- if self.names is not None:
237
- ax5.bar(np.arange(0, len(val_eve[s, :])), val_eve[s, :], label='ground truth')
238
- ax5.bar(np.arange(0, len(pred_eve[s, :])), pred_eve[s, :], width=0.5, label='prediction', alpha=0.5)
239
- ax5.set_xticks(np.arange(0, len(val_eve[s, :])))
240
- ax5.set_xticklabels(self.names, rotation=45)
241
- else:
242
- ax5.plot(np.arange(0, len(val_eve[s, :])), val_eve[s, :], label='ground truth', alpha=0.5,
243
- drawstyle='steps-mid')
244
- ax5.plot(np.arange(0, len(pred_eve[s, :])), pred_eve[s, :], label='prediction', alpha=0.5,
245
- drawstyle='steps-mid')
246
- ax5.set_yscale('log')
247
- ax5.legend()
248
-
249
- ax6 = fig.add_subplot(gs[1, 0])
250
- if self.names is not None:
251
- ax6.bar(np.arange(0, len(val_eve[s, :])), np.abs(pred_eve[s, :] - val_eve[s, :]) / val_eve[s, :] * 100,
252
- label='relative error (%)')
253
- ax6.set_xticks(np.arange(0, len(val_eve[s, :])))
254
- ax6.set_xticklabels(self.names, rotation=45)
255
- else:
256
- ax6.plot(np.arange(0, len(val_eve[s, :])), np.abs(pred_eve[s, :] - val_eve[s, :]) / val_eve[s, :] * 100,
257
- label='relative error (%)', alpha=0.5, drawstyle='steps-mid')
258
- ax6.set_yscale('log')
259
- ax6.legend()
260
-
261
- fig.tight_layout()
262
- return fig
 
10
  # Custom Callback
11
  sdoaia94 = matplotlib.colormaps['sdoaia94']
12
 
13
+
14
+ def unnormalize_sxr(normalized_values, sxr_norm):
15
+ if isinstance(normalized_values, torch.Tensor):
16
+ normalized_values = normalized_values.cpu().numpy()
17
+ normalized_values = np.array(normalized_values, dtype=np.float32)
18
+ return 10 ** (normalized_values * float(sxr_norm[1].item()) + float(sxr_norm[0].item())) - 1e-8
19
 
20
 
21
  class ImagePredictionLogger_SXR(Callback):
 
27
  self.val_sxr = data_samples[1]
28
  self.sxr_norm = sxr_norm
29
 
 
 
 
 
 
 
30
  def on_validation_epoch_end(self, trainer, pl_module):
31
 
32
  aia_images = []
 
44
  aia_images.append(aia.squeeze(0).cpu().numpy())
45
  true_sxr.append(target.item())
46
 
47
+ true_unorm = unnormalize_sxr(true_sxr,self.sxr_norm)
48
+ pred_unnorm = unnormalize_sxr(pred_sxr,self.sxr_norm)
49
  fig1 = self.plot_aia_sxr(aia_images,true_unorm, pred_unnorm)
50
  trainer.logger.experiment.log({"Soft X-ray flux plots": wandb.Image(fig1)})
51
  plt.close(fig1)
 
79
 
80
  fig.tight_layout()
81
  return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flaring/MEGS_AI_baseline/config.yaml CHANGED
@@ -11,7 +11,9 @@
11
  cnn_dp:
12
  - 0.75
13
  epochs:
14
- - 100
 
 
15
  wandb:
16
  entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
17
  project: MEGS-AI flaring # Lowercase, no spaces
 
11
  cnn_dp:
12
  - 0.75
13
  epochs:
14
+ - 1
15
+ save_dictionary:
16
+ -
17
  wandb:
18
  entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
19
  project: MEGS-AI flaring # Lowercase, no spaces
flaring/MEGS_AI_baseline/inference.py CHANGED
@@ -3,28 +3,42 @@ import torch
3
  import numpy as np
4
  from torch.utils.data import DataLoader
5
  from SDOAIA_dataloader import AIA_GOESDataset
 
 
6
 
 
7
  def predict_log_outputs(model, dataset, batch_size=8):
8
- """Generator yielding raw log-space model outputs"""
9
  model.eval()
10
  loader = DataLoader(dataset, batch_size=batch_size)
11
 
 
 
 
12
  with torch.no_grad():
13
  for batch in loader:
14
- # Handle different dataset formats
15
  if isinstance(batch, tuple) and len(batch) == 2:
16
- aia_imgs = batch[0][0] # Unpack ((aia, sxr), target)
 
17
  else:
18
- aia_imgs = batch[0] if isinstance(batch, (list, tuple)) else batch
 
 
 
 
19
 
20
- aia_imgs = aia_imgs.to(next(model.parameters()).device)
21
- log_outputs = model(aia_imgs) # Get raw log-space outputs
 
 
22
  yield from log_outputs.cpu().numpy()
23
 
24
  def main():
25
  parser = argparse.ArgumentParser(description='Save raw log-space model outputs')
26
- parser.add_argument('--model', required=True, help='Path to trained model')
27
- parser.add_argument('--aia-dir', required=True, help='Directory of AIA images')
 
 
28
  parser.add_argument('--output', default='log_predictions.txt',
29
  help='Output file for log-space predictions')
30
  parser.add_argument('--batch-size', type=int, default=8,
@@ -32,23 +46,42 @@ def main():
32
 
33
  args = parser.parse_args()
34
 
 
 
35
  # Setup
36
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
37
- model = torch.load(args.model, map_location=device).to(device)
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  # Dataset without any output transformation
40
  dataset = AIA_GOESDataset(
41
  aia_dir=args.aia_dir,
 
42
  sxr_dir='',
43
  sxr_norm=None,
44
  transform=None
 
 
 
 
45
  )
46
 
47
  # Save log-space predictions
48
  with open(args.output, 'w') as f:
49
  f.write("# Log-space SXR predictions (log10(W/m²))\n")
50
  for log_pred in predict_log_outputs(model, dataset, args.batch_size):
51
- f.write(f"{log_pred:.6f}\n") # Write with 6 decimal places
 
52
 
53
  print(f"Log-space predictions saved to {args.output}")
54
  print("These are raw model outputs in log10 space before any exponentiation")
 
3
  import numpy as np
4
  from torch.utils.data import DataLoader
5
  from SDOAIA_dataloader import AIA_GOESDataset
6
+ from models.linear_and_hybrid import HybridIrradianceModel
7
+ from callback import unnormalize_sxr
8
 
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
  def predict_log_outputs(model, dataset, batch_size=8):
 
11
  model.eval()
12
  loader = DataLoader(dataset, batch_size=batch_size)
13
 
14
+ # Get device from model
15
+ device = next(model.parameters()).device
16
+
17
  with torch.no_grad():
18
  for batch in loader:
19
+ # Correct unpacking based on your data structure
20
  if isinstance(batch, tuple) and len(batch) == 2:
21
+ # batch = (inputs, targets) where inputs = [aia_imgs, sxr_imgs]
22
+ aia_imgs = batch[0][0] # Get aia_imgs from inputs
23
  else:
24
+ # Fallback for other formats
25
+ aia_imgs = batch[0][0] if isinstance(batch[0], list) else batch[0]
26
+
27
+ # Move to device (it's already a tensor)
28
+ aia_imgs = aia_imgs.to(device)
29
 
30
+ # Get model predictions
31
+ log_outputs = model(aia_imgs)
32
+
33
+ # Move to CPU and convert to numpy before yielding
34
  yield from log_outputs.cpu().numpy()
35
 
36
  def main():
37
  parser = argparse.ArgumentParser(description='Save raw log-space model outputs')
38
+ parser.add_argument('--ckpt_path', required=True, help='Path to model checkpoint')
39
+ parser.add_argument('--aia_dir', required=True, help='Directory of AIA images')
40
+ parser.add_argument('--sxr_dir', required=True, help='Directory of target SXR data')
41
+ parser.add_argument('--sxr_norm', required=True, help='Path to SXR normalization parameters (mean, std)')
42
  parser.add_argument('--output', default='log_predictions.txt',
43
  help='Output file for log-space predictions')
44
  parser.add_argument('--batch-size', type=int, default=8,
 
46
 
47
  args = parser.parse_args()
48
 
49
+ sxr_norm = np.load(args.sxr_norm)
50
+
51
  # Setup
52
+ state = torch.load(args.ckpt_path, map_location=device, weights_only=False)
53
+ model = state['model']
54
+ model.to(device)
55
+ # Assume it's a checkpoint with state_dict
56
+
57
+ # model = HybridIrradianceModel(6,1)
58
+ # state_dict = checkpoint.get('state_dict', checkpoint)
59
+ #
60
+ # # Handle potential key mismatches (e.g., PyTorch Lightning prefixes)
61
+ # state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
62
+ # model.load_state_dict(state_dict, strict=False)
63
+
64
+
65
 
66
  # Dataset without any output transformation
67
  dataset = AIA_GOESDataset(
68
  aia_dir=args.aia_dir,
69
+ <<<<<<< HEAD
70
  sxr_dir='',
71
  sxr_norm=None,
72
  transform=None
73
+ =======
74
+ sxr_dir=args.sxr_dir, # No SXR files needed
75
+ transform=None # No input transforms
76
+ >>>>>>> 22f4a17192a3a77fa4d4fe1ae3a2aa8c0bbdb539
77
  )
78
 
79
  # Save log-space predictions
80
  with open(args.output, 'w') as f:
81
  f.write("# Log-space SXR predictions (log10(W/m²))\n")
82
  for log_pred in predict_log_outputs(model, dataset, args.batch_size):
83
+ pred = unnormalize_sxr(log_pred, sxr_norm)
84
+ print(pred)
85
 
86
  print(f"Log-space predictions saved to {args.output}")
87
  print("These are raw model outputs in log10 space before any exponentiation")
flaring/MEGS_AI_baseline/models/__init__.py ADDED
File without changes
flaring/MEGS_AI_baseline/{base_model.py → models/base_model.py} RENAMED
File without changes
flaring/MEGS_AI_baseline/{efficientnet.py → models/efficientnet.py} RENAMED
File without changes
flaring/MEGS_AI_baseline/{kan_success.py → models/kan_success.py} RENAMED
File without changes
flaring/MEGS_AI_baseline/{linear_and_hybrid.py → models/linear_and_hybrid.py} RENAMED
@@ -108,9 +108,6 @@ class HybridIrradianceModel(BaseModel):
108
  if isinstance(x, (list, tuple)):
109
  x = x[0]
110
 
111
- # Debug: Print input shape
112
- print(f"Input shape to HybridIrradianceModel.forward: {x.shape}")
113
-
114
  # Expect x shape: (batch_size, H, W, C)
115
  if len(x.shape) != 4:
116
  raise ValueError(f"Expected 4D input tensor (batch_size, H, W, C), got shape {x.shape}")
 
108
  if isinstance(x, (list, tuple)):
109
  x = x[0]
110
 
 
 
 
111
  # Expect x shape: (batch_size, H, W, C)
112
  if len(x.shape) != 4:
113
  raise ValueError(f"Expected 4D input tensor (batch_size, H, W, C), got shape {x.shape}")
flaring/MEGS_AI_baseline/train.py CHANGED
@@ -1,19 +1,19 @@
1
 
2
  import argparse
3
  import os
 
 
4
  import yaml
5
  import itertools
6
  import wandb
7
  import torch
8
  import numpy as np
9
- from pathlib import Path
10
- import torchvision.transforms as transforms
11
  from pytorch_lightning import Trainer
12
  from pytorch_lightning.loggers import WandbLogger
13
- from pytorch_lightning.callbacks import ModelCheckpoint, Callback
14
- from torch.nn import HuberLoss, MSELoss
15
  from SDOAIA_dataloader import AIA_GOESDataModule
16
- from linear_and_hybrid import LinearIrradianceModel, HybridIrradianceModel
17
  from callback import ImagePredictionLogger_SXR
18
 
19
  # Parser
@@ -117,7 +117,7 @@ for parameter_set in combined_parameters:
117
  default_root_dir=checkpoint_dir,
118
  accelerator="gpu" if torch.cuda.is_available() else "cpu",
119
  devices=1,
120
- max_epochs=run_config.get('epochs', 10),
121
  callbacks=[sxr_plot_callback, checkpoint_callback],
122
  logger=wandb_logger,
123
  log_every_n_steps=10
@@ -127,15 +127,16 @@ for parameter_set in combined_parameters:
127
  trainer.fit(model, data_loader)
128
 
129
  # Save checkpoint
130
- save_dictionary = run_config
131
- save_dictionary['model'] = model
132
- save_dictionary['instrument'] = instrument
133
- full_checkpoint_path = os.path.join(checkpoint_dir, f"{wb_name}_{n}.ckpt")
134
- torch.save(save_dictionary, full_checkpoint_path)
135
-
136
- # Test
137
- trainer.test(model, dataloaders=data_loader.test_dataloader())
138
 
 
 
 
 
 
 
 
 
 
139
  # Finalize
140
- wandb.finish()
141
- n += 1
 
1
 
2
  import argparse
3
  import os
4
+ from datetime import datetime
5
+
6
  import yaml
7
  import itertools
8
  import wandb
9
  import torch
10
  import numpy as np
 
 
11
  from pytorch_lightning import Trainer
12
  from pytorch_lightning.loggers import WandbLogger
13
+ from pytorch_lightning.callbacks import ModelCheckpoint
14
+ from torch.nn import MSELoss
15
  from SDOAIA_dataloader import AIA_GOESDataModule
16
+ from models.linear_and_hybrid import LinearIrradianceModel, HybridIrradianceModel
17
  from callback import ImagePredictionLogger_SXR
18
 
19
  # Parser
 
117
  default_root_dir=checkpoint_dir,
118
  accelerator="gpu" if torch.cuda.is_available() else "cpu",
119
  devices=1,
120
+ max_epochs=run_config['epochs'],
121
  callbacks=[sxr_plot_callback, checkpoint_callback],
122
  logger=wandb_logger,
123
  log_every_n_steps=10
 
127
  trainer.fit(model, data_loader)
128
 
129
  # Save checkpoint
130
+ trainer.fit(model, data_loader)
 
 
 
 
 
 
 
131
 
132
+ # Save final PyTorch checkpoint with model and state_dict
133
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
134
+ final_checkpoint_path = os.path.join(checkpoint_dir, f"{wb_name}-final-{timestamp}.pth")
135
+ torch.save({
136
+ 'model': model,
137
+ 'state_dict': model.state_dict()
138
+ }, final_checkpoint_path)
139
+ print(f"Saved final PyTorch checkpoint: {final_checkpoint_path}")
140
+ n += 1
141
  # Finalize
142
+ wandb.finish()
 
flaring/normalization_and_aligning_data.py CHANGED
@@ -6,11 +6,13 @@ from astropy.io import fits
6
  import warnings
7
  import pandas as pd
8
  from astropy.visualization import ImageNormalize, AsinhStretch
 
 
 
 
9
 
10
  warnings.filterwarnings('ignore')
11
 
12
- import pandas as pd
13
-
14
  # Directory paths for each wavelength folder.
15
  wavelength_dirs = {
16
  "94": "/mnt/data2/AIA_processed_data/94",
@@ -22,50 +24,8 @@ wavelength_dirs = {
22
  }
23
 
24
  # Regular expression to extract timestamp from file names.
25
- # Adjust this pattern to match your file naming scheme.
26
  timestamp_pattern = re.compile(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
27
 
28
- # Collect timestamps found in each wavelength directory.
29
- timestamps_found = defaultdict(set)
30
-
31
- for wavelength, dir_path in wavelength_dirs.items():
32
- try:
33
- for filename in os.listdir(dir_path):
34
- match = timestamp_pattern.search(filename)
35
- if match:
36
- ts = match.group(0)
37
- timestamps_found[ts].add(wavelength)
38
- except Exception as e:
39
- print(f"Could not read directory {dir_path}: {e}")
40
-
41
- # Identify timestamps that exist in all wavelength folders.
42
- all_wavelengths = set(wavelength_dirs.keys())
43
- common_timestamps = [ts for ts, waves in timestamps_found.items() if waves == all_wavelengths]
44
-
45
- # Identify which timestamps are missing files for some wavelengths.
46
- missing_files = {
47
- ts: list(all_wavelengths - waves)
48
- for ts, waves in timestamps_found.items() if waves != all_wavelengths
49
- }
50
-
51
- print("Timestamps present in all wavelength folders:")
52
- for ts in sorted(common_timestamps):
53
- print(ts)
54
-
55
- print("\nTimestamps with missing wavelength files:")
56
- for ts, missing in missing_files.items():
57
- print(f"{ts}: missing {', '.join(sorted(missing))}")
58
-
59
-
60
- goes = pd.read_csv("/mnt/data/goes_combined/combined_g18_avg1m_20230701_20230815.csv")
61
- # Convert 'time' column to datetime
62
- goes['time'] = pd.to_datetime(goes['time'], format='%Y-%m-%d %H:%M:%S')
63
-
64
-
65
- # Initialize the array to store all wavelength data
66
- data_shape = (6, 512, 512)
67
-
68
-
69
  # Map wavelengths to array indices
70
  wavelength_to_idx = {
71
  '94': 0,
@@ -76,52 +36,171 @@ wavelength_to_idx = {
76
  '304': 5
77
  }
78
 
79
- sdo_norms = {0: ImageNormalize(vmin=0, vmax= np.float32(16.560747), stretch=AsinhStretch(0.005), clip=True),
80
- 1: ImageNormalize(vmin=0, vmax= np.float32(75.84181), stretch=AsinhStretch(0.005), clip=True),
81
- 2: ImageNormalize(vmin=0, vmax= np.float32(1536.1443), stretch=AsinhStretch(0.005), clip=True),
82
- 3: ImageNormalize(vmin=0, vmax= np.float32(2288.1), stretch=AsinhStretch(0.005), clip=True),
83
- 4: ImageNormalize(vmin=0, vmax=np.float32(1163.9178), stretch=AsinhStretch(0.005), clip=True),
84
- 5: ImageNormalize(vmin=0, vmax=np.float32(401.82352), stretch=AsinhStretch(0.001), clip=True),
85
- }
86
-
87
-
88
-
89
- # Load data for each timestamp and wavelength
90
- for time_idx, timestamp in enumerate(common_timestamps):
91
- sxr = goes[goes['time'] == pd.to_datetime(timestamp)]
92
- sxr_a = sxr['xrsa_flux'].values[0] if not sxr.empty else None
93
- sxr_b = sxr['xrsb_flux'].values[0] if not sxr.empty else None
94
- if sxr_a is None or sxr_b is None:
95
- print(f"Missing SXR data for timestamp {timestamp}, skipping...")
96
- continue
97
- wavelength_data = np.zeros(data_shape, dtype=np.float32)
98
- sxr_a_data = np.zeros(1, dtype=np.float32)
99
- sxr_b_data = np.zeros(1, dtype=np.float32)
100
- sxr_a_data[0] = sxr_a if sxr_a is not None else np.nan
101
- sxr_b_data[0] = sxr_b if sxr_b is not None else np.nan
102
- print(f"Processing timestamp: {timestamp} (Index: {time_idx})")
103
- for wavelength, wave_idx in wavelength_to_idx.items():
104
- filepath = os.path.join(wavelength_dirs[wavelength], f"{timestamp}.fits")
105
- with fits.open(filepath) as hdul:
106
- raw_data = hdul[0].data
107
-
108
- # Apply the appropriate normalization for this wavelength
109
- if wave_idx in sdo_norms:
110
- # Get the normalizer for this wavelength index
111
- normalizer = sdo_norms[wave_idx]
112
-
113
- # Apply normalization and convert to [-1, 1] range
114
- normalized_data = normalizer(raw_data)
115
- wavelength_data[wave_idx] = normalized_data * 2 - 1
116
- else:
117
- # Fallback if no normalizer exists for this wavelength
118
- print(f"Warning: No normalizer found for wavelength index {wave_idx}")
119
- wavelength_data[wave_idx] = raw_data
120
-
121
- # Store the wavelength data for this timestamp
122
- np.save(f"/mnt/data2/ML-Ready/AIA-Data/{timestamp}.npy", wavelength_data)
123
- # Store the SXR data
124
- np.save(f"/mnt/data2/ML-Ready/GOES-18-SXR-A/{timestamp}.npy", sxr_a_data)
125
- np.save(f"/mnt/data2/ML-Ready/GOES-18-SXR-B/{timestamp}.npy", sxr_b_data)
126
- print(f"Saved data for timestamp {timestamp} to disk.")
127
- print(f"Percent: {time_idx + 1} / {len(common_timestamps)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import warnings
7
  import pandas as pd
8
  from astropy.visualization import ImageNormalize, AsinhStretch
9
+ from multiprocessing import Pool, cpu_count
10
+ from functools import partial
11
+ import time
12
+ from tqdm import tqdm
13
 
14
  warnings.filterwarnings('ignore')
15
 
 
 
16
  # Directory paths for each wavelength folder.
17
  wavelength_dirs = {
18
  "94": "/mnt/data2/AIA_processed_data/94",
 
24
  }
25
 
26
  # Regular expression to extract timestamp from file names.
 
27
  timestamp_pattern = re.compile(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # Map wavelengths to array indices
30
  wavelength_to_idx = {
31
  '94': 0,
 
36
  '304': 5
37
  }
38
 
39
+ # Initialize the array to store all wavelength data
40
+ data_shape = (6, 512, 512)
41
+
42
+ sdo_norms = {
43
+ 0: ImageNormalize(vmin=0, vmax=np.float32(16.560747), stretch=AsinhStretch(0.005), clip=True),
44
+ 1: ImageNormalize(vmin=0, vmax=np.float32(75.84181), stretch=AsinhStretch(0.005), clip=True),
45
+ 2: ImageNormalize(vmin=0, vmax=np.float32(1536.1443), stretch=AsinhStretch(0.005), clip=True),
46
+ 3: ImageNormalize(vmin=0, vmax=np.float32(2288.1), stretch=AsinhStretch(0.005), clip=True),
47
+ 4: ImageNormalize(vmin=0, vmax=np.float32(1163.9178), stretch=AsinhStretch(0.005), clip=True),
48
+ 5: ImageNormalize(vmin=0, vmax=np.float32(401.82352), stretch=AsinhStretch(0.001), clip=True),
49
+ }
50
+
51
+
52
+ def process_timestamp(args):
53
+ """
54
+ Process a single timestamp: load wavelength data, apply normalization,
55
+ and save to disk along with SXR data.
56
+ """
57
+ timestamp, goes_data = args
58
+ try:
59
+ # Get SXR data for this timestamp
60
+ sxr = goes_data[goes_data['time'] == pd.to_datetime(timestamp)]
61
+ sxr_a = sxr['xrsa_flux'].values[0] if not sxr.empty else None
62
+ sxr_b = sxr['xrsb_flux'].values[0] if not sxr.empty else None
63
+
64
+ if sxr_a is None or sxr_b is None:
65
+ return (timestamp, False, f"Missing SXR data for timestamp {timestamp}")
66
+
67
+ # Initialize arrays
68
+ wavelength_data = np.zeros(data_shape, dtype=np.float32)
69
+ sxr_a_data = np.zeros(1, dtype=np.float32)
70
+ sxr_b_data = np.zeros(1, dtype=np.float32)
71
+ sxr_a_data[0] = sxr_a
72
+ sxr_b_data[0] = sxr_b
73
+
74
+ # Process each wavelength
75
+ for wavelength, wave_idx in wavelength_to_idx.items():
76
+ filepath = os.path.join(wavelength_dirs[wavelength], f"{timestamp}.fits")
77
+
78
+ with fits.open(filepath) as hdul:
79
+ raw_data = hdul[0].data
80
+
81
+ # Apply the appropriate normalization for this wavelength
82
+ if wave_idx in sdo_norms:
83
+ normalizer = sdo_norms[wave_idx]
84
+ normalized_data = normalizer(raw_data)
85
+ wavelength_data[wave_idx] = normalized_data * 2 - 1
86
+ else:
87
+ wavelength_data[wave_idx] = raw_data
88
+
89
+ # Save data to disk
90
+ np.save(f"/mnt/data2/ML-Ready/AIA-Data/{timestamp}.npy", wavelength_data)
91
+ np.save(f"/mnt/data2/ML-Ready/GOES-18-SXR-A/{timestamp}.npy", sxr_a_data)
92
+ np.save(f"/mnt/data2/ML-Ready/GOES-18-SXR-B/{timestamp}.npy", sxr_b_data)
93
+
94
+ return (timestamp, True, "Success")
95
+
96
+ except Exception as e:
97
+ return (timestamp, False, f"Error processing timestamp {timestamp}: {e}")
98
+
99
+
100
+ def update_progress(result):
101
+ """Callback function to update progress bar"""
102
+ global pbar, successful_count, failed_count
103
+ timestamp, success, message = result
104
+
105
+ if success:
106
+ successful_count += 1
107
+ pbar.set_postfix(success=successful_count, failed=failed_count)
108
+ else:
109
+ failed_count += 1
110
+ pbar.set_postfix(success=successful_count, failed=failed_count)
111
+ tqdm.write(f"Failed: {message}")
112
+
113
+ pbar.update(1)
114
+
115
+
116
+ def main():
117
+ global pbar, successful_count, failed_count
118
+
119
+ # Collect timestamps found in each wavelength directory.
120
+ timestamps_found = defaultdict(set)
121
+
122
+ print("Scanning directories for timestamps...")
123
+ for wavelength, dir_path in tqdm(wavelength_dirs.items(), desc="Scanning directories"):
124
+ try:
125
+ for filename in os.listdir(dir_path):
126
+ match = timestamp_pattern.search(filename)
127
+ if match:
128
+ ts = match.group(0)
129
+ timestamps_found[ts].add(wavelength)
130
+ except Exception as e:
131
+ print(f"Could not read directory {dir_path}: {e}")
132
+
133
+ # Identify timestamps that exist in all wavelength folders.
134
+ all_wavelengths = set(wavelength_dirs.keys())
135
+ common_timestamps = [ts for ts, waves in timestamps_found.items() if waves == all_wavelengths]
136
+
137
+ # Identify which timestamps are missing files for some wavelengths.
138
+ missing_files = {
139
+ ts: list(all_wavelengths - waves)
140
+ for ts, waves in timestamps_found.items() if waves != all_wavelengths
141
+ }
142
+
143
+ print(f"\nFound {len(common_timestamps)} timestamps present in all wavelength folders")
144
+ print(f"Found {len(missing_files)} timestamps with missing wavelength files")
145
+
146
+ # Load GOES data
147
+ print("Loading GOES data...")
148
+ goes = pd.read_csv("/mnt/data/goes_combined/combined_g18_avg1m_20230701_20230815.csv")
149
+ goes['time'] = pd.to_datetime(goes['time'], format='%Y-%m-%d %H:%M:%S')
150
+
151
+ # Create output directories if they don't exist
152
+ os.makedirs("/mnt/data2/ML-Ready/AIA-Data", exist_ok=True)
153
+ os.makedirs("/mnt/data2/ML-Ready/GOES-18-SXR-A", exist_ok=True)
154
+ os.makedirs("/mnt/data2/ML-Ready/GOES-18-SXR-B", exist_ok=True)
155
+
156
+ # Use all available CPU cores
157
+ num_processes = cpu_count()
158
+ print(f"Using {num_processes} CPU cores for processing")
159
+ print(f"Processing {len(common_timestamps)} timestamps...")
160
+
161
+ # Initialize global counters for progress tracking
162
+ successful_count = 0
163
+ failed_count = 0
164
+
165
+ # Create arguments for multiprocessing (timestamp, goes_data pairs)
166
+ args_list = [(timestamp, goes) for timestamp in common_timestamps]
167
+
168
+ # Start timing
169
+ start_time = time.time()
170
+
171
+ # Create progress bar
172
+ pbar = tqdm(total=len(common_timestamps), desc="Processing timestamps",
173
+ unit="timestamp", dynamic_ncols=True)
174
+
175
+ # Process timestamps in parallel with progress tracking
176
+ with Pool(processes=num_processes) as pool:
177
+ # Use map with callback for real-time progress updates
178
+ results = []
179
+ for args in args_list:
180
+ result = pool.apply_async(process_timestamp, (args,), callback=update_progress)
181
+ results.append(result)
182
+
183
+ # Wait for all processes to complete
184
+ for result in results:
185
+ result.wait()
186
+
187
+ # Close progress bar
188
+ pbar.close()
189
+
190
+ # Calculate statistics
191
+ end_time = time.time()
192
+ total_time = end_time - start_time
193
+
194
+ print(f"\nProcessing complete!")
195
+ print(f"Total time: {total_time:.2f} seconds")
196
+ print(f"Average time per timestamp: {total_time / len(common_timestamps):.2f} seconds")
197
+ print(f"Successfully processed: {successful_count}/{len(common_timestamps)} timestamps")
198
+ print(f"Failed processes: {failed_count}")
199
+ print(f"Processing rate: {len(common_timestamps) / total_time:.2f} timestamps/second")
200
+
201
+ if failed_count > 0:
202
+ print(f"\n{failed_count} timestamps failed processing (see messages above)")
203
+
204
+
205
+ if __name__ == "__main__":
206
+ main()
flaring/outputs/outputs.txt ADDED
File without changes
flaring/split_data.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import pandas as pd
3
  import shutil
@@ -75,9 +76,94 @@ for base_dir in [flares_event_dir, non_flares_event_dir]:
75
  else:
76
  print(f"Skipping file {file} in {base_dir}: Outside date range")
77
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  # Move file to appropriate split directory
80
  src = os.path.join(base_dir, file)
81
  dst = os.path.join(base_dir, split_dir, file)
82
  shutil.move(src, dst)
83
- print(f"Moved {file} to {base_dir}/{split_dir}")
 
 
 
 
 
1
+ <<<<<<< HEAD
2
  import os
3
  import pandas as pd
4
  import shutil
 
76
  else:
77
  print(f"Skipping file {file} in {base_dir}: Outside date range")
78
  continue
79
+ =======
80
+ #
81
+
82
+
83
+ #
84
+ # data_dir = "/mnt/data/ML-Ready/AIA-Data"
85
+ # flares_event_dir = "/mnt/data/ML-Ready/flares_event_dir"
86
+ # non_flares_event_dir = "/mnt/data/ML-Ready/non_flares_event_dir"
87
+ # flare_events_csv = "/mnt/data/flare_list/flare_events_2023-07-01_2023-08-15.csv"
88
+ #
89
+ # train_range = (datetime(2023, 7, 1), datetime(2023, 7, 25))
90
+ # val_range = (datetime(2023, 7, 27), datetime(2023, 7, 30))
91
+ # test_range = (datetime(2023, 8, 1), datetime(2023, 8, 15))
92
+ #
93
+ # os.makedirs(flares_event_dir, exist_ok=True)
94
+ # os.makedirs(non_flares_event_dir, exist_ok=True)
95
+ #
96
+ # os.makedirs(os.path.join(flares_event_dir, "train"), exist_ok=True)
97
+ # os.makedirs(os.path.join(flares_event_dir, "val"), exist_ok=True)
98
+ # os.makedirs(os.path.join(flares_event_dir, "test"), exist_ok=True)
99
+ #
100
+ # os.makedirs(os.path.join(non_flares_event_dir, "train"), exist_ok=True)
101
+ # os.makedirs(os.path.join(non_flares_event_dir, "val"), exist_ok=True)
102
+ # os.makedirs(os.path.join(non_flares_event_dir, "test"), exist_ok=True)
103
+ #
104
+ #
105
+ # flare_event = pd.read_csv(flare_events_csv)
106
+ # print(f"Found {len(flare_event)} flare events")
107
+ # flaring_eve_list = []
108
+ # for i, row in flare_event.iterrows():
109
+ # start_time = pd.to_datetime(row['event_starttime'])
110
+ # end_time = pd.to_datetime(row['event_endtime'])
111
+ # flaring_eve_list.append((start_time, end_time))
112
+ #
113
+ # data_list = os.listdir(data_dir)
114
+ # print(f"Found {len(data_list)} files in {data_dir}")
115
+ # for file in data_list:
116
+ # try:
117
+ # aia_time = pd.to_datetime(file.split(".")[0])
118
+ # except ValueError:
119
+ # print(f"Skipping file {file}: Invalid timestamp format")
120
+ # continue
121
+ #
122
+ # # Check if the file's time falls within any flare event
123
+ # is_flaring = any(start <= aia_time <= end for start, end in flaring_eve_list)
124
+ # if is_flaring:
125
+ # src = os.path.join(data_dir, file)
126
+ # dst = os.path.join(flares_event_dir, file)
127
+ #
128
+ # if train_range[0] <= aia_time <= train_range[1]:
129
+ # dst = os.path.join(flares_event_dir, "train")
130
+ # shutil.copy(src, dst)
131
+ # elif val_range[0] <= aia_time <= val_range[1]:
132
+ # dst = os.path.join(flares_event_dir, "val")
133
+ # shutil.copy(src, dst)
134
+ # elif test_range[0] <= aia_time <= test_range[1]:
135
+ # dst = os.path.join(flares_event_dir, "test")
136
+ # shutil.copy(src, dst)
137
+ # else:
138
+ # print(f"Skipping {file}: Time {aia_time} not in any defined range")
139
+ # continue
140
+ # print(f"Copied {file} to {dst}")
141
+ # else:
142
+ # print("Skipping non-flaring event file:", file)
143
+ # else:
144
+ # src = os.path.join(data_dir, file)
145
+ # dst = os.path.join(non_flares_event_dir, file)
146
+ # print(aia_time)
147
+ # print(train_range[0], train_range[1])
148
+ # if train_range[0] <= aia_time <= train_range[1]:
149
+ # split_dir = "train"
150
+ # elif val_range[0] <= aia_time <= val_range[1]:
151
+ # split_dir = "val"
152
+ # elif test_range[0] <= aia_time <= test_range[1]:
153
+ # split_dir = "test"
154
+ # dst = os.path.join(flares_event_dir, split_dir)
155
+ # # shutil.copy(src, dst)
156
+ # print(f"Copied {file} to {dst}")
157
+
158
+
159
+ >>>>>>> 22f4a17192a3a77fa4d4fe1ae3a2aa8c0bbdb539
160
 
161
  # Move file to appropriate split directory
162
  src = os.path.join(base_dir, file)
163
  dst = os.path.join(base_dir, split_dir, file)
164
  shutil.move(src, dst)
165
+ <<<<<<< HEAD
166
+ print(f"Moved {file} to {base_dir}/{split_dir}")
167
+ =======
168
+ print(f"Moved {file} to {base_dir}/{split_dir}")
169
+ >>>>>>> 22f4a17192a3a77fa4d4fe1ae3a2aa8c0bbdb539