Akwbw commited on
Commit
f13629a
·
verified ·
1 Parent(s): 4f917c7

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile.txt +12 -0
  2. app.py +133 -0
  3. requirements.txt +5 -0
Dockerfile.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY . .
9
+
10
+ EXPOSE 7860
11
+
12
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import tempfile
4
+ import numpy as np
5
+ import requests
6
+ from PIL import Image
7
+ from flask import Flask, request, jsonify, send_file
8
+ import onnxruntime as ort
9
+
10
+ # ================= CONFIG =================
11
+ MODEL_DIR = "model"
12
+ MODEL_X2_PATH = os.path.join(MODEL_DIR, "Real-ESRGAN_x2plus.onnx")
13
+ MODEL_X4_PATH = os.path.join(MODEL_DIR, "Real-ESRGAN-x4plus.onnx")
14
+
15
+ FILE_ID_X2 = "15xmXXZNH2wMyeQv4ie5hagT7eWK9MgP6"
16
+ FILE_ID_X4 = "1wDBHad9RCJgJDGsPdapLYl3cr8j-PMJ6"
17
+
18
+ MAX_DIM = 1024
19
+
20
+ app = Flask(__name__)
21
+
22
+ # ================= MODEL DOWNLOAD =================
23
+ def download_from_drive(file_id, dest):
24
+ url = "https://drive.google.com/uc?export=download"
25
+ session = requests.Session()
26
+ r = session.get(url, params={"id": file_id}, stream=True)
27
+
28
+ token = None
29
+ for k, v in r.cookies.items():
30
+ if k.startswith("download_warning"):
31
+ token = v
32
+ break
33
+
34
+ if token:
35
+ r = session.get(url, params={"id": file_id, "confirm": token}, stream=True)
36
+
37
+ os.makedirs(os.path.dirname(dest), exist_ok=True)
38
+ with open(dest, "wb") as f:
39
+ for chunk in r.iter_content(32768):
40
+ if chunk:
41
+ f.write(chunk)
42
+
43
+ if not os.path.exists(MODEL_X2_PATH):
44
+ download_from_drive(FILE_ID_X2, MODEL_X2_PATH)
45
+
46
+ if not os.path.exists(MODEL_X4_PATH):
47
+ download_from_drive(FILE_ID_X4, MODEL_X4_PATH)
48
+
49
+ # ================= ONNX SESSIONS =================
50
+ opts = ort.SessionOptions()
51
+ opts.intra_op_num_threads = 2
52
+ opts.inter_op_num_threads = 2
53
+
54
+ sess_x2 = ort.InferenceSession(MODEL_X2_PATH, opts, providers=["CPUExecutionProvider"])
55
+ sess_x4 = ort.InferenceSession(MODEL_X4_PATH, opts, providers=["CPUExecutionProvider"])
56
+
57
+ meta_x2 = sess_x2.get_inputs()[0]
58
+ meta_x4 = sess_x4.get_inputs()[0]
59
+
60
+ _, _, H2, W2 = meta_x2.shape
61
+ _, _, H4, W4 = meta_x4.shape
62
+
63
+ # ================= HELPERS =================
64
+ def run_tile(tile, session, meta):
65
+ inp = np.transpose(tile, (2, 0, 1))[None, ...]
66
+ out = session.run(None, {meta.name: inp})[0][0]
67
+ return np.transpose(out, (1, 2, 0))
68
+
69
+ def upscale_core(img: Image.Image, scale: int):
70
+ if scale == 2:
71
+ H, W, sess, meta, S = H2, W2, sess_x2, meta_x2, 2
72
+ else:
73
+ H, W, sess, meta, S = H4, W4, sess_x4, meta_x4, 4
74
+
75
+ w, h = img.size
76
+ if max(w, h) > MAX_DIM:
77
+ r = MAX_DIM / max(w, h)
78
+ img = img.resize((int(w*r), int(h*r)), Image.LANCZOS)
79
+
80
+ arr = np.array(img.convert("RGB")).astype(np.float32) / 255.0
81
+ h0, w0, _ = arr.shape
82
+
83
+ th = math.ceil(h0 / H)
84
+ tw = math.ceil(w0 / W)
85
+
86
+ pad = np.pad(arr, ((0, th*H-h0), (0, tw*W-w0), (0, 0)), mode="reflect")
87
+ out = np.zeros((th*H*S, tw*W*S, 3), dtype=np.float32)
88
+
89
+ for i in range(th):
90
+ for j in range(tw):
91
+ tile = pad[i*H:(i+1)*H, j*W:(j+1)*W]
92
+ up = run_tile(tile, sess, meta)
93
+ out[i*H*S:(i+1)*H*S, j*W*S:(j+1)*W*S] = up
94
+
95
+ out = np.clip(out[:h0*S, :w0*S], 0, 1)
96
+ return Image.fromarray((out * 255).astype(np.uint8))
97
+
98
+ # ================= ROUTES =================
99
+ @app.route("/", methods=["GET"])
100
+ def index():
101
+ return jsonify({
102
+ "service": "SpectraGAN Upscaler API",
103
+ "status": "running",
104
+ "usage": "POST /upscale with image + mode=x2|x4|x8"
105
+ })
106
+
107
+ @app.route("/health", methods=["GET"])
108
+ def health():
109
+ return jsonify({"status": "ok"})
110
+
111
+ @app.route("/upscale", methods=["POST"])
112
+ def upscale():
113
+ if "image" not in request.files:
114
+ return jsonify({"error": "image file required"}), 400
115
+
116
+ mode = request.form.get("mode", "x4")
117
+ img = Image.open(request.files["image"])
118
+
119
+ if mode == "x2":
120
+ out = upscale_core(img, 2)
121
+ elif mode == "x8":
122
+ temp = upscale_core(img, 4)
123
+ out = temp.resize((img.width * 8, img.height * 8), Image.LANCZOS)
124
+ else:
125
+ out = upscale_core(img, 4)
126
+
127
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
128
+ out.save(tmp.name)
129
+
130
+ return send_file(tmp.name, mimetype="image/png")
131
+
132
+ if __name__ == "__main__":
133
+ app.run(host="0.0.0.0", port=7860)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ flask
2
+ onnxruntime
3
+ numpy
4
+ Pillow
5
+ requests