griffingoodwin04 commited on
Commit
3fb991b
·
1 Parent(s): 648b2a8

updates to evaluation script and callbacks

Browse files
forecasting/inference/auto_evaluate.py CHANGED
@@ -305,7 +305,7 @@ def main():
305
  parser.add_argument('-checkpoint_dir', type=str, help='Directory containing checkpoint files')
306
  parser.add_argument('-checkpoint_path', type=str, help='Specific checkpoint file path')
307
  parser.add_argument('-model_name', type=str, required=True, help='Name for the model (used for output naming)')
308
- parser.add_argument('-base_data_dir', type=str, default='/mnt/data/NO-OVERLAP', help='Base data directory')
309
  parser.add_argument('-skip_inference', action='store_true', help='Skip inference and only run evaluation')
310
  parser.add_argument('-skip_evaluation', action='store_true', help='Skip evaluation and only run inference')
311
 
 
305
  parser.add_argument('-checkpoint_dir', type=str, help='Directory containing checkpoint files')
306
  parser.add_argument('-checkpoint_path', type=str, help='Specific checkpoint file path')
307
  parser.add_argument('-model_name', type=str, required=True, help='Name for the model (used for output naming)')
308
+ parser.add_argument('-base_data_dir', type=str, default='/mnt/data/', help='Base data directory')
309
  parser.add_argument('-skip_inference', action='store_true', help='Skip inference and only run evaluation')
310
  parser.add_argument('-skip_evaluation', action='store_true', help='Skip evaluation and only run inference')
311
 
forecasting/inference/inference_template.yaml CHANGED
@@ -57,25 +57,3 @@ data:
57
  sxr_norm_path: "${base_data_dir}/SXR-SPLIT/normalized_sxr.npy"
58
  checkpoint_path: "PLACEHOLDER_CHECKPOINT_PATH" # Will be replaced by batch script
59
 
60
- # MEGSAI parameters (should match training config)
61
- megsai:
62
- cnn_model: "updated"
63
- cnn_dp: 0.2
64
- weight_decay: 1e-5
65
- cosine_restart_T0: 50
66
- cosine_restart_Tmult: 2
67
- cosine_eta_min: 1e-7
68
-
69
- # Fusion parameters (if using fusion model)
70
- fusion:
71
- scalar_branch: "hybrid"
72
- lr: 0.0001
73
- lambda_vit_to_target: 0.3
74
- lambda_scalar_to_target: 0.1
75
- learnable_gate: true
76
- gate_init_bias: 5.0
77
- scalar_kwargs:
78
- d_input: 6
79
- d_output: 1
80
- cnn_model: "updated"
81
- cnn_dp: 0.75
 
57
  sxr_norm_path: "${base_data_dir}/SXR-SPLIT/normalized_sxr.npy"
58
  checkpoint_path: "PLACEHOLDER_CHECKPOINT_PATH" # Will be replaced by batch script
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
forecasting/training/callback.py CHANGED
@@ -274,9 +274,57 @@ class AttentionMapCallback(Callback):
274
  img_display = (img_np[:, :, 0] + 1) / 2
275
  img_display = np.stack([img_display] * 3, axis=2)
276
 
277
- # Visualization layout logic (unchanged)
278
- # [The plotting logic remains as-is from the original script]
279
- # Produces multiple subplots showing attention patterns and overlayed maps.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
  plt.tight_layout()
282
  return fig
 
274
  img_display = (img_np[:, :, 0] + 1) / 2
275
  img_display = np.stack([img_display] * 3, axis=2)
276
 
277
+ # Create the figure and subplots
278
+ fig, axes = plt.subplots(2, 3, figsize=(15, 10))
279
+ fig.suptitle(f'Attention Visualization - Epoch {epoch}, Sample {sample_idx}', fontsize=16)
280
+
281
+ # Plot 1: Original image
282
+ axes[0, 0].imshow(img_display)
283
+ axes[0, 0].set_title('Original Image')
284
+ axes[0, 0].axis('off')
285
+
286
+ # Plot 2: Attention map
287
+ im1 = axes[0, 1].imshow(attention_map, cmap='hot', interpolation='nearest')
288
+ axes[0, 1].set_title('Attention Map')
289
+ axes[0, 1].axis('off')
290
+ plt.colorbar(im1, ax=axes[0, 1])
291
+
292
+ # Plot 3: Overlay
293
+ axes[0, 2].imshow(img_display)
294
+ axes[0, 2].imshow(attention_map, cmap='hot', alpha=0.6, interpolation='nearest')
295
+ axes[0, 2].set_title('Attention Overlay')
296
+ axes[0, 2].axis('off')
297
+
298
+ # Plot 4: Center attention (if available)
299
+ if center_map is not None:
300
+ im2 = axes[1, 0].imshow(center_map, cmap='hot', interpolation='nearest')
301
+ axes[1, 0].set_title('Center Patch Attention')
302
+ axes[1, 0].axis('off')
303
+ plt.colorbar(im2, ax=axes[1, 0])
304
+ else:
305
+ axes[1, 0].text(0.5, 0.5, 'Center attention\nnot available',
306
+ ha='center', va='center', transform=axes[1, 0].transAxes)
307
+ axes[1, 0].set_title('Center Patch Attention')
308
+ axes[1, 0].axis('off')
309
+
310
+ # Plot 5: Patch flux (if available)
311
+ if patch_flux is not None:
312
+ patch_flux_np = patch_flux.cpu().numpy().reshape(grid_h, grid_w)
313
+ im3 = axes[1, 1].imshow(patch_flux_np, cmap='viridis', interpolation='nearest')
314
+ axes[1, 1].set_title('Patch Flux')
315
+ axes[1, 1].axis('off')
316
+ plt.colorbar(im3, ax=axes[1, 1])
317
+ else:
318
+ axes[1, 1].text(0.5, 0.5, 'Patch flux\nnot available',
319
+ ha='center', va='center', transform=axes[1, 1].transAxes)
320
+ axes[1, 1].set_title('Patch Flux')
321
+ axes[1, 1].axis('off')
322
+
323
+ # Plot 6: Attention statistics
324
+ axes[1, 2].hist(attention_map.flatten(), bins=50, alpha=0.7)
325
+ axes[1, 2].set_title('Attention Distribution')
326
+ axes[1, 2].set_xlabel('Attention Weight')
327
+ axes[1, 2].set_ylabel('Frequency')
328
 
329
  plt.tight_layout()
330
  return fig
forecasting/training/localpatch.yaml CHANGED
@@ -1,7 +1,7 @@
1
 
2
  #Base directories - change these to switch datasets
3
- base_data_dir: "/mnt/data/PAPER_DATA_B" # Change this line for different datasets
4
- base_checkpoint_dir: "/mnt/data/PAPER_DATA_B" # Change this line for different datasets
5
  wavelengths: [94, 131, 171, 193, 211, 304, 335] # AIA wavelengths in Angstroms
6
 
7
  # GPU configuration
@@ -13,8 +13,8 @@ wavelengths: [94, 131, 171, 193, 211, 304, 335] # AIA wavelengths in Angstroms
13
  gpu_ids: "all" # Use both GPUs
14
  # Model configuration
15
  selected_model: "ViTLocal" # Options: "hybrid", "vit", "fusion", "vitpatch"
16
- batch_size: 4
17
- epochs: 250
18
  oversample: false
19
  balance_strategy: "upsample_minority"
20
  calculate_base_weights: false # Whether to calculate class-based weights for loss function
@@ -51,5 +51,5 @@ wandb:
51
  - aia
52
  - sxr
53
  - regression
54
- wb_name: paper-testing-8-patch-335-512-hidden-6-layers-256-embed-dim
55
  notes: Regression from AIA images (6 channels) to GOES SXR flux
 
1
 
2
  #Base directories - change these to switch datasets
3
+ base_data_dir: "/mnt/data/PAPER_DATA_WITH_335" # Change this line for different datasets
4
+ base_checkpoint_dir: "/mnt/data/PAPER_DATA_WITH_335" # Change this line for different datasets
5
  wavelengths: [94, 131, 171, 193, 211, 304, 335] # AIA wavelengths in Angstroms
6
 
7
  # GPU configuration
 
13
  gpu_ids: "all" # Use both GPUs
14
  # Model configuration
15
  selected_model: "ViTLocal" # Options: "hybrid", "vit", "fusion", "vitpatch"
16
+ batch_size: 6
17
+ epochs: 150
18
  oversample: false
19
  balance_strategy: "upsample_minority"
20
  calculate_base_weights: false # Whether to calculate class-based weights for loss function
 
51
  - aia
52
  - sxr
53
  - regression
54
+ wb_name: paper-testing-8-patch-335-512-hidden-6-layers-256-embed-dim-updated
55
  notes: Regression from AIA images (6 channels) to GOES SXR flux