| import math |
|
|
| import jax |
| from flax.core import unfreeze, freeze |
| import jax.numpy as jnp |
| import flax.linen as nn |
| from jaxtyping import Array, ArrayLike, PyTree |
|
|
| from .edsr import EDSR |
| from .rdn import RDN |
| from .hyper import Hypernetwork |
| from .tail import build_tail |
| from .init import uniform_between, linear_up |
| from utils import make_grid, interpolate_grid, repeat_vmap |
|
|
|
|
| class Thermal(nn.Module): |
| w0_scale: float = 1. |
|
|
| @nn.compact |
| def __call__(self, x: ArrayLike, t, norm, k) -> Array: |
| phase = self.param('phase', nn.initializers.uniform(.5), x.shape[-1:]) |
| return jnp.sin(self.w0_scale * x + phase) * jnp.exp(-(self.w0_scale * norm)**2 * k * t) |
|
|
|
|
| class TheraField(nn.Module): |
| dim_hidden: int |
| dim_out: int |
| w0: float = 1. |
| c: float = 6. |
|
|
| @nn.compact |
| def __call__(self, x: ArrayLike, t: ArrayLike, k: ArrayLike, components: ArrayLike) -> Array: |
| |
| x = x @ components |
|
|
| |
| norm = jnp.linalg.norm(components, axis=-2) |
| x = Thermal(self.w0)(x, t, norm, k) |
|
|
| |
| w_std = math.sqrt(self.c / self.dim_hidden) / self.w0 |
| dense_init_fn = uniform_between(-w_std, w_std) |
| x = nn.Dense(self.dim_out, kernel_init=dense_init_fn, use_bias=False)(x) |
|
|
| return x |
|
|
|
|
| class Thera: |
|
|
| def __init__( |
| self, |
| hidden_dim: int, |
| out_dim: int, |
| backbone: nn.Module, |
| tail: nn.Module, |
| k_init: float = None, |
| components_init_scale: float = None |
| ): |
| self.hidden_dim = hidden_dim |
| self.k_init = k_init |
| self.components_init_scale = components_init_scale |
|
|
| |
| self.field = TheraField(hidden_dim, out_dim) |
|
|
| |
| |
| sample_params = self.field.init(jax.random.PRNGKey(0), |
| jnp.zeros((2,)), 0., 0., jnp.zeros((2, hidden_dim))) |
| sample_params_flat, tree_def = jax.tree_util.tree_flatten(sample_params) |
| param_shapes = [p.shape for p in sample_params_flat] |
|
|
| self.hypernet = Hypernetwork(backbone, tail, param_shapes, tree_def) |
|
|
| def init(self, key, sample_source) -> PyTree: |
| keys = jax.random.split(key, 2) |
| sample_coords = jnp.zeros(sample_source.shape[:-1] + (2,)) |
| params = unfreeze(self.hypernet.init(keys[0], sample_source, sample_coords)) |
|
|
| params['params']['k'] = jnp.array(self.k_init) |
| params['params']['components'] = \ |
| linear_up(self.components_init_scale)(keys[1], (2, self.hidden_dim)) |
|
|
| return freeze(params) |
|
|
| def apply_encoder(self, params: PyTree, source: ArrayLike, **kwargs) -> Array: |
| """ |
| Performs a forward pass through the hypernetwork to obtain an encoding. |
| """ |
| return self.hypernet.apply( |
| params, source, method=self.hypernet.get_encoding, **kwargs) |
|
|
| def apply_decoder( |
| self, |
| params: PyTree, |
| encoding: ArrayLike, |
| coords: ArrayLike, |
| t: ArrayLike, |
| return_jac: bool = False |
| ) -> Array | tuple[Array, Array]: |
| """ |
| Performs a forward prediction through a grid of HxW Thera fields, |
| informed by `encoding`, at spatial and temporal coordinates |
| `coords` and `t`, respectively. |
| args: |
| params: Field parameters, shape (B, H, W, N) |
| encoding: Encoding tensor, shape (B, H, W, C) |
| coords: Spatial coordinates in [-0.5, 0.5], shape (B, H, W, 2) |
| t: Temporal coordinates, shape (B, 1) |
| """ |
| phi_params: PyTree = self.hypernet.apply( |
| params, encoding, coords, method=self.hypernet.get_params_at_coords) |
|
|
| |
| source_grid = jnp.asarray(make_grid(encoding.shape[-3:-1])) |
| source_coords = jnp.tile(source_grid, (encoding.shape[0], 1, 1, 1)) |
| interp_coords = interpolate_grid(coords, source_coords) |
| rel_coords = (coords - interp_coords) |
| rel_coords = rel_coords.at[..., 0].set(rel_coords[..., 0] * encoding.shape[-3]) |
| rel_coords = rel_coords.at[..., 1].set(rel_coords[..., 1] * encoding.shape[-2]) |
|
|
| |
| in_axes = [(0, 0, None, None, None), (0, 0, None, None, None), (0, 0, 0, None, None)] |
| apply_field = repeat_vmap(self.field.apply, in_axes) |
| out = apply_field(phi_params, rel_coords, t, params['params']['k'], |
| params['params']['components']) |
|
|
| if return_jac: |
| apply_jac = repeat_vmap(jax.jacrev(self.field.apply, argnums=1), in_axes) |
| jac = apply_jac(phi_params, rel_coords, jnp.zeros_like(t), params['params']['k'], |
| params['params']['components']) |
| return out, jac |
|
|
| return out |
|
|
| def apply( |
| self, |
| params: ArrayLike, |
| source: ArrayLike, |
| coords: ArrayLike, |
| t: ArrayLike, |
| return_jac: bool = False, |
| **kwargs |
| ) -> Array: |
| """ |
| Performs a forward pass through the Thera model. |
| """ |
| encoding = self.apply_encoder(params, source, **kwargs) |
| out = self.apply_decoder(params, encoding, coords, t, return_jac=return_jac) |
| return out |
|
|
|
|
| def build_thera( |
| out_dim: int, |
| backbone: str, |
| size: str, |
| k_init: float = None, |
| components_init_scale: float = None |
| ): |
| """ |
| Convenience function for building the three Thera sizes described in the paper. |
| """ |
| hidden_dim = 32 if size == 'air' else 512 |
|
|
| if backbone == 'edsr-baseline': |
| backbone_module = EDSR(None, num_blocks=16, num_feats=64) |
| elif backbone == 'rdn': |
| backbone_module = RDN() |
| else: |
| raise NotImplementedError(backbone) |
|
|
| tail_module = build_tail(size) |
|
|
| return Thera(hidden_dim, out_dim, backbone_module, tail_module, k_init, components_init_scale) |
|
|