Spaces:
Runtime error
Runtime error
| """Transolver: Transformer Solver with Physics Attention. | |
| # einops used for multi-head reshape ops β clearer than manual reshape+transpose. | |
| Adapted from PhysicsNeMo (NVIDIA Modulus) Transolver implementation: | |
| "Transolver: A Fast Transformer Solver for PDEs on General Geometries" | |
| Haixu Wu, Huakun Luo, Haowen Wang et al. β NeurIPS 2024 / ICML 2024 | |
| arXiv: https://arxiv.org/abs/2402.02366 | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| from einops import rearrange | |
| from core.device import DEVICE | |
| # ββ Physics Attention (1-D structured grid) βββββββββββββββββββββββββββββββββββ | |
| class PhysicsAttn1d(nn.Module): | |
| """Physics Attention over 1-D structured grids. | |
| Parameters | |
| ---------- | |
| dim : int | |
| Embedding dimension (must be divisible by n_head). | |
| n_head : int | |
| Number of attention heads. | |
| slice_num : int | |
| Number of physics slices S (analogous to tokens in ViT). | |
| """ | |
| def __init__(self, dim: int, n_head: int = 4, slice_num: int = 32): | |
| super().__init__() | |
| assert dim % n_head == 0, "dim must be divisible by n_head" | |
| self.dim = dim | |
| self.n_head = n_head | |
| self.slice_num = slice_num | |
| self.head_dim = dim // n_head | |
| self.scale = self.head_dim ** -0.5 | |
| # Slice assignment projections: N points β S slice logits (per head) | |
| self.to_slice = nn.Linear(dim, n_head * slice_num, bias=False) | |
| # Standard QKV projections (applied on slice tokens after grouping) | |
| self.to_q = nn.Linear(dim, dim, bias=False) | |
| self.to_k = nn.Linear(dim, dim, bias=False) | |
| self.to_v = nn.Linear(dim, dim, bias=False) | |
| self.out_proj = nn.Linear(dim, dim) | |
| self.to(DEVICE) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Parameters | |
| ---------- | |
| x : [B, N, D] | |
| Returns | |
| ------- | |
| out : [B, N, D] | |
| """ | |
| B, N, D = x.shape | |
| H, S, d = self.n_head, self.slice_num, self.head_dim | |
| # ββ Slice assignment βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logits = rearrange(self.to_slice(x), 'b n (h s) -> b h n s', h=H) | |
| A = F.softmax(logits, dim=-1) # [B, H, N, S] | |
| # ββ Aggregate N grid points β S slice tokens ββββββββββββββββββββββ | |
| q_grid = rearrange(self.to_q(x), 'b n (h d) -> b h n d', h=H) | |
| k_grid = rearrange(self.to_k(x), 'b n (h d) -> b h n d', h=H) | |
| v_grid = rearrange(self.to_v(x), 'b n (h d) -> b h n d', h=H) | |
| # Weighted average: [B,H,S,d] = A^T [B,H,N,S] @ {q,k,v} [B,H,N,d] | |
| At = A.transpose(-2, -1) # [B, H, S, N] | |
| q_s = torch.matmul(At, q_grid) # [B, H, S, d] | |
| k_s = torch.matmul(At, k_grid) | |
| v_s = torch.matmul(At, v_grid) | |
| # ββ Self-attention in slice space βββββββββββββββββββββββββββββββββ | |
| dots = torch.matmul(q_s, k_s.transpose(-2, -1)) * self.scale # [B,H,S,S] | |
| attn = F.softmax(dots, dim=-1) | |
| out_s = torch.matmul(attn, v_s) # [B, H, S, d] | |
| # ββ Broadcast back: S slice tokens β N grid points ββββββββββββββββ | |
| out_grid = torch.matmul(A, out_s) # [B, H, N, d] | |
| out = rearrange(out_grid, 'b h n d -> b n (h d)') | |
| return self.out_proj(out) | |
| # ββ Transolver Block ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TransolverBlock1d(nn.Module): | |
| """One Transolver layer: PhysicsAttn + FFN with pre-norm.""" | |
| def __init__(self, dim: int, n_head: int = 4, slice_num: int = 32, | |
| mlp_ratio: float = 2.0, dropout: float = 0.0): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.attn = PhysicsAttn1d(dim, n_head, slice_num) | |
| self.norm2 = nn.LayerNorm(dim) | |
| hidden_mlp = int(dim * mlp_ratio) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(dim, hidden_mlp), | |
| nn.GELU(), | |
| nn.Linear(hidden_mlp, dim), | |
| ) | |
| self.to(DEVICE) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x + self.attn(self.norm1(x)) | |
| x = x + self.ffn(self.norm2(x)) | |
| return x | |
| # ββ Transolver1d ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class Transolver1d(nn.Module): | |
| """Transolver for 1-D structured grids.""" | |
| def __init__(self, n_modes: int = 16, hidden_dim: int = 64, | |
| n_layers: int = 4, in_ch: int = 2, | |
| n_head: int = 4, slice_num: int = 32, | |
| mlp_ratio: float = 2.0): | |
| super().__init__() | |
| self.lift = nn.Linear(in_ch, hidden_dim) | |
| self.blocks = nn.ModuleList([ | |
| TransolverBlock1d(hidden_dim, n_head, slice_num, mlp_ratio) | |
| for _ in range(n_layers) | |
| ]) | |
| self.norm = nn.LayerNorm(hidden_dim) | |
| self.proj1 = nn.Linear(hidden_dim, hidden_dim // 2) | |
| self.proj2 = nn.Linear(hidden_dim // 2, 1) | |
| self.to(DEVICE) | |
| def forward(self, u0: torch.Tensor) -> torch.Tensor: | |
| B, N = u0.shape | |
| grid = torch.linspace(0.0, 1.0, N, device=DEVICE).unsqueeze(0).expand(B, N) | |
| x = torch.stack([u0, grid], dim=-1) # [B, N, 2] | |
| x = self.lift(x) # [B, N, D] | |
| for blk in self.blocks: | |
| x = blk(x) | |
| x = F.gelu(self.proj1(self.norm(x))) | |
| return self.proj2(x)[:, :, 0] # [B, N] | |
| # ββ Physics Attention (2-D structured grid) βββββββββββββββββββββββββββββββββββ | |
| class PhysicsAttn2d(nn.Module): | |
| """Physics Attention over 2-D structured grids [B, N1, N2, D].""" | |
| def __init__(self, dim: int, n_head: int = 4, slice_num: int = 32): | |
| super().__init__() | |
| self.inner = PhysicsAttn1d(dim, n_head, slice_num) | |
| self.to(DEVICE) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| B, N1, N2, D = x.shape | |
| x_flat = x.reshape(B, N1 * N2, D) | |
| out = self.inner(x_flat) | |
| return out.reshape(B, N1, N2, D) | |
| class TransolverBlock2d(nn.Module): | |
| """One 2-D Transolver layer.""" | |
| def __init__(self, dim: int, n_head: int = 4, slice_num: int = 32, | |
| mlp_ratio: float = 2.0): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.attn = PhysicsAttn2d(dim, n_head, slice_num) | |
| self.norm2 = nn.LayerNorm(dim) | |
| hidden_mlp = int(dim * mlp_ratio) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(dim, hidden_mlp), | |
| nn.GELU(), | |
| nn.Linear(hidden_mlp, dim), | |
| ) | |
| self.to(DEVICE) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x + self.attn(self.norm1(x)) | |
| x = x + self.ffn(self.norm2(x)) | |
| return x | |
| class Transolver2d(nn.Module): | |
| """Transolver for 2-D structured grids (darcy_2d, ns_2d).""" | |
| def __init__(self, n_modes1: int = 12, n_modes2: int = 12, | |
| hidden_dim: int = 64, n_layers: int = 4, | |
| in_ch: int = 3, n_head: int = 4, slice_num: int = 64, | |
| mlp_ratio: float = 2.0): | |
| super().__init__() | |
| self.lift = nn.Linear(in_ch, hidden_dim) | |
| self.blocks = nn.ModuleList([ | |
| TransolverBlock2d(hidden_dim, n_head, slice_num, mlp_ratio) | |
| for _ in range(n_layers) | |
| ]) | |
| self.norm = nn.LayerNorm(hidden_dim) | |
| self.proj1 = nn.Linear(hidden_dim, hidden_dim // 2) | |
| self.proj2 = nn.Linear(hidden_dim // 2, 1) | |
| self.to(DEVICE) | |
| def forward(self, u0: torch.Tensor) -> torch.Tensor: | |
| B, N1, N2 = u0.shape | |
| grid1 = torch.linspace(0.0, 1.0, N1, device=DEVICE).view(1, N1, 1).expand(B, N1, N2) | |
| grid2 = torch.linspace(0.0, 1.0, N2, device=DEVICE).view(1, 1, N2).expand(B, N1, N2) | |
| x = torch.stack([u0, grid1, grid2], dim=-1) # [B, N1, N2, 3] | |
| x = self.lift(x) | |
| for blk in self.blocks: | |
| x = blk(x) | |
| x = F.gelu(self.proj1(self.norm(x))) | |
| return self.proj2(x)[:, :, :, 0] # [B, N1, N2] | |