import os import math import tempfile import numpy as np import requests from PIL import Image from flask import Flask, request, jsonify, send_file import onnxruntime as ort # ================= CONFIG ================= MODEL_DIR = "model" MODEL_X2_PATH = os.path.join(MODEL_DIR, "Real-ESRGAN_x2plus.onnx") MODEL_X4_PATH = os.path.join(MODEL_DIR, "Real-ESRGAN-x4plus.onnx") FILE_ID_X2 = "15xmXXZNH2wMyeQv4ie5hagT7eWK9MgP6" FILE_ID_X4 = "1wDBHad9RCJgJDGsPdapLYl3cr8j-PMJ6" MAX_DIM = 1024 app = Flask(__name__) # ================= MODEL DOWNLOAD ================= def download_from_drive(file_id, dest): url = "https://drive.google.com/uc?export=download" session = requests.Session() r = session.get(url, params={"id": file_id}, stream=True) token = None for k, v in r.cookies.items(): if k.startswith("download_warning"): token = v break if token: r = session.get(url, params={"id": file_id, "confirm": token}, stream=True) os.makedirs(os.path.dirname(dest), exist_ok=True) with open(dest, "wb") as f: for chunk in r.iter_content(32768): if chunk: f.write(chunk) if not os.path.exists(MODEL_X2_PATH): download_from_drive(FILE_ID_X2, MODEL_X2_PATH) if not os.path.exists(MODEL_X4_PATH): download_from_drive(FILE_ID_X4, MODEL_X4_PATH) # ================= ONNX SESSIONS ================= opts = ort.SessionOptions() opts.intra_op_num_threads = 2 opts.inter_op_num_threads = 2 sess_x2 = ort.InferenceSession(MODEL_X2_PATH, opts, providers=["CPUExecutionProvider"]) sess_x4 = ort.InferenceSession(MODEL_X4_PATH, opts, providers=["CPUExecutionProvider"]) meta_x2 = sess_x2.get_inputs()[0] meta_x4 = sess_x4.get_inputs()[0] _, _, H2, W2 = meta_x2.shape _, _, H4, W4 = meta_x4.shape # ================= HELPERS ================= def run_tile(tile, session, meta): inp = np.transpose(tile, (2, 0, 1))[None, ...] out = session.run(None, {meta.name: inp})[0][0] return np.transpose(out, (1, 2, 0)) def upscale_core(img: Image.Image, scale: int): if scale == 2: H, W, sess, meta, S = H2, W2, sess_x2, meta_x2, 2 else: H, W, sess, meta, S = H4, W4, sess_x4, meta_x4, 4 w, h = img.size if max(w, h) > MAX_DIM: r = MAX_DIM / max(w, h) img = img.resize((int(w*r), int(h*r)), Image.LANCZOS) arr = np.array(img.convert("RGB")).astype(np.float32) / 255.0 h0, w0, _ = arr.shape th = math.ceil(h0 / H) tw = math.ceil(w0 / W) pad = np.pad(arr, ((0, th*H-h0), (0, tw*W-w0), (0, 0)), mode="reflect") out = np.zeros((th*H*S, tw*W*S, 3), dtype=np.float32) for i in range(th): for j in range(tw): tile = pad[i*H:(i+1)*H, j*W:(j+1)*W] up = run_tile(tile, sess, meta) out[i*H*S:(i+1)*H*S, j*W*S:(j+1)*W*S] = up out = np.clip(out[:h0*S, :w0*S], 0, 1) return Image.fromarray((out * 255).astype(np.uint8)) # ================= ROUTES ================= @app.route("/", methods=["GET"]) def index(): return jsonify({ "service": "SpectraGAN Upscaler API", "status": "running", "usage": "POST /upscale with image + mode=x2|x4|x8" }) @app.route("/health", methods=["GET"]) def health(): return jsonify({"status": "ok"}) @app.route("/upscale", methods=["POST"]) def upscale(): if "image" not in request.files: return jsonify({"error": "image file required"}), 400 mode = request.form.get("mode", "x4") img = Image.open(request.files["image"]) if mode == "x2": out = upscale_core(img, 2) elif mode == "x8": temp = upscale_core(img, 4) out = temp.resize((img.width * 8, img.height * 8), Image.LANCZOS) else: out = upscale_core(img, 4) tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png") out.save(tmp.name) return send_file(tmp.name, mimetype="image/png") if __name__ == "__main__": app.run(host="0.0.0.0", port=7860)