thera-mlx / ui.py
mlmPenguin's picture
Add source code
29e0144 verified
#!/usr/bin/env python3
"""
Web UI for Thera MLX super-resolution.
Supports single image, batch, and video upscaling.
Usage:
python3 ui.py
python3 ui.py --port 8080
"""
import argparse
import glob
import io
import json
import os
import shutil
import subprocess
import tempfile
import threading
import time
import uuid
import zipfile
import mlx.core as mx
import numpy as np
from flask import Flask, request, jsonify, send_file, Response
from PIL import Image
from model import Thera
def _find_ffmpeg():
"""Find ffmpeg binary — system PATH first, then imageio_ffmpeg fallback."""
path = shutil.which("ffmpeg")
if path:
return path
try:
import imageio_ffmpeg
return imageio_ffmpeg.get_ffmpeg_exe()
except ImportError:
return None
def _find_ffprobe():
"""Find ffprobe binary — system PATH first, then imageio_ffmpeg fallback."""
path = shutil.which("ffprobe")
if path:
return path
try:
import imageio_ffmpeg
ff = imageio_ffmpeg.get_ffmpeg_exe()
# ffprobe is in the same directory
probe = os.path.join(os.path.dirname(ff),
ff.replace("ffmpeg", "ffprobe").split("/")[-1])
if os.path.exists(probe):
return probe
# some builds bundle it as ffprobe next to ffmpeg
probe2 = ff.replace("ffmpeg", "ffprobe")
if os.path.exists(probe2):
return probe2
except ImportError:
pass
return None
FFMPEG = _find_ffmpeg()
FFPROBE = _find_ffprobe()
# ---------------------------------------------------------------------------
# App setup
# ---------------------------------------------------------------------------
app = Flask(__name__)
UPLOAD_DIR = tempfile.mkdtemp(prefix="thera_ui_")
# Job tracking for async video processing
_jobs = {}
_jobs_lock = threading.Lock()
# ---------------------------------------------------------------------------
# Model cache
# ---------------------------------------------------------------------------
_model_cache = {}
def get_model(size):
if size not in _model_cache:
from upscale import load_weights
model = Thera(size=size)
weights_dir = os.path.join(os.path.dirname(__file__), "weights")
weights_path = os.path.join(weights_dir, f"weights-{size}.safetensors")
if not os.path.exists(weights_path):
raise FileNotFoundError(
f"Weights not found: {weights_path}\n"
f"Run: python3 convert.py --model {size}")
model = load_weights(model, weights_path)
mx.eval(model.parameters())
_model_cache[size] = model
return _model_cache[size]
# ---------------------------------------------------------------------------
# Core upscale
# ---------------------------------------------------------------------------
def upscale_image(img_np, scale, model_size, ensemble, tiles=None):
model = get_model(model_size)
h, w = img_np.shape[:2]
th, tw = round(h * scale), round(w * scale)
source_f = img_np.astype(np.float32) / 255.0
if tiles and tiles > 1:
from upscale import upscale_tiled
return upscale_tiled(model, source_f, th, tw, tiles, ensemble=ensemble)
else:
source = mx.array(source_f)
result = model.upscale(source, th, tw, ensemble=ensemble)
mx.eval(result)
return np.array(result)
# ---------------------------------------------------------------------------
# API routes
# ---------------------------------------------------------------------------
@app.route("/api/upscale", methods=["POST"])
def api_upscale():
f = request.files.get("image")
if not f:
return jsonify(error="No image uploaded"), 400
scale = float(request.form.get("scale", 2.0))
model_size = request.form.get("model", "air")
ensemble = request.form.get("ensemble", "false") == "true"
tiles_str = request.form.get("tiles", "1")
tiles = int(tiles_str) if tiles_str.isdigit() and int(tiles_str) > 1 else None
img = np.array(Image.open(f.stream).convert("RGB"))
t0 = time.perf_counter()
result = upscale_image(img, scale, model_size, ensemble, tiles=tiles)
elapsed = time.perf_counter() - t0
buf = io.BytesIO()
Image.fromarray(result).save(buf, format="PNG")
buf.seek(0)
h, w = img.shape[:2]
th, tw = result.shape[:2]
resp = send_file(buf, mimetype="image/png", download_name="upscaled.png")
resp.headers["X-Info"] = json.dumps({
"src": f"{w}x{h}", "dst": f"{tw}x{th}",
"scale": scale, "model": model_size, "time": round(elapsed, 1)
})
return resp
@app.route("/api/batch", methods=["POST"])
def api_batch():
files = request.files.getlist("images")
if not files:
return jsonify(error="No images uploaded"), 400
scale = float(request.form.get("scale", 2.0))
model_size = request.form.get("model", "air")
ensemble = request.form.get("ensemble", "false") == "true"
buf = io.BytesIO()
t0 = time.perf_counter()
with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf:
for f in files:
img = np.array(Image.open(f.stream).convert("RGB"))
result = upscale_image(img, scale, model_size, ensemble)
img_buf = io.BytesIO()
Image.fromarray(result).save(img_buf, format="PNG")
name = os.path.splitext(f.filename)[0] + f"_thera_{scale}x.png"
zf.writestr(name, img_buf.getvalue())
elapsed = time.perf_counter() - t0
buf.seek(0)
resp = send_file(buf, mimetype="application/zip",
download_name="thera_batch.zip")
resp.headers["X-Info"] = json.dumps({
"count": len(files), "scale": scale,
"model": model_size, "time": round(elapsed, 1)
})
return resp
# ---------------------------------------------------------------------------
# Video processing (async with progress)
# ---------------------------------------------------------------------------
def get_video_info(path):
"""Get video metadata using ffprobe if available, else parse ffmpeg stderr."""
if FFPROBE:
cmd = [
FFPROBE, "-v", "quiet", "-print_format", "json",
"-show_streams", "-show_format", path
]
result = subprocess.run(cmd, capture_output=True, text=True)
info = json.loads(result.stdout)
stream = next(s for s in info["streams"] if s["codec_type"] == "video")
fps_parts = stream["r_frame_rate"].split("/")
fps = float(fps_parts[0]) / float(fps_parts[1])
w, h = int(stream["width"]), int(stream["height"])
duration = float(info["format"].get("duration", 0))
return fps, w, h, duration
else:
# Fallback: parse ffmpeg -i stderr
import re
result = subprocess.run(
[FFMPEG, "-i", path], capture_output=True, text=True)
stderr = result.stderr
# Parse "Duration: 00:00:10.00"
dur_m = re.search(r"Duration:\s*(\d+):(\d+):(\d+\.\d+)", stderr)
duration = 0.0
if dur_m:
duration = int(dur_m[1]) * 3600 + int(dur_m[2]) * 60 + float(dur_m[3])
# Parse "1920x1080" and fps
vid_m = re.search(r"(\d{2,5})x(\d{2,5})", stderr)
w, h = (int(vid_m[1]), int(vid_m[2])) if vid_m else (0, 0)
fps_m = re.search(r"(\d+(?:\.\d+)?)\s*fps", stderr)
fps = float(fps_m[1]) if fps_m else 30.0
return fps, w, h, duration
def has_audio(path):
"""Check if video has an audio stream."""
if FFPROBE:
probe = subprocess.run(
[FFPROBE, "-v", "quiet", "-select_streams", "a",
"-show_entries", "stream=codec_type", path],
capture_output=True, text=True)
return "audio" in probe.stdout
else:
result = subprocess.run(
[FFMPEG, "-i", path], capture_output=True, text=True)
return "Audio:" in result.stderr
def video_worker(job_id, video_path, scale, model_size):
job = _jobs[job_id]
try:
job["status"] = "analyzing"
fps, src_w, src_h, duration = get_video_info(video_path)
tw, th = round(src_w * scale), round(src_h * scale)
job["src"] = f"{src_w}x{src_h}"
job["dst"] = f"{tw}x{th}"
tmpdir = tempfile.mkdtemp(prefix="thera_vid_")
frames_dir = os.path.join(tmpdir, "frames")
upscaled_dir = os.path.join(tmpdir, "upscaled")
os.makedirs(frames_dir)
os.makedirs(upscaled_dir)
# Extract frames
job["status"] = "extracting"
subprocess.run([
FFMPEG, "-i", video_path, "-vsync", "0",
os.path.join(frames_dir, "frame_%06d.png")
], capture_output=True, check=True)
frame_files = sorted(glob.glob(os.path.join(frames_dir, "frame_*.png")))
total = len(frame_files)
job["total_frames"] = total
# Upscale frames
job["status"] = "upscaling"
model = get_model(model_size)
t0 = time.perf_counter()
for i, frame_path in enumerate(frame_files):
img = np.array(Image.open(frame_path).convert("RGB"))
source = mx.array(img.astype(np.float32) / 255.0)
fh, fw = img.shape[:2]
result = model.upscale(source, round(fh * scale), round(fw * scale))
mx.eval(result)
out_path = os.path.join(upscaled_dir, os.path.basename(frame_path))
Image.fromarray(np.array(result)).save(out_path)
elapsed = time.perf_counter() - t0
eta = (elapsed / (i + 1)) * (total - i - 1) if i > 0 else 0
job["current_frame"] = i + 1
job["eta"] = round(eta)
job["fps"] = round((i + 1) / elapsed, 1) if elapsed > 0 else 0
# Encode
job["status"] = "encoding"
output_path = os.path.join(tmpdir, "upscaled.mp4")
ffmpeg_cmd = [
FFMPEG, "-y",
"-framerate", str(fps),
"-i", os.path.join(upscaled_dir, "frame_%06d.png"),
]
# Check for audio
audio = has_audio(video_path)
if audio:
ffmpeg_cmd += ["-i", video_path, "-map", "0:v", "-map", "1:a",
"-shortest"]
ffmpeg_cmd += [
"-c:v", "libx264", "-preset", "medium", "-crf", "18",
"-pix_fmt", "yuv420p",
]
if audio:
ffmpeg_cmd += ["-c:a", "aac", "-b:a", "192k"]
ffmpeg_cmd.append(output_path)
subprocess.run(ffmpeg_cmd, capture_output=True, check=True)
total_time = time.perf_counter() - t0
job["status"] = "done"
job["output_path"] = output_path
job["time"] = round(total_time, 1)
job["tmpdir"] = tmpdir
except Exception as e:
job["status"] = "error"
job["error"] = str(e)
@app.route("/api/video/start", methods=["POST"])
def api_video_start():
f = request.files.get("video")
if not f:
return jsonify(error="No video uploaded"), 400
if not FFMPEG:
return jsonify(error="ffmpeg not found. Install with: pip3 install imageio[ffmpeg]"), 400
scale = float(request.form.get("scale", 2.0))
model_size = request.form.get("model", "air")
job_id = str(uuid.uuid4())[:8]
video_path = os.path.join(UPLOAD_DIR, f"{job_id}_input.mp4")
f.save(video_path)
with _jobs_lock:
_jobs[job_id] = {
"status": "queued", "current_frame": 0, "total_frames": 0,
"eta": 0, "fps": 0, "scale": scale, "model": model_size,
}
t = threading.Thread(target=video_worker,
args=(job_id, video_path, scale, model_size),
daemon=True)
t.start()
return jsonify(job_id=job_id)
@app.route("/api/video/progress/<job_id>")
def api_video_progress(job_id):
job = _jobs.get(job_id)
if not job:
return jsonify(error="Unknown job"), 404
safe = {k: v for k, v in job.items()
if k not in ("output_path", "tmpdir")}
return jsonify(safe)
@app.route("/api/video/download/<job_id>")
def api_video_download(job_id):
job = _jobs.get(job_id)
if not job or job["status"] != "done":
return jsonify(error="Not ready"), 400
return send_file(job["output_path"], mimetype="video/mp4",
download_name="thera_upscaled.mp4")
# ---------------------------------------------------------------------------
# Frontend
# ---------------------------------------------------------------------------
@app.route("/")
def index():
return HTML_PAGE
HTML_PAGE = r"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Thera MLX</title>
<style>
:root {
--bg: #0f0f0f;
--surface: #1a1a1a;
--surface2: #242424;
--border: #333;
--text: #e8e8e8;
--text2: #999;
--accent: #6c63ff;
--accent-hover: #7c74ff;
--green: #4caf50;
--radius: 12px;
}
* { box-sizing: border-box; margin: 0; padding: 0; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'SF Pro Text', system-ui, sans-serif;
background: var(--bg);
color: var(--text);
min-height: 100vh;
}
.container { max-width: 960px; margin: 0 auto; padding: 24px 20px; }
header {
text-align: center;
padding: 32px 0 24px;
}
header h1 {
font-size: 28px;
font-weight: 700;
letter-spacing: -0.5px;
}
header h1 span { color: var(--accent); }
header p { color: var(--text2); margin-top: 4px; font-size: 14px; }
/* Tabs */
.tabs {
display: flex;
gap: 4px;
background: var(--surface);
border-radius: var(--radius);
padding: 4px;
margin-bottom: 20px;
}
.tab {
flex: 1;
padding: 10px 16px;
border: none;
background: none;
color: var(--text2);
font-size: 14px;
font-weight: 500;
cursor: pointer;
border-radius: 8px;
transition: all 0.2s;
}
.tab:hover { color: var(--text); }
.tab.active { background: var(--accent); color: #fff; }
.tab-content { display: none; }
.tab-content.active { display: block; }
/* Controls */
.controls {
background: var(--surface);
border-radius: var(--radius);
padding: 20px;
margin-bottom: 16px;
}
.control-row {
display: flex;
gap: 16px;
align-items: center;
flex-wrap: wrap;
}
.control-group {
display: flex;
flex-direction: column;
gap: 6px;
}
.control-group label {
font-size: 12px;
font-weight: 600;
color: var(--text2);
text-transform: uppercase;
letter-spacing: 0.5px;
}
.control-group select,
.control-group input[type="range"] {
background: var(--surface2);
border: 1px solid var(--border);
color: var(--text);
border-radius: 8px;
padding: 8px 12px;
font-size: 14px;
}
.control-group select { min-width: 100px; cursor: pointer; }
.control-group input[type="range"] { width: 160px; accent-color: var(--accent); }
.scale-display {
font-size: 20px;
font-weight: 700;
color: var(--accent);
min-width: 40px;
text-align: center;
}
.checkbox-label {
display: flex;
align-items: center;
gap: 8px;
font-size: 14px;
cursor: pointer;
color: var(--text2);
}
.checkbox-label input { accent-color: var(--accent); }
/* Drop zone */
.dropzone {
border: 2px dashed var(--border);
border-radius: var(--radius);
padding: 48px 24px;
text-align: center;
cursor: pointer;
transition: all 0.2s;
background: var(--surface);
margin-bottom: 16px;
}
.dropzone:hover, .dropzone.dragover {
border-color: var(--accent);
background: rgba(108, 99, 255, 0.05);
}
.dropzone.has-file {
padding: 16px;
border-style: solid;
border-color: var(--green);
}
.dropzone-icon { font-size: 36px; margin-bottom: 8px; }
.dropzone-text { color: var(--text2); font-size: 14px; }
.dropzone-text strong { color: var(--text); }
/* Preview panels */
.preview-area {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 16px;
margin-bottom: 16px;
}
.preview-panel {
background: var(--surface);
border-radius: var(--radius);
overflow: hidden;
min-height: 200px;
position: relative;
}
.preview-panel .panel-label {
position: absolute;
top: 8px;
left: 12px;
font-size: 11px;
font-weight: 600;
text-transform: uppercase;
color: var(--text2);
background: rgba(0,0,0,0.6);
padding: 3px 8px;
border-radius: 4px;
z-index: 1;
}
.preview-panel img, .preview-panel video {
width: 100%;
height: 100%;
object-fit: contain;
display: block;
}
/* Compare slider */
.compare-container {
display: none;
position: relative;
background: var(--surface);
border-radius: var(--radius);
overflow: hidden;
margin-bottom: 16px;
cursor: col-resize;
user-select: none;
}
.compare-container.active { display: block; }
.compare-container img {
width: 100%;
display: block;
}
.compare-overlay {
position: absolute;
top: 0;
left: 0;
height: 100%;
overflow: hidden;
}
.compare-overlay img {
position: absolute;
top: 0;
left: 0;
height: 100%;
width: auto;
max-width: none;
}
.compare-line {
position: absolute;
top: 0;
width: 2px;
height: 100%;
background: #fff;
pointer-events: none;
z-index: 2;
box-shadow: 0 0 6px rgba(0,0,0,0.5);
}
.compare-line::after {
content: '';
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
width: 32px;
height: 32px;
background: #fff;
border-radius: 50%;
box-shadow: 0 2px 8px rgba(0,0,0,0.3);
}
.compare-line::before {
content: '';
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
width: 16px;
height: 16px;
background: var(--accent);
border-radius: 50%;
z-index: 1;
}
.compare-label {
position: absolute;
top: 8px;
font-size: 11px;
font-weight: 600;
text-transform: uppercase;
color: var(--text2);
background: rgba(0,0,0,0.6);
padding: 3px 8px;
border-radius: 4px;
z-index: 3;
pointer-events: none;
}
.compare-label-before { left: 12px; }
.compare-label-after { right: 12px; }
/* Buttons */
.btn {
padding: 12px 32px;
border: none;
border-radius: 8px;
font-size: 15px;
font-weight: 600;
cursor: pointer;
transition: all 0.2s;
display: inline-flex;
align-items: center;
gap: 8px;
}
.btn-primary {
background: var(--accent);
color: #fff;
}
.btn-primary:hover { background: var(--accent-hover); }
.btn-primary:disabled {
opacity: 0.5;
cursor: not-allowed;
}
.btn-secondary {
background: var(--surface2);
color: var(--text);
border: 1px solid var(--border);
}
.btn-secondary:hover { background: var(--border); }
.action-row {
display: flex;
gap: 12px;
align-items: center;
flex-wrap: wrap;
}
/* Progress */
.progress-container {
display: none;
margin-bottom: 16px;
}
.progress-container.active { display: block; }
.progress-bar-bg {
background: var(--surface2);
border-radius: 8px;
height: 8px;
overflow: hidden;
margin-bottom: 8px;
}
.progress-bar {
height: 100%;
background: var(--accent);
border-radius: 8px;
transition: width 0.3s;
width: 0%;
}
.progress-text {
font-size: 13px;
color: var(--text2);
font-family: 'SF Mono', monospace;
}
/* Info */
.info-bar {
background: var(--surface);
border-radius: 8px;
padding: 10px 16px;
font-family: 'SF Mono', monospace;
font-size: 13px;
color: var(--text2);
display: none;
}
.info-bar.active { display: block; }
/* Batch gallery */
.gallery {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(180px, 1fr));
gap: 12px;
margin-bottom: 16px;
}
.gallery-item {
background: var(--surface);
border-radius: 8px;
overflow: hidden;
cursor: pointer;
transition: transform 0.2s;
}
.gallery-item:hover { transform: scale(1.02); }
.gallery-item img {
width: 100%;
aspect-ratio: 1;
object-fit: cover;
}
.gallery-item .name {
padding: 6px 10px;
font-size: 11px;
color: var(--text2);
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
}
/* File list */
.file-list {
display: flex;
flex-wrap: wrap;
gap: 8px;
margin-top: 8px;
}
.file-chip {
background: var(--surface2);
border-radius: 6px;
padding: 4px 10px;
font-size: 12px;
color: var(--text2);
display: flex;
align-items: center;
gap: 6px;
}
.file-chip .remove {
cursor: pointer;
color: #f44;
font-weight: bold;
}
/* Spinner */
.spinner {
display: inline-block;
width: 18px;
height: 18px;
border: 2px solid rgba(255,255,255,0.3);
border-top-color: #fff;
border-radius: 50%;
animation: spin 0.8s linear infinite;
}
@keyframes spin { to { transform: rotate(360deg); } }
/* Toast */
.toast {
position: fixed;
bottom: 24px;
right: 24px;
background: var(--surface2);
border: 1px solid var(--border);
border-radius: 8px;
padding: 12px 20px;
font-size: 14px;
z-index: 100;
display: none;
animation: slideUp 0.3s;
}
@keyframes slideUp {
from { transform: translateY(20px); opacity: 0; }
to { transform: translateY(0); opacity: 1; }
}
@media (max-width: 640px) {
.preview-area { grid-template-columns: 1fr; }
.control-row { flex-direction: column; align-items: stretch; }
}
</style>
</head>
<body>
<div class="container">
<header>
<h1>Thera <span>MLX</span></h1>
<p>Arbitrary-scale super-resolution on Apple Silicon</p>
</header>
<div class="tabs">
<button class="tab active" onclick="switchTab('single')">Image</button>
<button class="tab" onclick="switchTab('batch')">Batch</button>
<button class="tab" onclick="switchTab('video')">Video</button>
</div>
<!-- ==================== SINGLE IMAGE ==================== -->
<div id="tab-single" class="tab-content active">
<div class="controls">
<div class="control-row">
<div class="control-group">
<label>Scale</label>
<div style="display:flex;align-items:center;gap:8px">
<input type="range" id="s-scale" min="1" max="8" step="0.5" value="2"
oninput="document.getElementById('s-scale-val').textContent=this.value+'x'">
<span class="scale-display" id="s-scale-val">2x</span>
</div>
</div>
<div class="control-group">
<label>Model</label>
<select id="s-model">
<option value="air">Air (fast)</option>
<option value="pro">Pro (quality)</option>
</select>
</div>
<label class="checkbox-label">
<input type="checkbox" id="s-ensemble"> Ensemble
</label>
<div class="control-group">
<label>Tiles</label>
<select id="s-tiles">
<option value="1">Off</option>
<option value="2">2x2</option>
<option value="3">3x3</option>
<option value="4">4x4</option>
</select>
</div>
</div>
</div>
<div class="dropzone" id="s-dropzone"
onclick="document.getElementById('s-file').click()">
<div class="dropzone-icon">🖼</div>
<div class="dropzone-text">
<strong>Drop image here</strong> or click to browse
</div>
<input type="file" id="s-file" accept="image/*" hidden
onchange="handleSingleFile(this.files[0])">
</div>
<div class="preview-area" id="s-preview" style="display:none">
<div class="preview-panel">
<span class="panel-label">Input</span>
<img id="s-input-img">
</div>
<div class="preview-panel">
<span class="panel-label">Output</span>
<img id="s-output-img">
</div>
</div>
<div class="compare-container" id="s-compare">
<span class="compare-label compare-label-before">Before</span>
<span class="compare-label compare-label-after">After</span>
<img id="s-compare-after" style="width:100%;display:block">
<div class="compare-overlay" id="s-compare-overlay">
<img id="s-compare-before">
</div>
<div class="compare-line" id="s-compare-line"></div>
</div>
<div class="progress-container" id="s-progress">
<div class="progress-bar-bg"><div class="progress-bar" id="s-pbar"></div></div>
<div class="progress-text" id="s-ptext">Processing...</div>
</div>
<div class="action-row">
<button class="btn btn-primary" id="s-btn" onclick="doSingleUpscale()" disabled>
Upscale
</button>
<button class="btn btn-secondary" id="s-compare-btn" style="display:none"
onclick="toggleCompare()">
Compare
</button>
<button class="btn btn-secondary" id="s-download" style="display:none"
onclick="downloadSingle()">
Download PNG
</button>
</div>
<div class="info-bar" id="s-info" style="margin-top:12px"></div>
</div>
<!-- ==================== BATCH ==================== -->
<div id="tab-batch" class="tab-content">
<div class="controls">
<div class="control-row">
<div class="control-group">
<label>Scale</label>
<div style="display:flex;align-items:center;gap:8px">
<input type="range" id="b-scale" min="1" max="8" step="0.5" value="2"
oninput="document.getElementById('b-scale-val').textContent=this.value+'x'">
<span class="scale-display" id="b-scale-val">2x</span>
</div>
</div>
<div class="control-group">
<label>Model</label>
<select id="b-model">
<option value="air">Air (fast)</option>
<option value="pro">Pro (quality)</option>
</select>
</div>
<label class="checkbox-label">
<input type="checkbox" id="b-ensemble"> Ensemble
</label>
</div>
</div>
<div class="dropzone" id="b-dropzone"
onclick="document.getElementById('b-file').click()">
<div class="dropzone-icon">📁</div>
<div class="dropzone-text">
<strong>Drop images here</strong> or click to browse (multiple)
</div>
<input type="file" id="b-file" accept="image/*" multiple hidden
onchange="handleBatchFiles(this.files)">
<div class="file-list" id="b-filelist"></div>
</div>
<div class="progress-container" id="b-progress">
<div class="progress-bar-bg"><div class="progress-bar" id="b-pbar"></div></div>
<div class="progress-text" id="b-ptext">Processing...</div>
</div>
<div class="action-row">
<button class="btn btn-primary" id="b-btn" onclick="doBatchUpscale()" disabled>
Upscale All
</button>
<button class="btn btn-secondary" id="b-download" style="display:none"
onclick="downloadBatch()">
Download ZIP
</button>
</div>
<div class="info-bar" id="b-info" style="margin-top:12px"></div>
</div>
<!-- ==================== VIDEO ==================== -->
<div id="tab-video" class="tab-content">
<div class="controls">
<div class="control-row">
<div class="control-group">
<label>Scale</label>
<div style="display:flex;align-items:center;gap:8px">
<input type="range" id="v-scale" min="1" max="4" step="0.5" value="2"
oninput="document.getElementById('v-scale-val').textContent=this.value+'x'">
<span class="scale-display" id="v-scale-val">2x</span>
</div>
</div>
<div class="control-group">
<label>Model</label>
<select id="v-model">
<option value="air">Air (fast)</option>
<option value="pro">Pro (quality)</option>
</select>
</div>
</div>
</div>
<div class="dropzone" id="v-dropzone"
onclick="document.getElementById('v-file').click()">
<div class="dropzone-icon">🎬</div>
<div class="dropzone-text">
<strong>Drop video here</strong> or click to browse
</div>
<input type="file" id="v-file" accept="video/*" hidden
onchange="handleVideoFile(this.files[0])">
</div>
<div class="preview-area" id="v-preview" style="display:none">
<div class="preview-panel">
<span class="panel-label">Input</span>
<video id="v-input-vid" controls muted></video>
</div>
<div class="preview-panel">
<span class="panel-label">Output</span>
<video id="v-output-vid" controls></video>
</div>
</div>
<div class="progress-container" id="v-progress">
<div class="progress-bar-bg"><div class="progress-bar" id="v-pbar"></div></div>
<div class="progress-text" id="v-ptext">Processing...</div>
</div>
<div class="action-row">
<button class="btn btn-primary" id="v-btn" onclick="doVideoUpscale()" disabled>
Upscale Video
</button>
<button class="btn btn-secondary" id="v-download" style="display:none"
onclick="downloadVideo()">
Download MP4
</button>
</div>
<div class="info-bar" id="v-info" style="margin-top:12px"></div>
</div>
</div>
<div class="toast" id="toast"></div>
<script>
// --- Tab switching ---
function switchTab(tab) {
document.querySelectorAll('.tab').forEach((el, i) => {
const tabs = ['single','batch','video'];
el.classList.toggle('active', tabs[i] === tab);
});
document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active'));
document.getElementById('tab-' + tab).classList.add('active');
}
// --- Toast ---
function toast(msg, duration=3000) {
const el = document.getElementById('toast');
el.textContent = msg;
el.style.display = 'block';
setTimeout(() => el.style.display = 'none', duration);
}
// --- Drag & drop ---
document.querySelectorAll('.dropzone').forEach(dz => {
dz.addEventListener('dragover', e => { e.preventDefault(); dz.classList.add('dragover'); });
dz.addEventListener('dragleave', () => dz.classList.remove('dragover'));
dz.addEventListener('drop', e => {
e.preventDefault();
dz.classList.remove('dragover');
const input = dz.querySelector('input[type="file"]');
if (input.multiple) {
handleBatchFiles(e.dataTransfer.files);
} else if (input.accept.includes('video')) {
handleVideoFile(e.dataTransfer.files[0]);
} else {
handleSingleFile(e.dataTransfer.files[0]);
}
});
});
// ==================== SINGLE ====================
let singleFile = null;
let singleBlob = null;
let singleInputUrl = null;
let compareActive = false;
function handleSingleFile(file) {
if (!file) return;
singleFile = file;
const dz = document.getElementById('s-dropzone');
dz.classList.add('has-file');
dz.querySelector('.dropzone-text').innerHTML = '<strong>' + file.name + '</strong>';
singleInputUrl = URL.createObjectURL(file);
document.getElementById('s-input-img').src = singleInputUrl;
document.getElementById('s-output-img').src = '';
document.getElementById('s-preview').style.display = 'grid';
document.getElementById('s-btn').disabled = false;
document.getElementById('s-download').style.display = 'none';
document.getElementById('s-compare-btn').style.display = 'none';
document.getElementById('s-info').classList.remove('active');
hideCompare();
}
async function doSingleUpscale() {
if (!singleFile) return;
const btn = document.getElementById('s-btn');
btn.disabled = true;
btn.innerHTML = '<span class="spinner"></span> Processing';
hideCompare();
const prog = document.getElementById('s-progress');
prog.classList.add('active');
document.getElementById('s-pbar').style.width = '60%';
const tiles = document.getElementById('s-tiles').value;
const tileLabel = tiles > 1 ? ` (${tiles}x${tiles} tiles)` : '';
document.getElementById('s-ptext').textContent = 'Upscaling' + tileLabel + '...';
const fd = new FormData();
fd.append('image', singleFile);
fd.append('scale', document.getElementById('s-scale').value);
fd.append('model', document.getElementById('s-model').value);
fd.append('ensemble', document.getElementById('s-ensemble').checked);
fd.append('tiles', tiles);
try {
const resp = await fetch('/api/upscale', { method: 'POST', body: fd });
if (!resp.ok) { const e = await resp.json(); throw new Error(e.error); }
const blob = await resp.blob();
singleBlob = blob;
const url = URL.createObjectURL(blob);
document.getElementById('s-output-img').src = url;
const info = JSON.parse(resp.headers.get('X-Info') || '{}');
document.getElementById('s-info').textContent =
`${info.src} \u2192 ${info.dst} (${info.scale}x) | ${info.model} | ${info.time}s`;
document.getElementById('s-info').classList.add('active');
document.getElementById('s-download').style.display = 'inline-flex';
document.getElementById('s-compare-btn').style.display = 'inline-flex';
document.getElementById('s-pbar').style.width = '100%';
toast('Upscale complete!');
} catch (e) {
toast('Error: ' + e.message, 5000);
} finally {
btn.disabled = false;
btn.innerHTML = 'Upscale';
prog.classList.remove('active');
}
}
function downloadSingle() {
if (!singleBlob) return;
const a = document.createElement('a');
a.href = URL.createObjectURL(singleBlob);
const name = singleFile.name.replace(/\.[^.]+$/, '') + '_thera.png';
a.download = name;
a.click();
}
// --- Compare split-view ---
function toggleCompare() {
if (compareActive) {
hideCompare();
} else {
showCompare();
}
}
function showCompare() {
if (!singleBlob || !singleInputUrl) return;
compareActive = true;
const outputUrl = URL.createObjectURL(singleBlob);
const container = document.getElementById('s-compare');
const afterImg = document.getElementById('s-compare-after');
const beforeImg = document.getElementById('s-compare-before');
// Use the upscaled version to determine the display size,
// then scale the input up to match via CSS so it's a fair pixel comparison
afterImg.src = outputUrl;
beforeImg.src = singleInputUrl;
afterImg.onload = () => {
// Set before image width to match container
beforeImg.style.width = container.offsetWidth + 'px';
updateCompareSlider(0.5);
};
document.getElementById('s-preview').style.display = 'none';
container.classList.add('active');
document.getElementById('s-compare-btn').textContent = 'Side by Side';
}
function hideCompare() {
compareActive = false;
document.getElementById('s-compare').classList.remove('active');
if (singleFile) {
document.getElementById('s-preview').style.display = 'grid';
}
document.getElementById('s-compare-btn').textContent = 'Compare';
}
function updateCompareSlider(ratio) {
const container = document.getElementById('s-compare');
const overlay = document.getElementById('s-compare-overlay');
const line = document.getElementById('s-compare-line');
const w = container.offsetWidth;
const pos = Math.max(0, Math.min(w, w * ratio));
overlay.style.width = pos + 'px';
line.style.left = pos + 'px';
}
// Compare drag handling
(function() {
const container = document.getElementById('s-compare');
let dragging = false;
function onMove(clientX) {
const rect = container.getBoundingClientRect();
const ratio = (clientX - rect.left) / rect.width;
updateCompareSlider(Math.max(0, Math.min(1, ratio)));
}
container.addEventListener('mousedown', e => { dragging = true; onMove(e.clientX); });
window.addEventListener('mousemove', e => { if (dragging) onMove(e.clientX); });
window.addEventListener('mouseup', () => { dragging = false; });
container.addEventListener('touchstart', e => { dragging = true; onMove(e.touches[0].clientX); }, {passive: true});
window.addEventListener('touchmove', e => { if (dragging) onMove(e.touches[0].clientX); }, {passive: true});
window.addEventListener('touchend', () => { dragging = false; });
})();
// ==================== BATCH ====================
let batchFiles = [];
let batchBlob = null;
function handleBatchFiles(files) {
batchFiles = Array.from(files);
const dz = document.getElementById('b-dropzone');
dz.classList.add('has-file');
const list = document.getElementById('b-filelist');
list.innerHTML = batchFiles.map((f, i) =>
`<span class="file-chip">${f.name}
<span class="remove" onclick="event.stopPropagation();removeBatchFile(${i})">×</span>
</span>`
).join('');
dz.querySelector('.dropzone-text').innerHTML =
`<strong>${batchFiles.length} image${batchFiles.length>1?'s':''}</strong> selected`;
document.getElementById('b-btn').disabled = false;
}
function removeBatchFile(idx) {
batchFiles.splice(idx, 1);
if (batchFiles.length === 0) {
const dz = document.getElementById('b-dropzone');
dz.classList.remove('has-file');
dz.querySelector('.dropzone-text').innerHTML =
'<strong>Drop images here</strong> or click to browse (multiple)';
document.getElementById('b-filelist').innerHTML = '';
document.getElementById('b-btn').disabled = true;
} else {
handleBatchFiles(batchFiles);
}
}
async function doBatchUpscale() {
if (!batchFiles.length) return;
const btn = document.getElementById('b-btn');
btn.disabled = true;
btn.innerHTML = '<span class="spinner"></span> Processing';
const prog = document.getElementById('b-progress');
prog.classList.add('active');
document.getElementById('b-pbar').style.width = '30%';
document.getElementById('b-ptext').textContent =
`Upscaling ${batchFiles.length} images...`;
const fd = new FormData();
batchFiles.forEach(f => fd.append('images', f));
fd.append('scale', document.getElementById('b-scale').value);
fd.append('model', document.getElementById('b-model').value);
fd.append('ensemble', document.getElementById('b-ensemble').checked);
try {
const resp = await fetch('/api/batch', { method: 'POST', body: fd });
if (!resp.ok) { const e = await resp.json(); throw new Error(e.error); }
batchBlob = await resp.blob();
const info = JSON.parse(resp.headers.get('X-Info') || '{}');
document.getElementById('b-info').textContent =
`${info.count} images | ${info.scale}x | ${info.model} | ${info.time}s`;
document.getElementById('b-info').classList.add('active');
document.getElementById('b-download').style.display = 'inline-flex';
document.getElementById('b-pbar').style.width = '100%';
toast(`Batch complete — ${info.count} images`);
} catch (e) {
toast('Error: ' + e.message, 5000);
} finally {
btn.disabled = false;
btn.innerHTML = 'Upscale All';
prog.classList.remove('active');
}
}
function downloadBatch() {
if (!batchBlob) return;
const a = document.createElement('a');
a.href = URL.createObjectURL(batchBlob);
a.download = 'thera_batch.zip';
a.click();
}
// ==================== VIDEO ====================
let videoFile = null;
let videoJobId = null;
function handleVideoFile(file) {
if (!file) return;
videoFile = file;
const dz = document.getElementById('v-dropzone');
dz.classList.add('has-file');
dz.querySelector('.dropzone-text').innerHTML = '<strong>' + file.name + '</strong>';
const url = URL.createObjectURL(file);
document.getElementById('v-input-vid').src = url;
document.getElementById('v-output-vid').src = '';
document.getElementById('v-preview').style.display = 'grid';
document.getElementById('v-btn').disabled = false;
document.getElementById('v-download').style.display = 'none';
document.getElementById('v-info').classList.remove('active');
}
async function doVideoUpscale() {
if (!videoFile) return;
const btn = document.getElementById('v-btn');
btn.disabled = true;
btn.innerHTML = '<span class="spinner"></span> Processing';
const prog = document.getElementById('v-progress');
prog.classList.add('active');
document.getElementById('v-pbar').style.width = '2%';
document.getElementById('v-ptext').textContent = 'Uploading...';
const fd = new FormData();
fd.append('video', videoFile);
fd.append('scale', document.getElementById('v-scale').value);
fd.append('model', document.getElementById('v-model').value);
try {
const resp = await fetch('/api/video/start', { method: 'POST', body: fd });
if (!resp.ok) { const e = await resp.json(); throw new Error(e.error); }
const data = await resp.json();
videoJobId = data.job_id;
pollVideoProgress();
} catch (e) {
toast('Error: ' + e.message, 5000);
btn.disabled = false;
btn.innerHTML = 'Upscale Video';
prog.classList.remove('active');
}
}
async function pollVideoProgress() {
if (!videoJobId) return;
try {
const resp = await fetch('/api/video/progress/' + videoJobId);
const job = await resp.json();
const pbar = document.getElementById('v-pbar');
const ptext = document.getElementById('v-ptext');
if (job.status === 'extracting') {
pbar.style.width = '5%';
ptext.textContent = 'Extracting frames...';
} else if (job.status === 'upscaling') {
const pct = 5 + 85 * (job.current_frame / Math.max(job.total_frames, 1));
pbar.style.width = pct + '%';
ptext.textContent =
`Frame ${job.current_frame}/${job.total_frames} | ` +
`${job.fps} fps | ETA ${job.eta}s`;
} else if (job.status === 'encoding') {
pbar.style.width = '92%';
ptext.textContent = 'Encoding video...';
} else if (job.status === 'done') {
pbar.style.width = '100%';
ptext.textContent = 'Complete!';
document.getElementById('v-output-vid').src =
'/api/video/download/' + videoJobId;
document.getElementById('v-info').textContent =
`${job.src} → ${job.dst} (${job.scale}x) | ` +
`${job.total_frames} frames | ${job.time}s | ${job.model}`;
document.getElementById('v-info').classList.add('active');
document.getElementById('v-download').style.display = 'inline-flex';
document.getElementById('v-btn').disabled = false;
document.getElementById('v-btn').innerHTML = 'Upscale Video';
toast('Video upscale complete!');
return;
} else if (job.status === 'error') {
toast('Error: ' + job.error, 5000);
document.getElementById('v-btn').disabled = false;
document.getElementById('v-btn').innerHTML = 'Upscale Video';
document.getElementById('v-progress').classList.remove('active');
return;
}
setTimeout(pollVideoProgress, 1000);
} catch (e) {
setTimeout(pollVideoProgress, 2000);
}
}
function downloadVideo() {
if (!videoJobId) return;
const a = document.createElement('a');
a.href = '/api/video/download/' + videoJobId;
a.download = 'thera_upscaled.mp4';
a.click();
}
</script>
</body>
</html>
"""
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import logging
logging.getLogger("werkzeug").setLevel(logging.ERROR)
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=5005)
parser.add_argument("--host", type=str, default="127.0.0.1")
args = parser.parse_args()
print(f"Thera MLX → http://{args.host}:{args.port}")
app.run(host=args.host, port=args.port, debug=False, threaded=True)