| |
| """Core upscale logic for Thera MLX.""" |
|
|
| import time |
| from pathlib import Path |
|
|
| import mlx.core as mx |
| import mlx.nn as nn |
| import numpy as np |
| from PIL import Image |
|
|
| from model import Thera |
|
|
| WEIGHTS_DIR = Path(__file__).parent / "weights" |
|
|
|
|
| def load_weights(model, weights_path): |
| """Load converted weights into the MLX model.""" |
| weights_path = str(weights_path) |
| if weights_path.endswith('.safetensors'): |
| from safetensors.numpy import load_file |
| raw = load_file(weights_path) |
| weights = {k: mx.array(v) for k, v in raw.items()} |
| elif weights_path.endswith('.npz'): |
| raw = np.load(weights_path) |
| weights = {k: mx.array(raw[k]) for k, v in raw.items()} |
| else: |
| raise ValueError(f"Unknown weight format: {weights_path}") |
|
|
| weight_list = list(weights.items()) |
| model.load_weights(weight_list) |
| return model |
|
|
|
|
| def get_weights_path(model_size): |
| """Resolve weights path for a model variant.""" |
| return WEIGHTS_DIR / f"weights-{model_size}.safetensors" |
|
|
|
|
| def upscale_tiled(model, source_np, target_h, target_w, tiles, ensemble=False): |
| """Upscale an image using NxN tiles to reduce peak RAM. |
| |
| Splits the source image into tiles with overlap, upscales each tile |
| individually, then blends them back together using linear feathering. |
| |
| Args: |
| model: Loaded Thera model. |
| source_np: numpy array (H, W, 3) float32 in [0, 1]. |
| target_h: Target height. |
| target_w: Target width. |
| tiles: Number of tiles per axis (2, 3, or 4). |
| ensemble: Use geometric self-ensemble. |
| |
| Returns: |
| numpy uint8 array (target_h, target_w, 3). |
| """ |
| h, w = source_np.shape[:2] |
| scale_h = target_h / h |
| scale_w = target_w / w |
|
|
| |
| tile_h = h / tiles |
| tile_w = w / tiles |
| overlap_h = max(8, int(tile_h * 0.1)) |
| overlap_w = max(8, int(tile_w * 0.1)) |
|
|
| |
| output = np.zeros((target_h, target_w, 3), dtype=np.float32) |
| weight_map = np.zeros((target_h, target_w, 1), dtype=np.float32) |
|
|
| total_tiles = tiles * tiles |
| done = 0 |
|
|
| for row in range(tiles): |
| for col in range(tiles): |
| |
| sy0 = round(row * h / tiles) - (overlap_h if row > 0 else 0) |
| sy1 = round((row + 1) * h / tiles) + (overlap_h if row < tiles - 1 else 0) |
| sx0 = round(col * w / tiles) - (overlap_w if col > 0 else 0) |
| sx1 = round((col + 1) * w / tiles) + (overlap_w if col < tiles - 1 else 0) |
|
|
| sy0 = max(0, sy0) |
| sy1 = min(h, sy1) |
| sx0 = max(0, sx0) |
| sx1 = min(w, sx1) |
|
|
| tile_src = source_np[sy0:sy1, sx0:sx1] |
| th = round((sy1 - sy0) * scale_h) |
| tw = round((sx1 - sx0) * scale_w) |
|
|
| |
| result = model.upscale(mx.array(tile_src), th, tw, ensemble=ensemble) |
| mx.eval(result) |
| tile_out = np.array(result).astype(np.float32) / 255.0 |
|
|
| |
| ty0 = round(sy0 * scale_h) |
| tx0 = round(sx0 * scale_w) |
| ty1 = ty0 + tile_out.shape[0] |
| tx1 = tx0 + tile_out.shape[1] |
|
|
| |
| ty1 = min(ty1, target_h) |
| tx1 = min(tx1, target_w) |
| tile_out = tile_out[:ty1 - ty0, :tx1 - tx0] |
|
|
| |
| fh, fw = tile_out.shape[:2] |
| wy = np.ones(fh, dtype=np.float32) |
| wx = np.ones(fw, dtype=np.float32) |
|
|
| |
| ovl_top = round(overlap_h * scale_h) if row > 0 else 0 |
| ovl_bot = round(overlap_h * scale_h) if row < tiles - 1 else 0 |
| ovl_left = round(overlap_w * scale_w) if col > 0 else 0 |
| ovl_right = round(overlap_w * scale_w) if col < tiles - 1 else 0 |
|
|
| if ovl_top > 0: |
| ramp = np.linspace(0, 1, min(ovl_top, fh), dtype=np.float32) |
| wy[:len(ramp)] = ramp |
| if ovl_bot > 0: |
| ramp = np.linspace(1, 0, min(ovl_bot, fh), dtype=np.float32) |
| wy[-len(ramp):] = np.minimum(wy[-len(ramp):], ramp) |
| if ovl_left > 0: |
| ramp = np.linspace(0, 1, min(ovl_left, fw), dtype=np.float32) |
| wx[:len(ramp)] = ramp |
| if ovl_right > 0: |
| ramp = np.linspace(1, 0, min(ovl_right, fw), dtype=np.float32) |
| wx[-len(ramp):] = np.minimum(wx[-len(ramp):], ramp) |
|
|
| w2d = wy[:, None] * wx[None, :] |
| w3d = w2d[:, :, None] |
|
|
| output[ty0:ty1, tx0:tx1] += tile_out * w3d |
| weight_map[ty0:ty1, tx0:tx1] += w3d |
|
|
| done += 1 |
| print(f" tile {done}/{total_tiles}") |
|
|
| |
| weight_map = np.maximum(weight_map, 1e-8) |
| output = (output / weight_map * 255 + 0.5).clip(0, 255).astype(np.uint8) |
| return output |
|
|
|
|
| def upscale_file(input_path, output_path, scale=None, size=None, |
| model_size='air', weights_path=None, ensemble=False, |
| tiles=None): |
| """Upscale a single image file.""" |
| img = Image.open(input_path).convert('RGB') |
| source = np.asarray(img, dtype=np.float32) / 255.0 |
| h, w = source.shape[:2] |
|
|
| if scale is not None: |
| target_h = round(h * scale) |
| target_w = round(w * scale) |
| elif size is not None: |
| target_h, target_w = size |
| else: |
| raise ValueError("Must specify either scale or size") |
|
|
| scale_actual = target_h / h |
| if weights_path is None: |
| weights_path = get_weights_path(model_size) |
|
|
| model = Thera(size=model_size) |
| model = load_weights(model, weights_path) |
| mx.eval(model.parameters()) |
|
|
| t0 = time.perf_counter() |
|
|
| if tiles and tiles > 1: |
| print(f"Tiled upscale: {tiles}x{tiles} ({tiles*tiles} tiles)") |
| result_np = upscale_tiled(model, source, target_h, target_w, |
| tiles, ensemble=ensemble) |
| Image.fromarray(result_np).save(output_path) |
| else: |
| result = model.upscale(mx.array(source), target_h, target_w, ensemble=ensemble) |
| mx.eval(result) |
| Image.fromarray(np.array(result)).save(output_path) |
|
|
| elapsed = time.perf_counter() - t0 |
|
|
| suffix = " (ensemble)" if ensemble else "" |
| tile_info = f" [{tiles}x{tiles} tiles]" if tiles and tiles > 1 else "" |
| print(f"[{model_size}]{suffix}{tile_info} {w}x{h} -> {target_w}x{target_h} ({scale_actual:.4g}x) {elapsed:.1f}s -> {output_path}") |
|
|