File size: 5,728 Bytes
29e0144 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | #!/usr/bin/env python3
"""
Thera MLX β Arbitrary-scale super-resolution on Apple Silicon.
Usage:
python run.py # Launch web UI (default port 5005)
python run.py --port 8080 # Custom port
python run.py run input.png out.png --scale 4 --model pro
python run.py run input.png out.png --scale 2 --model air --ensemble --tiles 2
python run.py convert --model air # Download + convert weights
"""
import argparse
import sys
import os
# Make sure this directory is on the path so all modules import cleanly
sys.path.insert(0, os.path.dirname(__file__))
def main():
parser = argparse.ArgumentParser(
prog="thera-mlx",
description="Thera MLX β arbitrary-scale super-resolution on Apple Silicon",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
sub = parser.add_subparsers(dest="command")
# ββ gui (default) ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
gui_p = sub.add_parser("gui", help="Launch web UI (default)")
gui_p.add_argument("--port", type=int, default=5005, help="Port (default: 5005)")
gui_p.add_argument("--host", type=str, default="127.0.0.1")
# ββ run ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
run_p = sub.add_parser("run", help="Upscale a single image from the CLI")
run_p.add_argument("input", help="Input image path")
run_p.add_argument("output", help="Output image path")
scale_g = run_p.add_mutually_exclusive_group(required=True)
scale_g.add_argument("--scale", type=float, help="Scale factor (e.g. 2, 4)")
scale_g.add_argument("--size", type=int, nargs=2, metavar=("H", "W"),
help="Exact target size height width")
run_p.add_argument("--model", choices=["air", "pro"], default="air",
help="Model variant: air (fast) or pro (quality). Default: air")
run_p.add_argument("--weights", default=None,
help="Path to weights file (optional β uses bundled weights by default)")
run_p.add_argument("--ensemble", action="store_true",
help="Geometric self-ensemble (4 rotations) β higher quality, slower")
run_p.add_argument("--tiles", type=int, choices=[2, 3, 4], default=None,
help="Tile NxN to reduce RAM usage (2, 3, or 4)")
# ββ convert ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
conv_p = sub.add_parser("convert", help="Download + convert weights from Flax format")
conv_p.add_argument("--model", choices=["air", "pro"], default="air")
conv_p.add_argument("--input", default=None,
help="Path to model.pkl (downloads from HuggingFace if omitted)")
conv_p.add_argument("--no-jax", action="store_true",
help="Load pickle without JAX installed")
args = parser.parse_args()
# Default: no subcommand β launch GUI
if args.command is None or args.command == "gui":
_cmd_gui(args if args.command else argparse.Namespace(host="127.0.0.1", port=5005))
elif args.command == "run":
_cmd_run(args)
elif args.command == "convert":
_cmd_convert(args)
# ββ Command implementations βββββββββββββββββββββββββββββββββββββββββββββββββββ
def _cmd_gui(args):
import logging
logging.getLogger("werkzeug").setLevel(logging.ERROR)
from ui import app
print(f"\n Thera MLX β http://{args.host}:{args.port}\n")
app.run(host=args.host, port=args.port, debug=False, threaded=True)
def _cmd_run(args):
from upscale import upscale_file
upscale_file(
args.input, args.output,
scale=args.scale, size=args.size,
model_size=args.model, weights_path=args.weights,
ensemble=args.ensemble, tiles=args.tiles,
)
def _cmd_convert(args):
from pathlib import Path
from upscale import WEIGHTS_DIR
from convert import (
REPO_IDS, download_model, convert_flax_to_mlx,
load_pickle_with_jax, load_pickle_without_jax,
save_safetensors, save_npz,
)
output_path = str(WEIGHTS_DIR / f"weights-{args.model}.safetensors")
if args.input is None:
repo = REPO_IDS[args.model]
print(f"Downloading model from HuggingFace ({repo})...")
pkl_path = download_model(args.model)
else:
pkl_path = args.input
flax_params = (load_pickle_without_jax if args.no_jax else load_pickle_with_jax)(pkl_path)
print("Converting weights...")
mlx_weights = convert_flax_to_mlx(flax_params, size=args.model)
total_params = sum(w.size for w in mlx_weights.values())
print(f" Parameters : {total_params:,}")
print(f" Weight keys: {len(mlx_weights)}")
if output_path.endswith(".safetensors"):
try:
save_safetensors(mlx_weights, output_path)
except ImportError:
output_path = output_path.replace(".safetensors", ".npz")
print("safetensors not installed β saving as .npz instead")
save_npz(mlx_weights, output_path)
else:
save_npz(mlx_weights, output_path)
print(f"Saved β {output_path}")
if __name__ == "__main__":
main()
|