| """MLX implementation of Thera super-resolution models (air/pro variants).""" |
|
|
| import math |
| import mlx.core as mx |
| import mlx.nn as nn |
| import numpy as np |
|
|
|
|
| |
|
|
| def make_grid(h, w): |
| """Create coordinate grid in [-0.5, 0.5] with pixel centers.""" |
| offset_h = 1.0 / (2 * h) |
| offset_w = 1.0 / (2 * w) |
| ys = np.linspace(-0.5 + offset_h, 0.5 - offset_h, h, dtype=np.float32) |
| xs = np.linspace(-0.5 + offset_w, 0.5 - offset_w, w, dtype=np.float32) |
| grid_y, grid_x = np.meshgrid(ys, xs, indexing='ij') |
| return np.stack([grid_y, grid_x], axis=-1) |
|
|
|
|
| def interpolate_nearest(coords, grid): |
| """ |
| Nearest-neighbor sampling of a grid at given coordinates. |
| Args: |
| coords: mx.array (B, H, W, 2) coordinates in [-0.5, 0.5] |
| grid: mx.array (B, H', W', C) grid to sample from |
| Returns: |
| mx.array (B, H, W, C) |
| """ |
| B, Hp, Wp, C = grid.shape |
| _, H, W, _ = coords.shape |
|
|
| y = coords[..., 0] * Hp + (Hp - 1) / 2.0 |
| x = coords[..., 1] * Wp + (Wp - 1) / 2.0 |
|
|
| y_idx = mx.clip(mx.round(y).astype(mx.int32), 0, Hp - 1) |
| x_idx = mx.clip(mx.round(x).astype(mx.int32), 0, Wp - 1) |
|
|
| flat_idx = y_idx * Wp + x_idx |
| batch_offset = mx.arange(B).reshape(B, 1, 1) * (Hp * Wp) |
| global_idx = (flat_idx + batch_offset).reshape(-1) |
|
|
| grid_flat = grid.reshape(-1, C) |
| result = grid_flat[global_idx] |
| return result.reshape(B, H, W, C) |
|
|
|
|
| |
|
|
| class RDBConv(nn.Module): |
| """Single convolution layer within a Residual Dense Block.""" |
| def __init__(self, in_channels: int, growth_rate: int, kernel_size: int = 3): |
| super().__init__() |
| self.conv = nn.Conv2d(in_channels, growth_rate, kernel_size, |
| padding=(kernel_size - 1) // 2) |
|
|
| def __call__(self, x): |
| out = nn.relu(self.conv(x)) |
| return mx.concatenate([x, out], axis=-1) |
|
|
|
|
| class RDB(nn.Module): |
| """Residual Dense Block.""" |
| def __init__(self, g0: int, growth_rate: int, n_conv_layers: int): |
| super().__init__() |
| self.convs = [ |
| RDBConv(g0 + i * growth_rate, growth_rate) |
| for i in range(n_conv_layers) |
| ] |
| total_ch = g0 + n_conv_layers * growth_rate |
| self.local_fusion = nn.Conv2d(total_ch, g0, kernel_size=1) |
|
|
| def __call__(self, x): |
| res = x |
| for conv in self.convs: |
| x = conv(x) |
| x = self.local_fusion(x) |
| return x + res |
|
|
|
|
| class RDN(nn.Module): |
| """Residual Dense Network backbone (config B).""" |
| def __init__(self, n_colors: int = 3, g0: int = 64): |
| super().__init__() |
| D, C, G = 16, 8, 64 |
|
|
| self.sfe1 = nn.Conv2d(n_colors, g0, kernel_size=3, padding=1) |
| self.sfe2 = nn.Conv2d(g0, g0, kernel_size=3, padding=1) |
| self.rdbs = [RDB(g0, G, C) for _ in range(D)] |
| self.gff_1x1 = nn.Conv2d(D * g0, g0, kernel_size=1) |
| self.gff_3x3 = nn.Conv2d(g0, g0, kernel_size=3, padding=1) |
|
|
| def __call__(self, x): |
| f1 = self.sfe1(x) |
| x = self.sfe2(f1) |
|
|
| rdb_outs = [] |
| for rdb in self.rdbs: |
| x = rdb(x) |
| rdb_outs.append(x) |
|
|
| x = mx.concatenate(rdb_outs, axis=-1) |
| x = self.gff_1x1(x) |
| x = self.gff_3x3(x) |
| return x + f1 |
|
|
|
|
| |
|
|
| class Thera(nn.Module): |
| """ |
| Thera: arbitrary-scale super-resolution using neural heat fields. |
| |
| Stages: |
| 1. Encoder (RDN backbone) produces features at source resolution |
| 2. Optional refinement tail (identity for air, SwinIR for pro) |
| 3. Hypernetwork (1x1 conv) predicts per-pixel field parameters |
| 4. Heat field decoder produces RGB residuals |
| """ |
| OUT_DIM = 3 |
| W0 = 1.0 |
| MEAN = np.array([0.4488, 0.4371, 0.4040], dtype=np.float32) |
| VAR = np.array([0.25, 0.25, 0.25], dtype=np.float32) |
|
|
| def __init__(self, size='air'): |
| super().__init__() |
| self.size = size |
| self.hidden_dim = 32 if size == 'air' else 512 |
|
|
| |
| n_field_params = self.hidden_dim * self.OUT_DIM + self.hidden_dim |
|
|
| self.encoder = RDN(n_colors=3, g0=64) |
|
|
| |
| if size == 'pro': |
| from swin_ir import SwinIRTail |
| self.refine = SwinIRTail( |
| in_channels=64, embed_dim=180, |
| depths=(7, 6), num_heads=(6, 6), |
| window_size=8, mlp_ratio=2.0, num_feat=64) |
| |
|
|
| self.out_conv = nn.Conv2d(64, n_field_params, kernel_size=1) |
|
|
| self.k = mx.array(0.0) |
| self.components = mx.zeros((2, self.hidden_dim)) |
|
|
| def encode(self, source_norm): |
| """Run encoder + optional refinement tail.""" |
| x = self.encoder(source_norm) |
| if self.size == 'pro': |
| x = self.refine(x) |
| return x |
|
|
| def decode(self, encoding, target_coords, t): |
| """Predict RGB residuals at target coordinates.""" |
| sampled = interpolate_nearest(target_coords, encoding) |
| phi = self.out_conv(sampled) |
|
|
| hd = self.hidden_dim |
| kernel = phi[..., :hd * self.OUT_DIM].reshape( |
| *phi.shape[:-1], hd, self.OUT_DIM) |
| phase = phi[..., hd * self.OUT_DIM:] |
|
|
| Hs, Ws = encoding.shape[1], encoding.shape[2] |
| source_grid = mx.array(make_grid(Hs, Ws)) |
| source_coords = mx.broadcast_to( |
| source_grid[None], (encoding.shape[0],) + source_grid.shape) |
| nearest_src = interpolate_nearest(target_coords, source_coords) |
| rel_coords = target_coords - nearest_src |
| rel_coords_scaled = mx.concatenate([ |
| rel_coords[..., 0:1] * Hs, |
| rel_coords[..., 1:2] * Ws, |
| ], axis=-1) |
|
|
| x = rel_coords_scaled @ self.components |
| norm = mx.linalg.norm(self.components, axis=0) |
| t_4d = t[:, :, None, None] if t.ndim == 2 else t.reshape(-1, 1, 1, 1) |
| decay = mx.exp(-((self.W0 * norm) ** 2) * self.k * t_4d) |
| x = mx.sin(self.W0 * x + phase) * decay |
| out = mx.sum(x[..., None] * kernel, axis=-2) |
|
|
| return out |
|
|
| def upscale(self, source, target_h, target_w, ensemble=False): |
| mean = mx.array(self.MEAN) |
| var = mx.array(self.VAR) |
| std = mx.sqrt(var) |
|
|
| if ensemble: |
| outs = [] |
| for k_rot in range(4): |
| src = mx.array(np.rot90(np.array(source), k=k_rot)) |
| th = target_w if k_rot % 2 else target_h |
| tw = target_h if k_rot % 2 else target_w |
| out = self._upscale_single(src, th, tw, mean, var, std) |
| mx.eval(out) |
| out_np = np.rot90(np.array(out), k=-k_rot) |
| outs.append(out_np) |
| result = np.stack(outs).mean(0).clip(0.0, 1.0) |
| return mx.array((result * 255).round().astype(np.uint8)) |
| else: |
| out = self._upscale_single(source, target_h, target_w, mean, var, std) |
| out = mx.clip(out, 0.0, 1.0) |
| return (out * 255 + 0.5).astype(mx.uint8) |
|
|
| def _upscale_single(self, source, target_h, target_w, mean, var, std): |
| Hs, Ws = source.shape[0], source.shape[1] |
| t = mx.array([(target_h / Hs) ** -2], dtype=mx.float32)[None] |
|
|
| target_grid = mx.array(make_grid(target_h, target_w))[None] |
| source_4d = source[None] |
| source_up = interpolate_nearest(target_grid, source_4d) |
|
|
| source_norm = (source_4d - mean) / std |
| encoding = self.encode(source_norm) |
|
|
| coords = mx.array(make_grid(target_h, target_w))[None] |
| residual = self.decode(encoding, coords, t) |
|
|
| out = residual * std + mean + source_up |
| return out[0] |
|
|
|
|
| |
| TheraRDNAir = Thera |
|
|