mobisoft commited on
Commit
60acabe
·
verified ·
1 Parent(s): 26c3131

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -75
app.py CHANGED
@@ -6,116 +6,102 @@ import numpy as np
6
  import insightface
7
  import concurrent.futures
8
  import traceback
9
- import requests
10
 
11
  from fastapi import FastAPI, UploadFile, File, HTTPException
12
  from fastapi.responses import HTMLResponse, StreamingResponse
13
 
 
 
 
 
 
14
  # ============================================================
15
  # CONFIG
16
  # ============================================================
17
  MAX_FILE_MB = 10
18
  MAX_DIM = 640
19
- UPSCALE_FACTOR = 2
20
  MAX_WORKERS = 3
21
  CLEANUP_TIME = 300
22
 
23
  TASKS = {}
24
  executor = concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)
25
 
26
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
27
-
28
  # ============================================================
29
  # LOAD MODELS
30
  # ============================================================
31
  face_app = insightface.app.FaceAnalysis(name="buffalo_l")
32
  face_app.prepare(ctx_id=-1, det_size=(640, 640))
33
 
34
- swapper = insightface.model_zoo.get_model("inswapper_128.onnx", root=BASE_DIR)
35
 
36
  # ============================================================
37
- # GFPGAN LITE (FACE ENHANCER)
38
  # ============================================================
39
- try:
40
- import onnxruntime as ort
41
- ONNX_AVAILABLE = True
42
- except:
43
- ONNX_AVAILABLE = False
44
 
45
- gfpgan_session = None
46
- GFPGAN_URL = "https://huggingface.co/ai-forever/GFPGAN-Lite/resolve/main/gfpgan_lite.onnx"
47
-
48
- def load_gfpgan():
49
- global gfpgan_session
50
- if not ONNX_AVAILABLE:
51
- return
52
 
53
- path = os.path.join(BASE_DIR, "gfpgan_lite.onnx")
 
 
54
 
55
- if not os.path.exists(path):
56
- print("Downloading GFPGAN...")
57
- r = requests.get(GFPGAN_URL)
58
- open(path, "wb").write(r.content)
59
 
60
- gfpgan_session = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
61
- print("GFPGAN Loaded")
62
 
63
  # ============================================================
64
- # IMAGE HELPERS
65
  # ============================================================
66
 
67
- def decode_image(b):
68
- img = cv2.imdecode(np.frombuffer(b, np.uint8), cv2.IMREAD_COLOR)
69
- if img is None:
70
- raise ValueError("Invalid image")
71
- return img
72
-
73
-
74
- def compress_resize(b):
75
- img = decode_image(b)
76
 
77
- size_mb = len(b) / (1024 * 1024)
78
 
 
79
  if size_mb > MAX_FILE_MB:
80
  img = cv2.resize(img, None, fx=0.6, fy=0.6, interpolation=cv2.INTER_AREA)
81
 
 
82
  h, w = img.shape[:2]
83
  if max(h, w) > MAX_DIM:
84
  scale = MAX_DIM / max(h, w)
85
- img = cv2.resize(img, (int(w * scale), int(h * scale)))
86
 
87
  return img
88
 
89
 
 
 
 
 
90
  def upscale_hd(img):
91
  h, w = img.shape[:2]
92
- img = cv2.resize(img, (w * UPSCALE_FACTOR, h * UPSCALE_FACTOR), interpolation=cv2.INTER_CUBIC)
93
 
94
- # sharpening
95
- kernel = np.array([[0,-1,0],[-1,5,-1],[0,-1,0]])
96
- img = cv2.filter2D(img, -1, kernel)
97
 
98
- return img
99
-
100
-
101
- def enhance_face(img):
102
- if gfpgan_session is None:
103
- return img
104
 
105
- try:
106
- inp = cv2.resize(img, (512, 512))
107
- inp = inp.astype(np.float32) / 255.0
108
- inp = np.transpose(inp, (2, 0, 1))[None]
109
 
110
- out = gfpgan_session.run(None, {"input": inp})[0]
111
- out = np.transpose(out[0], (1, 2, 0))
112
- out = (out * 255).clip(0,255).astype(np.uint8)
113
 
114
- out = cv2.resize(out, (img.shape[1], img.shape[0]))
115
- return out
116
- except:
117
- return img
118
 
 
 
 
119
 
120
  def cleanup():
121
  now = time.time()
@@ -133,16 +119,17 @@ def cleanup():
133
  for k in remove:
134
  TASKS.pop(k, None)
135
 
 
136
  # ============================================================
137
  # WORKER
138
  # ============================================================
139
 
140
- def run_task(tid, src_b, tgt_b):
141
  TASKS[tid]["status"] = "processing"
142
 
143
  try:
144
- src = compress_resize(src_b)
145
- tgt = compress_resize(tgt_b)
146
 
147
  s_faces = face_app.get(src)
148
  t_faces = face_app.get(tgt)
@@ -152,12 +139,11 @@ def run_task(tid, src_b, tgt_b):
152
 
153
  result = swapper.get(tgt, t_faces[0], s_faces[0], paste_back=True)
154
 
155
- # ===== HD UPSCALE =====
156
- result = enhance_face(result)
157
  result = upscale_hd(result)
158
 
159
- out_path = f"/tmp/{tid}.webp"
160
- cv2.imwrite(out_path, result, [cv2.IMWRITE_WEBP_QUALITY, 90])
161
 
162
  TASKS[tid] = {
163
  "status": "done",
@@ -169,15 +155,114 @@ def run_task(tid, src_b, tgt_b):
169
  TASKS[tid] = {"status": "failed", "error": str(e)}
170
  print(traceback.format_exc())
171
 
 
172
  # ============================================================
173
  # FASTAPI
174
  # ============================================================
175
 
176
  app = FastAPI()
177
 
 
 
 
 
178
  @app.get("/", response_class=HTMLResponse)
179
  def home():
180
- return "<h2>FaceSwap Running (HD Enabled)</h2>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  @app.post("/swap")
183
  async def swap(source: UploadFile = File(...), target: UploadFile = File(...)):
@@ -198,8 +283,10 @@ async def swap(source: UploadFile = File(...), target: UploadFile = File(...)):
198
  @app.get("/status/{tid}")
199
  def status(tid: str):
200
  cleanup()
 
201
  if tid not in TASKS:
202
  raise HTTPException(404)
 
203
  return TASKS[tid]
204
 
205
 
@@ -210,12 +297,4 @@ def result(tid: str):
210
  if not task or task["status"] != "done":
211
  raise HTTPException(404)
212
 
213
- return StreamingResponse(open(task["result"], "rb"), media_type="image/webp")
214
-
215
- # ============================================================
216
- # INIT
217
- # ============================================================
218
-
219
- print("Loading models...")
220
- load_gfpgan()
221
- print("Ready 🚀")
 
6
  import insightface
7
  import concurrent.futures
8
  import traceback
 
9
 
10
  from fastapi import FastAPI, UploadFile, File, HTTPException
11
  from fastapi.responses import HTMLResponse, StreamingResponse
12
 
13
+ # HEIC SUPPORT
14
+ from PIL import Image
15
+ import pillow_heif
16
+ pillow_heif.register_heif_opener()
17
+
18
  # ============================================================
19
  # CONFIG
20
  # ============================================================
21
  MAX_FILE_MB = 10
22
  MAX_DIM = 640
23
+ UPSCALE_SIZE = 1024 # HD output
24
  MAX_WORKERS = 3
25
  CLEANUP_TIME = 300
26
 
27
  TASKS = {}
28
  executor = concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)
29
 
 
 
30
  # ============================================================
31
  # LOAD MODELS
32
  # ============================================================
33
  face_app = insightface.app.FaceAnalysis(name="buffalo_l")
34
  face_app.prepare(ctx_id=-1, det_size=(640, 640))
35
 
36
+ swapper = insightface.model_zoo.get_model("inswapper_128.onnx", root=".")
37
 
38
  # ============================================================
39
+ # IMAGE DECODING (ALL FORMATS SUPPORT)
40
  # ============================================================
 
 
 
 
 
41
 
42
+ def read_image(file_bytes):
43
+ try:
44
+ # try OpenCV first
45
+ img = cv2.imdecode(np.frombuffer(file_bytes, np.uint8), cv2.IMREAD_COLOR)
46
+ if img is not None:
47
+ return img
 
48
 
49
+ # fallback for HEIC / unsupported
50
+ pil_img = Image.open(io.BytesIO(file_bytes)).convert("RGB")
51
+ return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
52
 
53
+ except Exception:
54
+ raise ValueError("Unsupported image format")
 
 
55
 
 
 
56
 
57
  # ============================================================
58
+ # OPTIMIZATION
59
  # ============================================================
60
 
61
+ def compress_and_resize(file_bytes):
62
+ img = read_image(file_bytes)
 
 
 
 
 
 
 
63
 
64
+ size_mb = len(file_bytes) / (1024 * 1024)
65
 
66
+ # compress large images
67
  if size_mb > MAX_FILE_MB:
68
  img = cv2.resize(img, None, fx=0.6, fy=0.6, interpolation=cv2.INTER_AREA)
69
 
70
+ # resize for faster inference
71
  h, w = img.shape[:2]
72
  if max(h, w) > MAX_DIM:
73
  scale = MAX_DIM / max(h, w)
74
+ img = cv2.resize(img, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR)
75
 
76
  return img
77
 
78
 
79
+ # ============================================================
80
+ # HD UPSCALE (FAST CPU FRIENDLY)
81
+ # ============================================================
82
+
83
  def upscale_hd(img):
84
  h, w = img.shape[:2]
 
85
 
86
+ # upscale to HD target
87
+ scale = UPSCALE_SIZE / max(h, w)
 
88
 
89
+ img = cv2.resize(
90
+ img,
91
+ (int(w * scale), int(h * scale)),
92
+ interpolation=cv2.INTER_CUBIC
93
+ )
 
94
 
95
+ # sharpen for better detail
96
+ blur = cv2.GaussianBlur(img, (0, 0), 1.0)
97
+ img = cv2.addWeighted(img, 1.3, blur, -0.3, 0)
 
98
 
99
+ return img
 
 
100
 
 
 
 
 
101
 
102
+ # ============================================================
103
+ # CLEANUP
104
+ # ============================================================
105
 
106
  def cleanup():
107
  now = time.time()
 
119
  for k in remove:
120
  TASKS.pop(k, None)
121
 
122
+
123
  # ============================================================
124
  # WORKER
125
  # ============================================================
126
 
127
+ def run_task(tid, src_bytes, tgt_bytes):
128
  TASKS[tid]["status"] = "processing"
129
 
130
  try:
131
+ src = compress_and_resize(src_bytes)
132
+ tgt = compress_and_resize(tgt_bytes)
133
 
134
  s_faces = face_app.get(src)
135
  t_faces = face_app.get(tgt)
 
139
 
140
  result = swapper.get(tgt, t_faces[0], s_faces[0], paste_back=True)
141
 
142
+ # HD UPSCALE
 
143
  result = upscale_hd(result)
144
 
145
+ out_path = f"/tmp/{tid}.jpg"
146
+ cv2.imwrite(out_path, result, [cv2.IMWRITE_JPEG_QUALITY, 95])
147
 
148
  TASKS[tid] = {
149
  "status": "done",
 
155
  TASKS[tid] = {"status": "failed", "error": str(e)}
156
  print(traceback.format_exc())
157
 
158
+
159
  # ============================================================
160
  # FASTAPI
161
  # ============================================================
162
 
163
  app = FastAPI()
164
 
165
+ # ============================================================
166
+ # UI
167
+ # ============================================================
168
+
169
  @app.get("/", response_class=HTMLResponse)
170
  def home():
171
+ return """
172
+ <!DOCTYPE html>
173
+ <html>
174
+ <head>
175
+ <meta name="viewport" content="width=device-width, initial-scale=1">
176
+ <title>HD Face Swap</title>
177
+
178
+ <style>
179
+ body{background:#0f172a;color:white;text-align:center;font-family:sans-serif}
180
+ .container{max-width:900px;margin:auto;padding:20px}
181
+ img{width:100%;max-height:260px;object-fit:contain;border-radius:10px}
182
+ button{padding:12px 18px;margin:10px;background:#6366f1;color:white;border:none;border-radius:8px}
183
+ .download{display:none;background:#10b981}
184
+ </style>
185
+ </head>
186
+
187
+ <body>
188
+
189
+ <div class="container">
190
+
191
+ <h2>🔥 HD Face Swap (iOS Ready)</h2>
192
+
193
+ <input type="file" id="src"><br><br>
194
+ <input type="file" id="tgt"><br><br>
195
+
196
+ <img id="p1"><br>
197
+ <img id="p2"><br>
198
+
199
+ <button onclick="start()">Upload & Swap</button>
200
+
201
+ <p id="status"></p>
202
+
203
+ <img id="out"><br>
204
+ <a id="dl" class="download" download="faceswap_hd.jpg">Download HD</a>
205
+
206
+ </div>
207
+
208
+ <script>
209
+ const src = document.getElementById("src");
210
+ const tgt = document.getElementById("tgt");
211
+ const dl = document.getElementById("dl");
212
+
213
+ src.onchange = ()=> p1.src = URL.createObjectURL(src.files[0]);
214
+ tgt.onchange = ()=> p2.src = URL.createObjectURL(tgt.files[0]);
215
+
216
+ async function start(){
217
+ if(!src.files[0] || !tgt.files[0]){
218
+ alert("Upload both images");
219
+ return;
220
+ }
221
+
222
+ let fd = new FormData();
223
+ fd.append("source", src.files[0]);
224
+ fd.append("target", tgt.files[0]);
225
+
226
+ status.innerText = "Processing...";
227
+
228
+ let r = await fetch("/swap", {method:"POST", body:fd});
229
+ let j = await r.json();
230
+
231
+ poll(j.task_id);
232
+ }
233
+
234
+ async function poll(id){
235
+ let r = await fetch("/status/"+id);
236
+ let j = await r.json();
237
+
238
+ status.innerText = j.status;
239
+
240
+ if(j.status==="done"){
241
+ let img = await fetch("/result/"+id);
242
+ let blob = await img.blob();
243
+
244
+ let url = URL.createObjectURL(blob);
245
+ out.src = url;
246
+
247
+ dl.href = url;
248
+ dl.style.display = "inline-block";
249
+
250
+ status.innerText = "Done ✅";
251
+ } else if(j.status==="failed"){
252
+ status.innerText = j.error;
253
+ } else {
254
+ setTimeout(()=>poll(id), 1000);
255
+ }
256
+ }
257
+ </script>
258
+
259
+ </body>
260
+ </html>
261
+ """
262
+
263
+ # ============================================================
264
+ # API (iOS FRIENDLY)
265
+ # ============================================================
266
 
267
  @app.post("/swap")
268
  async def swap(source: UploadFile = File(...), target: UploadFile = File(...)):
 
283
  @app.get("/status/{tid}")
284
  def status(tid: str):
285
  cleanup()
286
+
287
  if tid not in TASKS:
288
  raise HTTPException(404)
289
+
290
  return TASKS[tid]
291
 
292
 
 
297
  if not task or task["status"] != "done":
298
  raise HTTPException(404)
299
 
300
+ return StreamingResponse(open(task["result"], "rb"), media_type="image/jpeg")