griffingoodwin04 commited on
Commit
7e470a6
·
1 Parent(s): d7a4a12

Experimenting With Fusion Model

Browse files
forecasting/inference/evaluation.py CHANGED
@@ -941,10 +941,10 @@ class SolarFlareEvaluator:
941
 
942
  if __name__ == "__main__":
943
  # Example paths - replace with your actual paths
944
- vit_csv = "/mnt/data/ML-READY/output/final_epoch_patch.csv"
945
  baseline_results_csv = ""
946
  aia_data = "/mnt/data/ML-READY/AIA/test/"
947
- weights_directory = "/mnt/data/ML-READY/final_epoch_patch_weights_final"
948
 
949
  # Sample timestamps - Fixed the datetime generation
950
  start_time = datetime(2023, 8, 5, 20,30,00)
 
941
 
942
  if __name__ == "__main__":
943
  # Example paths - replace with your actual paths
944
+ vit_csv = "/mnt/data/ML-READY/output/patch.csv"
945
  baseline_results_csv = ""
946
  aia_data = "/mnt/data/ML-READY/AIA/test/"
947
+ weights_directory = "/mnt/data/ML-READY/patch_weights"
948
 
949
  # Sample timestamps - Fixed the datetime generation
950
  start_time = datetime(2023, 8, 5, 20,30,00)
forecasting/inference/inference_on_patch.py CHANGED
@@ -15,6 +15,7 @@ from torch.utils.data import DataLoader
15
  from forecasting.data_loaders.SDOAIA_dataloader import AIA_GOESDataset
16
  import forecasting.models as models
17
  from forecasting.models.vit_patch_model import ViT
 
18
  from forecasting.models.linear_and_hybrid import HybridIrradianceModel # Add your hybrid model import
19
  from forecasting.training.callback import unnormalize_sxr
20
  import yaml
@@ -28,7 +29,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28
 
29
  def has_attention_weights(model):
30
  """Check if model supports attention weights"""
31
- return hasattr(model, 'attention') or isinstance(model, ViT)
32
 
33
 
34
  def save_batch_flux_contributions(batch_flux_contributions, batch_idx, batch_size, times, flux_path, sxr_norm=None):
@@ -201,7 +202,7 @@ def load_model_from_config(config_data):
201
  model_class = getattr(models, model_type)
202
  model = model_class.load_from_checkpoint(checkpoint_path)
203
  except AttributeError:
204
- raise ValueError(f"Unknown model type: {model_type}. Available types: ViT, HybridIrradianceModel")
205
  else:
206
  # Regular PyTorch checkpoint
207
  state = torch.load(checkpoint_path, map_location=device, weights_only=False)
@@ -243,7 +244,7 @@ def main():
243
  parser.add_argument('-config', type=str, default='config.yaml', required=True, help='Path to config YAML.')
244
  parser.add_argument('-input_size', type=int, default=512, help='Input size for the model')
245
  parser.add_argument('-patch_size', type=int, default=16, help='Patch size for the model')
246
- parser.add_argument('--batch_size', type=int, default=4, help='Batch size for inference')
247
  parser.add_argument('--no_weights', action='store_true', help='Skip saving attention weights to speed up')
248
  args = parser.parse_args()
249
 
 
15
  from forecasting.data_loaders.SDOAIA_dataloader import AIA_GOESDataset
16
  import forecasting.models as models
17
  from forecasting.models.vit_patch_model import ViT
18
+ from forecasting.models import FusionViTHybrid
19
  from forecasting.models.linear_and_hybrid import HybridIrradianceModel # Add your hybrid model import
20
  from forecasting.training.callback import unnormalize_sxr
21
  import yaml
 
29
 
30
  def has_attention_weights(model):
31
  """Check if model supports attention weights"""
32
+ return hasattr(model, 'attention') or isinstance(model, ViT) or isinstance(model, FusionViTHybrid)
33
 
34
 
35
  def save_batch_flux_contributions(batch_flux_contributions, batch_idx, batch_size, times, flux_path, sxr_norm=None):
 
202
  model_class = getattr(models, model_type)
203
  model = model_class.load_from_checkpoint(checkpoint_path)
204
  except AttributeError:
205
+ raise ValueError(f"Unknown model type: {model_type}. Available types include: ViT, HybridIrradianceModel, FusionViTHybrid")
206
  else:
207
  # Regular PyTorch checkpoint
208
  state = torch.load(checkpoint_path, map_location=device, weights_only=False)
 
244
  parser.add_argument('-config', type=str, default='config.yaml', required=True, help='Path to config YAML.')
245
  parser.add_argument('-input_size', type=int, default=512, help='Input size for the model')
246
  parser.add_argument('-patch_size', type=int, default=16, help='Patch size for the model')
247
+ parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
248
  parser.add_argument('--no_weights', action='store_true', help='Skip saving attention weights to speed up')
249
  args = parser.parse_args()
250
 
forecasting/inference/inference_on_patch_config.yaml CHANGED
@@ -1,6 +1,6 @@
1
  base_data_dir: "/mnt/data/ML-READY/" # Change this line for different datasets
2
- output_path: "${base_data_dir}/output/final_epoch_patch.csv"
3
- weight_path: "${base_data_dir}/final_epoch_patch_weights_final/"
4
  flux_path: "${base_data_dir}/patch_flux/"
5
  mc:
6
  active: "false"
@@ -27,5 +27,5 @@ data:
27
  sxr_norm_path:
28
  "/mnt/data/ML-READY/SXR/normalized_sxr.npy"
29
  checkpoint_path:
30
- "/mnt/data/ML-READY/new-checkpoint/vit-16-higher-weight-lower-decay-epoch=288-val_total_loss=0.0385.ckpt"
31
 
 
1
  base_data_dir: "/mnt/data/ML-READY/" # Change this line for different datasets
2
+ output_path: "${base_data_dir}/output/patch.csv"
3
+ weight_path: "${base_data_dir}/patch_weights/"
4
  flux_path: "${base_data_dir}/patch_flux/"
5
  mc:
6
  active: "false"
 
27
  sxr_norm_path:
28
  "/mnt/data/ML-READY/SXR/normalized_sxr.npy"
29
  checkpoint_path:
30
+ "/mnt/data/ML-READY/new-checkpoint/vit-16-MSE-deeper-epoch=51-val_total_loss=0.1064.ckpt"
31
 
forecasting/models/__init__.py CHANGED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .fusion_vit_hybrid import FusionViTHybrid
2
+
forecasting/models/fusion_vit_hybrid.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import pytorch_lightning as pl
5
+
6
+ from .vit_patch_model import VisionTransformer, SXRRegressionDynamicLoss, normalize_sxr, unnormalize_sxr
7
+ from .linear_and_hybrid import LinearIrradianceModel, HybridIrradianceModel
8
+
9
+
10
+ class FusionViTHybrid(pl.LightningModule):
11
+ """End-to-end fused model: ViT for spatial patches + Linear/Hybrid for scalar.
12
+
13
+ - ViT branch outputs per-patch raw flux and a ViT global (sum of patches).
14
+ - Scalar branch (Linear or Hybrid) outputs a global scalar.
15
+ - A learnable gate blends the two globals; the spatial map uses ViT's
16
+ distribution but is calibrated to the fused/global prediction.
17
+
18
+ Forward returns a 4-tuple compatible with existing inference utils:
19
+ (global_fused, attention_weights, fused_patch_flux, global_fused)
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ vit_kwargs: dict,
25
+ scalar_branch: str,
26
+ scalar_kwargs: dict,
27
+ sxr_norm,
28
+ lr: float = 1e-4,
29
+ lambda_vit_to_target: float = 0.3,
30
+ lambda_scalar_to_target: float = 0.1,
31
+ use_attention: bool = True,
32
+ learnable_gate: bool = True,
33
+ gate_init_bias: float = 5.0,
34
+ weight_decay: float = 1e-5,
35
+ cosine_restart_T0: int = 50,
36
+ cosine_restart_Tmult: int = 2,
37
+ cosine_eta_min: float = 1e-7,
38
+ ):
39
+ super().__init__()
40
+
41
+ # Save hyperparameters needed for checkpointing
42
+ self.save_hyperparameters(ignore=["sxr_norm"]) # sxr_norm is a tensor/array
43
+
44
+ # Branches: filter unsupported keys for VisionTransformer
45
+ filtered_vit_kwargs = dict(vit_kwargs)
46
+ filtered_vit_kwargs.pop('lr', None)
47
+ filtered_vit_kwargs.pop('num_classes', None)
48
+ self.vit = VisionTransformer(**filtered_vit_kwargs)
49
+
50
+ if scalar_branch.lower() in ["linear", "lineairradiancemodel"]:
51
+ self.scalar = LinearIrradianceModel(
52
+ d_input=scalar_kwargs.get("d_input"),
53
+ d_output=scalar_kwargs.get("d_output"),
54
+ loss_func=scalar_kwargs.get("loss_func", nn.HuberLoss()),
55
+ lr=scalar_kwargs.get("lr", lr),
56
+ )
57
+ elif scalar_branch.lower() in ["hybrid", "hybridirradiancemodel"]:
58
+ self.scalar = HybridIrradianceModel(
59
+ d_input=scalar_kwargs.get("d_input"),
60
+ d_output=scalar_kwargs.get("d_output"),
61
+ cnn_model=scalar_kwargs.get("cnn_model", "updated"),
62
+ ln_model=scalar_kwargs.get("ln_model", True),
63
+ ln_params=scalar_kwargs.get("ln_params", None),
64
+ lr=scalar_kwargs.get("lr", lr),
65
+ cnn_dp=scalar_kwargs.get("cnn_dp", 0.75),
66
+ loss_func=scalar_kwargs.get("loss_func", nn.HuberLoss()),
67
+ )
68
+ else:
69
+ raise ValueError(f"Unknown scalar_branch: {scalar_branch}")
70
+
71
+ # Loss and normalization
72
+ self.sxr_norm = sxr_norm
73
+ self.adaptive_loss = SXRRegressionDynamicLoss(window_size=1500)
74
+
75
+ # Gate: learnable scalar in [0,1] blending scalar vs vit global
76
+ self.learnable_gate = learnable_gate
77
+ if learnable_gate:
78
+ self.gate_logit = nn.Parameter(torch.tensor(gate_init_bias, dtype=torch.float32))
79
+ else:
80
+ self.register_buffer("gate_logit", torch.tensor(gate_init_bias, dtype=torch.float32))
81
+
82
+ # Optim params
83
+ self.lr = lr
84
+ self.weight_decay = weight_decay
85
+ self.cosine_restart_T0 = cosine_restart_T0
86
+ self.cosine_restart_Tmult = cosine_restart_Tmult
87
+ self.cosine_eta_min = cosine_eta_min
88
+
89
+ # Aux loss weights
90
+ self.lambda_vit_to_target = lambda_vit_to_target
91
+ self.lambda_scalar_to_target = lambda_scalar_to_target
92
+
93
+ # Whether to compute/return attention
94
+ self.use_attention = use_attention
95
+
96
+ def forward(self, x, return_attention: bool = True):
97
+ # ViT branch: returns different numbers of values based on return_attention
98
+ vit_out = self.vit(x, self.sxr_norm, return_attention=(self.use_attention and return_attention))
99
+
100
+ if self.use_attention and return_attention and len(vit_out) == 3:
101
+ global_vit_raw, attention_weights, patch_flux_raw = vit_out
102
+ else:
103
+ global_vit_raw, patch_flux_raw = vit_out
104
+ attention_weights = None
105
+
106
+ # Scalar branch expects (B,H,W,C)
107
+ global_scalar_raw = self.scalar(x)
108
+ # Ensure positivity for SXR-like targets
109
+ global_scalar_raw = F.softplus(global_scalar_raw)
110
+
111
+ # Shapes: ensure tensors are shaped [B, 1]
112
+ if global_vit_raw.dim() == 1:
113
+ global_vit_raw = global_vit_raw.unsqueeze(-1)
114
+ if global_scalar_raw.dim() == 1:
115
+ global_scalar_raw = global_scalar_raw.unsqueeze(-1)
116
+
117
+ # Patch weights from ViT distribution
118
+ weights = patch_flux_raw / (global_vit_raw.clamp(min=1e-15))
119
+
120
+ # Blend globals via sigmoid(gate_logit)
121
+ gate = torch.sigmoid(self.gate_logit)
122
+ global_fused = gate * global_scalar_raw + (1.0 - gate) * global_vit_raw
123
+ # Avoid zeros/negatives before log normalization downstream
124
+ global_fused = global_fused.clamp(min=1e-15)
125
+
126
+ # Calibrated patch flux using fused global
127
+ fused_patch_flux = global_fused * weights
128
+
129
+ # Match inference API: (pred, attn, patch_flux, total_from_patches)
130
+ return global_fused, attention_weights, fused_patch_flux, global_fused
131
+ def forward_for_callback(self, x, return_attention: bool = True):
132
+ """Forward method compatible with AttentionMapCallback"""
133
+ global_fused, attention_weights, fused_patch_flux, _ = self.forward(x, return_attention)
134
+ # Callback expects (outputs, attention_weights, _)
135
+ return attention_weights
136
+ def _calc_losses(self, imgs, sxr):
137
+ # Forward
138
+ global_fused, attention_weights, fused_patch_flux, _ = self(imgs, return_attention=True)
139
+
140
+ # Main adaptive loss on fused global
141
+ raw_preds_squeezed = torch.squeeze(global_fused)
142
+ sxr_un = unnormalize_sxr(sxr, self.sxr_norm)
143
+ norm_preds_squeezed = normalize_sxr(raw_preds_squeezed, self.sxr_norm)
144
+ main_loss, weights_adapt = self.adaptive_loss.calculate_loss(
145
+ norm_preds_squeezed, sxr, sxr_un, raw_preds_squeezed
146
+ )
147
+
148
+ # Auxiliary consistency losses (vit and scalar heads individually)
149
+ # Recompute heads without extra forward
150
+ # Extract vit global by re-running vit without attention to save memory
151
+ with torch.no_grad():
152
+ vit_out = self.vit(imgs, self.sxr_norm, return_attention=False)
153
+ global_vit_raw = vit_out[0]
154
+ if global_vit_raw.dim() > 1:
155
+ global_vit_raw = torch.squeeze(global_vit_raw)
156
+ global_vit_raw = global_vit_raw.clamp(min=1e-15)
157
+ vit_norm = normalize_sxr(global_vit_raw, self.sxr_norm)
158
+ loss_vit = F.huber_loss(vit_norm, sxr)
159
+
160
+ global_scalar_raw = self.scalar(imgs)
161
+ global_scalar_raw = F.softplus(global_scalar_raw)
162
+ if global_scalar_raw.dim() > 1:
163
+ global_scalar_raw = torch.squeeze(global_scalar_raw)
164
+ global_scalar_raw = global_scalar_raw.clamp(min=1e-15)
165
+ scalar_norm = normalize_sxr(global_scalar_raw, self.sxr_norm)
166
+ loss_scalar = F.huber_loss(scalar_norm, sxr)
167
+
168
+ total_loss = main_loss \
169
+ + self.lambda_vit_to_target * loss_vit \
170
+ + self.lambda_scalar_to_target * loss_scalar
171
+
172
+ return total_loss, {
173
+ "main_loss": main_loss.detach(),
174
+ "loss_vit": loss_vit.detach(),
175
+ "loss_scalar": loss_scalar.detach(),
176
+ }
177
+
178
+ def training_step(self, batch, batch_idx):
179
+ imgs, sxr = batch
180
+ total_loss, logs = self._calc_losses(imgs, sxr)
181
+
182
+ # Logs
183
+ self.log("train_main_loss", logs["main_loss"], on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
184
+ self.log("train_vit_loss", logs["loss_vit"], on_step=True, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
185
+ self.log("train_scalar_loss", logs["loss_scalar"], on_step=True, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
186
+
187
+ # Learning rate
188
+ current_lr = self.trainer.optimizers[0].param_groups[0]['lr']
189
+ self.log('learning_rate', current_lr, on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True)
190
+
191
+ return total_loss
192
+
193
+ def validation_step(self, batch, batch_idx):
194
+ imgs, sxr = batch
195
+ total_loss, logs = self._calc_losses(imgs, sxr)
196
+ self.log("val_main_loss", logs["main_loss"], on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
197
+ self.log("val_total_loss", total_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
198
+ return total_loss
199
+
200
+ def test_step(self, batch, batch_idx):
201
+ imgs, sxr = batch
202
+ total_loss, _ = self._calc_losses(imgs, sxr)
203
+ self.log("test_total_loss", total_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
204
+ return total_loss
205
+
206
+ def configure_optimizers(self):
207
+ optimizer = torch.optim.AdamW(
208
+ self.parameters(),
209
+ lr=self.lr,
210
+ weight_decay=self.weight_decay,
211
+ )
212
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
213
+ optimizer,
214
+ T_0=self.cosine_restart_T0,
215
+ T_mult=self.cosine_restart_Tmult,
216
+ eta_min=self.cosine_eta_min,
217
+ )
218
+ return {
219
+ 'optimizer': optimizer,
220
+ 'lr_scheduler': {
221
+ 'scheduler': scheduler,
222
+ 'interval': 'epoch',
223
+ 'frequency': 1,
224
+ 'name': 'learning_rate'
225
+ }
226
+ }
227
+
228
+
forecasting/models/vit_patch_model.py CHANGED
@@ -306,8 +306,8 @@ class SXRRegressionDynamicLoss:
306
  'x_class': 20.0
307
  }
308
  def calculate_loss(self, preds_squeezed, sxr, sxr_un, preds_squeezed_un):
309
- #base_loss = F.huber_loss(preds_squeezed, sxr, delta=1.0, reduction='none')
310
- base_loss = F.mse_loss(preds_squeezed, sxr, reduction='none')
311
  weights = self._get_adaptive_weights(sxr_un, preds_squeezed_un, base_loss)
312
  self._update_tracking(sxr_un, preds_squeezed_un, base_loss)
313
  weighted_loss = base_loss * weights
 
306
  'x_class': 20.0
307
  }
308
  def calculate_loss(self, preds_squeezed, sxr, sxr_un, preds_squeezed_un):
309
+ base_loss = F.huber_loss(preds_squeezed, sxr, delta=1.0, reduction='none')
310
+ #base_loss = F.mse_loss(preds_squeezed, sxr, reduction='none')
311
  weights = self._get_adaptive_weights(sxr_un, preds_squeezed_un, base_loss)
312
  self._update_tracking(sxr_un, preds_squeezed_un, base_loss)
313
  weighted_loss = base_loss * weights
forecasting/training/callback.py CHANGED
@@ -124,7 +124,11 @@ class AttentionMapCallback(Callback):
124
  imgs = imgs[:self.num_samples].to(pl_module.device)
125
 
126
  # Get predictions with attention weights
127
- outputs, attention_weights, _ = pl_module(imgs, return_attention=True)
 
 
 
 
128
 
129
  # Visualize attention for each sample
130
  for sample_idx in range(min(self.num_samples, imgs.size(0))):
@@ -134,7 +138,7 @@ class AttentionMapCallback(Callback):
134
  attention_weights,
135
  sample_idx,
136
  trainer.current_epoch,
137
- pl_module.model.patch_size
138
  )
139
  trainer.logger.experiment.log({"Attention plots": wandb.Image(map)})
140
  plt.close(map)
 
124
  imgs = imgs[:self.num_samples].to(pl_module.device)
125
 
126
  # Get predictions with attention weights
127
+ #Dynamically extract attention weights from the model
128
+ try:
129
+ outputs, attention_weights, _ = pl_module(imgs, return_attention=True)
130
+ except:
131
+ attention_weights = pl_module.forward_for_callback(imgs, return_attention=True)
132
 
133
  # Visualize attention for each sample
134
  for sample_idx in range(min(self.num_samples, imgs.size(0))):
 
138
  attention_weights,
139
  sample_idx,
140
  trainer.current_epoch,
141
+ patch_size=16
142
  )
143
  trainer.logger.experiment.log({"Attention plots": wandb.Image(map)})
144
  plt.close(map)
forecasting/training/config.yaml CHANGED
@@ -4,8 +4,8 @@ base_data_dir: "/mnt/data/ML-READY" # Change this line for different datasets
4
  base_checkpoint_dir: "/mnt/data/ML-READY" # Change this line for different datasets
5
  wavelengths: [94, 131, 171, 193, 211, 304] # AIA wavelengths in Angstroms
6
  # Model configuration
7
- selected_model: "ViT Patch" # Options: "cnn", "vit",
8
- batch_size: 80
9
  epochs: 500
10
  oversample: false
11
  balance_strategy: "upsample_minority"
@@ -23,12 +23,27 @@ vit_custom:
23
  num_classes: 1
24
  patch_size: 16
25
  num_patches: 1024
26
- hidden_dim: 1024
27
  num_heads: 8
28
  num_layers: 6
29
  dropout: 0.1
30
  lr: 0.0001
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  # Data paths (automatically constructed from base directories)
33
  data:
34
  aia_dir:
@@ -48,5 +63,5 @@ wandb:
48
  - aia
49
  - sxr
50
  - regression
51
- wb_name: vit-16-MSE-deeper
52
  notes: Regression from AIA images (6 channels) to GOES SXR flux
 
4
  base_checkpoint_dir: "/mnt/data/ML-READY" # Change this line for different datasets
5
  wavelengths: [94, 131, 171, 193, 211, 304] # AIA wavelengths in Angstroms
6
  # Model configuration
7
+ selected_model: "FusionViTHybrid" # Options: "cnn", "vit",
8
+ batch_size: 16
9
  epochs: 500
10
  oversample: false
11
  balance_strategy: "upsample_minority"
 
23
  num_classes: 1
24
  patch_size: 16
25
  num_patches: 1024
26
+ hidden_dim: 512
27
  num_heads: 8
28
  num_layers: 6
29
  dropout: 0.1
30
  lr: 0.0001
31
 
32
+
33
+ fusion:
34
+ scalar_branch: "hybrid" # or "linear"
35
+ lr: 0.0001
36
+ lambda_vit_to_target: 0.3
37
+ lambda_scalar_to_target: 0.1
38
+ learnable_gate: true
39
+ gate_init_bias: 5.0
40
+ scalar_kwargs:
41
+ d_input: 6
42
+ d_output: 1
43
+ cnn_model: "updated"
44
+ cnn_dp: 0.75
45
+
46
+
47
  # Data paths (automatically constructed from base directories)
48
  data:
49
  aia_dir:
 
63
  - aia
64
  - sxr
65
  - regression
66
+ wb_name: vit-fused-model
67
  notes: Regression from AIA images (6 channels) to GOES SXR flux
forecasting/training/train.py CHANGED
@@ -22,6 +22,7 @@ from forecasting.data_loaders.SDOAIA_dataloader import AIA_GOESDataModule
22
  from forecasting.models.vision_transformer_custom import ViT
23
  from forecasting.models.linear_and_hybrid import LinearIrradianceModel, HybridIrradianceModel
24
  from forecasting.models.vit_patch_model import ViT as ViTPatch
 
25
  from callback import ImagePredictionLogger_SXR, AttentionMapCallback
26
  from pytorch_lightning.callbacks import Callback
27
 
@@ -204,14 +205,39 @@ elif config_data['selected_model'] == 'hybrid':
204
  elif config_data['selected_model'] == 'ViT':
205
  model = ViT(model_kwargs=config_data['vit_custom'], sxr_norm = sxr_norm)
206
 
207
- elif config_data['selected_model'] == 'ViT Patch':
208
  model = ViTPatch(model_kwargs=config_data['vit_custom'], sxr_norm = sxr_norm)
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  else:
211
  raise NotImplementedError(f"Architecture {config_data['selected_model']} not supported.")
212
 
213
  # Trainer
214
- if config_data['selected_model'] == 'ViT' or config_data['selected_model'] == 'ViT Patch':
215
  trainer = Trainer(
216
  default_root_dir=config_data['data']['checkpoints_dir'],
217
  accelerator="gpu" if torch.cuda.is_available() else "cpu",
 
22
  from forecasting.models.vision_transformer_custom import ViT
23
  from forecasting.models.linear_and_hybrid import LinearIrradianceModel, HybridIrradianceModel
24
  from forecasting.models.vit_patch_model import ViT as ViTPatch
25
+ from forecasting.models import FusionViTHybrid
26
  from callback import ImagePredictionLogger_SXR, AttentionMapCallback
27
  from pytorch_lightning.callbacks import Callback
28
 
 
205
  elif config_data['selected_model'] == 'ViT':
206
  model = ViT(model_kwargs=config_data['vit_custom'], sxr_norm = sxr_norm)
207
 
208
+ elif config_data['selected_model'] == 'ViTPatch':
209
  model = ViTPatch(model_kwargs=config_data['vit_custom'], sxr_norm = sxr_norm)
210
 
211
+ elif config_data['selected_model'] == 'FusionViTHybrid':
212
+ # Expect a 'fusion' section in YAML
213
+ fusion_cfg = config_data.get('fusion', {})
214
+ scalar_branch = fusion_cfg.get('scalar_branch', 'hybrid')
215
+ scalar_kwargs = fusion_cfg.get('scalar_kwargs', {
216
+ 'd_input': len(config_data['wavelengths']),
217
+ 'd_output': 1,
218
+ 'cnn_model': config_data.get('megsai', {}).get('cnn_model', 'updated'),
219
+ 'cnn_dp': config_data.get('megsai', {}).get('cnn_dp', 0.75),
220
+ 'lr': fusion_cfg.get('lr', config_data.get('megsai', {}).get('lr', 1e-4)),
221
+ })
222
+ vit_kwargs = config_data.get('vit_custom', {})
223
+
224
+ model = FusionViTHybrid(
225
+ vit_kwargs=vit_kwargs,
226
+ scalar_branch=scalar_branch,
227
+ scalar_kwargs=scalar_kwargs,
228
+ sxr_norm=sxr_norm,
229
+ lr=fusion_cfg.get('lr', 1e-4),
230
+ lambda_vit_to_target=fusion_cfg.get('lambda_vit_to_target', 0.3),
231
+ lambda_scalar_to_target=fusion_cfg.get('lambda_scalar_to_target', 0.1),
232
+ learnable_gate=fusion_cfg.get('learnable_gate', True),
233
+ gate_init_bias=fusion_cfg.get('gate_init_bias', 5.0),
234
+ )
235
+
236
  else:
237
  raise NotImplementedError(f"Architecture {config_data['selected_model']} not supported.")
238
 
239
  # Trainer
240
+ if config_data['selected_model'] == 'ViT' or config_data['selected_model'] == 'ViT Patch' or config_data['selected_model'] == 'FusionViTHybrid':
241
  trainer = Trainer(
242
  default_root_dir=config_data['data']['checkpoints_dir'],
243
  accelerator="gpu" if torch.cuda.is_available() else "cpu",