| |
| """ |
| Thera MLX β Arbitrary-scale super-resolution on Apple Silicon. |
| |
| Usage: |
| python run.py # Launch web UI (default port 5005) |
| python run.py --port 8080 # Custom port |
| python run.py run input.png out.png --scale 4 --model pro |
| python run.py run input.png out.png --scale 2 --model air --ensemble --tiles 2 |
| python run.py convert --model air # Download + convert weights |
| """ |
|
|
| import argparse |
| import sys |
| import os |
|
|
| |
| sys.path.insert(0, os.path.dirname(__file__)) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| prog="thera-mlx", |
| description="Thera MLX β arbitrary-scale super-resolution on Apple Silicon", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=__doc__, |
| ) |
| sub = parser.add_subparsers(dest="command") |
|
|
| |
| gui_p = sub.add_parser("gui", help="Launch web UI (default)") |
| gui_p.add_argument("--port", type=int, default=5005, help="Port (default: 5005)") |
| gui_p.add_argument("--host", type=str, default="127.0.0.1") |
|
|
| |
| run_p = sub.add_parser("run", help="Upscale a single image from the CLI") |
| run_p.add_argument("input", help="Input image path") |
| run_p.add_argument("output", help="Output image path") |
| scale_g = run_p.add_mutually_exclusive_group(required=True) |
| scale_g.add_argument("--scale", type=float, help="Scale factor (e.g. 2, 4)") |
| scale_g.add_argument("--size", type=int, nargs=2, metavar=("H", "W"), |
| help="Exact target size height width") |
| run_p.add_argument("--model", choices=["air", "pro"], default="air", |
| help="Model variant: air (fast) or pro (quality). Default: air") |
| run_p.add_argument("--weights", default=None, |
| help="Path to weights file (optional β uses bundled weights by default)") |
| run_p.add_argument("--ensemble", action="store_true", |
| help="Geometric self-ensemble (4 rotations) β higher quality, slower") |
| run_p.add_argument("--tiles", type=int, choices=[2, 3, 4], default=None, |
| help="Tile NxN to reduce RAM usage (2, 3, or 4)") |
|
|
| |
| conv_p = sub.add_parser("convert", help="Download + convert weights from Flax format") |
| conv_p.add_argument("--model", choices=["air", "pro"], default="air") |
| conv_p.add_argument("--input", default=None, |
| help="Path to model.pkl (downloads from HuggingFace if omitted)") |
| conv_p.add_argument("--no-jax", action="store_true", |
| help="Load pickle without JAX installed") |
|
|
| args = parser.parse_args() |
|
|
| |
| if args.command is None or args.command == "gui": |
| _cmd_gui(args if args.command else argparse.Namespace(host="127.0.0.1", port=5005)) |
| elif args.command == "run": |
| _cmd_run(args) |
| elif args.command == "convert": |
| _cmd_convert(args) |
|
|
|
|
| |
|
|
| def _cmd_gui(args): |
| import logging |
| logging.getLogger("werkzeug").setLevel(logging.ERROR) |
| from ui import app |
| print(f"\n Thera MLX β http://{args.host}:{args.port}\n") |
| app.run(host=args.host, port=args.port, debug=False, threaded=True) |
|
|
|
|
| def _cmd_run(args): |
| from upscale import upscale_file |
| upscale_file( |
| args.input, args.output, |
| scale=args.scale, size=args.size, |
| model_size=args.model, weights_path=args.weights, |
| ensemble=args.ensemble, tiles=args.tiles, |
| ) |
|
|
|
|
| def _cmd_convert(args): |
| from pathlib import Path |
| from upscale import WEIGHTS_DIR |
| from convert import ( |
| REPO_IDS, download_model, convert_flax_to_mlx, |
| load_pickle_with_jax, load_pickle_without_jax, |
| save_safetensors, save_npz, |
| ) |
|
|
| output_path = str(WEIGHTS_DIR / f"weights-{args.model}.safetensors") |
|
|
| if args.input is None: |
| repo = REPO_IDS[args.model] |
| print(f"Downloading model from HuggingFace ({repo})...") |
| pkl_path = download_model(args.model) |
| else: |
| pkl_path = args.input |
|
|
| flax_params = (load_pickle_without_jax if args.no_jax else load_pickle_with_jax)(pkl_path) |
|
|
| print("Converting weights...") |
| mlx_weights = convert_flax_to_mlx(flax_params, size=args.model) |
|
|
| total_params = sum(w.size for w in mlx_weights.values()) |
| print(f" Parameters : {total_params:,}") |
| print(f" Weight keys: {len(mlx_weights)}") |
|
|
| if output_path.endswith(".safetensors"): |
| try: |
| save_safetensors(mlx_weights, output_path) |
| except ImportError: |
| output_path = output_path.replace(".safetensors", ".npz") |
| print("safetensors not installed β saving as .npz instead") |
| save_npz(mlx_weights, output_path) |
| else: |
| save_npz(mlx_weights, output_path) |
|
|
| print(f"Saved β {output_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|