mobisoft commited on
Commit
7500a66
·
verified ·
1 Parent(s): ecd9a03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -139
app.py CHANGED
@@ -11,18 +11,19 @@ from fastapi import FastAPI, UploadFile, File, HTTPException
11
  from fastapi.responses import HTMLResponse, StreamingResponse
12
 
13
  # ============================================================
14
- # CONFIG (OPTIMIZED FOR SPEED + HF SPACES)
15
  # ============================================================
16
  MAX_FILE_MB = 10
17
- MAX_DIM = 640
18
- MAX_WORKERS = 3 # prevent CPU overload
 
19
  CLEANUP_TIME = 300
20
 
21
  TASKS = {}
22
  executor = concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)
23
 
24
  # ============================================================
25
- # LOAD MODELS (ONLY ONCE)
26
  # ============================================================
27
  face_app = insightface.app.FaceAnalysis(name="buffalo_l")
28
  face_app.prepare(ctx_id=-1, det_size=(640, 640))
@@ -30,42 +31,54 @@ face_app.prepare(ctx_id=-1, det_size=(640, 640))
30
  swapper = insightface.model_zoo.get_model("inswapper_128.onnx", root=".")
31
 
32
  # ============================================================
33
- # IMAGE OPTIMIZATION
34
  # ============================================================
35
 
36
  def decode_image(file_bytes):
37
  img = cv2.imdecode(np.frombuffer(file_bytes, np.uint8), cv2.IMREAD_COLOR)
38
  if img is None:
39
- raise ValueError("Invalid image format")
40
  return img
41
 
42
 
43
- def compress_and_resize(file_bytes):
44
- img = decode_image(file_bytes)
45
 
46
- size_mb = len(file_bytes) / (1024 * 1024)
 
 
47
 
48
- # compress if >10MB
49
- if size_mb > MAX_FILE_MB:
50
- img = cv2.resize(img, None, fx=0.6, fy=0.6, interpolation=cv2.INTER_AREA)
51
 
52
- # resize for speed
 
53
  h, w = img.shape[:2]
54
- if max(h, w) > MAX_DIM:
55
- scale = MAX_DIM / max(h, w)
56
- img = cv2.resize(img, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR)
 
57
 
58
  return img
59
 
60
 
61
- def enhance(img):
62
- blur = cv2.GaussianBlur(img, (0, 0), 1.2)
63
- return cv2.addWeighted(img, 1.2, blur, -0.2, 0)
 
 
 
 
 
 
 
 
 
 
64
 
65
 
66
  def cleanup():
67
  now = time.time()
68
- delete_keys = []
69
 
70
  for k, v in TASKS.items():
71
  if "time" in v and now - v["time"] > CLEANUP_TIME:
@@ -74,9 +87,9 @@ def cleanup():
74
  os.remove(v["result"])
75
  except:
76
  pass
77
- delete_keys.append(k)
78
 
79
- for k in delete_keys:
80
  TASKS.pop(k, None)
81
 
82
  # ============================================================
@@ -87,20 +100,33 @@ def run_task(tid, src_bytes, tgt_bytes):
87
  TASKS[tid]["status"] = "processing"
88
 
89
  try:
90
- src = compress_and_resize(src_bytes)
91
- tgt = compress_and_resize(tgt_bytes)
 
 
 
 
 
92
 
 
93
  s_faces = face_app.get(src)
94
  t_faces = face_app.get(tgt)
95
 
96
  if not s_faces or not t_faces:
97
  raise ValueError("Face not detected")
98
 
 
99
  result = swapper.get(tgt, t_faces[0], s_faces[0], paste_back=True)
100
- result = enhance(result)
101
 
 
 
 
 
 
 
 
102
  out_path = f"/tmp/{tid}.webp"
103
- cv2.imwrite(out_path, result, [cv2.IMWRITE_WEBP_QUALITY, 85])
104
 
105
  TASKS[tid] = {
106
  "status": "done",
@@ -119,7 +145,7 @@ def run_task(tid, src_bytes, tgt_bytes):
119
  app = FastAPI()
120
 
121
  # ============================================================
122
- # RESPONSIVE UI (UPLOAD + DOWNLOAD BUTTON)
123
  # ============================================================
124
 
125
  @app.get("/", response_class=HTMLResponse)
@@ -129,61 +155,17 @@ def home():
129
  <html>
130
  <head>
131
  <meta name="viewport" content="width=device-width, initial-scale=1">
132
- <title>Fast Face Swap</title>
133
 
134
  <style>
135
- body{
136
- margin:0;
137
- font-family:sans-serif;
138
- background:#0f172a;
139
- color:white;
140
- text-align:center;
141
- }
142
-
143
- .container{
144
- max-width:900px;
145
- margin:auto;
146
- padding:20px;
147
- }
148
-
149
- .grid{
150
- display:grid;
151
- grid-template-columns:1fr 1fr;
152
- gap:15px;
153
- }
154
-
155
- @media(max-width:700px){
156
- .grid{grid-template-columns:1fr;}
157
- }
158
-
159
- .card{
160
- background:#1e293b;
161
- padding:15px;
162
- border-radius:10px;
163
- }
164
-
165
- img{
166
- width:100%;
167
- max-height:250px;
168
- object-fit:contain;
169
- border-radius:10px;
170
- background:black;
171
- }
172
-
173
- button{
174
- padding:12px 18px;
175
- margin:8px;
176
- border:none;
177
- border-radius:8px;
178
- background:#6366f1;
179
- color:white;
180
- cursor:pointer;
181
- }
182
-
183
- .download{
184
- display:none;
185
- background:#10b981;
186
- }
187
  </style>
188
 
189
  </head>
@@ -192,90 +174,77 @@ button{
192
 
193
  <div class="container">
194
 
195
- <h2>⚡ Fast AI Face Swap</h2>
196
 
197
  <div class="grid">
198
-
199
  <div class="card">
200
- <h3>Source</h3>
201
  <input type="file" id="src">
202
  <img id="p1">
203
  </div>
204
 
205
  <div class="card">
206
- <h3>Target</h3>
207
  <input type="file" id="tgt">
208
  <img id="p2">
209
  </div>
210
-
211
  </div>
212
 
213
- <br>
214
-
215
  <button onclick="start()">Upload & Swap</button>
216
 
217
  <p id="status"></p>
218
 
219
  <div class="card">
220
- <h3>Output</h3>
221
  <img id="out">
222
  <br>
223
- <a id="downloadBtn" class="download" download="faceswap.webp">Download</a>
224
  </div>
225
 
226
  </div>
227
 
228
  <script>
229
- const src = document.getElementById("src");
230
- const tgt = document.getElementById("tgt");
231
- const statusEl = document.getElementById("status");
232
- const downloadBtn = document.getElementById("downloadBtn");
233
 
234
- src.onchange = ()=> p1.src = URL.createObjectURL(src.files[0]);
235
- tgt.onchange = ()=> p2.src = URL.createObjectURL(tgt.files[0]);
236
 
237
  async function start(){
238
- if(!src.files[0] || !tgt.files[0]){
239
- alert("Upload both images");
240
- return;
241
- }
242
 
243
- let fd = new FormData();
244
- fd.append("source", src.files[0]);
245
- fd.append("target", tgt.files[0]);
246
 
247
- statusEl.innerText = "Uploading...";
248
 
249
- let r = await fetch("/swap", {method:"POST", body:fd});
250
- let j = await r.json();
251
 
252
- poll(j.task_id);
253
  }
254
 
255
  async function poll(id){
256
- let r = await fetch("/status/"+id);
257
- let j = await r.json();
258
-
259
- statusEl.innerText = j.status;
260
-
261
- if(j.status === "done"){
262
- let img = await fetch("/result/"+id);
263
- let blob = await img.blob();
264
-
265
- let url = URL.createObjectURL(blob);
266
- out.src = url;
267
-
268
- downloadBtn.href = url;
269
- downloadBtn.style.display = "inline-block";
270
-
271
- statusEl.innerText = "Done ✅";
272
- }
273
- else if(j.status === "failed"){
274
- statusEl.innerText = "Error: " + j.error;
275
- }
276
- else{
277
- setTimeout(()=>poll(id), 800);
278
- }
279
  }
280
  </script>
281
 
@@ -284,7 +253,7 @@ async function poll(id){
284
  """
285
 
286
  # ============================================================
287
- # API (UNCHANGED)
288
  # ============================================================
289
 
290
  @app.post("/swap")
@@ -293,12 +262,7 @@ async def swap(source: UploadFile = File(...), target: UploadFile = File(...)):
293
 
294
  TASKS[tid] = {"status": "queued", "time": time.time()}
295
 
296
- executor.submit(
297
- run_task,
298
- tid,
299
- await source.read(),
300
- await target.read()
301
- )
302
 
303
  return {"task_id": tid}
304
 
@@ -306,17 +270,14 @@ async def swap(source: UploadFile = File(...), target: UploadFile = File(...)):
306
  @app.get("/status/{tid}")
307
  def status(tid: str):
308
  cleanup()
309
-
310
  if tid not in TASKS:
311
  raise HTTPException(404)
312
-
313
  return TASKS[tid]
314
 
315
 
316
  @app.get("/result/{tid}")
317
  def result(tid: str):
318
  task = TASKS.get(tid)
319
-
320
  if not task or task["status"] != "done":
321
  raise HTTPException(404)
322
 
 
11
  from fastapi.responses import HTMLResponse, StreamingResponse
12
 
13
  # ============================================================
14
+ # CONFIG
15
  # ============================================================
16
  MAX_FILE_MB = 10
17
+ INPUT_MAX_DIM = 640
18
+ OUTPUT_MAX_DIM = 1024
19
+ MAX_WORKERS = 3
20
  CLEANUP_TIME = 300
21
 
22
  TASKS = {}
23
  executor = concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)
24
 
25
  # ============================================================
26
+ # MODELS
27
  # ============================================================
28
  face_app = insightface.app.FaceAnalysis(name="buffalo_l")
29
  face_app.prepare(ctx_id=-1, det_size=(640, 640))
 
31
  swapper = insightface.model_zoo.get_model("inswapper_128.onnx", root=".")
32
 
33
  # ============================================================
34
+ # IMAGE HELPERS
35
  # ============================================================
36
 
37
  def decode_image(file_bytes):
38
  img = cv2.imdecode(np.frombuffer(file_bytes, np.uint8), cv2.IMREAD_COLOR)
39
  if img is None:
40
+ raise ValueError("Invalid image")
41
  return img
42
 
43
 
44
+ def compress_resize_input(img):
45
+ h, w = img.shape[:2]
46
 
47
+ if max(h, w) > INPUT_MAX_DIM:
48
+ scale = INPUT_MAX_DIM / max(h, w)
49
+ img = cv2.resize(img, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_LINEAR)
50
 
51
+ return img
 
 
52
 
53
+
54
+ def upscale_hd(img):
55
  h, w = img.shape[:2]
56
+
57
+ if max(h, w) < OUTPUT_MAX_DIM:
58
+ scale = OUTPUT_MAX_DIM / max(h, w)
59
+ img = cv2.resize(img, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_CUBIC)
60
 
61
  return img
62
 
63
 
64
+ def sharpen(img):
65
+ kernel = np.array([[0,-1,0],[-1,5,-1],[0,-1,0]])
66
+ return cv2.filter2D(img, -1, kernel)
67
+
68
+
69
+ def compress_if_needed(file_bytes):
70
+ size_mb = len(file_bytes) / (1024 * 1024)
71
+ img = decode_image(file_bytes)
72
+
73
+ if size_mb > MAX_FILE_MB:
74
+ img = cv2.resize(img, None, fx=0.6, fy=0.6, interpolation=cv2.INTER_AREA)
75
+
76
+ return img
77
 
78
 
79
  def cleanup():
80
  now = time.time()
81
+ remove = []
82
 
83
  for k, v in TASKS.items():
84
  if "time" in v and now - v["time"] > CLEANUP_TIME:
 
87
  os.remove(v["result"])
88
  except:
89
  pass
90
+ remove.append(k)
91
 
92
+ for k in remove:
93
  TASKS.pop(k, None)
94
 
95
  # ============================================================
 
100
  TASKS[tid]["status"] = "processing"
101
 
102
  try:
103
+ # Decode + compress
104
+ src = compress_if_needed(src_bytes)
105
+ tgt = compress_if_needed(tgt_bytes)
106
+
107
+ # Resize input for speed
108
+ src = compress_resize_input(src)
109
+ tgt = compress_resize_input(tgt)
110
 
111
+ # Face detect
112
  s_faces = face_app.get(src)
113
  t_faces = face_app.get(tgt)
114
 
115
  if not s_faces or not t_faces:
116
  raise ValueError("Face not detected")
117
 
118
+ # Face swap
119
  result = swapper.get(tgt, t_faces[0], s_faces[0], paste_back=True)
 
120
 
121
+ # 🔥 HD UPSCALE
122
+ result = upscale_hd(result)
123
+
124
+ # 🔥 SHARPEN
125
+ result = sharpen(result)
126
+
127
+ # Save
128
  out_path = f"/tmp/{tid}.webp"
129
+ cv2.imwrite(out_path, result, [cv2.IMWRITE_WEBP_QUALITY, 90])
130
 
131
  TASKS[tid] = {
132
  "status": "done",
 
145
  app = FastAPI()
146
 
147
  # ============================================================
148
+ # UI (UPLOAD + DOWNLOAD)
149
  # ============================================================
150
 
151
  @app.get("/", response_class=HTMLResponse)
 
155
  <html>
156
  <head>
157
  <meta name="viewport" content="width=device-width, initial-scale=1">
158
+ <title>HD Face Swap</title>
159
 
160
  <style>
161
+ body{background:#0f172a;color:white;text-align:center;font-family:sans-serif}
162
+ .container{max-width:900px;margin:auto;padding:20px}
163
+ .grid{display:grid;grid-template-columns:1fr 1fr;gap:10px}
164
+ @media(max-width:700px){.grid{grid-template-columns:1fr}}
165
+ .card{background:#1e293b;padding:10px;border-radius:10px}
166
+ img{width:100%;max-height:250px;object-fit:contain;border-radius:10px}
167
+ button{padding:12px 18px;margin:10px;border:none;background:#6366f1;color:white;border-radius:8px}
168
+ .download{display:none;background:#10b981}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  </style>
170
 
171
  </head>
 
174
 
175
  <div class="container">
176
 
177
+ <h2>⚡ HD Face Swap</h2>
178
 
179
  <div class="grid">
 
180
  <div class="card">
 
181
  <input type="file" id="src">
182
  <img id="p1">
183
  </div>
184
 
185
  <div class="card">
 
186
  <input type="file" id="tgt">
187
  <img id="p2">
188
  </div>
 
189
  </div>
190
 
 
 
191
  <button onclick="start()">Upload & Swap</button>
192
 
193
  <p id="status"></p>
194
 
195
  <div class="card">
 
196
  <img id="out">
197
  <br>
198
+ <a id="dl" class="download" download="faceswap_hd.webp">Download</a>
199
  </div>
200
 
201
  </div>
202
 
203
  <script>
204
+ const src=document.getElementById("src")
205
+ const tgt=document.getElementById("tgt")
206
+ const st=document.getElementById("status")
207
+ const dl=document.getElementById("dl")
208
 
209
+ src.onchange=()=>p1.src=URL.createObjectURL(src.files[0])
210
+ tgt.onchange=()=>p2.src=URL.createObjectURL(tgt.files[0])
211
 
212
  async function start(){
213
+ if(!src.files[0]||!tgt.files[0]) return alert("upload both")
 
 
 
214
 
215
+ let fd=new FormData()
216
+ fd.append("source",src.files[0])
217
+ fd.append("target",tgt.files[0])
218
 
219
+ st.innerText="Processing..."
220
 
221
+ let r=await fetch("/swap",{method:"POST",body:fd})
222
+ let j=await r.json()
223
 
224
+ poll(j.task_id)
225
  }
226
 
227
  async function poll(id){
228
+ let r=await fetch("/status/"+id)
229
+ let j=await r.json()
230
+
231
+ st.innerText=j.status
232
+
233
+ if(j.status==="done"){
234
+ let img=await fetch("/result/"+id)
235
+ let b=await img.blob()
236
+ let u=URL.createObjectURL(b)
237
+
238
+ out.src=u
239
+ dl.href=u
240
+ dl.style.display="inline-block"
241
+
242
+ st.innerText="Done ✅"
243
+ }else if(j.status==="failed"){
244
+ st.innerText="Error: "+j.error
245
+ }else{
246
+ setTimeout(()=>poll(id),800)
247
+ }
 
 
 
248
  }
249
  </script>
250
 
 
253
  """
254
 
255
  # ============================================================
256
+ # API
257
  # ============================================================
258
 
259
  @app.post("/swap")
 
262
 
263
  TASKS[tid] = {"status": "queued", "time": time.time()}
264
 
265
+ executor.submit(run_task, tid, await source.read(), await target.read())
 
 
 
 
 
266
 
267
  return {"task_id": tid}
268
 
 
270
  @app.get("/status/{tid}")
271
  def status(tid: str):
272
  cleanup()
 
273
  if tid not in TASKS:
274
  raise HTTPException(404)
 
275
  return TASKS[tid]
276
 
277
 
278
  @app.get("/result/{tid}")
279
  def result(tid: str):
280
  task = TASKS.get(tid)
 
281
  if not task or task["status"] != "done":
282
  raise HTTPException(404)
283