mobisoft commited on
Commit
f8a4ad9
·
verified ·
1 Parent(s): 459af64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -83
app.py CHANGED
@@ -6,6 +6,7 @@ import numpy as np
6
  import insightface
7
  import concurrent.futures
8
  import traceback
 
9
 
10
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form
11
  from fastapi.responses import HTMLResponse, StreamingResponse
@@ -31,7 +32,7 @@ TASKS = {}
31
  executor = concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)
32
 
33
  # ============================================================
34
- # LOAD MODELS
35
  # ============================================================
36
  face_app = insightface.app.FaceAnalysis(name="buffalo_l")
37
  face_app.prepare(ctx_id=-1, det_size=(640, 640))
@@ -48,7 +49,6 @@ def decode_image(file_bytes):
48
 
49
  if img is None and HEIC_SUPPORTED:
50
  try:
51
- from PIL import Image
52
  import io
53
  pil_img = Image.open(io.BytesIO(file_bytes)).convert("RGB")
54
  img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
@@ -56,7 +56,7 @@ def decode_image(file_bytes):
56
  pass
57
 
58
  if img is None:
59
- raise ValueError("Unsupported image format")
60
 
61
  return img
62
 
@@ -65,9 +65,8 @@ def compress_and_resize(file_bytes):
65
  img = decode_image(file_bytes)
66
 
67
  size_mb = len(file_bytes) / (1024 * 1024)
68
-
69
  if size_mb > MAX_FILE_MB:
70
- img = cv2.resize(img, None, fx=0.6, fy=0.6, interpolation=cv2.INTER_AREA)
71
 
72
  h, w = img.shape[:2]
73
  if max(h, w) > MAX_DIM:
@@ -84,19 +83,37 @@ def enhance(img):
84
 
85
  def cleanup():
86
  now = time.time()
87
- remove = []
88
-
89
- for k, v in TASKS.items():
90
  if "time" in v and now - v["time"] > CLEANUP_TIME:
91
  try:
92
  if "result" in v:
93
  os.remove(v["result"])
94
  except:
95
  pass
96
- remove.append(k)
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- for k in remove:
99
- TASKS.pop(k, None)
 
 
 
 
 
 
 
100
 
101
  # ============================================================
102
  # WORKER
@@ -123,7 +140,7 @@ def run_task(tid, src_bytes, tgt_bytes, filename, face_index):
123
  name = os.path.splitext(filename)[0]
124
  out_path = f"/tmp/{name}_{tid}.png"
125
 
126
- cv2.imwrite(out_path, result, [cv2.IMWRITE_PNG_COMPRESSION, 3])
127
 
128
  TASKS[tid] = {
129
  "status": "done",
@@ -143,7 +160,7 @@ def run_task(tid, src_bytes, tgt_bytes, filename, face_index):
143
  app = FastAPI()
144
 
145
  # ============================================================
146
- # UI PAGE
147
  # ============================================================
148
 
149
  @app.get("/", response_class=HTMLResponse)
@@ -153,106 +170,112 @@ def home():
153
  <html>
154
  <head>
155
  <meta name="viewport" content="width=device-width, initial-scale=1">
156
- <title>Face Swap API Test</title>
157
 
158
  <style>
159
- body{font-family:sans-serif;background:#0f172a;color:white;text-align:center;padding:20px}
160
  .box{border:2px dashed #444;padding:20px;margin:10px;border-radius:10px;cursor:pointer}
161
  img{max-width:200px;margin-top:10px;border-radius:8px}
162
- button{padding:12px 20px;background:#6366f1;color:white;border:none;border-radius:8px}
163
- .progress{height:10px;background:#333;margin-top:10px;border-radius:10px;overflow:hidden;display:none}
164
- .bar{height:100%;width:0;background:#22c55e}
 
 
 
 
165
  </style>
166
 
167
  </head>
168
 
169
  <body>
170
 
171
- <h2>⚡ Face Swap Test UI</h2>
172
 
173
- <div class="box" onclick="src.click()">Upload Source<input type="file" id="src" hidden></div>
174
  <img id="p1">
175
 
176
- <div class="box" onclick="tgt.click()">Upload Target<input type="file" id="tgt" hidden></div>
 
 
177
  <img id="p2">
178
 
179
  <br>
 
180
 
181
- <label>Select Face Index:</label>
182
- <input type="number" id="faceIndex" value="0" min="0">
183
-
184
- <br><br>
185
-
186
- <button onclick="start()">Start Swap</button>
187
-
188
- <div class="progress" id="progress"><div class="bar" id="bar"></div></div>
189
-
190
- <br>
191
 
192
- <img id="out">
193
  <br>
194
- <a id="dl" download="faceswap.png" style="display:none;color:lightgreen">Download PNG</a>
 
195
 
196
  <script>
197
- const src=document.getElementById("src");
198
- const tgt=document.getElementById("tgt");
199
 
200
- src.onchange=()=>p1.src=URL.createObjectURL(src.files[0]);
201
- tgt.onchange=()=>p2.src=URL.createObjectURL(tgt.files[0]);
202
 
203
- function upload(url,fd){
204
- return new Promise((res,rej)=>{
205
- let xhr=new XMLHttpRequest();
206
- xhr.open("POST",url);
207
 
208
- xhr.upload.onprogress=(e)=>{
209
- if(e.lengthComputable){
210
- progress.style.display="block";
211
- bar.style.width=(e.loaded/e.total*100)+"%";
212
- }
213
- };
214
 
215
- xhr.onload=()=>res(JSON.parse(xhr.responseText));
216
- xhr.onerror=rej;
 
 
 
 
 
 
 
 
 
 
217
 
218
- xhr.send(fd);
219
- });
220
- }
 
221
 
222
  async function start(){
223
- if(!src.files[0]||!tgt.files[0]) return alert("Upload both");
224
 
225
- let fd=new FormData();
226
- fd.append("source",src.files[0]);
227
- fd.append("target",tgt.files[0]);
228
- fd.append("face_index",document.getElementById("faceIndex").value);
229
 
230
- let data=await upload("/swap",fd);
 
231
 
232
- poll(data.task_id);
233
  }
234
 
235
  async function poll(id){
236
- let r=await fetch("/status/"+id);
237
- let j=await r.json();
238
-
239
- if(j.status==="done"){
240
- let img=await fetch("/result/"+id);
241
- let blob=await img.blob();
242
- let url=URL.createObjectURL(blob);
243
-
244
- out.src=url;
245
- dl.href=url;
246
- dl.style.display="block";
247
-
248
- bar.style.width="100%";
249
- }
250
- else if(j.status==="failed"){
251
- alert(j.error);
252
- }
253
- else{
254
- setTimeout(()=>poll(id),800);
255
- }
 
256
  }
257
  </script>
258
 
@@ -261,7 +284,17 @@ async function poll(id){
261
  """
262
 
263
  # ============================================================
264
- # API
 
 
 
 
 
 
 
 
 
 
265
  # ============================================================
266
 
267
  @app.post("/swap")
@@ -285,6 +318,9 @@ async def swap(
285
 
286
  return {"task_id": tid}
287
 
 
 
 
288
 
289
  @app.get("/status/{tid}")
290
  def status(tid: str):
@@ -293,6 +329,9 @@ def status(tid: str):
293
  raise HTTPException(404)
294
  return TASKS[tid]
295
 
 
 
 
296
 
297
  @app.get("/result/{tid}")
298
  def result(tid: str):
@@ -304,7 +343,5 @@ def result(tid: str):
304
  return StreamingResponse(
305
  open(task["result"], "rb"),
306
  media_type="image/png",
307
- headers={
308
- "Content-Disposition": f'attachment; filename="{task["filename"]}"'
309
- }
310
  )
 
6
  import insightface
7
  import concurrent.futures
8
  import traceback
9
+ import base64
10
 
11
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form
12
  from fastapi.responses import HTMLResponse, StreamingResponse
 
32
  executor = concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)
33
 
34
  # ============================================================
35
+ # MODELS
36
  # ============================================================
37
  face_app = insightface.app.FaceAnalysis(name="buffalo_l")
38
  face_app.prepare(ctx_id=-1, det_size=(640, 640))
 
49
 
50
  if img is None and HEIC_SUPPORTED:
51
  try:
 
52
  import io
53
  pil_img = Image.open(io.BytesIO(file_bytes)).convert("RGB")
54
  img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
 
56
  pass
57
 
58
  if img is None:
59
+ raise ValueError("Unsupported image")
60
 
61
  return img
62
 
 
65
  img = decode_image(file_bytes)
66
 
67
  size_mb = len(file_bytes) / (1024 * 1024)
 
68
  if size_mb > MAX_FILE_MB:
69
+ img = cv2.resize(img, None, fx=0.6, fy=0.6)
70
 
71
  h, w = img.shape[:2]
72
  if max(h, w) > MAX_DIM:
 
83
 
84
  def cleanup():
85
  now = time.time()
86
+ for k in list(TASKS.keys()):
87
+ v = TASKS[k]
 
88
  if "time" in v and now - v["time"] > CLEANUP_TIME:
89
  try:
90
  if "result" in v:
91
  os.remove(v["result"])
92
  except:
93
  pass
94
+ TASKS.pop(k, None)
95
+
96
+ # ============================================================
97
+ # FACE DETECTION API
98
+ # ============================================================
99
+
100
+ def extract_faces(img):
101
+ faces = face_app.get(img)
102
+ results = []
103
+
104
+ for i, f in enumerate(faces):
105
+ x1, y1, x2, y2 = map(int, f.bbox)
106
+ crop = img[y1:y2, x1:x2]
107
 
108
+ _, buf = cv2.imencode(".jpg", crop)
109
+ b64 = base64.b64encode(buf).decode()
110
+
111
+ results.append({
112
+ "index": i,
113
+ "image": f"data:image/jpeg;base64,{b64}"
114
+ })
115
+
116
+ return results
117
 
118
  # ============================================================
119
  # WORKER
 
140
  name = os.path.splitext(filename)[0]
141
  out_path = f"/tmp/{name}_{tid}.png"
142
 
143
+ cv2.imwrite(out_path, result)
144
 
145
  TASKS[tid] = {
146
  "status": "done",
 
160
  app = FastAPI()
161
 
162
  # ============================================================
163
+ # UI
164
  # ============================================================
165
 
166
  @app.get("/", response_class=HTMLResponse)
 
170
  <html>
171
  <head>
172
  <meta name="viewport" content="width=device-width, initial-scale=1">
173
+ <title>AI Face Swap</title>
174
 
175
  <style>
176
+ body{font-family:sans-serif;background:#0f172a;color:white;text-align:center}
177
  .box{border:2px dashed #444;padding:20px;margin:10px;border-radius:10px;cursor:pointer}
178
  img{max-width:200px;margin-top:10px;border-radius:8px}
179
+ .faces{display:flex;gap:10px;justify-content:center;margin-top:10px}
180
+ .face{border:2px solid transparent;cursor:pointer}
181
+ .face img{width:60px;height:60px}
182
+ .face.active{border-color:#22c55e}
183
+ .loader{display:none;margin-top:10px}
184
+ .spin{border:4px solid #333;border-top:4px solid #22c55e;border-radius:50%;width:30px;height:30px;animation:spin 1s linear infinite}
185
+ @keyframes spin{100%{transform:rotate(360deg)}}
186
  </style>
187
 
188
  </head>
189
 
190
  <body>
191
 
192
+ <h2>⚡ AI Face Swap</h2>
193
 
194
+ <input type="file" id="src"><br>
195
  <img id="p1">
196
 
197
+ <div id="faces" class="faces"></div>
198
+
199
+ <input type="file" id="tgt"><br>
200
  <img id="p2">
201
 
202
  <br>
203
+ <button onclick="start()">Swap</button>
204
 
205
+ <div class="loader" id="loader">
206
+ <div class="spin"></div>
207
+ </div>
 
 
 
 
 
 
 
208
 
 
209
  <br>
210
+ <img id="out"><br>
211
+ <a id="dl" download="faceswap.png" style="display:none;color:lightgreen">Download</a>
212
 
213
  <script>
214
+ let selectedFace=0;
 
215
 
216
+ src.onchange=async()=>{
217
+ p1.src=URL.createObjectURL(src.files[0]);
218
 
219
+ let fd=new FormData();
220
+ fd.append("file",src.files[0]);
 
 
221
 
222
+ let res=await fetch("/detect-faces",{method:"POST",body:fd});
223
+ let data=await res.json();
 
 
 
 
224
 
225
+ faces.innerHTML="";
226
+ data.faces.forEach(f=>{
227
+ let d=document.createElement("div");
228
+ d.className="face";
229
+ d.innerHTML="<img src='"+f.image+"'>";
230
+ d.onclick=()=>{
231
+ document.querySelectorAll(".face").forEach(x=>x.classList.remove("active"));
232
+ d.classList.add("active");
233
+ selectedFace=f.index;
234
+ };
235
+ faces.appendChild(d);
236
+ });
237
 
238
+ if(faces.firstChild) faces.firstChild.classList.add("active");
239
+ };
240
+
241
+ tgt.onchange=()=>p2.src=URL.createObjectURL(tgt.files[0]);
242
 
243
  async function start(){
244
+ loader.style.display="block";
245
 
246
+ let fd=new FormData();
247
+ fd.append("source",src.files[0]);
248
+ fd.append("target",tgt.files[0]);
249
+ fd.append("face_index",selectedFace);
250
 
251
+ let r=await fetch("/swap",{method:"POST",body:fd});
252
+ let j=await r.json();
253
 
254
+ poll(j.task_id);
255
  }
256
 
257
  async function poll(id){
258
+ let r=await fetch("/status/"+id);
259
+ let j=await r.json();
260
+
261
+ if(j.status==="done"){
262
+ let img=await fetch("/result/"+id);
263
+ let blob=await img.blob();
264
+ let url=URL.createObjectURL(blob);
265
+
266
+ out.src=url;
267
+ dl.href=url;
268
+ dl.style.display="block";
269
+
270
+ loader.style.display="none";
271
+ }
272
+ else if(j.status==="failed"){
273
+ alert(j.error);
274
+ loader.style.display="none";
275
+ }
276
+ else{
277
+ setTimeout(()=>poll(id),800);
278
+ }
279
  }
280
  </script>
281
 
 
284
  """
285
 
286
  # ============================================================
287
+ # DETECT FACES
288
+ # ============================================================
289
+
290
+ @app.post("/detect-faces")
291
+ async def detect_faces(file: UploadFile = File(...)):
292
+ img = compress_and_resize(await file.read())
293
+ faces = extract_faces(img)
294
+ return {"faces": faces}
295
+
296
+ # ============================================================
297
+ # SWAP
298
  # ============================================================
299
 
300
  @app.post("/swap")
 
318
 
319
  return {"task_id": tid}
320
 
321
+ # ============================================================
322
+ # STATUS
323
+ # ============================================================
324
 
325
  @app.get("/status/{tid}")
326
  def status(tid: str):
 
329
  raise HTTPException(404)
330
  return TASKS[tid]
331
 
332
+ # ============================================================
333
+ # RESULT
334
+ # ============================================================
335
 
336
  @app.get("/result/{tid}")
337
  def result(tid: str):
 
343
  return StreamingResponse(
344
  open(task["result"], "rb"),
345
  media_type="image/png",
346
+ headers={"Content-Disposition": f'attachment; filename="{task["filename"]}"'}
 
 
347
  )