griffingoodwin04 commited on
Commit
130c9b3
·
1 Parent(s): c14af1c

made evaluation easier by adding auto eval... makes configs on its own...

Browse files
data/iti_data_processing.py CHANGED
@@ -144,7 +144,7 @@ else:
144
  print(f"Processing {len(unprocessed_indices)} unprocessed samples")
145
 
146
  if unprocessed_indices:
147
- with Pool(processes=90) as pool:
148
  list(tqdm(pool.imap(save_sample, unprocessed_indices), total=len(unprocessed_indices)))
149
  print("AIA data processing completed.")
150
  else:
 
144
  print(f"Processing {len(unprocessed_indices)} unprocessed samples")
145
 
146
  if unprocessed_indices:
147
+ with Pool(processes=os.cpu_count()) as pool:
148
  list(tqdm(pool.imap(save_sample, unprocessed_indices), total=len(unprocessed_indices)))
149
  print("AIA data processing completed.")
150
  else:
forecasting/inference/auto_evaluate.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Automated Evaluation Script for Solar Flare Models
4
+
5
+ This script automatically generates inference and evaluation configs
6
+ and runs the complete evaluation pipeline based on a directory input.
7
+
8
+ Usage:
9
+ python auto_evaluate.py -checkpoint_dir /path/to/checkpoint/dir -model_name my_model
10
+ python auto_evaluate.py -checkpoint_path /path/to/checkpoint.pth -model_name my_model
11
+ """
12
+
13
+ import argparse
14
+ import os
15
+ import subprocess
16
+ import sys
17
+ import yaml
18
+ from pathlib import Path
19
+ from datetime import datetime
20
+ import glob
21
+
22
+ # Add project root to Python path
23
+ PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute()
24
+ sys.path.insert(0, str(PROJECT_ROOT))
25
+
26
+ def find_checkpoint_files(checkpoint_dir):
27
+ """Find checkpoint files in directory"""
28
+ patterns = ['*.pth', '*.ckpt', '*.pt']
29
+ checkpoints = []
30
+
31
+ for pattern in patterns:
32
+ checkpoints.extend(glob.glob(str(Path(checkpoint_dir) / pattern)))
33
+ checkpoints.extend(glob.glob(str(Path(checkpoint_dir) / '**' / pattern), recursive=True))
34
+
35
+ return sorted(checkpoints)
36
+
37
+ def detect_model_type(checkpoint_path):
38
+ """Detect model type from checkpoint filename or content"""
39
+ filename = Path(checkpoint_path).name.lower()
40
+
41
+ if 'local' in filename or 'localized' in filename:
42
+ return 'vitlocal'
43
+ elif 'patch' in filename:
44
+ return 'vitpatch'
45
+ elif 'fusion' in filename:
46
+ return 'fusion'
47
+ elif 'hybrid' in filename:
48
+ return 'hybrid'
49
+ elif 'linear' in filename:
50
+ return 'linear'
51
+ else:
52
+ # Default to vit for backward compatibility
53
+ return 'vit'
54
+
55
+ def create_inference_config(checkpoint_path, model_name, base_data_dir="/mnt/data/COMBINED"):
56
+ """Create inference config for checkpoint"""
57
+
58
+ # Detect model type
59
+ model_type = detect_model_type(checkpoint_path)
60
+
61
+ # Create output directory
62
+ output_dir = f"/mnt/data/batch_results/{model_name}"
63
+ os.makedirs(output_dir, exist_ok=True)
64
+ os.makedirs(f"{output_dir}/weights", exist_ok=True)
65
+
66
+ # Generate config
67
+ config = {
68
+ 'SolO': 'false',
69
+ 'Stereo': 'false',
70
+ 'base_data_dir': base_data_dir,
71
+ 'data': {
72
+ 'aia_dir': f"{base_data_dir}/AIA-SPLIT/",
73
+ 'checkpoint_path': checkpoint_path,
74
+ 'sxr_dir': f"{base_data_dir}/SXR-SPLIT/",
75
+ 'sxr_norm_path': f"{base_data_dir}/SXR-SPLIT/normalized_sxr.npy"
76
+ },
77
+ 'model': model_type,
78
+ 'wavelengths': [94, 131, 171, 193, 211, 304],
79
+ 'mc': {
80
+ 'active': 'false',
81
+ 'runs': 5
82
+ },
83
+ 'model_params': {
84
+ 'batch_size': 16,
85
+ 'input_size': 512,
86
+ 'no_weights': False,
87
+ 'patch_size': 16
88
+ },
89
+ 'vit_custom': {
90
+ 'embed_dim': 512,
91
+ 'hidden_dim': 512,
92
+ 'num_channels': 6,
93
+ 'num_classes': 1,
94
+ 'patch_size': 16,
95
+ 'num_patches': 1024,
96
+ 'num_heads': 8,
97
+ 'num_layers': 6,
98
+ 'dropout': 0.1
99
+ },
100
+ 'megsai': {
101
+ 'cnn_model': 'updated',
102
+ 'cnn_dp': 0.2,
103
+ 'weight_decay': 1e-5,
104
+ 'cosine_restart_T0': 50,
105
+ 'cosine_restart_Tmult': 2,
106
+ 'cosine_eta_min': 1e-7
107
+ },
108
+ 'output_path': f"{output_dir}/{model_name}_predictions.csv",
109
+ 'weight_path': f"{output_dir}/weights"
110
+ }
111
+
112
+ # Add model-specific configs
113
+ if model_type == 'fusion':
114
+ config['fusion'] = {
115
+ 'scalar_branch': 'hybrid',
116
+ 'lr': 0.0001,
117
+ 'lambda_vit_to_target': 0.3,
118
+ 'lambda_scalar_to_target': 0.1,
119
+ 'learnable_gate': True,
120
+ 'gate_init_bias': 5.0,
121
+ 'scalar_kwargs': {
122
+ 'd_input': 6,
123
+ 'd_output': 1,
124
+ 'cnn_model': 'updated',
125
+ 'cnn_dp': 0.75
126
+ }
127
+ }
128
+
129
+ return config, output_dir
130
+
131
+ def create_evaluation_config(model_name, output_dir, base_data_dir="/mnt/data/COMBINED"):
132
+ """Create evaluation config"""
133
+
134
+ config = {
135
+ 'base_data_dir': base_data_dir,
136
+ 'output_base_dir': f"{base_data_dir}/solar_flare_comparison_results",
137
+ 'data': {
138
+ 'aia_dir': f"{base_data_dir}/AIA-SPLIT/test/",
139
+ 'weight_path': f"{output_dir}/weights"
140
+ },
141
+ 'model_predictions': {
142
+ 'main_model_csv': f"{output_dir}/{model_name}_predictions.csv",
143
+ 'baseline_csv': ''
144
+ },
145
+ 'evaluation': {
146
+ 'output_dir': output_dir,
147
+ 'sxr_cutoff': 1e-7
148
+ },
149
+ 'time_range': {
150
+ 'start_time': '2023-08-05T00:00:00',
151
+ 'end_time': '2023-08-07T23:59:00',
152
+ 'interval_minutes': 1
153
+ },
154
+ 'plotting': {
155
+ 'figure_size': [12, 8],
156
+ 'dpi': 300,
157
+ 'colormap': 'sdoaia171'
158
+ },
159
+ 'metrics': {
160
+ 'include_rmse': True,
161
+ 'include_mae': True,
162
+ 'include_r2': True,
163
+ 'include_correlation': True
164
+ }
165
+ }
166
+
167
+ return config
168
+
169
+ def run_inference(inference_config_path):
170
+ """Run inference with the generated config"""
171
+ print(f"Running inference with config: {inference_config_path}")
172
+
173
+ cmd = [
174
+ sys.executable,
175
+ str(PROJECT_ROOT / "forecasting/inference/inference.py"),
176
+ "-config", inference_config_path
177
+ ]
178
+
179
+ result = subprocess.run(cmd, capture_output=True, text=True)
180
+
181
+ if result.returncode != 0:
182
+ print(f"Error running inference: {result.stderr}")
183
+ return False
184
+
185
+ print("Inference completed successfully!")
186
+ return True
187
+
188
+ def run_evaluation(evaluation_config_path):
189
+ """Run evaluation with the generated config"""
190
+ print(f"Running evaluation with config: {evaluation_config_path}")
191
+
192
+ cmd = [
193
+ sys.executable,
194
+ str(PROJECT_ROOT / "forecasting/inference/evaluation.py"),
195
+ "-config", evaluation_config_path
196
+ ]
197
+
198
+ result = subprocess.run(cmd, capture_output=True, text=True)
199
+
200
+ if result.returncode != 0:
201
+ print(f"Error running evaluation: {result.stderr}")
202
+ return False
203
+
204
+ print("Evaluation completed successfully!")
205
+ return True
206
+
207
+ def main():
208
+ parser = argparse.ArgumentParser(description='Automated evaluation for solar flare models')
209
+ parser.add_argument('-checkpoint_dir', type=str, help='Directory containing checkpoint files')
210
+ parser.add_argument('-checkpoint_path', type=str, help='Specific checkpoint file path')
211
+ parser.add_argument('-model_name', type=str, required=True, help='Name for the model (used for output naming)')
212
+ parser.add_argument('-base_data_dir', type=str, default='/mnt/data/COMBINED', help='Base data directory')
213
+ parser.add_argument('-skip_inference', action='store_true', help='Skip inference and only run evaluation')
214
+ parser.add_argument('-skip_evaluation', action='store_true', help='Skip evaluation and only run inference')
215
+
216
+ args = parser.parse_args()
217
+
218
+ # Determine checkpoint path
219
+ if args.checkpoint_path:
220
+ checkpoint_path = args.checkpoint_path
221
+ if not os.path.exists(checkpoint_path):
222
+ print(f"Error: Checkpoint file not found: {checkpoint_path}")
223
+ sys.exit(1)
224
+ elif args.checkpoint_dir:
225
+ checkpoints = find_checkpoint_files(args.checkpoint_dir)
226
+ if not checkpoints:
227
+ print(f"Error: No checkpoint files found in {args.checkpoint_dir}")
228
+ sys.exit(1)
229
+ elif len(checkpoints) > 1:
230
+ print(f"Found multiple checkpoints: {checkpoints}")
231
+ print("Using the first one. Use -checkpoint_path to specify a specific file.")
232
+ checkpoint_path = checkpoints[0]
233
+ else:
234
+ print("Error: Must specify either -checkpoint_dir or -checkpoint_path")
235
+ sys.exit(1)
236
+
237
+ print(f"Using checkpoint: {checkpoint_path}")
238
+ print(f"Model name: {args.model_name}")
239
+
240
+ # Create configs
241
+ inference_config, output_dir = create_inference_config(checkpoint_path, args.model_name, args.base_data_dir)
242
+ evaluation_config = create_evaluation_config(args.model_name, output_dir, args.base_data_dir)
243
+
244
+ # Save configs
245
+ inference_config_path = f"/tmp/inference_config_{args.model_name}.yaml"
246
+ evaluation_config_path = f"/tmp/evaluation_config_{args.model_name}.yaml"
247
+
248
+ with open(inference_config_path, 'w') as f:
249
+ yaml.dump(inference_config, f, default_flow_style=False)
250
+
251
+ with open(evaluation_config_path, 'w') as f:
252
+ yaml.dump(evaluation_config, f, default_flow_style=False)
253
+
254
+ print(f"Configs saved to:")
255
+ print(f" Inference: {inference_config_path}")
256
+ print(f" Evaluation: {evaluation_config_path}")
257
+ print(f" Output directory: {output_dir}")
258
+
259
+ # Run inference
260
+ if not args.skip_inference:
261
+ if not run_inference(inference_config_path):
262
+ print("Inference failed. Stopping.")
263
+ sys.exit(1)
264
+ else:
265
+ print("Skipping inference...")
266
+
267
+ # Run evaluation
268
+ if not args.skip_evaluation:
269
+ if not run_evaluation(evaluation_config_path):
270
+ print("Evaluation failed. Stopping.")
271
+ sys.exit(1)
272
+ else:
273
+ print("Skipping evaluation...")
274
+
275
+ print(f"\n✅ Complete! Results saved to: {output_dir}")
276
+ print(f"📊 Check the plots and metrics in: {output_dir}")
277
+
278
+ if __name__ == '__main__':
279
+ main()
forecasting/inference/checkpoint_list.yaml CHANGED
@@ -8,10 +8,10 @@ checkpoints:
8
  # checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-reduced-sensitivity-final-20250921_185953.pth"
9
  # - name: "baseweights-final"
10
  # checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-reduced-sensitivity-changed-base-weights-final-20250921_223323.pth"
11
- - name: "claude-mse"
12
- checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-mse-claude-epoch=62-val_total_loss=0.1904.ckpt"
13
- - name: "baseweights-mse"
14
- checkpoint_path: /mnt/data/COMBINED/new-checkpoint/vit-mse-base-weights-epoch=62-val_total_loss=0.2893.ckpt"
15
  # - name: "stereo-final"
16
  # checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-reduced-sensitivity-STEREO-final-20250921_183739.pth"
17
 
 
8
  # checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-reduced-sensitivity-final-20250921_185953.pth"
9
  # - name: "baseweights-final"
10
  # checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-reduced-sensitivity-changed-base-weights-final-20250921_223323.pth"
11
+ - name: "claude-localized"
12
+ checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-local-patch-mse-claude-final-20250929_050650.pth"
13
+ # - name: "baseweights-mse"
14
+ # checkpoint_path: /mnt/data/COMBINED/new-checkpoint/vit-mse-base-weights-epoch=62-val_total_loss=0.2893.ckpt"
15
  # - name: "stereo-final"
16
  # checkpoint_path: "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-reduced-sensitivity-STEREO-final-20250921_183739.pth"
17
 
forecasting/inference/evaluation.py CHANGED
@@ -622,7 +622,7 @@ class SolarFlareEvaluator:
622
  return None, None, None
623
 
624
  def generate_frame_worker(self, timestamp):
625
- """Worker function to generate a single frame with uncertainty bands"""
626
  try:
627
  print(f"Worker {os.getpid()}: Processing {timestamp}")
628
 
@@ -665,9 +665,7 @@ class SolarFlareEvaluator:
665
 
666
  ax.imshow(aia_img, cmap=cm.cmlist['sdoaia131'], origin='lower')
667
  ax.imshow(attention_data, cmap='hot', origin='lower', alpha=0.5,norm=att_norm)
668
- # Plot star at maximum attention location
669
- # ax.plot(max_x, max_y, marker='*', markersize=10, color='cyan',
670
- # markeredgecolor='white', markeredgewidth=1)
671
  ax.set_title(f'AIA {wavs[1]} Å', fontsize=12, fontfamily='Barlow', color='white')
672
  ax.axis('off')
673
 
@@ -687,67 +685,29 @@ class SolarFlareEvaluator:
687
  gt = sxr_window['groundtruth'].values
688
  uncertainties = sxr_window['groundtruth_uncertainty'].values
689
 
690
- # Create upper and lower bounds (assuming uncertainty is standard deviation)
691
- upper_bound = gt + uncertainties
692
- lower_bound = gt - uncertainties
693
-
694
  # Ensure bounds are positive for log scale
695
  lower_bound = np.maximum(lower_bound, 1e-12)
696
 
697
- #sxr_ax.fill_between(sxr_window['timestamp'], lower_bound, upper_bound,
698
- #alpha=0.3, color="#F78E69")
699
-
700
  # Plot model predictions with uncertainty bands
701
  model_label = 'Baseline Model' if self.baseline_only_mode else 'FOXES Model'
702
  model_color = "#94ECBE" if self.baseline_only_mode else "#C0B9DD"
703
- vit_prediction_line = sxr_ax.plot(sxr_window['timestamp'], sxr_window['predictions'],
704
  label=model_label, linewidth=2.5, alpha=1, markersize=5,
705
  color=model_color)
706
 
707
- # Add uncertainty bands for model if available
708
- if 'uncertainty' in sxr_window.columns and sxr_window['uncertainty'].notna().any():
709
- predictions = sxr_window['predictions'].values
710
- uncertainties = sxr_window['uncertainty'].values
711
-
712
- # Create upper and lower bounds (assuming uncertainty is standard deviation)
713
- upper_bound = predictions + uncertainties
714
- lower_bound = predictions - uncertainties
715
-
716
- # Ensure bounds are positive for log scale
717
- lower_bound = np.maximum(lower_bound, 1e-12)
718
-
719
- sxr_ax.fill_between(sxr_window['timestamp'], lower_bound, upper_bound,
720
- alpha=0.3, color=model_color)
721
-
722
- # Plot baseline predictions with uncertainty bands if available and not in baseline-only mode
723
  if not self.baseline_only_mode and 'baseline_predictions' in sxr_window.columns and sxr_window[
724
  'baseline_predictions'].notna().any():
725
  baseline_line = sxr_ax.plot(sxr_window['timestamp'], sxr_window['baseline_predictions'],
726
  label='Baseline Model', linewidth=1.5, alpha=1, markersize=5,
727
  color="#94ECBE")
728
 
729
- # Add uncertainty bands for baseline model if available
730
- if 'baseline_uncertainty' in sxr_window.columns and sxr_window[
731
- 'baseline_uncertainty'].notna().any():
732
- baseline_predictions = sxr_window['baseline_predictions'].values
733
- baseline_uncertainties = sxr_window['baseline_uncertainty'].values
734
-
735
- # Create upper and lower bounds
736
- baseline_upper = baseline_predictions + baseline_uncertainties
737
- baseline_lower = baseline_predictions - baseline_uncertainties
738
-
739
- # Ensure bounds are positive for log scale
740
- baseline_lower = np.maximum(baseline_lower, 1e-12)
741
-
742
- sxr_ax.fill_between(sxr_window['timestamp'], baseline_lower, baseline_upper,
743
- alpha=0.3, color="#94ECBE")
744
-
745
  # Mark current time
746
  if sxr_current is not None:
747
  sxr_ax.axvline(target_time, color='black', linestyle='--',
748
  linewidth=2, alpha=0.4, label='Current Time')
749
 
750
- # Create info text with all available values including uncertainties
751
  model_name = 'Baseline' if self.baseline_only_mode else 'FOXES'
752
  info_lines = ["Current Values:",
753
  f"Ground Truth: {sxr_current['groundtruth']:.2e}",
@@ -812,12 +772,6 @@ class SolarFlareEvaluator:
812
  transform=sxr_ax.transAxes, fontsize=12, fontfamily='Barlow',
813
  horizontalalignment='center', verticalalignment='center')
814
  sxr_ax.set_title('SXR Data Comparison with Uncertainties', fontsize=12, fontfamily='Barlow')
815
- #
816
- # for spine in sxr_ax.spines.values():
817
- # spine.set_color('white')
818
-
819
- #plt.suptitle(f'Timestamp: {timestamp}', fontsize=14)
820
- #plt.tight_layout()
821
  plt.savefig(save_path, dpi=500, facecolor='none',bbox_inches='tight')
822
  plt.close()
823
 
 
622
  return None, None, None
623
 
624
  def generate_frame_worker(self, timestamp):
625
+ """Worker function to generate a single frame"""
626
  try:
627
  print(f"Worker {os.getpid()}: Processing {timestamp}")
628
 
 
665
 
666
  ax.imshow(aia_img, cmap=cm.cmlist['sdoaia131'], origin='lower')
667
  ax.imshow(attention_data, cmap='hot', origin='lower', alpha=0.5,norm=att_norm)
668
+
 
 
669
  ax.set_title(f'AIA {wavs[1]} Å', fontsize=12, fontfamily='Barlow', color='white')
670
  ax.axis('off')
671
 
 
685
  gt = sxr_window['groundtruth'].values
686
  uncertainties = sxr_window['groundtruth_uncertainty'].values
687
 
 
 
 
 
688
  # Ensure bounds are positive for log scale
689
  lower_bound = np.maximum(lower_bound, 1e-12)
690
 
 
 
 
691
  # Plot model predictions with uncertainty bands
692
  model_label = 'Baseline Model' if self.baseline_only_mode else 'FOXES Model'
693
  model_color = "#94ECBE" if self.baseline_only_mode else "#C0B9DD"
694
+ sxr_ax.plot(sxr_window['timestamp'], sxr_window['predictions'],
695
  label=model_label, linewidth=2.5, alpha=1, markersize=5,
696
  color=model_color)
697
 
698
+ # Plot baseline predictions if available and not in baseline-only mode
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
699
  if not self.baseline_only_mode and 'baseline_predictions' in sxr_window.columns and sxr_window[
700
  'baseline_predictions'].notna().any():
701
  baseline_line = sxr_ax.plot(sxr_window['timestamp'], sxr_window['baseline_predictions'],
702
  label='Baseline Model', linewidth=1.5, alpha=1, markersize=5,
703
  color="#94ECBE")
704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
705
  # Mark current time
706
  if sxr_current is not None:
707
  sxr_ax.axvline(target_time, color='black', linestyle='--',
708
  linewidth=2, alpha=0.4, label='Current Time')
709
 
710
+ # Create info text with all available values
711
  model_name = 'Baseline' if self.baseline_only_mode else 'FOXES'
712
  info_lines = ["Current Values:",
713
  f"Ground Truth: {sxr_current['groundtruth']:.2e}",
 
772
  transform=sxr_ax.transAxes, fontsize=12, fontfamily='Barlow',
773
  horizontalalignment='center', verticalalignment='center')
774
  sxr_ax.set_title('SXR Data Comparison with Uncertainties', fontsize=12, fontfamily='Barlow')
 
 
 
 
 
 
775
  plt.savefig(save_path, dpi=500, facecolor='none',bbox_inches='tight')
776
  plt.close()
777
 
forecasting/inference/evaluation_config.yaml CHANGED
@@ -27,8 +27,8 @@ evaluation:
27
  # interval_minutes: 1
28
 
29
  time_range:
30
- start_time: "2014-08-01T00:00:00"
31
- end_time: "2014-08-31T23:59:00"
32
  interval_minutes: 1
33
 
34
  # Plotting parameters
 
27
  # interval_minutes: 1
28
 
29
  time_range:
30
+ start_time: "2023-08-05T00:00:00"
31
+ end_time: "2023-08-07T23:59:00"
32
  interval_minutes: 1
33
 
34
  # Plotting parameters
forecasting/inference/inference.py CHANGED
@@ -15,7 +15,9 @@ sys.path.insert(0, str(PROJECT_ROOT))
15
 
16
  from forecasting.data_loaders.SDOAIA_dataloader import AIA_GOESDataset
17
  import forecasting.models as models
18
- from forecasting.models.vision_transformer_custom import ViT
 
 
19
  from forecasting.models.linear_and_hybrid import HybridIrradianceModel, LinearIrradianceModel # Add your hybrid and linear model imports
20
  from torch.nn import HuberLoss
21
  from forecasting.training.callback import unnormalize_sxr
@@ -30,58 +32,11 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30
 
31
  def has_attention_weights(model):
32
  """Check if model supports attention weights"""
33
- return hasattr(model, 'attention') or isinstance(model, ViT)
34
-
35
- #Does not return SXR data or use Dataloader for solo dataset
36
- def evaluate_solo_dataset(model, dataset, batch_size=16, times=None, config_data=None, save_weights=True, input_size = 512, patch_size = 16):
37
- """Optimized generator for SolO dataset without Dataloader"""
38
- model.eval()
39
- supports_attention = has_attention_weights(model) and save_weights
40
-
41
- with torch.no_grad():
42
- for batch_idx, batch in enumerate(dataset):
43
- # Correct unpacking based on your data structure
44
- aia_imgs = batch[0] # Get aia_img from inputs
45
- # Move to device (it's already a tensor)
46
- aia_imgs = aia_imgs.to(device, non_blocking=True)
47
-
48
- # Get model predictions for entire batch
49
- pred = model(aia_imgs)
50
-
51
- # Handle different model output formats
52
- if isinstance(pred, tuple) and len(pred) > 1:
53
- predictions = pred[0] # Shape: [batch_size, ...]
54
- weights = pred[1] if supports_attention else None # Shape: [batch_size, heads, L, S ...]
55
- else:
56
- predictions = pred
57
- weights = None
58
-
59
- # Process entire batch at once for weights if needed
60
- batch_weights = []
61
- if supports_attention and weights is not None:
62
- current_batch_size = predictions.shape[0]
63
- for i in range(current_batch_size):
64
- last_layer_attention = weights[-1][i] # Get i-th item from batch [num_heads, seq_len, seq_len]
65
- avg_attention = last_layer_attention.mean(dim=0) # [seq_len, seq_len]
66
-
67
- cls_attention = avg_attention[0, 1:].cpu() # [num_patches] - 1D array
68
-
69
- grid_h, grid_w = input_size // patch_size, input_size // patch_size # Should be 64, 64
70
-
71
- attention_map = cls_attention.reshape(grid_h, grid_w) # [64, 64]
72
-
73
- batch_weights.append(attention_map.numpy())
74
-
75
- if config_data and 'weight_path' in config_data:
76
- save_batch_weights(batch_weights, batch_idx, batch_size, times, config_data['weight_path'])
77
-
78
- current_batch_size = predictions.shape[0]
79
- for i in range(current_batch_size):
80
- global_idx = batch_idx * batch_size + i
81
- weight_data = batch_weights[i] if (supports_attention and batch_weights) else None
82
- yield (predictions[i].cpu().numpy(),
83
- weight_data, global_idx)
84
 
 
 
 
85
 
86
 
87
  def evaluate_model_on_dataset(model, dataset, batch_size=16, times=None, config_data=None, save_weights=True, input_size = 512, patch_size = 16):
@@ -101,7 +56,10 @@ def evaluate_model_on_dataset(model, dataset, batch_size=16, times=None, config_
101
  aia_imgs = aia_imgs.to(device, non_blocking=True)
102
 
103
  # Get model predictions for entire batch
104
- pred = model(aia_imgs)
 
 
 
105
 
106
  # Handle different model output formats
107
  if isinstance(pred, tuple) and len(pred) > 1:
@@ -115,23 +73,49 @@ def evaluate_model_on_dataset(model, dataset, batch_size=16, times=None, config_
115
  batch_weights = []
116
  if supports_attention and weights is not None:
117
  current_batch_size = predictions.shape[0]
 
 
118
  for i in range(current_batch_size):
119
- # Process attention weights for this item - matching callback approach
120
- #select last layer and appropriate item from batch
121
- last_layer_attention = weights[-1][i] # Get i-th item from batch [num_heads, seq_len, seq_len]
122
- # Average across attention heads
123
- avg_attention = last_layer_attention.mean(dim=0) # [seq_len, seq_len]
124
-
125
- # Get attention from CLS token to patches (exclude CLS->CLS)
126
- cls_attention = avg_attention[0, 1:].cpu() # [num_patches] - 1D array
127
-
128
- # Calculate grid size based on patch size (assuming 8x8 patches)
129
- grid_h, grid_w = input_size // patch_size, input_size // patch_size # Should be 64, 64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- # Reshape CLS attention to spatial grid
132
- attention_map = cls_attention.reshape(grid_h, grid_w) # [64, 64]
133
 
134
- batch_weights.append(attention_map.numpy())
 
 
 
 
 
 
 
135
 
136
  # Save all weights in this batch at once
137
  if config_data and 'weight_path' in config_data:
@@ -145,98 +129,6 @@ def evaluate_model_on_dataset(model, dataset, batch_size=16, times=None, config_
145
  yield (predictions[i].cpu().numpy(), sxr[i].cpu().numpy(),
146
  weight_data, global_idx)
147
 
148
- #Evaluate model with batches using mc dropout
149
- def evaluate_model_on_dataset_mc_dropout(model, dataset, batch_size=16, times=None, config_data=None, save_weights=True,
150
- input_size=512, patch_size=16, runs=100, sxr_norm=None):
151
- """Streaming MC Dropout - processes each batch with multiple forward passes without loading all data"""
152
-
153
- loader = DataLoader(dataset, batch_size=batch_size, num_workers=4, pin_memory=True)
154
- supports_attention = has_attention_weights(model) and save_weights
155
-
156
- print(f"Starting streaming MC Dropout with {runs} forward passes per batch...")
157
-
158
- for batch_idx, batch in enumerate(loader):
159
- aia_imgs = batch[0] # Shape: [batch_size, ...]
160
- sxr = batch[1]
161
- aia_imgs = aia_imgs.to(device, non_blocking=True)
162
- current_batch_size = aia_imgs.shape[0]
163
-
164
- if (batch_idx * batch_size) % 100 == 0:
165
- print(
166
- f"Processing batch {batch_idx + 1}, samples {batch_idx * batch_size + 1}-{batch_idx * batch_size + current_batch_size}")
167
-
168
- # Storage for this batch's MC predictions
169
- # Shape: [runs, batch_size, prediction_dims...]
170
- batch_predictions = []
171
- batch_weights = [] if supports_attention else None
172
-
173
- # Perform MC dropout runs for this batch
174
- for run in range(runs):
175
- #Set seed based on run
176
- torch.manual_seed(run) # Ensure different dropout masks for each run
177
-
178
- model.train() # Enable dropout
179
-
180
- with torch.no_grad():
181
- pred = model(aia_imgs)
182
-
183
- if isinstance(pred, tuple) and len(pred) > 1:
184
- predictions = pred[0] # [batch_size, ...]
185
- weights = pred[1] if supports_attention else None
186
- else:
187
- predictions = pred
188
- weights = None
189
-
190
- # Store predictions for this run
191
- batch_predictions.append(predictions.cpu().numpy())
192
-
193
- # Process attention weights for this run
194
- if supports_attention and weights is not None:
195
- run_weights = []
196
- for i in range(current_batch_size):
197
- last_layer_attention = weights[-1][i] # [num_heads, seq_len, seq_len]
198
- avg_attention = last_layer_attention.mean(dim=0) # [seq_len, seq_len]
199
- cls_attention = avg_attention[0, 1:].cpu() # [num_patches]
200
-
201
- grid_h, grid_w = input_size // patch_size, input_size // patch_size
202
- attention_map = cls_attention.reshape(grid_h, grid_w)
203
- run_weights.append(attention_map.numpy())
204
-
205
- if batch_weights is None:
206
- batch_weights = []
207
- batch_weights.append(run_weights) # [runs, batch_size, grid_h, grid_w]
208
-
209
- # Convert to numpy and compute statistics
210
- # batch_predictions: [runs, batch_size, prediction_dims...]
211
- batch_predictions = np.array(batch_predictions)
212
-
213
- # Compute mean and std across runs (axis=0)
214
- # Result shapes: [batch_size, prediction_dims...]
215
- mean_predictions = np.mean(unnormalize_sxr(batch_predictions,sxr_norm=sxr_norm), axis=0)
216
- uncertainties = np.std(unnormalize_sxr(batch_predictions,sxr_norm=sxr_norm), axis=0)
217
-
218
- # Process attention weights if available
219
- mean_weights = None
220
- if supports_attention and batch_weights:
221
- # batch_weights: [runs, batch_size, grid_h, grid_w]
222
- batch_weights = np.array(batch_weights)
223
- # mean_weights: [batch_size, grid_h, grid_w]
224
- mean_weights = np.mean(batch_weights, axis=0)
225
-
226
- # Save weights if required
227
- if config_data and 'weight_path' in config_data:
228
- save_batch_weights(list(mean_weights), batch_idx, batch_size, times, config_data['weight_path'])
229
-
230
- # Yield results for each sample in the batch
231
- for i in range(current_batch_size):
232
- global_idx = batch_idx * batch_size + i
233
- weight_data = mean_weights[i] if mean_weights is not None else None
234
-
235
- yield (mean_predictions[i], # Mean prediction across MC runs
236
- sxr[i].cpu().numpy(), # Ground truth
237
- uncertainties[i], # Uncertainty (std) across MC runs
238
- weight_data, # Mean attention weights
239
- global_idx) # Sample index
240
 
241
  def save_batch_weights(batch_weights, batch_idx, batch_size, times, weight_path):
242
  """Save all weights in a batch efficiently"""
@@ -252,8 +144,9 @@ def save_batch_weights(batch_weights, batch_idx, batch_size, times, weight_path)
252
  save_args = []
253
  for i, weight in enumerate(batch_weights):
254
  global_idx = batch_idx * batch_size + i
255
- if global_idx < len(times): # Make sure we don't go out of bounds
256
- filepath = weight_path + f"{times[global_idx]}"
 
257
  save_args.append((weight, filepath))
258
 
259
  # Save all weights in this batch in parallel
@@ -283,7 +176,11 @@ def load_model_from_config(config_data):
283
  if ".ckpt" in checkpoint_path:
284
  # Lightning checkpoint format
285
  if model_type.lower() == 'vit':
286
- model = ViT.load_from_checkpoint(checkpoint_path)
 
 
 
 
287
  elif model_type.lower() == 'hybrid' or model_type.lower() == 'hybridirradiancemodel':
288
  # Try to load with saved hyperparameters first, then fall back to config parameters
289
  try:
@@ -427,113 +324,51 @@ def main():
427
 
428
  print(f"Processing {total_samples} samples with batch size {batch_size}...")
429
 
430
- if config_data['mc']['active'] == "false":
431
- print("Running inference without MC Dropout")
432
- for prediction, sxr, weight, idx in evaluate_model_on_dataset(
433
- model, dataset, batch_size, times, config_data, save_weights, input_size, patch_size
434
- ):
435
- # Unnormalize prediction
436
  pred = unnormalize_sxr(prediction, sxr_norm)
 
 
437
 
438
- # Store results
439
- predictions.append(pred.item() if hasattr(pred, 'item') else float(pred))
440
- ground.append(sxr.item() if hasattr(sxr, 'item') else float(sxr))
441
- timestamp.append(str(times[idx]))
442
 
443
- # Progress update
444
- if (idx + 1) % 50 == 0:
445
- print(f"Processed {idx + 1}/{total_samples}")
446
 
447
- if save_weights:
448
- print("All weights saved during batch processing!")
449
- else:
450
- print("Inference completed (no weights saved)!")
451
-
452
- # Create and save results DataFrame
453
- print("Creating output DataFrame...")
454
- output_df = pd.DataFrame({
455
- 'timestamp': timestamp,
456
- 'predictions': predictions,
457
- 'groundtruth': ground
458
- })
459
-
460
- print(output_df.head())
461
- #Make output directory if it doesn't exist
462
- output_dir = Path(config_data['output_path']).parent
463
- output_dir.mkdir(parents=True, exist_ok=True)
464
- output_df.to_csv(config_data['output_path'], index=False)
465
- print(f"Predictions saved to {config_data['output_path']}")
466
  else:
467
- print("Running inference with MC Dropout")
468
- if config_data['mc']['active'] == "false":
469
- print("Running inference without MC Dropout")
470
- for prediction, sxr, weight, idx in evaluate_model_on_dataset(
471
- model, dataset, batch_size, times, config_data, save_weights, input_size, patch_size
472
- ):
473
- # Unnormalize prediction
474
- pred = unnormalize_sxr(prediction, sxr_norm)
475
-
476
- # Store results
477
- predictions.append(pred.item() if hasattr(pred, 'item') else float(pred))
478
- ground.append(sxr.item() if hasattr(sxr, 'item') else float(sxr))
479
- timestamp.append(str(times[idx]))
480
-
481
- # Progress update
482
- if (idx + 1) % 50 == 0:
483
- print(f"Processed {idx + 1}/{total_samples}")
484
-
485
- # Create and save results DataFrame
486
- print("Creating output DataFrame...")
487
- output_df = pd.DataFrame({
488
- 'timestamp': timestamp,
489
- 'predictions': predictions,
490
- 'groundtruth': ground
491
- })
492
-
493
- else:
494
- #print("Running inference with MC Dropout")
495
- uncertainties = [] # Add this to store uncertainties
496
- mc_runs = config_data['mc']['runs'] # Allow configurable MC runs
497
-
498
- # Choose between batch processing or single-sample processing
499
- # Use single-sample for very large datasets or memory constraints
500
-
501
- print(f"Using batch MC Dropout with {mc_runs} runs per batch")
502
- mc_generator = evaluate_model_on_dataset_mc_dropout(
503
- model, dataset, batch_size, times, config_data, save_weights,
504
- input_size, patch_size, runs=mc_runs, sxr_norm=sxr_norm
505
- )
506
-
507
- for prediction, sxr, uncertainty, weight, idx in mc_generator:
508
- # Unnormalize prediction and uncertainty
509
- #pred = unnormalize_sxr(prediction, sxr_norm)
510
- #unc = unnormalize_sxr(uncertainty, sxr_norm)
511
-
512
- # Store results
513
- predictions.append(prediction.item() if hasattr(prediction, 'item') else float(prediction))
514
- ground.append(sxr.item() if hasattr(sxr, 'item') else float(sxr))
515
- uncertainties.append(uncertainty.item() if hasattr(uncertainty, 'item') else float(uncertainty))
516
- timestamp.append(str(times[idx]))
517
-
518
- # Progress update
519
- if (idx + 1) % 50 == 0:
520
- print(f"Processed {idx + 1}/{total_samples}")
521
-
522
- # Create and save results DataFrame with uncertainty
523
- print("Creating output DataFrame with uncertainty...")
524
- output_df = pd.DataFrame({
525
- 'timestamp': timestamp,
526
- 'predictions': predictions,
527
- 'groundtruth': ground,
528
- 'uncertainty': uncertainties # Add uncertainty column
529
- })
530
-
531
- print(output_df.head())
532
- # Make output directory if it doesn't exist
533
- output_dir = Path(config_data['output_path']).parent
534
- output_dir.mkdir(parents=True, exist_ok=True)
535
- output_df.to_csv(config_data['output_path'], index=False)
536
- print(f"Predictions saved to {config_data['output_path']}")
537
 
538
 
539
  if __name__ == '__main__':
 
15
 
16
  from forecasting.data_loaders.SDOAIA_dataloader import AIA_GOESDataset
17
  import forecasting.models as models
18
+ from forecasting.models.vision_transformer_custom import ViT as ViTCustom
19
+ from forecasting.models.vit_patch_model import ViT as ViTPatch
20
+ from forecasting.models.vit_patch_model_local import ViTLocal
21
  from forecasting.models.linear_and_hybrid import HybridIrradianceModel, LinearIrradianceModel # Add your hybrid and linear model imports
22
  from torch.nn import HuberLoss
23
  from forecasting.training.callback import unnormalize_sxr
 
32
 
33
  def has_attention_weights(model):
34
  """Check if model supports attention weights"""
35
+ return hasattr(model, 'attention') or isinstance(model, ViTCustom) or isinstance(model, ViTPatch) or isinstance(model, ViTLocal)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ def is_localized_attention_model(model):
38
+ """Check if model uses localized attention (no CLS token)"""
39
+ return isinstance(model, ViTLocal)
40
 
41
 
42
  def evaluate_model_on_dataset(model, dataset, batch_size=16, times=None, config_data=None, save_weights=True, input_size = 512, patch_size = 16):
 
56
  aia_imgs = aia_imgs.to(device, non_blocking=True)
57
 
58
  # Get model predictions for entire batch
59
+ if supports_attention:
60
+ pred = model(aia_imgs, return_attention=True)
61
+ else:
62
+ pred = model(aia_imgs)
63
 
64
  # Handle different model output formats
65
  if isinstance(pred, tuple) and len(pred) > 1:
 
73
  batch_weights = []
74
  if supports_attention and weights is not None:
75
  current_batch_size = predictions.shape[0]
76
+ is_localized = is_localized_attention_model(model)
77
+
78
  for i in range(current_batch_size):
79
+ try:
80
+ # Process attention weights for this item
81
+ last_layer_attention = weights[-1][i] # Get i-th item from batch [num_heads, seq_len, seq_len]
82
+
83
+ # Check for None or invalid values
84
+ if last_layer_attention is None:
85
+ print(f"Warning: last_layer_attention is None for sample {i}")
86
+ continue
87
+
88
+ # Average across attention heads
89
+ avg_attention = last_layer_attention.mean(dim=0) # [seq_len, seq_len]
90
+
91
+ # Check for NaN or invalid values
92
+ if torch.isnan(avg_attention).any():
93
+ print(f"Warning: NaN values in avg_attention for sample {i}")
94
+ continue
95
+
96
+ if is_localized:
97
+ # For ViTLocal (no CLS token), create attention map by averaging attention TO each patch
98
+ # This gives us how much each patch is "attended to" by its neighbors
99
+ patch_attention = avg_attention.mean(dim=0).cpu() # [num_patches] - average attention received by each patch
100
+ else:
101
+ # For regular ViT (with CLS token), get attention from CLS token to patches
102
+ cls_attention = avg_attention[0, 1:].cpu() # [num_patches] - CLS token attention to patches
103
+ patch_attention = cls_attention
104
+
105
+ # Calculate grid size based on patch size
106
+ grid_h, grid_w = input_size // patch_size, input_size // patch_size
107
 
108
+ # Reshape patch attention to spatial grid
109
+ attention_map = patch_attention.reshape(grid_h, grid_w)
110
 
111
+ batch_weights.append(attention_map.numpy())
112
+
113
+ except Exception as e:
114
+ print(f"Error processing attention weights for sample {i}: {e}")
115
+ # Add a zero attention map as fallback
116
+ grid_h, grid_w = input_size // patch_size, input_size // patch_size
117
+ fallback_map = torch.zeros(grid_h * grid_w).reshape(grid_h, grid_w).numpy()
118
+ batch_weights.append(fallback_map)
119
 
120
  # Save all weights in this batch at once
121
  if config_data and 'weight_path' in config_data:
 
129
  yield (predictions[i].cpu().numpy(), sxr[i].cpu().numpy(),
130
  weight_data, global_idx)
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  def save_batch_weights(batch_weights, batch_idx, batch_size, times, weight_path):
134
  """Save all weights in a batch efficiently"""
 
144
  save_args = []
145
  for i, weight in enumerate(batch_weights):
146
  global_idx = batch_idx * batch_size + i
147
+ if global_idx < len(times):# Make sure we don't go out of bounds
148
+ #Save to weight path using os join
149
+ filepath = os.path.join(weight_path, f"{times[global_idx]}")
150
  save_args.append((weight, filepath))
151
 
152
  # Save all weights in this batch in parallel
 
176
  if ".ckpt" in checkpoint_path:
177
  # Lightning checkpoint format
178
  if model_type.lower() == 'vit':
179
+ model = ViTCustom.load_from_checkpoint(checkpoint_path)
180
+ elif model_type.lower() == 'vitpatch':
181
+ model = ViTPatch.load_from_checkpoint(checkpoint_path)
182
+ elif model_type.lower() == 'vitlocal':
183
+ model = ViTLocal.load_from_checkpoint(checkpoint_path)
184
  elif model_type.lower() == 'hybrid' or model_type.lower() == 'hybridirradiancemodel':
185
  # Try to load with saved hyperparameters first, then fall back to config parameters
186
  try:
 
324
 
325
  print(f"Processing {total_samples} samples with batch size {batch_size}...")
326
 
327
+ print("Running inference...")
328
+ for prediction, sxr, weight, idx in evaluate_model_on_dataset(
329
+ model, dataset, batch_size, times, config_data, save_weights, input_size, patch_size
330
+ ):
331
+ # Unnormalize prediction only if not ViTPatch / ViTLocal
332
+ if not isinstance(model, ViTPatch) and not isinstance(model, ViTLocal):
333
  pred = unnormalize_sxr(prediction, sxr_norm)
334
+ else:
335
+ pred = prediction
336
 
337
+ # Store results
338
+ predictions.append(pred.item() if hasattr(pred, 'item') else float(pred))
339
+ ground.append(sxr.item() if hasattr(sxr, 'item') else float(sxr))
340
+ timestamp.append(str(times[idx]))
341
 
342
+ # Progress update
343
+ if (idx + 1) % 50 == 0:
344
+ print(f"Processed {idx + 1}/{total_samples}")
345
 
346
+ if save_weights:
347
+ print("All weights saved during batch processing!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  else:
349
+ print("Inference completed (no weights saved)!")
350
+
351
+ # Create and save results DataFrame
352
+ print("Creating output DataFrame...")
353
+ output_df = pd.DataFrame({
354
+ 'timestamp': timestamp,
355
+ 'predictions': predictions,
356
+ 'groundtruth': ground
357
+ })
358
+
359
+ print(output_df.head())
360
+ #Make output directory if it doesn't exist
361
+ output_dir = Path(config_data['output_path']).parent
362
+ output_dir.mkdir(parents=True, exist_ok=True)
363
+ output_df.to_csv(config_data['output_path'], index=False)
364
+ print(f"Predictions saved to {config_data['output_path']}")
365
+
366
+ print(output_df.head())
367
+ # Make output directory if it doesn't exist
368
+ output_dir = Path(config_data['output_path']).parent
369
+ output_dir.mkdir(parents=True, exist_ok=True)
370
+ output_df.to_csv(config_data['output_path'], index=False)
371
+ print(f"Predictions saved to {config_data['output_path']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
 
373
 
374
  if __name__ == '__main__':
forecasting/inference/inference_config.yaml DELETED
@@ -1,45 +0,0 @@
1
- # Base directories - change these to switch datasets
2
- base_data_dir: "/mnt/data/ML-READY/" # Change this line for different datasets
3
- output_path: "${base_data_dir}/output/baseline-model-more-complex-STEREO.csv"
4
- weight_path: "${base_data_dir}/baseline-model/"
5
- mc:
6
- active: "false"
7
- runs: 5
8
- # Enable or disable MC Dropout
9
- # Data paths (automatically constructed from base directories)
10
- Stereo: "false"
11
- Stereo_data:
12
- stereo_img_dir: "/mnt/data/ML-Ready-mixed/STEREO_processed"
13
- sxr_dir: "/mnt/data/ML-Ready-mixed/ML-Ready-mixed/SXR"
14
- sxr_norm_path: "/mnt/data/ML-READY/SXR/normalized_sxr.npy"
15
- SolO: "false"
16
- SolO_data:
17
- solo_img_dir: "/mnt/data/ML-Ready_clean/SolO/SolO/ML-Ready-SolO"
18
- sxr_dir: "${base_data_dir}/SXR"
19
- sxr_norm_path: "${base_data_dir}/SolO/SXR/normalized_sxr.npy"
20
-
21
- model: "hybrid" # Options: "vit", "hybrid"
22
- wavelengths: [171, 193, 211, 304] # AIA wavelengths in Angstroms
23
-
24
- # Model parameters
25
- model_params:
26
- input_size: 512
27
- patch_size: 16
28
- batch_size: 100
29
- no_weights: false # Set to true to skip saving attention weights
30
-
31
- megsai:
32
- cnn_model: "updated" # Must match the training config
33
- cnn_dp: 0.2
34
-
35
- data:
36
- aia_dir:
37
- "${base_data_dir}/AIA"
38
- sxr_dir:
39
- "${base_data_dir}/SXR"
40
- sxr_norm_path:
41
- "/mnt/data/ML-READY/SXR/normalized_sxr.npy"
42
- checkpoint_path:
43
- "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-claude-suggested-weights-epoch=30-val_total_loss=0.0385.ckpt"
44
-
45
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
forecasting/inference/inference_on_patch_config.yaml DELETED
@@ -1,32 +0,0 @@
1
- base_data_dir: "/mnt/data/COMBINED/" # Change this line for different datasets
2
- output_path: "${base_data_dir}/output/rs.csv"
3
- weight_path: "${base_data_dir}/rs_weights/"
4
- flux_path: "${base_data_dir}/rs_flux/"
5
- mc:
6
- active: "false"
7
- runs: 5
8
- # Enable or disable MC Dropout
9
- # Data paths (automatically constructed from base directories)
10
- Stereo: "false"
11
- Stereo_data:
12
- stereo_img_dir: "/mnt/data/ML-Ready-mixed/STEREO_processed"
13
- sxr_dir: "/mnt/data/ML-Ready-mixed/ML-Ready-mixed/SXR"
14
- sxr_norm_path: "/mnt/data/ML-Ready-mixed/ML-Ready-mixed/SXR/normalized_sxr.npy"
15
- SolO: "false"
16
- SolO_data:
17
- solo_img_dir: "/mnt/data/ML-Ready_clean/SolO/SolO/ML-Ready-SolO"
18
- sxr_dir: "${base_data_dir}/SXR"
19
- sxr_norm_path: "${base_data_dir}/SolO/SXR/normalized_sxr.npy"
20
- model: "vit" # Options: "cnn", "vit", "ViT Custom"
21
- wavelengths: [94,131,171, 193, 211, 304] # AIA wavelengths in Angstroms
22
- data:
23
- aia_dir:
24
- "${base_data_dir}/AIA-SPLIT"
25
- sxr_dir:
26
- "${base_data_dir}/SXR-SPLIT"
27
- sxr_norm_path:
28
- "${base_data_dir}/SXR-SPLIT/normalized_sxr.npy"
29
- checkpoint_path:
30
- "/mnt/data/COMBINED/new-checkpoint/vit-patch-model-2d-embeddings-reduced-sensitivity-epoch=42-val_total_loss=0.0393.ckpt"
31
-
32
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
forecasting/inference/patch_analysis_config.yaml DELETED
@@ -1,42 +0,0 @@
1
- # Analysis configuration
2
- base_data_dir: "/mnt/data/COMBINED"
3
- output_path: "${base_data_dir}/output/patch.csv"
4
- aia_path: "${base_data_dir}/AIA-SPLIT/train/"
5
- weight_path: "${base_data_dir}/patch_weights/"
6
- flux_path: "${base_data_dir}/patch_flux/"
7
- attention_path: "${base_data_dir}/patch_attention/"
8
-
9
- data:
10
- aia_dir:
11
- "${base_data_dir}/AIA-SPLIT"
12
- sxr_dir:
13
- "${base_data_dir}/SXR-SPLIT"
14
- sxr_norm_path:
15
- "${base_data_dir}/SXR-SPLIT/normalized_sxr.npy"
16
-
17
-
18
-
19
- analysis:
20
- # Time period selection for testing analysis
21
- time_period:
22
- start_time: "2023-08-05 00:00:00" # Start time for analysis
23
- end_time: "2023-08-06 00:00:00" # End time for analysis
24
- # Set to null to analyze all available data
25
- # start_time: null
26
- # end_time: null
27
-
28
- # Flare detection parameters
29
- flare_detection:
30
- threshold_percentile: 97.0
31
- min_patches: 2
32
- max_patches: 50
33
- simultaneous_flare_threshold: 0.000005 # Threshold for detecting simultaneous flares
34
-
35
- # Output configuration
36
- output:
37
- output_dir: "${base_data_dir}/flux_analysis_output"
38
- create_visualizations: true
39
- max_visualizations: 100
40
- visualization_threshold: 0.00005 # Only save figures for predictions above this threshold (5e-5)
41
- create_movie: true
42
- movie_fps: 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
forecasting/models/vit_patch_model_local.py CHANGED
@@ -200,8 +200,7 @@ class VisionTransformerLocal(nn.Module):
200
  B, T, _ = x.shape
201
  x = self.input_layer(x)
202
 
203
- # Add CLS token and positional encoding
204
- #x = x + self.pos_embedding[:, : T + 1]
205
  x = self._add_2d_positional_encoding(x)
206
 
207
  # Apply Transformer blocks
@@ -237,7 +236,7 @@ class VisionTransformerLocal(nn.Module):
237
  def _add_2d_positional_encoding(self, x):
238
  """Add learned 2D positional encoding to patch embeddings"""
239
  B, T, embed_dim = x.shape
240
- num_patches = T # Exclude CLS token
241
 
242
  # Reshape patches to 2D grid: [B, grid_h, grid_w, embed_dim]
243
  patch_embeddings = x.reshape(B, self.grid_h, self.grid_w, embed_dim)
@@ -247,9 +246,9 @@ class VisionTransformerLocal(nn.Module):
247
  patch_embeddings = patch_embeddings + self.pos_embedding_2d
248
 
249
  # Reshape back to sequence format: [B, num_patches, embed_dim]
250
- patch_embeddings = patch_embeddings.reshape(B, num_patches, embed_dim)
251
 
252
- return patch_embeddings
253
 
254
  def forward_for_callback(self, x, return_attention=True):
255
  """Forward method compatible with AttentionMapCallback"""
@@ -329,7 +328,10 @@ class LocalAttentionBlock(nn.Module):
329
  num_patches = self.num_patches # 32x32 patches
330
  grid_size = int(math.sqrt(num_patches))
331
 
 
332
  mask = torch.zeros(num_patches, num_patches)
 
 
333
  for i in range(num_patches):
334
  row_i, col_i = i // grid_size, i % grid_size
335
  for j in range(num_patches):
 
200
  B, T, _ = x.shape
201
  x = self.input_layer(x)
202
 
203
+ # Add positional encoding (no CLS token for local attention)
 
204
  x = self._add_2d_positional_encoding(x)
205
 
206
  # Apply Transformer blocks
 
236
  def _add_2d_positional_encoding(self, x):
237
  """Add learned 2D positional encoding to patch embeddings"""
238
  B, T, embed_dim = x.shape
239
+ num_patches = T # All tokens are patches (no CLS token)
240
 
241
  # Reshape patches to 2D grid: [B, grid_h, grid_w, embed_dim]
242
  patch_embeddings = x.reshape(B, self.grid_h, self.grid_w, embed_dim)
 
246
  patch_embeddings = patch_embeddings + self.pos_embedding_2d
247
 
248
  # Reshape back to sequence format: [B, num_patches, embed_dim]
249
+ x = patch_embeddings.reshape(B, num_patches, embed_dim)
250
 
251
+ return x
252
 
253
  def forward_for_callback(self, x, return_attention=True):
254
  """Forward method compatible with AttentionMapCallback"""
 
328
  num_patches = self.num_patches # 32x32 patches
329
  grid_size = int(math.sqrt(num_patches))
330
 
331
+ # Create mask for patches only: [num_patches, num_patches]
332
  mask = torch.zeros(num_patches, num_patches)
333
+
334
+ # Patches can only attend to nearby patches
335
  for i in range(num_patches):
336
  row_i, col_i = i // grid_size, i % grid_size
337
  for j in range(num_patches):