griffingoodwin04 commited on
Commit
4b96a19
Β·
1 Parent(s): 64fe3c0

updated tree structure

Browse files
Files changed (41) hide show
  1. {flaring β†’ data}/__init__.py +0 -0
  2. {flaring/data β†’ data}/align_data.py +0 -0
  3. {flaring/data β†’ data}/euv_data_cleaning.py +0 -0
  4. {flaring/data β†’ data}/frame-to-movie.py +0 -0
  5. {flaring/data β†’ data}/iti_data_processing.py +0 -0
  6. {flaring/data β†’ data}/split_data.py +0 -0
  7. {flaring/data β†’ data}/sxr_data_processing.py +0 -0
  8. {flaring/data β†’ data}/visualize_euv.py +0 -0
  9. {flaring/data β†’ download}/__init__.py +0 -0
  10. {flaring/download β†’ download}/download_sdo.py +0 -0
  11. {flaring/download β†’ download}/download_solo.py +0 -0
  12. {flaring/download β†’ download}/download_stereo.py +0 -0
  13. {flaring/download β†’ download}/flare_download_processor.py +0 -0
  14. {flaring/download β†’ download}/flare_event_downloader.py +0 -0
  15. {flaring/download β†’ download}/sxr_downloader.py +0 -0
  16. flaring/utils/__init__.py +0 -0
  17. flaring/vision_transformers/__init__.py +0 -1
  18. flaring/vision_transformers/callback.py +0 -264
  19. {flaring/download β†’ forecasting}/__init__.py +0 -0
  20. {flaring/forecasting β†’ forecasting}/data_loaders/SDOAIA_dataloader.py +0 -0
  21. {flaring/forecasting β†’ forecasting/data_loaders}/__init__.py +0 -0
  22. {flaring/forecasting β†’ forecasting}/data_loaders/sxr_normalization.py +0 -0
  23. {flaring/forecasting/data_loaders β†’ forecasting/inference}/__init__.py +0 -0
  24. {flaring/forecasting β†’ forecasting}/inference/evaluation.py +0 -0
  25. {flaring/forecasting β†’ forecasting}/inference/inference.py +0 -0
  26. {flaring/forecasting β†’ forecasting}/inference/inference_config.yaml +0 -0
  27. {flaring/forecasting β†’ forecasting}/inference/inference_on_patch.py +5 -5
  28. {flaring/forecasting β†’ forecasting}/inference/inference_on_patch_config.yaml +0 -0
  29. {flaring/forecasting β†’ forecasting}/inference/plotting.py +0 -0
  30. {flaring/forecasting β†’ forecasting}/models/FastSpectralNet.py +0 -0
  31. {flaring/forecasting/inference β†’ forecasting/models}/__init__.py +0 -0
  32. {flaring/forecasting β†’ forecasting}/models/base_model.py +0 -0
  33. {flaring/forecasting β†’ forecasting}/models/linear_and_hybrid.py +0 -0
  34. {flaring/forecasting β†’ forecasting}/models/vision_transformer_custom.py +0 -0
  35. {flaring/forecasting β†’ forecasting}/models/vit_patch_model.py +0 -0
  36. {flaring/forecasting/models β†’ forecasting/training}/__init__.py +0 -0
  37. {flaring/forecasting β†’ forecasting}/training/callback.py +0 -0
  38. {flaring/forecasting β†’ forecasting}/training/config.yaml +0 -0
  39. {flaring/forecasting β†’ forecasting}/training/train.py +5 -5
  40. {flaring/forecasting/training β†’ utils}/__init__.py +0 -0
  41. {flaring/utils β†’ utils}/cut_off_aia.py +0 -0
{flaring β†’ data}/__init__.py RENAMED
File without changes
{flaring/data β†’ data}/align_data.py RENAMED
File without changes
{flaring/data β†’ data}/euv_data_cleaning.py RENAMED
File without changes
{flaring/data β†’ data}/frame-to-movie.py RENAMED
File without changes
{flaring/data β†’ data}/iti_data_processing.py RENAMED
File without changes
{flaring/data β†’ data}/split_data.py RENAMED
File without changes
{flaring/data β†’ data}/sxr_data_processing.py RENAMED
File without changes
{flaring/data β†’ data}/visualize_euv.py RENAMED
File without changes
{flaring/data β†’ download}/__init__.py RENAMED
File without changes
{flaring/download β†’ download}/download_sdo.py RENAMED
File without changes
{flaring/download β†’ download}/download_solo.py RENAMED
File without changes
{flaring/download β†’ download}/download_stereo.py RENAMED
File without changes
{flaring/download β†’ download}/flare_download_processor.py RENAMED
File without changes
{flaring/download β†’ download}/flare_event_downloader.py RENAMED
File without changes
{flaring/download β†’ download}/sxr_downloader.py RENAMED
File without changes
flaring/utils/__init__.py DELETED
File without changes
flaring/vision_transformers/__init__.py DELETED
@@ -1 +0,0 @@
1
- # Vision Transformers module
 
 
flaring/vision_transformers/callback.py DELETED
@@ -1,264 +0,0 @@
1
- import wandb
2
- from pytorch_lightning import Callback
3
- import matplotlib
4
- import matplotlib.pyplot as plt
5
- import numpy as np
6
- import torch
7
- import sunpy.visualization.colormaps as cm
8
- import astropy.units as u
9
-
10
- # Custom Callback
11
- sdoaia94 = matplotlib.colormaps['sdoaia94']
12
-
13
- def unnormalize(y, eve_norm):
14
- eve_norm = torch.tensor(eve_norm).float()
15
- norm_mean = eve_norm[0]
16
- norm_stdev = eve_norm[1]
17
- y = y * norm_stdev[None].to(y) + norm_mean[None].to(y)
18
- return y
19
-
20
-
21
- class ImagePredictionLogger_SXR(Callback):
22
-
23
- def __init__(self, data_samples, sxr_norm):
24
- super().__init__()
25
- self.data_samples = data_samples
26
- self.val_aia = data_samples[0][0]
27
- self.val_sxr = data_samples[1]
28
- self.sxr_norm = sxr_norm
29
-
30
- def unnormalize_sxr(self, normalized_values):
31
- if isinstance(normalized_values, torch.Tensor):
32
- normalized_values = normalized_values.cpu().numpy()
33
- normalized_values = np.array(normalized_values, dtype=np.float32)
34
- return 10 ** (normalized_values * float(self.sxr_norm[1].item()) + float(self.sxr_norm[0].item())) - 1e-8
35
-
36
- def on_validation_epoch_end(self, trainer, pl_module):
37
-
38
- aia_images = []
39
- true_sxr = []
40
- pred_sxr = []
41
- # print(self.val_samples)
42
- for (aia, _), target in self.data_samples:
43
- #device = torch.device("cuda:0")
44
- aia = aia.to(pl_module.device).unsqueeze(0)
45
- # Get prediction
46
-
47
- pred = pl_module(aia)
48
- #pred = self.unnormalize_sxr(pred)
49
- pred_sxr.append(pred.item())
50
- aia_images.append(aia.squeeze(0).cpu().numpy())
51
- true_sxr.append(target.item())
52
-
53
- true_unorm = self.unnormalize_sxr(true_sxr)
54
- pred_unnorm = self.unnormalize_sxr(pred_sxr)
55
- print("Aia images:", aia_images)
56
- print("Sxr images:", true_unorm)
57
- print("Sxr images:", pred_unnorm)
58
- fig = self.plot_aia_sxr(aia_images,true_unorm, pred_unnorm)
59
- trainer.logger.experiment.log({"AIA 94Γ… Images and Soft X-ray flux plots": wandb.Image(fig)})
60
- plt.close(fig)
61
-
62
- def plot_aia_sxr(self, val_aia, val_sxr, pred_sxr):
63
- num_samples = len(val_aia)
64
- fig, axes = plt.subplots(num_samples, 2, figsize=(10, 10))
65
-
66
-
67
-
68
- for i in range(num_samples):
69
- axes.scatter(i, val_sxr[i], label='Ground truth' if i == 0 else "", color='blue')
70
- axes.scatter(i, pred_sxr[i], label='Prediction' if i == 0 else "", color='orange')
71
- axes.set_xlabel("Index")
72
- axes.set_ylabel("Soft x-ray flux [W/m2]")
73
- axes.set_yscale('log')
74
- axes.legend()
75
-
76
- fig.tight_layout()
77
- return fig
78
-
79
- def plot_aia_sxr_difference(self, val_aia, val_sxr, pred_sxr):
80
- num_samples = len(val_aia)
81
- fig, axes = plt.subplots(1, 1, figsize=(5, 2))
82
- for i in range(num_samples):
83
- # print("Aia images:", val_aia[i])
84
- axes.scatter(i, val_sxr[i]-pred_sxr[i], label='Soft X-ray Flux Difference', color='blue')
85
- axes.set_xlabel("Index")
86
- axes.set_ylabel("Soft X-ray Flux Difference (True - Pred.) [W/m2]")
87
-
88
- fig.tight_layout()
89
- return fig
90
-
91
-
92
- class ImagePredictionLogger(Callback):
93
- def __init__(self, val_imgs, val_eve, names, aia_wavelengths):
94
- super().__init__()
95
- self.val_imgs, self.val_eve = val_imgs, val_eve
96
- self.names = names
97
- self.aia_wavelengths = aia_wavelengths
98
-
99
- def on_validation_epoch_end(self, trainer, pl_module):
100
- # Bring the tensors to CPU
101
- val_imgs = self.val_imgs.to(device=pl_module.device)
102
- # Get model prediction
103
- # pred_eve = pl_module.forward(val_imgs).cpu().numpy()
104
- pred_eve = pl_module.forward_unnormalize(val_imgs).cpu().numpy()
105
- val_eve = unnormalize(self.val_eve, pl_module.eve_norm).numpy()
106
- val_imgs = val_imgs.cpu().numpy()
107
-
108
- # create matplotlib figure
109
- fig = self.plot_aia_eve(val_imgs, val_eve, pred_eve)
110
- # Log the images to wandb
111
- trainer.logger.experiment.log({"AIA Images and EVE bar plots": wandb.Image(fig)})
112
- plt.close(fig)
113
-
114
- def plot_aia_eve(self, val_imgs, val_eve, pred_eve):
115
- """
116
- Function to plot a 4 channel AIA stack and the EVE barplots
117
-
118
- Arguments:
119
- ----------
120
- val_imgs: numpy array
121
- Stack with 4 image channels
122
- val_eve: numpy array
123
- Stack of ground-truth eve channels
124
- pred_eve: numpy array
125
- Stack of predicted eve channels
126
- Returns:
127
- --------
128
- fig: matplotlib figure
129
- figure with plots
130
- """
131
- samples = pred_eve.shape[0]
132
- n_aia_wavelengths = len(self.aia_wavelengths)
133
- wspace = 0.2
134
- hspace = 0.125
135
- dpi = 100
136
-
137
- if n_aia_wavelengths < 3:
138
- nrows = 1
139
- ncols = n_aia_wavelengths
140
- fig = plt.figure(figsize=(9 + 9 / 4 * n_aia_wavelengths, 3 * samples), dpi=dpi)
141
- gs = fig.add_gridspec(samples, n_aia_wavelengths + 3, wspace=wspace, hspace=hspace)
142
- elif n_aia_wavelengths < 5:
143
- nrows = 2
144
- ncols = 2
145
- fig = plt.figure(figsize=(9 + 9 / 4 * 2, 6 * samples), dpi=dpi)
146
- gs = fig.add_gridspec(2 * samples, 5, wspace=wspace, hspace=hspace)
147
- elif n_aia_wavelengths < 7:
148
- nrows = 2
149
- ncols = 3
150
- fig = plt.figure(figsize=(9 + 9 / 4 * 3, 6 * samples), dpi=dpi)
151
- gs = fig.add_gridspec(2 * samples, 6, wspace=wspace, hspace=hspace)
152
- else:
153
- nrows = 2
154
- ncols = 4
155
- fig = plt.figure(figsize=(15, 5 * samples), dpi=dpi)
156
- gs = fig.add_gridspec(2 * samples, 7, wspace=wspace, hspace=hspace)
157
-
158
- cmaps_all = ['sdoaia94', 'sdoaia131', 'sdoaia171', 'sdoaia193', 'sdoaia211',
159
- 'sdoaia304', 'sdoaia335', 'sdoaia1600', 'sdoaia1700']
160
- cmaps = [cmaps_all[i] for i in self.aia_wavelengths]
161
- n_plots = 0
162
-
163
- for s in range(samples):
164
- for i in range(nrows):
165
- for j in range(ncols):
166
- if n_plots < n_aia_wavelengths:
167
- ax = fig.add_subplot(gs[s * nrows + i, j])
168
- ax.imshow(val_imgs[s, i * ncols + j], cmap=plt.get_cmap(cmaps[i * ncols + j]), origin='lower')
169
- ax.text(0.01, 0.99, cmaps[i * ncols + j], horizontalalignment='left', verticalalignment='top',
170
- color='w', transform=ax.transAxes)
171
- ax.set_axis_off()
172
- n_plots += 1
173
- n_plots = 0
174
- # eve data
175
- ax5 = fig.add_subplot(gs[s * nrows, ncols:])
176
- if self.names is not None:
177
- ax5.bar(np.arange(0, len(val_eve[s, :])), val_eve[s, :], label='ground truth')
178
- ax5.bar(np.arange(0, len(pred_eve[s, :])), pred_eve[s, :], width=0.5, label='prediction', alpha=0.5)
179
- ax5.set_xticks(np.arange(0, len(val_eve[s, :])))
180
- ax5.set_xticklabels(self.names, rotation=45)
181
- else:
182
- ax5.plot(np.arange(0, len(val_eve[s, :])), val_eve[s, :], label='ground truth', alpha=0.5,
183
- drawstyle='steps-mid')
184
- ax5.plot(np.arange(0, len(pred_eve[s, :])), pred_eve[s, :], label='prediction', alpha=0.5,
185
- drawstyle='steps-mid')
186
- ax5.set_yscale('log')
187
- ax5.legend()
188
-
189
- ax6 = fig.add_subplot(gs[s * nrows + 1, ncols:])
190
- if self.names is not None:
191
- ax6.bar(np.arange(0, len(val_eve[s, :])), np.abs(pred_eve[s, :] - val_eve[s, :]) / val_eve[s, :] * 100,
192
- label='relative error (%)')
193
- ax6.set_xticks(np.arange(0, len(val_eve[s, :])))
194
- ax6.set_xticklabels(self.names, rotation=45)
195
- else:
196
- ax6.plot(np.arange(0, len(val_eve[s, :])), np.abs(pred_eve[s, :] - val_eve[s, :]) / val_eve[s, :] * 100,
197
- label='relative error (%)', alpha=0.5, drawstyle='steps-mid')
198
- ax6.set_yscale('log')
199
- ax6.legend()
200
-
201
- fig.tight_layout()
202
- return fig
203
-
204
-
205
- class SpectrumPredictionLogger(ImagePredictionLogger):
206
- def __init__(self, val_imgs, val_eve, names, aia_wavelengths):
207
- super().__init__(val_imgs, val_eve, names, aia_wavelengths)
208
-
209
- def plot_aia_eve(self, val_imgs, val_eve, pred_eve):
210
- """
211
- Function to plot a 4 channel AIA stack and the EVE barplots
212
-
213
- Arguments:
214
- ----------
215
- val_imgs: numpy array
216
- Stack with 4 image channels
217
- val_eve: numpy array
218
- Stack of ground-truth eve channels
219
- pred_eve: numpy array
220
- Stack of predicted eve channels
221
- Returns:
222
- --------
223
- fig: matplotlib figure
224
- figure with plots
225
- """
226
- samples = pred_eve.shape[0]
227
- n_aia_wavelengths = len(self.aia_wavelengths)
228
- wspace = 0.2
229
- hspace = 0.125
230
- dpi = 200
231
-
232
- fig = plt.figure(figsize=(5, 5), dpi=dpi)
233
- gs = fig.add_gridspec(2, 1, wspace=wspace, hspace=hspace)
234
-
235
- # eve data
236
- s = 0
237
- ax5 = fig.add_subplot(gs[0, 0])
238
- if self.names is not None:
239
- ax5.bar(np.arange(0, len(val_eve[s, :])), val_eve[s, :], label='ground truth')
240
- ax5.bar(np.arange(0, len(pred_eve[s, :])), pred_eve[s, :], width=0.5, label='prediction', alpha=0.5)
241
- ax5.set_xticks(np.arange(0, len(val_eve[s, :])))
242
- ax5.set_xticklabels(self.names, rotation=45)
243
- else:
244
- ax5.plot(np.arange(0, len(val_eve[s, :])), val_eve[s, :], label='ground truth', alpha=0.5,
245
- drawstyle='steps-mid')
246
- ax5.plot(np.arange(0, len(pred_eve[s, :])), pred_eve[s, :], label='prediction', alpha=0.5,
247
- drawstyle='steps-mid')
248
- ax5.set_yscale('log')
249
- ax5.legend()
250
-
251
- ax6 = fig.add_subplot(gs[1, 0])
252
- if self.names is not None:
253
- ax6.bar(np.arange(0, len(val_eve[s, :])), np.abs(pred_eve[s, :] - val_eve[s, :]) / val_eve[s, :] * 100,
254
- label='relative error (%)')
255
- ax6.set_xticks(np.arange(0, len(val_eve[s, :])))
256
- ax6.set_xticklabels(self.names, rotation=45)
257
- else:
258
- ax6.plot(np.arange(0, len(val_eve[s, :])), np.abs(pred_eve[s, :] - val_eve[s, :]) / val_eve[s, :] * 100,
259
- label='relative error (%)', alpha=0.5, drawstyle='steps-mid')
260
- ax6.set_yscale('log')
261
- ax6.legend()
262
-
263
- fig.tight_layout()
264
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
{flaring/download β†’ forecasting}/__init__.py RENAMED
File without changes
{flaring/forecasting β†’ forecasting}/data_loaders/SDOAIA_dataloader.py RENAMED
File without changes
{flaring/forecasting β†’ forecasting/data_loaders}/__init__.py RENAMED
File without changes
{flaring/forecasting β†’ forecasting}/data_loaders/sxr_normalization.py RENAMED
File without changes
{flaring/forecasting/data_loaders β†’ forecasting/inference}/__init__.py RENAMED
File without changes
{flaring/forecasting β†’ forecasting}/inference/evaluation.py RENAMED
File without changes
{flaring/forecasting β†’ forecasting}/inference/inference.py RENAMED
File without changes
{flaring/forecasting β†’ forecasting}/inference/inference_config.yaml RENAMED
File without changes
{flaring/forecasting β†’ forecasting}/inference/inference_on_patch.py RENAMED
@@ -12,11 +12,11 @@ import torch
12
  import numpy as np
13
  from torch.utils.checkpoint import checkpoint
14
  from torch.utils.data import DataLoader
15
- from flaring.forecasting.data_loaders.SDOAIA_dataloader import AIA_GOESDataset
16
- import flaring.forecasting.models as models
17
- from flaring.forecasting.models.vit_patch_model import ViT
18
- from flaring.forecasting.models.linear_and_hybrid import HybridIrradianceModel # Add your hybrid model import
19
- from flaring.forecasting.training.callback import unnormalize_sxr
20
  import yaml
21
  import torch.nn.functional as F
22
  from concurrent.futures import ThreadPoolExecutor
 
12
  import numpy as np
13
  from torch.utils.checkpoint import checkpoint
14
  from torch.utils.data import DataLoader
15
+ from forecasting.data_loaders.SDOAIA_dataloader import AIA_GOESDataset
16
+ import forecasting.models as models
17
+ from forecasting.models.vit_patch_model import ViT
18
+ from forecasting.models.linear_and_hybrid import HybridIrradianceModel # Add your hybrid model import
19
+ from forecasting.training.callback import unnormalize_sxr
20
  import yaml
21
  import torch.nn.functional as F
22
  from concurrent.futures import ThreadPoolExecutor
{flaring/forecasting β†’ forecasting}/inference/inference_on_patch_config.yaml RENAMED
File without changes
{flaring/forecasting β†’ forecasting}/inference/plotting.py RENAMED
File without changes
{flaring/forecasting β†’ forecasting}/models/FastSpectralNet.py RENAMED
File without changes
{flaring/forecasting/inference β†’ forecasting/models}/__init__.py RENAMED
File without changes
{flaring/forecasting β†’ forecasting}/models/base_model.py RENAMED
File without changes
{flaring/forecasting β†’ forecasting}/models/linear_and_hybrid.py RENAMED
File without changes
{flaring/forecasting β†’ forecasting}/models/vision_transformer_custom.py RENAMED
File without changes
{flaring/forecasting β†’ forecasting}/models/vit_patch_model.py RENAMED
File without changes
{flaring/forecasting/models β†’ forecasting/training}/__init__.py RENAMED
File without changes
{flaring/forecasting β†’ forecasting}/training/callback.py RENAMED
File without changes
{flaring/forecasting β†’ forecasting}/training/config.yaml RENAMED
File without changes
{flaring/forecasting β†’ forecasting}/training/train.py RENAMED
@@ -18,14 +18,14 @@ import sys
18
  PROJECT_ROOT = Path(__file__).parent.parent.parent.parent.absolute()
19
  sys.path.insert(0, str(PROJECT_ROOT))
20
 
21
- from flaring.forecasting.data_loaders.SDOAIA_dataloader import AIA_GOESDataModule
22
- from flaring.forecasting.models.vision_transformer_custom import ViT
23
- from flaring.forecasting.models.linear_and_hybrid import LinearIrradianceModel, HybridIrradianceModel
24
- from flaring.forecasting.models.vit_patch_model import ViT as ViTPatch
25
  from callback import ImagePredictionLogger_SXR, AttentionMapCallback
26
  from pytorch_lightning.callbacks import Callback
27
 
28
- from flaring.forecasting.models.FastSpectralNet import FastViTFlaringModel
29
 
30
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
31
  os.environ["NCCL_DEBUG"] = "WARN"
 
18
  PROJECT_ROOT = Path(__file__).parent.parent.parent.parent.absolute()
19
  sys.path.insert(0, str(PROJECT_ROOT))
20
 
21
+ from forecasting.data_loaders.SDOAIA_dataloader import AIA_GOESDataModule
22
+ from forecasting.models.vision_transformer_custom import ViT
23
+ from forecasting.models.linear_and_hybrid import LinearIrradianceModel, HybridIrradianceModel
24
+ from forecasting.models.vit_patch_model import ViT as ViTPatch
25
  from callback import ImagePredictionLogger_SXR, AttentionMapCallback
26
  from pytorch_lightning.callbacks import Callback
27
 
28
+ from forecasting.models.FastSpectralNet import FastViTFlaringModel
29
 
30
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
31
  os.environ["NCCL_DEBUG"] = "WARN"
{flaring/forecasting/training β†’ utils}/__init__.py RENAMED
File without changes
{flaring/utils β†’ utils}/cut_off_aia.py RENAMED
File without changes