Commit ·
0b74556
1
Parent(s): 43b5310
refactor model configuration and callbacks; update data paths and loss functions
Browse files
flaring/MEGS_AI_baseline/callback.py
CHANGED
|
@@ -28,7 +28,7 @@ class ImagePredictionLogger_SXR(Callback):
|
|
| 28 |
def __init__(self, data_samples, sxr_norm):
|
| 29 |
super().__init__()
|
| 30 |
self.data_samples = data_samples
|
| 31 |
-
self.val_aia = data_samples[0]
|
| 32 |
self.val_sxr = data_samples[1]
|
| 33 |
self.sxr_norm = sxr_norm
|
| 34 |
|
|
@@ -178,7 +178,7 @@ class AttentionMapCallback(Callback):
|
|
| 178 |
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 179 |
|
| 180 |
# Plot 1: Original image
|
| 181 |
-
axes[0].imshow(img_np[:, :,
|
| 182 |
axes[0].set_title(f'Original Image (Epoch {epoch})')
|
| 183 |
axes[0].axis('off')
|
| 184 |
|
|
@@ -189,7 +189,7 @@ class AttentionMapCallback(Callback):
|
|
| 189 |
plt.colorbar(im, ax=axes[1])
|
| 190 |
|
| 191 |
# Plot 3: Overlay attention on image
|
| 192 |
-
axes[2].imshow(img_np[:, :,
|
| 193 |
|
| 194 |
# Overlay attention as colored patches
|
| 195 |
max_attention = attention_map.max().numpy()
|
|
|
|
| 28 |
def __init__(self, data_samples, sxr_norm):
|
| 29 |
super().__init__()
|
| 30 |
self.data_samples = data_samples
|
| 31 |
+
self.val_aia = data_samples[0]
|
| 32 |
self.val_sxr = data_samples[1]
|
| 33 |
self.sxr_norm = sxr_norm
|
| 34 |
|
|
|
|
| 178 |
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 179 |
|
| 180 |
# Plot 1: Original image
|
| 181 |
+
axes[0].imshow((img_np[:, :,0]+1)/2)
|
| 182 |
axes[0].set_title(f'Original Image (Epoch {epoch})')
|
| 183 |
axes[0].axis('off')
|
| 184 |
|
|
|
|
| 189 |
plt.colorbar(im, ax=axes[1])
|
| 190 |
|
| 191 |
# Plot 3: Overlay attention on image
|
| 192 |
+
axes[2].imshow((img_np[:, :,0]+1)/2)
|
| 193 |
|
| 194 |
# Overlay attention as colored patches
|
| 195 |
max_attention = attention_map.max().numpy()
|
flaring/MEGS_AI_baseline/config.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
|
| 2 |
# Base directories - change these to switch datasets
|
| 3 |
-
base_data_dir: "/mnt/data/ML-Ready/
|
| 4 |
-
base_checkpoint_dir: "/mnt/data/ML-Ready/
|
| 5 |
|
| 6 |
# Model configuration
|
| 7 |
selected_model: "ViT" # Options: "cnn", "vit",
|
|
@@ -20,20 +20,35 @@ model:
|
|
| 20 |
epochs:
|
| 21 |
100
|
| 22 |
batch_size:
|
| 23 |
-
|
| 24 |
|
| 25 |
vit:
|
| 26 |
embed_dim: 512
|
| 27 |
num_channels: 6 # AIA has 6 channels
|
| 28 |
num_classes: 1 # Regression task, predicting SXR flux
|
| 29 |
patch_size: 16
|
| 30 |
-
num_patches:
|
| 31 |
hidden_dim: 512
|
| 32 |
num_heads: 4
|
| 33 |
-
num_layers:
|
| 34 |
-
dropout: 0.
|
| 35 |
lr: .0001
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
# Data paths (automatically constructed from base directories)
|
| 38 |
data:
|
| 39 |
aia_dir:
|
|
|
|
| 1 |
|
| 2 |
# Base directories - change these to switch datasets
|
| 3 |
+
base_data_dir: "/mnt/data/ML-Ready/mixed_data" # Change this line for different datasets
|
| 4 |
+
base_checkpoint_dir: "/mnt/data/ML-Ready/mixed_data" # Change this line for different datasets
|
| 5 |
|
| 6 |
# Model configuration
|
| 7 |
selected_model: "ViT" # Options: "cnn", "vit",
|
|
|
|
| 20 |
epochs:
|
| 21 |
100
|
| 22 |
batch_size:
|
| 23 |
+
16
|
| 24 |
|
| 25 |
vit:
|
| 26 |
embed_dim: 512
|
| 27 |
num_channels: 6 # AIA has 6 channels
|
| 28 |
num_classes: 1 # Regression task, predicting SXR flux
|
| 29 |
patch_size: 16
|
| 30 |
+
num_patches: 1024
|
| 31 |
hidden_dim: 512
|
| 32 |
num_heads: 4
|
| 33 |
+
num_layers: 4
|
| 34 |
+
dropout: 0.25
|
| 35 |
lr: .0001
|
| 36 |
|
| 37 |
+
|
| 38 |
+
#vit:
|
| 39 |
+
# embed_dim: 512
|
| 40 |
+
# num_channels: 6 # AIA has 6 channels
|
| 41 |
+
# num_classes: 1 # Regression task, predicting SXR flux
|
| 42 |
+
# patch_size: 8
|
| 43 |
+
# num_patches: 4096
|
| 44 |
+
# hidden_dim: 512
|
| 45 |
+
# num_heads: 2
|
| 46 |
+
# num_layers: 3
|
| 47 |
+
# dropout: 0.25
|
| 48 |
+
# lr: .0001
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
# Data paths (automatically constructed from base directories)
|
| 53 |
data:
|
| 54 |
aia_dir:
|
flaring/MEGS_AI_baseline/models/base_model.py
CHANGED
|
@@ -16,7 +16,7 @@ class BaseModel(LightningModule):
|
|
| 16 |
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
| 17 |
scheduler = {
|
| 18 |
'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3),
|
| 19 |
-
'monitor': '
|
| 20 |
'interval': 'epoch',
|
| 21 |
}
|
| 22 |
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
|
|
|
|
| 16 |
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
| 17 |
scheduler = {
|
| 18 |
'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3),
|
| 19 |
+
'monitor': 'val_loss', # name of the metric to monitor
|
| 20 |
'interval': 'epoch',
|
| 21 |
}
|
| 22 |
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
|
flaring/MEGS_AI_baseline/models/vision_transformer_custom.py
CHANGED
|
@@ -19,20 +19,27 @@ class ViT(pl.LightningModule):
|
|
| 19 |
filtered_kwargs.pop('lr', None)
|
| 20 |
self.model = VisionTransformer(**filtered_kwargs)
|
| 21 |
|
| 22 |
-
def forward(self, x, return_attention=
|
| 23 |
return self.model(x, return_attention=return_attention)
|
| 24 |
|
| 25 |
def configure_optimizers(self):
|
| 26 |
-
optimizer = optim.AdamW(self.parameters(), lr=self.lr)
|
| 27 |
-
lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
|
| 28 |
-
return [optimizer], [lr_scheduler]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
def _calculate_loss(self, batch, mode="train"):
|
| 31 |
imgs, sxr = batch
|
| 32 |
preds = self.model(imgs)
|
| 33 |
|
| 34 |
# Change loss function for regression
|
| 35 |
-
loss = F.
|
| 36 |
|
| 37 |
# Change accuracy to a regression metric
|
| 38 |
mae = F.l1_loss(torch.squeeze(preds), sxr) # Mean Absolute Error
|
|
|
|
| 19 |
filtered_kwargs.pop('lr', None)
|
| 20 |
self.model = VisionTransformer(**filtered_kwargs)
|
| 21 |
|
| 22 |
+
def forward(self, x, return_attention=True):
|
| 23 |
return self.model(x, return_attention=return_attention)
|
| 24 |
|
| 25 |
def configure_optimizers(self):
|
| 26 |
+
# optimizer = optim.AdamW(self.parameters(), lr=self.lr)
|
| 27 |
+
# lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
|
| 28 |
+
# return [optimizer], [lr_scheduler]
|
| 29 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
| 30 |
+
scheduler = {
|
| 31 |
+
'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3),
|
| 32 |
+
'monitor': 'val_loss', # name of the metric to monitor
|
| 33 |
+
'interval': 'epoch',
|
| 34 |
+
}
|
| 35 |
+
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
|
| 36 |
|
| 37 |
def _calculate_loss(self, batch, mode="train"):
|
| 38 |
imgs, sxr = batch
|
| 39 |
preds = self.model(imgs)
|
| 40 |
|
| 41 |
# Change loss function for regression
|
| 42 |
+
loss = F.mse_loss(torch.squeeze(preds), sxr) # or F.l1_loss() or F.huber_loss()
|
| 43 |
|
| 44 |
# Change accuracy to a regression metric
|
| 45 |
mae = F.l1_loss(torch.squeeze(preds), sxr) # Mean Absolute Error
|
flaring/MEGS_AI_baseline/train.py
CHANGED
|
@@ -203,7 +203,7 @@ trainer = Trainer(
|
|
| 203 |
accelerator="gpu" if torch.cuda.is_available() else "cpu",
|
| 204 |
devices=1,
|
| 205 |
max_epochs=config_data['model']['epochs'],
|
| 206 |
-
callbacks=[
|
| 207 |
logger=wandb_logger,
|
| 208 |
log_every_n_steps=10
|
| 209 |
)
|
|
|
|
| 203 |
accelerator="gpu" if torch.cuda.is_available() else "cpu",
|
| 204 |
devices=1,
|
| 205 |
max_epochs=config_data['model']['epochs'],
|
| 206 |
+
callbacks=[attention, pth_callback],
|
| 207 |
logger=wandb_logger,
|
| 208 |
log_every_n_steps=10
|
| 209 |
)
|