griffingoodwin04 commited on
Commit
e89f383
·
1 Parent(s): 9592dff

Implement Vision Transformer model (custom) and update configuration for model selection + changed data loader outputs (it was partially redundant)

Browse files
flaring/MEGS_AI_baseline/SDOAIA_dataloader.py CHANGED
@@ -94,7 +94,7 @@ class AIA_GOESDataset(torch.utils.data.Dataset):
94
  if self.sxr_transform:
95
  sxr_val = self.sxr_transform(sxr_val)
96
 
97
- return (aia_img, torch.tensor(sxr_val, dtype=torch.float32)), torch.tensor(sxr_val, dtype=torch.float32)
98
 
99
  class AIA_GOESDataModule(LightningDataModule):
100
  """PyTorch Lightning DataModule for AIA and SXR data."""
 
94
  if self.sxr_transform:
95
  sxr_val = self.sxr_transform(sxr_val)
96
 
97
+ return aia_img, torch.tensor(sxr_val, dtype=torch.float32)
98
 
99
  class AIA_GOESDataModule(LightningDataModule):
100
  """PyTorch Lightning DataModule for AIA and SXR data."""
flaring/MEGS_AI_baseline/callback.py CHANGED
@@ -33,7 +33,7 @@ class ImagePredictionLogger_SXR(Callback):
33
  true_sxr = []
34
  pred_sxr = []
35
  # print(self.val_samples)
36
- for (aia, _), target in self.data_samples:
37
  #device = torch.device("cuda:0")
38
  aia = aia.to(pl_module.device).unsqueeze(0)
39
  # Get prediction
 
33
  true_sxr = []
34
  pred_sxr = []
35
  # print(self.val_samples)
36
+ for aia, target in self.data_samples:
37
  #device = torch.device("cuda:0")
38
  aia = aia.to(pl_module.device).unsqueeze(0)
39
  # Get prediction
flaring/MEGS_AI_baseline/config.yaml CHANGED
@@ -4,9 +4,11 @@ base_data_dir: "/mnt/data/ML-Ready/flares_event_dir" # Change this line for dif
4
  base_checkpoint_dir: "/mnt/data/ML-Ready/flares_event_dir" # Change this line for different datasets
5
 
6
  # Model configuration
 
 
7
  model:
8
  architecture:
9
- "hybrid"
10
  seed:
11
  42
12
  lr:
@@ -20,6 +22,18 @@ model:
20
  batch_size:
21
  64
22
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # Data paths (automatically constructed from base directories)
24
  data:
25
  aia_dir:
@@ -33,11 +47,11 @@ data:
33
 
34
  wandb:
35
  entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
36
- project: MEGS-AI Basline Models
37
  job_type: training
38
  tags:
39
  - aia
40
  - sxr
41
  - regression
42
- wb_name: flaring-baseline-lr-scheduler
43
  notes: Regression from AIA images (6 channels) to GOES SXR flux
 
4
  base_checkpoint_dir: "/mnt/data/ML-Ready/flares_event_dir" # Change this line for different datasets
5
 
6
  # Model configuration
7
+ selected_model: "ViT" # Options: "cnn", "vit",
8
+
9
  model:
10
  architecture:
11
+ "cnn"
12
  seed:
13
  42
14
  lr:
 
22
  batch_size:
23
  64
24
 
25
+ vit:
26
+ embed_dim: 512
27
+ num_channels: 6 # AIA has 6 channels
28
+ num_classes: 1 # Regression task, predicting SXR flux
29
+ patch_size: 16
30
+ num_patches: 262144
31
+ hidden_dim: 512
32
+ num_heads: 8
33
+ num_layers: 6
34
+ dropout: 0.1
35
+ lr: .00001
36
+
37
  # Data paths (automatically constructed from base directories)
38
  data:
39
  aia_dir:
 
47
 
48
  wandb:
49
  entity: jayantbiradar619-university-of-arizona # Use your exact W&B username
50
+ project: MEGS-AI ViT Testing Griffin
51
  job_type: training
52
  tags:
53
  - aia
54
  - sxr
55
  - regression
56
+ wb_name: flaring-vit-lr-scheduler
57
  notes: Regression from AIA images (6 channels) to GOES SXR flux
flaring/MEGS_AI_baseline/models/base_model.py CHANGED
@@ -36,7 +36,7 @@ class BaseModel(LightningModule):
36
  return loss
37
 
38
  def test_step(self, batch, batch_idx):
39
- (x, sxr), target = batch
40
  pred = self(x)
41
  loss = self.loss_func(torch.squeeze(pred), target)
42
  self.log('test_loss', loss)
 
36
  return loss
37
 
38
  def test_step(self, batch, batch_idx):
39
+ x, target = batch
40
  pred = self(x)
41
  loss = self.loss_func(torch.squeeze(pred), target)
42
  self.log('test_loss', loss)
flaring/MEGS_AI_baseline/models/vision_transformer_custom.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.optim as optim
6
+ import torch.utils.data as data
7
+ import torchvision
8
+ from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
9
+ from torchvision import transforms
10
+ import pytorch_lightning as pl
11
+
12
+
13
+ class ViT(pl.LightningModule):
14
+ def __init__(self, model_kwargs):
15
+ super().__init__()
16
+ self.lr = model_kwargs['lr']
17
+ self.save_hyperparameters()
18
+ filtered_kwargs = dict(model_kwargs)
19
+ filtered_kwargs.pop('lr', None)
20
+ self.model = VisionTransformer(**filtered_kwargs)
21
+
22
+ def forward(self, x):
23
+ return self.model(x)
24
+
25
+ def configure_optimizers(self):
26
+ optimizer = optim.AdamW(self.parameters(), lr=self.lr)
27
+ lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
28
+ return [optimizer], [lr_scheduler]
29
+
30
+ def _calculate_loss(self, batch, mode="train"):
31
+ imgs, sxr = batch
32
+ preds = self.model(imgs)
33
+
34
+ # Change loss function for regression
35
+ loss = F.huber_loss(torch.squeeze(preds), sxr) # or F.l1_loss() or F.huber_loss()
36
+
37
+ # Change accuracy to a regression metric
38
+ mae = F.l1_loss(torch.squeeze(preds), sxr) # Mean Absolute Error
39
+ # OR use RMSE:
40
+ # rmse = torch.sqrt(F.mse_loss(preds, labels))
41
+
42
+ self.log(f"{mode}_loss", loss)
43
+ self.log(f"{mode}_mae", mae) # or f"{mode}_rmse" if using RMSE
44
+ return loss
45
+
46
+ def training_step(self, batch, batch_idx):
47
+ loss = self._calculate_loss(batch, mode="train")
48
+ return loss
49
+
50
+ def validation_step(self, batch, batch_idx):
51
+ self._calculate_loss(batch, mode="val")
52
+
53
+ def test_step(self, batch, batch_idx):
54
+ self._calculate_loss(batch, mode="test")
55
+
56
+
57
+ class VisionTransformer(nn.Module):
58
+ def __init__(
59
+ self,
60
+ embed_dim,
61
+ hidden_dim,
62
+ num_channels,
63
+ num_heads,
64
+ num_layers,
65
+ num_classes,
66
+ patch_size,
67
+ num_patches,
68
+ dropout=0.0,
69
+ ):
70
+ """Vision Transformer.
71
+
72
+ Args:
73
+ embed_dim: Dimensionality of the input feature vectors to the Transformer
74
+ hidden_dim: Dimensionality of the hidden layer in the feed-forward networks
75
+ within the Transformer
76
+ num_channels: Number of channels of the input (3 for RGB)
77
+ num_heads: Number of heads to use in the Multi-Head Attention block
78
+ num_layers: Number of layers to use in the Transformer
79
+ num_classes: Number of classes to predict
80
+ patch_size: Number of pixels that the patches have per dimension
81
+ num_patches: Maximum number of patches an image can have
82
+ dropout: Amount of dropout to apply in the feed-forward network and
83
+ on the input encoding
84
+
85
+ """
86
+ super().__init__()
87
+
88
+ self.patch_size = patch_size
89
+
90
+ # Layers/Networks
91
+ self.input_layer = nn.Linear(num_channels * (patch_size**2), embed_dim)
92
+ self.transformer = nn.Sequential(
93
+ *(AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout) for _ in range(num_layers))
94
+ )
95
+ self.mlp_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, num_classes))
96
+ self.dropout = nn.Dropout(dropout)
97
+
98
+ # Parameters/Embeddings
99
+ self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
100
+ self.pos_embedding = nn.Parameter(torch.randn(1, 1 + num_patches, embed_dim))
101
+
102
+ def forward(self, x):
103
+ # Preprocess input
104
+ #x = x[0]
105
+ x = img_to_patch(x, self.patch_size)
106
+ B, T, _ = x.shape
107
+ x = self.input_layer(x)
108
+
109
+ # Add CLS token and positional encoding
110
+ cls_token = self.cls_token.repeat(B, 1, 1)
111
+ x = torch.cat([cls_token, x], dim=1)
112
+ x = x + self.pos_embedding[:, : T + 1]
113
+
114
+ # Apply Transforrmer
115
+ x = self.dropout(x)
116
+ x = x.transpose(0, 1)
117
+ x = self.transformer(x)
118
+
119
+ # Perform classification prediction
120
+ cls = x[0]
121
+ out = self.mlp_head(cls)
122
+ return out
123
+
124
+ class AttentionBlock(nn.Module):
125
+ def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
126
+ """Attention Block.
127
+
128
+ Args:
129
+ embed_dim: Dimensionality of input and attention feature vectors
130
+ hidden_dim: Dimensionality of hidden layer in feed-forward network
131
+ (usually 2-4x larger than embed_dim)
132
+ num_heads: Number of heads to use in the Multi-Head Attention block
133
+ dropout: Amount of dropout to apply in the feed-forward network
134
+
135
+ """
136
+ super().__init__()
137
+
138
+ self.layer_norm_1 = nn.LayerNorm(embed_dim)
139
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads)
140
+ self.layer_norm_2 = nn.LayerNorm(embed_dim)
141
+ self.linear = nn.Sequential(
142
+ nn.Linear(embed_dim, hidden_dim),
143
+ nn.GELU(),
144
+ nn.Dropout(dropout),
145
+ nn.Linear(hidden_dim, embed_dim),
146
+ nn.Dropout(dropout),
147
+ )
148
+
149
+ def forward(self, x):
150
+ inp_x = self.layer_norm_1(x)
151
+ x = x + self.attn(inp_x, inp_x, inp_x)[0]
152
+ x = x + self.linear(self.layer_norm_2(x))
153
+ return x
154
+
155
+ def img_to_patch(x, patch_size, flatten_channels=True):
156
+ """
157
+ Args:
158
+ x: Tensor representing the image of shape [B, C, H, W]
159
+ patch_size: Number of pixels per dimension of the patches (integer)
160
+ flatten_channels: If True, the patches will be returned in a flattened format
161
+ as a feature vector instead of a image grid.
162
+ """
163
+ x = x.permute(0, 3, 1, 2)
164
+ B, C, H, W = x.shape
165
+ x = x.reshape(B, C, H // patch_size, patch_size, W // patch_size, patch_size)
166
+ x = x.permute(0, 2, 4, 1, 3, 5) # [B, H', W', C, p_H, p_W]
167
+ x = x.flatten(1, 2) # [B, H'*W', C, p_H, p_W]
168
+ if flatten_channels:
169
+ x = x.flatten(2, 4) # [B, H'*W', C*p_H*p_W]
170
+ return x
flaring/MEGS_AI_baseline/train.py CHANGED
@@ -14,6 +14,7 @@ from pytorch_lightning.loggers import WandbLogger
14
  from pytorch_lightning.callbacks import ModelCheckpoint
15
  from torch.nn import MSELoss
16
  from SDOAIA_dataloader import AIA_GOESDataModule
 
17
  from models.linear_and_hybrid import LinearIrradianceModel, HybridIrradianceModel
18
  from callback import ImagePredictionLogger_SXR
19
  from pytorch_lightning.callbacks import Callback
@@ -166,14 +167,14 @@ pth_callback = PTHCheckpointCallback(
166
  )
167
 
168
  # Model
169
- if config_data['model']['architecture'] == 'linear':
170
  model = LinearIrradianceModel(
171
  d_input=6,
172
  d_output=1,
173
  lr= config_data['model']['lr'],
174
  loss_func=MSELoss()
175
  )
176
- elif config_data['model']['architecture'] == 'hybrid':
177
  model = HybridIrradianceModel(
178
  d_input=6,
179
  d_output=1,
@@ -182,8 +183,16 @@ elif config_data['model']['architecture'] == 'hybrid':
182
  cnn_dp=config_data['model']['cnn_dp'],
183
  lr=config_data['model']['lr'],
184
  )
 
 
 
 
 
 
 
 
185
  else:
186
- raise NotImplementedError(f"Architecture {config_data['model']['architecture']} not supported.")
187
 
188
  # Trainer
189
  trainer = Trainer(
 
14
  from pytorch_lightning.callbacks import ModelCheckpoint
15
  from torch.nn import MSELoss
16
  from SDOAIA_dataloader import AIA_GOESDataModule
17
+ from models.vision_transformer_custom import ViT
18
  from models.linear_and_hybrid import LinearIrradianceModel, HybridIrradianceModel
19
  from callback import ImagePredictionLogger_SXR
20
  from pytorch_lightning.callbacks import Callback
 
167
  )
168
 
169
  # Model
170
+ if config_data['selected_model'] == 'linear':
171
  model = LinearIrradianceModel(
172
  d_input=6,
173
  d_output=1,
174
  lr= config_data['model']['lr'],
175
  loss_func=MSELoss()
176
  )
177
+ elif config_data['selected_model'] == 'hybrid':
178
  model = HybridIrradianceModel(
179
  d_input=6,
180
  d_output=1,
 
183
  cnn_dp=config_data['model']['cnn_dp'],
184
  lr=config_data['model']['lr'],
185
  )
186
+ elif config_data['selected_model'] == 'ViT':
187
+ print("Using ViT")
188
+ # model = ViT(embed_dim=config_data['vit']['embed_dim'], hidden_dim=config_data['vit']['hidden_dim'],
189
+ # num_channels=config_data['vit']['num_channels'],num_heads=config_data['vit']['num_heads'],
190
+ # num_layers=config_data['vit']['num_layers'], num_classes=config_data['vit']['num_classes'],
191
+ # patch_size=config_data['vit']['patch_size'], num_patches=config_data['vit']['num_patches'],
192
+ # dropout=config_data['vit']['dropout'], lr=config_data['vit']['lr'])
193
+ model = ViT(model_kwargs=config_data['vit'])
194
  else:
195
+ raise NotImplementedError(f"Architecture {config_data['selected_model']} not supported.")
196
 
197
  # Trainer
198
  trainer = Trainer(