| | import os |
| | import math |
| | import tempfile |
| | import numpy as np |
| | import requests |
| | from PIL import Image |
| | from flask import Flask, request, jsonify, send_file |
| | import onnxruntime as ort |
| |
|
| | |
| | MODEL_DIR = "model" |
| | MODEL_X2_PATH = os.path.join(MODEL_DIR, "Real-ESRGAN_x2plus.onnx") |
| | MODEL_X4_PATH = os.path.join(MODEL_DIR, "Real-ESRGAN-x4plus.onnx") |
| |
|
| | FILE_ID_X2 = "15xmXXZNH2wMyeQv4ie5hagT7eWK9MgP6" |
| | FILE_ID_X4 = "1wDBHad9RCJgJDGsPdapLYl3cr8j-PMJ6" |
| |
|
| | MAX_DIM = 1024 |
| |
|
| | app = Flask(__name__) |
| |
|
| | |
| | def download_from_drive(file_id, dest): |
| | url = "https://drive.google.com/uc?export=download" |
| | session = requests.Session() |
| | r = session.get(url, params={"id": file_id}, stream=True) |
| |
|
| | token = None |
| | for k, v in r.cookies.items(): |
| | if k.startswith("download_warning"): |
| | token = v |
| | break |
| |
|
| | if token: |
| | r = session.get(url, params={"id": file_id, "confirm": token}, stream=True) |
| |
|
| | os.makedirs(os.path.dirname(dest), exist_ok=True) |
| | with open(dest, "wb") as f: |
| | for chunk in r.iter_content(32768): |
| | if chunk: |
| | f.write(chunk) |
| |
|
| | if not os.path.exists(MODEL_X2_PATH): |
| | download_from_drive(FILE_ID_X2, MODEL_X2_PATH) |
| |
|
| | if not os.path.exists(MODEL_X4_PATH): |
| | download_from_drive(FILE_ID_X4, MODEL_X4_PATH) |
| |
|
| | |
| | opts = ort.SessionOptions() |
| | opts.intra_op_num_threads = 2 |
| | opts.inter_op_num_threads = 2 |
| |
|
| | sess_x2 = ort.InferenceSession(MODEL_X2_PATH, opts, providers=["CPUExecutionProvider"]) |
| | sess_x4 = ort.InferenceSession(MODEL_X4_PATH, opts, providers=["CPUExecutionProvider"]) |
| |
|
| | meta_x2 = sess_x2.get_inputs()[0] |
| | meta_x4 = sess_x4.get_inputs()[0] |
| |
|
| | _, _, H2, W2 = meta_x2.shape |
| | _, _, H4, W4 = meta_x4.shape |
| |
|
| | |
| | def run_tile(tile, session, meta): |
| | inp = np.transpose(tile, (2, 0, 1))[None, ...] |
| | out = session.run(None, {meta.name: inp})[0][0] |
| | return np.transpose(out, (1, 2, 0)) |
| |
|
| | def upscale_core(img: Image.Image, scale: int): |
| | if scale == 2: |
| | H, W, sess, meta, S = H2, W2, sess_x2, meta_x2, 2 |
| | else: |
| | H, W, sess, meta, S = H4, W4, sess_x4, meta_x4, 4 |
| |
|
| | w, h = img.size |
| | if max(w, h) > MAX_DIM: |
| | r = MAX_DIM / max(w, h) |
| | img = img.resize((int(w*r), int(h*r)), Image.LANCZOS) |
| |
|
| | arr = np.array(img.convert("RGB")).astype(np.float32) / 255.0 |
| | h0, w0, _ = arr.shape |
| |
|
| | th = math.ceil(h0 / H) |
| | tw = math.ceil(w0 / W) |
| |
|
| | pad = np.pad(arr, ((0, th*H-h0), (0, tw*W-w0), (0, 0)), mode="reflect") |
| | out = np.zeros((th*H*S, tw*W*S, 3), dtype=np.float32) |
| |
|
| | for i in range(th): |
| | for j in range(tw): |
| | tile = pad[i*H:(i+1)*H, j*W:(j+1)*W] |
| | up = run_tile(tile, sess, meta) |
| | out[i*H*S:(i+1)*H*S, j*W*S:(j+1)*W*S] = up |
| |
|
| | out = np.clip(out[:h0*S, :w0*S], 0, 1) |
| | return Image.fromarray((out * 255).astype(np.uint8)) |
| |
|
| | |
| | @app.route("/", methods=["GET"]) |
| | def index(): |
| | return jsonify({ |
| | "service": "SpectraGAN Upscaler API", |
| | "status": "running", |
| | "usage": "POST /upscale with image + mode=x2|x4|x8" |
| | }) |
| |
|
| | @app.route("/health", methods=["GET"]) |
| | def health(): |
| | return jsonify({"status": "ok"}) |
| |
|
| | @app.route("/upscale", methods=["POST"]) |
| | def upscale(): |
| | if "image" not in request.files: |
| | return jsonify({"error": "image file required"}), 400 |
| |
|
| | mode = request.form.get("mode", "x4") |
| | img = Image.open(request.files["image"]) |
| |
|
| | if mode == "x2": |
| | out = upscale_core(img, 2) |
| | elif mode == "x8": |
| | temp = upscale_core(img, 4) |
| | out = temp.resize((img.width * 8, img.height * 8), Image.LANCZOS) |
| | else: |
| | out = upscale_core(img, 4) |
| |
|
| | tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png") |
| | out.save(tmp.name) |
| |
|
| | return send_file(tmp.name, mimetype="image/png") |
| |
|
| | if __name__ == "__main__": |
| | app.run(host="0.0.0.0", port=7860) |