LogicGoInfotechSpaces commited on
Commit
303e4bb
·
verified ·
1 Parent(s): b76f561

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +210 -0
app.py CHANGED
@@ -851,6 +851,216 @@ async def multi_face_swap_api(
851
  except Exception as e:
852
  raise HTTPException(status_code=500, detail=str(e))
853
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
854
  # --------------------- Mount Gradio ---------------------
855
 
856
  multi_faceswap_app = build_multi_faceswap_gradio()
 
851
  except Exception as e:
852
  raise HTTPException(status_code=500, detail=str(e))
853
 
854
+ @fastapi_app.post("/face-swap-couple", dependencies=[Depends(verify_token)])
855
+ async def face_swap_api(
856
+ image1: UploadFile = File(...),
857
+ image2: Optional[UploadFile] = File(None),
858
+ target_category_id: str = Form(None),
859
+ new_category_id: str = Form(None),
860
+ user_id: Optional[str] = Form(None),
861
+ credentials: HTTPAuthorizationCredentials = Security(security)
862
+ ):
863
+ """
864
+ Production-ready face swap endpoint supporting:
865
+ - Multiple source images (image1 + optional image2)
866
+ - Gender-based pairing
867
+ - Merged faces from multiple sources
868
+ - Mandatory CodeFormer enhancement
869
+ """
870
+ start_time = datetime.utcnow()
871
+
872
+ try:
873
+ # -----------------------------
874
+ # Validate input
875
+ # -----------------------------
876
+ if target_category_id == "":
877
+ target_category_id = None
878
+ if new_category_id == "":
879
+ new_category_id = None
880
+ if user_id == "":
881
+ user_id = None
882
+
883
+ if target_category_id and new_category_id:
884
+ raise HTTPException(400, "Provide only one of new_category_id or target_category_id.")
885
+ if not target_category_id and not new_category_id:
886
+ raise HTTPException(400, "Either new_category_id or target_category_id is required.")
887
+
888
+ logger.info(f"[FaceSwap] Incoming request → target_category_id={target_category_id}, new_category_id={new_category_id}, user_id={user_id}")
889
+
890
+ # -----------------------------
891
+ # Read source images
892
+ # -----------------------------
893
+ src_images = []
894
+ img1_bytes = await image1.read()
895
+ src1 = cv2.imdecode(np.frombuffer(img1_bytes, np.uint8), cv2.IMREAD_COLOR)
896
+ if src1 is None:
897
+ raise HTTPException(400, "Invalid image1 data")
898
+ src_images.append(cv2.cvtColor(src1, cv2.COLOR_BGR2RGB))
899
+
900
+ if image2:
901
+ img2_bytes = await image2.read()
902
+ src2 = cv2.imdecode(np.frombuffer(img2_bytes, np.uint8), cv2.IMREAD_COLOR)
903
+ if src2 is not None:
904
+ src_images.append(cv2.cvtColor(src2, cv2.COLOR_BGR2RGB))
905
+
906
+ # -----------------------------
907
+ # Determine target image
908
+ # -----------------------------
909
+ target_bytes = None
910
+
911
+ if new_category_id:
912
+ doc = await subcategories_col.find_one({"asset_images._id": ObjectId(new_category_id)})
913
+ if not doc:
914
+ raise HTTPException(404, "Asset image not found")
915
+ asset = next((img for img in doc["asset_images"] if str(img["_id"]) == new_category_id), None)
916
+ if not asset:
917
+ raise HTTPException(404, "Asset image URL not found")
918
+ target_url = asset["url"]
919
+
920
+ elif target_category_id:
921
+ client = get_spaces_client()
922
+ base_prefix = "faceswap/target/"
923
+ resp = client.list_objects_v2(Bucket=DO_SPACES_BUCKET, Prefix=base_prefix, Delimiter="/")
924
+ categories = [p["Prefix"].split("/")[2] for p in resp.get("CommonPrefixes", [])]
925
+ target_url = None
926
+ for category in categories:
927
+ original_prefix = f"faceswap/target/{category}/original/"
928
+ objects = client.list_objects_v2(Bucket=DO_SPACES_BUCKET, Prefix=original_prefix).get("Contents", [])
929
+ original_filenames = sorted([obj["Key"].split("/")[-1] for obj in objects if obj["Key"].endswith(".png")])
930
+ for idx, filename in enumerate(original_filenames, start=1):
931
+ cid = f"{category.lower()}image_{idx}"
932
+ if cid == target_category_id:
933
+ target_url = f"{DO_SPACES_ENDPOINT}/{DO_SPACES_BUCKET}/{original_prefix}{filename}"
934
+ break
935
+ if target_url:
936
+ break
937
+ if not target_url:
938
+ raise HTTPException(404, "Target categoryId not found")
939
+
940
+ # -----------------------------
941
+ # Download target image
942
+ # -----------------------------
943
+ async with httpx.AsyncClient(timeout=30.0) as client:
944
+ resp = await client.get(target_url)
945
+ resp.raise_for_status()
946
+ target_bytes = resp.content
947
+
948
+ tgt_bgr = cv2.imdecode(np.frombuffer(target_bytes, np.uint8), cv2.IMREAD_COLOR)
949
+ if tgt_bgr is None:
950
+ raise HTTPException(400, "Invalid target image data")
951
+ target_rgb = cv2.cvtColor(tgt_bgr, cv2.COLOR_BGR2RGB)
952
+
953
+ # -----------------------------
954
+ # Merge all source faces
955
+ # -----------------------------
956
+ all_src_faces = []
957
+ for img in src_images:
958
+ faces = face_analysis_app.get(cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
959
+ all_src_faces.extend(faces)
960
+
961
+ if not all_src_faces:
962
+ raise HTTPException(400, "No faces detected in source images")
963
+
964
+ tgt_faces = face_analysis_app.get(tgt_bgr)
965
+ if not tgt_faces:
966
+ raise HTTPException(400, "No faces detected in target image")
967
+
968
+ # -----------------------------
969
+ # Gender-based pairing
970
+ # -----------------------------
971
+ def face_sort_key(face):
972
+ x1, y1, x2, y2 = face.bbox
973
+ area = (x2 - x1) * (y2 - y1)
974
+ cx = (x1 + x2) / 2
975
+ return (-area, cx)
976
+
977
+ # Separate by gender
978
+ src_male = sorted([f for f in all_src_faces if f.gender == 1], key=face_sort_key)
979
+ src_female = sorted([f for f in all_src_faces if f.gender == 0], key=face_sort_key)
980
+ tgt_male = sorted([f for f in tgt_faces if f.gender == 1], key=face_sort_key)
981
+ tgt_female = sorted([f for f in tgt_faces if f.gender == 0], key=face_sort_key)
982
+
983
+ pairs = []
984
+ for s, t in zip(src_male, tgt_male):
985
+ pairs.append((s, t))
986
+ for s, t in zip(src_female, tgt_female):
987
+ pairs.append((s, t))
988
+
989
+ # fallback if gender mismatch
990
+ if not pairs:
991
+ src_all = sorted(all_src_faces, key=face_sort_key)
992
+ tgt_all = sorted(tgt_faces, key=face_sort_key)
993
+ pairs = list(zip(src_all, tgt_all))
994
+
995
+ # -----------------------------
996
+ # Perform face swap
997
+ # -----------------------------
998
+ with swap_lock:
999
+ result_img = tgt_bgr.copy()
1000
+ for src_face, _ in pairs:
1001
+ current_faces = sorted(face_analysis_app.get(result_img), key=face_sort_key)
1002
+ candidates = [f for f in current_faces if f.gender == src_face.gender] or current_faces
1003
+ target_face = candidates[0]
1004
+ result_img = swapper.get(result_img, target_face, src_face, paste_back=True)
1005
+
1006
+ result_rgb = cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)
1007
+
1008
+ # -----------------------------
1009
+ # Mandatory enhancement
1010
+ # -----------------------------
1011
+ enhanced_rgb = mandatory_enhancement(result_rgb)
1012
+ enhanced_bgr = cv2.cvtColor(enhanced_rgb, cv2.COLOR_RGB2BGR)
1013
+
1014
+ # -----------------------------
1015
+ # Save, upload, compress
1016
+ # -----------------------------
1017
+ temp_dir = tempfile.mkdtemp(prefix="faceswap_")
1018
+ final_path = os.path.join(temp_dir, "result.png")
1019
+ cv2.imwrite(final_path, enhanced_bgr)
1020
+
1021
+ with open(final_path, "rb") as f:
1022
+ result_bytes = f.read()
1023
+
1024
+ result_key = f"faceswap/result/{uuid.uuid4().hex}_enhanced.png"
1025
+ result_url = upload_to_spaces(result_bytes, result_key)
1026
+
1027
+ compressed_bytes = compress_image(result_bytes, max_size=(1280, 1280), quality=72)
1028
+ compressed_key = f"faceswap/result/{uuid.uuid4().hex}_enhanced_compressed.jpg"
1029
+ compressed_url = upload_to_spaces(compressed_bytes, compressed_key, content_type="image/jpeg")
1030
+
1031
+ # -----------------------------
1032
+ # Log API usage
1033
+ # -----------------------------
1034
+ end_time = datetime.utcnow()
1035
+ response_time_ms = (end_time - start_time).total_seconds() * 1000
1036
+ if database:
1037
+ await database.api_logs.insert_one({
1038
+ "endpoint": "/face-swap",
1039
+ "status": "success",
1040
+ "response_time_ms": response_time_ms,
1041
+ "timestamp": end_time
1042
+ })
1043
+
1044
+ return {
1045
+ "result_key": result_key,
1046
+ "result_url": result_url,
1047
+ "compressed_url": compressed_url
1048
+ }
1049
+
1050
+ except Exception as e:
1051
+ end_time = datetime.utcnow()
1052
+ response_time_ms = (end_time - start_time).total_seconds() * 1000
1053
+ if database:
1054
+ await database.api_logs.insert_one({
1055
+ "endpoint": "/face-swap",
1056
+ "status": "fail",
1057
+ "response_time_ms": response_time_ms,
1058
+ "timestamp": end_time,
1059
+ "error": str(e)
1060
+ })
1061
+ raise HTTPException(500, f"Face swap failed: {str(e)}")
1062
+
1063
+
1064
  # --------------------- Mount Gradio ---------------------
1065
 
1066
  multi_faceswap_app = build_multi_faceswap_gradio()