griffingoodwin04 commited on
Commit
2a0d39b
·
2 Parent(s): 6aa19b45729842

Merge remote-tracking branch 'origin/main' into dev-patch

Browse files
README.md CHANGED
@@ -1 +1,5 @@
1
- # 2025-HL-flaring-intelligence
 
 
 
 
 
1
+ <<<<<<< HEAD
2
+ # 2025-HL-flaring-intelligence
3
+ =======
4
+ # FOXES
5
+ >>>>>>> origin/main
forecasting/models/__init__.py CHANGED
@@ -1,2 +1 @@
1
- from .fusion_vit_hybrid import FusionViTHybrid
2
 
 
 
1
 
forecasting/models/vit_patch_model_local.py CHANGED
@@ -108,7 +108,7 @@ class ViTLocal(pl.LightningModule):
108
  if self.global_step % 200 == 0:
109
  multipliers = self.adaptive_loss.get_current_multipliers()
110
  for key, value in multipliers.items():
111
- self.log(f"adaptive/{key}", value, on_step=True, on_epoch=False)
112
 
113
  if mode == "val":
114
  # Validation: typically only log epoch aggregates
@@ -177,9 +177,8 @@ class VisionTransformerLocal(nn.Module):
177
  self.mlp_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1))
178
  self.dropout = nn.Dropout(dropout)
179
 
180
- # Parameters/Embeddings
181
- self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
182
- self.pos_embedding = nn.Parameter(torch.randn(1, 1 + num_patches, embed_dim))
183
  self.grid_h = int(math.sqrt(num_patches))
184
  self.grid_w = int(math.sqrt(num_patches))
185
  self.pos_embedding_2d = nn.Parameter(torch.randn(1, self.grid_h, self.grid_w, embed_dim))
@@ -289,7 +288,7 @@ class AttentionBlock(nn.Module):
289
 
290
 
291
  class LocalAttentionBlock(nn.Module):
292
- def __init__(self, embed_dim, hidden_dim, num_heads, num_patches, dropout=0.0, local_window=3):
293
  super().__init__()
294
  self.embed_dim = embed_dim
295
  self.num_heads = num_heads
 
108
  if self.global_step % 200 == 0:
109
  multipliers = self.adaptive_loss.get_current_multipliers()
110
  for key, value in multipliers.items():
111
+ self.log(f"adaptive/{key}", value, on_step=True, on_epoch=False, sync_dist=True)
112
 
113
  if mode == "val":
114
  # Validation: typically only log epoch aggregates
 
177
  self.mlp_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1))
178
  self.dropout = nn.Dropout(dropout)
179
 
180
+ # Parameters/Embeddings - using 2D positional encoding for local attention
181
+ # No CLS token needed for local attention architecture
 
182
  self.grid_h = int(math.sqrt(num_patches))
183
  self.grid_w = int(math.sqrt(num_patches))
184
  self.pos_embedding_2d = nn.Parameter(torch.randn(1, self.grid_h, self.grid_w, embed_dim))
 
288
 
289
 
290
  class LocalAttentionBlock(nn.Module):
291
+ def __init__(self, embed_dim, hidden_dim, num_heads, num_patches, dropout=0.0, local_window=9):
292
  super().__init__()
293
  self.embed_dim = embed_dim
294
  self.num_heads = num_heads
forecasting/training/callback.py CHANGED
@@ -90,7 +90,7 @@ class ImagePredictionLogger_SXR(Callback):
90
 
91
 
92
  class AttentionMapCallback(Callback):
93
- def __init__(self, log_every_n_epochs=1, num_samples=4, save_dir="attention_maps", patch_size=8):
94
  """
95
  Callback to visualize attention maps during training.
96
 
@@ -99,12 +99,14 @@ class AttentionMapCallback(Callback):
99
  num_samples: Number of samples to visualize
100
  save_dir: Directory to save attention maps
101
  patch_size: Size of patches used in the model
 
102
  """
103
  super().__init__()
104
  self.patch_size = patch_size
105
  self.log_every_n_epochs = log_every_n_epochs
106
  self.num_samples = num_samples
107
  self.save_dir = save_dir
 
108
 
109
  def on_validation_epoch_end(self, trainer, pl_module):
110
  if trainer.current_epoch % self.log_every_n_epochs == 0:
@@ -125,8 +127,8 @@ class AttentionMapCallback(Callback):
125
  # Move to device
126
  imgs = imgs[:self.num_samples].to(pl_module.device)
127
 
128
- # Get predictions with attention weights
129
- #Dynamically extract attention weights from the model
130
  try:
131
  outputs, attention_weights = pl_module(imgs, return_attention=True)
132
  except:
@@ -134,7 +136,7 @@ class AttentionMapCallback(Callback):
134
  if hasattr(pl_module, 'model') and hasattr(pl_module.model, 'forward'):
135
  try:
136
  print("Using model's forward method")
137
- outputs, attention_weights, _ = pl_module.model(imgs, pl_module.sxr_norm, return_attention=True)
138
  except:
139
  print("Using model's forward method failed")
140
  outputs, attention_weights = pl_module.forward_for_callback(imgs, return_attention=True)
@@ -149,12 +151,13 @@ class AttentionMapCallback(Callback):
149
  attention_weights,
150
  sample_idx,
151
  trainer.current_epoch,
152
- patch_size=self.patch_size
 
153
  )
154
  trainer.logger.experiment.log({"Attention plots": wandb.Image(map)})
155
  plt.close(map)
156
 
157
- def _plot_attention_map(self, image, attention_weights, sample_idx, epoch, patch_size):
158
  """
159
  Plot attention map for a single image.
160
 
@@ -164,50 +167,44 @@ class AttentionMapCallback(Callback):
164
  sample_idx: Index of the sample in the batch
165
  epoch: Current epoch number
166
  patch_size: Size of patches
 
167
  """
168
  # Convert image to numpy and transpose
169
  img_np = image.cpu().numpy()
170
  if len(img_np.shape) == 3 and img_np.shape[0] in [1, 3]: # Check if channels first
171
  img_np = np.transpose(img_np, (1, 2, 0))
172
 
 
 
 
173
 
174
  # Get attention from the last layer
175
  last_layer_attention = attention_weights[-1] # [B, num_heads, seq_len, seq_len]
176
-
177
  # Extract attention for this sample
178
  sample_attention = last_layer_attention[sample_idx] # [num_heads, seq_len, seq_len]
179
-
180
  # Average across heads
181
  avg_attention = sample_attention.mean(dim=0) # [seq_len, seq_len]
182
 
183
- # Get attention from CLS token to patches (exclude CLS->CLS)
184
- cls_attention = avg_attention[0, 1:].cpu() # [num_patches]
185
-
186
- # Calculate grid size - NOW USING CORRECT DIMENSIONS
187
- H, W = img_np.shape[:2] # Now this is correct after transpose
188
- grid_h, grid_w = H // patch_size, W // patch_size
 
 
 
 
 
 
 
 
 
 
189
 
190
- # Reshape attention to spatial grid
191
- attention_map = cls_attention.reshape(grid_h, grid_w)
192
-
193
- # Create figure with subplots
194
- fig, axes = plt.subplots(1, 3, figsize=(15, 5))
195
-
196
- # Plot 1: Original image
197
- # if img_np.shape[2] == 1: # Grayscale
198
- # img_display = (img_np[:, :, 0] + 1) / 2
199
- # axes[0].imshow(img_display, cmap='gray')
200
- # elif img_np.shape[2] == 3: # RGB
201
- # # Normalize RGB image properly
202
- # img_display = (img_np + 1) / 2 # Assuming images are in [-1, 1] range
203
- # img_display = np.clip(img_display, 0, 1) # Ensure valid range
204
- # axes[0].imshow(img_display)
205
- # else: # Multi-channel (6 channels in your case)
206
- # # Option 1: Display first channel as grayscale
207
- # img_display = (img_np[:, :, 0] + 1) / 2
208
- # axes[0].imshow(img_display, cmap='gray')
209
-
210
- # Option 2: Create RGB composite from 3 channels (uncomment if preferred)
211
  if len(img_np[0,0,:]) >= 6: # Ensure we have enough channels
212
  rgb_channels = [0, 2, 4] # Select which channels to use for R, G, B
213
  img_display = np.stack([(img_np[:, :, i] + 1) / 2 for i in rgb_channels], axis=2)
@@ -216,32 +213,88 @@ class AttentionMapCallback(Callback):
216
  # If not enough channels, use grayscale
217
  img_display = (img_np[:, :, 0] + 1) / 2
218
  img_display = np.stack([img_display] * 3, axis=2)
219
- axes[0].imshow(img_display)
220
- axes[0].set_title(f'Original Image (Epoch {epoch})')
221
- axes[0].axis('off')
222
-
223
- # Plot 2: Attention heatmap
224
- attention_np = np.log1p(attention_map.numpy())
225
- # Resize attention map to match image size
226
- attention_resized = zoom(attention_np, (H / grid_h, W / grid_w), order=1)
227
-
228
- # Create colormap for attention - FIX: Use the scalar values, not RGB
229
- im = axes[1].imshow(attention_resized, cmap='hot')
230
- axes[1].set_title(f'Attention Map (Sample {sample_idx})')
231
- axes[1].axis('off')
232
- # FIXED: Create colorbar from the scalar image, not RGB
233
- plt.colorbar(im, ax=axes[1])
234
-
235
- # Plot 3: Overlay attention on image
236
- #img_display_overlay = (img_np[:, :, 0] + 1) / 2
237
- axes[2].imshow(img_display)
238
-
239
- # Overlay attention with proper alpha blending
240
- axes[2].imshow(attention_resized, cmap='hot', alpha=0.5)
241
- axes[2].set_title(f'Log-Scaled Attention Overlay (Sample {sample_idx})')
242
- axes[2].axis('off')
243
 
244
- plt.tight_layout()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  plt.tight_layout()
247
  return fig
 
90
 
91
 
92
  class AttentionMapCallback(Callback):
93
+ def __init__(self, log_every_n_epochs=1, num_samples=4, save_dir="attention_maps", patch_size=8, use_local_attention=False):
94
  """
95
  Callback to visualize attention maps during training.
96
 
 
99
  num_samples: Number of samples to visualize
100
  save_dir: Directory to save attention maps
101
  patch_size: Size of patches used in the model
102
+ use_local_attention: If True, visualize local attention patterns instead of CLS token attention
103
  """
104
  super().__init__()
105
  self.patch_size = patch_size
106
  self.log_every_n_epochs = log_every_n_epochs
107
  self.num_samples = num_samples
108
  self.save_dir = save_dir
109
+ self.use_local_attention = use_local_attention
110
 
111
  def on_validation_epoch_end(self, trainer, pl_module):
112
  if trainer.current_epoch % self.log_every_n_epochs == 0:
 
127
  # Move to device
128
  imgs = imgs[:self.num_samples].to(pl_module.device)
129
 
130
+ # Get predictions with attention weights and patch contributions
131
+ patch_flux_raw = None
132
  try:
133
  outputs, attention_weights = pl_module(imgs, return_attention=True)
134
  except:
 
136
  if hasattr(pl_module, 'model') and hasattr(pl_module.model, 'forward'):
137
  try:
138
  print("Using model's forward method")
139
+ outputs, attention_weights, patch_flux_raw = pl_module.model(imgs, pl_module.sxr_norm, return_attention=True)
140
  except:
141
  print("Using model's forward method failed")
142
  outputs, attention_weights = pl_module.forward_for_callback(imgs, return_attention=True)
 
151
  attention_weights,
152
  sample_idx,
153
  trainer.current_epoch,
154
+ patch_size=self.patch_size,
155
+ patch_flux=patch_flux_raw[sample_idx] if patch_flux_raw is not None else None
156
  )
157
  trainer.logger.experiment.log({"Attention plots": wandb.Image(map)})
158
  plt.close(map)
159
 
160
+ def _plot_attention_map(self, image, attention_weights, sample_idx, epoch, patch_size, patch_flux=None):
161
  """
162
  Plot attention map for a single image.
163
 
 
167
  sample_idx: Index of the sample in the batch
168
  epoch: Current epoch number
169
  patch_size: Size of patches
170
+ patch_flux: Optional tensor of patch flux contributions [num_patches]
171
  """
172
  # Convert image to numpy and transpose
173
  img_np = image.cpu().numpy()
174
  if len(img_np.shape) == 3 and img_np.shape[0] in [1, 3]: # Check if channels first
175
  img_np = np.transpose(img_np, (1, 2, 0))
176
 
177
+ # Calculate grid size
178
+ H, W = img_np.shape[:2]
179
+ grid_h, grid_w = H // patch_size, W // patch_size
180
 
181
  # Get attention from the last layer
182
  last_layer_attention = attention_weights[-1] # [B, num_heads, seq_len, seq_len]
183
+
184
  # Extract attention for this sample
185
  sample_attention = last_layer_attention[sample_idx] # [num_heads, seq_len, seq_len]
186
+
187
  # Average across heads
188
  avg_attention = sample_attention.mean(dim=0) # [seq_len, seq_len]
189
 
190
+ if self.use_local_attention:
191
+ # For local attention: visualize attention patterns from center patch
192
+ # and average attention across all patches
193
+ center_patch_idx = (grid_h * grid_w) // 2 # Center patch
194
+ center_attention = avg_attention[center_patch_idx, :].cpu() # [num_patches]
195
+
196
+ # Average attention pattern (how much each patch attends to others on average)
197
+ avg_attention_map = avg_attention.mean(dim=0).cpu() # [num_patches]
198
+
199
+ attention_map = avg_attention_map.reshape(grid_h, grid_w)
200
+ center_map = center_attention.reshape(grid_h, grid_w)
201
+ else:
202
+ # For CLS token attention: visualize attention from CLS to patches
203
+ cls_attention = avg_attention[0, 1:].cpu() # [num_patches]
204
+ attention_map = cls_attention.reshape(grid_h, grid_w)
205
+ center_map = None
206
 
207
+ # Prepare image display
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  if len(img_np[0,0,:]) >= 6: # Ensure we have enough channels
209
  rgb_channels = [0, 2, 4] # Select which channels to use for R, G, B
210
  img_display = np.stack([(img_np[:, :, i] + 1) / 2 for i in rgb_channels], axis=2)
 
213
  # If not enough channels, use grayscale
214
  img_display = (img_np[:, :, 0] + 1) / 2
215
  img_display = np.stack([img_display] * 3, axis=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
+ # Create figure with appropriate number of subplots
218
+ if self.use_local_attention and patch_flux is not None:
219
+ # Show: Original, Avg Attention, Center Attention, Patch Flux
220
+ fig, axes = plt.subplots(1, 4, figsize=(20, 5))
221
+
222
+ # Plot 1: Original image
223
+ axes[0].imshow(img_display)
224
+ axes[0].set_title(f'Original Image (Epoch {epoch})')
225
+ axes[0].axis('off')
226
+
227
+ # Plot 2: Average attention pattern
228
+ attention_np = np.log1p(attention_map.numpy())
229
+ attention_resized = zoom(attention_np, (H / grid_h, W / grid_w), order=1)
230
+ im1 = axes[1].imshow(attention_resized, cmap='hot')
231
+ axes[1].set_title('Avg Attention (All Patches)')
232
+ axes[1].axis('off')
233
+ plt.colorbar(im1, ax=axes[1])
234
+
235
+ # Plot 3: Center patch attention
236
+ center_np = np.log1p(center_map.numpy())
237
+ center_resized = zoom(center_np, (H / grid_h, W / grid_w), order=1)
238
+ im2 = axes[2].imshow(center_resized, cmap='viridis')
239
+ axes[2].set_title('Center Patch Attention')
240
+ axes[2].axis('off')
241
+ plt.colorbar(im2, ax=axes[2])
242
+
243
+ # Plot 4: Patch flux contributions
244
+ flux_map = patch_flux.cpu().reshape(grid_h, grid_w)
245
+ flux_np = np.log1p(flux_map.numpy())
246
+ flux_resized = zoom(flux_np, (H / grid_h, W / grid_w), order=1)
247
+ im3 = axes[3].imshow(flux_resized, cmap='plasma')
248
+ axes[3].set_title('Log Patch Flux Contributions')
249
+ axes[3].axis('off')
250
+ plt.colorbar(im3, ax=axes[3])
251
+
252
+ elif self.use_local_attention:
253
+ # Show: Original, Avg Attention, Center Attention
254
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
255
+
256
+ # Plot 1: Original image
257
+ axes[0].imshow(img_display)
258
+ axes[0].set_title(f'Original Image (Epoch {epoch})')
259
+ axes[0].axis('off')
260
+
261
+ # Plot 2: Average attention pattern
262
+ attention_np = np.log1p(attention_map.numpy())
263
+ attention_resized = zoom(attention_np, (H / grid_h, W / grid_w), order=1)
264
+ im1 = axes[1].imshow(attention_resized, cmap='hot')
265
+ axes[1].set_title('Avg Attention (All Patches)')
266
+ axes[1].axis('off')
267
+ plt.colorbar(im1, ax=axes[1])
268
+
269
+ # Plot 3: Center patch attention
270
+ center_np = np.log1p(center_map.numpy())
271
+ center_resized = zoom(center_np, (H / grid_h, W / grid_w), order=1)
272
+ im2 = axes[2].imshow(center_resized, cmap='viridis')
273
+ axes[2].set_title('Center Patch Attention')
274
+ axes[2].axis('off')
275
+ plt.colorbar(im2, ax=axes[2])
276
+ else:
277
+ # Original CLS token visualization
278
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
279
+
280
+ # Plot 1: Original image
281
+ axes[0].imshow(img_display)
282
+ axes[0].set_title(f'Original Image (Epoch {epoch})')
283
+ axes[0].axis('off')
284
+
285
+ # Plot 2: Attention heatmap
286
+ attention_np = np.log1p(attention_map.numpy())
287
+ attention_resized = zoom(attention_np, (H / grid_h, W / grid_w), order=1)
288
+ im = axes[1].imshow(attention_resized, cmap='hot')
289
+ axes[1].set_title(f'Attention Map (Sample {sample_idx})')
290
+ axes[1].axis('off')
291
+ plt.colorbar(im, ax=axes[1])
292
+
293
+ # Plot 3: Overlay attention on image
294
+ axes[2].imshow(img_display)
295
+ axes[2].imshow(attention_resized, cmap='hot', alpha=0.5)
296
+ axes[2].set_title(f'Log-Scaled Attention Overlay (Sample {sample_idx})')
297
+ axes[2].axis('off')
298
 
299
  plt.tight_layout()
300
  return fig
forecasting/training/localpatch.yaml CHANGED
@@ -1,14 +1,19 @@
1
 
2
  #Base directories - change these to switch datasets
3
- base_data_dir: "/mnt/data/NO-OVERLAP" # Change this line for different datasets
4
- base_checkpoint_dir: "/mnt/data/NO-OVERLAP" # Change this line for different datasets
5
  wavelengths: [94, 131, 171, 193, 211, 304] # AIA wavelengths in Angstroms
6
 
7
  # GPU configuration
8
- gpu_id: 0 # GPU device ID to use (0, 1, 2, etc.) or -1 for CPU only
 
 
 
 
 
9
  # Model configuration
10
  selected_model: "ViTLocal" # Options: "hybrid", "vit", "fusion", "vitpatch"
11
- batch_size: 128
12
  epochs: 250
13
  oversample: false
14
  balance_strategy: "upsample_minority"
@@ -20,9 +25,9 @@ vit_custom:
20
  num_classes: 1
21
  patch_size: 16
22
  num_patches: 1024
23
- hidden_dim: 512
24
  num_heads: 8
25
- num_layers: 6
26
  dropout: 0.1
27
  lr: 0.0001
28
 
@@ -40,11 +45,11 @@ data:
40
 
41
  wandb:
42
  entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
43
- project: Model Testing
44
  job_type: training
45
  tags:
46
  - aia
47
  - sxr
48
  - regression
49
- wb_name: vit-local-patch
50
  notes: Regression from AIA images (6 channels) to GOES SXR flux
 
1
 
2
  #Base directories - change these to switch datasets
3
+ base_data_dir: "/mnt/data/PAPER_DATA_B" # Change this line for different datasets
4
+ base_checkpoint_dir: "/mnt/data/PAPER_DATA_B" # Change this line for different datasets
5
  wavelengths: [94, 131, 171, 193, 211, 304] # AIA wavelengths in Angstroms
6
 
7
  # GPU configuration
8
+ # Options:
9
+ # - Single GPU: gpu_ids: 0 or gpu_ids: [0]
10
+ # - Multi GPU: gpu_ids: [0, 1] (uses both GPU 0 and 1)
11
+ # - All GPUs: gpu_ids: "all" (uses all available GPUs)
12
+ # - CPU only: gpu_ids: -1
13
+ gpu_ids: "all" # Use both GPUs
14
  # Model configuration
15
  selected_model: "ViTLocal" # Options: "hybrid", "vit", "fusion", "vitpatch"
16
+ batch_size: 48
17
  epochs: 250
18
  oversample: false
19
  balance_strategy: "upsample_minority"
 
25
  num_classes: 1
26
  patch_size: 16
27
  num_patches: 1024
28
+ hidden_dim: 2048
29
  num_heads: 8
30
+ num_layers: 10
31
  dropout: 0.1
32
  lr: 0.0001
33
 
 
45
 
46
  wandb:
47
  entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
48
+ project: Paper
49
  job_type: training
50
  tags:
51
  - aia
52
  - sxr
53
  - regression
54
+ wb_name: paper-testing-16-patch-deeper-model-9x9-attention
55
  notes: Regression from AIA images (6 channels) to GOES SXR flux
forecasting/training/train.py CHANGED
@@ -13,7 +13,6 @@ import numpy as np
13
  from pytorch_lightning import Trainer
14
  from pytorch_lightning.loggers import WandbLogger
15
  from pytorch_lightning.callbacks import ModelCheckpoint
16
- from torch.nn import MSELoss, HuberLoss
17
  from pathlib import Path
18
  import sys
19
  # Add project root to Python path
@@ -21,33 +20,13 @@ PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute()
21
  sys.path.insert(0, str(PROJECT_ROOT))
22
 
23
  from forecasting.data_loaders.SDOAIA_dataloader import AIA_GOESDataModule
24
- from forecasting.models.vision_transformer_custom import ViT
25
- from forecasting.models.linear_and_hybrid import LinearIrradianceModel, HybridIrradianceModel
26
- from forecasting.models.vit_patch_model import ViT as ViTPatch
27
- from forecasting.models.vit_patch_model_uncertainty import ViTUncertainty
28
- from forecasting.models import FusionViTHybrid
29
- from forecasting.models.CNN_Patch import CNNPatch
30
  from forecasting.models.vit_patch_model_local import ViTLocal
31
  from callback import ImagePredictionLogger_SXR, AttentionMapCallback
32
 
33
  from pytorch_lightning.callbacks import Callback
34
 
35
- from forecasting.models.FastSpectralNet import FastViTFlaringModel
36
-
37
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
38
- os.environ["NCCL_DEBUG"] = "WARN"
39
- # Shared memory optimizations
40
- os.environ["OMP_NUM_THREADS"] = "1" # Limit OpenMP threads
41
- os.environ["MKL_NUM_THREADS"] = "1" # Limit MKL threads
42
 
43
- def print_gpu_memory(stage=""):
44
- """Print GPU memory usage for monitoring"""
45
- if torch.cuda.is_available():
46
- allocated = torch.cuda.memory_allocated() / 1e9
47
- reserved = torch.cuda.memory_reserved() / 1e9
48
- print(f"GPU Memory {stage} - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")
49
- else:
50
- print(f"No GPU available for memory monitoring {stage}")
51
 
52
  def resolve_config_variables(config_dict):
53
  """Recursively resolve ${variable} references within the config"""
@@ -91,27 +70,6 @@ with open(args.config, 'r') as stream:
91
  # Resolve variables like ${base_data_dir}
92
  config_data = resolve_config_variables(config_data)
93
 
94
- # GPU Memory Isolation for Multi-GPU Systems
95
- gpu_id = config_data.get('gpu_id', 0)
96
- if gpu_id != -1: # Only if using GPU
97
- # Set CUDA device visibility to only the specified GPU
98
- os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
99
- print(f"Set CUDA_VISIBLE_DEVICES to GPU {gpu_id}")
100
-
101
- # Clear any existing CUDA cache
102
- if torch.cuda.is_available():
103
- torch.cuda.empty_cache()
104
- print(f"Cleared CUDA cache for GPU {gpu_id}")
105
-
106
- # Set memory allocation strategy for better isolation
107
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,roundup_power2_divisions:16"
108
-
109
- # Disable memory sharing between processes
110
- os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
111
-
112
- print(f"GPU Memory Isolation configured for GPU {gpu_id}")
113
- else:
114
- print("Using CPU - no GPU memory isolation needed")
115
 
116
  # Debug: Print resolved paths
117
  print("Resolved paths:")
@@ -120,12 +78,6 @@ print(f"SXR dir: {config_data['data']['sxr_dir']}")
120
  print(f"Checkpoints dir: {config_data['data']['checkpoints_dir']}")
121
 
122
  sxr_norm = np.load(config_data['data']['sxr_norm_path'])
123
-
124
- n = 0
125
-
126
- torch.manual_seed(config_data['megsai']['seed'])
127
- np.random.seed(config_data['megsai']['seed'])
128
-
129
  training_wavelengths = config_data['wavelengths']
130
 
131
 
@@ -145,10 +97,6 @@ data_loader = AIA_GOESDataModule(
145
  balance_strategy=config_data['balance_strategy'],
146
  )
147
  data_loader.setup()
148
-
149
- # Monitor memory after data loading
150
- print_gpu_memory("after data loading")
151
-
152
  # Logger
153
  #wb_name = f"{instrument}_{n}" if len(combined_parameters) > 1 else "aia_sxr_model"
154
  wandb_logger = WandbLogger(
@@ -158,7 +106,7 @@ wandb_logger = WandbLogger(
158
  tags=config_data['wandb']['tags'],
159
  name=config_data['wandb']['wb_name'],
160
  notes=config_data['wandb']['notes'],
161
- config=config_data['megsai']
162
  )
163
 
164
  # Logging callback
@@ -169,8 +117,8 @@ plot_samples = plot_data # Keep as list of ((aia, sxr), target)
169
 
170
  sxr_plot_callback = ImagePredictionLogger_SXR(plot_samples, sxr_norm)
171
  # Attention map callback - get patch size from config
172
- patch_size = config_data.get('vit_custom', {}).get('patch_size', 8)
173
- attention = AttentionMapCallback(patch_size=patch_size)
174
 
175
 
176
  class PTHCheckpointCallback(Callback):
@@ -323,27 +271,63 @@ else:
323
  raise NotImplementedError(f"Architecture {config_data['selected_model']} not supported.")
324
 
325
  # Set device based on config
326
- gpu_id = config_data.get('gpu_id', 0)
327
- if gpu_id == -1:
 
 
 
328
  accelerator = "cpu"
329
  devices = 1
 
330
  print("Using CPU for training")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  else:
 
332
  if torch.cuda.is_available():
333
  accelerator = "gpu"
334
- # When CUDA_VISIBLE_DEVICES is set, PyTorch Lightning only sees GPU 0
335
- devices = [0] # Always use device 0 since we've isolated to specific GPU
336
- print(f"Using GPU {gpu_id} for training (mapped to device 0 after CUDA_VISIBLE_DEVICES)")
337
  else:
338
  accelerator = "cpu"
339
  devices = 1
340
- print(f"GPU {gpu_id} not available, falling back to CPU")
 
341
 
342
  # Trainer
343
  trainer = Trainer(
344
  default_root_dir=config_data['data']['checkpoints_dir'],
345
  accelerator=accelerator,
346
  devices=devices,
 
347
  max_epochs=config_data['epochs'],
348
  callbacks=[attention, checkpoint_callback],
349
  logger=wandb_logger,
@@ -359,6 +343,5 @@ torch.save({
359
  'state_dict': model.state_dict()
360
  }, final_checkpoint_path)
361
  print(f"Saved final PyTorch checkpoint: {final_checkpoint_path}")
362
- n += 1
363
  # Finalize
364
  wandb.finish()
 
13
  from pytorch_lightning import Trainer
14
  from pytorch_lightning.loggers import WandbLogger
15
  from pytorch_lightning.callbacks import ModelCheckpoint
 
16
  from pathlib import Path
17
  import sys
18
  # Add project root to Python path
 
20
  sys.path.insert(0, str(PROJECT_ROOT))
21
 
22
  from forecasting.data_loaders.SDOAIA_dataloader import AIA_GOESDataModule
23
+
 
 
 
 
 
24
  from forecasting.models.vit_patch_model_local import ViTLocal
25
  from callback import ImagePredictionLogger_SXR, AttentionMapCallback
26
 
27
  from pytorch_lightning.callbacks import Callback
28
 
 
 
 
 
 
 
 
29
 
 
 
 
 
 
 
 
 
30
 
31
  def resolve_config_variables(config_dict):
32
  """Recursively resolve ${variable} references within the config"""
 
70
  # Resolve variables like ${base_data_dir}
71
  config_data = resolve_config_variables(config_data)
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  # Debug: Print resolved paths
75
  print("Resolved paths:")
 
78
  print(f"Checkpoints dir: {config_data['data']['checkpoints_dir']}")
79
 
80
  sxr_norm = np.load(config_data['data']['sxr_norm_path'])
 
 
 
 
 
 
81
  training_wavelengths = config_data['wavelengths']
82
 
83
 
 
97
  balance_strategy=config_data['balance_strategy'],
98
  )
99
  data_loader.setup()
 
 
 
 
100
  # Logger
101
  #wb_name = f"{instrument}_{n}" if len(combined_parameters) > 1 else "aia_sxr_model"
102
  wandb_logger = WandbLogger(
 
106
  tags=config_data['wandb']['tags'],
107
  name=config_data['wandb']['wb_name'],
108
  notes=config_data['wandb']['notes'],
109
+ config=config_data
110
  )
111
 
112
  # Logging callback
 
117
 
118
  sxr_plot_callback = ImagePredictionLogger_SXR(plot_samples, sxr_norm)
119
  # Attention map callback - get patch size from config
120
+ patch_size = config_data.get('vit_custom', {}).get('patch_size', 16)
121
+ attention = AttentionMapCallback(patch_size=patch_size, use_local_attention=True)
122
 
123
 
124
  class PTHCheckpointCallback(Callback):
 
271
  raise NotImplementedError(f"Architecture {config_data['selected_model']} not supported.")
272
 
273
  # Set device based on config
274
+ # Support both old 'gpu_id' and new 'gpu_ids' config keys for backward compatibility
275
+ gpu_config = config_data.get('gpu_ids', config_data.get('gpu_id', 0))
276
+
277
+ if gpu_config == -1:
278
+ # CPU only
279
  accelerator = "cpu"
280
  devices = 1
281
+ strategy = "auto"
282
  print("Using CPU for training")
283
+ elif gpu_config == "all":
284
+ # Use all available GPUs
285
+ if torch.cuda.is_available():
286
+ accelerator = "gpu"
287
+ devices = -1 # -1 means use all available GPUs
288
+ num_gpus = torch.cuda.device_count()
289
+ strategy = "auto"
290
+ print(f"Using all available GPUs ({num_gpus} GPUs)")
291
+ if num_gpus > 1:
292
+ print(f"Multi-GPU training with DDP: Effective batch size = {config_data['batch_size']} x {num_gpus} GPUs = {config_data['batch_size'] * num_gpus}")
293
+ else:
294
+ accelerator = "cpu"
295
+ devices = 1
296
+ strategy = "auto"
297
+ print("No GPUs available, falling back to CPU")
298
+ elif isinstance(gpu_config, list):
299
+ # Multiple specific GPUs
300
+ if torch.cuda.is_available():
301
+ accelerator = "gpu"
302
+ devices = gpu_config
303
+ strategy = "auto"
304
+ print(f"Using GPUs: {gpu_config}")
305
+ if len(gpu_config) > 1:
306
+ print(f"Multi-GPU training with DDP: Effective batch size = {config_data['batch_size']} x {len(gpu_config)} GPUs = {config_data['batch_size'] * len(gpu_config)}")
307
+ else:
308
+ accelerator = "cpu"
309
+ devices = 1
310
+ strategy = "auto"
311
+ print("No GPUs available, falling back to CPU")
312
  else:
313
+ # Single GPU (integer)
314
  if torch.cuda.is_available():
315
  accelerator = "gpu"
316
+ devices = [gpu_config]
317
+ strategy = "auto"
318
+ print(f"Using GPU {gpu_config}")
319
  else:
320
  accelerator = "cpu"
321
  devices = 1
322
+ strategy = "auto"
323
+ print(f"GPU {gpu_config} not available, falling back to CPU")
324
 
325
  # Trainer
326
  trainer = Trainer(
327
  default_root_dir=config_data['data']['checkpoints_dir'],
328
  accelerator=accelerator,
329
  devices=devices,
330
+ strategy=strategy,
331
  max_epochs=config_data['epochs'],
332
  callbacks=[attention, checkpoint_callback],
333
  logger=wandb_logger,
 
343
  'state_dict': model.state_dict()
344
  }, final_checkpoint_path)
345
  print(f"Saved final PyTorch checkpoint: {final_checkpoint_path}")
 
346
  # Finalize
347
  wandb.finish()