griffingoodwin04 commited on
Commit
5eaed60
·
1 Parent(s): 71f2b7e

Rename localpatch.yaml to train_config.yaml and update notes; modify learning rate scheduler parameters and rename attention mask creation method

Browse files
forecasting/models/vit_patch_model_local.py CHANGED
@@ -51,9 +51,9 @@ class ViTLocal(pl.LightningModule):
51
 
52
  scheduler = CosineAnnealingWarmRestarts(
53
  optimizer,
54
- T_0=250, # Restart every 20 epochs
55
- T_mult=2, # Double the cycle length after each restart
56
- eta_min=1e-7 # Minimum learning rate
57
  )
58
 
59
  return {
@@ -66,8 +66,6 @@ class ViTLocal(pl.LightningModule):
66
  }
67
  }
68
 
69
- # M/X Class Flare Detection Optimized Weights
70
-
71
  def _calculate_loss(self, batch, mode="train"):
72
  imgs, sxr = batch
73
  raw_preds, raw_patch_contributions = self.model(imgs, self.sxr_norm)
@@ -304,9 +302,9 @@ class InvertedAttentionBlock(nn.Module):
304
  )
305
 
306
  # Pre-compute attention mask for local interactions
307
- self.register_buffer('attention_mask', self._create_local_attention_mask())
308
 
309
- def _create_local_attention_mask(self):
310
  """Create attention mask for local interactions only"""
311
  # This creates a mask where only distant patches can attend to each other
312
 
 
51
 
52
  scheduler = CosineAnnealingWarmRestarts(
53
  optimizer,
54
+ T_0=150,
55
+ T_mult=2,
56
+ eta_min=1e-7
57
  )
58
 
59
  return {
 
66
  }
67
  }
68
 
 
 
69
  def _calculate_loss(self, batch, mode="train"):
70
  imgs, sxr = batch
71
  raw_preds, raw_patch_contributions = self.model(imgs, self.sxr_norm)
 
302
  )
303
 
304
  # Pre-compute attention mask for local interactions
305
+ self.register_buffer('attention_mask', self._create_inverted_attention_mask())
306
 
307
+ def _create_inverted_attention_mask(self):
308
  """Create attention mask for local interactions only"""
309
  # This creates a mask where only distant patches can attend to each other
310
 
forecasting/training/{localpatch.yaml → train_config.yaml} RENAMED
@@ -52,4 +52,4 @@ wandb:
52
  - sxr
53
  - regression
54
  wb_name: paper-8-patch-4ch
55
- notes: Regression from AIA images (4 channels) to GOES SXR flux
 
52
  - sxr
53
  - regression
54
  wb_name: paper-8-patch-4ch
55
+ notes: Regression from AIA images to SXR images using ViTLocal model with 8x8 patches