LogicGoInfotechSpaces commited on
Commit
2f674a3
·
verified ·
1 Parent(s): ceb3a1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +386 -19
app.py CHANGED
@@ -851,6 +851,215 @@ async def multi_face_swap_api(
851
  except Exception as e:
852
  raise HTTPException(status_code=500, detail=str(e))
853
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
854
  @fastapi_app.post("/face-swap-couple", dependencies=[Depends(verify_token)])
855
  async def face_swap_api(
856
  image1: UploadFile = File(...),
@@ -904,51 +1113,208 @@ async def face_swap_api(
904
  src_images.append(cv2.cvtColor(src2, cv2.COLOR_BGR2RGB))
905
 
906
  # -----------------------------
907
- # Determine target image
908
  # -----------------------------
909
- target_bytes = None
910
-
911
  if new_category_id:
912
- doc = await subcategories_col.find_one({"asset_images._id": ObjectId(new_category_id)})
 
 
 
913
  if not doc:
914
- raise HTTPException(404, "Asset image not found")
915
- asset = next((img for img in doc["asset_images"] if str(img["_id"]) == new_category_id), None)
 
 
 
 
 
916
  if not asset:
917
  raise HTTPException(404, "Asset image URL not found")
 
918
  target_url = asset["url"]
 
919
 
920
- elif target_category_id:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
921
  client = get_spaces_client()
922
  base_prefix = "faceswap/target/"
923
- resp = client.list_objects_v2(Bucket=DO_SPACES_BUCKET, Prefix=base_prefix, Delimiter="/")
 
 
 
924
  categories = [p["Prefix"].split("/")[2] for p in resp.get("CommonPrefixes", [])]
925
- target_url = None
926
  for category in categories:
927
  original_prefix = f"faceswap/target/{category}/original/"
928
- objects = client.list_objects_v2(Bucket=DO_SPACES_BUCKET, Prefix=original_prefix).get("Contents", [])
929
- original_filenames = sorted([obj["Key"].split("/")[-1] for obj in objects if obj["Key"].endswith(".png")])
 
 
 
 
 
 
 
 
 
 
 
 
 
930
  for idx, filename in enumerate(original_filenames, start=1):
931
  cid = f"{category.lower()}image_{idx}"
932
  if cid == target_category_id:
933
  target_url = f"{DO_SPACES_ENDPOINT}/{DO_SPACES_BUCKET}/{original_prefix}{filename}"
934
  break
 
935
  if target_url:
936
  break
 
937
  if not target_url:
938
  raise HTTPException(404, "Target categoryId not found")
939
 
940
- # -----------------------------
941
- # Download target image
942
- # -----------------------------
943
  async with httpx.AsyncClient(timeout=30.0) as client:
944
- resp = await client.get(target_url)
945
- resp.raise_for_status()
946
- target_bytes = resp.content
947
 
948
- tgt_bgr = cv2.imdecode(np.frombuffer(target_bytes, np.uint8), cv2.IMREAD_COLOR)
949
  if tgt_bgr is None:
950
  raise HTTPException(400, "Invalid target image data")
951
- target_rgb = cv2.cvtColor(tgt_bgr, cv2.COLOR_BGR2RGB)
952
 
953
  # -----------------------------
954
  # Merge all source faces
@@ -1062,6 +1428,7 @@ async def face_swap_api(
1062
 
1063
 
1064
 
 
1065
  # --------------------- Mount Gradio ---------------------
1066
 
1067
  multi_faceswap_app = build_multi_faceswap_gradio()
 
851
  except Exception as e:
852
  raise HTTPException(status_code=500, detail=str(e))
853
 
854
+ # @fastapi_app.post("/face-swap-couple", dependencies=[Depends(verify_token)])
855
+ # async def face_swap_api(
856
+ # image1: UploadFile = File(...),
857
+ # image2: Optional[UploadFile] = File(None),
858
+ # target_category_id: str = Form(None),
859
+ # new_category_id: str = Form(None),
860
+ # user_id: Optional[str] = Form(None),
861
+ # credentials: HTTPAuthorizationCredentials = Security(security)
862
+ # ):
863
+ # """
864
+ # Production-ready face swap endpoint supporting:
865
+ # - Multiple source images (image1 + optional image2)
866
+ # - Gender-based pairing
867
+ # - Merged faces from multiple sources
868
+ # - Mandatory CodeFormer enhancement
869
+ # """
870
+ # start_time = datetime.utcnow()
871
+
872
+ # try:
873
+ # # -----------------------------
874
+ # # Validate input
875
+ # # -----------------------------
876
+ # if target_category_id == "":
877
+ # target_category_id = None
878
+ # if new_category_id == "":
879
+ # new_category_id = None
880
+ # if user_id == "":
881
+ # user_id = None
882
+
883
+ # if target_category_id and new_category_id:
884
+ # raise HTTPException(400, "Provide only one of new_category_id or target_category_id.")
885
+ # if not target_category_id and not new_category_id:
886
+ # raise HTTPException(400, "Either new_category_id or target_category_id is required.")
887
+
888
+ # logger.info(f"[FaceSwap] Incoming request → target_category_id={target_category_id}, new_category_id={new_category_id}, user_id={user_id}")
889
+
890
+ # # -----------------------------
891
+ # # Read source images
892
+ # # -----------------------------
893
+ # src_images = []
894
+ # img1_bytes = await image1.read()
895
+ # src1 = cv2.imdecode(np.frombuffer(img1_bytes, np.uint8), cv2.IMREAD_COLOR)
896
+ # if src1 is None:
897
+ # raise HTTPException(400, "Invalid image1 data")
898
+ # src_images.append(cv2.cvtColor(src1, cv2.COLOR_BGR2RGB))
899
+
900
+ # if image2:
901
+ # img2_bytes = await image2.read()
902
+ # src2 = cv2.imdecode(np.frombuffer(img2_bytes, np.uint8), cv2.IMREAD_COLOR)
903
+ # if src2 is not None:
904
+ # src_images.append(cv2.cvtColor(src2, cv2.COLOR_BGR2RGB))
905
+
906
+ # # -----------------------------
907
+ # # Determine target image
908
+ # # -----------------------------
909
+ # target_bytes = None
910
+
911
+ # if new_category_id:
912
+ # doc = await subcategories_col.find_one({"asset_images._id": ObjectId(new_category_id)})
913
+ # if not doc:
914
+ # raise HTTPException(404, "Asset image not found")
915
+ # asset = next((img for img in doc["asset_images"] if str(img["_id"]) == new_category_id), None)
916
+ # if not asset:
917
+ # raise HTTPException(404, "Asset image URL not found")
918
+ # target_url = asset["url"]
919
+
920
+ # elif target_category_id:
921
+ # client = get_spaces_client()
922
+ # base_prefix = "faceswap/target/"
923
+ # resp = client.list_objects_v2(Bucket=DO_SPACES_BUCKET, Prefix=base_prefix, Delimiter="/")
924
+ # categories = [p["Prefix"].split("/")[2] for p in resp.get("CommonPrefixes", [])]
925
+ # target_url = None
926
+ # for category in categories:
927
+ # original_prefix = f"faceswap/target/{category}/original/"
928
+ # objects = client.list_objects_v2(Bucket=DO_SPACES_BUCKET, Prefix=original_prefix).get("Contents", [])
929
+ # original_filenames = sorted([obj["Key"].split("/")[-1] for obj in objects if obj["Key"].endswith(".png")])
930
+ # for idx, filename in enumerate(original_filenames, start=1):
931
+ # cid = f"{category.lower()}image_{idx}"
932
+ # if cid == target_category_id:
933
+ # target_url = f"{DO_SPACES_ENDPOINT}/{DO_SPACES_BUCKET}/{original_prefix}{filename}"
934
+ # break
935
+ # if target_url:
936
+ # break
937
+ # if not target_url:
938
+ # raise HTTPException(404, "Target categoryId not found")
939
+
940
+ # # -----------------------------
941
+ # # Download target image
942
+ # # -----------------------------
943
+ # async with httpx.AsyncClient(timeout=30.0) as client:
944
+ # resp = await client.get(target_url)
945
+ # resp.raise_for_status()
946
+ # target_bytes = resp.content
947
+
948
+ # tgt_bgr = cv2.imdecode(np.frombuffer(target_bytes, np.uint8), cv2.IMREAD_COLOR)
949
+ # if tgt_bgr is None:
950
+ # raise HTTPException(400, "Invalid target image data")
951
+ # target_rgb = cv2.cvtColor(tgt_bgr, cv2.COLOR_BGR2RGB)
952
+
953
+ # # -----------------------------
954
+ # # Merge all source faces
955
+ # # -----------------------------
956
+ # all_src_faces = []
957
+ # for img in src_images:
958
+ # faces = face_analysis_app.get(cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
959
+ # all_src_faces.extend(faces)
960
+
961
+ # if not all_src_faces:
962
+ # raise HTTPException(400, "No faces detected in source images")
963
+
964
+ # tgt_faces = face_analysis_app.get(tgt_bgr)
965
+ # if not tgt_faces:
966
+ # raise HTTPException(400, "No faces detected in target image")
967
+
968
+ # # -----------------------------
969
+ # # Gender-based pairing
970
+ # # -----------------------------
971
+ # def face_sort_key(face):
972
+ # x1, y1, x2, y2 = face.bbox
973
+ # area = (x2 - x1) * (y2 - y1)
974
+ # cx = (x1 + x2) / 2
975
+ # return (-area, cx)
976
+
977
+ # # Separate by gender
978
+ # src_male = sorted([f for f in all_src_faces if f.gender == 1], key=face_sort_key)
979
+ # src_female = sorted([f for f in all_src_faces if f.gender == 0], key=face_sort_key)
980
+ # tgt_male = sorted([f for f in tgt_faces if f.gender == 1], key=face_sort_key)
981
+ # tgt_female = sorted([f for f in tgt_faces if f.gender == 0], key=face_sort_key)
982
+
983
+ # pairs = []
984
+ # for s, t in zip(src_male, tgt_male):
985
+ # pairs.append((s, t))
986
+ # for s, t in zip(src_female, tgt_female):
987
+ # pairs.append((s, t))
988
+
989
+ # # fallback if gender mismatch
990
+ # if not pairs:
991
+ # src_all = sorted(all_src_faces, key=face_sort_key)
992
+ # tgt_all = sorted(tgt_faces, key=face_sort_key)
993
+ # pairs = list(zip(src_all, tgt_all))
994
+
995
+ # # -----------------------------
996
+ # # Perform face swap
997
+ # # -----------------------------
998
+ # with swap_lock:
999
+ # result_img = tgt_bgr.copy()
1000
+ # for src_face, _ in pairs:
1001
+ # current_faces = sorted(face_analysis_app.get(result_img), key=face_sort_key)
1002
+ # candidates = [f for f in current_faces if f.gender == src_face.gender] or current_faces
1003
+ # target_face = candidates[0]
1004
+ # result_img = swapper.get(result_img, target_face, src_face, paste_back=True)
1005
+
1006
+ # result_rgb = cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)
1007
+
1008
+ # # -----------------------------
1009
+ # # Mandatory enhancement
1010
+ # # -----------------------------
1011
+ # enhanced_rgb = mandatory_enhancement(result_rgb)
1012
+ # enhanced_bgr = cv2.cvtColor(enhanced_rgb, cv2.COLOR_RGB2BGR)
1013
+
1014
+ # # -----------------------------
1015
+ # # Save, upload, compress
1016
+ # # -----------------------------
1017
+ # temp_dir = tempfile.mkdtemp(prefix="faceswap_")
1018
+ # final_path = os.path.join(temp_dir, "result.png")
1019
+ # cv2.imwrite(final_path, enhanced_bgr)
1020
+
1021
+ # with open(final_path, "rb") as f:
1022
+ # result_bytes = f.read()
1023
+
1024
+ # result_key = f"faceswap/result/{uuid.uuid4().hex}_enhanced.png"
1025
+ # result_url = upload_to_spaces(result_bytes, result_key)
1026
+
1027
+ # compressed_bytes = compress_image(result_bytes, max_size=(1280, 1280), quality=72)
1028
+ # compressed_key = f"faceswap/result/{uuid.uuid4().hex}_enhanced_compressed.jpg"
1029
+ # compressed_url = upload_to_spaces(compressed_bytes, compressed_key, content_type="image/jpeg")
1030
+
1031
+ # # -----------------------------
1032
+ # # Log API usage
1033
+ # # -----------------------------
1034
+ # end_time = datetime.utcnow()
1035
+ # response_time_ms = (end_time - start_time).total_seconds() * 1000
1036
+ # if database is not None:
1037
+ # await database.api_logs.insert_one({
1038
+ # "endpoint": "/face-swap",
1039
+ # "status": "success",
1040
+ # "response_time_ms": response_time_ms,
1041
+ # "timestamp": end_time
1042
+ # })
1043
+
1044
+ # return {
1045
+ # "result_key": result_key,
1046
+ # "result_url": result_url,
1047
+ # "compressed_url": compressed_url
1048
+ # }
1049
+
1050
+ # except Exception as e:
1051
+ # end_time = datetime.utcnow()
1052
+ # response_time_ms = (end_time - start_time).total_seconds() * 1000
1053
+ # if database is not None:
1054
+ # await database.api_logs.insert_one({
1055
+ # "endpoint": "/face-swap",
1056
+ # "status": "fail",
1057
+ # "response_time_ms": response_time_ms,
1058
+ # "timestamp": end_time,
1059
+ # "error": str(e)
1060
+ # })
1061
+ # raise HTTPException(500, f"Face swap failed: {str(e)}")
1062
+
1063
  @fastapi_app.post("/face-swap-couple", dependencies=[Depends(verify_token)])
1064
  async def face_swap_api(
1065
  image1: UploadFile = File(...),
 
1113
  src_images.append(cv2.cvtColor(src2, cv2.COLOR_BGR2RGB))
1114
 
1115
  # -----------------------------
1116
+ # Resolve target image
1117
  # -----------------------------
1118
+ target_url = None
 
1119
  if new_category_id:
1120
+ doc = await subcategories_col.find_one({
1121
+ "asset_images._id": ObjectId(new_category_id)
1122
+ })
1123
+
1124
  if not doc:
1125
+ raise HTTPException(404, "Asset image not found in database")
1126
+
1127
+ asset = next(
1128
+ (img for img in doc["asset_images"] if str(img["_id"]) == new_category_id),
1129
+ None
1130
+ )
1131
+
1132
  if not asset:
1133
  raise HTTPException(404, "Asset image URL not found")
1134
+
1135
  target_url = asset["url"]
1136
+ subcategory_oid = doc["_id"]
1137
 
1138
+ if user_id:
1139
+ try:
1140
+ user_id_clean = user_id.strip()
1141
+ if not user_id_clean:
1142
+ raise ValueError("user_id cannot be empty")
1143
+ try:
1144
+ user_oid = ObjectId(user_id_clean)
1145
+ except (InvalidId, ValueError):
1146
+ logger.error(f"Invalid user_id format: {user_id_clean}")
1147
+ raise ValueError(f"Invalid user_id format: {user_id_clean}")
1148
+
1149
+ now = datetime.utcnow()
1150
+
1151
+ # Step 1: ensure root document exists
1152
+ await media_clicks_col.update_one(
1153
+ {"userId": user_oid},
1154
+ {
1155
+ "$setOnInsert": {
1156
+ "userId": user_oid,
1157
+ "createdAt": now,
1158
+ "ai_edit_complete": 0,
1159
+ "ai_edit_daily_count": []
1160
+ }
1161
+ },
1162
+ upsert=True
1163
+ )
1164
+
1165
+ # Step 2: handle daily usage (binary, no duplicates)
1166
+ doc = await media_clicks_col.find_one(
1167
+ {"userId": user_oid},
1168
+ {"ai_edit_daily_count": 1}
1169
+ )
1170
+
1171
+ daily_entries = doc.get("ai_edit_daily_count", []) if doc else []
1172
+
1173
+ today_date = datetime(now.year, now.month, now.day)
1174
+
1175
+ daily_map = {}
1176
+ for entry in daily_entries:
1177
+ d = entry["date"]
1178
+ if isinstance(d, datetime):
1179
+ d = datetime(d.year, d.month, d.day)
1180
+ daily_map[d] = entry["count"]
1181
+
1182
+ last_date = max(daily_map.keys()) if daily_map else None
1183
+
1184
+ if last_date != today_date:
1185
+ daily_map[today_date] = 1
1186
+
1187
+ final_daily_entries = [
1188
+ {"date": d, "count": daily_map[d]}
1189
+ for d in sorted(daily_map.keys())
1190
+ ]
1191
+
1192
+ final_daily_entries = final_daily_entries[-32:]
1193
+
1194
+ await media_clicks_col.update_one(
1195
+ {"userId": user_oid},
1196
+ {
1197
+ "$set": {
1198
+ "ai_edit_daily_count": final_daily_entries,
1199
+ "updatedAt": now
1200
+ }
1201
+ }
1202
+ )
1203
+
1204
+ # Step 3: try updating existing subCategory
1205
+ update_result = await media_clicks_col.update_one(
1206
+ {
1207
+ "userId": user_oid,
1208
+ "subCategories.subCategoryId": subcategory_oid
1209
+ },
1210
+ {
1211
+ "$inc": {
1212
+ "subCategories.$.click_count": 1,
1213
+ "ai_edit_complete": 1
1214
+ },
1215
+ "$set": {
1216
+ "subCategories.$.lastClickedAt": now,
1217
+ "ai_edit_last_date": now,
1218
+ "updatedAt": now
1219
+ }
1220
+ }
1221
+ )
1222
+
1223
+ # Step 4: push subCategory if missing
1224
+ if update_result.matched_count == 0:
1225
+ await media_clicks_col.update_one(
1226
+ {"userId": user_oid},
1227
+ {
1228
+ "$inc": {
1229
+ "ai_edit_complete": 1
1230
+ },
1231
+ "$set": {
1232
+ "ai_edit_last_date": now,
1233
+ "updatedAt": now
1234
+ },
1235
+ "$push": {
1236
+ "subCategories": {
1237
+ "subCategoryId": subcategory_oid,
1238
+ "click_count": 1,
1239
+ "lastClickedAt": now
1240
+ }
1241
+ }
1242
+ }
1243
+ )
1244
+
1245
+ # Step 5: sort subCategories by lastClickedAt (ascending)
1246
+ user_doc = await media_clicks_col.find_one({"userId": user_oid})
1247
+ if user_doc and "subCategories" in user_doc:
1248
+ subcategories = user_doc["subCategories"]
1249
+ subcategories_sorted = sorted(
1250
+ subcategories,
1251
+ key=lambda x: x.get("lastClickedAt") if x.get("lastClickedAt") is not None else datetime.min
1252
+ )
1253
+ await media_clicks_col.update_one(
1254
+ {"userId": user_oid},
1255
+ {
1256
+ "$set": {
1257
+ "subCategories": subcategories_sorted,
1258
+ "updatedAt": now
1259
+ }
1260
+ }
1261
+ )
1262
+
1263
+ logger.info(
1264
+ "[MEDIA_CLICK] user=%s subCategory=%s ai_edit_complete++ daily_tracked",
1265
+ user_id,
1266
+ str(subcategory_oid)
1267
+ )
1268
+
1269
+ except Exception as media_err:
1270
+ logger.error(f"MEDIA_CLICK ERROR: {media_err}")
1271
+
1272
+ if target_category_id:
1273
  client = get_spaces_client()
1274
  base_prefix = "faceswap/target/"
1275
+ resp = client.list_objects_v2(
1276
+ Bucket=DO_SPACES_BUCKET, Prefix=base_prefix, Delimiter="/"
1277
+ )
1278
+
1279
  categories = [p["Prefix"].split("/")[2] for p in resp.get("CommonPrefixes", [])]
1280
+
1281
  for category in categories:
1282
  original_prefix = f"faceswap/target/{category}/original/"
1283
+ thumb_prefix = f"faceswap/target/{category}/thumb/"
1284
+
1285
+ original_objects = client.list_objects_v2(
1286
+ Bucket=DO_SPACES_BUCKET, Prefix=original_prefix
1287
+ ).get("Contents", [])
1288
+
1289
+ thumb_objects = client.list_objects_v2(
1290
+ Bucket=DO_SPACES_BUCKET, Prefix=thumb_prefix
1291
+ ).get("Contents", [])
1292
+
1293
+ original_filenames = sorted([
1294
+ obj["Key"].split("/")[-1] for obj in original_objects
1295
+ if obj["Key"].split("/")[-1].endswith(".png")
1296
+ ])
1297
+
1298
  for idx, filename in enumerate(original_filenames, start=1):
1299
  cid = f"{category.lower()}image_{idx}"
1300
  if cid == target_category_id:
1301
  target_url = f"{DO_SPACES_ENDPOINT}/{DO_SPACES_BUCKET}/{original_prefix}{filename}"
1302
  break
1303
+
1304
  if target_url:
1305
  break
1306
+
1307
  if not target_url:
1308
  raise HTTPException(404, "Target categoryId not found")
1309
 
 
 
 
1310
  async with httpx.AsyncClient(timeout=30.0) as client:
1311
+ response = await client.get(target_url)
1312
+ response.raise_for_status()
1313
+ tgt_bytes = response.content
1314
 
1315
+ tgt_bgr = cv2.imdecode(np.frombuffer(tgt_bytes, np.uint8), cv2.IMREAD_COLOR)
1316
  if tgt_bgr is None:
1317
  raise HTTPException(400, "Invalid target image data")
 
1318
 
1319
  # -----------------------------
1320
  # Merge all source faces
 
1428
 
1429
 
1430
 
1431
+
1432
  # --------------------- Mount Gradio ---------------------
1433
 
1434
  multi_faceswap_app = build_multi_faceswap_gradio()