Commit ·
3720287
1
Parent(s): 5493436
added megai baseline
Browse files- flaring/MEGS_AI_baseline/SDOAIA_dataloader.py +153 -0
- flaring/MEGS_AI_baseline/__init__.py +0 -0
- flaring/MEGS_AI_baseline/base_model.py +45 -0
- flaring/MEGS_AI_baseline/chopped_alexnet.py +50 -0
- flaring/MEGS_AI_baseline/efficientnet.py +45 -0
- flaring/MEGS_AI_baseline/kan_success.py +219 -0
- flaring/MEGS_AI_baseline/linear_and_hybrid.py +119 -0
- flaring/MEGS_AI_baseline/models/base_model.py +45 -0
- flaring/MEGS_AI_baseline/models/chopped_alexnet.py +50 -0
- flaring/MEGS_AI_baseline/models/efficientnet.py +45 -0
- flaring/MEGS_AI_baseline/models/kan_success.py +219 -0
- flaring/MEGS_AI_baseline/models/linear_and_hybrid.py +119 -0
- flaring/MEGS_AI_baseline/sxr_normalization.py +57 -0
- flaring/MEGS_AI_baseline/train.py +179 -0
- flaring/__init__.py +0 -0
- flaring/cut_off_aia.py +38 -0
flaring/MEGS_AI_baseline/SDOAIA_dataloader.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import DataLoader, Subset
|
| 3 |
+
import numpy as np
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from scipy.ndimage import zoom
|
| 6 |
+
import torchvision.transforms as T
|
| 7 |
+
from pytorch_lightning import LightningDataModule
|
| 8 |
+
import glob
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
class AIA_GOESDataset(torch.utils.data.Dataset):
|
| 12 |
+
"""Dataset for loading AIA images and SXR values for regression."""
|
| 13 |
+
|
| 14 |
+
def __init__(self, aia_dir, sxr_dir, transform=None, sxr_transform=None, target_size=(512, 512)):
|
| 15 |
+
self.aia_dir = Path(aia_dir).resolve()
|
| 16 |
+
self.sxr_dir = Path(sxr_dir).resolve()
|
| 17 |
+
self.transform = transform
|
| 18 |
+
self.sxr_transform = sxr_transform
|
| 19 |
+
self.target_size = target_size
|
| 20 |
+
self.samples = []
|
| 21 |
+
|
| 22 |
+
# Check directories
|
| 23 |
+
if not self.aia_dir.is_dir():
|
| 24 |
+
raise FileNotFoundError(f"AIA directory not found: {self.aia_dir}")
|
| 25 |
+
if not self.sxr_dir.is_dir():
|
| 26 |
+
raise FileNotFoundError(f"SXR directory not found: {self.sxr_dir}")
|
| 27 |
+
|
| 28 |
+
# Find matching files
|
| 29 |
+
aia_files = sorted(glob.glob(str(self.aia_dir / "*.npy")))
|
| 30 |
+
aia_files = [Path(f) for f in aia_files]
|
| 31 |
+
|
| 32 |
+
for f in aia_files:
|
| 33 |
+
timestamp = f.stem
|
| 34 |
+
sxr_path = self.sxr_dir / f"{timestamp}.npy"
|
| 35 |
+
if sxr_path.exists():
|
| 36 |
+
self.samples.append(timestamp)
|
| 37 |
+
|
| 38 |
+
if len(self.samples) == 0:
|
| 39 |
+
raise ValueError("No valid sample pairs found")
|
| 40 |
+
|
| 41 |
+
def __len__(self):
|
| 42 |
+
return len(self.samples)
|
| 43 |
+
|
| 44 |
+
def __getitem__(self, idx):
|
| 45 |
+
timestamp = self.samples[idx]
|
| 46 |
+
aia_path = self.aia_dir / f"{timestamp}.npy"
|
| 47 |
+
sxr_path = self.sxr_dir / f"{timestamp}.npy"
|
| 48 |
+
|
| 49 |
+
# Load AIA image as (6, H, W)
|
| 50 |
+
aia_img = np.load(aia_path)
|
| 51 |
+
if aia_img.shape[0] != 6:
|
| 52 |
+
raise ValueError(f"AIA image has {aia_img.shape[0]} channels, expected 6")
|
| 53 |
+
|
| 54 |
+
# Resize if needed (operates on (6, H, W))
|
| 55 |
+
if aia_img.shape[1:3] != self.target_size:
|
| 56 |
+
aia_img = zoom(aia_img, (1,
|
| 57 |
+
self.target_size[0]/aia_img.shape[1],
|
| 58 |
+
self.target_size[1]/aia_img.shape[2]))
|
| 59 |
+
# Convert to torch for transforms
|
| 60 |
+
aia_img = torch.tensor(aia_img, dtype=torch.float32) # (6, H, W)
|
| 61 |
+
|
| 62 |
+
# Apply transforms (should expect channel-first (C, H, W))
|
| 63 |
+
if self.transform:
|
| 64 |
+
aia_img = self.transform(aia_img)
|
| 65 |
+
|
| 66 |
+
# Always output channel-last for model: (H, W, C)
|
| 67 |
+
aia_img = aia_img.permute(1,2,0) # (H, W, 6)
|
| 68 |
+
|
| 69 |
+
# Load SXR value
|
| 70 |
+
sxr_val = np.load(sxr_path)
|
| 71 |
+
if sxr_val.size != 1:
|
| 72 |
+
raise ValueError(f"SXR value has size {sxr_val.size}, expected scalar")
|
| 73 |
+
sxr_val = float(np.atleast_1d(sxr_val).flatten()[0])
|
| 74 |
+
if self.sxr_transform:
|
| 75 |
+
sxr_val = self.sxr_transform(sxr_val)
|
| 76 |
+
|
| 77 |
+
return (aia_img, torch.tensor(sxr_val, dtype=torch.float32)), torch.tensor(sxr_val, dtype=torch.float32)
|
| 78 |
+
|
| 79 |
+
class AIA_GOESDataModule(LightningDataModule):
|
| 80 |
+
"""PyTorch Lightning DataModule for AIA and SXR data."""
|
| 81 |
+
|
| 82 |
+
def __init__(self, aia_dir, sxr_dir, sxr_norm, batch_size=16, num_workers=4,
|
| 83 |
+
train_transforms=None, val_transforms=None, val_split=0.2, test_split=0.1):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.aia_dir = aia_dir
|
| 86 |
+
self.sxr_dir = sxr_dir
|
| 87 |
+
self.sxr_norm = sxr_norm
|
| 88 |
+
self.batch_size = batch_size
|
| 89 |
+
self.num_workers = num_workers
|
| 90 |
+
self.train_transforms = train_transforms
|
| 91 |
+
self.val_transforms = val_transforms
|
| 92 |
+
self.val_split = val_split
|
| 93 |
+
self.test_split = test_split
|
| 94 |
+
|
| 95 |
+
def setup(self, stage=None):
|
| 96 |
+
# Prepare base set just to get indices
|
| 97 |
+
base_ds = AIA_GOESDataset(
|
| 98 |
+
aia_dir=self.aia_dir,
|
| 99 |
+
sxr_dir=self.sxr_dir,
|
| 100 |
+
transform=None,
|
| 101 |
+
sxr_transform=T.Lambda(lambda x: (np.log10(x + 1e-8) - self.sxr_norm[0]) / self.sxr_norm[1]),
|
| 102 |
+
target_size=(512, 512)
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
total_size = len(base_ds)
|
| 106 |
+
test_size = int(self.test_split * total_size)
|
| 107 |
+
val_size = int(self.val_split * total_size)
|
| 108 |
+
train_size = total_size - val_size - test_size
|
| 109 |
+
|
| 110 |
+
indices = np.random.permutation(total_size)
|
| 111 |
+
train_idx = indices[:train_size]
|
| 112 |
+
val_idx = indices[train_size:train_size + val_size]
|
| 113 |
+
test_idx = indices[train_size + val_size:]
|
| 114 |
+
|
| 115 |
+
# Now, re-instantiate with proper transforms for all splits
|
| 116 |
+
full_train_ds = AIA_GOESDataset(
|
| 117 |
+
aia_dir=self.aia_dir,
|
| 118 |
+
sxr_dir=self.sxr_dir,
|
| 119 |
+
transform=self.train_transforms,
|
| 120 |
+
sxr_transform=T.Lambda(lambda x: (np.log10(x + 1e-8) - self.sxr_norm[0]) / self.sxr_norm[1]),
|
| 121 |
+
target_size=(512, 512)
|
| 122 |
+
)
|
| 123 |
+
self.train_ds = Subset(full_train_ds, train_idx)
|
| 124 |
+
|
| 125 |
+
full_val_ds = AIA_GOESDataset(
|
| 126 |
+
aia_dir=self.aia_dir,
|
| 127 |
+
sxr_dir=self.sxr_dir,
|
| 128 |
+
transform=self.val_transforms,
|
| 129 |
+
sxr_transform=T.Lambda(lambda x: (np.log10(x + 1e-8) - self.sxr_norm[0]) / self.sxr_norm[1]),
|
| 130 |
+
target_size=(512, 512)
|
| 131 |
+
)
|
| 132 |
+
self.valid_ds = Subset(full_val_ds, val_idx)
|
| 133 |
+
|
| 134 |
+
full_test_ds = AIA_GOESDataset(
|
| 135 |
+
aia_dir=self.aia_dir,
|
| 136 |
+
sxr_dir=self.sxr_dir,
|
| 137 |
+
transform=self.val_transforms,
|
| 138 |
+
sxr_transform=T.Lambda(lambda x: (np.log10(x + 1e-8) - self.sxr_norm[0]) / self.sxr_norm[1]),
|
| 139 |
+
target_size=(512, 512)
|
| 140 |
+
)
|
| 141 |
+
self.test_ds = Subset(full_test_ds, test_idx)
|
| 142 |
+
|
| 143 |
+
def train_dataloader(self):
|
| 144 |
+
return DataLoader(self.train_ds, batch_size=self.batch_size,
|
| 145 |
+
shuffle=True, num_workers=self.num_workers)
|
| 146 |
+
|
| 147 |
+
def val_dataloader(self):
|
| 148 |
+
return DataLoader(self.valid_ds, batch_size=self.batch_size,
|
| 149 |
+
shuffle=False, num_workers=self.num_workers)
|
| 150 |
+
|
| 151 |
+
def test_dataloader(self):
|
| 152 |
+
return DataLoader(self.test_ds, batch_size=self.batch_size,
|
| 153 |
+
shuffle=False, num_workers=self.num_workers)
|
flaring/MEGS_AI_baseline/__init__.py
ADDED
|
File without changes
|
flaring/MEGS_AI_baseline/base_model.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from pytorch_lightning import LightningModule
|
| 5 |
+
|
| 6 |
+
class BaseModel(LightningModule):
|
| 7 |
+
def __init__(self, model, eve_norm, loss_func, lr):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.model = model
|
| 10 |
+
self.eve_norm = eve_norm # Used for SXR normalization (mean, std)
|
| 11 |
+
self.loss_func = loss_func
|
| 12 |
+
self.lr = lr
|
| 13 |
+
|
| 14 |
+
def forward(self, x, sxr=None):
|
| 15 |
+
return self.model(x)
|
| 16 |
+
|
| 17 |
+
def configure_optimizers(self):
|
| 18 |
+
return torch.optim.Adam(self.parameters(), lr=self.lr)
|
| 19 |
+
|
| 20 |
+
def training_step(self, batch, batch_idx):
|
| 21 |
+
(x, sxr), target = batch
|
| 22 |
+
pred = self(x, sxr)
|
| 23 |
+
pred = pred * self.eve_norm[1] + self.eve_norm[0] # Denormalize for loss
|
| 24 |
+
target = target * self.eve_norm[1] + self.eve_norm[0] # Denormalize target
|
| 25 |
+
loss = self.loss_func(pred, target)
|
| 26 |
+
self.log('train_loss', loss)
|
| 27 |
+
return loss
|
| 28 |
+
|
| 29 |
+
def validation_step(self, batch, batch_idx):
|
| 30 |
+
(x, sxr), target = batch
|
| 31 |
+
pred = self(x, sxr)
|
| 32 |
+
pred = pred * self.eve_norm[1] + self.eve_norm[0]
|
| 33 |
+
target = target * self.eve_norm[1] + self.eve_norm[0]
|
| 34 |
+
loss = self.loss_func(pred, target)
|
| 35 |
+
self.log('valid_loss', loss)
|
| 36 |
+
return loss
|
| 37 |
+
|
| 38 |
+
def test_step(self, batch, batch_idx):
|
| 39 |
+
(x, sxr), target = batch
|
| 40 |
+
pred = self(x, sxr)
|
| 41 |
+
pred = pred * self.eve_norm[1] + self.eve_norm[0]
|
| 42 |
+
target = target * self.eve_norm[1] + self.eve_norm[0]
|
| 43 |
+
loss = self.loss_func(pred, target)
|
| 44 |
+
self.log('test_loss', loss)
|
| 45 |
+
return loss
|
flaring/MEGS_AI_baseline/chopped_alexnet.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import HuberLoss
|
| 4 |
+
from irradiance.models.base_model import BaseModel
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ChoppedAlexnet(BaseModel):
|
| 8 |
+
|
| 9 |
+
# def __init__(self, numlayers, n_channels, outSize, dropout):
|
| 10 |
+
def __init__(self, d_input, d_output, eve_norm, loss_func=HuberLoss(), numLayers=3, dropout=0, lr=1e-4):
|
| 11 |
+
self.numLayers = numLayers
|
| 12 |
+
self.n_channels = d_input
|
| 13 |
+
self.outSize = d_output
|
| 14 |
+
self.loss_func = HuberLoss() # consider MSE
|
| 15 |
+
|
| 16 |
+
layers, channelSize = self.getLayers(self.numLayers, self.n_channels)
|
| 17 |
+
self.features = nn.Sequential(*layers)
|
| 18 |
+
self.pool = nn.AdaptiveAvgPool2d((1,1))
|
| 19 |
+
|
| 20 |
+
model = nn.Sequential(nn.Dropout(p=dropout),
|
| 21 |
+
nn.Linear(channelSize, self.outSize))
|
| 22 |
+
|
| 23 |
+
for m in self.modules():
|
| 24 |
+
if isinstance(m, nn.Conv2d):
|
| 25 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 26 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 27 |
+
nn.init.constant_(m.weight, 1)
|
| 28 |
+
nn.init.constant_(m.bias, 0)
|
| 29 |
+
super().__init__(model=model, eve_norm=eve_norm, loss_func=loss_func, lr=lr)
|
| 30 |
+
|
| 31 |
+
def getLayers(self, numLayers, n_channels):
|
| 32 |
+
"""Returns a list of layers + the feature size coming out"""
|
| 33 |
+
layers = [nn.Conv2d(n_channels, 64, kernel_size=11, stride=4, padding=2), nn.BatchNorm2d(64), nn.ReLU(inplace=True), ]
|
| 34 |
+
if numLayers == 1:
|
| 35 |
+
return (layers, 64)
|
| 36 |
+
layers += [nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(64, 192, kernel_size=5, padding=2), nn.BatchNorm2d(192), nn.ReLU(inplace=True), ]
|
| 37 |
+
if numLayers == 2:
|
| 38 |
+
return (layers, 192)
|
| 39 |
+
layers += [nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(192, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384), nn.ReLU(inplace=True)]
|
| 40 |
+
if numLayers == 3:
|
| 41 |
+
return (layers,384)
|
| 42 |
+
|
| 43 |
+
layers += [nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.BatchNorm2d(256)]
|
| 44 |
+
return (layers,256)
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
x = self.features(x)
|
| 48 |
+
x = self.pool(x).view(x.size(0),-1)
|
| 49 |
+
x = self.model(x)
|
| 50 |
+
return x
|
flaring/MEGS_AI_baseline/efficientnet.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torchvision
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import HuberLoss
|
| 4 |
+
from irradiance.models.base_model import BaseModel
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class EfficientnetIrradiance(BaseModel):
|
| 8 |
+
|
| 9 |
+
def __init__(self, d_input, d_output, eve_norm, loss_func=HuberLoss(), model='efficientnet_b0', dp=0.75, lr=1e-4):
|
| 10 |
+
if model == 'efficientnet_b0':
|
| 11 |
+
model = torchvision.models.efficientnet_b0(pretrained=True)
|
| 12 |
+
elif model == 'efficientnet_b1':
|
| 13 |
+
model = torchvision.models.efficientnet_b1(pretrained=True)
|
| 14 |
+
elif model == 'efficientnet_b2':
|
| 15 |
+
model = torchvision.models.efficientnet_b2(pretrained=True)
|
| 16 |
+
elif model == 'efficientnet_b3':
|
| 17 |
+
model = torchvision.models.efficientnet_b3(pretrained=True)
|
| 18 |
+
elif model == 'efficientnet_b4':
|
| 19 |
+
model = torchvision.models.efficientnet_b4(pretrained=True)
|
| 20 |
+
elif model == 'efficientnet_b5':
|
| 21 |
+
model = torchvision.models.efficientnet_b5(pretrained=True)
|
| 22 |
+
elif model == 'efficientnet_b6':
|
| 23 |
+
model = torchvision.models.efficientnet_b6(pretrained=True)
|
| 24 |
+
elif model == 'efficientnet_b7':
|
| 25 |
+
model = torchvision.models.efficientnet_b7(pretrained=True)
|
| 26 |
+
conv1_out = model.features[0][0].out_channels
|
| 27 |
+
model.features[0][0] = nn.Conv2d(d_input, conv1_out, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
|
| 28 |
+
|
| 29 |
+
lin_in = model.classifier[1].in_features
|
| 30 |
+
# consider adding average pool of full image(s)
|
| 31 |
+
classifier = nn.Sequential(nn.Dropout(p=dp, inplace=True),
|
| 32 |
+
nn.Linear(in_features=lin_in, out_features=d_output, bias=True))
|
| 33 |
+
model.classifier = classifier
|
| 34 |
+
# set all dropouts to 0.75
|
| 35 |
+
# TODO: other dropout values?
|
| 36 |
+
for m in model.modules():
|
| 37 |
+
if m.__class__.__name__.startswith('Dropout'):
|
| 38 |
+
m.p = dp
|
| 39 |
+
model = model
|
| 40 |
+
|
| 41 |
+
super().__init__(model=model, eve_norm=eve_norm, loss_func=loss_func, lr=lr)
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
x = self.model(x)
|
| 45 |
+
return x
|
flaring/MEGS_AI_baseline/kan_success.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Li, Ziyao
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from typing import *
|
| 20 |
+
from torch.nn import HuberLoss
|
| 21 |
+
from irradiance.models.base_model import BaseModel
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SplineLinear(nn.Linear):
|
| 25 |
+
def __init__(self, in_features: int, out_features: int, init_scale: float = 0.1, **kw) -> None:
|
| 26 |
+
self.init_scale = init_scale
|
| 27 |
+
super().__init__(in_features, out_features, bias=False, **kw)
|
| 28 |
+
|
| 29 |
+
def reset_parameters(self) -> None:
|
| 30 |
+
nn.init.trunc_normal_(self.weight, mean=0, std=self.init_scale)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class RadialBasisFunction(nn.Module):
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
grid_min: float = -2.,
|
| 37 |
+
grid_max: float = 2.,
|
| 38 |
+
num_grids: int = 8,
|
| 39 |
+
denominator: float = None, # larger denominators lead to smoother basis
|
| 40 |
+
):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.grid_min = grid_min
|
| 43 |
+
self.grid_max = grid_max
|
| 44 |
+
self.num_grids = num_grids
|
| 45 |
+
grid = torch.linspace(grid_min, grid_max, num_grids)
|
| 46 |
+
self.grid = torch.nn.Parameter(grid, requires_grad=False)
|
| 47 |
+
self.denominator = denominator or (grid_max - grid_min) / (num_grids - 1)
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
return torch.exp(-((x[..., None] - self.grid) / self.denominator) ** 2)
|
| 51 |
+
|
| 52 |
+
class FastKANLayer(nn.Module):
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
input_dim: int,
|
| 56 |
+
output_dim: int,
|
| 57 |
+
grid_min: float = -2.,
|
| 58 |
+
grid_max: float = 2.,
|
| 59 |
+
num_grids: int = 8,
|
| 60 |
+
use_base_update: bool = True,
|
| 61 |
+
use_layernorm: bool = True,
|
| 62 |
+
base_activation = F.silu,
|
| 63 |
+
spline_weight_init_scale: float = 0.1,
|
| 64 |
+
) -> None:
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.input_dim = input_dim
|
| 67 |
+
self.output_dim = output_dim
|
| 68 |
+
self.layernorm = None
|
| 69 |
+
if use_layernorm:
|
| 70 |
+
assert input_dim > 1, "Do not use layernorms on 1D inputs. Set `use_layernorm=False`."
|
| 71 |
+
self.layernorm = nn.LayerNorm(input_dim)
|
| 72 |
+
self.rbf = RadialBasisFunction(grid_min, grid_max, num_grids)
|
| 73 |
+
self.spline_linear = SplineLinear(input_dim * num_grids, output_dim, spline_weight_init_scale)
|
| 74 |
+
self.use_base_update = use_base_update
|
| 75 |
+
if use_base_update:
|
| 76 |
+
self.base_activation = base_activation
|
| 77 |
+
self.base_linear = nn.Linear(input_dim, output_dim)
|
| 78 |
+
|
| 79 |
+
def forward(self, x, use_layernorm=True):
|
| 80 |
+
if self.layernorm is not None and use_layernorm:
|
| 81 |
+
spline_basis = self.rbf(self.layernorm(x))
|
| 82 |
+
else:
|
| 83 |
+
spline_basis = self.rbf(x)
|
| 84 |
+
ret = self.spline_linear(spline_basis.view(*spline_basis.shape[:-2], -1))
|
| 85 |
+
if self.use_base_update:
|
| 86 |
+
base = self.base_linear(self.base_activation(x))
|
| 87 |
+
ret = ret + base
|
| 88 |
+
return ret
|
| 89 |
+
|
| 90 |
+
def plot_curve(
|
| 91 |
+
self,
|
| 92 |
+
input_index: int,
|
| 93 |
+
output_index: int,
|
| 94 |
+
num_pts: int = 1000,
|
| 95 |
+
num_extrapolate_bins: int = 2
|
| 96 |
+
):
|
| 97 |
+
'''this function returns the learned curves in a FastKANLayer.
|
| 98 |
+
input_index: the selected index of the input, in [0, input_dim) .
|
| 99 |
+
output_index: the selected index of the output, in [0, output_dim) .
|
| 100 |
+
num_pts: num of points sampled for the curve.
|
| 101 |
+
num_extrapolate_bins (N_e): num of bins extrapolating from the given grids. The curve
|
| 102 |
+
will be calculate in the range of [grid_min - h * N_e, grid_max + h * N_e].
|
| 103 |
+
'''
|
| 104 |
+
ng = self.rbf.num_grids
|
| 105 |
+
h = self.rbf.denominator
|
| 106 |
+
assert input_index < self.input_dim
|
| 107 |
+
assert output_index < self.output_dim
|
| 108 |
+
w = self.spline_linear.weight[
|
| 109 |
+
output_index, input_index * ng : (input_index + 1) * ng
|
| 110 |
+
] # num_grids,
|
| 111 |
+
x = torch.linspace(
|
| 112 |
+
self.rbf.grid_min - num_extrapolate_bins * h,
|
| 113 |
+
self.rbf.grid_max + num_extrapolate_bins * h,
|
| 114 |
+
num_pts
|
| 115 |
+
) # num_pts, num_grids
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
y = (w * self.rbf(x.to(w.dtype))).sum(-1)
|
| 118 |
+
return x, y
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class FastKANIrradiance(BaseModel):
|
| 122 |
+
def __init__(
|
| 123 |
+
self,
|
| 124 |
+
eve_norm,
|
| 125 |
+
layers_hidden: List[int],
|
| 126 |
+
grid_min: float = -2.,
|
| 127 |
+
grid_max: float = 2.,
|
| 128 |
+
num_grids: int = 8,
|
| 129 |
+
use_base_update: bool = True,
|
| 130 |
+
base_activation = F.silu,
|
| 131 |
+
spline_weight_init_scale: float = 0.1,
|
| 132 |
+
loss_func = HuberLoss(),
|
| 133 |
+
lr=1e-4,
|
| 134 |
+
use_std=False
|
| 135 |
+
) -> None:
|
| 136 |
+
super().__init__(model=None, eve_norm=eve_norm, loss_func=loss_func, lr=lr)
|
| 137 |
+
self.use_std = use_std
|
| 138 |
+
if use_std:
|
| 139 |
+
layers_hidden[0] = layers_hidden[0]*2
|
| 140 |
+
self.layers = nn.ModuleList([
|
| 141 |
+
FastKANLayer(
|
| 142 |
+
in_dim, out_dim,
|
| 143 |
+
grid_min=grid_min,
|
| 144 |
+
grid_max=grid_max,
|
| 145 |
+
num_grids=num_grids,
|
| 146 |
+
use_base_update=use_base_update,
|
| 147 |
+
base_activation=base_activation,
|
| 148 |
+
spline_weight_init_scale=spline_weight_init_scale,
|
| 149 |
+
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
|
| 150 |
+
])
|
| 151 |
+
|
| 152 |
+
def forward(self, x):
|
| 153 |
+
# Calculating mean and std of images to take them as input to 1D KAN
|
| 154 |
+
mean_irradiance = torch.torch.mean(x, dim=(2,3))
|
| 155 |
+
std_irradiance = torch.torch.std(x, dim=(2,3))
|
| 156 |
+
if self.use_std:
|
| 157 |
+
x = torch.cat((mean_irradiance, std_irradiance), dim=1)
|
| 158 |
+
else:
|
| 159 |
+
x = mean_irradiance
|
| 160 |
+
for layer in self.layers:
|
| 161 |
+
x = layer(x)
|
| 162 |
+
return x
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class AttentionWithFastKANTransform(nn.Module):
|
| 166 |
+
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
q_dim: int,
|
| 170 |
+
k_dim: int,
|
| 171 |
+
v_dim: int,
|
| 172 |
+
head_dim: int,
|
| 173 |
+
num_heads: int,
|
| 174 |
+
gating: bool = True,
|
| 175 |
+
):
|
| 176 |
+
super(AttentionWithFastKANTransform, self).__init__()
|
| 177 |
+
|
| 178 |
+
self.num_heads = num_heads
|
| 179 |
+
total_dim = head_dim * self.num_heads
|
| 180 |
+
self.gating = gating
|
| 181 |
+
self.linear_q = FastKANLayer(q_dim, total_dim)
|
| 182 |
+
self.linear_k = FastKANLayer(k_dim, total_dim)
|
| 183 |
+
self.linear_v = FastKANLayer(v_dim, total_dim)
|
| 184 |
+
self.linear_o = FastKANLayer(total_dim, q_dim)
|
| 185 |
+
self.linear_g = None
|
| 186 |
+
if self.gating:
|
| 187 |
+
self.linear_g = FastKANLayer(q_dim, total_dim)
|
| 188 |
+
# precompute the 1/sqrt(head_dim)
|
| 189 |
+
self.norm = head_dim**-0.5
|
| 190 |
+
|
| 191 |
+
def forward(
|
| 192 |
+
self,
|
| 193 |
+
q: torch.Tensor,
|
| 194 |
+
k: torch.Tensor,
|
| 195 |
+
v: torch.Tensor,
|
| 196 |
+
bias: torch.Tensor = None, # additive attention bias
|
| 197 |
+
) -> torch.Tensor:
|
| 198 |
+
|
| 199 |
+
wq = self.linear_q(q).view(*q.shape[:-1], 1, self.num_heads, -1) * self.norm # *q1hc
|
| 200 |
+
wk = self.linear_k(k).view(*k.shape[:-2], 1, k.shape[-2], self.num_heads, -1) # *1khc
|
| 201 |
+
att = (wq * wk).sum(-1).softmax(-2) # *qkh
|
| 202 |
+
del wq, wk
|
| 203 |
+
if bias is not None:
|
| 204 |
+
att = att + bias[..., None]
|
| 205 |
+
|
| 206 |
+
wv = self.linear_v(v).view(*v.shape[:-2],1, v.shape[-2], self.num_heads, -1) # *1khc
|
| 207 |
+
o = (att[..., None] * wv).sum(-3) # *qhc
|
| 208 |
+
del att, wv
|
| 209 |
+
|
| 210 |
+
o = o.view(*o.shape[:-2], -1) # *q(hc)
|
| 211 |
+
|
| 212 |
+
if self.linear_g is not None:
|
| 213 |
+
# gating, use raw query input
|
| 214 |
+
g = self.linear_g(q)
|
| 215 |
+
o = torch.sigmoid(g) * o
|
| 216 |
+
|
| 217 |
+
# merge heads
|
| 218 |
+
o = self.linear_o(o)
|
| 219 |
+
return o
|
flaring/MEGS_AI_baseline/linear_and_hybrid.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn import HuberLoss
|
| 4 |
+
from models.base_model import BaseModel
|
| 5 |
+
|
| 6 |
+
class LinearIrradianceModel(BaseModel):
|
| 7 |
+
def __init__(self, d_input, d_output, eve_norm, loss_func=HuberLoss(), lr=1e-2):
|
| 8 |
+
self.n_channels = d_input
|
| 9 |
+
self.outSize = d_output
|
| 10 |
+
model = nn.Linear(2 * self.n_channels, self.outSize)
|
| 11 |
+
super().__init__(model=model, eve_norm=eve_norm, loss_func=loss_func, lr=lr)
|
| 12 |
+
|
| 13 |
+
def forward(self, x, sxr=None, **kwargs):
|
| 14 |
+
# If x is a tuple (aia_img, sxr_val), extract the AIA image tensor
|
| 15 |
+
if isinstance(x, (list, tuple)):
|
| 16 |
+
x = x[0]
|
| 17 |
+
|
| 18 |
+
# Debug: Print input shape
|
| 19 |
+
print(f"Input shape to LinearIrradianceModel.forward: {x.shape}")
|
| 20 |
+
|
| 21 |
+
# Expect x shape: (batch_size, H, W, C)
|
| 22 |
+
if len(x.shape) != 4:
|
| 23 |
+
raise ValueError(f"Expected 4D input tensor (batch_size, H, W, C), got shape {x.shape}")
|
| 24 |
+
if x.shape[-1] != self.n_channels:
|
| 25 |
+
raise ValueError(f"AIA image has {x.shape[-1]} channels, expected {self.n_channels}")
|
| 26 |
+
|
| 27 |
+
# Calculate mean and std across spatial dimensions (H,W)
|
| 28 |
+
# First permute to (batch_size, C, H, W)
|
| 29 |
+
x = x.permute(0, 3, 1, 2)
|
| 30 |
+
|
| 31 |
+
# Now calculate mean/std across dimensions 2 and 3 (H,W)
|
| 32 |
+
mean_irradiance = torch.mean(x, dim=(2, 3)) # Shape: (batch_size, n_channels)
|
| 33 |
+
std_irradiance = torch.std(x, dim=(2, 3)) # Shape: (batch_size, n_channels)
|
| 34 |
+
|
| 35 |
+
# Debug: Print shapes after mean and std
|
| 36 |
+
print(f"mean_irradiance shape: {mean_irradiance.shape}, std_irradiance shape: {std_irradiance.shape}")
|
| 37 |
+
|
| 38 |
+
input_features = torch.cat((mean_irradiance, std_irradiance), dim=1) # Shape: (batch_size, 2 * n_channels)
|
| 39 |
+
print(f"Input features shape to linear layer: {input_features.shape}")
|
| 40 |
+
|
| 41 |
+
if input_features.shape[1] != 2 * self.n_channels:
|
| 42 |
+
raise ValueError(f"Expected {2 * self.n_channels} features, got {input_features.shape[1]}")
|
| 43 |
+
|
| 44 |
+
return self.model(input_features)
|
| 45 |
+
|
| 46 |
+
class HybridIrradianceModel(BaseModel):
|
| 47 |
+
def __init__(self, d_input, d_output, eve_norm, cnn_model='resnet', ln_model=True, ln_params=None, lr=1e-4, cnn_dp=0.75, loss_func=HuberLoss()):
|
| 48 |
+
super().__init__(model=None, eve_norm=eve_norm, loss_func=loss_func, lr=lr)
|
| 49 |
+
self.n_channels = d_input
|
| 50 |
+
self.outSize = d_output
|
| 51 |
+
self.ln_params = ln_params
|
| 52 |
+
self.ln_model = None
|
| 53 |
+
if ln_model:
|
| 54 |
+
self.ln_model = LinearIrradianceModel(d_input, d_output, eve_norm, loss_func=loss_func, lr=lr)
|
| 55 |
+
if self.ln_params is not None and self.ln_model is not None:
|
| 56 |
+
self.ln_model.model.weight = nn.Parameter(self.ln_params['weight'])
|
| 57 |
+
self.ln_model.model.bias = nn.Parameter(self.ln_params['bias'])
|
| 58 |
+
self.cnn_model = None
|
| 59 |
+
self.cnn_lambda = 1.
|
| 60 |
+
if cnn_model == 'resnet':
|
| 61 |
+
self.cnn_model = nn.Sequential(
|
| 62 |
+
nn.Conv2d(d_input, 64, kernel_size=7, stride=2, padding=3),
|
| 63 |
+
nn.ReLU(),
|
| 64 |
+
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
| 65 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
| 66 |
+
nn.ReLU(),
|
| 67 |
+
nn.AdaptiveAvgPool2d((1, 1)),
|
| 68 |
+
nn.Flatten(),
|
| 69 |
+
nn.Linear(64, d_output),
|
| 70 |
+
nn.Dropout(cnn_dp)
|
| 71 |
+
)
|
| 72 |
+
elif cnn_model.startswith('efficientnet'):
|
| 73 |
+
raise NotImplementedError("EfficientNet requires timm; replace with custom CNN or install timm")
|
| 74 |
+
if self.ln_model is None and self.cnn_model is None:
|
| 75 |
+
raise ValueError('Please pass at least one model.')
|
| 76 |
+
|
| 77 |
+
def forward(self, x, sxr=None, **kwargs):
|
| 78 |
+
# If x is a tuple (aia_img, sxr_val), extract the AIA image tensor
|
| 79 |
+
if isinstance(x, (list, tuple)):
|
| 80 |
+
x = x[0]
|
| 81 |
+
|
| 82 |
+
# Debug: Print input shape
|
| 83 |
+
print(f"Input shape to HybridIrradianceModel.forward: {x.shape}")
|
| 84 |
+
|
| 85 |
+
# Expect x shape: (batch_size, H, W, C)
|
| 86 |
+
if len(x.shape) != 4:
|
| 87 |
+
raise ValueError(f"Expected 4D input tensor (batch_size, H, W, C), got shape {x.shape}")
|
| 88 |
+
if x.shape[-1] != self.n_channels:
|
| 89 |
+
raise ValueError(f"AIA image has {x.shape[-1]} channels, expected {self.n_channels}")
|
| 90 |
+
|
| 91 |
+
# Convert to (batch_size, C, H, W) for CNN
|
| 92 |
+
x_cnn = x.permute(0, 3, 1, 2)
|
| 93 |
+
|
| 94 |
+
if self.ln_model is not None and self.cnn_model is not None:
|
| 95 |
+
# For linear model, keep original (B,H,W,C) format
|
| 96 |
+
return self.ln_model(x) + self.cnn_lambda * self.cnn_model(x_cnn)
|
| 97 |
+
elif self.ln_model is not None:
|
| 98 |
+
return self.ln_model(x)
|
| 99 |
+
elif self.cnn_model is not None:
|
| 100 |
+
return self.cnn_model(x_cnn)
|
| 101 |
+
|
| 102 |
+
def configure_optimizers(self):
|
| 103 |
+
return torch.optim.Adam(self.parameters(), lr=self.lr)
|
| 104 |
+
|
| 105 |
+
def set_train_mode(self, mode):
|
| 106 |
+
if mode == 'linear':
|
| 107 |
+
self.cnn_lambda = 0
|
| 108 |
+
if self.cnn_model: self.cnn_model.eval()
|
| 109 |
+
if self.ln_model: self.ln_model.train()
|
| 110 |
+
elif mode == 'cnn':
|
| 111 |
+
self.cnn_lambda = 0.01
|
| 112 |
+
if self.cnn_model: self.cnn_model.train()
|
| 113 |
+
if self.ln_model: self.ln_model.eval()
|
| 114 |
+
elif mode == 'both':
|
| 115 |
+
self.cnn_lambda = 0.01
|
| 116 |
+
if self.cnn_model: self.cnn_model.train()
|
| 117 |
+
if self.ln_model: self.ln_model.train()
|
| 118 |
+
else:
|
| 119 |
+
raise NotImplementedError(f'Mode not supported: {mode}')
|
flaring/MEGS_AI_baseline/models/base_model.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from pytorch_lightning import LightningModule
|
| 5 |
+
|
| 6 |
+
class BaseModel(LightningModule):
|
| 7 |
+
def __init__(self, model, eve_norm, loss_func, lr):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.model = model
|
| 10 |
+
self.eve_norm = eve_norm # Used for SXR normalization (mean, std)
|
| 11 |
+
self.loss_func = loss_func
|
| 12 |
+
self.lr = lr
|
| 13 |
+
|
| 14 |
+
def forward(self, x, sxr=None):
|
| 15 |
+
return self.model(x)
|
| 16 |
+
|
| 17 |
+
def configure_optimizers(self):
|
| 18 |
+
return torch.optim.Adam(self.parameters(), lr=self.lr)
|
| 19 |
+
|
| 20 |
+
def training_step(self, batch, batch_idx):
|
| 21 |
+
(x, sxr), target = batch
|
| 22 |
+
pred = self(x, sxr)
|
| 23 |
+
pred = pred * self.eve_norm[1] + self.eve_norm[0] # Denormalize for loss
|
| 24 |
+
target = target * self.eve_norm[1] + self.eve_norm[0] # Denormalize target
|
| 25 |
+
loss = self.loss_func(pred, target)
|
| 26 |
+
self.log('train_loss', loss)
|
| 27 |
+
return loss
|
| 28 |
+
|
| 29 |
+
def validation_step(self, batch, batch_idx):
|
| 30 |
+
(x, sxr), target = batch
|
| 31 |
+
pred = self(x, sxr)
|
| 32 |
+
pred = pred * self.eve_norm[1] + self.eve_norm[0]
|
| 33 |
+
target = target * self.eve_norm[1] + self.eve_norm[0]
|
| 34 |
+
loss = self.loss_func(pred, target)
|
| 35 |
+
self.log('valid_loss', loss)
|
| 36 |
+
return loss
|
| 37 |
+
|
| 38 |
+
def test_step(self, batch, batch_idx):
|
| 39 |
+
(x, sxr), target = batch
|
| 40 |
+
pred = self(x, sxr)
|
| 41 |
+
pred = pred * self.eve_norm[1] + self.eve_norm[0]
|
| 42 |
+
target = target * self.eve_norm[1] + self.eve_norm[0]
|
| 43 |
+
loss = self.loss_func(pred, target)
|
| 44 |
+
self.log('test_loss', loss)
|
| 45 |
+
return loss
|
flaring/MEGS_AI_baseline/models/chopped_alexnet.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import HuberLoss
|
| 4 |
+
from models.base_model import BaseModel
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ChoppedAlexnet(BaseModel):
|
| 8 |
+
|
| 9 |
+
# def __init__(self, numlayers, n_channels, outSize, dropout):
|
| 10 |
+
def __init__(self, d_input, d_output, eve_norm, loss_func=HuberLoss(), numLayers=3, dropout=0, lr=1e-4):
|
| 11 |
+
self.numLayers = numLayers
|
| 12 |
+
self.n_channels = d_input
|
| 13 |
+
self.outSize = d_output
|
| 14 |
+
self.loss_func = HuberLoss() # consider MSE
|
| 15 |
+
|
| 16 |
+
layers, channelSize = self.getLayers(self.numLayers, self.n_channels)
|
| 17 |
+
self.features = nn.Sequential(*layers)
|
| 18 |
+
self.pool = nn.AdaptiveAvgPool2d((1,1))
|
| 19 |
+
|
| 20 |
+
model = nn.Sequential(nn.Dropout(p=dropout),
|
| 21 |
+
nn.Linear(channelSize, self.outSize))
|
| 22 |
+
|
| 23 |
+
for m in self.modules():
|
| 24 |
+
if isinstance(m, nn.Conv2d):
|
| 25 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 26 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 27 |
+
nn.init.constant_(m.weight, 1)
|
| 28 |
+
nn.init.constant_(m.bias, 0)
|
| 29 |
+
super().__init__(model=model, eve_norm=eve_norm, loss_func=loss_func, lr=lr)
|
| 30 |
+
|
| 31 |
+
def getLayers(self, numLayers, n_channels):
|
| 32 |
+
"""Returns a list of layers + the feature size coming out"""
|
| 33 |
+
layers = [nn.Conv2d(n_channels, 64, kernel_size=11, stride=4, padding=2), nn.BatchNorm2d(64), nn.ReLU(inplace=True), ]
|
| 34 |
+
if numLayers == 1:
|
| 35 |
+
return (layers, 64)
|
| 36 |
+
layers += [nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(64, 192, kernel_size=5, padding=2), nn.BatchNorm2d(192), nn.ReLU(inplace=True), ]
|
| 37 |
+
if numLayers == 2:
|
| 38 |
+
return (layers, 192)
|
| 39 |
+
layers += [nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(192, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384), nn.ReLU(inplace=True)]
|
| 40 |
+
if numLayers == 3:
|
| 41 |
+
return (layers,384)
|
| 42 |
+
|
| 43 |
+
layers += [nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.BatchNorm2d(256)]
|
| 44 |
+
return (layers,256)
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
x = self.features(x)
|
| 48 |
+
x = self.pool(x).view(x.size(0),-1)
|
| 49 |
+
x = self.model(x)
|
| 50 |
+
return x
|
flaring/MEGS_AI_baseline/models/efficientnet.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torchvision
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import HuberLoss
|
| 4 |
+
from models.base_model import BaseModel
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class EfficientnetIrradiance(BaseModel):
|
| 8 |
+
|
| 9 |
+
def __init__(self, d_input, d_output, eve_norm, loss_func=HuberLoss(), model='efficientnet_b0', dp=0.75, lr=1e-4):
|
| 10 |
+
if model == 'efficientnet_b0':
|
| 11 |
+
model = torchvision.models.efficientnet_b0(pretrained=True)
|
| 12 |
+
elif model == 'efficientnet_b1':
|
| 13 |
+
model = torchvision.models.efficientnet_b1(pretrained=True)
|
| 14 |
+
elif model == 'efficientnet_b2':
|
| 15 |
+
model = torchvision.models.efficientnet_b2(pretrained=True)
|
| 16 |
+
elif model == 'efficientnet_b3':
|
| 17 |
+
model = torchvision.models.efficientnet_b3(pretrained=True)
|
| 18 |
+
elif model == 'efficientnet_b4':
|
| 19 |
+
model = torchvision.models.efficientnet_b4(pretrained=True)
|
| 20 |
+
elif model == 'efficientnet_b5':
|
| 21 |
+
model = torchvision.models.efficientnet_b5(pretrained=True)
|
| 22 |
+
elif model == 'efficientnet_b6':
|
| 23 |
+
model = torchvision.models.efficientnet_b6(pretrained=True)
|
| 24 |
+
elif model == 'efficientnet_b7':
|
| 25 |
+
model = torchvision.models.efficientnet_b7(pretrained=True)
|
| 26 |
+
conv1_out = model.features[0][0].out_channels
|
| 27 |
+
model.features[0][0] = nn.Conv2d(d_input, conv1_out, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
|
| 28 |
+
|
| 29 |
+
lin_in = model.classifier[1].in_features
|
| 30 |
+
# consider adding average pool of full image(s)
|
| 31 |
+
classifier = nn.Sequential(nn.Dropout(p=dp, inplace=True),
|
| 32 |
+
nn.Linear(in_features=lin_in, out_features=d_output, bias=True))
|
| 33 |
+
model.classifier = classifier
|
| 34 |
+
# set all dropouts to 0.75
|
| 35 |
+
# TODO: other dropout values?
|
| 36 |
+
for m in model.modules():
|
| 37 |
+
if m.__class__.__name__.startswith('Dropout'):
|
| 38 |
+
m.p = dp
|
| 39 |
+
model = model
|
| 40 |
+
|
| 41 |
+
super().__init__(model=model, eve_norm=eve_norm, loss_func=loss_func, lr=lr)
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
x = self.model(x)
|
| 45 |
+
return x
|
flaring/MEGS_AI_baseline/models/kan_success.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Li, Ziyao
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from typing import *
|
| 20 |
+
from torch.nn import HuberLoss
|
| 21 |
+
from irradiance.models.base_model import BaseModel
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SplineLinear(nn.Linear):
|
| 25 |
+
def __init__(self, in_features: int, out_features: int, init_scale: float = 0.1, **kw) -> None:
|
| 26 |
+
self.init_scale = init_scale
|
| 27 |
+
super().__init__(in_features, out_features, bias=False, **kw)
|
| 28 |
+
|
| 29 |
+
def reset_parameters(self) -> None:
|
| 30 |
+
nn.init.trunc_normal_(self.weight, mean=0, std=self.init_scale)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class RadialBasisFunction(nn.Module):
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
grid_min: float = -2.,
|
| 37 |
+
grid_max: float = 2.,
|
| 38 |
+
num_grids: int = 8,
|
| 39 |
+
denominator: float = None, # larger denominators lead to smoother basis
|
| 40 |
+
):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.grid_min = grid_min
|
| 43 |
+
self.grid_max = grid_max
|
| 44 |
+
self.num_grids = num_grids
|
| 45 |
+
grid = torch.linspace(grid_min, grid_max, num_grids)
|
| 46 |
+
self.grid = torch.nn.Parameter(grid, requires_grad=False)
|
| 47 |
+
self.denominator = denominator or (grid_max - grid_min) / (num_grids - 1)
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
return torch.exp(-((x[..., None] - self.grid) / self.denominator) ** 2)
|
| 51 |
+
|
| 52 |
+
class FastKANLayer(nn.Module):
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
input_dim: int,
|
| 56 |
+
output_dim: int,
|
| 57 |
+
grid_min: float = -2.,
|
| 58 |
+
grid_max: float = 2.,
|
| 59 |
+
num_grids: int = 8,
|
| 60 |
+
use_base_update: bool = True,
|
| 61 |
+
use_layernorm: bool = True,
|
| 62 |
+
base_activation = F.silu,
|
| 63 |
+
spline_weight_init_scale: float = 0.1,
|
| 64 |
+
) -> None:
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.input_dim = input_dim
|
| 67 |
+
self.output_dim = output_dim
|
| 68 |
+
self.layernorm = None
|
| 69 |
+
if use_layernorm:
|
| 70 |
+
assert input_dim > 1, "Do not use layernorms on 1D inputs. Set `use_layernorm=False`."
|
| 71 |
+
self.layernorm = nn.LayerNorm(input_dim)
|
| 72 |
+
self.rbf = RadialBasisFunction(grid_min, grid_max, num_grids)
|
| 73 |
+
self.spline_linear = SplineLinear(input_dim * num_grids, output_dim, spline_weight_init_scale)
|
| 74 |
+
self.use_base_update = use_base_update
|
| 75 |
+
if use_base_update:
|
| 76 |
+
self.base_activation = base_activation
|
| 77 |
+
self.base_linear = nn.Linear(input_dim, output_dim)
|
| 78 |
+
|
| 79 |
+
def forward(self, x, use_layernorm=True):
|
| 80 |
+
if self.layernorm is not None and use_layernorm:
|
| 81 |
+
spline_basis = self.rbf(self.layernorm(x))
|
| 82 |
+
else:
|
| 83 |
+
spline_basis = self.rbf(x)
|
| 84 |
+
ret = self.spline_linear(spline_basis.view(*spline_basis.shape[:-2], -1))
|
| 85 |
+
if self.use_base_update:
|
| 86 |
+
base = self.base_linear(self.base_activation(x))
|
| 87 |
+
ret = ret + base
|
| 88 |
+
return ret
|
| 89 |
+
|
| 90 |
+
def plot_curve(
|
| 91 |
+
self,
|
| 92 |
+
input_index: int,
|
| 93 |
+
output_index: int,
|
| 94 |
+
num_pts: int = 1000,
|
| 95 |
+
num_extrapolate_bins: int = 2
|
| 96 |
+
):
|
| 97 |
+
'''this function returns the learned curves in a FastKANLayer.
|
| 98 |
+
input_index: the selected index of the input, in [0, input_dim) .
|
| 99 |
+
output_index: the selected index of the output, in [0, output_dim) .
|
| 100 |
+
num_pts: num of points sampled for the curve.
|
| 101 |
+
num_extrapolate_bins (N_e): num of bins extrapolating from the given grids. The curve
|
| 102 |
+
will be calculate in the range of [grid_min - h * N_e, grid_max + h * N_e].
|
| 103 |
+
'''
|
| 104 |
+
ng = self.rbf.num_grids
|
| 105 |
+
h = self.rbf.denominator
|
| 106 |
+
assert input_index < self.input_dim
|
| 107 |
+
assert output_index < self.output_dim
|
| 108 |
+
w = self.spline_linear.weight[
|
| 109 |
+
output_index, input_index * ng : (input_index + 1) * ng
|
| 110 |
+
] # num_grids,
|
| 111 |
+
x = torch.linspace(
|
| 112 |
+
self.rbf.grid_min - num_extrapolate_bins * h,
|
| 113 |
+
self.rbf.grid_max + num_extrapolate_bins * h,
|
| 114 |
+
num_pts
|
| 115 |
+
) # num_pts, num_grids
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
y = (w * self.rbf(x.to(w.dtype))).sum(-1)
|
| 118 |
+
return x, y
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class FastKANIrradiance(BaseModel):
|
| 122 |
+
def __init__(
|
| 123 |
+
self,
|
| 124 |
+
eve_norm,
|
| 125 |
+
layers_hidden: List[int],
|
| 126 |
+
grid_min: float = -2.,
|
| 127 |
+
grid_max: float = 2.,
|
| 128 |
+
num_grids: int = 8,
|
| 129 |
+
use_base_update: bool = True,
|
| 130 |
+
base_activation = F.silu,
|
| 131 |
+
spline_weight_init_scale: float = 0.1,
|
| 132 |
+
loss_func = HuberLoss(),
|
| 133 |
+
lr=1e-4,
|
| 134 |
+
use_std=False
|
| 135 |
+
) -> None:
|
| 136 |
+
super().__init__(model=None, eve_norm=eve_norm, loss_func=loss_func, lr=lr)
|
| 137 |
+
self.use_std = use_std
|
| 138 |
+
if use_std:
|
| 139 |
+
layers_hidden[0] = layers_hidden[0]*2
|
| 140 |
+
self.layers = nn.ModuleList([
|
| 141 |
+
FastKANLayer(
|
| 142 |
+
in_dim, out_dim,
|
| 143 |
+
grid_min=grid_min,
|
| 144 |
+
grid_max=grid_max,
|
| 145 |
+
num_grids=num_grids,
|
| 146 |
+
use_base_update=use_base_update,
|
| 147 |
+
base_activation=base_activation,
|
| 148 |
+
spline_weight_init_scale=spline_weight_init_scale,
|
| 149 |
+
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
|
| 150 |
+
])
|
| 151 |
+
|
| 152 |
+
def forward(self, x):
|
| 153 |
+
# Calculating mean and std of images to take them as input to 1D KAN
|
| 154 |
+
mean_irradiance = torch.torch.mean(x, dim=(2,3))
|
| 155 |
+
std_irradiance = torch.torch.std(x, dim=(2,3))
|
| 156 |
+
if self.use_std:
|
| 157 |
+
x = torch.cat((mean_irradiance, std_irradiance), dim=1)
|
| 158 |
+
else:
|
| 159 |
+
x = mean_irradiance
|
| 160 |
+
for layer in self.layers:
|
| 161 |
+
x = layer(x)
|
| 162 |
+
return x
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class AttentionWithFastKANTransform(nn.Module):
|
| 166 |
+
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
q_dim: int,
|
| 170 |
+
k_dim: int,
|
| 171 |
+
v_dim: int,
|
| 172 |
+
head_dim: int,
|
| 173 |
+
num_heads: int,
|
| 174 |
+
gating: bool = True,
|
| 175 |
+
):
|
| 176 |
+
super(AttentionWithFastKANTransform, self).__init__()
|
| 177 |
+
|
| 178 |
+
self.num_heads = num_heads
|
| 179 |
+
total_dim = head_dim * self.num_heads
|
| 180 |
+
self.gating = gating
|
| 181 |
+
self.linear_q = FastKANLayer(q_dim, total_dim)
|
| 182 |
+
self.linear_k = FastKANLayer(k_dim, total_dim)
|
| 183 |
+
self.linear_v = FastKANLayer(v_dim, total_dim)
|
| 184 |
+
self.linear_o = FastKANLayer(total_dim, q_dim)
|
| 185 |
+
self.linear_g = None
|
| 186 |
+
if self.gating:
|
| 187 |
+
self.linear_g = FastKANLayer(q_dim, total_dim)
|
| 188 |
+
# precompute the 1/sqrt(head_dim)
|
| 189 |
+
self.norm = head_dim**-0.5
|
| 190 |
+
|
| 191 |
+
def forward(
|
| 192 |
+
self,
|
| 193 |
+
q: torch.Tensor,
|
| 194 |
+
k: torch.Tensor,
|
| 195 |
+
v: torch.Tensor,
|
| 196 |
+
bias: torch.Tensor = None, # additive attention bias
|
| 197 |
+
) -> torch.Tensor:
|
| 198 |
+
|
| 199 |
+
wq = self.linear_q(q).view(*q.shape[:-1], 1, self.num_heads, -1) * self.norm # *q1hc
|
| 200 |
+
wk = self.linear_k(k).view(*k.shape[:-2], 1, k.shape[-2], self.num_heads, -1) # *1khc
|
| 201 |
+
att = (wq * wk).sum(-1).softmax(-2) # *qkh
|
| 202 |
+
del wq, wk
|
| 203 |
+
if bias is not None:
|
| 204 |
+
att = att + bias[..., None]
|
| 205 |
+
|
| 206 |
+
wv = self.linear_v(v).view(*v.shape[:-2],1, v.shape[-2], self.num_heads, -1) # *1khc
|
| 207 |
+
o = (att[..., None] * wv).sum(-3) # *qhc
|
| 208 |
+
del att, wv
|
| 209 |
+
|
| 210 |
+
o = o.view(*o.shape[:-2], -1) # *q(hc)
|
| 211 |
+
|
| 212 |
+
if self.linear_g is not None:
|
| 213 |
+
# gating, use raw query input
|
| 214 |
+
g = self.linear_g(q)
|
| 215 |
+
o = torch.sigmoid(g) * o
|
| 216 |
+
|
| 217 |
+
# merge heads
|
| 218 |
+
o = self.linear_o(o)
|
| 219 |
+
return o
|
flaring/MEGS_AI_baseline/models/linear_and_hybrid.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn import HuberLoss
|
| 4 |
+
from models.base_model import BaseModel
|
| 5 |
+
|
| 6 |
+
class LinearIrradianceModel(BaseModel):
|
| 7 |
+
def __init__(self, d_input, d_output, eve_norm, loss_func=HuberLoss(), lr=1e-2):
|
| 8 |
+
self.n_channels = d_input
|
| 9 |
+
self.outSize = d_output
|
| 10 |
+
model = nn.Linear(2 * self.n_channels, self.outSize)
|
| 11 |
+
super().__init__(model=model, eve_norm=eve_norm, loss_func=loss_func, lr=lr)
|
| 12 |
+
|
| 13 |
+
def forward(self, x, sxr=None, **kwargs):
|
| 14 |
+
# If x is a tuple (aia_img, sxr_val), extract the AIA image tensor
|
| 15 |
+
if isinstance(x, (list, tuple)):
|
| 16 |
+
x = x[0]
|
| 17 |
+
|
| 18 |
+
# Debug: Print input shape
|
| 19 |
+
print(f"Input shape to LinearIrradianceModel.forward: {x.shape}")
|
| 20 |
+
|
| 21 |
+
# Expect x shape: (batch_size, H, W, C)
|
| 22 |
+
if len(x.shape) != 4:
|
| 23 |
+
raise ValueError(f"Expected 4D input tensor (batch_size, H, W, C), got shape {x.shape}")
|
| 24 |
+
if x.shape[-1] != self.n_channels:
|
| 25 |
+
raise ValueError(f"AIA image has {x.shape[-1]} channels, expected {self.n_channels}")
|
| 26 |
+
|
| 27 |
+
# Calculate mean and std across spatial dimensions (H,W)
|
| 28 |
+
# First permute to (batch_size, C, H, W)
|
| 29 |
+
x = x.permute(0, 3, 1, 2)
|
| 30 |
+
|
| 31 |
+
# Now calculate mean/std across dimensions 2 and 3 (H,W)
|
| 32 |
+
mean_irradiance = torch.mean(x, dim=(2, 3)) # Shape: (batch_size, n_channels)
|
| 33 |
+
std_irradiance = torch.std(x, dim=(2, 3)) # Shape: (batch_size, n_channels)
|
| 34 |
+
|
| 35 |
+
# Debug: Print shapes after mean and std
|
| 36 |
+
print(f"mean_irradiance shape: {mean_irradiance.shape}, std_irradiance shape: {std_irradiance.shape}")
|
| 37 |
+
|
| 38 |
+
input_features = torch.cat((mean_irradiance, std_irradiance), dim=1) # Shape: (batch_size, 2 * n_channels)
|
| 39 |
+
print(f"Input features shape to linear layer: {input_features.shape}")
|
| 40 |
+
|
| 41 |
+
if input_features.shape[1] != 2 * self.n_channels:
|
| 42 |
+
raise ValueError(f"Expected {2 * self.n_channels} features, got {input_features.shape[1]}")
|
| 43 |
+
|
| 44 |
+
return self.model(input_features)
|
| 45 |
+
|
| 46 |
+
class HybridIrradianceModel(BaseModel):
|
| 47 |
+
def __init__(self, d_input, d_output, eve_norm, cnn_model='resnet', ln_model=True, ln_params=None, lr=1e-4, cnn_dp=0.75, loss_func=HuberLoss()):
|
| 48 |
+
super().__init__(model=None, eve_norm=eve_norm, loss_func=loss_func, lr=lr)
|
| 49 |
+
self.n_channels = d_input
|
| 50 |
+
self.outSize = d_output
|
| 51 |
+
self.ln_params = ln_params
|
| 52 |
+
self.ln_model = None
|
| 53 |
+
if ln_model:
|
| 54 |
+
self.ln_model = LinearIrradianceModel(d_input, d_output, eve_norm, loss_func=loss_func, lr=lr)
|
| 55 |
+
if self.ln_params is not None and self.ln_model is not None:
|
| 56 |
+
self.ln_model.model.weight = nn.Parameter(self.ln_params['weight'])
|
| 57 |
+
self.ln_model.model.bias = nn.Parameter(self.ln_params['bias'])
|
| 58 |
+
self.cnn_model = None
|
| 59 |
+
self.cnn_lambda = 1.
|
| 60 |
+
if cnn_model == 'resnet':
|
| 61 |
+
self.cnn_model = nn.Sequential(
|
| 62 |
+
nn.Conv2d(d_input, 64, kernel_size=7, stride=2, padding=3),
|
| 63 |
+
nn.ReLU(),
|
| 64 |
+
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
| 65 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
| 66 |
+
nn.ReLU(),
|
| 67 |
+
nn.AdaptiveAvgPool2d((1, 1)),
|
| 68 |
+
nn.Flatten(),
|
| 69 |
+
nn.Linear(64, d_output),
|
| 70 |
+
nn.Dropout(cnn_dp)
|
| 71 |
+
)
|
| 72 |
+
elif cnn_model.startswith('efficientnet'):
|
| 73 |
+
raise NotImplementedError("EfficientNet requires timm; replace with custom CNN or install timm")
|
| 74 |
+
if self.ln_model is None and self.cnn_model is None:
|
| 75 |
+
raise ValueError('Please pass at least one model.')
|
| 76 |
+
|
| 77 |
+
def forward(self, x, sxr=None, **kwargs):
|
| 78 |
+
# If x is a tuple (aia_img, sxr_val), extract the AIA image tensor
|
| 79 |
+
if isinstance(x, (list, tuple)):
|
| 80 |
+
x = x[0]
|
| 81 |
+
|
| 82 |
+
# Debug: Print input shape
|
| 83 |
+
print(f"Input shape to HybridIrradianceModel.forward: {x.shape}")
|
| 84 |
+
|
| 85 |
+
# Expect x shape: (batch_size, H, W, C)
|
| 86 |
+
if len(x.shape) != 4:
|
| 87 |
+
raise ValueError(f"Expected 4D input tensor (batch_size, H, W, C), got shape {x.shape}")
|
| 88 |
+
if x.shape[-1] != self.n_channels:
|
| 89 |
+
raise ValueError(f"AIA image has {x.shape[-1]} channels, expected {self.n_channels}")
|
| 90 |
+
|
| 91 |
+
# Convert to (batch_size, C, H, W) for CNN
|
| 92 |
+
x_cnn = x.permute(0, 3, 1, 2)
|
| 93 |
+
|
| 94 |
+
if self.ln_model is not None and self.cnn_model is not None:
|
| 95 |
+
# For linear model, keep original (B,H,W,C) format
|
| 96 |
+
return self.ln_model(x) + self.cnn_lambda * self.cnn_model(x_cnn)
|
| 97 |
+
elif self.ln_model is not None:
|
| 98 |
+
return self.ln_model(x)
|
| 99 |
+
elif self.cnn_model is not None:
|
| 100 |
+
return self.cnn_model(x_cnn)
|
| 101 |
+
|
| 102 |
+
def configure_optimizers(self):
|
| 103 |
+
return torch.optim.Adam(self.parameters(), lr=self.lr)
|
| 104 |
+
|
| 105 |
+
def set_train_mode(self, mode):
|
| 106 |
+
if mode == 'linear':
|
| 107 |
+
self.cnn_lambda = 0
|
| 108 |
+
if self.cnn_model: self.cnn_model.eval()
|
| 109 |
+
if self.ln_model: self.ln_model.train()
|
| 110 |
+
elif mode == 'cnn':
|
| 111 |
+
self.cnn_lambda = 0.01
|
| 112 |
+
if self.cnn_model: self.cnn_model.train()
|
| 113 |
+
if self.ln_model: self.ln_model.eval()
|
| 114 |
+
elif mode == 'both':
|
| 115 |
+
self.cnn_lambda = 0.01
|
| 116 |
+
if self.cnn_model: self.cnn_model.train()
|
| 117 |
+
if self.ln_model: self.ln_model.train()
|
| 118 |
+
else:
|
| 119 |
+
raise NotImplementedError(f'Mode not supported: {mode}')
|
flaring/MEGS_AI_baseline/sxr_normalization.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import numpy as np
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import glob
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
def compute_sxr_norm(sxr_dir):
|
| 8 |
+
"""
|
| 9 |
+
Compute mean and standard deviation of log10-transformed SXR values.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
sxr_dir (str): Path to directory containing SXR .npy files.
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
tuple: (mean, std) of log10(SXR + 1e-8) values.
|
| 16 |
+
"""
|
| 17 |
+
sxr_dir = Path(sxr_dir).resolve()
|
| 18 |
+
print(f"Checking SXR directory: {sxr_dir}")
|
| 19 |
+
if not sxr_dir.is_dir():
|
| 20 |
+
raise FileNotFoundError(f"SXR directory does not exist or is not a directory: {sxr_dir}")
|
| 21 |
+
|
| 22 |
+
# Use glob for case-insensitive matching
|
| 23 |
+
sxr_files = sorted(glob.glob(os.path.join(sxr_dir, "*.npy")))
|
| 24 |
+
print(f"Found {len(sxr_files)} SXR files in {sxr_dir}")
|
| 25 |
+
if len(sxr_files) == 0:
|
| 26 |
+
print(f"No files matching '*_sxr.npy' found. Listing directory contents:")
|
| 27 |
+
print(os.listdir(sxr_dir)[:10]) # Show first 10 files
|
| 28 |
+
raise ValueError(f"No SXR files found in {sxr_dir}")
|
| 29 |
+
|
| 30 |
+
sxr_values = []
|
| 31 |
+
for f in sxr_files:
|
| 32 |
+
try:
|
| 33 |
+
sxr = np.load(f)
|
| 34 |
+
sxr = np.atleast_1d(sxr).flatten()[0]
|
| 35 |
+
if not np.isfinite(sxr) or sxr < 0:
|
| 36 |
+
print(f"Skipping invalid SXR value in {f}: {sxr}")
|
| 37 |
+
continue
|
| 38 |
+
sxr_values.append(np.log10(sxr + 1e-8))
|
| 39 |
+
except Exception as e:
|
| 40 |
+
print(f"Failed to load SXR file {f}: {e}")
|
| 41 |
+
continue
|
| 42 |
+
|
| 43 |
+
sxr_values = np.array(sxr_values)
|
| 44 |
+
if len(sxr_values) == 0:
|
| 45 |
+
raise ValueError(f"No valid SXR values found in {sxr_dir}. All files failed to load or contained invalid data.")
|
| 46 |
+
|
| 47 |
+
mean = np.mean(sxr_values)
|
| 48 |
+
std = np.std(sxr_values)
|
| 49 |
+
print(f"Computed SXR normalization: mean={mean}, std={std}")
|
| 50 |
+
return mean, std
|
| 51 |
+
|
| 52 |
+
if __name__ == "__main__":
|
| 53 |
+
# Update this path to your real data SXR directory
|
| 54 |
+
sxr_dir = "/mnt/data/ML-Ready-Data-No-Intensity-Cut/GOES-18-SXR-B/" # Replace with actual path
|
| 55 |
+
sxr_norm = compute_sxr_norm(sxr_dir)
|
| 56 |
+
np.save("/home/jayantbiradar619/sxr_norm2.npy", sxr_norm)
|
| 57 |
+
print(f"Saved SXR normalization to /home/jayantbiradar619/sxr_norm.npy")
|
flaring/MEGS_AI_baseline/train.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import yaml
|
| 5 |
+
import itertools
|
| 6 |
+
import wandb
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import torchvision.transforms as transforms
|
| 11 |
+
from pytorch_lightning import Trainer
|
| 12 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 13 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
|
| 14 |
+
from torch.nn import HuberLoss
|
| 15 |
+
from SDOAIA_dataloader import AIA_GOESDataModule
|
| 16 |
+
from linear_and_hybrid import LinearIrradianceModel, HybridIrradianceModel
|
| 17 |
+
|
| 18 |
+
# SXR Prediction Logger
|
| 19 |
+
class SXRPredictionLogger(Callback):
|
| 20 |
+
def __init__(self, val_samples):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.val_samples = val_samples
|
| 23 |
+
|
| 24 |
+
def on_validation_epoch_end(self, trainer, pl_module):
|
| 25 |
+
# val_samples is a list of ((aia, sxr), target)
|
| 26 |
+
for (aia, sxr), target in self.val_samples:
|
| 27 |
+
aia, sxr, target = aia.to(pl_module.device), sxr.to(pl_module.device), target.to(pl_module.device)
|
| 28 |
+
pred = pl_module(aia.unsqueeze(0)) # Add batch dimension
|
| 29 |
+
trainer.logger.experiment.log({
|
| 30 |
+
"val_pred_sxr": pred.cpu().numpy(),
|
| 31 |
+
"val_target_sxr": target.cpu().numpy()
|
| 32 |
+
})
|
| 33 |
+
|
| 34 |
+
# Compute SXR normalization
|
| 35 |
+
def compute_sxr_norm(sxr_dir):
|
| 36 |
+
sxr_values = []
|
| 37 |
+
for f in Path(sxr_dir).glob("*.npy"):
|
| 38 |
+
sxr = np.load(f)
|
| 39 |
+
sxr = np.atleast_1d(sxr).flatten()[0]
|
| 40 |
+
sxr_values.append(np.log10(sxr + 1e-8))
|
| 41 |
+
sxr_values = np.array(sxr_values)
|
| 42 |
+
if len(sxr_values) == 0:
|
| 43 |
+
raise ValueError(f"No SXR files found in {sxr_dir}")
|
| 44 |
+
return np.mean(sxr_values), np.std(sxr_values)
|
| 45 |
+
|
| 46 |
+
# Parser
|
| 47 |
+
parser = argparse.ArgumentParser()
|
| 48 |
+
parser.add_argument('-checkpoint_dir', type=str, required=True, help='Directory to save checkpoints.')
|
| 49 |
+
parser.add_argument('-model', type=str, default='config.yaml', help='Path to model config YAML.')
|
| 50 |
+
parser.add_argument('-aia_dir', type=str, required=True, help='Path to AIA .npy files.')
|
| 51 |
+
parser.add_argument('-sxr_dir', type=str, required=True, help='Path to SXR .npy files.')
|
| 52 |
+
parser.add_argument('-sxr_norm', type=str, help='Path to SXR normalization (mean, std).')
|
| 53 |
+
parser.add_argument('-instrument', type=str, default='AIA_6', help='Instrument (e.g., AIA_6 for 6 wavelengths).')
|
| 54 |
+
args = parser.parse_args()
|
| 55 |
+
|
| 56 |
+
# Load config
|
| 57 |
+
with open(args.model, 'r') as stream:
|
| 58 |
+
config_data = yaml.load(stream, Loader=yaml.SafeLoader)
|
| 59 |
+
|
| 60 |
+
dic_values = [i for i in config_data['model'].values()]
|
| 61 |
+
combined_parameters = list(itertools.product(*dic_values))
|
| 62 |
+
|
| 63 |
+
# Paths and normalization
|
| 64 |
+
checkpoint_dir = args.checkpoint_dir
|
| 65 |
+
aia_dir = args.aia_dir
|
| 66 |
+
sxr_dir = args.sxr_dir
|
| 67 |
+
if args.sxr_norm:
|
| 68 |
+
sxr_norm = np.load(args.sxr_norm)
|
| 69 |
+
else:
|
| 70 |
+
sxr_norm = compute_sxr_norm(sxr_dir)
|
| 71 |
+
instrument = args.instrument
|
| 72 |
+
|
| 73 |
+
# Transforms
|
| 74 |
+
train_transforms = transforms.Compose([
|
| 75 |
+
transforms.Lambda(lambda x: (x - x.min()) / (x.max() - x.min() + 1e-8)), # Remove clone/detach
|
| 76 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
| 77 |
+
transforms.RandomRotation(10),
|
| 78 |
+
])
|
| 79 |
+
val_transforms = transforms.Compose([
|
| 80 |
+
transforms.Lambda(lambda x: (x - x.min()) / (x.max() - x.min() + 1e-8)), # Remove clone/detach
|
| 81 |
+
])
|
| 82 |
+
|
| 83 |
+
# Training loop
|
| 84 |
+
n = 0
|
| 85 |
+
for parameter_set in combined_parameters:
|
| 86 |
+
run_config = {key: item for key, item in zip(config_data['model'].keys(), parameter_set)}
|
| 87 |
+
torch.manual_seed(run_config['seed'])
|
| 88 |
+
np.random.seed(run_config['seed'])
|
| 89 |
+
|
| 90 |
+
# DataModule
|
| 91 |
+
data_loader = AIA_GOESDataModule(
|
| 92 |
+
aia_dir=aia_dir,
|
| 93 |
+
sxr_dir=sxr_dir,
|
| 94 |
+
sxr_norm=sxr_norm,
|
| 95 |
+
batch_size=16,
|
| 96 |
+
num_workers=os.cpu_count() // 2,
|
| 97 |
+
train_transforms=train_transforms,
|
| 98 |
+
val_transforms=val_transforms,
|
| 99 |
+
val_split=0.2,
|
| 100 |
+
test_split=0.1
|
| 101 |
+
)
|
| 102 |
+
data_loader.setup()
|
| 103 |
+
|
| 104 |
+
# Logger
|
| 105 |
+
wb_name = f"{instrument}_{n}" if len(combined_parameters) > 1 else "aia_sxr_model"
|
| 106 |
+
wandb_logger = WandbLogger(
|
| 107 |
+
entity=config_data['wandb']['entity'],
|
| 108 |
+
project=config_data['wandb']['project'],
|
| 109 |
+
job_type=config_data['wandb']['job_type'],
|
| 110 |
+
tags=config_data['wandb']['tags'],
|
| 111 |
+
name=wb_name,
|
| 112 |
+
notes=config_data['wandb']['notes'],
|
| 113 |
+
config=run_config
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Logging callback
|
| 117 |
+
total_n_valid = len(data_loader.valid_ds)
|
| 118 |
+
plot_data = [data_loader.valid_ds[i] for i in range(0, total_n_valid, max(1, total_n_valid // 4))]
|
| 119 |
+
plot_samples = plot_data # Keep as list of ((aia, sxr), target)
|
| 120 |
+
sxr_callback = SXRPredictionLogger(plot_samples)
|
| 121 |
+
|
| 122 |
+
# Checkpoint callback
|
| 123 |
+
checkpoint_callback = ModelCheckpoint(
|
| 124 |
+
dirpath=checkpoint_dir,
|
| 125 |
+
monitor='valid_loss',
|
| 126 |
+
mode='min',
|
| 127 |
+
save_top_k=1,
|
| 128 |
+
filename=f"{wb_name}-{{epoch:02d}}-{{valid_loss:.4f}}"
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# Model
|
| 132 |
+
if run_config['architecture'] == 'linear':
|
| 133 |
+
model = LinearIrradianceModel(
|
| 134 |
+
d_input=6,
|
| 135 |
+
d_output=1,
|
| 136 |
+
eve_norm=sxr_norm,
|
| 137 |
+
lr=run_config.get('lr', 1e-2),
|
| 138 |
+
loss_func=HuberLoss()
|
| 139 |
+
)
|
| 140 |
+
elif run_config['architecture'] == 'hybrid':
|
| 141 |
+
model = HybridIrradianceModel(
|
| 142 |
+
d_input=6,
|
| 143 |
+
d_output=1,
|
| 144 |
+
eve_norm=sxr_norm,
|
| 145 |
+
cnn_model=run_config['cnn_model'],
|
| 146 |
+
ln_model=True,
|
| 147 |
+
cnn_dp=run_config.get('cnn_dp', 0.75),
|
| 148 |
+
lr=run_config.get('lr', 1e-4)
|
| 149 |
+
)
|
| 150 |
+
else:
|
| 151 |
+
raise NotImplementedError(f"Architecture {run_config['architecture']} not supported.")
|
| 152 |
+
|
| 153 |
+
# Trainer
|
| 154 |
+
trainer = Trainer(
|
| 155 |
+
default_root_dir=checkpoint_dir,
|
| 156 |
+
accelerator="gpu" if torch.cuda.is_available() else "cpu",
|
| 157 |
+
devices=1,
|
| 158 |
+
max_epochs=run_config.get('epochs', 10),
|
| 159 |
+
callbacks=[sxr_callback, checkpoint_callback],
|
| 160 |
+
logger=wandb_logger,
|
| 161 |
+
log_every_n_steps=10
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# Train
|
| 165 |
+
trainer.fit(model, data_loader)
|
| 166 |
+
|
| 167 |
+
# Save checkpoint
|
| 168 |
+
save_dictionary = run_config
|
| 169 |
+
save_dictionary['model'] = model
|
| 170 |
+
save_dictionary['instrument'] = instrument
|
| 171 |
+
full_checkpoint_path = os.path.join(checkpoint_dir, f"{wb_name}_{n}.ckpt")
|
| 172 |
+
torch.save(save_dictionary, full_checkpoint_path)
|
| 173 |
+
|
| 174 |
+
# Test
|
| 175 |
+
trainer.test(model, dataloaders=data_loader.test_dataloader())
|
| 176 |
+
|
| 177 |
+
# Finalize
|
| 178 |
+
wandb.finish()
|
| 179 |
+
n += 1
|
flaring/__init__.py
ADDED
|
File without changes
|
flaring/cut_off_aia.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
aia = os.listdir("/mnt/data/ML-Ready-Data-No-Intensity-Cut/AIA-Data")
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
target_dates = ["2023-07-11","2023-07-15","2023-07-16", "2023-07-18" "2023-07-20", "2023-07-26", "2023-07-30", "2023-08-01", "2023-08-02", "2023-08-07", ]
|
| 9 |
+
|
| 10 |
+
aia_dict = {}
|
| 11 |
+
aia_dict[0] = []
|
| 12 |
+
aia_dict[1] = []
|
| 13 |
+
aia_dict[2] = []
|
| 14 |
+
aia_dict[3] = []
|
| 15 |
+
aia_dict[4] = []
|
| 16 |
+
aia_dict[5] = []
|
| 17 |
+
|
| 18 |
+
count = 0
|
| 19 |
+
for i, file in enumerate(aia):
|
| 20 |
+
if file.split("T")[0] in target_dates:
|
| 21 |
+
aia_data = np.load("/mnt/data/ML-Ready-Data-No-Intensity-Cut/AIA-Data/"+file)
|
| 22 |
+
aia_dict[0].append(aia_data[0].flatten())
|
| 23 |
+
aia_dict[1].append(aia_data[1].flatten())
|
| 24 |
+
aia_dict[2].append(aia_data[2].flatten())
|
| 25 |
+
aia_dict[3].append(aia_data[3].flatten())
|
| 26 |
+
aia_dict[4].append(aia_data[4].flatten())
|
| 27 |
+
aia_dict[5].append(aia_data[5].flatten())
|
| 28 |
+
count = count + 1
|
| 29 |
+
print("Flares: " + str(count) + "\n")
|
| 30 |
+
print(f"\nProcessed {i+1}/{len(aia)} files", end='\r')
|
| 31 |
+
|
| 32 |
+
def percentile(data, perc):
|
| 33 |
+
return np.percentile(data, perc)
|
| 34 |
+
|
| 35 |
+
percentile_dict = {0:[percentile(aia_dict[0], 95), percentile(aia_dict[0], 99.5)],1: [percentile(aia_dict[1], 95), percentile(aia_dict[1], 99.5)], 2: [percentile(aia_dict[2], 95), percentile(aia_dict[2], 99.5)], 3: [percentile(aia_dict[3], 95), percentile(aia_dict[3], 99.5)], 4: [percentile(aia_dict[4], 95), percentile(aia_dict[4], 99.5)], 5: [percentile(aia_dict[5], 95), percentile(aia_dict[5], 99.5)]}
|
| 36 |
+
|
| 37 |
+
print(percentile_dict)
|
| 38 |
+
#{0: [np.float32(5.0747647), np.float32(16.560747)], 1: [np.float32(24.491392), np.float32(75.84181)], 2: [np.float32(607.3201), np.float32(1536.1443)], 3: [np.float32(1021.83466), np.float32(2288.1)], 4: [np.float32(480.13672), np.float32(1163.9178)], 5: [np.float32(144.44502), np.float32(401.82352)]}
|