Commit ·
5eaed60
1
Parent(s): 71f2b7e
Rename localpatch.yaml to train_config.yaml and update notes; modify learning rate scheduler parameters and rename attention mask creation method
Browse files
forecasting/models/vit_patch_model_local.py
CHANGED
|
@@ -51,9 +51,9 @@ class ViTLocal(pl.LightningModule):
|
|
| 51 |
|
| 52 |
scheduler = CosineAnnealingWarmRestarts(
|
| 53 |
optimizer,
|
| 54 |
-
T_0=
|
| 55 |
-
T_mult=2,
|
| 56 |
-
eta_min=1e-7
|
| 57 |
)
|
| 58 |
|
| 59 |
return {
|
|
@@ -66,8 +66,6 @@ class ViTLocal(pl.LightningModule):
|
|
| 66 |
}
|
| 67 |
}
|
| 68 |
|
| 69 |
-
# M/X Class Flare Detection Optimized Weights
|
| 70 |
-
|
| 71 |
def _calculate_loss(self, batch, mode="train"):
|
| 72 |
imgs, sxr = batch
|
| 73 |
raw_preds, raw_patch_contributions = self.model(imgs, self.sxr_norm)
|
|
@@ -304,9 +302,9 @@ class InvertedAttentionBlock(nn.Module):
|
|
| 304 |
)
|
| 305 |
|
| 306 |
# Pre-compute attention mask for local interactions
|
| 307 |
-
self.register_buffer('attention_mask', self.
|
| 308 |
|
| 309 |
-
def
|
| 310 |
"""Create attention mask for local interactions only"""
|
| 311 |
# This creates a mask where only distant patches can attend to each other
|
| 312 |
|
|
|
|
| 51 |
|
| 52 |
scheduler = CosineAnnealingWarmRestarts(
|
| 53 |
optimizer,
|
| 54 |
+
T_0=150,
|
| 55 |
+
T_mult=2,
|
| 56 |
+
eta_min=1e-7
|
| 57 |
)
|
| 58 |
|
| 59 |
return {
|
|
|
|
| 66 |
}
|
| 67 |
}
|
| 68 |
|
|
|
|
|
|
|
| 69 |
def _calculate_loss(self, batch, mode="train"):
|
| 70 |
imgs, sxr = batch
|
| 71 |
raw_preds, raw_patch_contributions = self.model(imgs, self.sxr_norm)
|
|
|
|
| 302 |
)
|
| 303 |
|
| 304 |
# Pre-compute attention mask for local interactions
|
| 305 |
+
self.register_buffer('attention_mask', self._create_inverted_attention_mask())
|
| 306 |
|
| 307 |
+
def _create_inverted_attention_mask(self):
|
| 308 |
"""Create attention mask for local interactions only"""
|
| 309 |
# This creates a mask where only distant patches can attend to each other
|
| 310 |
|
forecasting/training/{localpatch.yaml → train_config.yaml}
RENAMED
|
@@ -52,4 +52,4 @@ wandb:
|
|
| 52 |
- sxr
|
| 53 |
- regression
|
| 54 |
wb_name: paper-8-patch-4ch
|
| 55 |
-
notes: Regression from AIA images
|
|
|
|
| 52 |
- sxr
|
| 53 |
- regression
|
| 54 |
wb_name: paper-8-patch-4ch
|
| 55 |
+
notes: Regression from AIA images to SXR images using ViTLocal model with 8x8 patches
|