arabeh commited on
Commit
28bf80d
·
1 Parent(s): 3f489e4

added the time dependent deeponet model

Browse files
models/.gitkeep ADDED
File without changes
models/base.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+
4
+
5
+ class BaseLightningModule(pl.LightningModule):
6
+ def configure_optimizers(self):
7
+ return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
8
+
9
+ def _masked_mse(self, y_hat, y_true, sdf):
10
+ mask = (sdf > 0).flatten(1).unsqueeze(-1)
11
+ se = ((y_hat - y_true) ** 2) * mask
12
+ return se.sum() / mask.sum()
13
+
14
+ def training_step(self, batch, batch_idx):
15
+ (branch, re, coords, sdf), tgt = batch
16
+ y_hat = self.model((branch, re, coords, sdf))
17
+ if self.hparams.use_derivative_loss:
18
+ loss = self._derivative_loss(y_hat, tgt, sdf)
19
+ else:
20
+ loss = self._masked_mse(y_hat, tgt, sdf)
21
+
22
+ self.log('train_loss', loss)
23
+ return loss
24
+
25
+ def validation_step(self, batch, batch_idx):
26
+ (branch, re, coords, sdf), tgt = batch
27
+ y_hat = self.model((branch, re, coords, sdf))
28
+ if self.hparams.use_derivative_loss:
29
+ loss = self._derivative_loss(y_hat, tgt, sdf)
30
+ else:
31
+ loss = self._masked_mse(y_hat, tgt, sdf)
32
+
33
+ self.log('val_loss', loss)
34
+ return loss
35
+
36
+ def _derivative_loss(self, y_hat, y_true, sdf):
37
+ # --- reshape [B,1,p,C] → [B,C,H,W] ---
38
+ B, _, p, C = y_hat.shape
39
+ H, W = self.hparams.height, self.hparams.width
40
+ yh = y_hat.squeeze(1).permute(0,2,1).reshape(B, C, H, W)
41
+ yt = y_true.squeeze(1).permute(0,2,1).reshape(B, C, H, W)
42
+
43
+ deriv_hat = self.deriv_calc(yh)
44
+ deriv_true = self.deriv_calc(yt)
45
+ fluid_mask = (sdf > 0) # [B,1,H,W]
46
+ delta = self.hparams.domain_length_y / H
47
+ loss = 0.0
48
+ # Derivative tensors come out at resolution (H-1)x(W-1) so crop the fluid_mask to match:
49
+ dm = fluid_mask[:, :, :-1, :-1].unsqueeze(1) # → [B,1,1,H-1,W-1]
50
+ for key in ('u_x','u_y','v_x','v_y'):
51
+ diff = deriv_hat[key] - deriv_true[key] # [B,ngp,1,H-1,W-1]
52
+ # apply mask before averaging
53
+ deriv_loss = delta * (diff.pow(2) * dm).sum() / dm.sum()
54
+ self.log(f"deriv_loss/{key}", deriv_loss, on_step=False, on_epoch=True)
55
+ loss = loss + deriv_loss
56
+
57
+ inner = (sdf > 0) & (sdf <= delta) # [B,1,H,W]
58
+ if inner.any().item():
59
+ u_hat = yh[:, 0:1] # [B,1,H,W]
60
+ v_hat = yh[:, 1:2]
61
+ if self.hparams.use_zero_bc:
62
+ bc_loss = 1000 * (u_hat[inner].pow(2) + v_hat[inner].pow(2)).mean()
63
+
64
+ else:
65
+ u_true = yt[:, 0:1]
66
+ v_true = yt[:, 1:2]
67
+ u_target = u_true[inner]
68
+ v_target = v_true[inner]
69
+ bc_loss = ((u_hat[inner] - u_target).pow(2) + (v_hat[inner] - v_target).pow(2)).mean()
70
+
71
+ self.log("boundary_bc_loss", bc_loss, on_step=False, on_epoch=True)
72
+ loss = loss + bc_loss
73
+
74
+ return loss
models/deriv_calc.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import warnings
5
+ from itertools import product
6
+ from typing import Dict
7
+
8
+
9
+ def gauss_pt_eval(tensor: torch.Tensor, kernels: nn.ParameterList, stride: int = 1) -> torch.Tensor:
10
+ if not kernels:
11
+ raise ValueError("No Gauss kernels provided.")
12
+ conv = F.conv2d
13
+ B, C = tensor.shape[0], tensor.shape[1]
14
+ device = tensor.device
15
+ # determine output spatial shape
16
+ with torch.no_grad():
17
+ sample_out = conv(tensor[:, :1], kernels[0].to(device), stride=stride)
18
+ out_spatial = sample_out.shape[2:]
19
+
20
+ results = []
21
+ for k in kernels:
22
+ k = k.to(device)
23
+ # apply convolution per channel
24
+ out_ch = [conv(tensor[:, i:i+1], k, stride=stride) for i in range(C)]
25
+ results.append(torch.cat(out_ch, dim=1).unsqueeze(1))
26
+
27
+ out = torch.cat(results, dim=1)
28
+ expected = (B, len(kernels), C) + out_spatial
29
+ if out.shape != expected:
30
+ warnings.warn(f"Shape mismatch in gauss_pt_eval: {out.shape} != {expected}")
31
+ return out
32
+
33
+
34
+ class FEM2D(nn.Module):
35
+ """
36
+ Builds 2D FEM convolution kernels and evaluates derivatives.
37
+ """
38
+ def __init__(
39
+ self,
40
+ height: int,
41
+ width: int,
42
+ domain_length_x: float,
43
+ domain_length_y: float,
44
+ device: torch.device
45
+ ):
46
+ super().__init__()
47
+ self.height, self.width = height, width
48
+ self.device = device
49
+ # 2-point Gauss quadrature
50
+ self.gpx = [-0.57735, 0.57735]
51
+ self.kernels_dx = nn.ParameterList()
52
+ self.kernels_dy = nn.ParameterList()
53
+ self._build_kernels(domain_length_x, domain_length_y)
54
+
55
+ def _build_kernels(self, Lx: float, Ly: float):
56
+ hx = Lx / (self.width - 1)
57
+ hy = Ly / (self.height - 1)
58
+ # linear basis functions on [-1,1]
59
+ bf = lambda x: [0.5 * (1 - x), 0.5 * (1 + x)]
60
+ dbf = lambda x: [-0.5, 0.5]
61
+
62
+ for gx, gy in product(self.gpx, repeat=2):
63
+ dx = torch.zeros(2, 2, device=self.device)
64
+ dy = torch.zeros(2, 2, device=self.device)
65
+ for i, bf_x in enumerate(bf(gx)):
66
+ for j, bf_y in enumerate(bf(gy)):
67
+ dx[j, i] = dbf(gx)[i] * (2 / hx) * bf_y
68
+ dy[j, i] = bf_x * (dbf(gy)[j] * (2 / hy))
69
+ # store kernels with shape [1,1,2,2]
70
+ self.kernels_dx.append(nn.Parameter(dx.unsqueeze(0).unsqueeze(0), requires_grad=False))
71
+ self.kernels_dy.append(nn.Parameter(dy.unsqueeze(0).unsqueeze(0), requires_grad=False))
72
+
73
+ def eval_derivative_x(self, tensor: torch.Tensor) -> torch.Tensor:
74
+ return gauss_pt_eval(tensor, self.kernels_dx)
75
+
76
+ def eval_derivative_y(self, tensor: torch.Tensor) -> torch.Tensor:
77
+ return gauss_pt_eval(tensor, self.kernels_dy)
78
+
79
+
80
+ class DerivativeCalculator(nn.Module):
81
+ """
82
+ Computes first spatial derivatives for 'u' and 'v' channels.
83
+ """
84
+ def __init__(
85
+ self,
86
+ height: int,
87
+ width: int,
88
+ domain_length_x: float,
89
+ domain_length_y: float,
90
+ device: torch.device,
91
+ channels: int = 2 # number of channels: 2 for (u,v)
92
+ ):
93
+ super().__init__()
94
+ self.channels = channels
95
+ self.fem = FEM2D(height, width, domain_length_x, domain_length_y, device)
96
+
97
+ def calculate_first_derivatives(self, y_spatial: torch.Tensor) -> Dict[str, torch.Tensor]:
98
+ """
99
+ y_spatial: [B, C, H, W] tensor where C == channels
100
+ Returns a dict with keys 'u_x','u_y','v_x','v_y'.
101
+ """
102
+ deriv = {}
103
+ names = ['u', 'v'][:self.channels]
104
+ for idx, name in enumerate(names):
105
+ field = y_spatial[:, idx:idx+1]
106
+ deriv[f'{name}_x'] = self.fem.eval_derivative_x(field)
107
+ deriv[f'{name}_y'] = self.fem.eval_derivative_y(field)
108
+ return deriv
109
+
110
+ forward = calculate_first_derivatives
models/geometric_deeponet/.gitkeep ADDED
File without changes
models/geometric_deeponet/geometric_deeponet.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from models.base import BaseLightningModule
3
+ from models.geometric_deeponet.network import GeoDeepONetTime as _GeoDeepONetTime
4
+ from models.deriv_calc import DerivativeCalculator
5
+
6
+ class GeometricDeepONetTime(BaseLightningModule):
7
+ def __init__(self, **kwargs):
8
+ super().__init__()
9
+ self.save_hyperparameters()
10
+ eff = self.hparams.output_channels - 1
11
+
12
+ self.deriv_calc = DerivativeCalculator(
13
+ height=self.hparams.height,
14
+ width=self.hparams.width,
15
+ domain_length_x=self.hparams.domain_length_x,
16
+ domain_length_y=self.hparams.domain_length_y,
17
+ device=torch.device('cpu'),
18
+ channels=eff
19
+ )
20
+
21
+ # build network args
22
+ net_args = {k: getattr(self.hparams, k) for k in [
23
+ 'height', 'width', 'num_input_timesteps', 'input_channels_loc',
24
+ 'modes', 'branch_stage1_layers', 'trunk_stage1_layers',
25
+ 'branch_stage2_layers', 'trunk_stage2_layers',
26
+ 'cnn_c1', 'cnn_c2', 'cnn_c3', 'cnn_fc1', 'cnn_fc2'
27
+ ]}
28
+ net_args['effective_output_channels'] = eff
29
+ self.model = _GeoDeepONetTime(**net_args)
30
+
31
+ def forward(self, inputs: tuple):
32
+ """
33
+ This is the LightningModule entry point for inference.
34
+ It simply calls through to the underlying torch.nn.Module.
35
+ """
36
+ return self.model(inputs)
models/geometric_deeponet/network.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+ import warnings
6
+ from typing import List
7
+
8
+ class ConvBlock(nn.Module):
9
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
10
+ super().__init__()
11
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
12
+ stride=stride, padding=padding, bias=False)
13
+ self.bn = nn.BatchNorm2d(out_channels)
14
+ self.relu = nn.ReLU(inplace=True)
15
+ def forward(self, x):
16
+ return self.relu(self.bn(self.conv(x)))
17
+
18
+ class InceptionStyleCNNEncoder(nn.Module):
19
+ def __init__(self, input_channels: int, c1: int, c2: int, c3: int, fc1: int, fc2: int):
20
+ super().__init__()
21
+ # Branch 1
22
+ self.b1_conv1 = ConvBlock(input_channels, c1, 1)
23
+ self.b1_pool1 = nn.MaxPool2d(2, 2)
24
+ self.b1_conv2 = ConvBlock(c1, c2, 1)
25
+ self.b1_pool2 = nn.MaxPool2d(2, 2)
26
+ self.b1_conv3 = ConvBlock(c2, c3, 1)
27
+ self.b1_pool3 = nn.MaxPool2d(2, 2)
28
+ # Branch 2
29
+ self.b2_conv1a = ConvBlock(input_channels, c1, 1)
30
+ self.b2_conv1b = ConvBlock(c1, c1, 3, padding=1)
31
+ self.b2_pool1 = nn.MaxPool2d(2, 2)
32
+ self.b2_conv2a = ConvBlock(c1, c2, 1)
33
+ self.b2_conv2b = ConvBlock(c2, c2, 3, padding=1)
34
+ self.b2_pool2 = nn.MaxPool2d(2, 2)
35
+ self.b2_conv3a = ConvBlock(c2, c3, 1)
36
+ self.b2_conv3b = ConvBlock(c3, c3, 3, padding=1)
37
+ self.b2_pool3 = nn.MaxPool2d(2, 2)
38
+ # Branch 3
39
+ self.b3_conv1a = ConvBlock(input_channels, c1, 1)
40
+ self.b3_conv1b = ConvBlock(c1, c1, 5, padding=2)
41
+ self.b3_pool1 = nn.MaxPool2d(2, 2)
42
+ self.b3_conv2a = ConvBlock(c1, c2, 1)
43
+ self.b3_conv2b = ConvBlock(c2, c2, 5, padding=2)
44
+ self.b3_pool2 = nn.MaxPool2d(2, 2)
45
+ self.b3_conv3a = ConvBlock(c2, c3, 1)
46
+ self.b3_conv3b = ConvBlock(c3, c3, 5, padding=2)
47
+ self.b3_pool3 = nn.MaxPool2d(2, 2)
48
+ # Fusion
49
+ concat_channels = 3 * c3
50
+ self.fusion_conv1 = ConvBlock(concat_channels, fc1, 1)
51
+ self.fusion_pool1 = nn.MaxPool2d(2, 2)
52
+ self.fusion_conv2 = ConvBlock(fc1, fc2, 1)
53
+ self.fusion_pool2 = nn.MaxPool2d(2, 2)
54
+ self.flatten = nn.Flatten()
55
+ self.final_cnn_channels = fc2
56
+
57
+ def forward(self, x):
58
+ p1 = self.b1_pool3(self.b1_conv3(self.b1_pool2(self.b1_conv2(self.b1_pool1(self.b1_conv1(x))))))
59
+ p2 = self.b2_pool3(self.b2_conv3b(self.b2_conv3a(self.b2_pool2(self.b2_conv2b(self.b2_conv2a(self.b2_pool1(self.b2_conv1b(self.b2_conv1a(x)))))))))
60
+ p3 = self.b3_pool3(self.b3_conv3b(self.b3_conv3a(self.b3_pool2(self.b3_conv2b(self.b3_conv2a(self.b3_pool1(self.b3_conv1b(self.b3_conv1a(x)))))))))
61
+ c = torch.cat((p1, p2, p3), dim=1)
62
+ f = self.fusion_pool2(self.fusion_conv2(self.fusion_pool1(self.fusion_conv1(c))))
63
+ return self.flatten(f)
64
+
65
+ class LinearMLP(nn.Module):
66
+ def __init__(self, dims: List[int], nonlin):
67
+ super().__init__()
68
+ layers = []
69
+ for i in range(len(dims)-1):
70
+ layers.append(nn.Linear(dims[i], dims[i+1]))
71
+ if i < len(dims)-2:
72
+ layers.append(nonlin())
73
+ self.mlp = nn.Sequential(*layers)
74
+ def forward(self, x):
75
+ return self.mlp(x)
76
+
77
+ class torchSine(nn.Module):
78
+ def forward(self, x): return torch.sin(x)
79
+
80
+ class GeoDeepONetTime(nn.Module):
81
+ def __init__(
82
+ self, height: int, width: int, num_input_timesteps: int,
83
+ input_channels_loc: int, effective_output_channels: int,
84
+ modes: int,
85
+ branch_stage1_layers: List[int], trunk_stage1_layers: List[int],
86
+ branch_stage2_layers: List[int], trunk_stage2_layers: List[int],
87
+ cnn_c1: int, cnn_c2: int, cnn_c3: int, cnn_fc1: int, cnn_fc2: int
88
+ ):
89
+ super().__init__()
90
+ if input_channels_loc != 2:
91
+ warnings.warn("GeoDeepONetTime expects input_channels_loc=2 (x,y). SDF will be added.")
92
+
93
+ self.input_channels_loc_base = input_channels_loc
94
+ self.input_channels_loc_effective = input_channels_loc + 1
95
+ self.effective_output_channels = effective_output_channels
96
+ self.modes = modes
97
+ self.height = height; self.width = width
98
+ self.num_points = height * width
99
+
100
+ # --- Branch ---
101
+ channels_per_step = self.effective_output_channels
102
+ cnn_in_ch = num_input_timesteps * channels_per_step
103
+ self.cnn_encoder = InceptionStyleCNNEncoder(cnn_in_ch, cnn_c1, cnn_c2, cnn_c3, cnn_fc1, cnn_fc2)
104
+ with torch.no_grad():
105
+ dummy = torch.zeros(1, cnn_in_ch, height, width)
106
+ flat = self.cnn_encoder(dummy)
107
+ cnn_flat = flat.shape[1]
108
+ branch_dims1 = [cnn_flat] + branch_stage1_layers + [modes]
109
+ self.branch_stage_1 = LinearMLP(branch_dims1, nn.ReLU)
110
+ branch_dims2 = [modes] + branch_stage2_layers + [modes * effective_output_channels]
111
+ self.branch_stage_2 = LinearMLP(branch_dims2, nn.ReLU)
112
+
113
+ # --- Trunk ---
114
+ trunk_dims1 = [self.input_channels_loc_effective] + trunk_stage1_layers + [modes]
115
+ self.trunk_stage_1 = LinearMLP(trunk_dims1, nn.ReLU)
116
+ trunk_dims2 = [modes] + trunk_stage2_layers + [modes * effective_output_channels]
117
+ self.trunk_stage_2 = LinearMLP(trunk_dims2, torchSine)
118
+
119
+ # --- Bias ---
120
+ self.b = nn.Parameter(torch.tensor(0.0))
121
+
122
+ def forward(self, inputs: tuple):
123
+ x1, _, coords, sdf = inputs[:4]
124
+ # --- Branch ---
125
+ feat = self.cnn_encoder(x1)
126
+ glob = self.branch_stage_1(feat)
127
+ # --- Trunk ---
128
+ # coords: [b, 2, h, w] → [b, h*w, 2]
129
+ c2 = rearrange(coords, 'b c h w -> b (h w) c')
130
+ # sdf: [b, 1, h, w] → [b, h*w, 1]
131
+ sdf_flat = rearrange(sdf, 'b 1 h w -> b (h w) 1')
132
+ # combine into [b, h*w, 3]
133
+ trunk_in = torch.cat((c2, sdf_flat), dim=-1)
134
+ # pass each point through the trunk MLP → [b, h*w, modes]
135
+ local = self.trunk_stage_1(trunk_in)
136
+
137
+ # --- Merge & Stage2 --- [b, modes] → [b, 1, modes], local: [b, h*w, modes]
138
+ merged = glob.unsqueeze(1) * local
139
+ avg = merged.mean(dim=1)
140
+
141
+ out_b = self.branch_stage_2(avg) # [b, modes*eff_out]
142
+ out_t = self.trunk_stage_2(merged) # [b, h*w, modes*eff_out]
143
+
144
+ # reshape for tensor contraction
145
+ b_r = rearrange(out_b, 'b (m c) -> b m c', m=self.modes, c=self.effective_output_channels)
146
+ t_r = rearrange(out_t, 'b p (m c) -> b p m c', m=self.modes, c=self.effective_output_channels)
147
+
148
+ # compute solution and add bias
149
+ sol_flat = torch.einsum('bmc,bpmc->bpc', b_r, t_r) + self.b
150
+
151
+ # final shape [b, 1, p, c]
152
+ return rearrange(sol_flat, 'b p c -> b 1 p c')