Nothingimage / app.py
Akwbw's picture
Upload 3 files
f13629a verified
import os
import math
import tempfile
import numpy as np
import requests
from PIL import Image
from flask import Flask, request, jsonify, send_file
import onnxruntime as ort
# ================= CONFIG =================
MODEL_DIR = "model"
MODEL_X2_PATH = os.path.join(MODEL_DIR, "Real-ESRGAN_x2plus.onnx")
MODEL_X4_PATH = os.path.join(MODEL_DIR, "Real-ESRGAN-x4plus.onnx")
FILE_ID_X2 = "15xmXXZNH2wMyeQv4ie5hagT7eWK9MgP6"
FILE_ID_X4 = "1wDBHad9RCJgJDGsPdapLYl3cr8j-PMJ6"
MAX_DIM = 1024
app = Flask(__name__)
# ================= MODEL DOWNLOAD =================
def download_from_drive(file_id, dest):
url = "https://drive.google.com/uc?export=download"
session = requests.Session()
r = session.get(url, params={"id": file_id}, stream=True)
token = None
for k, v in r.cookies.items():
if k.startswith("download_warning"):
token = v
break
if token:
r = session.get(url, params={"id": file_id, "confirm": token}, stream=True)
os.makedirs(os.path.dirname(dest), exist_ok=True)
with open(dest, "wb") as f:
for chunk in r.iter_content(32768):
if chunk:
f.write(chunk)
if not os.path.exists(MODEL_X2_PATH):
download_from_drive(FILE_ID_X2, MODEL_X2_PATH)
if not os.path.exists(MODEL_X4_PATH):
download_from_drive(FILE_ID_X4, MODEL_X4_PATH)
# ================= ONNX SESSIONS =================
opts = ort.SessionOptions()
opts.intra_op_num_threads = 2
opts.inter_op_num_threads = 2
sess_x2 = ort.InferenceSession(MODEL_X2_PATH, opts, providers=["CPUExecutionProvider"])
sess_x4 = ort.InferenceSession(MODEL_X4_PATH, opts, providers=["CPUExecutionProvider"])
meta_x2 = sess_x2.get_inputs()[0]
meta_x4 = sess_x4.get_inputs()[0]
_, _, H2, W2 = meta_x2.shape
_, _, H4, W4 = meta_x4.shape
# ================= HELPERS =================
def run_tile(tile, session, meta):
inp = np.transpose(tile, (2, 0, 1))[None, ...]
out = session.run(None, {meta.name: inp})[0][0]
return np.transpose(out, (1, 2, 0))
def upscale_core(img: Image.Image, scale: int):
if scale == 2:
H, W, sess, meta, S = H2, W2, sess_x2, meta_x2, 2
else:
H, W, sess, meta, S = H4, W4, sess_x4, meta_x4, 4
w, h = img.size
if max(w, h) > MAX_DIM:
r = MAX_DIM / max(w, h)
img = img.resize((int(w*r), int(h*r)), Image.LANCZOS)
arr = np.array(img.convert("RGB")).astype(np.float32) / 255.0
h0, w0, _ = arr.shape
th = math.ceil(h0 / H)
tw = math.ceil(w0 / W)
pad = np.pad(arr, ((0, th*H-h0), (0, tw*W-w0), (0, 0)), mode="reflect")
out = np.zeros((th*H*S, tw*W*S, 3), dtype=np.float32)
for i in range(th):
for j in range(tw):
tile = pad[i*H:(i+1)*H, j*W:(j+1)*W]
up = run_tile(tile, sess, meta)
out[i*H*S:(i+1)*H*S, j*W*S:(j+1)*W*S] = up
out = np.clip(out[:h0*S, :w0*S], 0, 1)
return Image.fromarray((out * 255).astype(np.uint8))
# ================= ROUTES =================
@app.route("/", methods=["GET"])
def index():
return jsonify({
"service": "SpectraGAN Upscaler API",
"status": "running",
"usage": "POST /upscale with image + mode=x2|x4|x8"
})
@app.route("/health", methods=["GET"])
def health():
return jsonify({"status": "ok"})
@app.route("/upscale", methods=["POST"])
def upscale():
if "image" not in request.files:
return jsonify({"error": "image file required"}), 400
mode = request.form.get("mode", "x4")
img = Image.open(request.files["image"])
if mode == "x2":
out = upscale_core(img, 2)
elif mode == "x8":
temp = upscale_core(img, 4)
out = temp.resize((img.width * 8, img.height * 8), Image.LANCZOS)
else:
out = upscale_core(img, 4)
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
out.save(tmp.name)
return send_file(tmp.name, mimetype="image/png")
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)