#!/usr/bin/env python3 """ Thera MLX — Arbitrary-scale super-resolution on Apple Silicon. Usage: python run.py # Launch web UI (default port 5005) python run.py --port 8080 # Custom port python run.py run input.png out.png --scale 4 --model pro python run.py run input.png out.png --scale 2 --model air --ensemble --tiles 2 python run.py convert --model air # Download + convert weights """ import argparse import sys import os # Make sure this directory is on the path so all modules import cleanly sys.path.insert(0, os.path.dirname(__file__)) def main(): parser = argparse.ArgumentParser( prog="thera-mlx", description="Thera MLX — arbitrary-scale super-resolution on Apple Silicon", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=__doc__, ) sub = parser.add_subparsers(dest="command") # ── gui (default) ──────────────────────────────────────────────────────── gui_p = sub.add_parser("gui", help="Launch web UI (default)") gui_p.add_argument("--port", type=int, default=5005, help="Port (default: 5005)") gui_p.add_argument("--host", type=str, default="127.0.0.1") # ── run ────────────────────────────────────────────────────────────────── run_p = sub.add_parser("run", help="Upscale a single image from the CLI") run_p.add_argument("input", help="Input image path") run_p.add_argument("output", help="Output image path") scale_g = run_p.add_mutually_exclusive_group(required=True) scale_g.add_argument("--scale", type=float, help="Scale factor (e.g. 2, 4)") scale_g.add_argument("--size", type=int, nargs=2, metavar=("H", "W"), help="Exact target size height width") run_p.add_argument("--model", choices=["air", "pro"], default="air", help="Model variant: air (fast) or pro (quality). Default: air") run_p.add_argument("--weights", default=None, help="Path to weights file (optional — uses bundled weights by default)") run_p.add_argument("--ensemble", action="store_true", help="Geometric self-ensemble (4 rotations) — higher quality, slower") run_p.add_argument("--tiles", type=int, choices=[2, 3, 4], default=None, help="Tile NxN to reduce RAM usage (2, 3, or 4)") # ── convert ────────────────────────────────────────────────────────────── conv_p = sub.add_parser("convert", help="Download + convert weights from Flax format") conv_p.add_argument("--model", choices=["air", "pro"], default="air") conv_p.add_argument("--input", default=None, help="Path to model.pkl (downloads from HuggingFace if omitted)") conv_p.add_argument("--no-jax", action="store_true", help="Load pickle without JAX installed") args = parser.parse_args() # Default: no subcommand → launch GUI if args.command is None or args.command == "gui": _cmd_gui(args if args.command else argparse.Namespace(host="127.0.0.1", port=5005)) elif args.command == "run": _cmd_run(args) elif args.command == "convert": _cmd_convert(args) # ── Command implementations ─────────────────────────────────────────────────── def _cmd_gui(args): import logging logging.getLogger("werkzeug").setLevel(logging.ERROR) from ui import app print(f"\n Thera MLX → http://{args.host}:{args.port}\n") app.run(host=args.host, port=args.port, debug=False, threaded=True) def _cmd_run(args): from upscale import upscale_file upscale_file( args.input, args.output, scale=args.scale, size=args.size, model_size=args.model, weights_path=args.weights, ensemble=args.ensemble, tiles=args.tiles, ) def _cmd_convert(args): from pathlib import Path from upscale import WEIGHTS_DIR from convert import ( REPO_IDS, download_model, convert_flax_to_mlx, load_pickle_with_jax, load_pickle_without_jax, save_safetensors, save_npz, ) output_path = str(WEIGHTS_DIR / f"weights-{args.model}.safetensors") if args.input is None: repo = REPO_IDS[args.model] print(f"Downloading model from HuggingFace ({repo})...") pkl_path = download_model(args.model) else: pkl_path = args.input flax_params = (load_pickle_without_jax if args.no_jax else load_pickle_with_jax)(pkl_path) print("Converting weights...") mlx_weights = convert_flax_to_mlx(flax_params, size=args.model) total_params = sum(w.size for w in mlx_weights.values()) print(f" Parameters : {total_params:,}") print(f" Weight keys: {len(mlx_weights)}") if output_path.endswith(".safetensors"): try: save_safetensors(mlx_weights, output_path) except ImportError: output_path = output_path.replace(".safetensors", ".npz") print("safetensors not installed — saving as .npz instead") save_npz(mlx_weights, output_path) else: save_npz(mlx_weights, output_path) print(f"Saved → {output_path}") if __name__ == "__main__": main()