thera-mlx / upscale.py
mlmPenguin's picture
Add source code
29e0144 verified
#!/usr/bin/env python3
"""Core upscale logic for Thera MLX."""
import time
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from PIL import Image
from model import Thera
WEIGHTS_DIR = Path(__file__).parent / "weights"
def load_weights(model, weights_path):
"""Load converted weights into the MLX model."""
weights_path = str(weights_path)
if weights_path.endswith('.safetensors'):
from safetensors.numpy import load_file
raw = load_file(weights_path)
weights = {k: mx.array(v) for k, v in raw.items()}
elif weights_path.endswith('.npz'):
raw = np.load(weights_path)
weights = {k: mx.array(raw[k]) for k, v in raw.items()}
else:
raise ValueError(f"Unknown weight format: {weights_path}")
weight_list = list(weights.items())
model.load_weights(weight_list)
return model
def get_weights_path(model_size):
"""Resolve weights path for a model variant."""
return WEIGHTS_DIR / f"weights-{model_size}.safetensors"
def upscale_tiled(model, source_np, target_h, target_w, tiles, ensemble=False):
"""Upscale an image using NxN tiles to reduce peak RAM.
Splits the source image into tiles with overlap, upscales each tile
individually, then blends them back together using linear feathering.
Args:
model: Loaded Thera model.
source_np: numpy array (H, W, 3) float32 in [0, 1].
target_h: Target height.
target_w: Target width.
tiles: Number of tiles per axis (2, 3, or 4).
ensemble: Use geometric self-ensemble.
Returns:
numpy uint8 array (target_h, target_w, 3).
"""
h, w = source_np.shape[:2]
scale_h = target_h / h
scale_w = target_w / w
# Overlap in source pixels (10% of tile size, minimum 8px)
tile_h = h / tiles
tile_w = w / tiles
overlap_h = max(8, int(tile_h * 0.1))
overlap_w = max(8, int(tile_w * 0.1))
# Build output canvas (float32 for blending)
output = np.zeros((target_h, target_w, 3), dtype=np.float32)
weight_map = np.zeros((target_h, target_w, 1), dtype=np.float32)
total_tiles = tiles * tiles
done = 0
for row in range(tiles):
for col in range(tiles):
# Source tile bounds with overlap
sy0 = round(row * h / tiles) - (overlap_h if row > 0 else 0)
sy1 = round((row + 1) * h / tiles) + (overlap_h if row < tiles - 1 else 0)
sx0 = round(col * w / tiles) - (overlap_w if col > 0 else 0)
sx1 = round((col + 1) * w / tiles) + (overlap_w if col < tiles - 1 else 0)
sy0 = max(0, sy0)
sy1 = min(h, sy1)
sx0 = max(0, sx0)
sx1 = min(w, sx1)
tile_src = source_np[sy0:sy1, sx0:sx1]
th = round((sy1 - sy0) * scale_h)
tw = round((sx1 - sx0) * scale_w)
# Upscale tile
result = model.upscale(mx.array(tile_src), th, tw, ensemble=ensemble)
mx.eval(result)
tile_out = np.array(result).astype(np.float32) / 255.0
# Target tile bounds
ty0 = round(sy0 * scale_h)
tx0 = round(sx0 * scale_w)
ty1 = ty0 + tile_out.shape[0]
tx1 = tx0 + tile_out.shape[1]
# Clamp to output bounds
ty1 = min(ty1, target_h)
tx1 = min(tx1, target_w)
tile_out = tile_out[:ty1 - ty0, :tx1 - tx0]
# Linear feather weight for blending overlaps
fh, fw = tile_out.shape[:2]
wy = np.ones(fh, dtype=np.float32)
wx = np.ones(fw, dtype=np.float32)
# Feather top/bottom edges in overlap regions
ovl_top = round(overlap_h * scale_h) if row > 0 else 0
ovl_bot = round(overlap_h * scale_h) if row < tiles - 1 else 0
ovl_left = round(overlap_w * scale_w) if col > 0 else 0
ovl_right = round(overlap_w * scale_w) if col < tiles - 1 else 0
if ovl_top > 0:
ramp = np.linspace(0, 1, min(ovl_top, fh), dtype=np.float32)
wy[:len(ramp)] = ramp
if ovl_bot > 0:
ramp = np.linspace(1, 0, min(ovl_bot, fh), dtype=np.float32)
wy[-len(ramp):] = np.minimum(wy[-len(ramp):], ramp)
if ovl_left > 0:
ramp = np.linspace(0, 1, min(ovl_left, fw), dtype=np.float32)
wx[:len(ramp)] = ramp
if ovl_right > 0:
ramp = np.linspace(1, 0, min(ovl_right, fw), dtype=np.float32)
wx[-len(ramp):] = np.minimum(wx[-len(ramp):], ramp)
w2d = wy[:, None] * wx[None, :] # (fh, fw)
w3d = w2d[:, :, None] # (fh, fw, 1)
output[ty0:ty1, tx0:tx1] += tile_out * w3d
weight_map[ty0:ty1, tx0:tx1] += w3d
done += 1
print(f" tile {done}/{total_tiles}")
# Normalize by weight
weight_map = np.maximum(weight_map, 1e-8)
output = (output / weight_map * 255 + 0.5).clip(0, 255).astype(np.uint8)
return output
def upscale_file(input_path, output_path, scale=None, size=None,
model_size='air', weights_path=None, ensemble=False,
tiles=None):
"""Upscale a single image file."""
img = Image.open(input_path).convert('RGB')
source = np.asarray(img, dtype=np.float32) / 255.0
h, w = source.shape[:2]
if scale is not None:
target_h = round(h * scale)
target_w = round(w * scale)
elif size is not None:
target_h, target_w = size
else:
raise ValueError("Must specify either scale or size")
scale_actual = target_h / h
if weights_path is None:
weights_path = get_weights_path(model_size)
model = Thera(size=model_size)
model = load_weights(model, weights_path)
mx.eval(model.parameters())
t0 = time.perf_counter()
if tiles and tiles > 1:
print(f"Tiled upscale: {tiles}x{tiles} ({tiles*tiles} tiles)")
result_np = upscale_tiled(model, source, target_h, target_w,
tiles, ensemble=ensemble)
Image.fromarray(result_np).save(output_path)
else:
result = model.upscale(mx.array(source), target_h, target_w, ensemble=ensemble)
mx.eval(result)
Image.fromarray(np.array(result)).save(output_path)
elapsed = time.perf_counter() - t0
suffix = " (ensemble)" if ensemble else ""
tile_info = f" [{tiles}x{tiles} tiles]" if tiles and tiles > 1 else ""
print(f"[{model_size}]{suffix}{tile_info} {w}x{h} -> {target_w}x{target_h} ({scale_actual:.4g}x) {elapsed:.1f}s -> {output_path}")