Spaces:
Sleeping
Sleeping
| import os | |
| import math | |
| import tempfile | |
| import numpy as np | |
| import requests | |
| from PIL import Image | |
| from flask import Flask, request, jsonify, send_file | |
| import onnxruntime as ort | |
| # ================= CONFIG ================= | |
| MODEL_DIR = "model" | |
| MODEL_X2_PATH = os.path.join(MODEL_DIR, "Real-ESRGAN_x2plus.onnx") | |
| MODEL_X4_PATH = os.path.join(MODEL_DIR, "Real-ESRGAN-x4plus.onnx") | |
| FILE_ID_X2 = "15xmXXZNH2wMyeQv4ie5hagT7eWK9MgP6" | |
| FILE_ID_X4 = "1wDBHad9RCJgJDGsPdapLYl3cr8j-PMJ6" | |
| MAX_DIM = 1024 | |
| app = Flask(__name__) | |
| # ================= MODEL DOWNLOAD ================= | |
| def download_from_drive(file_id, dest): | |
| url = "https://drive.google.com/uc?export=download" | |
| session = requests.Session() | |
| r = session.get(url, params={"id": file_id}, stream=True) | |
| token = None | |
| for k, v in r.cookies.items(): | |
| if k.startswith("download_warning"): | |
| token = v | |
| break | |
| if token: | |
| r = session.get(url, params={"id": file_id, "confirm": token}, stream=True) | |
| os.makedirs(os.path.dirname(dest), exist_ok=True) | |
| with open(dest, "wb") as f: | |
| for chunk in r.iter_content(32768): | |
| if chunk: | |
| f.write(chunk) | |
| if not os.path.exists(MODEL_X2_PATH): | |
| download_from_drive(FILE_ID_X2, MODEL_X2_PATH) | |
| if not os.path.exists(MODEL_X4_PATH): | |
| download_from_drive(FILE_ID_X4, MODEL_X4_PATH) | |
| # ================= ONNX SESSIONS ================= | |
| opts = ort.SessionOptions() | |
| opts.intra_op_num_threads = 2 | |
| opts.inter_op_num_threads = 2 | |
| sess_x2 = ort.InferenceSession(MODEL_X2_PATH, opts, providers=["CPUExecutionProvider"]) | |
| sess_x4 = ort.InferenceSession(MODEL_X4_PATH, opts, providers=["CPUExecutionProvider"]) | |
| meta_x2 = sess_x2.get_inputs()[0] | |
| meta_x4 = sess_x4.get_inputs()[0] | |
| _, _, H2, W2 = meta_x2.shape | |
| _, _, H4, W4 = meta_x4.shape | |
| # ================= HELPERS ================= | |
| def run_tile(tile, session, meta): | |
| inp = np.transpose(tile, (2, 0, 1))[None, ...] | |
| out = session.run(None, {meta.name: inp})[0][0] | |
| return np.transpose(out, (1, 2, 0)) | |
| def upscale_core(img: Image.Image, scale: int): | |
| if scale == 2: | |
| H, W, sess, meta, S = H2, W2, sess_x2, meta_x2, 2 | |
| else: | |
| H, W, sess, meta, S = H4, W4, sess_x4, meta_x4, 4 | |
| w, h = img.size | |
| if max(w, h) > MAX_DIM: | |
| r = MAX_DIM / max(w, h) | |
| img = img.resize((int(w*r), int(h*r)), Image.LANCZOS) | |
| arr = np.array(img.convert("RGB")).astype(np.float32) / 255.0 | |
| h0, w0, _ = arr.shape | |
| th = math.ceil(h0 / H) | |
| tw = math.ceil(w0 / W) | |
| pad = np.pad(arr, ((0, th*H-h0), (0, tw*W-w0), (0, 0)), mode="reflect") | |
| out = np.zeros((th*H*S, tw*W*S, 3), dtype=np.float32) | |
| for i in range(th): | |
| for j in range(tw): | |
| tile = pad[i*H:(i+1)*H, j*W:(j+1)*W] | |
| up = run_tile(tile, sess, meta) | |
| out[i*H*S:(i+1)*H*S, j*W*S:(j+1)*W*S] = up | |
| out = np.clip(out[:h0*S, :w0*S], 0, 1) | |
| return Image.fromarray((out * 255).astype(np.uint8)) | |
| # ================= ROUTES ================= | |
| def index(): | |
| return jsonify({ | |
| "service": "SpectraGAN Upscaler API", | |
| "status": "running", | |
| "usage": "POST /upscale with image + mode=x2|x4|x8" | |
| }) | |
| def health(): | |
| return jsonify({"status": "ok"}) | |
| def upscale(): | |
| if "image" not in request.files: | |
| return jsonify({"error": "image file required"}), 400 | |
| mode = request.form.get("mode", "x4") | |
| img = Image.open(request.files["image"]) | |
| if mode == "x2": | |
| out = upscale_core(img, 2) | |
| elif mode == "x8": | |
| temp = upscale_core(img, 4) | |
| out = temp.resize((img.width * 8, img.height * 8), Image.LANCZOS) | |
| else: | |
| out = upscale_core(img, 4) | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png") | |
| out.save(tmp.name) | |
| return send_file(tmp.name, mimetype="image/png") | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=7860) |