thera-mlx / model.py
mlmPenguin's picture
Add source code
29e0144 verified
"""MLX implementation of Thera super-resolution models (air/pro variants)."""
import math
import mlx.core as mx
import mlx.nn as nn
import numpy as np
# --- Utility functions ---
def make_grid(h, w):
"""Create coordinate grid in [-0.5, 0.5] with pixel centers."""
offset_h = 1.0 / (2 * h)
offset_w = 1.0 / (2 * w)
ys = np.linspace(-0.5 + offset_h, 0.5 - offset_h, h, dtype=np.float32)
xs = np.linspace(-0.5 + offset_w, 0.5 - offset_w, w, dtype=np.float32)
grid_y, grid_x = np.meshgrid(ys, xs, indexing='ij')
return np.stack([grid_y, grid_x], axis=-1) # (H, W, 2)
def interpolate_nearest(coords, grid):
"""
Nearest-neighbor sampling of a grid at given coordinates.
Args:
coords: mx.array (B, H, W, 2) coordinates in [-0.5, 0.5]
grid: mx.array (B, H', W', C) grid to sample from
Returns:
mx.array (B, H, W, C)
"""
B, Hp, Wp, C = grid.shape
_, H, W, _ = coords.shape
y = coords[..., 0] * Hp + (Hp - 1) / 2.0
x = coords[..., 1] * Wp + (Wp - 1) / 2.0
y_idx = mx.clip(mx.round(y).astype(mx.int32), 0, Hp - 1)
x_idx = mx.clip(mx.round(x).astype(mx.int32), 0, Wp - 1)
flat_idx = y_idx * Wp + x_idx # (B, H, W)
batch_offset = mx.arange(B).reshape(B, 1, 1) * (Hp * Wp)
global_idx = (flat_idx + batch_offset).reshape(-1) # (B*H*W,)
grid_flat = grid.reshape(-1, C) # (B*Hp*Wp, C)
result = grid_flat[global_idx] # (B*H*W, C)
return result.reshape(B, H, W, C)
# --- RDN Backbone ---
class RDBConv(nn.Module):
"""Single convolution layer within a Residual Dense Block."""
def __init__(self, in_channels: int, growth_rate: int, kernel_size: int = 3):
super().__init__()
self.conv = nn.Conv2d(in_channels, growth_rate, kernel_size,
padding=(kernel_size - 1) // 2)
def __call__(self, x):
out = nn.relu(self.conv(x))
return mx.concatenate([x, out], axis=-1)
class RDB(nn.Module):
"""Residual Dense Block."""
def __init__(self, g0: int, growth_rate: int, n_conv_layers: int):
super().__init__()
self.convs = [
RDBConv(g0 + i * growth_rate, growth_rate)
for i in range(n_conv_layers)
]
total_ch = g0 + n_conv_layers * growth_rate
self.local_fusion = nn.Conv2d(total_ch, g0, kernel_size=1)
def __call__(self, x):
res = x
for conv in self.convs:
x = conv(x)
x = self.local_fusion(x)
return x + res
class RDN(nn.Module):
"""Residual Dense Network backbone (config B)."""
def __init__(self, n_colors: int = 3, g0: int = 64):
super().__init__()
D, C, G = 16, 8, 64 # config B
self.sfe1 = nn.Conv2d(n_colors, g0, kernel_size=3, padding=1)
self.sfe2 = nn.Conv2d(g0, g0, kernel_size=3, padding=1)
self.rdbs = [RDB(g0, G, C) for _ in range(D)]
self.gff_1x1 = nn.Conv2d(D * g0, g0, kernel_size=1)
self.gff_3x3 = nn.Conv2d(g0, g0, kernel_size=3, padding=1)
def __call__(self, x):
f1 = self.sfe1(x)
x = self.sfe2(f1)
rdb_outs = []
for rdb in self.rdbs:
x = rdb(x)
rdb_outs.append(x)
x = mx.concatenate(rdb_outs, axis=-1)
x = self.gff_1x1(x)
x = self.gff_3x3(x)
return x + f1
# --- Thera Model ---
class Thera(nn.Module):
"""
Thera: arbitrary-scale super-resolution using neural heat fields.
Stages:
1. Encoder (RDN backbone) produces features at source resolution
2. Optional refinement tail (identity for air, SwinIR for pro)
3. Hypernetwork (1x1 conv) predicts per-pixel field parameters
4. Heat field decoder produces RGB residuals
"""
OUT_DIM = 3
W0 = 1.0
MEAN = np.array([0.4488, 0.4371, 0.4040], dtype=np.float32)
VAR = np.array([0.25, 0.25, 0.25], dtype=np.float32)
def __init__(self, size='air'):
super().__init__()
self.size = size
self.hidden_dim = 32 if size == 'air' else 512
# Field params: Dense kernel + Thermal phase (alphabetical order)
n_field_params = self.hidden_dim * self.OUT_DIM + self.hidden_dim
self.encoder = RDN(n_colors=3, g0=64)
# Refinement tail
if size == 'pro':
from swin_ir import SwinIRTail
self.refine = SwinIRTail(
in_channels=64, embed_dim=180,
depths=(7, 6), num_heads=(6, 6),
window_size=8, mlp_ratio=2.0, num_feat=64)
# For 'air', no refine module (identity)
self.out_conv = nn.Conv2d(64, n_field_params, kernel_size=1)
self.k = mx.array(0.0)
self.components = mx.zeros((2, self.hidden_dim))
def encode(self, source_norm):
"""Run encoder + optional refinement tail."""
x = self.encoder(source_norm)
if self.size == 'pro':
x = self.refine(x)
return x
def decode(self, encoding, target_coords, t):
"""Predict RGB residuals at target coordinates."""
sampled = interpolate_nearest(target_coords, encoding)
phi = self.out_conv(sampled)
hd = self.hidden_dim
kernel = phi[..., :hd * self.OUT_DIM].reshape(
*phi.shape[:-1], hd, self.OUT_DIM)
phase = phi[..., hd * self.OUT_DIM:]
Hs, Ws = encoding.shape[1], encoding.shape[2]
source_grid = mx.array(make_grid(Hs, Ws))
source_coords = mx.broadcast_to(
source_grid[None], (encoding.shape[0],) + source_grid.shape)
nearest_src = interpolate_nearest(target_coords, source_coords)
rel_coords = target_coords - nearest_src
rel_coords_scaled = mx.concatenate([
rel_coords[..., 0:1] * Hs,
rel_coords[..., 1:2] * Ws,
], axis=-1)
x = rel_coords_scaled @ self.components
norm = mx.linalg.norm(self.components, axis=0)
t_4d = t[:, :, None, None] if t.ndim == 2 else t.reshape(-1, 1, 1, 1)
decay = mx.exp(-((self.W0 * norm) ** 2) * self.k * t_4d)
x = mx.sin(self.W0 * x + phase) * decay
out = mx.sum(x[..., None] * kernel, axis=-2)
return out
def upscale(self, source, target_h, target_w, ensemble=False):
mean = mx.array(self.MEAN)
var = mx.array(self.VAR)
std = mx.sqrt(var)
if ensemble:
outs = []
for k_rot in range(4):
src = mx.array(np.rot90(np.array(source), k=k_rot))
th = target_w if k_rot % 2 else target_h
tw = target_h if k_rot % 2 else target_w
out = self._upscale_single(src, th, tw, mean, var, std)
mx.eval(out)
out_np = np.rot90(np.array(out), k=-k_rot)
outs.append(out_np)
result = np.stack(outs).mean(0).clip(0.0, 1.0)
return mx.array((result * 255).round().astype(np.uint8))
else:
out = self._upscale_single(source, target_h, target_w, mean, var, std)
out = mx.clip(out, 0.0, 1.0)
return (out * 255 + 0.5).astype(mx.uint8)
def _upscale_single(self, source, target_h, target_w, mean, var, std):
Hs, Ws = source.shape[0], source.shape[1]
t = mx.array([(target_h / Hs) ** -2], dtype=mx.float32)[None]
target_grid = mx.array(make_grid(target_h, target_w))[None]
source_4d = source[None]
source_up = interpolate_nearest(target_grid, source_4d)
source_norm = (source_4d - mean) / std
encoding = self.encode(source_norm)
coords = mx.array(make_grid(target_h, target_w))[None]
residual = self.decode(encoding, coords, t)
out = residual * std + mean + source_up
return out[0]
# Backwards compatibility alias
TheraRDNAir = Thera