Commit ·
1aeb490
1
Parent(s): f6b8791
Refactor SXRRegressionDynamicLoss weights and performance multipliers; adjust model configuration for ViTPatch with reduced patch size and increased number of heads; implement GPU memory isolation and monitoring in training script; enhance AttentionMapCallback to accept dynamic patch size.
Browse files- forecasting/models/vit_patch_model.py +8 -8
- forecasting/training/callback.py +4 -2
- forecasting/training/config.yaml +5 -5
- forecasting/training/config2.yaml +5 -1
- forecasting/training/config4.yaml +74 -0
- forecasting/training/config5.yaml +75 -0
- forecasting/training/config6.yaml +75 -0
- forecasting/training/train.py +67 -13
forecasting/models/vit_patch_model.py
CHANGED
|
@@ -341,10 +341,10 @@ class SXRRegressionDynamicLoss:
|
|
| 341 |
def _get_base_weights(self):
|
| 342 |
#Calculate the base weights based on the number of samples in each class within training data
|
| 343 |
return {
|
| 344 |
-
'quiet': 1.
|
| 345 |
-
'c_class':
|
| 346 |
-
'm_class':
|
| 347 |
-
'x_class':
|
| 348 |
}
|
| 349 |
|
| 350 |
def calculate_loss(self, preds_norm, sxr_norm, sxr_un):
|
|
@@ -360,16 +360,16 @@ class SXRRegressionDynamicLoss:
|
|
| 360 |
|
| 361 |
# Get continuous multipliers per class with custom params
|
| 362 |
quiet_mult = self._get_performance_multiplier(
|
| 363 |
-
self.quiet_errors, max_multiplier=1.5, min_multiplier=0.6, sensitivity=0.
|
| 364 |
)
|
| 365 |
c_mult = self._get_performance_multiplier(
|
| 366 |
-
self.c_errors, max_multiplier=2, min_multiplier=0.7, sensitivity=0.
|
| 367 |
)
|
| 368 |
m_mult = self._get_performance_multiplier(
|
| 369 |
-
self.m_errors, max_multiplier=5.0, min_multiplier=0.8, sensitivity=0.
|
| 370 |
)
|
| 371 |
x_mult = self._get_performance_multiplier(
|
| 372 |
-
self.x_errors, max_multiplier=8.0, min_multiplier=0.8, sensitivity=.
|
| 373 |
)
|
| 374 |
|
| 375 |
quiet_weight = self.base_weights['quiet'] * quiet_mult
|
|
|
|
| 341 |
def _get_base_weights(self):
|
| 342 |
#Calculate the base weights based on the number of samples in each class within training data
|
| 343 |
return {
|
| 344 |
+
'quiet': 1.2110,
|
| 345 |
+
'c_class': 1.2110,
|
| 346 |
+
'm_class': 6.3106,
|
| 347 |
+
'x_class': 63.4350
|
| 348 |
}
|
| 349 |
|
| 350 |
def calculate_loss(self, preds_norm, sxr_norm, sxr_un):
|
|
|
|
| 360 |
|
| 361 |
# Get continuous multipliers per class with custom params
|
| 362 |
quiet_mult = self._get_performance_multiplier(
|
| 363 |
+
self.quiet_errors, max_multiplier=1.5, min_multiplier=0.6, sensitivity=0.05, sxrclass='quiet' # Was 0.2
|
| 364 |
)
|
| 365 |
c_mult = self._get_performance_multiplier(
|
| 366 |
+
self.c_errors, max_multiplier=2, min_multiplier=0.7, sensitivity=0.08, sxrclass='c_class' # Was 0.3
|
| 367 |
)
|
| 368 |
m_mult = self._get_performance_multiplier(
|
| 369 |
+
self.m_errors, max_multiplier=5.0, min_multiplier=0.8, sensitivity=0.1, sxrclass='m_class' # Was 0.4
|
| 370 |
)
|
| 371 |
x_mult = self._get_performance_multiplier(
|
| 372 |
+
self.x_errors, max_multiplier=8.0, min_multiplier=0.8, sensitivity=0.12, sxrclass='x_class' # Was 0.5
|
| 373 |
)
|
| 374 |
|
| 375 |
quiet_weight = self.base_weights['quiet'] * quiet_mult
|
forecasting/training/callback.py
CHANGED
|
@@ -90,7 +90,7 @@ class ImagePredictionLogger_SXR(Callback):
|
|
| 90 |
|
| 91 |
|
| 92 |
class AttentionMapCallback(Callback):
|
| 93 |
-
def __init__(self, log_every_n_epochs=1, num_samples=4, save_dir="attention_maps"):
|
| 94 |
"""
|
| 95 |
Callback to visualize attention maps during training.
|
| 96 |
|
|
@@ -98,8 +98,10 @@ class AttentionMapCallback(Callback):
|
|
| 98 |
log_every_n_epochs: How often to log attention maps
|
| 99 |
num_samples: Number of samples to visualize
|
| 100 |
save_dir: Directory to save attention maps
|
|
|
|
| 101 |
"""
|
| 102 |
super().__init__()
|
|
|
|
| 103 |
self.log_every_n_epochs = log_every_n_epochs
|
| 104 |
self.num_samples = num_samples
|
| 105 |
self.save_dir = save_dir
|
|
@@ -142,7 +144,7 @@ class AttentionMapCallback(Callback):
|
|
| 142 |
attention_weights,
|
| 143 |
sample_idx,
|
| 144 |
trainer.current_epoch,
|
| 145 |
-
patch_size=
|
| 146 |
)
|
| 147 |
trainer.logger.experiment.log({"Attention plots": wandb.Image(map)})
|
| 148 |
plt.close(map)
|
|
|
|
| 90 |
|
| 91 |
|
| 92 |
class AttentionMapCallback(Callback):
|
| 93 |
+
def __init__(self, log_every_n_epochs=1, num_samples=4, save_dir="attention_maps", patch_size=8):
|
| 94 |
"""
|
| 95 |
Callback to visualize attention maps during training.
|
| 96 |
|
|
|
|
| 98 |
log_every_n_epochs: How often to log attention maps
|
| 99 |
num_samples: Number of samples to visualize
|
| 100 |
save_dir: Directory to save attention maps
|
| 101 |
+
patch_size: Size of patches used in the model
|
| 102 |
"""
|
| 103 |
super().__init__()
|
| 104 |
+
self.patch_size = patch_size
|
| 105 |
self.log_every_n_epochs = log_every_n_epochs
|
| 106 |
self.num_samples = num_samples
|
| 107 |
self.save_dir = save_dir
|
|
|
|
| 144 |
attention_weights,
|
| 145 |
sample_idx,
|
| 146 |
trainer.current_epoch,
|
| 147 |
+
patch_size=self.patch_size
|
| 148 |
)
|
| 149 |
trainer.logger.experiment.log({"Attention plots": wandb.Image(map)})
|
| 150 |
plt.close(map)
|
forecasting/training/config.yaml
CHANGED
|
@@ -25,11 +25,11 @@ vit_custom:
|
|
| 25 |
embed_dim: 512
|
| 26 |
num_channels: 6
|
| 27 |
num_classes: 1
|
| 28 |
-
patch_size:
|
| 29 |
-
num_patches:
|
| 30 |
hidden_dim: 512
|
| 31 |
-
num_heads: 8
|
| 32 |
-
num_layers: 6
|
| 33 |
dropout: 0.1
|
| 34 |
lr: 0.0001
|
| 35 |
|
|
@@ -67,5 +67,5 @@ wandb:
|
|
| 67 |
- aia
|
| 68 |
- sxr
|
| 69 |
- regression
|
| 70 |
-
wb_name:
|
| 71 |
notes: Regression from AIA images (6 channels) to GOES SXR flux
|
|
|
|
| 25 |
embed_dim: 512
|
| 26 |
num_channels: 6
|
| 27 |
num_classes: 1
|
| 28 |
+
patch_size: 8
|
| 29 |
+
num_patches: 4096
|
| 30 |
hidden_dim: 512
|
| 31 |
+
num_heads: 12 # Increased from 8
|
| 32 |
+
num_layers: 4 # Reduced from 6
|
| 33 |
dropout: 0.1
|
| 34 |
lr: 0.0001
|
| 35 |
|
|
|
|
| 67 |
- aia
|
| 68 |
- sxr
|
| 69 |
- regression
|
| 70 |
+
wb_name:
|
| 71 |
notes: Regression from AIA images (6 channels) to GOES SXR flux
|
forecasting/training/config2.yaml
CHANGED
|
@@ -3,12 +3,16 @@
|
|
| 3 |
base_data_dir: "/mnt/data/COMBINED" # Change this line for different datasets
|
| 4 |
base_checkpoint_dir: "/mnt/data/COMBINED" # Change this line for different datasets
|
| 5 |
wavelengths: [94, 131, 171, 193, 211, 304] # AIA wavelengths in Angstroms
|
|
|
|
|
|
|
|
|
|
| 6 |
# Model configuration
|
| 7 |
selected_model: "ViTPatch" # Options: "hybrid", "vit", "fusion", "vitpatch"
|
| 8 |
batch_size: 64
|
| 9 |
epochs: 250
|
| 10 |
oversample: false
|
| 11 |
balance_strategy: "upsample_minority"
|
|
|
|
| 12 |
|
| 13 |
megsai:
|
| 14 |
architecture: "cnn"
|
|
@@ -67,5 +71,5 @@ wandb:
|
|
| 67 |
- aia
|
| 68 |
- sxr
|
| 69 |
- regression
|
| 70 |
-
wb_name: vit-patch-model-2d-embeddings
|
| 71 |
notes: Regression from AIA images (6 channels) to GOES SXR flux
|
|
|
|
| 3 |
base_data_dir: "/mnt/data/COMBINED" # Change this line for different datasets
|
| 4 |
base_checkpoint_dir: "/mnt/data/COMBINED" # Change this line for different datasets
|
| 5 |
wavelengths: [94, 131, 171, 193, 211, 304] # AIA wavelengths in Angstroms
|
| 6 |
+
|
| 7 |
+
# GPU configuration
|
| 8 |
+
gpu_id: 0 # GPU device ID to use (0, 1, 2, etc.) or -1 for CPU only
|
| 9 |
# Model configuration
|
| 10 |
selected_model: "ViTPatch" # Options: "hybrid", "vit", "fusion", "vitpatch"
|
| 11 |
batch_size: 64
|
| 12 |
epochs: 250
|
| 13 |
oversample: false
|
| 14 |
balance_strategy: "upsample_minority"
|
| 15 |
+
calculate_base_weights: false # Whether to calculate class-based weights for loss function
|
| 16 |
|
| 17 |
megsai:
|
| 18 |
architecture: "cnn"
|
|
|
|
| 71 |
- aia
|
| 72 |
- sxr
|
| 73 |
- regression
|
| 74 |
+
wb_name: vit-patch-model-2d-embeddings-reduced-sensitivity
|
| 75 |
notes: Regression from AIA images (6 channels) to GOES SXR flux
|
forecasting/training/config4.yaml
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#Base directories - change these to switch datasets
|
| 2 |
+
base_data_dir: "/mnt/data/COMBINED" # Change this line for different datasets
|
| 3 |
+
base_checkpoint_dir: "/mnt/data/COMBINED" # Change this line for different datasets
|
| 4 |
+
wavelengths: [171, 193, 211, 304] # AIA wavelengths in Angstroms
|
| 5 |
+
|
| 6 |
+
# GPU configuration
|
| 7 |
+
gpu_id: 1 # GPU device ID to use (0, 1, 2, etc.) or -1 for CPU only
|
| 8 |
+
# Model configuration
|
| 9 |
+
selected_model: "ViTPatch" # Options: "hybrid", "vit", "fusion", "vitpatch"
|
| 10 |
+
batch_size: 64
|
| 11 |
+
epochs: 250
|
| 12 |
+
oversample: false
|
| 13 |
+
balance_strategy: "upsample_minority"
|
| 14 |
+
calculate_base_weights: false # Whether to calculate class-based weights for loss function
|
| 15 |
+
|
| 16 |
+
megsai:
|
| 17 |
+
architecture: "cnn"
|
| 18 |
+
seed: 42
|
| 19 |
+
lr: 0.0001
|
| 20 |
+
cnn_model: "updated"
|
| 21 |
+
cnn_dp: 0.2
|
| 22 |
+
weight_decay: 1e-5
|
| 23 |
+
cosine_restart_T0: 50
|
| 24 |
+
cosine_restart_Tmult: 2
|
| 25 |
+
cosine_eta_min: 1e-7
|
| 26 |
+
|
| 27 |
+
vit_custom:
|
| 28 |
+
embed_dim: 512
|
| 29 |
+
num_channels: 4
|
| 30 |
+
num_classes: 1
|
| 31 |
+
patch_size: 16
|
| 32 |
+
num_patches: 1024
|
| 33 |
+
hidden_dim: 512
|
| 34 |
+
num_heads: 8
|
| 35 |
+
num_layers: 6
|
| 36 |
+
dropout: 0.1
|
| 37 |
+
lr: 0.0001
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
fusion:
|
| 41 |
+
scalar_branch: "hybrid" # or "linear"
|
| 42 |
+
lr: 0.0001
|
| 43 |
+
lambda_vit_to_target: 0.3
|
| 44 |
+
lambda_scalar_to_target: 0.1
|
| 45 |
+
learnable_gate: true
|
| 46 |
+
gate_init_bias: 5.0
|
| 47 |
+
scalar_kwargs:
|
| 48 |
+
d_input: 6
|
| 49 |
+
d_output: 1
|
| 50 |
+
cnn_model: "updated"
|
| 51 |
+
cnn_dp: 0.75
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Data paths (automatically constructed from base directories)
|
| 55 |
+
data:
|
| 56 |
+
aia_dir:
|
| 57 |
+
"${base_data_dir}/AIA-SPLIT"
|
| 58 |
+
sxr_dir:
|
| 59 |
+
"${base_data_dir}/SXR-SPLIT"
|
| 60 |
+
sxr_norm_path:
|
| 61 |
+
"${base_data_dir}/SXR-SPLIT/normalized_sxr.npy"
|
| 62 |
+
checkpoints_dir:
|
| 63 |
+
"${base_checkpoint_dir}/new-checkpoint/"
|
| 64 |
+
|
| 65 |
+
wandb:
|
| 66 |
+
entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
|
| 67 |
+
project: Model Testing
|
| 68 |
+
job_type: training
|
| 69 |
+
tags:
|
| 70 |
+
- aia
|
| 71 |
+
- sxr
|
| 72 |
+
- regression
|
| 73 |
+
wb_name: vit-patch-model-2d-embeddings-reduced-sensitivity-STEREO
|
| 74 |
+
notes: Regression from AIA images (6 channels) to GOES SXR flux
|
forecasting/training/config5.yaml
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
#Base directories - change these to switch datasets
|
| 3 |
+
base_data_dir: "/mnt/data/COMBINED" # Change this line for different datasets
|
| 4 |
+
base_checkpoint_dir: "/mnt/data/COMBINED" # Change this line for different datasets
|
| 5 |
+
wavelengths: [94, 131, 171, 193, 211, 304] # AIA wavelengths in Angstroms
|
| 6 |
+
|
| 7 |
+
# GPU configuration
|
| 8 |
+
gpu_id: 2 # GPU device ID to use (0, 1, 2, etc.) or -1 for CPU only
|
| 9 |
+
# Model configuration
|
| 10 |
+
selected_model: "vit" # Options: "hybrid", "vit", "fusion", "vitpatch"
|
| 11 |
+
batch_size: 64
|
| 12 |
+
epochs: 250
|
| 13 |
+
oversample: false
|
| 14 |
+
balance_strategy: "upsample_minority"
|
| 15 |
+
calculate_base_weights: false # Whether to calculate class-based weights for loss function
|
| 16 |
+
|
| 17 |
+
megsai:
|
| 18 |
+
architecture: "cnn"
|
| 19 |
+
seed: 42
|
| 20 |
+
lr: 0.0001
|
| 21 |
+
cnn_model: "updated"
|
| 22 |
+
cnn_dp: 0.2
|
| 23 |
+
weight_decay: 1e-5
|
| 24 |
+
cosine_restart_T0: 50
|
| 25 |
+
cosine_restart_Tmult: 2
|
| 26 |
+
cosine_eta_min: 1e-7
|
| 27 |
+
|
| 28 |
+
vit_custom:
|
| 29 |
+
embed_dim: 512
|
| 30 |
+
num_channels: 6
|
| 31 |
+
num_classes: 1
|
| 32 |
+
patch_size: 16
|
| 33 |
+
num_patches: 1024
|
| 34 |
+
hidden_dim: 512
|
| 35 |
+
num_heads: 8
|
| 36 |
+
num_layers: 6
|
| 37 |
+
dropout: 0.1
|
| 38 |
+
lr: 0.0001
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
fusion:
|
| 42 |
+
scalar_branch: "hybrid" # or "linear"
|
| 43 |
+
lr: 0.0001
|
| 44 |
+
lambda_vit_to_target: 0.3
|
| 45 |
+
lambda_scalar_to_target: 0.1
|
| 46 |
+
learnable_gate: true
|
| 47 |
+
gate_init_bias: 5.0
|
| 48 |
+
scalar_kwargs:
|
| 49 |
+
d_input: 6
|
| 50 |
+
d_output: 1
|
| 51 |
+
cnn_model: "updated"
|
| 52 |
+
cnn_dp: 0.75
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Data paths (automatically constructed from base directories)
|
| 56 |
+
data:
|
| 57 |
+
aia_dir:
|
| 58 |
+
"${base_data_dir}/AIA-SPLIT"
|
| 59 |
+
sxr_dir:
|
| 60 |
+
"${base_data_dir}/SXR-SPLIT"
|
| 61 |
+
sxr_norm_path:
|
| 62 |
+
"${base_data_dir}/SXR-SPLIT/normalized_sxr.npy"
|
| 63 |
+
checkpoints_dir:
|
| 64 |
+
"${base_checkpoint_dir}/new-checkpoint/"
|
| 65 |
+
|
| 66 |
+
wandb:
|
| 67 |
+
entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
|
| 68 |
+
project: Model Testing
|
| 69 |
+
job_type: training
|
| 70 |
+
tags:
|
| 71 |
+
- aia
|
| 72 |
+
- sxr
|
| 73 |
+
- regression
|
| 74 |
+
wb_name: vit-patch-model-2d-embeddings-reduced-sensitivity
|
| 75 |
+
notes: Regression from AIA images (6 channels) to GOES SXR flux
|
forecasting/training/config6.yaml
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
#Base directories - change these to switch datasets
|
| 3 |
+
base_data_dir: "/mnt/data/COMBINED" # Change this line for different datasets
|
| 4 |
+
base_checkpoint_dir: "/mnt/data/COMBINED" # Change this line for different datasets
|
| 5 |
+
wavelengths: [94, 131, 171, 193, 211, 304] # AIA wavelengths in Angstroms
|
| 6 |
+
|
| 7 |
+
# GPU configuration
|
| 8 |
+
gpu_id: 3 # GPU device ID to use (0, 1, 2, etc.) or -1 for CPU only
|
| 9 |
+
# Model configuration
|
| 10 |
+
selected_model: "ViTPatch" # Options: "hybrid", "vit", "fusion", "vitpatch"
|
| 11 |
+
batch_size: 64
|
| 12 |
+
epochs: 250
|
| 13 |
+
oversample: false
|
| 14 |
+
balance_strategy: "upsample_minority"
|
| 15 |
+
calculate_base_weights: false # Whether to calculate class-based weights for loss function
|
| 16 |
+
|
| 17 |
+
megsai:
|
| 18 |
+
architecture: "cnn"
|
| 19 |
+
seed: 42
|
| 20 |
+
lr: 0.0001
|
| 21 |
+
cnn_model: "updated"
|
| 22 |
+
cnn_dp: 0.2
|
| 23 |
+
weight_decay: 1e-5
|
| 24 |
+
cosine_restart_T0: 50
|
| 25 |
+
cosine_restart_Tmult: 2
|
| 26 |
+
cosine_eta_min: 1e-7
|
| 27 |
+
|
| 28 |
+
vit_custom:
|
| 29 |
+
embed_dim: 512
|
| 30 |
+
num_channels: 6
|
| 31 |
+
num_classes: 1
|
| 32 |
+
patch_size: 16
|
| 33 |
+
num_patches: 1024
|
| 34 |
+
hidden_dim: 512
|
| 35 |
+
num_heads: 8
|
| 36 |
+
num_layers: 6
|
| 37 |
+
dropout: 0.1
|
| 38 |
+
lr: 0.001
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
fusion:
|
| 42 |
+
scalar_branch: "hybrid" # or "linear"
|
| 43 |
+
lr: 0.0001
|
| 44 |
+
lambda_vit_to_target: 0.3
|
| 45 |
+
lambda_scalar_to_target: 0.1
|
| 46 |
+
learnable_gate: true
|
| 47 |
+
gate_init_bias: 5.0
|
| 48 |
+
scalar_kwargs:
|
| 49 |
+
d_input: 6
|
| 50 |
+
d_output: 1
|
| 51 |
+
cnn_model: "updated"
|
| 52 |
+
cnn_dp: 0.75
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Data paths (automatically constructed from base directories)
|
| 56 |
+
data:
|
| 57 |
+
aia_dir:
|
| 58 |
+
"${base_data_dir}/AIA-SPLIT"
|
| 59 |
+
sxr_dir:
|
| 60 |
+
"${base_data_dir}/SXR-SPLIT"
|
| 61 |
+
sxr_norm_path:
|
| 62 |
+
"${base_data_dir}/SXR-SPLIT/normalized_sxr.npy"
|
| 63 |
+
checkpoints_dir:
|
| 64 |
+
"${base_checkpoint_dir}/new-checkpoint/"
|
| 65 |
+
|
| 66 |
+
wandb:
|
| 67 |
+
entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
|
| 68 |
+
project: Model Testing
|
| 69 |
+
job_type: training
|
| 70 |
+
tags:
|
| 71 |
+
- aia
|
| 72 |
+
- sxr
|
| 73 |
+
- regression
|
| 74 |
+
wb_name: vit-patch-model-2d-embeddings-reduced-sensitivity-higher-lr
|
| 75 |
+
notes: Regression from AIA images (6 channels) to GOES SXR flux
|
forecasting/training/train.py
CHANGED
|
@@ -32,6 +32,18 @@ from forecasting.models.FastSpectralNet import FastViTFlaringModel
|
|
| 32 |
|
| 33 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 34 |
os.environ["NCCL_DEBUG"] = "WARN"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
def resolve_config_variables(config_dict):
|
| 37 |
"""Recursively resolve ${variable} references within the config"""
|
|
@@ -75,11 +87,27 @@ with open(args.config, 'r') as stream:
|
|
| 75 |
# Resolve variables like ${base_data_dir}
|
| 76 |
config_data = resolve_config_variables(config_data)
|
| 77 |
|
| 78 |
-
#
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
# Debug: Print resolved paths
|
| 85 |
print("Resolved paths:")
|
|
@@ -106,7 +134,7 @@ data_loader = AIA_GOESDataModule(
|
|
| 106 |
sxr_val_dir=config_data['data']['sxr_dir']+"/val",
|
| 107 |
sxr_test_dir=config_data['data']['sxr_dir']+"/test",
|
| 108 |
batch_size=config_data['batch_size'],
|
| 109 |
-
num_workers=os.cpu_count(),
|
| 110 |
sxr_norm=sxr_norm,
|
| 111 |
wavelengths=training_wavelengths,
|
| 112 |
oversample=config_data['oversample'],
|
|
@@ -114,6 +142,9 @@ data_loader = AIA_GOESDataModule(
|
|
| 114 |
)
|
| 115 |
data_loader.setup()
|
| 116 |
|
|
|
|
|
|
|
|
|
|
| 117 |
# Logger
|
| 118 |
#wb_name = f"{instrument}_{n}" if len(combined_parameters) > 1 else "aia_sxr_model"
|
| 119 |
wandb_logger = WandbLogger(
|
|
@@ -133,8 +164,9 @@ plot_samples = plot_data # Keep as list of ((aia, sxr), target)
|
|
| 133 |
#sxr_callback = SXRPredictionLogger(plot_samples)
|
| 134 |
|
| 135 |
sxr_plot_callback = ImagePredictionLogger_SXR(plot_samples, sxr_norm)
|
| 136 |
-
# Attention map callback
|
| 137 |
-
|
|
|
|
| 138 |
|
| 139 |
|
| 140 |
class PTHCheckpointCallback(Callback):
|
|
@@ -308,7 +340,9 @@ elif config_data['selected_model'] == 'ViT':
|
|
| 308 |
model = ViT(model_kwargs=config_data['vit_custom'], sxr_norm = sxr_norm)
|
| 309 |
|
| 310 |
elif config_data['selected_model'] == 'ViTPatch':
|
| 311 |
-
|
|
|
|
|
|
|
| 312 |
|
| 313 |
elif config_data['selected_model'] == 'FusionViTHybrid':
|
| 314 |
# Expect a 'fusion' section in YAML
|
|
@@ -338,12 +372,32 @@ elif config_data['selected_model'] == 'FusionViTHybrid':
|
|
| 338 |
else:
|
| 339 |
raise NotImplementedError(f"Architecture {config_data['selected_model']} not supported.")
|
| 340 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
# Trainer
|
| 342 |
if config_data['selected_model'] == 'ViT' or config_data['selected_model'] == 'ViTPatch' or config_data['selected_model'] == 'FusionViTHybrid':
|
| 343 |
trainer = Trainer(
|
| 344 |
default_root_dir=config_data['data']['checkpoints_dir'],
|
| 345 |
-
accelerator=
|
| 346 |
-
devices=
|
| 347 |
max_epochs=config_data['epochs'],
|
| 348 |
callbacks=[attention, checkpoint_callback],
|
| 349 |
logger=wandb_logger,
|
|
@@ -352,8 +406,8 @@ if config_data['selected_model'] == 'ViT' or config_data['selected_model'] == 'V
|
|
| 352 |
else:
|
| 353 |
trainer = Trainer(
|
| 354 |
default_root_dir=config_data['data']['checkpoints_dir'],
|
| 355 |
-
accelerator=
|
| 356 |
-
devices=
|
| 357 |
max_epochs=config_data['epochs'],
|
| 358 |
callbacks=[checkpoint_callback],
|
| 359 |
logger=wandb_logger,
|
|
|
|
| 32 |
|
| 33 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 34 |
os.environ["NCCL_DEBUG"] = "WARN"
|
| 35 |
+
# Shared memory optimizations
|
| 36 |
+
os.environ["OMP_NUM_THREADS"] = "1" # Limit OpenMP threads
|
| 37 |
+
os.environ["MKL_NUM_THREADS"] = "1" # Limit MKL threads
|
| 38 |
+
|
| 39 |
+
def print_gpu_memory(stage=""):
|
| 40 |
+
"""Print GPU memory usage for monitoring"""
|
| 41 |
+
if torch.cuda.is_available():
|
| 42 |
+
allocated = torch.cuda.memory_allocated() / 1e9
|
| 43 |
+
reserved = torch.cuda.memory_reserved() / 1e9
|
| 44 |
+
print(f"GPU Memory {stage} - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")
|
| 45 |
+
else:
|
| 46 |
+
print(f"No GPU available for memory monitoring {stage}")
|
| 47 |
|
| 48 |
def resolve_config_variables(config_dict):
|
| 49 |
"""Recursively resolve ${variable} references within the config"""
|
|
|
|
| 87 |
# Resolve variables like ${base_data_dir}
|
| 88 |
config_data = resolve_config_variables(config_data)
|
| 89 |
|
| 90 |
+
# GPU Memory Isolation for Multi-GPU Systems
|
| 91 |
+
gpu_id = config_data.get('gpu_id', 0)
|
| 92 |
+
if gpu_id != -1: # Only if using GPU
|
| 93 |
+
# Set CUDA device visibility to only the specified GPU
|
| 94 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
| 95 |
+
print(f"Set CUDA_VISIBLE_DEVICES to GPU {gpu_id}")
|
| 96 |
+
|
| 97 |
+
# Clear any existing CUDA cache
|
| 98 |
+
if torch.cuda.is_available():
|
| 99 |
+
torch.cuda.empty_cache()
|
| 100 |
+
print(f"Cleared CUDA cache for GPU {gpu_id}")
|
| 101 |
+
|
| 102 |
+
# Set memory allocation strategy for better isolation
|
| 103 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,roundup_power2_divisions:16"
|
| 104 |
+
|
| 105 |
+
# Disable memory sharing between processes
|
| 106 |
+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
| 107 |
+
|
| 108 |
+
print(f"GPU Memory Isolation configured for GPU {gpu_id}")
|
| 109 |
+
else:
|
| 110 |
+
print("Using CPU - no GPU memory isolation needed")
|
| 111 |
|
| 112 |
# Debug: Print resolved paths
|
| 113 |
print("Resolved paths:")
|
|
|
|
| 134 |
sxr_val_dir=config_data['data']['sxr_dir']+"/val",
|
| 135 |
sxr_test_dir=config_data['data']['sxr_dir']+"/test",
|
| 136 |
batch_size=config_data['batch_size'],
|
| 137 |
+
num_workers=min(8, os.cpu_count()), # Limit workers to prevent shm issues
|
| 138 |
sxr_norm=sxr_norm,
|
| 139 |
wavelengths=training_wavelengths,
|
| 140 |
oversample=config_data['oversample'],
|
|
|
|
| 142 |
)
|
| 143 |
data_loader.setup()
|
| 144 |
|
| 145 |
+
# Monitor memory after data loading
|
| 146 |
+
print_gpu_memory("after data loading")
|
| 147 |
+
|
| 148 |
# Logger
|
| 149 |
#wb_name = f"{instrument}_{n}" if len(combined_parameters) > 1 else "aia_sxr_model"
|
| 150 |
wandb_logger = WandbLogger(
|
|
|
|
| 164 |
#sxr_callback = SXRPredictionLogger(plot_samples)
|
| 165 |
|
| 166 |
sxr_plot_callback = ImagePredictionLogger_SXR(plot_samples, sxr_norm)
|
| 167 |
+
# Attention map callback - get patch size from config
|
| 168 |
+
patch_size = config_data.get('vit_custom', {}).get('patch_size', 8)
|
| 169 |
+
attention = AttentionMapCallback(patch_size=patch_size)
|
| 170 |
|
| 171 |
|
| 172 |
class PTHCheckpointCallback(Callback):
|
|
|
|
| 340 |
model = ViT(model_kwargs=config_data['vit_custom'], sxr_norm = sxr_norm)
|
| 341 |
|
| 342 |
elif config_data['selected_model'] == 'ViTPatch':
|
| 343 |
+
# Calculate base weights only if configured to do so
|
| 344 |
+
base_weights = get_base_weights(data_loader, sxr_norm) if config_data.get('calculate_base_weights', True) else None
|
| 345 |
+
model = ViTPatch(model_kwargs=config_data['vit_custom'], sxr_norm = sxr_norm, base_weights=base_weights)
|
| 346 |
|
| 347 |
elif config_data['selected_model'] == 'FusionViTHybrid':
|
| 348 |
# Expect a 'fusion' section in YAML
|
|
|
|
| 372 |
else:
|
| 373 |
raise NotImplementedError(f"Architecture {config_data['selected_model']} not supported.")
|
| 374 |
|
| 375 |
+
# Monitor memory after model creation
|
| 376 |
+
print_gpu_memory("after model creation")
|
| 377 |
+
|
| 378 |
+
# Set device based on config
|
| 379 |
+
gpu_id = config_data.get('gpu_id', 0)
|
| 380 |
+
if gpu_id == -1:
|
| 381 |
+
accelerator = "cpu"
|
| 382 |
+
devices = 1
|
| 383 |
+
print("Using CPU for training")
|
| 384 |
+
else:
|
| 385 |
+
if torch.cuda.is_available():
|
| 386 |
+
accelerator = "gpu"
|
| 387 |
+
# When CUDA_VISIBLE_DEVICES is set, PyTorch Lightning only sees GPU 0
|
| 388 |
+
devices = [0] # Always use device 0 since we've isolated to specific GPU
|
| 389 |
+
print(f"Using GPU {gpu_id} for training (mapped to device 0 after CUDA_VISIBLE_DEVICES)")
|
| 390 |
+
else:
|
| 391 |
+
accelerator = "cpu"
|
| 392 |
+
devices = 1
|
| 393 |
+
print(f"GPU {gpu_id} not available, falling back to CPU")
|
| 394 |
+
|
| 395 |
# Trainer
|
| 396 |
if config_data['selected_model'] == 'ViT' or config_data['selected_model'] == 'ViTPatch' or config_data['selected_model'] == 'FusionViTHybrid':
|
| 397 |
trainer = Trainer(
|
| 398 |
default_root_dir=config_data['data']['checkpoints_dir'],
|
| 399 |
+
accelerator=accelerator,
|
| 400 |
+
devices=devices,
|
| 401 |
max_epochs=config_data['epochs'],
|
| 402 |
callbacks=[attention, checkpoint_callback],
|
| 403 |
logger=wandb_logger,
|
|
|
|
| 406 |
else:
|
| 407 |
trainer = Trainer(
|
| 408 |
default_root_dir=config_data['data']['checkpoints_dir'],
|
| 409 |
+
accelerator=accelerator,
|
| 410 |
+
devices=devices,
|
| 411 |
max_epochs=config_data['epochs'],
|
| 412 |
callbacks=[checkpoint_callback],
|
| 413 |
logger=wandb_logger,
|