griffingoodwin04 commited on
Commit
0b74556
·
1 Parent(s): 43b5310

refactor model configuration and callbacks; update data paths and loss functions

Browse files
flaring/MEGS_AI_baseline/callback.py CHANGED
@@ -28,7 +28,7 @@ class ImagePredictionLogger_SXR(Callback):
28
  def __init__(self, data_samples, sxr_norm):
29
  super().__init__()
30
  self.data_samples = data_samples
31
- self.val_aia = data_samples[0][0]
32
  self.val_sxr = data_samples[1]
33
  self.sxr_norm = sxr_norm
34
 
@@ -178,7 +178,7 @@ class AttentionMapCallback(Callback):
178
  fig, axes = plt.subplots(1, 3, figsize=(15, 5))
179
 
180
  # Plot 1: Original image
181
- axes[0].imshow(img_np[:, :, :3]) # only first 3 channels if more than 3
182
  axes[0].set_title(f'Original Image (Epoch {epoch})')
183
  axes[0].axis('off')
184
 
@@ -189,7 +189,7 @@ class AttentionMapCallback(Callback):
189
  plt.colorbar(im, ax=axes[1])
190
 
191
  # Plot 3: Overlay attention on image
192
- axes[2].imshow(img_np[:, :, :3])
193
 
194
  # Overlay attention as colored patches
195
  max_attention = attention_map.max().numpy()
 
28
  def __init__(self, data_samples, sxr_norm):
29
  super().__init__()
30
  self.data_samples = data_samples
31
+ self.val_aia = data_samples[0]
32
  self.val_sxr = data_samples[1]
33
  self.sxr_norm = sxr_norm
34
 
 
178
  fig, axes = plt.subplots(1, 3, figsize=(15, 5))
179
 
180
  # Plot 1: Original image
181
+ axes[0].imshow((img_np[:, :,0]+1)/2)
182
  axes[0].set_title(f'Original Image (Epoch {epoch})')
183
  axes[0].axis('off')
184
 
 
189
  plt.colorbar(im, ax=axes[1])
190
 
191
  # Plot 3: Overlay attention on image
192
+ axes[2].imshow((img_np[:, :,0]+1)/2)
193
 
194
  # Overlay attention as colored patches
195
  max_attention = attention_map.max().numpy()
flaring/MEGS_AI_baseline/config.yaml CHANGED
@@ -1,7 +1,7 @@
1
 
2
  # Base directories - change these to switch datasets
3
- base_data_dir: "/mnt/data/ML-Ready/flares_event_dir" # Change this line for different datasets
4
- base_checkpoint_dir: "/mnt/data/ML-Ready/flares_event_dir" # Change this line for different datasets
5
 
6
  # Model configuration
7
  selected_model: "ViT" # Options: "cnn", "vit",
@@ -20,20 +20,35 @@ model:
20
  epochs:
21
  100
22
  batch_size:
23
- 64
24
 
25
  vit:
26
  embed_dim: 512
27
  num_channels: 6 # AIA has 6 channels
28
  num_classes: 1 # Regression task, predicting SXR flux
29
  patch_size: 16
30
- num_patches: 262144
31
  hidden_dim: 512
32
  num_heads: 4
33
- num_layers: 6
34
- dropout: 0.1
35
  lr: .0001
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # Data paths (automatically constructed from base directories)
38
  data:
39
  aia_dir:
 
1
 
2
  # Base directories - change these to switch datasets
3
+ base_data_dir: "/mnt/data/ML-Ready/mixed_data" # Change this line for different datasets
4
+ base_checkpoint_dir: "/mnt/data/ML-Ready/mixed_data" # Change this line for different datasets
5
 
6
  # Model configuration
7
  selected_model: "ViT" # Options: "cnn", "vit",
 
20
  epochs:
21
  100
22
  batch_size:
23
+ 16
24
 
25
  vit:
26
  embed_dim: 512
27
  num_channels: 6 # AIA has 6 channels
28
  num_classes: 1 # Regression task, predicting SXR flux
29
  patch_size: 16
30
+ num_patches: 1024
31
  hidden_dim: 512
32
  num_heads: 4
33
+ num_layers: 4
34
+ dropout: 0.25
35
  lr: .0001
36
 
37
+
38
+ #vit:
39
+ # embed_dim: 512
40
+ # num_channels: 6 # AIA has 6 channels
41
+ # num_classes: 1 # Regression task, predicting SXR flux
42
+ # patch_size: 8
43
+ # num_patches: 4096
44
+ # hidden_dim: 512
45
+ # num_heads: 2
46
+ # num_layers: 3
47
+ # dropout: 0.25
48
+ # lr: .0001
49
+
50
+
51
+
52
  # Data paths (automatically constructed from base directories)
53
  data:
54
  aia_dir:
flaring/MEGS_AI_baseline/models/base_model.py CHANGED
@@ -16,7 +16,7 @@ class BaseModel(LightningModule):
16
  optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
17
  scheduler = {
18
  'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3),
19
- 'monitor': 'valid_loss', # name of the metric to monitor
20
  'interval': 'epoch',
21
  }
22
  return {'optimizer': optimizer, 'lr_scheduler': scheduler}
 
16
  optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
17
  scheduler = {
18
  'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3),
19
+ 'monitor': 'val_loss', # name of the metric to monitor
20
  'interval': 'epoch',
21
  }
22
  return {'optimizer': optimizer, 'lr_scheduler': scheduler}
flaring/MEGS_AI_baseline/models/vision_transformer_custom.py CHANGED
@@ -19,20 +19,27 @@ class ViT(pl.LightningModule):
19
  filtered_kwargs.pop('lr', None)
20
  self.model = VisionTransformer(**filtered_kwargs)
21
 
22
- def forward(self, x, return_attention=False):
23
  return self.model(x, return_attention=return_attention)
24
 
25
  def configure_optimizers(self):
26
- optimizer = optim.AdamW(self.parameters(), lr=self.lr)
27
- lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
28
- return [optimizer], [lr_scheduler]
 
 
 
 
 
 
 
29
 
30
  def _calculate_loss(self, batch, mode="train"):
31
  imgs, sxr = batch
32
  preds = self.model(imgs)
33
 
34
  # Change loss function for regression
35
- loss = F.huber_loss(torch.squeeze(preds), sxr) # or F.l1_loss() or F.huber_loss()
36
 
37
  # Change accuracy to a regression metric
38
  mae = F.l1_loss(torch.squeeze(preds), sxr) # Mean Absolute Error
 
19
  filtered_kwargs.pop('lr', None)
20
  self.model = VisionTransformer(**filtered_kwargs)
21
 
22
+ def forward(self, x, return_attention=True):
23
  return self.model(x, return_attention=return_attention)
24
 
25
  def configure_optimizers(self):
26
+ # optimizer = optim.AdamW(self.parameters(), lr=self.lr)
27
+ # lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
28
+ # return [optimizer], [lr_scheduler]
29
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
30
+ scheduler = {
31
+ 'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3),
32
+ 'monitor': 'val_loss', # name of the metric to monitor
33
+ 'interval': 'epoch',
34
+ }
35
+ return {'optimizer': optimizer, 'lr_scheduler': scheduler}
36
 
37
  def _calculate_loss(self, batch, mode="train"):
38
  imgs, sxr = batch
39
  preds = self.model(imgs)
40
 
41
  # Change loss function for regression
42
+ loss = F.mse_loss(torch.squeeze(preds), sxr) # or F.l1_loss() or F.huber_loss()
43
 
44
  # Change accuracy to a regression metric
45
  mae = F.l1_loss(torch.squeeze(preds), sxr) # Mean Absolute Error
flaring/MEGS_AI_baseline/train.py CHANGED
@@ -203,7 +203,7 @@ trainer = Trainer(
203
  accelerator="gpu" if torch.cuda.is_available() else "cpu",
204
  devices=1,
205
  max_epochs=config_data['model']['epochs'],
206
- callbacks=[sxr_plot_callback, attention, pth_callback],
207
  logger=wandb_logger,
208
  log_every_n_steps=10
209
  )
 
203
  accelerator="gpu" if torch.cuda.is_available() else "cpu",
204
  devices=1,
205
  max_epochs=config_data['model']['epochs'],
206
+ callbacks=[attention, pth_callback],
207
  logger=wandb_logger,
208
  log_every_n_steps=10
209
  )