File size: 6,668 Bytes
29e0144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
#!/usr/bin/env python3
"""Core upscale logic for Thera MLX."""

import time
from pathlib import Path

import mlx.core as mx
import mlx.nn as nn
import numpy as np
from PIL import Image

from model import Thera

WEIGHTS_DIR = Path(__file__).parent / "weights"


def load_weights(model, weights_path):
    """Load converted weights into the MLX model."""
    weights_path = str(weights_path)
    if weights_path.endswith('.safetensors'):
        from safetensors.numpy import load_file
        raw = load_file(weights_path)
        weights = {k: mx.array(v) for k, v in raw.items()}
    elif weights_path.endswith('.npz'):
        raw = np.load(weights_path)
        weights = {k: mx.array(raw[k]) for k, v in raw.items()}
    else:
        raise ValueError(f"Unknown weight format: {weights_path}")

    weight_list = list(weights.items())
    model.load_weights(weight_list)
    return model


def get_weights_path(model_size):
    """Resolve weights path for a model variant."""
    return WEIGHTS_DIR / f"weights-{model_size}.safetensors"


def upscale_tiled(model, source_np, target_h, target_w, tiles, ensemble=False):
    """Upscale an image using NxN tiles to reduce peak RAM.

    Splits the source image into tiles with overlap, upscales each tile
    individually, then blends them back together using linear feathering.

    Args:
        model: Loaded Thera model.
        source_np: numpy array (H, W, 3) float32 in [0, 1].
        target_h: Target height.
        target_w: Target width.
        tiles: Number of tiles per axis (2, 3, or 4).
        ensemble: Use geometric self-ensemble.

    Returns:
        numpy uint8 array (target_h, target_w, 3).
    """
    h, w = source_np.shape[:2]
    scale_h = target_h / h
    scale_w = target_w / w

    # Overlap in source pixels (10% of tile size, minimum 8px)
    tile_h = h / tiles
    tile_w = w / tiles
    overlap_h = max(8, int(tile_h * 0.1))
    overlap_w = max(8, int(tile_w * 0.1))

    # Build output canvas (float32 for blending)
    output = np.zeros((target_h, target_w, 3), dtype=np.float32)
    weight_map = np.zeros((target_h, target_w, 1), dtype=np.float32)

    total_tiles = tiles * tiles
    done = 0

    for row in range(tiles):
        for col in range(tiles):
            # Source tile bounds with overlap
            sy0 = round(row * h / tiles) - (overlap_h if row > 0 else 0)
            sy1 = round((row + 1) * h / tiles) + (overlap_h if row < tiles - 1 else 0)
            sx0 = round(col * w / tiles) - (overlap_w if col > 0 else 0)
            sx1 = round((col + 1) * w / tiles) + (overlap_w if col < tiles - 1 else 0)

            sy0 = max(0, sy0)
            sy1 = min(h, sy1)
            sx0 = max(0, sx0)
            sx1 = min(w, sx1)

            tile_src = source_np[sy0:sy1, sx0:sx1]
            th = round((sy1 - sy0) * scale_h)
            tw = round((sx1 - sx0) * scale_w)

            # Upscale tile
            result = model.upscale(mx.array(tile_src), th, tw, ensemble=ensemble)
            mx.eval(result)
            tile_out = np.array(result).astype(np.float32) / 255.0

            # Target tile bounds
            ty0 = round(sy0 * scale_h)
            tx0 = round(sx0 * scale_w)
            ty1 = ty0 + tile_out.shape[0]
            tx1 = tx0 + tile_out.shape[1]

            # Clamp to output bounds
            ty1 = min(ty1, target_h)
            tx1 = min(tx1, target_w)
            tile_out = tile_out[:ty1 - ty0, :tx1 - tx0]

            # Linear feather weight for blending overlaps
            fh, fw = tile_out.shape[:2]
            wy = np.ones(fh, dtype=np.float32)
            wx = np.ones(fw, dtype=np.float32)

            # Feather top/bottom edges in overlap regions
            ovl_top = round(overlap_h * scale_h) if row > 0 else 0
            ovl_bot = round(overlap_h * scale_h) if row < tiles - 1 else 0
            ovl_left = round(overlap_w * scale_w) if col > 0 else 0
            ovl_right = round(overlap_w * scale_w) if col < tiles - 1 else 0

            if ovl_top > 0:
                ramp = np.linspace(0, 1, min(ovl_top, fh), dtype=np.float32)
                wy[:len(ramp)] = ramp
            if ovl_bot > 0:
                ramp = np.linspace(1, 0, min(ovl_bot, fh), dtype=np.float32)
                wy[-len(ramp):] = np.minimum(wy[-len(ramp):], ramp)
            if ovl_left > 0:
                ramp = np.linspace(0, 1, min(ovl_left, fw), dtype=np.float32)
                wx[:len(ramp)] = ramp
            if ovl_right > 0:
                ramp = np.linspace(1, 0, min(ovl_right, fw), dtype=np.float32)
                wx[-len(ramp):] = np.minimum(wx[-len(ramp):], ramp)

            w2d = wy[:, None] * wx[None, :]  # (fh, fw)
            w3d = w2d[:, :, None]  # (fh, fw, 1)

            output[ty0:ty1, tx0:tx1] += tile_out * w3d
            weight_map[ty0:ty1, tx0:tx1] += w3d

            done += 1
            print(f"  tile {done}/{total_tiles}")

    # Normalize by weight
    weight_map = np.maximum(weight_map, 1e-8)
    output = (output / weight_map * 255 + 0.5).clip(0, 255).astype(np.uint8)
    return output


def upscale_file(input_path, output_path, scale=None, size=None,
                 model_size='air', weights_path=None, ensemble=False,
                 tiles=None):
    """Upscale a single image file."""
    img = Image.open(input_path).convert('RGB')
    source = np.asarray(img, dtype=np.float32) / 255.0
    h, w = source.shape[:2]

    if scale is not None:
        target_h = round(h * scale)
        target_w = round(w * scale)
    elif size is not None:
        target_h, target_w = size
    else:
        raise ValueError("Must specify either scale or size")

    scale_actual = target_h / h
    if weights_path is None:
        weights_path = get_weights_path(model_size)

    model = Thera(size=model_size)
    model = load_weights(model, weights_path)
    mx.eval(model.parameters())

    t0 = time.perf_counter()

    if tiles and tiles > 1:
        print(f"Tiled upscale: {tiles}x{tiles} ({tiles*tiles} tiles)")
        result_np = upscale_tiled(model, source, target_h, target_w,
                                  tiles, ensemble=ensemble)
        Image.fromarray(result_np).save(output_path)
    else:
        result = model.upscale(mx.array(source), target_h, target_w, ensemble=ensemble)
        mx.eval(result)
        Image.fromarray(np.array(result)).save(output_path)

    elapsed = time.perf_counter() - t0

    suffix = " (ensemble)" if ensemble else ""
    tile_info = f" [{tiles}x{tiles} tiles]" if tiles and tiles > 1 else ""
    print(f"[{model_size}]{suffix}{tile_info} {w}x{h} -> {target_w}x{target_h} ({scale_actual:.4g}x)  {elapsed:.1f}s  ->  {output_path}")