#!/usr/bin/env python3 """ Web UI for Thera MLX super-resolution. Supports single image, batch, and video upscaling. Usage: python3 ui.py python3 ui.py --port 8080 """ import argparse import glob import io import json import os import shutil import subprocess import tempfile import threading import time import uuid import zipfile import mlx.core as mx import numpy as np from flask import Flask, request, jsonify, send_file, Response from PIL import Image from model import Thera def _find_ffmpeg(): """Find ffmpeg binary — system PATH first, then imageio_ffmpeg fallback.""" path = shutil.which("ffmpeg") if path: return path try: import imageio_ffmpeg return imageio_ffmpeg.get_ffmpeg_exe() except ImportError: return None def _find_ffprobe(): """Find ffprobe binary — system PATH first, then imageio_ffmpeg fallback.""" path = shutil.which("ffprobe") if path: return path try: import imageio_ffmpeg ff = imageio_ffmpeg.get_ffmpeg_exe() # ffprobe is in the same directory probe = os.path.join(os.path.dirname(ff), ff.replace("ffmpeg", "ffprobe").split("/")[-1]) if os.path.exists(probe): return probe # some builds bundle it as ffprobe next to ffmpeg probe2 = ff.replace("ffmpeg", "ffprobe") if os.path.exists(probe2): return probe2 except ImportError: pass return None FFMPEG = _find_ffmpeg() FFPROBE = _find_ffprobe() # --------------------------------------------------------------------------- # App setup # --------------------------------------------------------------------------- app = Flask(__name__) UPLOAD_DIR = tempfile.mkdtemp(prefix="thera_ui_") # Job tracking for async video processing _jobs = {} _jobs_lock = threading.Lock() # --------------------------------------------------------------------------- # Model cache # --------------------------------------------------------------------------- _model_cache = {} def get_model(size): if size not in _model_cache: from upscale import load_weights model = Thera(size=size) weights_dir = os.path.join(os.path.dirname(__file__), "weights") weights_path = os.path.join(weights_dir, f"weights-{size}.safetensors") if not os.path.exists(weights_path): raise FileNotFoundError( f"Weights not found: {weights_path}\n" f"Run: python3 convert.py --model {size}") model = load_weights(model, weights_path) mx.eval(model.parameters()) _model_cache[size] = model return _model_cache[size] # --------------------------------------------------------------------------- # Core upscale # --------------------------------------------------------------------------- def upscale_image(img_np, scale, model_size, ensemble, tiles=None): model = get_model(model_size) h, w = img_np.shape[:2] th, tw = round(h * scale), round(w * scale) source_f = img_np.astype(np.float32) / 255.0 if tiles and tiles > 1: from upscale import upscale_tiled return upscale_tiled(model, source_f, th, tw, tiles, ensemble=ensemble) else: source = mx.array(source_f) result = model.upscale(source, th, tw, ensemble=ensemble) mx.eval(result) return np.array(result) # --------------------------------------------------------------------------- # API routes # --------------------------------------------------------------------------- @app.route("/api/upscale", methods=["POST"]) def api_upscale(): f = request.files.get("image") if not f: return jsonify(error="No image uploaded"), 400 scale = float(request.form.get("scale", 2.0)) model_size = request.form.get("model", "air") ensemble = request.form.get("ensemble", "false") == "true" tiles_str = request.form.get("tiles", "1") tiles = int(tiles_str) if tiles_str.isdigit() and int(tiles_str) > 1 else None img = np.array(Image.open(f.stream).convert("RGB")) t0 = time.perf_counter() result = upscale_image(img, scale, model_size, ensemble, tiles=tiles) elapsed = time.perf_counter() - t0 buf = io.BytesIO() Image.fromarray(result).save(buf, format="PNG") buf.seek(0) h, w = img.shape[:2] th, tw = result.shape[:2] resp = send_file(buf, mimetype="image/png", download_name="upscaled.png") resp.headers["X-Info"] = json.dumps({ "src": f"{w}x{h}", "dst": f"{tw}x{th}", "scale": scale, "model": model_size, "time": round(elapsed, 1) }) return resp @app.route("/api/batch", methods=["POST"]) def api_batch(): files = request.files.getlist("images") if not files: return jsonify(error="No images uploaded"), 400 scale = float(request.form.get("scale", 2.0)) model_size = request.form.get("model", "air") ensemble = request.form.get("ensemble", "false") == "true" buf = io.BytesIO() t0 = time.perf_counter() with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf: for f in files: img = np.array(Image.open(f.stream).convert("RGB")) result = upscale_image(img, scale, model_size, ensemble) img_buf = io.BytesIO() Image.fromarray(result).save(img_buf, format="PNG") name = os.path.splitext(f.filename)[0] + f"_thera_{scale}x.png" zf.writestr(name, img_buf.getvalue()) elapsed = time.perf_counter() - t0 buf.seek(0) resp = send_file(buf, mimetype="application/zip", download_name="thera_batch.zip") resp.headers["X-Info"] = json.dumps({ "count": len(files), "scale": scale, "model": model_size, "time": round(elapsed, 1) }) return resp # --------------------------------------------------------------------------- # Video processing (async with progress) # --------------------------------------------------------------------------- def get_video_info(path): """Get video metadata using ffprobe if available, else parse ffmpeg stderr.""" if FFPROBE: cmd = [ FFPROBE, "-v", "quiet", "-print_format", "json", "-show_streams", "-show_format", path ] result = subprocess.run(cmd, capture_output=True, text=True) info = json.loads(result.stdout) stream = next(s for s in info["streams"] if s["codec_type"] == "video") fps_parts = stream["r_frame_rate"].split("/") fps = float(fps_parts[0]) / float(fps_parts[1]) w, h = int(stream["width"]), int(stream["height"]) duration = float(info["format"].get("duration", 0)) return fps, w, h, duration else: # Fallback: parse ffmpeg -i stderr import re result = subprocess.run( [FFMPEG, "-i", path], capture_output=True, text=True) stderr = result.stderr # Parse "Duration: 00:00:10.00" dur_m = re.search(r"Duration:\s*(\d+):(\d+):(\d+\.\d+)", stderr) duration = 0.0 if dur_m: duration = int(dur_m[1]) * 3600 + int(dur_m[2]) * 60 + float(dur_m[3]) # Parse "1920x1080" and fps vid_m = re.search(r"(\d{2,5})x(\d{2,5})", stderr) w, h = (int(vid_m[1]), int(vid_m[2])) if vid_m else (0, 0) fps_m = re.search(r"(\d+(?:\.\d+)?)\s*fps", stderr) fps = float(fps_m[1]) if fps_m else 30.0 return fps, w, h, duration def has_audio(path): """Check if video has an audio stream.""" if FFPROBE: probe = subprocess.run( [FFPROBE, "-v", "quiet", "-select_streams", "a", "-show_entries", "stream=codec_type", path], capture_output=True, text=True) return "audio" in probe.stdout else: result = subprocess.run( [FFMPEG, "-i", path], capture_output=True, text=True) return "Audio:" in result.stderr def video_worker(job_id, video_path, scale, model_size): job = _jobs[job_id] try: job["status"] = "analyzing" fps, src_w, src_h, duration = get_video_info(video_path) tw, th = round(src_w * scale), round(src_h * scale) job["src"] = f"{src_w}x{src_h}" job["dst"] = f"{tw}x{th}" tmpdir = tempfile.mkdtemp(prefix="thera_vid_") frames_dir = os.path.join(tmpdir, "frames") upscaled_dir = os.path.join(tmpdir, "upscaled") os.makedirs(frames_dir) os.makedirs(upscaled_dir) # Extract frames job["status"] = "extracting" subprocess.run([ FFMPEG, "-i", video_path, "-vsync", "0", os.path.join(frames_dir, "frame_%06d.png") ], capture_output=True, check=True) frame_files = sorted(glob.glob(os.path.join(frames_dir, "frame_*.png"))) total = len(frame_files) job["total_frames"] = total # Upscale frames job["status"] = "upscaling" model = get_model(model_size) t0 = time.perf_counter() for i, frame_path in enumerate(frame_files): img = np.array(Image.open(frame_path).convert("RGB")) source = mx.array(img.astype(np.float32) / 255.0) fh, fw = img.shape[:2] result = model.upscale(source, round(fh * scale), round(fw * scale)) mx.eval(result) out_path = os.path.join(upscaled_dir, os.path.basename(frame_path)) Image.fromarray(np.array(result)).save(out_path) elapsed = time.perf_counter() - t0 eta = (elapsed / (i + 1)) * (total - i - 1) if i > 0 else 0 job["current_frame"] = i + 1 job["eta"] = round(eta) job["fps"] = round((i + 1) / elapsed, 1) if elapsed > 0 else 0 # Encode job["status"] = "encoding" output_path = os.path.join(tmpdir, "upscaled.mp4") ffmpeg_cmd = [ FFMPEG, "-y", "-framerate", str(fps), "-i", os.path.join(upscaled_dir, "frame_%06d.png"), ] # Check for audio audio = has_audio(video_path) if audio: ffmpeg_cmd += ["-i", video_path, "-map", "0:v", "-map", "1:a", "-shortest"] ffmpeg_cmd += [ "-c:v", "libx264", "-preset", "medium", "-crf", "18", "-pix_fmt", "yuv420p", ] if audio: ffmpeg_cmd += ["-c:a", "aac", "-b:a", "192k"] ffmpeg_cmd.append(output_path) subprocess.run(ffmpeg_cmd, capture_output=True, check=True) total_time = time.perf_counter() - t0 job["status"] = "done" job["output_path"] = output_path job["time"] = round(total_time, 1) job["tmpdir"] = tmpdir except Exception as e: job["status"] = "error" job["error"] = str(e) @app.route("/api/video/start", methods=["POST"]) def api_video_start(): f = request.files.get("video") if not f: return jsonify(error="No video uploaded"), 400 if not FFMPEG: return jsonify(error="ffmpeg not found. Install with: pip3 install imageio[ffmpeg]"), 400 scale = float(request.form.get("scale", 2.0)) model_size = request.form.get("model", "air") job_id = str(uuid.uuid4())[:8] video_path = os.path.join(UPLOAD_DIR, f"{job_id}_input.mp4") f.save(video_path) with _jobs_lock: _jobs[job_id] = { "status": "queued", "current_frame": 0, "total_frames": 0, "eta": 0, "fps": 0, "scale": scale, "model": model_size, } t = threading.Thread(target=video_worker, args=(job_id, video_path, scale, model_size), daemon=True) t.start() return jsonify(job_id=job_id) @app.route("/api/video/progress/") def api_video_progress(job_id): job = _jobs.get(job_id) if not job: return jsonify(error="Unknown job"), 404 safe = {k: v for k, v in job.items() if k not in ("output_path", "tmpdir")} return jsonify(safe) @app.route("/api/video/download/") def api_video_download(job_id): job = _jobs.get(job_id) if not job or job["status"] != "done": return jsonify(error="Not ready"), 400 return send_file(job["output_path"], mimetype="video/mp4", download_name="thera_upscaled.mp4") # --------------------------------------------------------------------------- # Frontend # --------------------------------------------------------------------------- @app.route("/") def index(): return HTML_PAGE HTML_PAGE = r""" Thera MLX

Thera MLX

Arbitrary-scale super-resolution on Apple Silicon

2x
🖼
Drop image here or click to browse
Before After
Processing...
2x
📁
Drop images here or click to browse (multiple)
Processing...
2x
🎬
Drop video here or click to browse
Processing...
""" # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- if __name__ == "__main__": import logging logging.getLogger("werkzeug").setLevel(logging.ERROR) parser = argparse.ArgumentParser() parser.add_argument("--port", type=int, default=5005) parser.add_argument("--host", type=str, default="127.0.0.1") args = parser.parse_args() print(f"Thera MLX → http://{args.host}:{args.port}") app.run(host=args.host, port=args.port, debug=False, threaded=True)