Spaces:
Running
on
T4
Running
on
T4
Update app.py
Browse files
app.py
CHANGED
|
@@ -851,6 +851,216 @@ async def multi_face_swap_api(
|
|
| 851 |
except Exception as e:
|
| 852 |
raise HTTPException(status_code=500, detail=str(e))
|
| 853 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 854 |
# --------------------- Mount Gradio ---------------------
|
| 855 |
|
| 856 |
multi_faceswap_app = build_multi_faceswap_gradio()
|
|
|
|
| 851 |
except Exception as e:
|
| 852 |
raise HTTPException(status_code=500, detail=str(e))
|
| 853 |
|
| 854 |
+
@fastapi_app.post("/face-swap-couple", dependencies=[Depends(verify_token)])
|
| 855 |
+
async def face_swap_api(
|
| 856 |
+
image1: UploadFile = File(...),
|
| 857 |
+
image2: Optional[UploadFile] = File(None),
|
| 858 |
+
target_category_id: str = Form(None),
|
| 859 |
+
new_category_id: str = Form(None),
|
| 860 |
+
user_id: Optional[str] = Form(None),
|
| 861 |
+
credentials: HTTPAuthorizationCredentials = Security(security)
|
| 862 |
+
):
|
| 863 |
+
"""
|
| 864 |
+
Production-ready face swap endpoint supporting:
|
| 865 |
+
- Multiple source images (image1 + optional image2)
|
| 866 |
+
- Gender-based pairing
|
| 867 |
+
- Merged faces from multiple sources
|
| 868 |
+
- Mandatory CodeFormer enhancement
|
| 869 |
+
"""
|
| 870 |
+
start_time = datetime.utcnow()
|
| 871 |
+
|
| 872 |
+
try:
|
| 873 |
+
# -----------------------------
|
| 874 |
+
# Validate input
|
| 875 |
+
# -----------------------------
|
| 876 |
+
if target_category_id == "":
|
| 877 |
+
target_category_id = None
|
| 878 |
+
if new_category_id == "":
|
| 879 |
+
new_category_id = None
|
| 880 |
+
if user_id == "":
|
| 881 |
+
user_id = None
|
| 882 |
+
|
| 883 |
+
if target_category_id and new_category_id:
|
| 884 |
+
raise HTTPException(400, "Provide only one of new_category_id or target_category_id.")
|
| 885 |
+
if not target_category_id and not new_category_id:
|
| 886 |
+
raise HTTPException(400, "Either new_category_id or target_category_id is required.")
|
| 887 |
+
|
| 888 |
+
logger.info(f"[FaceSwap] Incoming request → target_category_id={target_category_id}, new_category_id={new_category_id}, user_id={user_id}")
|
| 889 |
+
|
| 890 |
+
# -----------------------------
|
| 891 |
+
# Read source images
|
| 892 |
+
# -----------------------------
|
| 893 |
+
src_images = []
|
| 894 |
+
img1_bytes = await image1.read()
|
| 895 |
+
src1 = cv2.imdecode(np.frombuffer(img1_bytes, np.uint8), cv2.IMREAD_COLOR)
|
| 896 |
+
if src1 is None:
|
| 897 |
+
raise HTTPException(400, "Invalid image1 data")
|
| 898 |
+
src_images.append(cv2.cvtColor(src1, cv2.COLOR_BGR2RGB))
|
| 899 |
+
|
| 900 |
+
if image2:
|
| 901 |
+
img2_bytes = await image2.read()
|
| 902 |
+
src2 = cv2.imdecode(np.frombuffer(img2_bytes, np.uint8), cv2.IMREAD_COLOR)
|
| 903 |
+
if src2 is not None:
|
| 904 |
+
src_images.append(cv2.cvtColor(src2, cv2.COLOR_BGR2RGB))
|
| 905 |
+
|
| 906 |
+
# -----------------------------
|
| 907 |
+
# Determine target image
|
| 908 |
+
# -----------------------------
|
| 909 |
+
target_bytes = None
|
| 910 |
+
|
| 911 |
+
if new_category_id:
|
| 912 |
+
doc = await subcategories_col.find_one({"asset_images._id": ObjectId(new_category_id)})
|
| 913 |
+
if not doc:
|
| 914 |
+
raise HTTPException(404, "Asset image not found")
|
| 915 |
+
asset = next((img for img in doc["asset_images"] if str(img["_id"]) == new_category_id), None)
|
| 916 |
+
if not asset:
|
| 917 |
+
raise HTTPException(404, "Asset image URL not found")
|
| 918 |
+
target_url = asset["url"]
|
| 919 |
+
|
| 920 |
+
elif target_category_id:
|
| 921 |
+
client = get_spaces_client()
|
| 922 |
+
base_prefix = "faceswap/target/"
|
| 923 |
+
resp = client.list_objects_v2(Bucket=DO_SPACES_BUCKET, Prefix=base_prefix, Delimiter="/")
|
| 924 |
+
categories = [p["Prefix"].split("/")[2] for p in resp.get("CommonPrefixes", [])]
|
| 925 |
+
target_url = None
|
| 926 |
+
for category in categories:
|
| 927 |
+
original_prefix = f"faceswap/target/{category}/original/"
|
| 928 |
+
objects = client.list_objects_v2(Bucket=DO_SPACES_BUCKET, Prefix=original_prefix).get("Contents", [])
|
| 929 |
+
original_filenames = sorted([obj["Key"].split("/")[-1] for obj in objects if obj["Key"].endswith(".png")])
|
| 930 |
+
for idx, filename in enumerate(original_filenames, start=1):
|
| 931 |
+
cid = f"{category.lower()}image_{idx}"
|
| 932 |
+
if cid == target_category_id:
|
| 933 |
+
target_url = f"{DO_SPACES_ENDPOINT}/{DO_SPACES_BUCKET}/{original_prefix}{filename}"
|
| 934 |
+
break
|
| 935 |
+
if target_url:
|
| 936 |
+
break
|
| 937 |
+
if not target_url:
|
| 938 |
+
raise HTTPException(404, "Target categoryId not found")
|
| 939 |
+
|
| 940 |
+
# -----------------------------
|
| 941 |
+
# Download target image
|
| 942 |
+
# -----------------------------
|
| 943 |
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
| 944 |
+
resp = await client.get(target_url)
|
| 945 |
+
resp.raise_for_status()
|
| 946 |
+
target_bytes = resp.content
|
| 947 |
+
|
| 948 |
+
tgt_bgr = cv2.imdecode(np.frombuffer(target_bytes, np.uint8), cv2.IMREAD_COLOR)
|
| 949 |
+
if tgt_bgr is None:
|
| 950 |
+
raise HTTPException(400, "Invalid target image data")
|
| 951 |
+
target_rgb = cv2.cvtColor(tgt_bgr, cv2.COLOR_BGR2RGB)
|
| 952 |
+
|
| 953 |
+
# -----------------------------
|
| 954 |
+
# Merge all source faces
|
| 955 |
+
# -----------------------------
|
| 956 |
+
all_src_faces = []
|
| 957 |
+
for img in src_images:
|
| 958 |
+
faces = face_analysis_app.get(cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
| 959 |
+
all_src_faces.extend(faces)
|
| 960 |
+
|
| 961 |
+
if not all_src_faces:
|
| 962 |
+
raise HTTPException(400, "No faces detected in source images")
|
| 963 |
+
|
| 964 |
+
tgt_faces = face_analysis_app.get(tgt_bgr)
|
| 965 |
+
if not tgt_faces:
|
| 966 |
+
raise HTTPException(400, "No faces detected in target image")
|
| 967 |
+
|
| 968 |
+
# -----------------------------
|
| 969 |
+
# Gender-based pairing
|
| 970 |
+
# -----------------------------
|
| 971 |
+
def face_sort_key(face):
|
| 972 |
+
x1, y1, x2, y2 = face.bbox
|
| 973 |
+
area = (x2 - x1) * (y2 - y1)
|
| 974 |
+
cx = (x1 + x2) / 2
|
| 975 |
+
return (-area, cx)
|
| 976 |
+
|
| 977 |
+
# Separate by gender
|
| 978 |
+
src_male = sorted([f for f in all_src_faces if f.gender == 1], key=face_sort_key)
|
| 979 |
+
src_female = sorted([f for f in all_src_faces if f.gender == 0], key=face_sort_key)
|
| 980 |
+
tgt_male = sorted([f for f in tgt_faces if f.gender == 1], key=face_sort_key)
|
| 981 |
+
tgt_female = sorted([f for f in tgt_faces if f.gender == 0], key=face_sort_key)
|
| 982 |
+
|
| 983 |
+
pairs = []
|
| 984 |
+
for s, t in zip(src_male, tgt_male):
|
| 985 |
+
pairs.append((s, t))
|
| 986 |
+
for s, t in zip(src_female, tgt_female):
|
| 987 |
+
pairs.append((s, t))
|
| 988 |
+
|
| 989 |
+
# fallback if gender mismatch
|
| 990 |
+
if not pairs:
|
| 991 |
+
src_all = sorted(all_src_faces, key=face_sort_key)
|
| 992 |
+
tgt_all = sorted(tgt_faces, key=face_sort_key)
|
| 993 |
+
pairs = list(zip(src_all, tgt_all))
|
| 994 |
+
|
| 995 |
+
# -----------------------------
|
| 996 |
+
# Perform face swap
|
| 997 |
+
# -----------------------------
|
| 998 |
+
with swap_lock:
|
| 999 |
+
result_img = tgt_bgr.copy()
|
| 1000 |
+
for src_face, _ in pairs:
|
| 1001 |
+
current_faces = sorted(face_analysis_app.get(result_img), key=face_sort_key)
|
| 1002 |
+
candidates = [f for f in current_faces if f.gender == src_face.gender] or current_faces
|
| 1003 |
+
target_face = candidates[0]
|
| 1004 |
+
result_img = swapper.get(result_img, target_face, src_face, paste_back=True)
|
| 1005 |
+
|
| 1006 |
+
result_rgb = cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)
|
| 1007 |
+
|
| 1008 |
+
# -----------------------------
|
| 1009 |
+
# Mandatory enhancement
|
| 1010 |
+
# -----------------------------
|
| 1011 |
+
enhanced_rgb = mandatory_enhancement(result_rgb)
|
| 1012 |
+
enhanced_bgr = cv2.cvtColor(enhanced_rgb, cv2.COLOR_RGB2BGR)
|
| 1013 |
+
|
| 1014 |
+
# -----------------------------
|
| 1015 |
+
# Save, upload, compress
|
| 1016 |
+
# -----------------------------
|
| 1017 |
+
temp_dir = tempfile.mkdtemp(prefix="faceswap_")
|
| 1018 |
+
final_path = os.path.join(temp_dir, "result.png")
|
| 1019 |
+
cv2.imwrite(final_path, enhanced_bgr)
|
| 1020 |
+
|
| 1021 |
+
with open(final_path, "rb") as f:
|
| 1022 |
+
result_bytes = f.read()
|
| 1023 |
+
|
| 1024 |
+
result_key = f"faceswap/result/{uuid.uuid4().hex}_enhanced.png"
|
| 1025 |
+
result_url = upload_to_spaces(result_bytes, result_key)
|
| 1026 |
+
|
| 1027 |
+
compressed_bytes = compress_image(result_bytes, max_size=(1280, 1280), quality=72)
|
| 1028 |
+
compressed_key = f"faceswap/result/{uuid.uuid4().hex}_enhanced_compressed.jpg"
|
| 1029 |
+
compressed_url = upload_to_spaces(compressed_bytes, compressed_key, content_type="image/jpeg")
|
| 1030 |
+
|
| 1031 |
+
# -----------------------------
|
| 1032 |
+
# Log API usage
|
| 1033 |
+
# -----------------------------
|
| 1034 |
+
end_time = datetime.utcnow()
|
| 1035 |
+
response_time_ms = (end_time - start_time).total_seconds() * 1000
|
| 1036 |
+
if database:
|
| 1037 |
+
await database.api_logs.insert_one({
|
| 1038 |
+
"endpoint": "/face-swap",
|
| 1039 |
+
"status": "success",
|
| 1040 |
+
"response_time_ms": response_time_ms,
|
| 1041 |
+
"timestamp": end_time
|
| 1042 |
+
})
|
| 1043 |
+
|
| 1044 |
+
return {
|
| 1045 |
+
"result_key": result_key,
|
| 1046 |
+
"result_url": result_url,
|
| 1047 |
+
"compressed_url": compressed_url
|
| 1048 |
+
}
|
| 1049 |
+
|
| 1050 |
+
except Exception as e:
|
| 1051 |
+
end_time = datetime.utcnow()
|
| 1052 |
+
response_time_ms = (end_time - start_time).total_seconds() * 1000
|
| 1053 |
+
if database:
|
| 1054 |
+
await database.api_logs.insert_one({
|
| 1055 |
+
"endpoint": "/face-swap",
|
| 1056 |
+
"status": "fail",
|
| 1057 |
+
"response_time_ms": response_time_ms,
|
| 1058 |
+
"timestamp": end_time,
|
| 1059 |
+
"error": str(e)
|
| 1060 |
+
})
|
| 1061 |
+
raise HTTPException(500, f"Face swap failed: {str(e)}")
|
| 1062 |
+
|
| 1063 |
+
|
| 1064 |
# --------------------- Mount Gradio ---------------------
|
| 1065 |
|
| 1066 |
multi_faceswap_app = build_multi_faceswap_gradio()
|