griffingoodwin04 commited on
Commit
3720287
·
1 Parent(s): 5493436

added megai baseline

Browse files
flaring/MEGS_AI_baseline/SDOAIA_dataloader.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader, Subset
3
+ import numpy as np
4
+ from pathlib import Path
5
+ from scipy.ndimage import zoom
6
+ import torchvision.transforms as T
7
+ from pytorch_lightning import LightningDataModule
8
+ import glob
9
+ import os
10
+
11
+ class AIA_GOESDataset(torch.utils.data.Dataset):
12
+ """Dataset for loading AIA images and SXR values for regression."""
13
+
14
+ def __init__(self, aia_dir, sxr_dir, transform=None, sxr_transform=None, target_size=(512, 512)):
15
+ self.aia_dir = Path(aia_dir).resolve()
16
+ self.sxr_dir = Path(sxr_dir).resolve()
17
+ self.transform = transform
18
+ self.sxr_transform = sxr_transform
19
+ self.target_size = target_size
20
+ self.samples = []
21
+
22
+ # Check directories
23
+ if not self.aia_dir.is_dir():
24
+ raise FileNotFoundError(f"AIA directory not found: {self.aia_dir}")
25
+ if not self.sxr_dir.is_dir():
26
+ raise FileNotFoundError(f"SXR directory not found: {self.sxr_dir}")
27
+
28
+ # Find matching files
29
+ aia_files = sorted(glob.glob(str(self.aia_dir / "*.npy")))
30
+ aia_files = [Path(f) for f in aia_files]
31
+
32
+ for f in aia_files:
33
+ timestamp = f.stem
34
+ sxr_path = self.sxr_dir / f"{timestamp}.npy"
35
+ if sxr_path.exists():
36
+ self.samples.append(timestamp)
37
+
38
+ if len(self.samples) == 0:
39
+ raise ValueError("No valid sample pairs found")
40
+
41
+ def __len__(self):
42
+ return len(self.samples)
43
+
44
+ def __getitem__(self, idx):
45
+ timestamp = self.samples[idx]
46
+ aia_path = self.aia_dir / f"{timestamp}.npy"
47
+ sxr_path = self.sxr_dir / f"{timestamp}.npy"
48
+
49
+ # Load AIA image as (6, H, W)
50
+ aia_img = np.load(aia_path)
51
+ if aia_img.shape[0] != 6:
52
+ raise ValueError(f"AIA image has {aia_img.shape[0]} channels, expected 6")
53
+
54
+ # Resize if needed (operates on (6, H, W))
55
+ if aia_img.shape[1:3] != self.target_size:
56
+ aia_img = zoom(aia_img, (1,
57
+ self.target_size[0]/aia_img.shape[1],
58
+ self.target_size[1]/aia_img.shape[2]))
59
+ # Convert to torch for transforms
60
+ aia_img = torch.tensor(aia_img, dtype=torch.float32) # (6, H, W)
61
+
62
+ # Apply transforms (should expect channel-first (C, H, W))
63
+ if self.transform:
64
+ aia_img = self.transform(aia_img)
65
+
66
+ # Always output channel-last for model: (H, W, C)
67
+ aia_img = aia_img.permute(1,2,0) # (H, W, 6)
68
+
69
+ # Load SXR value
70
+ sxr_val = np.load(sxr_path)
71
+ if sxr_val.size != 1:
72
+ raise ValueError(f"SXR value has size {sxr_val.size}, expected scalar")
73
+ sxr_val = float(np.atleast_1d(sxr_val).flatten()[0])
74
+ if self.sxr_transform:
75
+ sxr_val = self.sxr_transform(sxr_val)
76
+
77
+ return (aia_img, torch.tensor(sxr_val, dtype=torch.float32)), torch.tensor(sxr_val, dtype=torch.float32)
78
+
79
+ class AIA_GOESDataModule(LightningDataModule):
80
+ """PyTorch Lightning DataModule for AIA and SXR data."""
81
+
82
+ def __init__(self, aia_dir, sxr_dir, sxr_norm, batch_size=16, num_workers=4,
83
+ train_transforms=None, val_transforms=None, val_split=0.2, test_split=0.1):
84
+ super().__init__()
85
+ self.aia_dir = aia_dir
86
+ self.sxr_dir = sxr_dir
87
+ self.sxr_norm = sxr_norm
88
+ self.batch_size = batch_size
89
+ self.num_workers = num_workers
90
+ self.train_transforms = train_transforms
91
+ self.val_transforms = val_transforms
92
+ self.val_split = val_split
93
+ self.test_split = test_split
94
+
95
+ def setup(self, stage=None):
96
+ # Prepare base set just to get indices
97
+ base_ds = AIA_GOESDataset(
98
+ aia_dir=self.aia_dir,
99
+ sxr_dir=self.sxr_dir,
100
+ transform=None,
101
+ sxr_transform=T.Lambda(lambda x: (np.log10(x + 1e-8) - self.sxr_norm[0]) / self.sxr_norm[1]),
102
+ target_size=(512, 512)
103
+ )
104
+
105
+ total_size = len(base_ds)
106
+ test_size = int(self.test_split * total_size)
107
+ val_size = int(self.val_split * total_size)
108
+ train_size = total_size - val_size - test_size
109
+
110
+ indices = np.random.permutation(total_size)
111
+ train_idx = indices[:train_size]
112
+ val_idx = indices[train_size:train_size + val_size]
113
+ test_idx = indices[train_size + val_size:]
114
+
115
+ # Now, re-instantiate with proper transforms for all splits
116
+ full_train_ds = AIA_GOESDataset(
117
+ aia_dir=self.aia_dir,
118
+ sxr_dir=self.sxr_dir,
119
+ transform=self.train_transforms,
120
+ sxr_transform=T.Lambda(lambda x: (np.log10(x + 1e-8) - self.sxr_norm[0]) / self.sxr_norm[1]),
121
+ target_size=(512, 512)
122
+ )
123
+ self.train_ds = Subset(full_train_ds, train_idx)
124
+
125
+ full_val_ds = AIA_GOESDataset(
126
+ aia_dir=self.aia_dir,
127
+ sxr_dir=self.sxr_dir,
128
+ transform=self.val_transforms,
129
+ sxr_transform=T.Lambda(lambda x: (np.log10(x + 1e-8) - self.sxr_norm[0]) / self.sxr_norm[1]),
130
+ target_size=(512, 512)
131
+ )
132
+ self.valid_ds = Subset(full_val_ds, val_idx)
133
+
134
+ full_test_ds = AIA_GOESDataset(
135
+ aia_dir=self.aia_dir,
136
+ sxr_dir=self.sxr_dir,
137
+ transform=self.val_transforms,
138
+ sxr_transform=T.Lambda(lambda x: (np.log10(x + 1e-8) - self.sxr_norm[0]) / self.sxr_norm[1]),
139
+ target_size=(512, 512)
140
+ )
141
+ self.test_ds = Subset(full_test_ds, test_idx)
142
+
143
+ def train_dataloader(self):
144
+ return DataLoader(self.train_ds, batch_size=self.batch_size,
145
+ shuffle=True, num_workers=self.num_workers)
146
+
147
+ def val_dataloader(self):
148
+ return DataLoader(self.valid_ds, batch_size=self.batch_size,
149
+ shuffle=False, num_workers=self.num_workers)
150
+
151
+ def test_dataloader(self):
152
+ return DataLoader(self.test_ds, batch_size=self.batch_size,
153
+ shuffle=False, num_workers=self.num_workers)
flaring/MEGS_AI_baseline/__init__.py ADDED
File without changes
flaring/MEGS_AI_baseline/base_model.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from pytorch_lightning import LightningModule
5
+
6
+ class BaseModel(LightningModule):
7
+ def __init__(self, model, eve_norm, loss_func, lr):
8
+ super().__init__()
9
+ self.model = model
10
+ self.eve_norm = eve_norm # Used for SXR normalization (mean, std)
11
+ self.loss_func = loss_func
12
+ self.lr = lr
13
+
14
+ def forward(self, x, sxr=None):
15
+ return self.model(x)
16
+
17
+ def configure_optimizers(self):
18
+ return torch.optim.Adam(self.parameters(), lr=self.lr)
19
+
20
+ def training_step(self, batch, batch_idx):
21
+ (x, sxr), target = batch
22
+ pred = self(x, sxr)
23
+ pred = pred * self.eve_norm[1] + self.eve_norm[0] # Denormalize for loss
24
+ target = target * self.eve_norm[1] + self.eve_norm[0] # Denormalize target
25
+ loss = self.loss_func(pred, target)
26
+ self.log('train_loss', loss)
27
+ return loss
28
+
29
+ def validation_step(self, batch, batch_idx):
30
+ (x, sxr), target = batch
31
+ pred = self(x, sxr)
32
+ pred = pred * self.eve_norm[1] + self.eve_norm[0]
33
+ target = target * self.eve_norm[1] + self.eve_norm[0]
34
+ loss = self.loss_func(pred, target)
35
+ self.log('valid_loss', loss)
36
+ return loss
37
+
38
+ def test_step(self, batch, batch_idx):
39
+ (x, sxr), target = batch
40
+ pred = self(x, sxr)
41
+ pred = pred * self.eve_norm[1] + self.eve_norm[0]
42
+ target = target * self.eve_norm[1] + self.eve_norm[0]
43
+ loss = self.loss_func(pred, target)
44
+ self.log('test_loss', loss)
45
+ return loss
flaring/MEGS_AI_baseline/chopped_alexnet.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from torch import nn
3
+ from torch.nn import HuberLoss
4
+ from irradiance.models.base_model import BaseModel
5
+
6
+
7
+ class ChoppedAlexnet(BaseModel):
8
+
9
+ # def __init__(self, numlayers, n_channels, outSize, dropout):
10
+ def __init__(self, d_input, d_output, eve_norm, loss_func=HuberLoss(), numLayers=3, dropout=0, lr=1e-4):
11
+ self.numLayers = numLayers
12
+ self.n_channels = d_input
13
+ self.outSize = d_output
14
+ self.loss_func = HuberLoss() # consider MSE
15
+
16
+ layers, channelSize = self.getLayers(self.numLayers, self.n_channels)
17
+ self.features = nn.Sequential(*layers)
18
+ self.pool = nn.AdaptiveAvgPool2d((1,1))
19
+
20
+ model = nn.Sequential(nn.Dropout(p=dropout),
21
+ nn.Linear(channelSize, self.outSize))
22
+
23
+ for m in self.modules():
24
+ if isinstance(m, nn.Conv2d):
25
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
26
+ elif isinstance(m, nn.BatchNorm2d):
27
+ nn.init.constant_(m.weight, 1)
28
+ nn.init.constant_(m.bias, 0)
29
+ super().__init__(model=model, eve_norm=eve_norm, loss_func=loss_func, lr=lr)
30
+
31
+ def getLayers(self, numLayers, n_channels):
32
+ """Returns a list of layers + the feature size coming out"""
33
+ layers = [nn.Conv2d(n_channels, 64, kernel_size=11, stride=4, padding=2), nn.BatchNorm2d(64), nn.ReLU(inplace=True), ]
34
+ if numLayers == 1:
35
+ return (layers, 64)
36
+ layers += [nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(64, 192, kernel_size=5, padding=2), nn.BatchNorm2d(192), nn.ReLU(inplace=True), ]
37
+ if numLayers == 2:
38
+ return (layers, 192)
39
+ layers += [nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(192, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384), nn.ReLU(inplace=True)]
40
+ if numLayers == 3:
41
+ return (layers,384)
42
+
43
+ layers += [nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.BatchNorm2d(256)]
44
+ return (layers,256)
45
+
46
+ def forward(self, x):
47
+ x = self.features(x)
48
+ x = self.pool(x).view(x.size(0),-1)
49
+ x = self.model(x)
50
+ return x
flaring/MEGS_AI_baseline/efficientnet.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ from torch import nn
3
+ from torch.nn import HuberLoss
4
+ from irradiance.models.base_model import BaseModel
5
+
6
+
7
+ class EfficientnetIrradiance(BaseModel):
8
+
9
+ def __init__(self, d_input, d_output, eve_norm, loss_func=HuberLoss(), model='efficientnet_b0', dp=0.75, lr=1e-4):
10
+ if model == 'efficientnet_b0':
11
+ model = torchvision.models.efficientnet_b0(pretrained=True)
12
+ elif model == 'efficientnet_b1':
13
+ model = torchvision.models.efficientnet_b1(pretrained=True)
14
+ elif model == 'efficientnet_b2':
15
+ model = torchvision.models.efficientnet_b2(pretrained=True)
16
+ elif model == 'efficientnet_b3':
17
+ model = torchvision.models.efficientnet_b3(pretrained=True)
18
+ elif model == 'efficientnet_b4':
19
+ model = torchvision.models.efficientnet_b4(pretrained=True)
20
+ elif model == 'efficientnet_b5':
21
+ model = torchvision.models.efficientnet_b5(pretrained=True)
22
+ elif model == 'efficientnet_b6':
23
+ model = torchvision.models.efficientnet_b6(pretrained=True)
24
+ elif model == 'efficientnet_b7':
25
+ model = torchvision.models.efficientnet_b7(pretrained=True)
26
+ conv1_out = model.features[0][0].out_channels
27
+ model.features[0][0] = nn.Conv2d(d_input, conv1_out, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
28
+
29
+ lin_in = model.classifier[1].in_features
30
+ # consider adding average pool of full image(s)
31
+ classifier = nn.Sequential(nn.Dropout(p=dp, inplace=True),
32
+ nn.Linear(in_features=lin_in, out_features=d_output, bias=True))
33
+ model.classifier = classifier
34
+ # set all dropouts to 0.75
35
+ # TODO: other dropout values?
36
+ for m in model.modules():
37
+ if m.__class__.__name__.startswith('Dropout'):
38
+ m.p = dp
39
+ model = model
40
+
41
+ super().__init__(model=model, eve_norm=eve_norm, loss_func=loss_func, lr=lr)
42
+
43
+ def forward(self, x):
44
+ x = self.model(x)
45
+ return x
flaring/MEGS_AI_baseline/kan_success.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Li, Ziyao
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from typing import *
20
+ from torch.nn import HuberLoss
21
+ from irradiance.models.base_model import BaseModel
22
+
23
+
24
+ class SplineLinear(nn.Linear):
25
+ def __init__(self, in_features: int, out_features: int, init_scale: float = 0.1, **kw) -> None:
26
+ self.init_scale = init_scale
27
+ super().__init__(in_features, out_features, bias=False, **kw)
28
+
29
+ def reset_parameters(self) -> None:
30
+ nn.init.trunc_normal_(self.weight, mean=0, std=self.init_scale)
31
+
32
+
33
+ class RadialBasisFunction(nn.Module):
34
+ def __init__(
35
+ self,
36
+ grid_min: float = -2.,
37
+ grid_max: float = 2.,
38
+ num_grids: int = 8,
39
+ denominator: float = None, # larger denominators lead to smoother basis
40
+ ):
41
+ super().__init__()
42
+ self.grid_min = grid_min
43
+ self.grid_max = grid_max
44
+ self.num_grids = num_grids
45
+ grid = torch.linspace(grid_min, grid_max, num_grids)
46
+ self.grid = torch.nn.Parameter(grid, requires_grad=False)
47
+ self.denominator = denominator or (grid_max - grid_min) / (num_grids - 1)
48
+
49
+ def forward(self, x):
50
+ return torch.exp(-((x[..., None] - self.grid) / self.denominator) ** 2)
51
+
52
+ class FastKANLayer(nn.Module):
53
+ def __init__(
54
+ self,
55
+ input_dim: int,
56
+ output_dim: int,
57
+ grid_min: float = -2.,
58
+ grid_max: float = 2.,
59
+ num_grids: int = 8,
60
+ use_base_update: bool = True,
61
+ use_layernorm: bool = True,
62
+ base_activation = F.silu,
63
+ spline_weight_init_scale: float = 0.1,
64
+ ) -> None:
65
+ super().__init__()
66
+ self.input_dim = input_dim
67
+ self.output_dim = output_dim
68
+ self.layernorm = None
69
+ if use_layernorm:
70
+ assert input_dim > 1, "Do not use layernorms on 1D inputs. Set `use_layernorm=False`."
71
+ self.layernorm = nn.LayerNorm(input_dim)
72
+ self.rbf = RadialBasisFunction(grid_min, grid_max, num_grids)
73
+ self.spline_linear = SplineLinear(input_dim * num_grids, output_dim, spline_weight_init_scale)
74
+ self.use_base_update = use_base_update
75
+ if use_base_update:
76
+ self.base_activation = base_activation
77
+ self.base_linear = nn.Linear(input_dim, output_dim)
78
+
79
+ def forward(self, x, use_layernorm=True):
80
+ if self.layernorm is not None and use_layernorm:
81
+ spline_basis = self.rbf(self.layernorm(x))
82
+ else:
83
+ spline_basis = self.rbf(x)
84
+ ret = self.spline_linear(spline_basis.view(*spline_basis.shape[:-2], -1))
85
+ if self.use_base_update:
86
+ base = self.base_linear(self.base_activation(x))
87
+ ret = ret + base
88
+ return ret
89
+
90
+ def plot_curve(
91
+ self,
92
+ input_index: int,
93
+ output_index: int,
94
+ num_pts: int = 1000,
95
+ num_extrapolate_bins: int = 2
96
+ ):
97
+ '''this function returns the learned curves in a FastKANLayer.
98
+ input_index: the selected index of the input, in [0, input_dim) .
99
+ output_index: the selected index of the output, in [0, output_dim) .
100
+ num_pts: num of points sampled for the curve.
101
+ num_extrapolate_bins (N_e): num of bins extrapolating from the given grids. The curve
102
+ will be calculate in the range of [grid_min - h * N_e, grid_max + h * N_e].
103
+ '''
104
+ ng = self.rbf.num_grids
105
+ h = self.rbf.denominator
106
+ assert input_index < self.input_dim
107
+ assert output_index < self.output_dim
108
+ w = self.spline_linear.weight[
109
+ output_index, input_index * ng : (input_index + 1) * ng
110
+ ] # num_grids,
111
+ x = torch.linspace(
112
+ self.rbf.grid_min - num_extrapolate_bins * h,
113
+ self.rbf.grid_max + num_extrapolate_bins * h,
114
+ num_pts
115
+ ) # num_pts, num_grids
116
+ with torch.no_grad():
117
+ y = (w * self.rbf(x.to(w.dtype))).sum(-1)
118
+ return x, y
119
+
120
+
121
+ class FastKANIrradiance(BaseModel):
122
+ def __init__(
123
+ self,
124
+ eve_norm,
125
+ layers_hidden: List[int],
126
+ grid_min: float = -2.,
127
+ grid_max: float = 2.,
128
+ num_grids: int = 8,
129
+ use_base_update: bool = True,
130
+ base_activation = F.silu,
131
+ spline_weight_init_scale: float = 0.1,
132
+ loss_func = HuberLoss(),
133
+ lr=1e-4,
134
+ use_std=False
135
+ ) -> None:
136
+ super().__init__(model=None, eve_norm=eve_norm, loss_func=loss_func, lr=lr)
137
+ self.use_std = use_std
138
+ if use_std:
139
+ layers_hidden[0] = layers_hidden[0]*2
140
+ self.layers = nn.ModuleList([
141
+ FastKANLayer(
142
+ in_dim, out_dim,
143
+ grid_min=grid_min,
144
+ grid_max=grid_max,
145
+ num_grids=num_grids,
146
+ use_base_update=use_base_update,
147
+ base_activation=base_activation,
148
+ spline_weight_init_scale=spline_weight_init_scale,
149
+ ) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
150
+ ])
151
+
152
+ def forward(self, x):
153
+ # Calculating mean and std of images to take them as input to 1D KAN
154
+ mean_irradiance = torch.torch.mean(x, dim=(2,3))
155
+ std_irradiance = torch.torch.std(x, dim=(2,3))
156
+ if self.use_std:
157
+ x = torch.cat((mean_irradiance, std_irradiance), dim=1)
158
+ else:
159
+ x = mean_irradiance
160
+ for layer in self.layers:
161
+ x = layer(x)
162
+ return x
163
+
164
+
165
+ class AttentionWithFastKANTransform(nn.Module):
166
+
167
+ def __init__(
168
+ self,
169
+ q_dim: int,
170
+ k_dim: int,
171
+ v_dim: int,
172
+ head_dim: int,
173
+ num_heads: int,
174
+ gating: bool = True,
175
+ ):
176
+ super(AttentionWithFastKANTransform, self).__init__()
177
+
178
+ self.num_heads = num_heads
179
+ total_dim = head_dim * self.num_heads
180
+ self.gating = gating
181
+ self.linear_q = FastKANLayer(q_dim, total_dim)
182
+ self.linear_k = FastKANLayer(k_dim, total_dim)
183
+ self.linear_v = FastKANLayer(v_dim, total_dim)
184
+ self.linear_o = FastKANLayer(total_dim, q_dim)
185
+ self.linear_g = None
186
+ if self.gating:
187
+ self.linear_g = FastKANLayer(q_dim, total_dim)
188
+ # precompute the 1/sqrt(head_dim)
189
+ self.norm = head_dim**-0.5
190
+
191
+ def forward(
192
+ self,
193
+ q: torch.Tensor,
194
+ k: torch.Tensor,
195
+ v: torch.Tensor,
196
+ bias: torch.Tensor = None, # additive attention bias
197
+ ) -> torch.Tensor:
198
+
199
+ wq = self.linear_q(q).view(*q.shape[:-1], 1, self.num_heads, -1) * self.norm # *q1hc
200
+ wk = self.linear_k(k).view(*k.shape[:-2], 1, k.shape[-2], self.num_heads, -1) # *1khc
201
+ att = (wq * wk).sum(-1).softmax(-2) # *qkh
202
+ del wq, wk
203
+ if bias is not None:
204
+ att = att + bias[..., None]
205
+
206
+ wv = self.linear_v(v).view(*v.shape[:-2],1, v.shape[-2], self.num_heads, -1) # *1khc
207
+ o = (att[..., None] * wv).sum(-3) # *qhc
208
+ del att, wv
209
+
210
+ o = o.view(*o.shape[:-2], -1) # *q(hc)
211
+
212
+ if self.linear_g is not None:
213
+ # gating, use raw query input
214
+ g = self.linear_g(q)
215
+ o = torch.sigmoid(g) * o
216
+
217
+ # merge heads
218
+ o = self.linear_o(o)
219
+ return o
flaring/MEGS_AI_baseline/linear_and_hybrid.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import HuberLoss
4
+ from models.base_model import BaseModel
5
+
6
+ class LinearIrradianceModel(BaseModel):
7
+ def __init__(self, d_input, d_output, eve_norm, loss_func=HuberLoss(), lr=1e-2):
8
+ self.n_channels = d_input
9
+ self.outSize = d_output
10
+ model = nn.Linear(2 * self.n_channels, self.outSize)
11
+ super().__init__(model=model, eve_norm=eve_norm, loss_func=loss_func, lr=lr)
12
+
13
+ def forward(self, x, sxr=None, **kwargs):
14
+ # If x is a tuple (aia_img, sxr_val), extract the AIA image tensor
15
+ if isinstance(x, (list, tuple)):
16
+ x = x[0]
17
+
18
+ # Debug: Print input shape
19
+ print(f"Input shape to LinearIrradianceModel.forward: {x.shape}")
20
+
21
+ # Expect x shape: (batch_size, H, W, C)
22
+ if len(x.shape) != 4:
23
+ raise ValueError(f"Expected 4D input tensor (batch_size, H, W, C), got shape {x.shape}")
24
+ if x.shape[-1] != self.n_channels:
25
+ raise ValueError(f"AIA image has {x.shape[-1]} channels, expected {self.n_channels}")
26
+
27
+ # Calculate mean and std across spatial dimensions (H,W)
28
+ # First permute to (batch_size, C, H, W)
29
+ x = x.permute(0, 3, 1, 2)
30
+
31
+ # Now calculate mean/std across dimensions 2 and 3 (H,W)
32
+ mean_irradiance = torch.mean(x, dim=(2, 3)) # Shape: (batch_size, n_channels)
33
+ std_irradiance = torch.std(x, dim=(2, 3)) # Shape: (batch_size, n_channels)
34
+
35
+ # Debug: Print shapes after mean and std
36
+ print(f"mean_irradiance shape: {mean_irradiance.shape}, std_irradiance shape: {std_irradiance.shape}")
37
+
38
+ input_features = torch.cat((mean_irradiance, std_irradiance), dim=1) # Shape: (batch_size, 2 * n_channels)
39
+ print(f"Input features shape to linear layer: {input_features.shape}")
40
+
41
+ if input_features.shape[1] != 2 * self.n_channels:
42
+ raise ValueError(f"Expected {2 * self.n_channels} features, got {input_features.shape[1]}")
43
+
44
+ return self.model(input_features)
45
+
46
+ class HybridIrradianceModel(BaseModel):
47
+ def __init__(self, d_input, d_output, eve_norm, cnn_model='resnet', ln_model=True, ln_params=None, lr=1e-4, cnn_dp=0.75, loss_func=HuberLoss()):
48
+ super().__init__(model=None, eve_norm=eve_norm, loss_func=loss_func, lr=lr)
49
+ self.n_channels = d_input
50
+ self.outSize = d_output
51
+ self.ln_params = ln_params
52
+ self.ln_model = None
53
+ if ln_model:
54
+ self.ln_model = LinearIrradianceModel(d_input, d_output, eve_norm, loss_func=loss_func, lr=lr)
55
+ if self.ln_params is not None and self.ln_model is not None:
56
+ self.ln_model.model.weight = nn.Parameter(self.ln_params['weight'])
57
+ self.ln_model.model.bias = nn.Parameter(self.ln_params['bias'])
58
+ self.cnn_model = None
59
+ self.cnn_lambda = 1.
60
+ if cnn_model == 'resnet':
61
+ self.cnn_model = nn.Sequential(
62
+ nn.Conv2d(d_input, 64, kernel_size=7, stride=2, padding=3),
63
+ nn.ReLU(),
64
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
65
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
66
+ nn.ReLU(),
67
+ nn.AdaptiveAvgPool2d((1, 1)),
68
+ nn.Flatten(),
69
+ nn.Linear(64, d_output),
70
+ nn.Dropout(cnn_dp)
71
+ )
72
+ elif cnn_model.startswith('efficientnet'):
73
+ raise NotImplementedError("EfficientNet requires timm; replace with custom CNN or install timm")
74
+ if self.ln_model is None and self.cnn_model is None:
75
+ raise ValueError('Please pass at least one model.')
76
+
77
+ def forward(self, x, sxr=None, **kwargs):
78
+ # If x is a tuple (aia_img, sxr_val), extract the AIA image tensor
79
+ if isinstance(x, (list, tuple)):
80
+ x = x[0]
81
+
82
+ # Debug: Print input shape
83
+ print(f"Input shape to HybridIrradianceModel.forward: {x.shape}")
84
+
85
+ # Expect x shape: (batch_size, H, W, C)
86
+ if len(x.shape) != 4:
87
+ raise ValueError(f"Expected 4D input tensor (batch_size, H, W, C), got shape {x.shape}")
88
+ if x.shape[-1] != self.n_channels:
89
+ raise ValueError(f"AIA image has {x.shape[-1]} channels, expected {self.n_channels}")
90
+
91
+ # Convert to (batch_size, C, H, W) for CNN
92
+ x_cnn = x.permute(0, 3, 1, 2)
93
+
94
+ if self.ln_model is not None and self.cnn_model is not None:
95
+ # For linear model, keep original (B,H,W,C) format
96
+ return self.ln_model(x) + self.cnn_lambda * self.cnn_model(x_cnn)
97
+ elif self.ln_model is not None:
98
+ return self.ln_model(x)
99
+ elif self.cnn_model is not None:
100
+ return self.cnn_model(x_cnn)
101
+
102
+ def configure_optimizers(self):
103
+ return torch.optim.Adam(self.parameters(), lr=self.lr)
104
+
105
+ def set_train_mode(self, mode):
106
+ if mode == 'linear':
107
+ self.cnn_lambda = 0
108
+ if self.cnn_model: self.cnn_model.eval()
109
+ if self.ln_model: self.ln_model.train()
110
+ elif mode == 'cnn':
111
+ self.cnn_lambda = 0.01
112
+ if self.cnn_model: self.cnn_model.train()
113
+ if self.ln_model: self.ln_model.eval()
114
+ elif mode == 'both':
115
+ self.cnn_lambda = 0.01
116
+ if self.cnn_model: self.cnn_model.train()
117
+ if self.ln_model: self.ln_model.train()
118
+ else:
119
+ raise NotImplementedError(f'Mode not supported: {mode}')
flaring/MEGS_AI_baseline/models/base_model.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from pytorch_lightning import LightningModule
5
+
6
+ class BaseModel(LightningModule):
7
+ def __init__(self, model, eve_norm, loss_func, lr):
8
+ super().__init__()
9
+ self.model = model
10
+ self.eve_norm = eve_norm # Used for SXR normalization (mean, std)
11
+ self.loss_func = loss_func
12
+ self.lr = lr
13
+
14
+ def forward(self, x, sxr=None):
15
+ return self.model(x)
16
+
17
+ def configure_optimizers(self):
18
+ return torch.optim.Adam(self.parameters(), lr=self.lr)
19
+
20
+ def training_step(self, batch, batch_idx):
21
+ (x, sxr), target = batch
22
+ pred = self(x, sxr)
23
+ pred = pred * self.eve_norm[1] + self.eve_norm[0] # Denormalize for loss
24
+ target = target * self.eve_norm[1] + self.eve_norm[0] # Denormalize target
25
+ loss = self.loss_func(pred, target)
26
+ self.log('train_loss', loss)
27
+ return loss
28
+
29
+ def validation_step(self, batch, batch_idx):
30
+ (x, sxr), target = batch
31
+ pred = self(x, sxr)
32
+ pred = pred * self.eve_norm[1] + self.eve_norm[0]
33
+ target = target * self.eve_norm[1] + self.eve_norm[0]
34
+ loss = self.loss_func(pred, target)
35
+ self.log('valid_loss', loss)
36
+ return loss
37
+
38
+ def test_step(self, batch, batch_idx):
39
+ (x, sxr), target = batch
40
+ pred = self(x, sxr)
41
+ pred = pred * self.eve_norm[1] + self.eve_norm[0]
42
+ target = target * self.eve_norm[1] + self.eve_norm[0]
43
+ loss = self.loss_func(pred, target)
44
+ self.log('test_loss', loss)
45
+ return loss
flaring/MEGS_AI_baseline/models/chopped_alexnet.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from torch import nn
3
+ from torch.nn import HuberLoss
4
+ from models.base_model import BaseModel
5
+
6
+
7
+ class ChoppedAlexnet(BaseModel):
8
+
9
+ # def __init__(self, numlayers, n_channels, outSize, dropout):
10
+ def __init__(self, d_input, d_output, eve_norm, loss_func=HuberLoss(), numLayers=3, dropout=0, lr=1e-4):
11
+ self.numLayers = numLayers
12
+ self.n_channels = d_input
13
+ self.outSize = d_output
14
+ self.loss_func = HuberLoss() # consider MSE
15
+
16
+ layers, channelSize = self.getLayers(self.numLayers, self.n_channels)
17
+ self.features = nn.Sequential(*layers)
18
+ self.pool = nn.AdaptiveAvgPool2d((1,1))
19
+
20
+ model = nn.Sequential(nn.Dropout(p=dropout),
21
+ nn.Linear(channelSize, self.outSize))
22
+
23
+ for m in self.modules():
24
+ if isinstance(m, nn.Conv2d):
25
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
26
+ elif isinstance(m, nn.BatchNorm2d):
27
+ nn.init.constant_(m.weight, 1)
28
+ nn.init.constant_(m.bias, 0)
29
+ super().__init__(model=model, eve_norm=eve_norm, loss_func=loss_func, lr=lr)
30
+
31
+ def getLayers(self, numLayers, n_channels):
32
+ """Returns a list of layers + the feature size coming out"""
33
+ layers = [nn.Conv2d(n_channels, 64, kernel_size=11, stride=4, padding=2), nn.BatchNorm2d(64), nn.ReLU(inplace=True), ]
34
+ if numLayers == 1:
35
+ return (layers, 64)
36
+ layers += [nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(64, 192, kernel_size=5, padding=2), nn.BatchNorm2d(192), nn.ReLU(inplace=True), ]
37
+ if numLayers == 2:
38
+ return (layers, 192)
39
+ layers += [nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(192, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384), nn.ReLU(inplace=True)]
40
+ if numLayers == 3:
41
+ return (layers,384)
42
+
43
+ layers += [nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.BatchNorm2d(256)]
44
+ return (layers,256)
45
+
46
+ def forward(self, x):
47
+ x = self.features(x)
48
+ x = self.pool(x).view(x.size(0),-1)
49
+ x = self.model(x)
50
+ return x
flaring/MEGS_AI_baseline/models/efficientnet.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ from torch import nn
3
+ from torch.nn import HuberLoss
4
+ from models.base_model import BaseModel
5
+
6
+
7
+ class EfficientnetIrradiance(BaseModel):
8
+
9
+ def __init__(self, d_input, d_output, eve_norm, loss_func=HuberLoss(), model='efficientnet_b0', dp=0.75, lr=1e-4):
10
+ if model == 'efficientnet_b0':
11
+ model = torchvision.models.efficientnet_b0(pretrained=True)
12
+ elif model == 'efficientnet_b1':
13
+ model = torchvision.models.efficientnet_b1(pretrained=True)
14
+ elif model == 'efficientnet_b2':
15
+ model = torchvision.models.efficientnet_b2(pretrained=True)
16
+ elif model == 'efficientnet_b3':
17
+ model = torchvision.models.efficientnet_b3(pretrained=True)
18
+ elif model == 'efficientnet_b4':
19
+ model = torchvision.models.efficientnet_b4(pretrained=True)
20
+ elif model == 'efficientnet_b5':
21
+ model = torchvision.models.efficientnet_b5(pretrained=True)
22
+ elif model == 'efficientnet_b6':
23
+ model = torchvision.models.efficientnet_b6(pretrained=True)
24
+ elif model == 'efficientnet_b7':
25
+ model = torchvision.models.efficientnet_b7(pretrained=True)
26
+ conv1_out = model.features[0][0].out_channels
27
+ model.features[0][0] = nn.Conv2d(d_input, conv1_out, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
28
+
29
+ lin_in = model.classifier[1].in_features
30
+ # consider adding average pool of full image(s)
31
+ classifier = nn.Sequential(nn.Dropout(p=dp, inplace=True),
32
+ nn.Linear(in_features=lin_in, out_features=d_output, bias=True))
33
+ model.classifier = classifier
34
+ # set all dropouts to 0.75
35
+ # TODO: other dropout values?
36
+ for m in model.modules():
37
+ if m.__class__.__name__.startswith('Dropout'):
38
+ m.p = dp
39
+ model = model
40
+
41
+ super().__init__(model=model, eve_norm=eve_norm, loss_func=loss_func, lr=lr)
42
+
43
+ def forward(self, x):
44
+ x = self.model(x)
45
+ return x
flaring/MEGS_AI_baseline/models/kan_success.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Li, Ziyao
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from typing import *
20
+ from torch.nn import HuberLoss
21
+ from irradiance.models.base_model import BaseModel
22
+
23
+
24
+ class SplineLinear(nn.Linear):
25
+ def __init__(self, in_features: int, out_features: int, init_scale: float = 0.1, **kw) -> None:
26
+ self.init_scale = init_scale
27
+ super().__init__(in_features, out_features, bias=False, **kw)
28
+
29
+ def reset_parameters(self) -> None:
30
+ nn.init.trunc_normal_(self.weight, mean=0, std=self.init_scale)
31
+
32
+
33
+ class RadialBasisFunction(nn.Module):
34
+ def __init__(
35
+ self,
36
+ grid_min: float = -2.,
37
+ grid_max: float = 2.,
38
+ num_grids: int = 8,
39
+ denominator: float = None, # larger denominators lead to smoother basis
40
+ ):
41
+ super().__init__()
42
+ self.grid_min = grid_min
43
+ self.grid_max = grid_max
44
+ self.num_grids = num_grids
45
+ grid = torch.linspace(grid_min, grid_max, num_grids)
46
+ self.grid = torch.nn.Parameter(grid, requires_grad=False)
47
+ self.denominator = denominator or (grid_max - grid_min) / (num_grids - 1)
48
+
49
+ def forward(self, x):
50
+ return torch.exp(-((x[..., None] - self.grid) / self.denominator) ** 2)
51
+
52
+ class FastKANLayer(nn.Module):
53
+ def __init__(
54
+ self,
55
+ input_dim: int,
56
+ output_dim: int,
57
+ grid_min: float = -2.,
58
+ grid_max: float = 2.,
59
+ num_grids: int = 8,
60
+ use_base_update: bool = True,
61
+ use_layernorm: bool = True,
62
+ base_activation = F.silu,
63
+ spline_weight_init_scale: float = 0.1,
64
+ ) -> None:
65
+ super().__init__()
66
+ self.input_dim = input_dim
67
+ self.output_dim = output_dim
68
+ self.layernorm = None
69
+ if use_layernorm:
70
+ assert input_dim > 1, "Do not use layernorms on 1D inputs. Set `use_layernorm=False`."
71
+ self.layernorm = nn.LayerNorm(input_dim)
72
+ self.rbf = RadialBasisFunction(grid_min, grid_max, num_grids)
73
+ self.spline_linear = SplineLinear(input_dim * num_grids, output_dim, spline_weight_init_scale)
74
+ self.use_base_update = use_base_update
75
+ if use_base_update:
76
+ self.base_activation = base_activation
77
+ self.base_linear = nn.Linear(input_dim, output_dim)
78
+
79
+ def forward(self, x, use_layernorm=True):
80
+ if self.layernorm is not None and use_layernorm:
81
+ spline_basis = self.rbf(self.layernorm(x))
82
+ else:
83
+ spline_basis = self.rbf(x)
84
+ ret = self.spline_linear(spline_basis.view(*spline_basis.shape[:-2], -1))
85
+ if self.use_base_update:
86
+ base = self.base_linear(self.base_activation(x))
87
+ ret = ret + base
88
+ return ret
89
+
90
+ def plot_curve(
91
+ self,
92
+ input_index: int,
93
+ output_index: int,
94
+ num_pts: int = 1000,
95
+ num_extrapolate_bins: int = 2
96
+ ):
97
+ '''this function returns the learned curves in a FastKANLayer.
98
+ input_index: the selected index of the input, in [0, input_dim) .
99
+ output_index: the selected index of the output, in [0, output_dim) .
100
+ num_pts: num of points sampled for the curve.
101
+ num_extrapolate_bins (N_e): num of bins extrapolating from the given grids. The curve
102
+ will be calculate in the range of [grid_min - h * N_e, grid_max + h * N_e].
103
+ '''
104
+ ng = self.rbf.num_grids
105
+ h = self.rbf.denominator
106
+ assert input_index < self.input_dim
107
+ assert output_index < self.output_dim
108
+ w = self.spline_linear.weight[
109
+ output_index, input_index * ng : (input_index + 1) * ng
110
+ ] # num_grids,
111
+ x = torch.linspace(
112
+ self.rbf.grid_min - num_extrapolate_bins * h,
113
+ self.rbf.grid_max + num_extrapolate_bins * h,
114
+ num_pts
115
+ ) # num_pts, num_grids
116
+ with torch.no_grad():
117
+ y = (w * self.rbf(x.to(w.dtype))).sum(-1)
118
+ return x, y
119
+
120
+
121
+ class FastKANIrradiance(BaseModel):
122
+ def __init__(
123
+ self,
124
+ eve_norm,
125
+ layers_hidden: List[int],
126
+ grid_min: float = -2.,
127
+ grid_max: float = 2.,
128
+ num_grids: int = 8,
129
+ use_base_update: bool = True,
130
+ base_activation = F.silu,
131
+ spline_weight_init_scale: float = 0.1,
132
+ loss_func = HuberLoss(),
133
+ lr=1e-4,
134
+ use_std=False
135
+ ) -> None:
136
+ super().__init__(model=None, eve_norm=eve_norm, loss_func=loss_func, lr=lr)
137
+ self.use_std = use_std
138
+ if use_std:
139
+ layers_hidden[0] = layers_hidden[0]*2
140
+ self.layers = nn.ModuleList([
141
+ FastKANLayer(
142
+ in_dim, out_dim,
143
+ grid_min=grid_min,
144
+ grid_max=grid_max,
145
+ num_grids=num_grids,
146
+ use_base_update=use_base_update,
147
+ base_activation=base_activation,
148
+ spline_weight_init_scale=spline_weight_init_scale,
149
+ ) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
150
+ ])
151
+
152
+ def forward(self, x):
153
+ # Calculating mean and std of images to take them as input to 1D KAN
154
+ mean_irradiance = torch.torch.mean(x, dim=(2,3))
155
+ std_irradiance = torch.torch.std(x, dim=(2,3))
156
+ if self.use_std:
157
+ x = torch.cat((mean_irradiance, std_irradiance), dim=1)
158
+ else:
159
+ x = mean_irradiance
160
+ for layer in self.layers:
161
+ x = layer(x)
162
+ return x
163
+
164
+
165
+ class AttentionWithFastKANTransform(nn.Module):
166
+
167
+ def __init__(
168
+ self,
169
+ q_dim: int,
170
+ k_dim: int,
171
+ v_dim: int,
172
+ head_dim: int,
173
+ num_heads: int,
174
+ gating: bool = True,
175
+ ):
176
+ super(AttentionWithFastKANTransform, self).__init__()
177
+
178
+ self.num_heads = num_heads
179
+ total_dim = head_dim * self.num_heads
180
+ self.gating = gating
181
+ self.linear_q = FastKANLayer(q_dim, total_dim)
182
+ self.linear_k = FastKANLayer(k_dim, total_dim)
183
+ self.linear_v = FastKANLayer(v_dim, total_dim)
184
+ self.linear_o = FastKANLayer(total_dim, q_dim)
185
+ self.linear_g = None
186
+ if self.gating:
187
+ self.linear_g = FastKANLayer(q_dim, total_dim)
188
+ # precompute the 1/sqrt(head_dim)
189
+ self.norm = head_dim**-0.5
190
+
191
+ def forward(
192
+ self,
193
+ q: torch.Tensor,
194
+ k: torch.Tensor,
195
+ v: torch.Tensor,
196
+ bias: torch.Tensor = None, # additive attention bias
197
+ ) -> torch.Tensor:
198
+
199
+ wq = self.linear_q(q).view(*q.shape[:-1], 1, self.num_heads, -1) * self.norm # *q1hc
200
+ wk = self.linear_k(k).view(*k.shape[:-2], 1, k.shape[-2], self.num_heads, -1) # *1khc
201
+ att = (wq * wk).sum(-1).softmax(-2) # *qkh
202
+ del wq, wk
203
+ if bias is not None:
204
+ att = att + bias[..., None]
205
+
206
+ wv = self.linear_v(v).view(*v.shape[:-2],1, v.shape[-2], self.num_heads, -1) # *1khc
207
+ o = (att[..., None] * wv).sum(-3) # *qhc
208
+ del att, wv
209
+
210
+ o = o.view(*o.shape[:-2], -1) # *q(hc)
211
+
212
+ if self.linear_g is not None:
213
+ # gating, use raw query input
214
+ g = self.linear_g(q)
215
+ o = torch.sigmoid(g) * o
216
+
217
+ # merge heads
218
+ o = self.linear_o(o)
219
+ return o
flaring/MEGS_AI_baseline/models/linear_and_hybrid.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import HuberLoss
4
+ from models.base_model import BaseModel
5
+
6
+ class LinearIrradianceModel(BaseModel):
7
+ def __init__(self, d_input, d_output, eve_norm, loss_func=HuberLoss(), lr=1e-2):
8
+ self.n_channels = d_input
9
+ self.outSize = d_output
10
+ model = nn.Linear(2 * self.n_channels, self.outSize)
11
+ super().__init__(model=model, eve_norm=eve_norm, loss_func=loss_func, lr=lr)
12
+
13
+ def forward(self, x, sxr=None, **kwargs):
14
+ # If x is a tuple (aia_img, sxr_val), extract the AIA image tensor
15
+ if isinstance(x, (list, tuple)):
16
+ x = x[0]
17
+
18
+ # Debug: Print input shape
19
+ print(f"Input shape to LinearIrradianceModel.forward: {x.shape}")
20
+
21
+ # Expect x shape: (batch_size, H, W, C)
22
+ if len(x.shape) != 4:
23
+ raise ValueError(f"Expected 4D input tensor (batch_size, H, W, C), got shape {x.shape}")
24
+ if x.shape[-1] != self.n_channels:
25
+ raise ValueError(f"AIA image has {x.shape[-1]} channels, expected {self.n_channels}")
26
+
27
+ # Calculate mean and std across spatial dimensions (H,W)
28
+ # First permute to (batch_size, C, H, W)
29
+ x = x.permute(0, 3, 1, 2)
30
+
31
+ # Now calculate mean/std across dimensions 2 and 3 (H,W)
32
+ mean_irradiance = torch.mean(x, dim=(2, 3)) # Shape: (batch_size, n_channels)
33
+ std_irradiance = torch.std(x, dim=(2, 3)) # Shape: (batch_size, n_channels)
34
+
35
+ # Debug: Print shapes after mean and std
36
+ print(f"mean_irradiance shape: {mean_irradiance.shape}, std_irradiance shape: {std_irradiance.shape}")
37
+
38
+ input_features = torch.cat((mean_irradiance, std_irradiance), dim=1) # Shape: (batch_size, 2 * n_channels)
39
+ print(f"Input features shape to linear layer: {input_features.shape}")
40
+
41
+ if input_features.shape[1] != 2 * self.n_channels:
42
+ raise ValueError(f"Expected {2 * self.n_channels} features, got {input_features.shape[1]}")
43
+
44
+ return self.model(input_features)
45
+
46
+ class HybridIrradianceModel(BaseModel):
47
+ def __init__(self, d_input, d_output, eve_norm, cnn_model='resnet', ln_model=True, ln_params=None, lr=1e-4, cnn_dp=0.75, loss_func=HuberLoss()):
48
+ super().__init__(model=None, eve_norm=eve_norm, loss_func=loss_func, lr=lr)
49
+ self.n_channels = d_input
50
+ self.outSize = d_output
51
+ self.ln_params = ln_params
52
+ self.ln_model = None
53
+ if ln_model:
54
+ self.ln_model = LinearIrradianceModel(d_input, d_output, eve_norm, loss_func=loss_func, lr=lr)
55
+ if self.ln_params is not None and self.ln_model is not None:
56
+ self.ln_model.model.weight = nn.Parameter(self.ln_params['weight'])
57
+ self.ln_model.model.bias = nn.Parameter(self.ln_params['bias'])
58
+ self.cnn_model = None
59
+ self.cnn_lambda = 1.
60
+ if cnn_model == 'resnet':
61
+ self.cnn_model = nn.Sequential(
62
+ nn.Conv2d(d_input, 64, kernel_size=7, stride=2, padding=3),
63
+ nn.ReLU(),
64
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
65
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
66
+ nn.ReLU(),
67
+ nn.AdaptiveAvgPool2d((1, 1)),
68
+ nn.Flatten(),
69
+ nn.Linear(64, d_output),
70
+ nn.Dropout(cnn_dp)
71
+ )
72
+ elif cnn_model.startswith('efficientnet'):
73
+ raise NotImplementedError("EfficientNet requires timm; replace with custom CNN or install timm")
74
+ if self.ln_model is None and self.cnn_model is None:
75
+ raise ValueError('Please pass at least one model.')
76
+
77
+ def forward(self, x, sxr=None, **kwargs):
78
+ # If x is a tuple (aia_img, sxr_val), extract the AIA image tensor
79
+ if isinstance(x, (list, tuple)):
80
+ x = x[0]
81
+
82
+ # Debug: Print input shape
83
+ print(f"Input shape to HybridIrradianceModel.forward: {x.shape}")
84
+
85
+ # Expect x shape: (batch_size, H, W, C)
86
+ if len(x.shape) != 4:
87
+ raise ValueError(f"Expected 4D input tensor (batch_size, H, W, C), got shape {x.shape}")
88
+ if x.shape[-1] != self.n_channels:
89
+ raise ValueError(f"AIA image has {x.shape[-1]} channels, expected {self.n_channels}")
90
+
91
+ # Convert to (batch_size, C, H, W) for CNN
92
+ x_cnn = x.permute(0, 3, 1, 2)
93
+
94
+ if self.ln_model is not None and self.cnn_model is not None:
95
+ # For linear model, keep original (B,H,W,C) format
96
+ return self.ln_model(x) + self.cnn_lambda * self.cnn_model(x_cnn)
97
+ elif self.ln_model is not None:
98
+ return self.ln_model(x)
99
+ elif self.cnn_model is not None:
100
+ return self.cnn_model(x_cnn)
101
+
102
+ def configure_optimizers(self):
103
+ return torch.optim.Adam(self.parameters(), lr=self.lr)
104
+
105
+ def set_train_mode(self, mode):
106
+ if mode == 'linear':
107
+ self.cnn_lambda = 0
108
+ if self.cnn_model: self.cnn_model.eval()
109
+ if self.ln_model: self.ln_model.train()
110
+ elif mode == 'cnn':
111
+ self.cnn_lambda = 0.01
112
+ if self.cnn_model: self.cnn_model.train()
113
+ if self.ln_model: self.ln_model.eval()
114
+ elif mode == 'both':
115
+ self.cnn_lambda = 0.01
116
+ if self.cnn_model: self.cnn_model.train()
117
+ if self.ln_model: self.ln_model.train()
118
+ else:
119
+ raise NotImplementedError(f'Mode not supported: {mode}')
flaring/MEGS_AI_baseline/sxr_normalization.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ from pathlib import Path
4
+ import glob
5
+ import os
6
+
7
+ def compute_sxr_norm(sxr_dir):
8
+ """
9
+ Compute mean and standard deviation of log10-transformed SXR values.
10
+
11
+ Args:
12
+ sxr_dir (str): Path to directory containing SXR .npy files.
13
+
14
+ Returns:
15
+ tuple: (mean, std) of log10(SXR + 1e-8) values.
16
+ """
17
+ sxr_dir = Path(sxr_dir).resolve()
18
+ print(f"Checking SXR directory: {sxr_dir}")
19
+ if not sxr_dir.is_dir():
20
+ raise FileNotFoundError(f"SXR directory does not exist or is not a directory: {sxr_dir}")
21
+
22
+ # Use glob for case-insensitive matching
23
+ sxr_files = sorted(glob.glob(os.path.join(sxr_dir, "*.npy")))
24
+ print(f"Found {len(sxr_files)} SXR files in {sxr_dir}")
25
+ if len(sxr_files) == 0:
26
+ print(f"No files matching '*_sxr.npy' found. Listing directory contents:")
27
+ print(os.listdir(sxr_dir)[:10]) # Show first 10 files
28
+ raise ValueError(f"No SXR files found in {sxr_dir}")
29
+
30
+ sxr_values = []
31
+ for f in sxr_files:
32
+ try:
33
+ sxr = np.load(f)
34
+ sxr = np.atleast_1d(sxr).flatten()[0]
35
+ if not np.isfinite(sxr) or sxr < 0:
36
+ print(f"Skipping invalid SXR value in {f}: {sxr}")
37
+ continue
38
+ sxr_values.append(np.log10(sxr + 1e-8))
39
+ except Exception as e:
40
+ print(f"Failed to load SXR file {f}: {e}")
41
+ continue
42
+
43
+ sxr_values = np.array(sxr_values)
44
+ if len(sxr_values) == 0:
45
+ raise ValueError(f"No valid SXR values found in {sxr_dir}. All files failed to load or contained invalid data.")
46
+
47
+ mean = np.mean(sxr_values)
48
+ std = np.std(sxr_values)
49
+ print(f"Computed SXR normalization: mean={mean}, std={std}")
50
+ return mean, std
51
+
52
+ if __name__ == "__main__":
53
+ # Update this path to your real data SXR directory
54
+ sxr_dir = "/mnt/data/ML-Ready-Data-No-Intensity-Cut/GOES-18-SXR-B/" # Replace with actual path
55
+ sxr_norm = compute_sxr_norm(sxr_dir)
56
+ np.save("/home/jayantbiradar619/sxr_norm2.npy", sxr_norm)
57
+ print(f"Saved SXR normalization to /home/jayantbiradar619/sxr_norm.npy")
flaring/MEGS_AI_baseline/train.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+ import os
4
+ import yaml
5
+ import itertools
6
+ import wandb
7
+ import torch
8
+ import numpy as np
9
+ from pathlib import Path
10
+ import torchvision.transforms as transforms
11
+ from pytorch_lightning import Trainer
12
+ from pytorch_lightning.loggers import WandbLogger
13
+ from pytorch_lightning.callbacks import ModelCheckpoint, Callback
14
+ from torch.nn import HuberLoss
15
+ from SDOAIA_dataloader import AIA_GOESDataModule
16
+ from linear_and_hybrid import LinearIrradianceModel, HybridIrradianceModel
17
+
18
+ # SXR Prediction Logger
19
+ class SXRPredictionLogger(Callback):
20
+ def __init__(self, val_samples):
21
+ super().__init__()
22
+ self.val_samples = val_samples
23
+
24
+ def on_validation_epoch_end(self, trainer, pl_module):
25
+ # val_samples is a list of ((aia, sxr), target)
26
+ for (aia, sxr), target in self.val_samples:
27
+ aia, sxr, target = aia.to(pl_module.device), sxr.to(pl_module.device), target.to(pl_module.device)
28
+ pred = pl_module(aia.unsqueeze(0)) # Add batch dimension
29
+ trainer.logger.experiment.log({
30
+ "val_pred_sxr": pred.cpu().numpy(),
31
+ "val_target_sxr": target.cpu().numpy()
32
+ })
33
+
34
+ # Compute SXR normalization
35
+ def compute_sxr_norm(sxr_dir):
36
+ sxr_values = []
37
+ for f in Path(sxr_dir).glob("*.npy"):
38
+ sxr = np.load(f)
39
+ sxr = np.atleast_1d(sxr).flatten()[0]
40
+ sxr_values.append(np.log10(sxr + 1e-8))
41
+ sxr_values = np.array(sxr_values)
42
+ if len(sxr_values) == 0:
43
+ raise ValueError(f"No SXR files found in {sxr_dir}")
44
+ return np.mean(sxr_values), np.std(sxr_values)
45
+
46
+ # Parser
47
+ parser = argparse.ArgumentParser()
48
+ parser.add_argument('-checkpoint_dir', type=str, required=True, help='Directory to save checkpoints.')
49
+ parser.add_argument('-model', type=str, default='config.yaml', help='Path to model config YAML.')
50
+ parser.add_argument('-aia_dir', type=str, required=True, help='Path to AIA .npy files.')
51
+ parser.add_argument('-sxr_dir', type=str, required=True, help='Path to SXR .npy files.')
52
+ parser.add_argument('-sxr_norm', type=str, help='Path to SXR normalization (mean, std).')
53
+ parser.add_argument('-instrument', type=str, default='AIA_6', help='Instrument (e.g., AIA_6 for 6 wavelengths).')
54
+ args = parser.parse_args()
55
+
56
+ # Load config
57
+ with open(args.model, 'r') as stream:
58
+ config_data = yaml.load(stream, Loader=yaml.SafeLoader)
59
+
60
+ dic_values = [i for i in config_data['model'].values()]
61
+ combined_parameters = list(itertools.product(*dic_values))
62
+
63
+ # Paths and normalization
64
+ checkpoint_dir = args.checkpoint_dir
65
+ aia_dir = args.aia_dir
66
+ sxr_dir = args.sxr_dir
67
+ if args.sxr_norm:
68
+ sxr_norm = np.load(args.sxr_norm)
69
+ else:
70
+ sxr_norm = compute_sxr_norm(sxr_dir)
71
+ instrument = args.instrument
72
+
73
+ # Transforms
74
+ train_transforms = transforms.Compose([
75
+ transforms.Lambda(lambda x: (x - x.min()) / (x.max() - x.min() + 1e-8)), # Remove clone/detach
76
+ transforms.RandomHorizontalFlip(p=0.5),
77
+ transforms.RandomRotation(10),
78
+ ])
79
+ val_transforms = transforms.Compose([
80
+ transforms.Lambda(lambda x: (x - x.min()) / (x.max() - x.min() + 1e-8)), # Remove clone/detach
81
+ ])
82
+
83
+ # Training loop
84
+ n = 0
85
+ for parameter_set in combined_parameters:
86
+ run_config = {key: item for key, item in zip(config_data['model'].keys(), parameter_set)}
87
+ torch.manual_seed(run_config['seed'])
88
+ np.random.seed(run_config['seed'])
89
+
90
+ # DataModule
91
+ data_loader = AIA_GOESDataModule(
92
+ aia_dir=aia_dir,
93
+ sxr_dir=sxr_dir,
94
+ sxr_norm=sxr_norm,
95
+ batch_size=16,
96
+ num_workers=os.cpu_count() // 2,
97
+ train_transforms=train_transforms,
98
+ val_transforms=val_transforms,
99
+ val_split=0.2,
100
+ test_split=0.1
101
+ )
102
+ data_loader.setup()
103
+
104
+ # Logger
105
+ wb_name = f"{instrument}_{n}" if len(combined_parameters) > 1 else "aia_sxr_model"
106
+ wandb_logger = WandbLogger(
107
+ entity=config_data['wandb']['entity'],
108
+ project=config_data['wandb']['project'],
109
+ job_type=config_data['wandb']['job_type'],
110
+ tags=config_data['wandb']['tags'],
111
+ name=wb_name,
112
+ notes=config_data['wandb']['notes'],
113
+ config=run_config
114
+ )
115
+
116
+ # Logging callback
117
+ total_n_valid = len(data_loader.valid_ds)
118
+ plot_data = [data_loader.valid_ds[i] for i in range(0, total_n_valid, max(1, total_n_valid // 4))]
119
+ plot_samples = plot_data # Keep as list of ((aia, sxr), target)
120
+ sxr_callback = SXRPredictionLogger(plot_samples)
121
+
122
+ # Checkpoint callback
123
+ checkpoint_callback = ModelCheckpoint(
124
+ dirpath=checkpoint_dir,
125
+ monitor='valid_loss',
126
+ mode='min',
127
+ save_top_k=1,
128
+ filename=f"{wb_name}-{{epoch:02d}}-{{valid_loss:.4f}}"
129
+ )
130
+
131
+ # Model
132
+ if run_config['architecture'] == 'linear':
133
+ model = LinearIrradianceModel(
134
+ d_input=6,
135
+ d_output=1,
136
+ eve_norm=sxr_norm,
137
+ lr=run_config.get('lr', 1e-2),
138
+ loss_func=HuberLoss()
139
+ )
140
+ elif run_config['architecture'] == 'hybrid':
141
+ model = HybridIrradianceModel(
142
+ d_input=6,
143
+ d_output=1,
144
+ eve_norm=sxr_norm,
145
+ cnn_model=run_config['cnn_model'],
146
+ ln_model=True,
147
+ cnn_dp=run_config.get('cnn_dp', 0.75),
148
+ lr=run_config.get('lr', 1e-4)
149
+ )
150
+ else:
151
+ raise NotImplementedError(f"Architecture {run_config['architecture']} not supported.")
152
+
153
+ # Trainer
154
+ trainer = Trainer(
155
+ default_root_dir=checkpoint_dir,
156
+ accelerator="gpu" if torch.cuda.is_available() else "cpu",
157
+ devices=1,
158
+ max_epochs=run_config.get('epochs', 10),
159
+ callbacks=[sxr_callback, checkpoint_callback],
160
+ logger=wandb_logger,
161
+ log_every_n_steps=10
162
+ )
163
+
164
+ # Train
165
+ trainer.fit(model, data_loader)
166
+
167
+ # Save checkpoint
168
+ save_dictionary = run_config
169
+ save_dictionary['model'] = model
170
+ save_dictionary['instrument'] = instrument
171
+ full_checkpoint_path = os.path.join(checkpoint_dir, f"{wb_name}_{n}.ckpt")
172
+ torch.save(save_dictionary, full_checkpoint_path)
173
+
174
+ # Test
175
+ trainer.test(model, dataloaders=data_loader.test_dataloader())
176
+
177
+ # Finalize
178
+ wandb.finish()
179
+ n += 1
flaring/__init__.py ADDED
File without changes
flaring/cut_off_aia.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+
4
+
5
+ aia = os.listdir("/mnt/data/ML-Ready-Data-No-Intensity-Cut/AIA-Data")
6
+
7
+
8
+ target_dates = ["2023-07-11","2023-07-15","2023-07-16", "2023-07-18" "2023-07-20", "2023-07-26", "2023-07-30", "2023-08-01", "2023-08-02", "2023-08-07", ]
9
+
10
+ aia_dict = {}
11
+ aia_dict[0] = []
12
+ aia_dict[1] = []
13
+ aia_dict[2] = []
14
+ aia_dict[3] = []
15
+ aia_dict[4] = []
16
+ aia_dict[5] = []
17
+
18
+ count = 0
19
+ for i, file in enumerate(aia):
20
+ if file.split("T")[0] in target_dates:
21
+ aia_data = np.load("/mnt/data/ML-Ready-Data-No-Intensity-Cut/AIA-Data/"+file)
22
+ aia_dict[0].append(aia_data[0].flatten())
23
+ aia_dict[1].append(aia_data[1].flatten())
24
+ aia_dict[2].append(aia_data[2].flatten())
25
+ aia_dict[3].append(aia_data[3].flatten())
26
+ aia_dict[4].append(aia_data[4].flatten())
27
+ aia_dict[5].append(aia_data[5].flatten())
28
+ count = count + 1
29
+ print("Flares: " + str(count) + "\n")
30
+ print(f"\nProcessed {i+1}/{len(aia)} files", end='\r')
31
+
32
+ def percentile(data, perc):
33
+ return np.percentile(data, perc)
34
+
35
+ percentile_dict = {0:[percentile(aia_dict[0], 95), percentile(aia_dict[0], 99.5)],1: [percentile(aia_dict[1], 95), percentile(aia_dict[1], 99.5)], 2: [percentile(aia_dict[2], 95), percentile(aia_dict[2], 99.5)], 3: [percentile(aia_dict[3], 95), percentile(aia_dict[3], 99.5)], 4: [percentile(aia_dict[4], 95), percentile(aia_dict[4], 99.5)], 5: [percentile(aia_dict[5], 95), percentile(aia_dict[5], 99.5)]}
36
+
37
+ print(percentile_dict)
38
+ #{0: [np.float32(5.0747647), np.float32(16.560747)], 1: [np.float32(24.491392), np.float32(75.84181)], 2: [np.float32(607.3201), np.float32(1536.1443)], 3: [np.float32(1021.83466), np.float32(2288.1)], 4: [np.float32(480.13672), np.float32(1163.9178)], 5: [np.float32(144.44502), np.float32(401.82352)]}