Commit ·
3fb991b
1
Parent(s): 648b2a8
updates to evaluation script and callbacks
Browse files
forecasting/inference/auto_evaluate.py
CHANGED
|
@@ -305,7 +305,7 @@ def main():
|
|
| 305 |
parser.add_argument('-checkpoint_dir', type=str, help='Directory containing checkpoint files')
|
| 306 |
parser.add_argument('-checkpoint_path', type=str, help='Specific checkpoint file path')
|
| 307 |
parser.add_argument('-model_name', type=str, required=True, help='Name for the model (used for output naming)')
|
| 308 |
-
parser.add_argument('-base_data_dir', type=str, default='/mnt/data/
|
| 309 |
parser.add_argument('-skip_inference', action='store_true', help='Skip inference and only run evaluation')
|
| 310 |
parser.add_argument('-skip_evaluation', action='store_true', help='Skip evaluation and only run inference')
|
| 311 |
|
|
|
|
| 305 |
parser.add_argument('-checkpoint_dir', type=str, help='Directory containing checkpoint files')
|
| 306 |
parser.add_argument('-checkpoint_path', type=str, help='Specific checkpoint file path')
|
| 307 |
parser.add_argument('-model_name', type=str, required=True, help='Name for the model (used for output naming)')
|
| 308 |
+
parser.add_argument('-base_data_dir', type=str, default='/mnt/data/', help='Base data directory')
|
| 309 |
parser.add_argument('-skip_inference', action='store_true', help='Skip inference and only run evaluation')
|
| 310 |
parser.add_argument('-skip_evaluation', action='store_true', help='Skip evaluation and only run inference')
|
| 311 |
|
forecasting/inference/inference_template.yaml
CHANGED
|
@@ -57,25 +57,3 @@ data:
|
|
| 57 |
sxr_norm_path: "${base_data_dir}/SXR-SPLIT/normalized_sxr.npy"
|
| 58 |
checkpoint_path: "PLACEHOLDER_CHECKPOINT_PATH" # Will be replaced by batch script
|
| 59 |
|
| 60 |
-
# MEGSAI parameters (should match training config)
|
| 61 |
-
megsai:
|
| 62 |
-
cnn_model: "updated"
|
| 63 |
-
cnn_dp: 0.2
|
| 64 |
-
weight_decay: 1e-5
|
| 65 |
-
cosine_restart_T0: 50
|
| 66 |
-
cosine_restart_Tmult: 2
|
| 67 |
-
cosine_eta_min: 1e-7
|
| 68 |
-
|
| 69 |
-
# Fusion parameters (if using fusion model)
|
| 70 |
-
fusion:
|
| 71 |
-
scalar_branch: "hybrid"
|
| 72 |
-
lr: 0.0001
|
| 73 |
-
lambda_vit_to_target: 0.3
|
| 74 |
-
lambda_scalar_to_target: 0.1
|
| 75 |
-
learnable_gate: true
|
| 76 |
-
gate_init_bias: 5.0
|
| 77 |
-
scalar_kwargs:
|
| 78 |
-
d_input: 6
|
| 79 |
-
d_output: 1
|
| 80 |
-
cnn_model: "updated"
|
| 81 |
-
cnn_dp: 0.75
|
|
|
|
| 57 |
sxr_norm_path: "${base_data_dir}/SXR-SPLIT/normalized_sxr.npy"
|
| 58 |
checkpoint_path: "PLACEHOLDER_CHECKPOINT_PATH" # Will be replaced by batch script
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
forecasting/training/callback.py
CHANGED
|
@@ -274,9 +274,57 @@ class AttentionMapCallback(Callback):
|
|
| 274 |
img_display = (img_np[:, :, 0] + 1) / 2
|
| 275 |
img_display = np.stack([img_display] * 3, axis=2)
|
| 276 |
|
| 277 |
-
#
|
| 278 |
-
|
| 279 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
|
| 281 |
plt.tight_layout()
|
| 282 |
return fig
|
|
|
|
| 274 |
img_display = (img_np[:, :, 0] + 1) / 2
|
| 275 |
img_display = np.stack([img_display] * 3, axis=2)
|
| 276 |
|
| 277 |
+
# Create the figure and subplots
|
| 278 |
+
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
|
| 279 |
+
fig.suptitle(f'Attention Visualization - Epoch {epoch}, Sample {sample_idx}', fontsize=16)
|
| 280 |
+
|
| 281 |
+
# Plot 1: Original image
|
| 282 |
+
axes[0, 0].imshow(img_display)
|
| 283 |
+
axes[0, 0].set_title('Original Image')
|
| 284 |
+
axes[0, 0].axis('off')
|
| 285 |
+
|
| 286 |
+
# Plot 2: Attention map
|
| 287 |
+
im1 = axes[0, 1].imshow(attention_map, cmap='hot', interpolation='nearest')
|
| 288 |
+
axes[0, 1].set_title('Attention Map')
|
| 289 |
+
axes[0, 1].axis('off')
|
| 290 |
+
plt.colorbar(im1, ax=axes[0, 1])
|
| 291 |
+
|
| 292 |
+
# Plot 3: Overlay
|
| 293 |
+
axes[0, 2].imshow(img_display)
|
| 294 |
+
axes[0, 2].imshow(attention_map, cmap='hot', alpha=0.6, interpolation='nearest')
|
| 295 |
+
axes[0, 2].set_title('Attention Overlay')
|
| 296 |
+
axes[0, 2].axis('off')
|
| 297 |
+
|
| 298 |
+
# Plot 4: Center attention (if available)
|
| 299 |
+
if center_map is not None:
|
| 300 |
+
im2 = axes[1, 0].imshow(center_map, cmap='hot', interpolation='nearest')
|
| 301 |
+
axes[1, 0].set_title('Center Patch Attention')
|
| 302 |
+
axes[1, 0].axis('off')
|
| 303 |
+
plt.colorbar(im2, ax=axes[1, 0])
|
| 304 |
+
else:
|
| 305 |
+
axes[1, 0].text(0.5, 0.5, 'Center attention\nnot available',
|
| 306 |
+
ha='center', va='center', transform=axes[1, 0].transAxes)
|
| 307 |
+
axes[1, 0].set_title('Center Patch Attention')
|
| 308 |
+
axes[1, 0].axis('off')
|
| 309 |
+
|
| 310 |
+
# Plot 5: Patch flux (if available)
|
| 311 |
+
if patch_flux is not None:
|
| 312 |
+
patch_flux_np = patch_flux.cpu().numpy().reshape(grid_h, grid_w)
|
| 313 |
+
im3 = axes[1, 1].imshow(patch_flux_np, cmap='viridis', interpolation='nearest')
|
| 314 |
+
axes[1, 1].set_title('Patch Flux')
|
| 315 |
+
axes[1, 1].axis('off')
|
| 316 |
+
plt.colorbar(im3, ax=axes[1, 1])
|
| 317 |
+
else:
|
| 318 |
+
axes[1, 1].text(0.5, 0.5, 'Patch flux\nnot available',
|
| 319 |
+
ha='center', va='center', transform=axes[1, 1].transAxes)
|
| 320 |
+
axes[1, 1].set_title('Patch Flux')
|
| 321 |
+
axes[1, 1].axis('off')
|
| 322 |
+
|
| 323 |
+
# Plot 6: Attention statistics
|
| 324 |
+
axes[1, 2].hist(attention_map.flatten(), bins=50, alpha=0.7)
|
| 325 |
+
axes[1, 2].set_title('Attention Distribution')
|
| 326 |
+
axes[1, 2].set_xlabel('Attention Weight')
|
| 327 |
+
axes[1, 2].set_ylabel('Frequency')
|
| 328 |
|
| 329 |
plt.tight_layout()
|
| 330 |
return fig
|
forecasting/training/localpatch.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
|
| 2 |
#Base directories - change these to switch datasets
|
| 3 |
-
base_data_dir: "/mnt/data/
|
| 4 |
-
base_checkpoint_dir: "/mnt/data/
|
| 5 |
wavelengths: [94, 131, 171, 193, 211, 304, 335] # AIA wavelengths in Angstroms
|
| 6 |
|
| 7 |
# GPU configuration
|
|
@@ -13,8 +13,8 @@ wavelengths: [94, 131, 171, 193, 211, 304, 335] # AIA wavelengths in Angstroms
|
|
| 13 |
gpu_ids: "all" # Use both GPUs
|
| 14 |
# Model configuration
|
| 15 |
selected_model: "ViTLocal" # Options: "hybrid", "vit", "fusion", "vitpatch"
|
| 16 |
-
batch_size:
|
| 17 |
-
epochs:
|
| 18 |
oversample: false
|
| 19 |
balance_strategy: "upsample_minority"
|
| 20 |
calculate_base_weights: false # Whether to calculate class-based weights for loss function
|
|
@@ -51,5 +51,5 @@ wandb:
|
|
| 51 |
- aia
|
| 52 |
- sxr
|
| 53 |
- regression
|
| 54 |
-
wb_name: paper-testing-8-patch-335-512-hidden-6-layers-256-embed-dim
|
| 55 |
notes: Regression from AIA images (6 channels) to GOES SXR flux
|
|
|
|
| 1 |
|
| 2 |
#Base directories - change these to switch datasets
|
| 3 |
+
base_data_dir: "/mnt/data/PAPER_DATA_WITH_335" # Change this line for different datasets
|
| 4 |
+
base_checkpoint_dir: "/mnt/data/PAPER_DATA_WITH_335" # Change this line for different datasets
|
| 5 |
wavelengths: [94, 131, 171, 193, 211, 304, 335] # AIA wavelengths in Angstroms
|
| 6 |
|
| 7 |
# GPU configuration
|
|
|
|
| 13 |
gpu_ids: "all" # Use both GPUs
|
| 14 |
# Model configuration
|
| 15 |
selected_model: "ViTLocal" # Options: "hybrid", "vit", "fusion", "vitpatch"
|
| 16 |
+
batch_size: 6
|
| 17 |
+
epochs: 150
|
| 18 |
oversample: false
|
| 19 |
balance_strategy: "upsample_minority"
|
| 20 |
calculate_base_weights: false # Whether to calculate class-based weights for loss function
|
|
|
|
| 51 |
- aia
|
| 52 |
- sxr
|
| 53 |
- regression
|
| 54 |
+
wb_name: paper-testing-8-patch-335-512-hidden-6-layers-256-embed-dim-updated
|
| 55 |
notes: Regression from AIA images (6 channels) to GOES SXR flux
|