File size: 5,728 Bytes
29e0144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!/usr/bin/env python3
"""
Thera MLX β€” Arbitrary-scale super-resolution on Apple Silicon.

Usage:
    python run.py                         # Launch web UI (default port 5005)
    python run.py --port 8080             # Custom port
    python run.py run input.png out.png --scale 4 --model pro
    python run.py run input.png out.png --scale 2 --model air --ensemble --tiles 2
    python run.py convert --model air     # Download + convert weights
"""

import argparse
import sys
import os

# Make sure this directory is on the path so all modules import cleanly
sys.path.insert(0, os.path.dirname(__file__))


def main():
    parser = argparse.ArgumentParser(
        prog="thera-mlx",
        description="Thera MLX β€” arbitrary-scale super-resolution on Apple Silicon",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=__doc__,
    )
    sub = parser.add_subparsers(dest="command")

    # ── gui (default) ────────────────────────────────────────────────────────
    gui_p = sub.add_parser("gui", help="Launch web UI (default)")
    gui_p.add_argument("--port", type=int, default=5005, help="Port (default: 5005)")
    gui_p.add_argument("--host", type=str, default="127.0.0.1")

    # ── run ──────────────────────────────────────────────────────────────────
    run_p = sub.add_parser("run", help="Upscale a single image from the CLI")
    run_p.add_argument("input",  help="Input image path")
    run_p.add_argument("output", help="Output image path")
    scale_g = run_p.add_mutually_exclusive_group(required=True)
    scale_g.add_argument("--scale", type=float, help="Scale factor  (e.g. 2, 4)")
    scale_g.add_argument("--size",  type=int, nargs=2, metavar=("H", "W"),
                          help="Exact target size  height width")
    run_p.add_argument("--model",    choices=["air", "pro"], default="air",
                        help="Model variant: air (fast) or pro (quality). Default: air")
    run_p.add_argument("--weights",  default=None,
                        help="Path to weights file (optional β€” uses bundled weights by default)")
    run_p.add_argument("--ensemble", action="store_true",
                        help="Geometric self-ensemble (4 rotations) β€” higher quality, slower")
    run_p.add_argument("--tiles",    type=int, choices=[2, 3, 4], default=None,
                        help="Tile NxN to reduce RAM usage (2, 3, or 4)")

    # ── convert ──────────────────────────────────────────────────────────────
    conv_p = sub.add_parser("convert", help="Download + convert weights from Flax format")
    conv_p.add_argument("--model",  choices=["air", "pro"], default="air")
    conv_p.add_argument("--input",  default=None,
                         help="Path to model.pkl (downloads from HuggingFace if omitted)")
    conv_p.add_argument("--no-jax", action="store_true",
                         help="Load pickle without JAX installed")

    args = parser.parse_args()

    # Default: no subcommand β†’ launch GUI
    if args.command is None or args.command == "gui":
        _cmd_gui(args if args.command else argparse.Namespace(host="127.0.0.1", port=5005))
    elif args.command == "run":
        _cmd_run(args)
    elif args.command == "convert":
        _cmd_convert(args)


# ── Command implementations ───────────────────────────────────────────────────

def _cmd_gui(args):
    import logging
    logging.getLogger("werkzeug").setLevel(logging.ERROR)
    from ui import app
    print(f"\n  Thera MLX  β†’  http://{args.host}:{args.port}\n")
    app.run(host=args.host, port=args.port, debug=False, threaded=True)


def _cmd_run(args):
    from upscale import upscale_file
    upscale_file(
        args.input, args.output,
        scale=args.scale, size=args.size,
        model_size=args.model, weights_path=args.weights,
        ensemble=args.ensemble, tiles=args.tiles,
    )


def _cmd_convert(args):
    from pathlib import Path
    from upscale import WEIGHTS_DIR
    from convert import (
        REPO_IDS, download_model, convert_flax_to_mlx,
        load_pickle_with_jax, load_pickle_without_jax,
        save_safetensors, save_npz,
    )

    output_path = str(WEIGHTS_DIR / f"weights-{args.model}.safetensors")

    if args.input is None:
        repo = REPO_IDS[args.model]
        print(f"Downloading model from HuggingFace ({repo})...")
        pkl_path = download_model(args.model)
    else:
        pkl_path = args.input

    flax_params = (load_pickle_without_jax if args.no_jax else load_pickle_with_jax)(pkl_path)

    print("Converting weights...")
    mlx_weights = convert_flax_to_mlx(flax_params, size=args.model)

    total_params = sum(w.size for w in mlx_weights.values())
    print(f"  Parameters : {total_params:,}")
    print(f"  Weight keys: {len(mlx_weights)}")

    if output_path.endswith(".safetensors"):
        try:
            save_safetensors(mlx_weights, output_path)
        except ImportError:
            output_path = output_path.replace(".safetensors", ".npz")
            print("safetensors not installed β€” saving as .npz instead")
            save_npz(mlx_weights, output_path)
    else:
        save_npz(mlx_weights, output_path)

    print(f"Saved β†’ {output_path}")


if __name__ == "__main__":
    main()