thera-mlx / run.py
mlmPenguin's picture
Add source code
29e0144 verified
#!/usr/bin/env python3
"""
Thera MLX β€” Arbitrary-scale super-resolution on Apple Silicon.
Usage:
python run.py # Launch web UI (default port 5005)
python run.py --port 8080 # Custom port
python run.py run input.png out.png --scale 4 --model pro
python run.py run input.png out.png --scale 2 --model air --ensemble --tiles 2
python run.py convert --model air # Download + convert weights
"""
import argparse
import sys
import os
# Make sure this directory is on the path so all modules import cleanly
sys.path.insert(0, os.path.dirname(__file__))
def main():
parser = argparse.ArgumentParser(
prog="thera-mlx",
description="Thera MLX β€” arbitrary-scale super-resolution on Apple Silicon",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
sub = parser.add_subparsers(dest="command")
# ── gui (default) ────────────────────────────────────────────────────────
gui_p = sub.add_parser("gui", help="Launch web UI (default)")
gui_p.add_argument("--port", type=int, default=5005, help="Port (default: 5005)")
gui_p.add_argument("--host", type=str, default="127.0.0.1")
# ── run ──────────────────────────────────────────────────────────────────
run_p = sub.add_parser("run", help="Upscale a single image from the CLI")
run_p.add_argument("input", help="Input image path")
run_p.add_argument("output", help="Output image path")
scale_g = run_p.add_mutually_exclusive_group(required=True)
scale_g.add_argument("--scale", type=float, help="Scale factor (e.g. 2, 4)")
scale_g.add_argument("--size", type=int, nargs=2, metavar=("H", "W"),
help="Exact target size height width")
run_p.add_argument("--model", choices=["air", "pro"], default="air",
help="Model variant: air (fast) or pro (quality). Default: air")
run_p.add_argument("--weights", default=None,
help="Path to weights file (optional β€” uses bundled weights by default)")
run_p.add_argument("--ensemble", action="store_true",
help="Geometric self-ensemble (4 rotations) β€” higher quality, slower")
run_p.add_argument("--tiles", type=int, choices=[2, 3, 4], default=None,
help="Tile NxN to reduce RAM usage (2, 3, or 4)")
# ── convert ──────────────────────────────────────────────────────────────
conv_p = sub.add_parser("convert", help="Download + convert weights from Flax format")
conv_p.add_argument("--model", choices=["air", "pro"], default="air")
conv_p.add_argument("--input", default=None,
help="Path to model.pkl (downloads from HuggingFace if omitted)")
conv_p.add_argument("--no-jax", action="store_true",
help="Load pickle without JAX installed")
args = parser.parse_args()
# Default: no subcommand β†’ launch GUI
if args.command is None or args.command == "gui":
_cmd_gui(args if args.command else argparse.Namespace(host="127.0.0.1", port=5005))
elif args.command == "run":
_cmd_run(args)
elif args.command == "convert":
_cmd_convert(args)
# ── Command implementations ───────────────────────────────────────────────────
def _cmd_gui(args):
import logging
logging.getLogger("werkzeug").setLevel(logging.ERROR)
from ui import app
print(f"\n Thera MLX β†’ http://{args.host}:{args.port}\n")
app.run(host=args.host, port=args.port, debug=False, threaded=True)
def _cmd_run(args):
from upscale import upscale_file
upscale_file(
args.input, args.output,
scale=args.scale, size=args.size,
model_size=args.model, weights_path=args.weights,
ensemble=args.ensemble, tiles=args.tiles,
)
def _cmd_convert(args):
from pathlib import Path
from upscale import WEIGHTS_DIR
from convert import (
REPO_IDS, download_model, convert_flax_to_mlx,
load_pickle_with_jax, load_pickle_without_jax,
save_safetensors, save_npz,
)
output_path = str(WEIGHTS_DIR / f"weights-{args.model}.safetensors")
if args.input is None:
repo = REPO_IDS[args.model]
print(f"Downloading model from HuggingFace ({repo})...")
pkl_path = download_model(args.model)
else:
pkl_path = args.input
flax_params = (load_pickle_without_jax if args.no_jax else load_pickle_with_jax)(pkl_path)
print("Converting weights...")
mlx_weights = convert_flax_to_mlx(flax_params, size=args.model)
total_params = sum(w.size for w in mlx_weights.values())
print(f" Parameters : {total_params:,}")
print(f" Weight keys: {len(mlx_weights)}")
if output_path.endswith(".safetensors"):
try:
save_safetensors(mlx_weights, output_path)
except ImportError:
output_path = output_path.replace(".safetensors", ".npz")
print("safetensors not installed β€” saving as .npz instead")
save_npz(mlx_weights, output_path)
else:
save_npz(mlx_weights, output_path)
print(f"Saved β†’ {output_path}")
if __name__ == "__main__":
main()