File size: 3,987 Bytes
1936a01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import math
import tempfile
import numpy as np
import requests
from PIL import Image
from flask import Flask, request, jsonify, send_file
import onnxruntime as ort

# ================= CONFIG =================
MODEL_DIR = "model"
MODEL_X2_PATH = os.path.join(MODEL_DIR, "Real-ESRGAN_x2plus.onnx")
MODEL_X4_PATH = os.path.join(MODEL_DIR, "Real-ESRGAN-x4plus.onnx")

FILE_ID_X2 = "15xmXXZNH2wMyeQv4ie5hagT7eWK9MgP6"
FILE_ID_X4 = "1wDBHad9RCJgJDGsPdapLYl3cr8j-PMJ6"

MAX_DIM = 1024

app = Flask(__name__)

# ================= MODEL DOWNLOAD =================
def download_from_drive(file_id, dest):
    url = "https://drive.google.com/uc?export=download"
    session = requests.Session()
    r = session.get(url, params={"id": file_id}, stream=True)

    token = None
    for k, v in r.cookies.items():
        if k.startswith("download_warning"):
            token = v
            break

    if token:
        r = session.get(url, params={"id": file_id, "confirm": token}, stream=True)

    os.makedirs(os.path.dirname(dest), exist_ok=True)
    with open(dest, "wb") as f:
        for chunk in r.iter_content(32768):
            if chunk:
                f.write(chunk)

if not os.path.exists(MODEL_X2_PATH):
    download_from_drive(FILE_ID_X2, MODEL_X2_PATH)

if not os.path.exists(MODEL_X4_PATH):
    download_from_drive(FILE_ID_X4, MODEL_X4_PATH)

# ================= ONNX SESSIONS =================
opts = ort.SessionOptions()
opts.intra_op_num_threads = 2
opts.inter_op_num_threads = 2

sess_x2 = ort.InferenceSession(MODEL_X2_PATH, opts, providers=["CPUExecutionProvider"])
sess_x4 = ort.InferenceSession(MODEL_X4_PATH, opts, providers=["CPUExecutionProvider"])

meta_x2 = sess_x2.get_inputs()[0]
meta_x4 = sess_x4.get_inputs()[0]

_, _, H2, W2 = meta_x2.shape
_, _, H4, W4 = meta_x4.shape

# ================= HELPERS =================
def run_tile(tile, session, meta):
    inp = np.transpose(tile, (2, 0, 1))[None, ...]
    out = session.run(None, {meta.name: inp})[0][0]
    return np.transpose(out, (1, 2, 0))

def upscale_core(img: Image.Image, scale: int):
    if scale == 2:
        H, W, sess, meta, S = H2, W2, sess_x2, meta_x2, 2
    else:
        H, W, sess, meta, S = H4, W4, sess_x4, meta_x4, 4

    w, h = img.size
    if max(w, h) > MAX_DIM:
        r = MAX_DIM / max(w, h)
        img = img.resize((int(w*r), int(h*r)), Image.LANCZOS)

    arr = np.array(img.convert("RGB")).astype(np.float32) / 255.0
    h0, w0, _ = arr.shape

    th = math.ceil(h0 / H)
    tw = math.ceil(w0 / W)

    pad = np.pad(arr, ((0, th*H-h0), (0, tw*W-w0), (0, 0)), mode="reflect")
    out = np.zeros((th*H*S, tw*W*S, 3), dtype=np.float32)

    for i in range(th):
        for j in range(tw):
            tile = pad[i*H:(i+1)*H, j*W:(j+1)*W]
            up = run_tile(tile, sess, meta)
            out[i*H*S:(i+1)*H*S, j*W*S:(j+1)*W*S] = up

    out = np.clip(out[:h0*S, :w0*S], 0, 1)
    return Image.fromarray((out * 255).astype(np.uint8))

# ================= ROUTES =================
@app.route("/", methods=["GET"])
def index():
    return jsonify({
        "service": "SpectraGAN Upscaler API",
        "status": "running",
        "usage": "POST /upscale with image + mode=x2|x4|x8"
    })

@app.route("/health", methods=["GET"])
def health():
    return jsonify({"status": "ok"})

@app.route("/upscale", methods=["POST"])
def upscale():
    if "image" not in request.files:
        return jsonify({"error": "image file required"}), 400

    mode = request.form.get("mode", "x4")
    img = Image.open(request.files["image"])

    if mode == "x2":
        out = upscale_core(img, 2)
    elif mode == "x8":
        temp = upscale_core(img, 4)
        out = temp.resize((img.width * 8, img.height * 8), Image.LANCZOS)
    else:
        out = upscale_core(img, 4)

    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
    out.save(tmp.name)

    return send_file(tmp.name, mimetype="image/png")

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)