griffingoodwin04 commited on
Commit
f02d855
·
1 Parent(s): b1cac57

push model changes

Browse files
flaring/MEGS_AI_baseline/base_model.py CHANGED
@@ -20,8 +20,8 @@ class BaseModel(LightningModule):
20
  def training_step(self, batch, batch_idx):
21
  (x, sxr), target = batch
22
  pred = self(x, sxr)
23
- pred = pred * self.eve_norm[1] + self.eve_norm[0] # Denormalize for loss
24
- target = target * self.eve_norm[1] + self.eve_norm[0] # Denormalize target
25
  loss = self.loss_func(pred, target)
26
  self.log('train_loss', loss)
27
  return loss
@@ -29,8 +29,8 @@ class BaseModel(LightningModule):
29
  def validation_step(self, batch, batch_idx):
30
  (x, sxr), target = batch
31
  pred = self(x, sxr)
32
- pred = pred * self.eve_norm[1] + self.eve_norm[0]
33
- target = target * self.eve_norm[1] + self.eve_norm[0]
34
  loss = self.loss_func(pred, target)
35
  self.log('valid_loss', loss)
36
  return loss
 
20
  def training_step(self, batch, batch_idx):
21
  (x, sxr), target = batch
22
  pred = self(x, sxr)
23
+ # pred = pred * self.eve_norm[1] + self.eve_norm[0] # Denormalize for loss
24
+ # target = target * self.eve_norm[1] + self.eve_norm[0] # Denormalize target
25
  loss = self.loss_func(pred, target)
26
  self.log('train_loss', loss)
27
  return loss
 
29
  def validation_step(self, batch, batch_idx):
30
  (x, sxr), target = batch
31
  pred = self(x, sxr)
32
+ # pred = pred * self.eve_norm[1] + self.eve_norm[0]
33
+ # target = target * self.eve_norm[1] + self.eve_norm[0]
34
  loss = self.loss_func(pred, target)
35
  self.log('valid_loss', loss)
36
  return loss
flaring/MEGS_AI_baseline/callback.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wandb
2
+ from pytorch_lightning import Callback
3
+ import matplotlib
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import torch
7
+ import sunpy.visualization.colormaps as cm
8
+ import astropy.units as u
9
+
10
+ # Custom Callback
11
+ sdoaia94 = matplotlib.colormaps['sdoaia94']
12
+
13
+ def unnormalize(y, eve_norm):
14
+ eve_norm = torch.tensor(eve_norm).float()
15
+ norm_mean = eve_norm[0]
16
+ norm_stdev = eve_norm[1]
17
+ y = y * norm_stdev[None].to(y) + norm_mean[None].to(y)
18
+ return y
19
+
20
+
21
+ class ImagePredictionLogger_SXR(Callback):
22
+
23
+ def __init__(self, val_aia, val_sxr, sxr_norm, val_samples):
24
+ super().__init__()
25
+ self.val_aia = val_aia
26
+ self.val_sxr = val_sxr
27
+ self.sxr_norm = sxr_norm
28
+ self.val_samples = val_samples
29
+
30
+ def unnormalize_sxr(self, normalized_values):
31
+
32
+ print(normalized_values)
33
+ print(self.sxr_norm)
34
+ if isinstance(normalized_values, torch.Tensor):
35
+ normalized_values = normalized_values.cpu().numpy()
36
+ normalized_values = np.array(normalized_values, dtype=np.float32)
37
+ return 10 ** (normalized_values * float(self.sxr_norm[1].item()) + float(self.sxr_norm[0].item()))
38
+
39
+ def on_validation_epoch_end(self, trainer, pl_module):
40
+
41
+ aia_images = []
42
+ true_sxr = []
43
+ pred_sxr = []
44
+ # print(self.val_samples)
45
+ for (aia, _), target in self.val_samples:
46
+ #device = torch.device("cuda:0")
47
+ aia = aia.to(pl_module.device).unsqueeze(0)
48
+ # Get prediction
49
+
50
+ pred = pl_module(aia)
51
+ #pred = self.unnormalize_sxr(pred)
52
+ pred_sxr.append(pred.item())
53
+ aia_images.append(aia.squeeze(0).cpu().numpy())
54
+ true_sxr.append(target.item())
55
+
56
+ true_unorm = self.unnormalize_sxr(true_sxr)
57
+ pred_unnorm = self.unnormalize_sxr(pred_sxr)
58
+ print("Aia images:", aia_images)
59
+ print("Sxr images:", true_unorm)
60
+ print("Sxr images:", pred_unnorm)
61
+ fig = self.plot_aia_sxr(aia_images,true_unorm, pred_unnorm)
62
+ trainer.logger.experiment.log({"AIA 94Å Images and Soft X-ray flux plots": wandb.Image(fig)})
63
+ plt.close(fig)
64
+
65
+ def plot_aia_sxr(self, val_aia, val_sxr, pred_sxr):
66
+ num_samples = len(val_aia)
67
+ fig, axes = plt.subplots(num_samples, 2, figsize=(10, 10))
68
+
69
+
70
+
71
+ for i in range(num_samples):
72
+ #print("Aia images:", val_aia[i])
73
+ print(val_aia[i].shape)
74
+ axes[i, 0].imshow(val_aia[i][:, :, 0], cmap=sdoaia94, origin='lower')
75
+ axes[i, 0].set_title("AIA 94Å Index" + str(i))
76
+ axes[i, 1].scatter(i, val_sxr[i])
77
+ axes[i, 1].scatter(i, pred_sxr[i])
78
+ axes[i, 1].set_xlabel("Index")
79
+ axes[i, 1].set_ylabel("Soft x-ray flux [W/m2]")
80
+ axes[i, 1].set_yscale('log')
81
+
82
+ fig.tight_layout()
83
+ return fig
84
+
85
+
86
+ class ImagePredictionLogger(Callback):
87
+ def __init__(self, val_imgs, val_eve, names, aia_wavelengths):
88
+ super().__init__()
89
+ self.val_imgs, self.val_eve = val_imgs, val_eve
90
+ self.names = names
91
+ self.aia_wavelengths = aia_wavelengths
92
+
93
+ def on_validation_epoch_end(self, trainer, pl_module):
94
+ # Bring the tensors to CPU
95
+ val_imgs = self.val_imgs.to(device=pl_module.device)
96
+ # Get model prediction
97
+ # pred_eve = pl_module.forward(val_imgs).cpu().numpy()
98
+ pred_eve = pl_module.forward_unnormalize(val_imgs).cpu().numpy()
99
+ val_eve = unnormalize(self.val_eve, pl_module.eve_norm).numpy()
100
+ val_imgs = val_imgs.cpu().numpy()
101
+
102
+ # create matplotlib figure
103
+ fig = self.plot_aia_eve(val_imgs, val_eve, pred_eve)
104
+ # Log the images to wandb
105
+ trainer.logger.experiment.log({"AIA Images and EVE bar plots": wandb.Image(fig)})
106
+ plt.close(fig)
107
+
108
+ def plot_aia_eve(self, val_imgs, val_eve, pred_eve):
109
+ """
110
+ Function to plot a 4 channel AIA stack and the EVE barplots
111
+
112
+ Arguments:
113
+ ----------
114
+ val_imgs: numpy array
115
+ Stack with 4 image channels
116
+ val_eve: numpy array
117
+ Stack of ground-truth eve channels
118
+ pred_eve: numpy array
119
+ Stack of predicted eve channels
120
+ Returns:
121
+ --------
122
+ fig: matplotlib figure
123
+ figure with plots
124
+ """
125
+ samples = pred_eve.shape[0]
126
+ n_aia_wavelengths = len(self.aia_wavelengths)
127
+ wspace = 0.2
128
+ hspace = 0.125
129
+ dpi = 100
130
+
131
+ if n_aia_wavelengths < 3:
132
+ nrows = 1
133
+ ncols = n_aia_wavelengths
134
+ fig = plt.figure(figsize=(9 + 9 / 4 * n_aia_wavelengths, 3 * samples), dpi=dpi)
135
+ gs = fig.add_gridspec(samples, n_aia_wavelengths + 3, wspace=wspace, hspace=hspace)
136
+ elif n_aia_wavelengths < 5:
137
+ nrows = 2
138
+ ncols = 2
139
+ fig = plt.figure(figsize=(9 + 9 / 4 * 2, 6 * samples), dpi=dpi)
140
+ gs = fig.add_gridspec(2 * samples, 5, wspace=wspace, hspace=hspace)
141
+ elif n_aia_wavelengths < 7:
142
+ nrows = 2
143
+ ncols = 3
144
+ fig = plt.figure(figsize=(9 + 9 / 4 * 3, 6 * samples), dpi=dpi)
145
+ gs = fig.add_gridspec(2 * samples, 6, wspace=wspace, hspace=hspace)
146
+ else:
147
+ nrows = 2
148
+ ncols = 4
149
+ fig = plt.figure(figsize=(15, 5 * samples), dpi=dpi)
150
+ gs = fig.add_gridspec(2 * samples, 7, wspace=wspace, hspace=hspace)
151
+
152
+ cmaps_all = ['sdoaia94', 'sdoaia131', 'sdoaia171', 'sdoaia193', 'sdoaia211',
153
+ 'sdoaia304', 'sdoaia335', 'sdoaia1600', 'sdoaia1700']
154
+ cmaps = [cmaps_all[i] for i in self.aia_wavelengths]
155
+ n_plots = 0
156
+
157
+ for s in range(samples):
158
+ for i in range(nrows):
159
+ for j in range(ncols):
160
+ if n_plots < n_aia_wavelengths:
161
+ ax = fig.add_subplot(gs[s * nrows + i, j])
162
+ ax.imshow(val_imgs[s, i * ncols + j], cmap=plt.get_cmap(cmaps[i * ncols + j]), origin='lower')
163
+ ax.text(0.01, 0.99, cmaps[i * ncols + j], horizontalalignment='left', verticalalignment='top',
164
+ color='w', transform=ax.transAxes)
165
+ ax.set_axis_off()
166
+ n_plots += 1
167
+ n_plots = 0
168
+ # eve data
169
+ ax5 = fig.add_subplot(gs[s * nrows, ncols:])
170
+ if self.names is not None:
171
+ ax5.bar(np.arange(0, len(val_eve[s, :])), val_eve[s, :], label='ground truth')
172
+ ax5.bar(np.arange(0, len(pred_eve[s, :])), pred_eve[s, :], width=0.5, label='prediction', alpha=0.5)
173
+ ax5.set_xticks(np.arange(0, len(val_eve[s, :])))
174
+ ax5.set_xticklabels(self.names, rotation=45)
175
+ else:
176
+ ax5.plot(np.arange(0, len(val_eve[s, :])), val_eve[s, :], label='ground truth', alpha=0.5,
177
+ drawstyle='steps-mid')
178
+ ax5.plot(np.arange(0, len(pred_eve[s, :])), pred_eve[s, :], label='prediction', alpha=0.5,
179
+ drawstyle='steps-mid')
180
+ ax5.set_yscale('log')
181
+ ax5.legend()
182
+
183
+ ax6 = fig.add_subplot(gs[s * nrows + 1, ncols:])
184
+ if self.names is not None:
185
+ ax6.bar(np.arange(0, len(val_eve[s, :])), np.abs(pred_eve[s, :] - val_eve[s, :]) / val_eve[s, :] * 100,
186
+ label='relative error (%)')
187
+ ax6.set_xticks(np.arange(0, len(val_eve[s, :])))
188
+ ax6.set_xticklabels(self.names, rotation=45)
189
+ else:
190
+ ax6.plot(np.arange(0, len(val_eve[s, :])), np.abs(pred_eve[s, :] - val_eve[s, :]) / val_eve[s, :] * 100,
191
+ label='relative error (%)', alpha=0.5, drawstyle='steps-mid')
192
+ ax6.set_yscale('log')
193
+ ax6.legend()
194
+
195
+ fig.tight_layout()
196
+ return fig
197
+
198
+
199
+ class SpectrumPredictionLogger(ImagePredictionLogger):
200
+ def __init__(self, val_imgs, val_eve, names, aia_wavelengths):
201
+ super().__init__(val_imgs, val_eve, names, aia_wavelengths)
202
+
203
+ def plot_aia_eve(self, val_imgs, val_eve, pred_eve):
204
+ """
205
+ Function to plot a 4 channel AIA stack and the EVE barplots
206
+
207
+ Arguments:
208
+ ----------
209
+ val_imgs: numpy array
210
+ Stack with 4 image channels
211
+ val_eve: numpy array
212
+ Stack of ground-truth eve channels
213
+ pred_eve: numpy array
214
+ Stack of predicted eve channels
215
+ Returns:
216
+ --------
217
+ fig: matplotlib figure
218
+ figure with plots
219
+ """
220
+ samples = pred_eve.shape[0]
221
+ n_aia_wavelengths = len(self.aia_wavelengths)
222
+ wspace = 0.2
223
+ hspace = 0.125
224
+ dpi = 200
225
+
226
+ fig = plt.figure(figsize=(5, 5), dpi=dpi)
227
+ gs = fig.add_gridspec(2, 1, wspace=wspace, hspace=hspace)
228
+
229
+ # eve data
230
+ s = 0
231
+ ax5 = fig.add_subplot(gs[0, 0])
232
+ if self.names is not None:
233
+ ax5.bar(np.arange(0, len(val_eve[s, :])), val_eve[s, :], label='ground truth')
234
+ ax5.bar(np.arange(0, len(pred_eve[s, :])), pred_eve[s, :], width=0.5, label='prediction', alpha=0.5)
235
+ ax5.set_xticks(np.arange(0, len(val_eve[s, :])))
236
+ ax5.set_xticklabels(self.names, rotation=45)
237
+ else:
238
+ ax5.plot(np.arange(0, len(val_eve[s, :])), val_eve[s, :], label='ground truth', alpha=0.5,
239
+ drawstyle='steps-mid')
240
+ ax5.plot(np.arange(0, len(pred_eve[s, :])), pred_eve[s, :], label='prediction', alpha=0.5,
241
+ drawstyle='steps-mid')
242
+ ax5.set_yscale('log')
243
+ ax5.legend()
244
+
245
+ ax6 = fig.add_subplot(gs[1, 0])
246
+ if self.names is not None:
247
+ ax6.bar(np.arange(0, len(val_eve[s, :])), np.abs(pred_eve[s, :] - val_eve[s, :]) / val_eve[s, :] * 100,
248
+ label='relative error (%)')
249
+ ax6.set_xticks(np.arange(0, len(val_eve[s, :])))
250
+ ax6.set_xticklabels(self.names, rotation=45)
251
+ else:
252
+ ax6.plot(np.arange(0, len(val_eve[s, :])), np.abs(pred_eve[s, :] - val_eve[s, :]) / val_eve[s, :] * 100,
253
+ label='relative error (%)', alpha=0.5, drawstyle='steps-mid')
254
+ ax6.set_yscale('log')
255
+ ax6.legend()
256
+
257
+ fig.tight_layout()
258
+ return fig
flaring/MEGS_AI_baseline/config.yaml CHANGED
@@ -14,7 +14,7 @@
14
  epochs:
15
  - 10
16
  wandb:
17
- entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
18
  project: MEGS-AI flaring # Lowercase, no spaces
19
  job_type: training
20
  tags:
 
14
  epochs:
15
  - 10
16
  wandb:
17
+ entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
18
  project: MEGS-AI flaring # Lowercase, no spaces
19
  job_type: training
20
  tags:
flaring/MEGS_AI_baseline/sxr_normalization.py CHANGED
@@ -53,5 +53,5 @@ if __name__ == "__main__":
53
  # Update this path to your real data SXR directory
54
  sxr_dir = "/mnt/data/ML-Ready-Data-No-Intensity-Cut/GOES-18-SXR-B/" # Replace with actual path
55
  sxr_norm = compute_sxr_norm(sxr_dir)
56
- np.save("/home/jayantbiradar619/sxr_norm2.npy", sxr_norm)
57
- print(f"Saved SXR normalization to /home/jayantbiradar619/sxr_norm.npy")
 
53
  # Update this path to your real data SXR directory
54
  sxr_dir = "/mnt/data/ML-Ready-Data-No-Intensity-Cut/GOES-18-SXR-B/" # Replace with actual path
55
  sxr_norm = compute_sxr_norm(sxr_dir)
56
+ np.save("/mnt/data/ML-Ready-Data-No-Intensity-Cut/normalized_sxr.npy", sxr_norm)
57
+ print(f"Saved SXR normalization to /mnt/data/ML-Ready-Data-No-Intensity-Cut/normalized_sxr")
flaring/MEGS_AI_baseline/train.py CHANGED
@@ -14,22 +14,23 @@ from pytorch_lightning.callbacks import ModelCheckpoint, Callback
14
  from torch.nn import HuberLoss
15
  from SDOAIA_dataloader import AIA_GOESDataModule
16
  from linear_and_hybrid import LinearIrradianceModel, HybridIrradianceModel
 
17
 
18
  # SXR Prediction Logger
19
- class SXRPredictionLogger(Callback):
20
- def __init__(self, val_samples):
21
- super().__init__()
22
- self.val_samples = val_samples
23
-
24
- def on_validation_epoch_end(self, trainer, pl_module):
25
- # val_samples is a list of ((aia, sxr), target)
26
- for (aia, sxr), target in self.val_samples:
27
- aia, sxr, target = aia.to(pl_module.device), sxr.to(pl_module.device), target.to(pl_module.device)
28
- pred = pl_module(aia.unsqueeze(0)) # Add batch dimension
29
- trainer.logger.experiment.log({
30
- "val_pred_sxr": pred.cpu().numpy(),
31
- "val_target_sxr": target.cpu().numpy()
32
- })
33
 
34
  # Compute SXR normalization
35
  def compute_sxr_norm(sxr_dir):
@@ -117,7 +118,10 @@ for parameter_set in combined_parameters:
117
  total_n_valid = len(data_loader.valid_ds)
118
  plot_data = [data_loader.valid_ds[i] for i in range(0, total_n_valid, max(1, total_n_valid // 4))]
119
  plot_samples = plot_data # Keep as list of ((aia, sxr), target)
120
- sxr_callback = SXRPredictionLogger(plot_samples)
 
 
 
121
 
122
  # Checkpoint callback
123
  checkpoint_callback = ModelCheckpoint(
 
14
  from torch.nn import HuberLoss
15
  from SDOAIA_dataloader import AIA_GOESDataModule
16
  from linear_and_hybrid import LinearIrradianceModel, HybridIrradianceModel
17
+ from callback import ImagePredictionLogger_SXR
18
 
19
  # SXR Prediction Logger
20
+ # class SXRPredictionLogger(Callback):
21
+ # def __init__(self, val_samples):
22
+ # super().__init__()
23
+ # self.val_samples = val_samples
24
+ #
25
+ # def on_validation_epoch_end(self, trainer, pl_module):
26
+ # # val_samples is a list of ((aia, sxr), target)
27
+ # for (aia, sxr), target in self.val_samples:
28
+ # aia, sxr, target = aia.to(pl_module.device), sxr.to(pl_module.device), target.to(pl_module.device)
29
+ # pred = pl_module(aia.unsqueeze(0)) # Add batch dimension
30
+ # trainer.logger.experiment.log({
31
+ # "val_pred_sxr": pred.cpu().numpy(),
32
+ # "val_target_sxr": target.cpu().numpy()
33
+ # })
34
 
35
  # Compute SXR normalization
36
  def compute_sxr_norm(sxr_dir):
 
118
  total_n_valid = len(data_loader.valid_ds)
119
  plot_data = [data_loader.valid_ds[i] for i in range(0, total_n_valid, max(1, total_n_valid // 4))]
120
  plot_samples = plot_data # Keep as list of ((aia, sxr), target)
121
+ #sxr_callback = SXRPredictionLogger(plot_samples)
122
+
123
+ sxr_plot_callback = ImagePredictionLogger_SXR(plot_data[0][0], plot_data[0][1], sxr_norm, plot_samples)
124
+
125
 
126
  # Checkpoint callback
127
  checkpoint_callback = ModelCheckpoint(