griffingoodwin04 commited on
Commit
ac3e767
·
1 Parent(s): 858f1e5

bug fixes

Browse files
flaring/MEGS_AI_baseline/callback.py CHANGED
@@ -10,12 +10,12 @@ import astropy.units as u
10
  # Custom Callback
11
  sdoaia94 = matplotlib.colormaps['sdoaia94']
12
 
13
- def unnormalize(y, eve_norm):
14
- eve_norm = torch.tensor(eve_norm).float()
15
- norm_mean = eve_norm[0]
16
- norm_stdev = eve_norm[1]
17
- y = y * norm_stdev[None].to(y) + norm_mean[None].to(y)
18
- return y
19
 
20
 
21
  class ImagePredictionLogger_SXR(Callback):
@@ -27,12 +27,6 @@ class ImagePredictionLogger_SXR(Callback):
27
  self.val_sxr = data_samples[1]
28
  self.sxr_norm = sxr_norm
29
 
30
- def unnormalize_sxr(self, normalized_values):
31
- if isinstance(normalized_values, torch.Tensor):
32
- normalized_values = normalized_values.cpu().numpy()
33
- normalized_values = np.array(normalized_values, dtype=np.float32)
34
- return 10 ** (normalized_values * float(self.sxr_norm[1].item()) + float(self.sxr_norm[0].item())) - 1e-8
35
-
36
  def on_validation_epoch_end(self, trainer, pl_module):
37
 
38
  aia_images = []
@@ -50,8 +44,8 @@ class ImagePredictionLogger_SXR(Callback):
50
  aia_images.append(aia.squeeze(0).cpu().numpy())
51
  true_sxr.append(target.item())
52
 
53
- true_unorm = self.unnormalize_sxr(true_sxr)
54
- pred_unnorm = self.unnormalize_sxr(pred_sxr)
55
  fig1 = self.plot_aia_sxr(aia_images,true_unorm, pred_unnorm)
56
  trainer.logger.experiment.log({"Soft X-ray flux plots": wandb.Image(fig1)})
57
  plt.close(fig1)
@@ -85,178 +79,3 @@ class ImagePredictionLogger_SXR(Callback):
85
 
86
  fig.tight_layout()
87
  return fig
88
-
89
-
90
- class ImagePredictionLogger(Callback):
91
- def __init__(self, val_imgs, val_eve, names, aia_wavelengths):
92
- super().__init__()
93
- self.val_imgs, self.val_eve = val_imgs, val_eve
94
- self.names = names
95
- self.aia_wavelengths = aia_wavelengths
96
-
97
- def on_validation_epoch_end(self, trainer, pl_module):
98
- # Bring the tensors to CPU
99
- val_imgs = self.val_imgs.to(device=pl_module.device)
100
- # Get model prediction
101
- # pred_eve = pl_module.forward(val_imgs).cpu().numpy()
102
- pred_eve = pl_module.forward_unnormalize(val_imgs).cpu().numpy()
103
- val_eve = unnormalize(self.val_eve, pl_module.eve_norm).numpy()
104
- val_imgs = val_imgs.cpu().numpy()
105
-
106
- # create matplotlib figure
107
- fig = self.plot_aia_eve(val_imgs, val_eve, pred_eve)
108
- # Log the images to wandb
109
- trainer.logger.experiment.log({"AIA Images and EVE bar plots": wandb.Image(fig)})
110
- plt.close(fig)
111
-
112
- def plot_aia_eve(self, val_imgs, val_eve, pred_eve):
113
- """
114
- Function to plot a 4 channel AIA stack and the EVE barplots
115
-
116
- Arguments:
117
- ----------
118
- val_imgs: numpy array
119
- Stack with 4 image channels
120
- val_eve: numpy array
121
- Stack of ground-truth eve channels
122
- pred_eve: numpy array
123
- Stack of predicted eve channels
124
- Returns:
125
- --------
126
- fig: matplotlib figure
127
- figure with plots
128
- """
129
- samples = pred_eve.shape[0]
130
- n_aia_wavelengths = len(self.aia_wavelengths)
131
- wspace = 0.2
132
- hspace = 0.125
133
- dpi = 100
134
-
135
- if n_aia_wavelengths < 3:
136
- nrows = 1
137
- ncols = n_aia_wavelengths
138
- fig = plt.figure(figsize=(9 + 9 / 4 * n_aia_wavelengths, 3 * samples), dpi=dpi)
139
- gs = fig.add_gridspec(samples, n_aia_wavelengths + 3, wspace=wspace, hspace=hspace)
140
- elif n_aia_wavelengths < 5:
141
- nrows = 2
142
- ncols = 2
143
- fig = plt.figure(figsize=(9 + 9 / 4 * 2, 6 * samples), dpi=dpi)
144
- gs = fig.add_gridspec(2 * samples, 5, wspace=wspace, hspace=hspace)
145
- elif n_aia_wavelengths < 7:
146
- nrows = 2
147
- ncols = 3
148
- fig = plt.figure(figsize=(9 + 9 / 4 * 3, 6 * samples), dpi=dpi)
149
- gs = fig.add_gridspec(2 * samples, 6, wspace=wspace, hspace=hspace)
150
- else:
151
- nrows = 2
152
- ncols = 4
153
- fig = plt.figure(figsize=(15, 5 * samples), dpi=dpi)
154
- gs = fig.add_gridspec(2 * samples, 7, wspace=wspace, hspace=hspace)
155
-
156
- cmaps_all = ['sdoaia94', 'sdoaia131', 'sdoaia171', 'sdoaia193', 'sdoaia211',
157
- 'sdoaia304', 'sdoaia335', 'sdoaia1600', 'sdoaia1700']
158
- cmaps = [cmaps_all[i] for i in self.aia_wavelengths]
159
- n_plots = 0
160
-
161
- for s in range(samples):
162
- for i in range(nrows):
163
- for j in range(ncols):
164
- if n_plots < n_aia_wavelengths:
165
- ax = fig.add_subplot(gs[s * nrows + i, j])
166
- ax.imshow(val_imgs[s, i * ncols + j], cmap=plt.get_cmap(cmaps[i * ncols + j]), origin='lower')
167
- ax.text(0.01, 0.99, cmaps[i * ncols + j], horizontalalignment='left', verticalalignment='top',
168
- color='w', transform=ax.transAxes)
169
- ax.set_axis_off()
170
- n_plots += 1
171
- n_plots = 0
172
- # eve data
173
- ax5 = fig.add_subplot(gs[s * nrows, ncols:])
174
- if self.names is not None:
175
- ax5.bar(np.arange(0, len(val_eve[s, :])), val_eve[s, :], label='ground truth')
176
- ax5.bar(np.arange(0, len(pred_eve[s, :])), pred_eve[s, :], width=0.5, label='prediction', alpha=0.5)
177
- ax5.set_xticks(np.arange(0, len(val_eve[s, :])))
178
- ax5.set_xticklabels(self.names, rotation=45)
179
- else:
180
- ax5.plot(np.arange(0, len(val_eve[s, :])), val_eve[s, :], label='ground truth', alpha=0.5,
181
- drawstyle='steps-mid')
182
- ax5.plot(np.arange(0, len(pred_eve[s, :])), pred_eve[s, :], label='prediction', alpha=0.5,
183
- drawstyle='steps-mid')
184
- ax5.set_yscale('log')
185
- ax5.legend()
186
-
187
- ax6 = fig.add_subplot(gs[s * nrows + 1, ncols:])
188
- if self.names is not None:
189
- ax6.bar(np.arange(0, len(val_eve[s, :])), np.abs(pred_eve[s, :] - val_eve[s, :]) / val_eve[s, :] * 100,
190
- label='relative error (%)')
191
- ax6.set_xticks(np.arange(0, len(val_eve[s, :])))
192
- ax6.set_xticklabels(self.names, rotation=45)
193
- else:
194
- ax6.plot(np.arange(0, len(val_eve[s, :])), np.abs(pred_eve[s, :] - val_eve[s, :]) / val_eve[s, :] * 100,
195
- label='relative error (%)', alpha=0.5, drawstyle='steps-mid')
196
- ax6.set_yscale('log')
197
- ax6.legend()
198
-
199
- fig.tight_layout()
200
- return fig
201
-
202
-
203
- class SpectrumPredictionLogger(ImagePredictionLogger):
204
- def __init__(self, val_imgs, val_eve, names, aia_wavelengths):
205
- super().__init__(val_imgs, val_eve, names, aia_wavelengths)
206
-
207
- def plot_aia_eve(self, val_imgs, val_eve, pred_eve):
208
- """
209
- Function to plot a 4 channel AIA stack and the EVE barplots
210
-
211
- Arguments:
212
- ----------
213
- val_imgs: numpy array
214
- Stack with 4 image channels
215
- val_eve: numpy array
216
- Stack of ground-truth eve channels
217
- pred_eve: numpy array
218
- Stack of predicted eve channels
219
- Returns:
220
- --------
221
- fig: matplotlib figure
222
- figure with plots
223
- """
224
- samples = pred_eve.shape[0]
225
- n_aia_wavelengths = len(self.aia_wavelengths)
226
- wspace = 0.2
227
- hspace = 0.125
228
- dpi = 200
229
-
230
- fig = plt.figure(figsize=(5, 5), dpi=dpi)
231
- gs = fig.add_gridspec(2, 1, wspace=wspace, hspace=hspace)
232
-
233
- # eve data
234
- s = 0
235
- ax5 = fig.add_subplot(gs[0, 0])
236
- if self.names is not None:
237
- ax5.bar(np.arange(0, len(val_eve[s, :])), val_eve[s, :], label='ground truth')
238
- ax5.bar(np.arange(0, len(pred_eve[s, :])), pred_eve[s, :], width=0.5, label='prediction', alpha=0.5)
239
- ax5.set_xticks(np.arange(0, len(val_eve[s, :])))
240
- ax5.set_xticklabels(self.names, rotation=45)
241
- else:
242
- ax5.plot(np.arange(0, len(val_eve[s, :])), val_eve[s, :], label='ground truth', alpha=0.5,
243
- drawstyle='steps-mid')
244
- ax5.plot(np.arange(0, len(pred_eve[s, :])), pred_eve[s, :], label='prediction', alpha=0.5,
245
- drawstyle='steps-mid')
246
- ax5.set_yscale('log')
247
- ax5.legend()
248
-
249
- ax6 = fig.add_subplot(gs[1, 0])
250
- if self.names is not None:
251
- ax6.bar(np.arange(0, len(val_eve[s, :])), np.abs(pred_eve[s, :] - val_eve[s, :]) / val_eve[s, :] * 100,
252
- label='relative error (%)')
253
- ax6.set_xticks(np.arange(0, len(val_eve[s, :])))
254
- ax6.set_xticklabels(self.names, rotation=45)
255
- else:
256
- ax6.plot(np.arange(0, len(val_eve[s, :])), np.abs(pred_eve[s, :] - val_eve[s, :]) / val_eve[s, :] * 100,
257
- label='relative error (%)', alpha=0.5, drawstyle='steps-mid')
258
- ax6.set_yscale('log')
259
- ax6.legend()
260
-
261
- fig.tight_layout()
262
- return fig
 
10
  # Custom Callback
11
  sdoaia94 = matplotlib.colormaps['sdoaia94']
12
 
13
+
14
+ def unnormalize_sxr(normalized_values, sxr_norm):
15
+ if isinstance(normalized_values, torch.Tensor):
16
+ normalized_values = normalized_values.cpu().numpy()
17
+ normalized_values = np.array(normalized_values, dtype=np.float32)
18
+ return 10 ** (normalized_values * float(sxr_norm[1].item()) + float(sxr_norm[0].item())) - 1e-8
19
 
20
 
21
  class ImagePredictionLogger_SXR(Callback):
 
27
  self.val_sxr = data_samples[1]
28
  self.sxr_norm = sxr_norm
29
 
 
 
 
 
 
 
30
  def on_validation_epoch_end(self, trainer, pl_module):
31
 
32
  aia_images = []
 
44
  aia_images.append(aia.squeeze(0).cpu().numpy())
45
  true_sxr.append(target.item())
46
 
47
+ true_unorm = unnormalize_sxr(true_sxr,self.sxr_norm)
48
+ pred_unnorm = unnormalize_sxr(pred_sxr,self.sxr_norm)
49
  fig1 = self.plot_aia_sxr(aia_images,true_unorm, pred_unnorm)
50
  trainer.logger.experiment.log({"Soft X-ray flux plots": wandb.Image(fig1)})
51
  plt.close(fig1)
 
79
 
80
  fig.tight_layout()
81
  return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flaring/MEGS_AI_baseline/config.yaml CHANGED
@@ -11,7 +11,9 @@
11
  cnn_dp:
12
  - 0.75
13
  epochs:
14
- - 100
 
 
15
  wandb:
16
  entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
17
  project: MEGS-AI flaring # Lowercase, no spaces
 
11
  cnn_dp:
12
  - 0.75
13
  epochs:
14
+ - 1
15
+ save_dictionary:
16
+ -
17
  wandb:
18
  entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
19
  project: MEGS-AI flaring # Lowercase, no spaces
flaring/MEGS_AI_baseline/inference.py CHANGED
@@ -3,28 +3,41 @@ import torch
3
  import numpy as np
4
  from torch.utils.data import DataLoader
5
  from SDOAIA_dataloader import AIA_GOESDataset
 
 
6
 
 
7
  def predict_log_outputs(model, dataset, batch_size=8):
8
- """Generator yielding raw log-space model outputs"""
9
  model.eval()
10
  loader = DataLoader(dataset, batch_size=batch_size)
11
 
 
 
 
12
  with torch.no_grad():
13
  for batch in loader:
14
- # Handle different dataset formats
15
  if isinstance(batch, tuple) and len(batch) == 2:
16
- aia_imgs = batch[0][0] # Unpack ((aia, sxr), target)
 
17
  else:
18
- aia_imgs = batch[0] if isinstance(batch, (list, tuple)) else batch
 
 
 
 
 
 
 
19
 
20
- aia_imgs = aia_imgs.to(next(model.parameters()).device)
21
- log_outputs = model(aia_imgs) # Get raw log-space outputs
22
  yield from log_outputs.cpu().numpy()
23
 
24
  def main():
25
  parser = argparse.ArgumentParser(description='Save raw log-space model outputs')
26
- parser.add_argument('--model', required=True, help='Path to trained model')
27
- parser.add_argument('--aia-dir', required=True, help='Directory of AIA images')
 
28
  parser.add_argument('--output', default='log_predictions.txt',
29
  help='Output file for log-space predictions')
30
  parser.add_argument('--batch-size', type=int, default=8,
@@ -33,14 +46,24 @@ def main():
33
  args = parser.parse_args()
34
 
35
  # Setup
36
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
37
- model = torch.load(args.model, map_location=device).to(device)
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  # Dataset without any output transformation
40
  dataset = AIA_GOESDataset(
41
  aia_dir=args.aia_dir,
42
- sxr_dir='', # No SXR files needed
43
- sxr_norm=None, # Skip any normalization
44
  transform=None # No input transforms
45
  )
46
 
@@ -48,7 +71,7 @@ def main():
48
  with open(args.output, 'w') as f:
49
  f.write("# Log-space SXR predictions (log10(W/m²))\n")
50
  for log_pred in predict_log_outputs(model, dataset, args.batch_size):
51
- f.write(f"{log_pred:.6f}\n") # Write with 6 decimal places
52
 
53
  print(f"Log-space predictions saved to {args.output}")
54
  print("These are raw model outputs in log10 space before any exponentiation")
 
3
  import numpy as np
4
  from torch.utils.data import DataLoader
5
  from SDOAIA_dataloader import AIA_GOESDataset
6
+ from models.linear_and_hybrid import HybridIrradianceModel
7
+ from callback import ImagePredictionLogger_SXR
8
 
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
  def predict_log_outputs(model, dataset, batch_size=8):
 
11
  model.eval()
12
  loader = DataLoader(dataset, batch_size=batch_size)
13
 
14
+ # Get device from model
15
+ device = next(model.parameters()).device
16
+
17
  with torch.no_grad():
18
  for batch in loader:
19
+ # Correct unpacking based on your data structure
20
  if isinstance(batch, tuple) and len(batch) == 2:
21
+ # batch = (inputs, targets) where inputs = [aia_imgs, sxr_imgs]
22
+ aia_imgs = batch[0][0] # Get aia_imgs from inputs
23
  else:
24
+ # Fallback for other formats
25
+ aia_imgs = batch[0][0] if isinstance(batch[0], list) else batch[0]
26
+
27
+ # Move to device (it's already a tensor)
28
+ aia_imgs = aia_imgs.to(device)
29
+
30
+ # Get model predictions
31
+ log_outputs = model(aia_imgs)
32
 
33
+ # Move to CPU and convert to numpy before yielding
 
34
  yield from log_outputs.cpu().numpy()
35
 
36
  def main():
37
  parser = argparse.ArgumentParser(description='Save raw log-space model outputs')
38
+ parser.add_argument('--ckpt_path', required=True, help='Path to model checkpoint')
39
+ parser.add_argument('--aia_dir', required=True, help='Directory of AIA images')
40
+ parser.add_argument('--sxr_dir', required=True, help='Directory of target SXR images')
41
  parser.add_argument('--output', default='log_predictions.txt',
42
  help='Output file for log-space predictions')
43
  parser.add_argument('--batch-size', type=int, default=8,
 
46
  args = parser.parse_args()
47
 
48
  # Setup
49
+ state = torch.load(args.ckpt_path, map_location=device, weights_only=False)
50
+ model = state['model']
51
+ model.to(device)
52
+ # Assume it's a checkpoint with state_dict
53
+
54
+ # model = HybridIrradianceModel(6,1)
55
+ # state_dict = checkpoint.get('state_dict', checkpoint)
56
+ #
57
+ # # Handle potential key mismatches (e.g., PyTorch Lightning prefixes)
58
+ # state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
59
+ # model.load_state_dict(state_dict, strict=False)
60
+
61
+
62
 
63
  # Dataset without any output transformation
64
  dataset = AIA_GOESDataset(
65
  aia_dir=args.aia_dir,
66
+ sxr_dir=args.sxr_dir, # No SXR files needed
 
67
  transform=None # No input transforms
68
  )
69
 
 
71
  with open(args.output, 'w') as f:
72
  f.write("# Log-space SXR predictions (log10(W/m²))\n")
73
  for log_pred in predict_log_outputs(model, dataset, args.batch_size):
74
+ print(log_pred)
75
 
76
  print(f"Log-space predictions saved to {args.output}")
77
  print("These are raw model outputs in log10 space before any exponentiation")
flaring/MEGS_AI_baseline/models/linear_and_hybrid.py CHANGED
@@ -108,9 +108,6 @@ class HybridIrradianceModel(BaseModel):
108
  if isinstance(x, (list, tuple)):
109
  x = x[0]
110
 
111
- # Debug: Print input shape
112
- print(f"Input shape to HybridIrradianceModel.forward: {x.shape}")
113
-
114
  # Expect x shape: (batch_size, H, W, C)
115
  if len(x.shape) != 4:
116
  raise ValueError(f"Expected 4D input tensor (batch_size, H, W, C), got shape {x.shape}")
 
108
  if isinstance(x, (list, tuple)):
109
  x = x[0]
110
 
 
 
 
111
  # Expect x shape: (batch_size, H, W, C)
112
  if len(x.shape) != 4:
113
  raise ValueError(f"Expected 4D input tensor (batch_size, H, W, C), got shape {x.shape}")
flaring/MEGS_AI_baseline/train.py CHANGED
@@ -1,6 +1,8 @@
1
 
2
  import argparse
3
  import os
 
 
4
  import yaml
5
  import itertools
6
  import wandb
@@ -11,7 +13,7 @@ from pytorch_lightning.loggers import WandbLogger
11
  from pytorch_lightning.callbacks import ModelCheckpoint
12
  from torch.nn import MSELoss
13
  from SDOAIA_dataloader import AIA_GOESDataModule
14
- from flaring.MEGS_AI_baseline.models.linear_and_hybrid import LinearIrradianceModel, HybridIrradianceModel
15
  from callback import ImagePredictionLogger_SXR
16
 
17
  # Parser
@@ -115,7 +117,7 @@ for parameter_set in combined_parameters:
115
  default_root_dir=checkpoint_dir,
116
  accelerator="gpu" if torch.cuda.is_available() else "cpu",
117
  devices=1,
118
- max_epochs=run_config.get('epochs', 10),
119
  callbacks=[sxr_plot_callback, checkpoint_callback],
120
  logger=wandb_logger,
121
  log_every_n_steps=10
@@ -131,9 +133,15 @@ for parameter_set in combined_parameters:
131
  full_checkpoint_path = os.path.join(checkpoint_dir, f"{wb_name}_{n}.ckpt")
132
  torch.save(save_dictionary, full_checkpoint_path)
133
 
134
- # Test
135
- trainer.test(model, dataloaders=data_loader.test_dataloader())
136
-
 
 
 
 
 
 
137
  # Finalize
138
  wandb.finish()
139
  n += 1
 
1
 
2
  import argparse
3
  import os
4
+ from datetime import datetime
5
+
6
  import yaml
7
  import itertools
8
  import wandb
 
13
  from pytorch_lightning.callbacks import ModelCheckpoint
14
  from torch.nn import MSELoss
15
  from SDOAIA_dataloader import AIA_GOESDataModule
16
+ from models.linear_and_hybrid import LinearIrradianceModel, HybridIrradianceModel
17
  from callback import ImagePredictionLogger_SXR
18
 
19
  # Parser
 
117
  default_root_dir=checkpoint_dir,
118
  accelerator="gpu" if torch.cuda.is_available() else "cpu",
119
  devices=1,
120
+ max_epochs=run_config['epochs'],
121
  callbacks=[sxr_plot_callback, checkpoint_callback],
122
  logger=wandb_logger,
123
  log_every_n_steps=10
 
133
  full_checkpoint_path = os.path.join(checkpoint_dir, f"{wb_name}_{n}.ckpt")
134
  torch.save(save_dictionary, full_checkpoint_path)
135
 
136
+ # Save final PyTorch checkpoint with model and state_dict
137
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
138
+ final_checkpoint_path = os.path.join(checkpoint_dir, f"{wb_name}-final-{timestamp}.pth")
139
+ torch.save({
140
+ 'model': model,
141
+ 'state_dict': model.state_dict()
142
+ }, final_checkpoint_path)
143
+ print(f"Saved final PyTorch checkpoint: {final_checkpoint_path}")
144
+ n += 1
145
  # Finalize
146
  wandb.finish()
147
  n += 1
flaring/outputs/outputs.txt ADDED
File without changes