griffingoodwin04 commited on
Commit
1aeb490
·
1 Parent(s): f6b8791

Refactor SXRRegressionDynamicLoss weights and performance multipliers; adjust model configuration for ViTPatch with reduced patch size and increased number of heads; implement GPU memory isolation and monitoring in training script; enhance AttentionMapCallback to accept dynamic patch size.

Browse files
forecasting/models/vit_patch_model.py CHANGED
@@ -341,10 +341,10 @@ class SXRRegressionDynamicLoss:
341
  def _get_base_weights(self):
342
  #Calculate the base weights based on the number of samples in each class within training data
343
  return {
344
- 'quiet': 1.0,
345
- 'c_class': 2.0,
346
- 'm_class': 10.0,
347
- 'x_class': 20.0
348
  }
349
 
350
  def calculate_loss(self, preds_norm, sxr_norm, sxr_un):
@@ -360,16 +360,16 @@ class SXRRegressionDynamicLoss:
360
 
361
  # Get continuous multipliers per class with custom params
362
  quiet_mult = self._get_performance_multiplier(
363
- self.quiet_errors, max_multiplier=1.5, min_multiplier=0.6, sensitivity=0.2, sxrclass='quiet'
364
  )
365
  c_mult = self._get_performance_multiplier(
366
- self.c_errors, max_multiplier=2, min_multiplier=0.7, sensitivity=0.3, sxrclass='c_class'
367
  )
368
  m_mult = self._get_performance_multiplier(
369
- self.m_errors, max_multiplier=5.0, min_multiplier=0.8, sensitivity=0.4, sxrclass='m_class'
370
  )
371
  x_mult = self._get_performance_multiplier(
372
- self.x_errors, max_multiplier=8.0, min_multiplier=0.8, sensitivity=.5, sxrclass='x_class'
373
  )
374
 
375
  quiet_weight = self.base_weights['quiet'] * quiet_mult
 
341
  def _get_base_weights(self):
342
  #Calculate the base weights based on the number of samples in each class within training data
343
  return {
344
+ 'quiet': 1.2110,
345
+ 'c_class': 1.2110,
346
+ 'm_class': 6.3106,
347
+ 'x_class': 63.4350
348
  }
349
 
350
  def calculate_loss(self, preds_norm, sxr_norm, sxr_un):
 
360
 
361
  # Get continuous multipliers per class with custom params
362
  quiet_mult = self._get_performance_multiplier(
363
+ self.quiet_errors, max_multiplier=1.5, min_multiplier=0.6, sensitivity=0.05, sxrclass='quiet' # Was 0.2
364
  )
365
  c_mult = self._get_performance_multiplier(
366
+ self.c_errors, max_multiplier=2, min_multiplier=0.7, sensitivity=0.08, sxrclass='c_class' # Was 0.3
367
  )
368
  m_mult = self._get_performance_multiplier(
369
+ self.m_errors, max_multiplier=5.0, min_multiplier=0.8, sensitivity=0.1, sxrclass='m_class' # Was 0.4
370
  )
371
  x_mult = self._get_performance_multiplier(
372
+ self.x_errors, max_multiplier=8.0, min_multiplier=0.8, sensitivity=0.12, sxrclass='x_class' # Was 0.5
373
  )
374
 
375
  quiet_weight = self.base_weights['quiet'] * quiet_mult
forecasting/training/callback.py CHANGED
@@ -90,7 +90,7 @@ class ImagePredictionLogger_SXR(Callback):
90
 
91
 
92
  class AttentionMapCallback(Callback):
93
- def __init__(self, log_every_n_epochs=1, num_samples=4, save_dir="attention_maps"):
94
  """
95
  Callback to visualize attention maps during training.
96
 
@@ -98,8 +98,10 @@ class AttentionMapCallback(Callback):
98
  log_every_n_epochs: How often to log attention maps
99
  num_samples: Number of samples to visualize
100
  save_dir: Directory to save attention maps
 
101
  """
102
  super().__init__()
 
103
  self.log_every_n_epochs = log_every_n_epochs
104
  self.num_samples = num_samples
105
  self.save_dir = save_dir
@@ -142,7 +144,7 @@ class AttentionMapCallback(Callback):
142
  attention_weights,
143
  sample_idx,
144
  trainer.current_epoch,
145
- patch_size=16
146
  )
147
  trainer.logger.experiment.log({"Attention plots": wandb.Image(map)})
148
  plt.close(map)
 
90
 
91
 
92
  class AttentionMapCallback(Callback):
93
+ def __init__(self, log_every_n_epochs=1, num_samples=4, save_dir="attention_maps", patch_size=8):
94
  """
95
  Callback to visualize attention maps during training.
96
 
 
98
  log_every_n_epochs: How often to log attention maps
99
  num_samples: Number of samples to visualize
100
  save_dir: Directory to save attention maps
101
+ patch_size: Size of patches used in the model
102
  """
103
  super().__init__()
104
+ self.patch_size = patch_size
105
  self.log_every_n_epochs = log_every_n_epochs
106
  self.num_samples = num_samples
107
  self.save_dir = save_dir
 
144
  attention_weights,
145
  sample_idx,
146
  trainer.current_epoch,
147
+ patch_size=self.patch_size
148
  )
149
  trainer.logger.experiment.log({"Attention plots": wandb.Image(map)})
150
  plt.close(map)
forecasting/training/config.yaml CHANGED
@@ -25,11 +25,11 @@ vit_custom:
25
  embed_dim: 512
26
  num_channels: 6
27
  num_classes: 1
28
- patch_size: 16
29
- num_patches: 1024
30
  hidden_dim: 512
31
- num_heads: 8
32
- num_layers: 6
33
  dropout: 0.1
34
  lr: 0.0001
35
 
@@ -67,5 +67,5 @@ wandb:
67
  - aia
68
  - sxr
69
  - regression
70
- wb_name: baseline-model-more-complex
71
  notes: Regression from AIA images (6 channels) to GOES SXR flux
 
25
  embed_dim: 512
26
  num_channels: 6
27
  num_classes: 1
28
+ patch_size: 8
29
+ num_patches: 4096
30
  hidden_dim: 512
31
+ num_heads: 12 # Increased from 8
32
+ num_layers: 4 # Reduced from 6
33
  dropout: 0.1
34
  lr: 0.0001
35
 
 
67
  - aia
68
  - sxr
69
  - regression
70
+ wb_name:
71
  notes: Regression from AIA images (6 channels) to GOES SXR flux
forecasting/training/config2.yaml CHANGED
@@ -3,12 +3,16 @@
3
  base_data_dir: "/mnt/data/COMBINED" # Change this line for different datasets
4
  base_checkpoint_dir: "/mnt/data/COMBINED" # Change this line for different datasets
5
  wavelengths: [94, 131, 171, 193, 211, 304] # AIA wavelengths in Angstroms
 
 
 
6
  # Model configuration
7
  selected_model: "ViTPatch" # Options: "hybrid", "vit", "fusion", "vitpatch"
8
  batch_size: 64
9
  epochs: 250
10
  oversample: false
11
  balance_strategy: "upsample_minority"
 
12
 
13
  megsai:
14
  architecture: "cnn"
@@ -67,5 +71,5 @@ wandb:
67
  - aia
68
  - sxr
69
  - regression
70
- wb_name: vit-patch-model-2d-embeddings
71
  notes: Regression from AIA images (6 channels) to GOES SXR flux
 
3
  base_data_dir: "/mnt/data/COMBINED" # Change this line for different datasets
4
  base_checkpoint_dir: "/mnt/data/COMBINED" # Change this line for different datasets
5
  wavelengths: [94, 131, 171, 193, 211, 304] # AIA wavelengths in Angstroms
6
+
7
+ # GPU configuration
8
+ gpu_id: 0 # GPU device ID to use (0, 1, 2, etc.) or -1 for CPU only
9
  # Model configuration
10
  selected_model: "ViTPatch" # Options: "hybrid", "vit", "fusion", "vitpatch"
11
  batch_size: 64
12
  epochs: 250
13
  oversample: false
14
  balance_strategy: "upsample_minority"
15
+ calculate_base_weights: false # Whether to calculate class-based weights for loss function
16
 
17
  megsai:
18
  architecture: "cnn"
 
71
  - aia
72
  - sxr
73
  - regression
74
+ wb_name: vit-patch-model-2d-embeddings-reduced-sensitivity
75
  notes: Regression from AIA images (6 channels) to GOES SXR flux
forecasting/training/config4.yaml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Base directories - change these to switch datasets
2
+ base_data_dir: "/mnt/data/COMBINED" # Change this line for different datasets
3
+ base_checkpoint_dir: "/mnt/data/COMBINED" # Change this line for different datasets
4
+ wavelengths: [171, 193, 211, 304] # AIA wavelengths in Angstroms
5
+
6
+ # GPU configuration
7
+ gpu_id: 1 # GPU device ID to use (0, 1, 2, etc.) or -1 for CPU only
8
+ # Model configuration
9
+ selected_model: "ViTPatch" # Options: "hybrid", "vit", "fusion", "vitpatch"
10
+ batch_size: 64
11
+ epochs: 250
12
+ oversample: false
13
+ balance_strategy: "upsample_minority"
14
+ calculate_base_weights: false # Whether to calculate class-based weights for loss function
15
+
16
+ megsai:
17
+ architecture: "cnn"
18
+ seed: 42
19
+ lr: 0.0001
20
+ cnn_model: "updated"
21
+ cnn_dp: 0.2
22
+ weight_decay: 1e-5
23
+ cosine_restart_T0: 50
24
+ cosine_restart_Tmult: 2
25
+ cosine_eta_min: 1e-7
26
+
27
+ vit_custom:
28
+ embed_dim: 512
29
+ num_channels: 4
30
+ num_classes: 1
31
+ patch_size: 16
32
+ num_patches: 1024
33
+ hidden_dim: 512
34
+ num_heads: 8
35
+ num_layers: 6
36
+ dropout: 0.1
37
+ lr: 0.0001
38
+
39
+
40
+ fusion:
41
+ scalar_branch: "hybrid" # or "linear"
42
+ lr: 0.0001
43
+ lambda_vit_to_target: 0.3
44
+ lambda_scalar_to_target: 0.1
45
+ learnable_gate: true
46
+ gate_init_bias: 5.0
47
+ scalar_kwargs:
48
+ d_input: 6
49
+ d_output: 1
50
+ cnn_model: "updated"
51
+ cnn_dp: 0.75
52
+
53
+
54
+ # Data paths (automatically constructed from base directories)
55
+ data:
56
+ aia_dir:
57
+ "${base_data_dir}/AIA-SPLIT"
58
+ sxr_dir:
59
+ "${base_data_dir}/SXR-SPLIT"
60
+ sxr_norm_path:
61
+ "${base_data_dir}/SXR-SPLIT/normalized_sxr.npy"
62
+ checkpoints_dir:
63
+ "${base_checkpoint_dir}/new-checkpoint/"
64
+
65
+ wandb:
66
+ entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
67
+ project: Model Testing
68
+ job_type: training
69
+ tags:
70
+ - aia
71
+ - sxr
72
+ - regression
73
+ wb_name: vit-patch-model-2d-embeddings-reduced-sensitivity-STEREO
74
+ notes: Regression from AIA images (6 channels) to GOES SXR flux
forecasting/training/config5.yaml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #Base directories - change these to switch datasets
3
+ base_data_dir: "/mnt/data/COMBINED" # Change this line for different datasets
4
+ base_checkpoint_dir: "/mnt/data/COMBINED" # Change this line for different datasets
5
+ wavelengths: [94, 131, 171, 193, 211, 304] # AIA wavelengths in Angstroms
6
+
7
+ # GPU configuration
8
+ gpu_id: 2 # GPU device ID to use (0, 1, 2, etc.) or -1 for CPU only
9
+ # Model configuration
10
+ selected_model: "vit" # Options: "hybrid", "vit", "fusion", "vitpatch"
11
+ batch_size: 64
12
+ epochs: 250
13
+ oversample: false
14
+ balance_strategy: "upsample_minority"
15
+ calculate_base_weights: false # Whether to calculate class-based weights for loss function
16
+
17
+ megsai:
18
+ architecture: "cnn"
19
+ seed: 42
20
+ lr: 0.0001
21
+ cnn_model: "updated"
22
+ cnn_dp: 0.2
23
+ weight_decay: 1e-5
24
+ cosine_restart_T0: 50
25
+ cosine_restart_Tmult: 2
26
+ cosine_eta_min: 1e-7
27
+
28
+ vit_custom:
29
+ embed_dim: 512
30
+ num_channels: 6
31
+ num_classes: 1
32
+ patch_size: 16
33
+ num_patches: 1024
34
+ hidden_dim: 512
35
+ num_heads: 8
36
+ num_layers: 6
37
+ dropout: 0.1
38
+ lr: 0.0001
39
+
40
+
41
+ fusion:
42
+ scalar_branch: "hybrid" # or "linear"
43
+ lr: 0.0001
44
+ lambda_vit_to_target: 0.3
45
+ lambda_scalar_to_target: 0.1
46
+ learnable_gate: true
47
+ gate_init_bias: 5.0
48
+ scalar_kwargs:
49
+ d_input: 6
50
+ d_output: 1
51
+ cnn_model: "updated"
52
+ cnn_dp: 0.75
53
+
54
+
55
+ # Data paths (automatically constructed from base directories)
56
+ data:
57
+ aia_dir:
58
+ "${base_data_dir}/AIA-SPLIT"
59
+ sxr_dir:
60
+ "${base_data_dir}/SXR-SPLIT"
61
+ sxr_norm_path:
62
+ "${base_data_dir}/SXR-SPLIT/normalized_sxr.npy"
63
+ checkpoints_dir:
64
+ "${base_checkpoint_dir}/new-checkpoint/"
65
+
66
+ wandb:
67
+ entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
68
+ project: Model Testing
69
+ job_type: training
70
+ tags:
71
+ - aia
72
+ - sxr
73
+ - regression
74
+ wb_name: vit-patch-model-2d-embeddings-reduced-sensitivity
75
+ notes: Regression from AIA images (6 channels) to GOES SXR flux
forecasting/training/config6.yaml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #Base directories - change these to switch datasets
3
+ base_data_dir: "/mnt/data/COMBINED" # Change this line for different datasets
4
+ base_checkpoint_dir: "/mnt/data/COMBINED" # Change this line for different datasets
5
+ wavelengths: [94, 131, 171, 193, 211, 304] # AIA wavelengths in Angstroms
6
+
7
+ # GPU configuration
8
+ gpu_id: 3 # GPU device ID to use (0, 1, 2, etc.) or -1 for CPU only
9
+ # Model configuration
10
+ selected_model: "ViTPatch" # Options: "hybrid", "vit", "fusion", "vitpatch"
11
+ batch_size: 64
12
+ epochs: 250
13
+ oversample: false
14
+ balance_strategy: "upsample_minority"
15
+ calculate_base_weights: false # Whether to calculate class-based weights for loss function
16
+
17
+ megsai:
18
+ architecture: "cnn"
19
+ seed: 42
20
+ lr: 0.0001
21
+ cnn_model: "updated"
22
+ cnn_dp: 0.2
23
+ weight_decay: 1e-5
24
+ cosine_restart_T0: 50
25
+ cosine_restart_Tmult: 2
26
+ cosine_eta_min: 1e-7
27
+
28
+ vit_custom:
29
+ embed_dim: 512
30
+ num_channels: 6
31
+ num_classes: 1
32
+ patch_size: 16
33
+ num_patches: 1024
34
+ hidden_dim: 512
35
+ num_heads: 8
36
+ num_layers: 6
37
+ dropout: 0.1
38
+ lr: 0.001
39
+
40
+
41
+ fusion:
42
+ scalar_branch: "hybrid" # or "linear"
43
+ lr: 0.0001
44
+ lambda_vit_to_target: 0.3
45
+ lambda_scalar_to_target: 0.1
46
+ learnable_gate: true
47
+ gate_init_bias: 5.0
48
+ scalar_kwargs:
49
+ d_input: 6
50
+ d_output: 1
51
+ cnn_model: "updated"
52
+ cnn_dp: 0.75
53
+
54
+
55
+ # Data paths (automatically constructed from base directories)
56
+ data:
57
+ aia_dir:
58
+ "${base_data_dir}/AIA-SPLIT"
59
+ sxr_dir:
60
+ "${base_data_dir}/SXR-SPLIT"
61
+ sxr_norm_path:
62
+ "${base_data_dir}/SXR-SPLIT/normalized_sxr.npy"
63
+ checkpoints_dir:
64
+ "${base_checkpoint_dir}/new-checkpoint/"
65
+
66
+ wandb:
67
+ entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
68
+ project: Model Testing
69
+ job_type: training
70
+ tags:
71
+ - aia
72
+ - sxr
73
+ - regression
74
+ wb_name: vit-patch-model-2d-embeddings-reduced-sensitivity-higher-lr
75
+ notes: Regression from AIA images (6 channels) to GOES SXR flux
forecasting/training/train.py CHANGED
@@ -32,6 +32,18 @@ from forecasting.models.FastSpectralNet import FastViTFlaringModel
32
 
33
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
34
  os.environ["NCCL_DEBUG"] = "WARN"
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def resolve_config_variables(config_dict):
37
  """Recursively resolve ${variable} references within the config"""
@@ -75,11 +87,27 @@ with open(args.config, 'r') as stream:
75
  # Resolve variables like ${base_data_dir}
76
  config_data = resolve_config_variables(config_data)
77
 
78
- # Debug: Print resolved paths
79
- print("Resolved paths:")
80
- print(f"AIA dir: {config_data['data']['aia_dir']}")
81
- print(f"SXR dir: {config_data['data']['sxr_dir']}")
82
- print(f"Checkpoints dir: {config_data['data']['checkpoints_dir']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  # Debug: Print resolved paths
85
  print("Resolved paths:")
@@ -106,7 +134,7 @@ data_loader = AIA_GOESDataModule(
106
  sxr_val_dir=config_data['data']['sxr_dir']+"/val",
107
  sxr_test_dir=config_data['data']['sxr_dir']+"/test",
108
  batch_size=config_data['batch_size'],
109
- num_workers=os.cpu_count(),
110
  sxr_norm=sxr_norm,
111
  wavelengths=training_wavelengths,
112
  oversample=config_data['oversample'],
@@ -114,6 +142,9 @@ data_loader = AIA_GOESDataModule(
114
  )
115
  data_loader.setup()
116
 
 
 
 
117
  # Logger
118
  #wb_name = f"{instrument}_{n}" if len(combined_parameters) > 1 else "aia_sxr_model"
119
  wandb_logger = WandbLogger(
@@ -133,8 +164,9 @@ plot_samples = plot_data # Keep as list of ((aia, sxr), target)
133
  #sxr_callback = SXRPredictionLogger(plot_samples)
134
 
135
  sxr_plot_callback = ImagePredictionLogger_SXR(plot_samples, sxr_norm)
136
- # Attention map callback
137
- attention = AttentionMapCallback()
 
138
 
139
 
140
  class PTHCheckpointCallback(Callback):
@@ -308,7 +340,9 @@ elif config_data['selected_model'] == 'ViT':
308
  model = ViT(model_kwargs=config_data['vit_custom'], sxr_norm = sxr_norm)
309
 
310
  elif config_data['selected_model'] == 'ViTPatch':
311
- model = ViTPatch(model_kwargs=config_data['vit_custom'], sxr_norm = sxr_norm, base_weights=get_base_weights(data_loader, sxr_norm))
 
 
312
 
313
  elif config_data['selected_model'] == 'FusionViTHybrid':
314
  # Expect a 'fusion' section in YAML
@@ -338,12 +372,32 @@ elif config_data['selected_model'] == 'FusionViTHybrid':
338
  else:
339
  raise NotImplementedError(f"Architecture {config_data['selected_model']} not supported.")
340
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  # Trainer
342
  if config_data['selected_model'] == 'ViT' or config_data['selected_model'] == 'ViTPatch' or config_data['selected_model'] == 'FusionViTHybrid':
343
  trainer = Trainer(
344
  default_root_dir=config_data['data']['checkpoints_dir'],
345
- accelerator="gpu" if torch.cuda.is_available() else "cpu",
346
- devices=1,
347
  max_epochs=config_data['epochs'],
348
  callbacks=[attention, checkpoint_callback],
349
  logger=wandb_logger,
@@ -352,8 +406,8 @@ if config_data['selected_model'] == 'ViT' or config_data['selected_model'] == 'V
352
  else:
353
  trainer = Trainer(
354
  default_root_dir=config_data['data']['checkpoints_dir'],
355
- accelerator="gpu" if torch.cuda.is_available() else "cpu",
356
- devices=1,
357
  max_epochs=config_data['epochs'],
358
  callbacks=[checkpoint_callback],
359
  logger=wandb_logger,
 
32
 
33
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
34
  os.environ["NCCL_DEBUG"] = "WARN"
35
+ # Shared memory optimizations
36
+ os.environ["OMP_NUM_THREADS"] = "1" # Limit OpenMP threads
37
+ os.environ["MKL_NUM_THREADS"] = "1" # Limit MKL threads
38
+
39
+ def print_gpu_memory(stage=""):
40
+ """Print GPU memory usage for monitoring"""
41
+ if torch.cuda.is_available():
42
+ allocated = torch.cuda.memory_allocated() / 1e9
43
+ reserved = torch.cuda.memory_reserved() / 1e9
44
+ print(f"GPU Memory {stage} - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")
45
+ else:
46
+ print(f"No GPU available for memory monitoring {stage}")
47
 
48
  def resolve_config_variables(config_dict):
49
  """Recursively resolve ${variable} references within the config"""
 
87
  # Resolve variables like ${base_data_dir}
88
  config_data = resolve_config_variables(config_data)
89
 
90
+ # GPU Memory Isolation for Multi-GPU Systems
91
+ gpu_id = config_data.get('gpu_id', 0)
92
+ if gpu_id != -1: # Only if using GPU
93
+ # Set CUDA device visibility to only the specified GPU
94
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
95
+ print(f"Set CUDA_VISIBLE_DEVICES to GPU {gpu_id}")
96
+
97
+ # Clear any existing CUDA cache
98
+ if torch.cuda.is_available():
99
+ torch.cuda.empty_cache()
100
+ print(f"Cleared CUDA cache for GPU {gpu_id}")
101
+
102
+ # Set memory allocation strategy for better isolation
103
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,roundup_power2_divisions:16"
104
+
105
+ # Disable memory sharing between processes
106
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
107
+
108
+ print(f"GPU Memory Isolation configured for GPU {gpu_id}")
109
+ else:
110
+ print("Using CPU - no GPU memory isolation needed")
111
 
112
  # Debug: Print resolved paths
113
  print("Resolved paths:")
 
134
  sxr_val_dir=config_data['data']['sxr_dir']+"/val",
135
  sxr_test_dir=config_data['data']['sxr_dir']+"/test",
136
  batch_size=config_data['batch_size'],
137
+ num_workers=min(8, os.cpu_count()), # Limit workers to prevent shm issues
138
  sxr_norm=sxr_norm,
139
  wavelengths=training_wavelengths,
140
  oversample=config_data['oversample'],
 
142
  )
143
  data_loader.setup()
144
 
145
+ # Monitor memory after data loading
146
+ print_gpu_memory("after data loading")
147
+
148
  # Logger
149
  #wb_name = f"{instrument}_{n}" if len(combined_parameters) > 1 else "aia_sxr_model"
150
  wandb_logger = WandbLogger(
 
164
  #sxr_callback = SXRPredictionLogger(plot_samples)
165
 
166
  sxr_plot_callback = ImagePredictionLogger_SXR(plot_samples, sxr_norm)
167
+ # Attention map callback - get patch size from config
168
+ patch_size = config_data.get('vit_custom', {}).get('patch_size', 8)
169
+ attention = AttentionMapCallback(patch_size=patch_size)
170
 
171
 
172
  class PTHCheckpointCallback(Callback):
 
340
  model = ViT(model_kwargs=config_data['vit_custom'], sxr_norm = sxr_norm)
341
 
342
  elif config_data['selected_model'] == 'ViTPatch':
343
+ # Calculate base weights only if configured to do so
344
+ base_weights = get_base_weights(data_loader, sxr_norm) if config_data.get('calculate_base_weights', True) else None
345
+ model = ViTPatch(model_kwargs=config_data['vit_custom'], sxr_norm = sxr_norm, base_weights=base_weights)
346
 
347
  elif config_data['selected_model'] == 'FusionViTHybrid':
348
  # Expect a 'fusion' section in YAML
 
372
  else:
373
  raise NotImplementedError(f"Architecture {config_data['selected_model']} not supported.")
374
 
375
+ # Monitor memory after model creation
376
+ print_gpu_memory("after model creation")
377
+
378
+ # Set device based on config
379
+ gpu_id = config_data.get('gpu_id', 0)
380
+ if gpu_id == -1:
381
+ accelerator = "cpu"
382
+ devices = 1
383
+ print("Using CPU for training")
384
+ else:
385
+ if torch.cuda.is_available():
386
+ accelerator = "gpu"
387
+ # When CUDA_VISIBLE_DEVICES is set, PyTorch Lightning only sees GPU 0
388
+ devices = [0] # Always use device 0 since we've isolated to specific GPU
389
+ print(f"Using GPU {gpu_id} for training (mapped to device 0 after CUDA_VISIBLE_DEVICES)")
390
+ else:
391
+ accelerator = "cpu"
392
+ devices = 1
393
+ print(f"GPU {gpu_id} not available, falling back to CPU")
394
+
395
  # Trainer
396
  if config_data['selected_model'] == 'ViT' or config_data['selected_model'] == 'ViTPatch' or config_data['selected_model'] == 'FusionViTHybrid':
397
  trainer = Trainer(
398
  default_root_dir=config_data['data']['checkpoints_dir'],
399
+ accelerator=accelerator,
400
+ devices=devices,
401
  max_epochs=config_data['epochs'],
402
  callbacks=[attention, checkpoint_callback],
403
  logger=wandb_logger,
 
406
  else:
407
  trainer = Trainer(
408
  default_root_dir=config_data['data']['checkpoints_dir'],
409
+ accelerator=accelerator,
410
+ devices=devices,
411
  max_epochs=config_data['epochs'],
412
  callbacks=[checkpoint_callback],
413
  logger=wandb_logger,