Bikini-Theme / app.py
LogicGoInfotechSpaces's picture
Update app.py
25592c0 verified
# --------------------- List Images Endpoint ---------------------
import os
os.environ["OMP_NUM_THREADS"] = "1"
import shutil
import uuid
import cv2
import numpy as np
import threading
import subprocess
import logging
from datetime import datetime, timezone,timedelta
import insightface
from insightface.app import FaceAnalysis
from huggingface_hub import hf_hub_download
from fastapi import FastAPI, UploadFile, File, HTTPException, Response, Depends, Security, Form
from fastapi.responses import RedirectResponse
from pydantic import BaseModel
from motor.motor_asyncio import AsyncIOMotorClient
import uvicorn
import gradio as gr
from gradio import mount_gradio_app
# DigitalOcean Spaces
import boto3
from botocore.client import Config
from io import BytesIO
from typing import Optional
import requests
import json
from bson import ObjectId
from PIL import Image
import io
# --------------------- Logging ---------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --------------------- Paths -----------------------
REPO_ID = "HariLogicgo/face_swap_models"
BASE_DIR = "./workspace"
MODELS_DIR = "./models"
os.makedirs(MODELS_DIR, exist_ok=True)
# --------------------- Secrets ---------------------
HF_TOKEN = os.getenv("HF_TOKEN") # Hugging Face private repo token
# Firebase credentials JSON
FIREBASE_CREDENTIALS_PATH = os.getenv("FIREBASE_CREDENTIALS_PATH")
# --------------------- DigitalOcean Spaces Credentials ---------------------
DO_SPACES_REGION = os.getenv("DO_SPACES_REGION", "blr1")
DO_SPACES_ENDPOINT = os.getenv("DO_SPACES_ENDPOINT", f"https://{DO_SPACES_REGION}.digitaloceanspaces.com")
DO_SPACES_KEY = os.getenv("DO_SPACES_KEY")
DO_SPACES_SECRET = os.getenv("DO_SPACES_SECRET")
DO_SPACES_BUCKET = os.getenv("DO_SPACES_BUCKET")
# --------------------- Firebase Auth ---------------------
import firebase_admin
from firebase_admin import credentials, auth
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
if not firebase_admin._apps:
FIREBASE_CREDENTIALS = os.getenv("FIREBASE_CREDENTIALS_PATH")
if not FIREBASE_CREDENTIALS:
raise RuntimeError("❌ FIREBASE_CREDENTIALS_PATH not set in environment variables")
try:
# Try parsing as JSON string
cred_dict = json.loads(FIREBASE_CREDENTIALS)
cred = credentials.Certificate(cred_dict)
logger.info("✅ Firebase initialized from JSON string in environment variable")
except json.JSONDecodeError:
# Fallback: assume it's a file path
cred = credentials.Certificate(FIREBASE_CREDENTIALS)
logger.info("✅ Firebase initialized from JSON file path")
firebase_admin.initialize_app(cred)
security = HTTPBearer()
def verify_firebase_token(credentials: HTTPAuthorizationCredentials = Security(security)):
"""Verify Firebase ID token from Authorization header."""
try:
id_token = credentials.credentials
decoded_token = auth.verify_id_token(id_token)
user_email = decoded_token.get("email")
if not user_email:
raise HTTPException(status_code=401, detail="Firebase token invalid or missing email")
return user_email
except Exception as e:
logger.error(f"Firebase auth failed: {e}")
raise HTTPException(status_code=401, detail="Unauthorized: Invalid Firebase token")
# --------------------- Download Models ---------------------
def download_models():
logger.info("Downloading models from private HF repo...")
inswapper_path = hf_hub_download(
repo_id=REPO_ID,
filename="models/inswapper_128.onnx",
repo_type="model",
local_dir=MODELS_DIR,
token=HF_TOKEN
)
buffalo_files = [
"1k3d68.onnx",
"2d106det.onnx",
"genderage.onnx",
"det_10g.onnx",
"w600k_r50.onnx"
]
for f in buffalo_files:
hf_hub_download(
repo_id=REPO_ID,
filename=f"models/buffalo_l/{f}",
repo_type="model",
local_dir=MODELS_DIR,
token=HF_TOKEN
)
logger.info("Models downloaded successfully")
return inswapper_path
inswapper_path = download_models()
# --------------------- Face Analysis + Swapper ---------------------
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
face_analysis_app = FaceAnalysis(name="buffalo_l", root=MODELS_DIR, providers=providers)
face_analysis_app.prepare(ctx_id=0, det_size=(640, 640))
swapper = insightface.model_zoo.get_model(inswapper_path, providers=providers)
# --------------------- CodeFormer ---------------------
CODEFORMER_PATH = "CodeFormer/inference_codeformer.py"
def ensure_codeformer():
if not os.path.exists("CodeFormer"):
subprocess.run("git clone https://github.com/sczhou/CodeFormer.git", shell=True, check=True)
subprocess.run("pip install -r CodeFormer/requirements.txt", shell=True, check=True)
subprocess.run("python CodeFormer/basicsr/setup.py develop", shell=True, check=True)
subprocess.run("python CodeFormer/scripts/download_pretrained_models.py facelib", shell=True, check=True)
subprocess.run("python CodeFormer/scripts/download_pretrained_models.py CodeFormer", shell=True, check=True)
ensure_codeformer()
# --------------------- MongoDB ---------------------
MONGODB_URL = os.getenv("MONGODB_URL")
client = None
database = None
# --------------------- Admin Panel DB (categories + subcategories + media_clicks) ---------------------
ADMIN_MONGO_URL = os.getenv("ADMIN_MONGO_URL")
admin_client = AsyncIOMotorClient(ADMIN_MONGO_URL)
admin_db = admin_client.adminPanel
# Collections
categories_col = admin_db.categories
subcategories_col = admin_db.subcategories
media_clicks_col = admin_db.media_clicks
MAX_COMPRESSED_SIZE = 2 * 1024 * 1024 # 2 MB
def compress_image_bytes(
image_bytes: bytes,
max_dim: int = 1280
) -> bytes:
"""
Compress image bytes by reducing dimensions + quality.
Guaranteed <= 2MB (best-effort).
"""
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Reduce dimensions (keep aspect ratio)
img.thumbnail((max_dim, max_dim), Image.LANCZOS)
quality = 80
buffer = io.BytesIO()
while quality >= 40:
buffer.seek(0)
buffer.truncate()
img.save(
buffer,
format="PNG", # keep PNG since your output expects .png
optimize=True
)
# PNG ignores quality, so if still big → fallback to JPEG
if buffer.tell() <= MAX_COMPRESSED_SIZE:
return buffer.getvalue()
quality -= 5
# ---- FALLBACK TO JPEG (much smaller) ----
buffer = io.BytesIO()
img.save(
buffer,
format="JPEG",
quality=70,
optimize=True,
progressive=True
)
return buffer.getvalue()
# --------------------- FastAPI ---------------------
fastapi_app = FastAPI()
@fastapi_app.on_event("startup")
async def startup_db():
global client, database
logger.info("Initializing MongoDB for API logs...")
client = AsyncIOMotorClient(MONGODB_URL)
database = client.FaceSwap
logger.info("MongoDB initialized for API logs")
@fastapi_app.on_event("shutdown")
async def shutdown_db():
global client
if client:
client.close()
logger.info("MongoDB connection closed")
# --------------------- Logging API Hits ---------------------
async def log_faceswap_hit(user_email: str, status: str, start_time: datetime, end_time: datetime):
global database
if database is None:
return
response_time_ms = (end_time - start_time).total_seconds() * 1000
await database.api_logs.insert_one({
"user": user_email,
"endpoint": "/face-swap",
"status": status,
"start_time": start_time,
"end_time": end_time,
"response_time_ms": response_time_ms
})
# --------------------- Media Click Logging Helper ---------------------
# --------------------- Media Click Logging Helper ---------------------
async def log_media_click(user_id: str, category_oid_str: str):
"""
Logs a click event to the media_clicks collection against the Category.
ai_edit_daily_count is binary per day (no duplicate dates).
"""
try:
user_oid = ObjectId(user_id.strip())
category_oid = ObjectId(category_oid_str.strip())
now = datetime.utcnow()
# Normalize today (UTC midnight)
today_date = datetime(now.year, now.month, now.day)
# -------------------------------------------------
# STEP 1: Ensure root document exists
# -------------------------------------------------
await media_clicks_col.update_one(
{"userId": user_oid},
{
"$setOnInsert": {
"userId": user_oid,
"createdAt": now,
"ai_edit_complete": 0,
"ai_edit_daily_count": []
}
},
upsert=True
)
# -------------------------------------------------
# STEP 2: FIXED DAILY BINARY TRACKING
# -------------------------------------------------
doc = await media_clicks_col.find_one(
{"userId": user_oid},
{"ai_edit_daily_count": 1}
)
daily_entries = doc.get("ai_edit_daily_count", []) if doc else []
# Convert to date -> count map (unique by design)
daily_map = {
entry["date"]: entry["count"]
for entry in daily_entries
}
# Determine last recorded date
last_date = max(daily_map.keys()) if daily_map else today_date
# Fill ALL missing days with 0
next_day = last_date + timedelta(days=1)
while next_day < today_date:
daily_map.setdefault(next_day, 0)
next_day += timedelta(days=1)
# Mark today as used (binary)
daily_map[today_date] = 1
# Rebuild sorted list (old → new)
final_daily_entries = [
{"date": d, "count": daily_map[d]}
for d in sorted(daily_map.keys())
]
# Keep last 32 days only
final_daily_entries = final_daily_entries[-32:]
# Atomic replace (NO $push)
await media_clicks_col.update_one(
{"userId": user_oid},
{
"$set": {
"ai_edit_daily_count": final_daily_entries,
"ai_edit_last_date": now,
"updatedAt": now
}
}
)
# -------------------------------------------------
# STEP 3: CATEGORY CLICK LOGIC (DATES CAN REPEAT)
# -------------------------------------------------
update_result = await media_clicks_col.update_one(
{
"userId": user_oid,
"categories.categoryId": category_oid
},
{
"$inc": {
"categories.$.click_count": 1,
"ai_edit_complete": 1
},
"$set": {
"categories.$.lastClickedAt": now,
"updatedAt": now
}
}
)
# -------------------------------------------------
# STEP 4: Push category if missing (order = time)
# -------------------------------------------------
if update_result.matched_count == 0:
await media_clicks_col.update_one(
{"userId": user_oid},
{
"$inc": {"ai_edit_complete": 1},
"$set": {"updatedAt": now},
"$push": {
"categories": {
"categoryId": category_oid,
"click_count": 1,
"lastClickedAt": now
}
}
}
)
logger.info(
"[MEDIA_CLICK] user=%s category=%s daily_binary_tracked",
user_id,
category_oid_str
)
except Exception as media_err:
logger.error("MEDIA_CLICK LOGGING ERROR: %s", media_err)
# async def log_media_click(user_id: str, category_oid_str: str):
# """
# Logs a click event to the media_clicks collection against the Category.
# Safely updates ai_edit_last_date, ai_edit_complete, and ai_edit_daily_count.
# """
# try:
# user_oid = ObjectId(user_id.strip())
# category_oid = ObjectId(category_oid_str.strip())
# now = datetime.utcnow()
# # Normalize dates (UTC midnight)
# today_date = datetime(now.year, now.month, now.day)
# yesterday_date = today_date - timedelta(days=1)
# # Optional safety check
# if not await categories_col.find_one({"_id": category_oid}):
# logger.warning(
# "Category ID %s not found. Skipping media click logging.",
# category_oid_str
# )
# return
# # -------------------------------------------------
# # STEP 1: Ensure user document + root fields exist
# # -------------------------------------------------
# await media_clicks_col.update_one(
# {"userId": user_oid},
# {
# "$setOnInsert": {
# "userId": user_oid,
# "createdAt": now,
# "ai_edit_complete": 0,
# "ai_edit_daily_count": []
# }
# },
# upsert=True
# )
# # -------------------------------------------------
# # STEP 1.5: Handle ai_edit_daily_count
# # -------------------------------------------------
# doc = await media_clicks_col.find_one(
# {"userId": user_oid},
# {"ai_edit_daily_count": 1}
# )
# daily_entries = doc.get("ai_edit_daily_count", []) if doc else []
# daily_updates = []
# if not daily_entries:
# # First-ever usage → only today
# daily_updates.append({"date": today_date, "count": 1})
# else:
# # Build existing date map for lookup
# existing_dates = {entry["date"].date(): entry["count"] for entry in daily_entries}
# last_date_in_db = max(entry["date"].date() for entry in daily_entries)
# # Fill missing days between last recorded day and today-1
# next_day = last_date_in_db + timedelta(days=1)
# while next_day < today_date:
# if next_day not in existing_dates:
# daily_updates.append({"date": next_day, "count": 0})
# next_day += timedelta(days=1)
# # Add today if not already present
# if today_date not in existing_dates:
# daily_updates.append({"date": today_date, "count": 1})
# # Push updates if any
# if daily_updates:
# await media_clicks_col.update_one(
# {"userId": user_oid},
# {"$push": {"ai_edit_daily_count": {"$each": daily_updates}}}
# )
# # Sort oldest → newest and trim to last 32 entries
# doc = await media_clicks_col.find_one({"userId": user_oid}, {"ai_edit_daily_count": 1})
# daily_entries = doc.get("ai_edit_daily_count", []) if doc else []
# daily_entries.sort(key=lambda x: x["date"])
# if len(daily_entries) > 32:
# daily_entries = daily_entries[-32:]
# await media_clicks_col.update_one(
# {"userId": user_oid},
# {"$set": {"ai_edit_daily_count": daily_entries}}
# )
# # -------------------------------------------------
# # STEP 2: Try updating existing category
# # -------------------------------------------------
# update_result = await media_clicks_col.update_one(
# {
# "userId": user_oid,
# "categories.categoryId": category_oid
# },
# {
# "$inc": {
# "categories.$.click_count": 1,
# "ai_edit_complete": 1
# },
# "$set": {
# "categories.$.lastClickedAt": now,
# "ai_edit_last_date": now,
# "updatedAt": now
# }
# }
# )
# # -------------------------------------------------
# # STEP 3: Category not present → push new
# # -------------------------------------------------
# if update_result.matched_count == 0:
# await media_clicks_col.update_one(
# {"userId": user_oid},
# {
# "$inc": {
# "ai_edit_complete": 1
# },
# "$set": {
# "ai_edit_last_date": now,
# "updatedAt": now
# },
# "$push": {
# "categories": {
# "categoryId": category_oid,
# "click_count": 1,
# "lastClickedAt": now
# }
# }
# }
# )
# logger.info(
# "[MEDIA_CLICK] user=%s category=%s ai_edit_complete incremented & daily_tracked",
# user_id,
# category_oid_str
# )
# except Exception as media_err:
# logger.error("MEDIA_CLICK LOGGING ERROR: %s", media_err)
# --------------------- Face Swap Pipeline ---------------------
swap_lock = threading.Lock()
def face_swap_and_enhance(src_img, tgt_img, temp_dir="/tmp/faceswap_work"):
try:
with swap_lock:
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
os.makedirs(temp_dir, exist_ok=True)
src_bgr = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR)
tgt_bgr_full = cv2.cvtColor(tgt_img, cv2.COLOR_RGB2BGR)
src_faces = face_analysis_app.get(src_bgr)
tgt_faces = face_analysis_app.get(tgt_bgr_full)
if not src_faces or not tgt_faces:
return None, None, "❌ Face not detected in source or target image"
src_face0 = src_faces[0]
tgt_face0 = tgt_faces[0]
swapped_bgr_full = swapper.get(tgt_bgr_full, tgt_face0, src_face0)
if swapped_bgr_full is None:
return None, None, "❌ Face swap failed"
swapped_path = os.path.join(temp_dir, f"swapped_{uuid.uuid4().hex[:8]}.jpg")
cv2.imwrite(swapped_path, swapped_bgr_full)
cmd = f"python {CODEFORMER_PATH} -w 0.7 --input_path {swapped_path} --output_path {temp_dir} --bg_upsampler realesrgan --face_upsample"
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
if result.returncode != 0:
return None, None, f"❌ CodeFormer failed:\n{result.stderr}"
final_results_dir = os.path.join(temp_dir, "final_results")
final_files = [f for f in os.listdir(final_results_dir) if f.endswith(".png")]
if not final_files:
return None, None, "❌ No enhanced image found"
final_path = os.path.join(final_results_dir, final_files[0])
final_img = cv2.cvtColor(cv2.imread(final_path), cv2.COLOR_BGR2RGB)
return final_img, final_path, ""
except Exception as e:
return None, None, f"❌ Error: {str(e)}"
# --------------------- Gradio ---------------------
with gr.Blocks() as demo:
gr.Markdown("Face Swap")
with gr.Row():
src_input = gr.Image(type="numpy", label="Upload Your Face")
tgt_input = gr.Image(type="numpy", label="Upload Target Image")
btn = gr.Button("Swap Face")
output_img = gr.Image(type="numpy", label="Enhanced Output")
download = gr.File(label="⬇️ Download Enhanced Image")
error_box = gr.Textbox(label="Logs / Errors", interactive=False)
def process(src, tgt):
img, path, err = face_swap_and_enhance(src, tgt)
return img, path, err
btn.click(process, [src_input, tgt_input], [output_img, download, error_box])
# --------------------- DigitalOcean Spaces Helper ---------------------
def get_spaces_client():
session = boto3.session.Session()
client = session.client(
's3',
region_name=DO_SPACES_REGION,
endpoint_url=DO_SPACES_ENDPOINT,
aws_access_key_id=DO_SPACES_KEY,
aws_secret_access_key=DO_SPACES_SECRET,
config=Config(signature_version='s3v4')
)
return client
def upload_to_spaces(file_bytes, key, content_type="image/png"):
client = get_spaces_client()
client.put_object(Bucket=DO_SPACES_BUCKET, Key=key, Body=file_bytes, ContentType=content_type, ACL='public-read')
return f"{DO_SPACES_ENDPOINT}/{DO_SPACES_BUCKET}/{key}"
def download_from_spaces(key):
client = get_spaces_client()
obj = client.get_object(Bucket=DO_SPACES_BUCKET, Key=key)
return obj['Body'].read()
# --------------------- API Endpoints ---------------------
@fastapi_app.get("/")
def root():
return RedirectResponse("/gradio")
@fastapi_app.get("/health")
async def health():
return {"status": "healthy"}
@fastapi_app.post("/face-swap")
async def face_swap_api(
source: UploadFile = File(...),
target_category_id: Optional[str] = Form(None),
category_id: Optional[str] = Form(None), # <-- The ID to use for media click logging
user_id: Optional[str] = Form(None),
new_subcategory_id: Optional[str] = Form(None),
user_email: str = Depends(verify_firebase_token)
):
start_time = datetime.now(timezone.utc)
target_url = None
try:
# ---------------------------------------------------------
# NORMALIZE EMPTY STRINGS
# ---------------------------------------------------------
if target_category_id == "": target_category_id = None
if new_subcategory_id == "": new_subcategory_id = None
if category_id == "": category_id = None
if user_id == "": user_id = None
# ---------------------------------------------------------
# REQUEST TRACE LOG
# ---------------------------------------------------------
logger.info(
"FACE_SWAP_REQUEST | user_id=%s | category_id=%s | new_subcategory_id=%s",
user_id,
category_id,
new_subcategory_id
)
# Optional warning if nothing useful is provided
if not user_id and not category_id and not new_subcategory_id:
logger.warning(
"FACE_SWAP_REQUEST_MISSING_IDS | user_id=%s | category_id=%s | new_subcategory_id=%s",
user_id,
category_id,
new_subcategory_id
)
# ---------------------------------------------------------
# STRICT XOR VALIDATION
# ---------------------------------------------------------
target_provided = target_category_id is not None
new_sub_provided = new_subcategory_id is not None
if target_provided and new_sub_provided:
raise HTTPException(
status_code=400,
detail="Provide ONLY ONE of: target_category_id OR new_subcategory_id"
)
if not target_provided and not new_sub_provided:
raise HTTPException(
status_code=400,
detail="Either target_category_id OR new_subcategory_id is required"
)
# ---------------------------------------------------------
# READ SOURCE IMAGE
# ---------------------------------------------------------
src_bytes = await source.read()
src_key = f"bikini-theme/source/{uuid.uuid4().hex}_{source.filename}"
upload_to_spaces(src_bytes, src_key, content_type=source.content_type)
# ---------------------------------------------------------
# TARGET IMAGE RETRIEVAL
# ---------------------------------------------------------
if target_provided:
# CASE 1 — Old behavior (use DO Spaces target image)
target_filename = f"{target_category_id}.png"
target_url = (
f"https://{DO_SPACES_BUCKET}.{DO_SPACES_REGION}."
f"digitaloceanspaces.com/bikini-theme/target/{target_filename}"
)
elif new_sub_provided:
# CASE 2 — New behavior (use subcategory asset image)
try:
asset_oid = ObjectId(new_subcategory_id)
except:
raise HTTPException(400, "Invalid new_subcategory_id format.")
# 1. Find subcategory asset by asset_images._id
subcat_doc = await subcategories_col.find_one(
{"asset_images._id": asset_oid},
{"asset_images.$": 1} # Only need the asset image URL
)
if not subcat_doc or "asset_images" not in subcat_doc:
await log_faceswap_hit(user_email, "error", start_time, datetime.now(timezone.utc))
raise HTTPException(
status_code=404,
detail="Subcategory asset image not found in DB."
)
asset_url = subcat_doc["asset_images"][0]["url"]
target_url = asset_url
# ---------------------------------------------------------
# DOWNLOAD TARGET IMAGE
# ---------------------------------------------------------
resp = requests.get(target_url)
if resp.status_code != 200:
await log_faceswap_hit(user_email, "error", start_time, datetime.now(timezone.utc))
raise HTTPException(status_code=404, detail=f"Target image not found or download failed: {target_url}")
tgt_bytes = resp.content
# ---------------------------------------------------------
# MEDIA CLICK LOGGING (Unified Logic)
# ---------------------------------------------------------
if user_id and category_id:
# Log only if both optional logging fields are provided
await log_media_click(user_id, category_id)
# ---------------------------------------------------------
# DECODE & FACE SWAP
# ---------------------------------------------------------
src_array = np.frombuffer(src_bytes, np.uint8)
tgt_array = np.frombuffer(tgt_bytes, np.uint8)
src_bgr = cv2.imdecode(src_array, cv2.IMREAD_COLOR)
tgt_bgr = cv2.imdecode(tgt_array, cv2.IMREAD_COLOR)
if src_bgr is None or tgt_bgr is None:
await log_faceswap_hit(user_email, "error", start_time, datetime.now(timezone.utc))
raise HTTPException(status_code=400, detail="Invalid image data")
# Convert to RGB for processing
src_rgb = cv2.cvtColor(src_bgr, cv2.COLOR_BGR2RGB)
tgt_rgb = cv2.cvtColor(tgt_bgr, cv2.COLOR_BGR2RGB)
# FACE SWAP & ENHANCE
final_img, final_path, err = face_swap_and_enhance(src_rgb, tgt_rgb)
if err:
await log_faceswap_hit(user_email, "error", start_time, datetime.now(timezone.utc))
raise HTTPException(status_code=500, detail=err)
# Save final output to DO Spaces
with open(final_path, "rb") as f:
result_bytes = f.read()
result_key = f"bikini-theme/result/{uuid.uuid4().hex}_enhanced.png"
result_url = upload_to_spaces(result_bytes, result_key, "image/png")
# ---------------------------------------------------------
# COMPRESS RESULT IMAGE
# ---------------------------------------------------------
compressed_bytes = compress_image_bytes(
image_bytes=result_bytes,
max_dim=1280
)
compressed_key = result_key.replace("_enhanced.png", "_compressed.png")
compressed_url = upload_to_spaces(
compressed_bytes,
compressed_key,
"image/png"
)
# ---------------------------------------------------------
# SUCCESS RESPONSE
# ---------------------------------------------------------
await log_faceswap_hit(user_email, "success", start_time, datetime.now(timezone.utc))
return {
"result_url": result_url,
"category_id": category_id,
"user_id": user_id,
"new_subcategory_id": new_subcategory_id,
"Compressed_Image_URL": compressed_url
}
except HTTPException:
raise
except Exception as e:
end_time = datetime.now(timezone.utc)
try:
await log_faceswap_hit(user_email, "error", start_time, end_time)
except Exception as log_exc:
logger.error("Failed to write log_faceswap_hit: %s", log_exc)
logger.error(f"Critical /face-swap error: {e}")
raise HTTPException(status_code=500, detail=f"Face swap failed: Internal server error.")
@fastapi_app.get("/preview/{result_key:path}")
async def preview_result(result_key: str):
try:
img_bytes = download_from_spaces(result_key)
except Exception:
raise HTTPException(status_code=404, detail="Result not found")
return Response(
content=img_bytes,
media_type="image/png",
headers={"Content-Disposition": "inline; filename=result.png"}
)
# --------------------- Mount Gradio ---------------------
fastapi_app = mount_gradio_app(fastapi_app, demo, path="/gradio")
if __name__ == "__main__":
uvicorn.run(fastapi_app, host="0.0.0.0", port=7860)
# # --------------------- List Images Endpoint ---------------------
# import os
# os.environ["OMP_NUM_THREADS"] = "1"
# import shutil
# import uuid
# import cv2
# import numpy as np
# import threading
# import subprocess
# import logging
# from datetime import datetime, timezone
# import insightface
# from insightface.app import FaceAnalysis
# from huggingface_hub import hf_hub_download
# from fastapi import FastAPI, UploadFile, File, HTTPException, Response, Depends, Security, Form
# from fastapi.responses import RedirectResponse
# from pydantic import BaseModel
# from motor.motor_asyncio import AsyncIOMotorClient
# import uvicorn
# import gradio as gr
# from gradio import mount_gradio_app
# # DigitalOcean Spaces
# import boto3
# from botocore.client import Config
# from io import BytesIO
# from typing import Optional
# import requests
# import json
# from bson import ObjectId
# # --------------------- Logging ---------------------
# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)
# # --------------------- Paths -----------------------
# REPO_ID = "HariLogicgo/face_swap_models"
# BASE_DIR = "./workspace"
# MODELS_DIR = "./models"
# os.makedirs(MODELS_DIR, exist_ok=True)
# # --------------------- Secrets ---------------------
# HF_TOKEN = os.getenv("HF_TOKEN") # Hugging Face private repo token
# # Firebase credentials JSON
# FIREBASE_CREDENTIALS_PATH = os.getenv("FIREBASE_CREDENTIALS_PATH")
# # --------------------- DigitalOcean Spaces Credentials ---------------------
# DO_SPACES_REGION = os.getenv("DO_SPACES_REGION", "blr1")
# DO_SPACES_ENDPOINT = os.getenv("DO_SPACES_ENDPOINT", f"https://{DO_SPACES_REGION}.digitaloceanspaces.com")
# DO_SPACES_KEY = os.getenv("DO_SPACES_KEY")
# DO_SPACES_SECRET = os.getenv("DO_SPACES_SECRET")
# DO_SPACES_BUCKET = os.getenv("DO_SPACES_BUCKET")
# # --------------------- Firebase Auth ---------------------
# import firebase_admin
# from firebase_admin import credentials, auth
# from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
# if not firebase_admin._apps:
# FIREBASE_CREDENTIALS = os.getenv("FIREBASE_CREDENTIALS_PATH")
# if not FIREBASE_CREDENTIALS:
# raise RuntimeError("❌ FIREBASE_CREDENTIALS_PATH not set in environment variables")
# try:
# # Try parsing as JSON string
# cred_dict = json.loads(FIREBASE_CREDENTIALS)
# cred = credentials.Certificate(cred_dict)
# logger.info("✅ Firebase initialized from JSON string in environment variable")
# except json.JSONDecodeError:
# # Fallback: assume it's a file path
# cred = credentials.Certificate(FIREBASE_CREDENTIALS)
# logger.info("✅ Firebase initialized from JSON file path")
# firebase_admin.initialize_app(cred)
# security = HTTPBearer()
# def verify_firebase_token(credentials: HTTPAuthorizationCredentials = Security(security)):
# """Verify Firebase ID token from Authorization header."""
# try:
# id_token = credentials.credentials
# decoded_token = auth.verify_id_token(id_token)
# user_email = decoded_token.get("email")
# if not user_email:
# raise HTTPException(status_code=401, detail="Firebase token invalid or missing email")
# return user_email
# except Exception as e:
# logger.error(f"Firebase auth failed: {e}")
# raise HTTPException(status_code=401, detail="Unauthorized: Invalid Firebase token")
# # --------------------- Download Models ---------------------
# def download_models():
# logger.info("Downloading models from private HF repo...")
# inswapper_path = hf_hub_download(
# repo_id=REPO_ID,
# filename="models/inswapper_128.onnx",
# repo_type="model",
# local_dir=MODELS_DIR,
# token=HF_TOKEN
# )
# buffalo_files = [
# "1k3d68.onnx",
# "2d106det.onnx",
# "genderage.onnx",
# "det_10g.onnx",
# "w600k_r50.onnx"
# ]
# for f in buffalo_files:
# hf_hub_download(
# repo_id=REPO_ID,
# filename=f"models/buffalo_l/{f}",
# repo_type="model",
# local_dir=MODELS_DIR,
# token=HF_TOKEN
# )
# logger.info("Models downloaded successfully")
# return inswapper_path
# inswapper_path = download_models()
# # --------------------- Face Analysis + Swapper ---------------------
# providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
# face_analysis_app = FaceAnalysis(name="buffalo_l", root=MODELS_DIR, providers=providers)
# face_analysis_app.prepare(ctx_id=0, det_size=(640, 640))
# swapper = insightface.model_zoo.get_model(inswapper_path, providers=providers)
# # --------------------- CodeFormer ---------------------
# CODEFORMER_PATH = "CodeFormer/inference_codeformer.py"
# def ensure_codeformer():
# if not os.path.exists("CodeFormer"):
# subprocess.run("git clone https://github.com/sczhou/CodeFormer.git", shell=True, check=True)
# subprocess.run("pip install -r CodeFormer/requirements.txt", shell=True, check=True)
# subprocess.run("python CodeFormer/basicsr/setup.py develop", shell=True, check=True)
# subprocess.run("python CodeFormer/scripts/download_pretrained_models.py facelib", shell=True, check=True)
# subprocess.run("python CodeFormer/scripts/download_pretrained_models.py CodeFormer", shell=True, check=True)
# ensure_codeformer()
# # --------------------- MongoDB ---------------------
# MONGODB_URL = os.getenv("MONGODB_URL")
# client = None
# database = None
# # --------------------- Admin Panel DB (categories + media_clicks) ---------------------
# # --------------------- Admin Panel DB (categories + subcategories + media_clicks) ---------------------
# ADMIN_MONGO_URL = os.getenv("ADMIN_MONGO_URL")
# admin_client = AsyncIOMotorClient(ADMIN_MONGO_URL)
# admin_db = admin_client.adminPanel
# # Collections
# categories_col = admin_db.categories
# subcategories_col = admin_db.subcategories
# media_clicks_col = admin_db.media_clicks
# users_col = admin_db.users # optional, only if needed
# # --------------------- FastAPI ---------------------
# fastapi_app = FastAPI()
# @fastapi_app.on_event("startup")
# async def startup_db():
# global client, database
# logger.info("Initializing MongoDB for API logs...")
# client = AsyncIOMotorClient(MONGODB_URL)
# database = client.FaceSwap
# logger.info("MongoDB initialized for API logs")
# @fastapi_app.on_event("shutdown")
# async def shutdown_db():
# global client
# if client:
# client.close()
# logger.info("MongoDB connection closed")
# # --------------------- Logging API Hits ---------------------
# async def log_faceswap_hit(user_email: str, status: str, start_time: datetime, end_time: datetime):
# global database
# if database is None:
# return
# response_time_ms = (end_time - start_time).total_seconds() * 1000
# await database.api_logs.insert_one({
# "user": user_email,
# "endpoint": "/face-swap",
# "status": status,
# "start_time": start_time,
# "end_time": end_time,
# "response_time_ms": response_time_ms
# })
# # --------------------- Face Swap Pipeline ---------------------
# swap_lock = threading.Lock()
# def face_swap_and_enhance(src_img, tgt_img, temp_dir="/tmp/faceswap_work"):
# try:
# with swap_lock:
# if os.path.exists(temp_dir):
# shutil.rmtree(temp_dir)
# os.makedirs(temp_dir, exist_ok=True)
# src_bgr = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR)
# tgt_bgr_full = cv2.cvtColor(tgt_img, cv2.COLOR_RGB2BGR)
# src_faces = face_analysis_app.get(src_bgr)
# tgt_faces = face_analysis_app.get(tgt_bgr_full)
# if not src_faces or not tgt_faces:
# return None, None, "❌ Face not detected in source or target image"
# src_face0 = src_faces[0]
# tgt_face0 = tgt_faces[0]
# swapped_bgr_full = swapper.get(tgt_bgr_full, tgt_face0, src_face0)
# if swapped_bgr_full is None:
# return None, None, "❌ Face swap failed"
# swapped_path = os.path.join(temp_dir, f"swapped_{uuid.uuid4().hex[:8]}.jpg")
# cv2.imwrite(swapped_path, swapped_bgr_full)
# cmd = f"python {CODEFORMER_PATH} -w 0.7 --input_path {swapped_path} --output_path {temp_dir} --bg_upsampler realesrgan --face_upsample"
# result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
# if result.returncode != 0:
# return None, None, f"❌ CodeFormer failed:\n{result.stderr}"
# final_results_dir = os.path.join(temp_dir, "final_results")
# final_files = [f for f in os.listdir(final_results_dir) if f.endswith(".png")]
# if not final_files:
# return None, None, "❌ No enhanced image found"
# final_path = os.path.join(final_results_dir, final_files[0])
# final_img = cv2.cvtColor(cv2.imread(final_path), cv2.COLOR_BGR2RGB)
# return final_img, final_path, ""
# except Exception as e:
# return None, None, f"❌ Error: {str(e)}"
# # --------------------- Gradio ---------------------
# with gr.Blocks() as demo:
# gr.Markdown("Face Swap")
# with gr.Row():
# src_input = gr.Image(type="numpy", label="Upload Your Face")
# tgt_input = gr.Image(type="numpy", label="Upload Target Image")
# btn = gr.Button("Swap Face")
# output_img = gr.Image(type="numpy", label="Enhanced Output")
# download = gr.File(label="⬇️ Download Enhanced Image")
# error_box = gr.Textbox(label="Logs / Errors", interactive=False)
# def process(src, tgt):
# img, path, err = face_swap_and_enhance(src, tgt)
# return img, path, err
# btn.click(process, [src_input, tgt_input], [output_img, download, error_box])
# # --------------------- DigitalOcean Spaces Helper ---------------------
# def get_spaces_client():
# session = boto3.session.Session()
# client = session.client(
# 's3',
# region_name=DO_SPACES_REGION,
# endpoint_url=DO_SPACES_ENDPOINT,
# aws_access_key_id=DO_SPACES_KEY,
# aws_secret_access_key=DO_SPACES_SECRET,
# config=Config(signature_version='s3v4')
# )
# return client
# def upload_to_spaces(file_bytes, key, content_type="image/png"):
# client = get_spaces_client()
# client.put_object(Bucket=DO_SPACES_BUCKET, Key=key, Body=file_bytes, ContentType=content_type, ACL='public-read')
# return f"{DO_SPACES_ENDPOINT}/{DO_SPACES_BUCKET}/{key}"
# def download_from_spaces(key):
# client = get_spaces_client()
# obj = client.get_object(Bucket=DO_SPACES_BUCKET, Key=key)
# return obj['Body'].read()
# # --------------------- API Endpoints ---------------------
# @fastapi_app.get("/")
# def root():
# return RedirectResponse("/gradio")
# @fastapi_app.get("/health")
# async def health():
# return {"status": "healthy"}
# @fastapi_app.post("/face-swap")
# async def face_swap_api(
# source: UploadFile = File(...),
# target_category_id: str = Form(...), # REQUIRED (old behavior preserved)
# category_id: Optional[str] = Form(None),
# user_id: Optional[str] = Form(None),
# new_subcategory_id: Optional[str] = Form(None),
# user_email: str = Depends(verify_firebase_token)
# ):
# start_time = datetime.now(timezone.utc)
# try:
# # ---------------------------------------------------------
# # NORMALIZE EMPTY STRINGS (Android older versions)
# # ---------------------------------------------------------
# if target_category_id == "":
# target_category_id = None
# if new_subcategory_id == "":
# new_subcategory_id = None
# if category_id == "":
# category_id = None
# if user_id == "":
# user_id = None
# # ---------------------------------------------------------
# # STRICT XOR VALIDATION
# # ---------------------------------------------------------
# if target_category_id and new_subcategory_id:
# raise HTTPException(
# status_code=400,
# detail="Provide ONLY ONE of: target_category_id OR new_subcategory_id"
# )
# if not target_category_id and not new_subcategory_id:
# raise HTTPException(
# status_code=400,
# detail="Either target_category_id OR new_subcategory_id is required"
# )
# # ---------------------------------------------------------
# # READ SOURCE IMAGE
# # ---------------------------------------------------------
# src_bytes = await source.read()
# src_key = f"bikini-theme/source/{uuid.uuid4().hex}_{source.filename}"
# upload_to_spaces(src_bytes, src_key, content_type=source.content_type)
# # ---------------------------------------------------------
# # CASE 1 — Old behavior (use DO Spaces target image)
# # ---------------------------------------------------------
# if target_category_id:
# target_filename = f"{target_category_id}.png"
# target_url = (
# f"https://{DO_SPACES_BUCKET}.{DO_SPACES_REGION}."
# f"digitaloceanspaces.com/bikini-theme/target/{target_filename}"
# )
# resp = requests.get(target_url)
# if resp.status_code != 200:
# await log_faceswap_hit(user_email, "error", start_time, datetime.now(timezone.utc))
# raise HTTPException(status_code=404, detail=f"Target image not found: {target_url}")
# tgt_bytes = resp.content
# # ---------------------------------------------------------
# # CASE 2 — New behavior (use subcategory asset image)
# # ---------------------------------------------------------
# else:
# # Find subcategory asset by asset_images._id
# asset = await admin_db.subcategories.find_one(
# {"asset_images._id": ObjectId(new_subcategory_id)},
# {"asset_images.$": 1}
# )
# if not asset or "asset_images" not in asset:
# await log_faceswap_hit(user_email, "error", start_time, datetime.now(timezone.utc))
# raise HTTPException(
# status_code=404,
# detail="Subcategory asset image not found"
# )
# # Extract the single matching image URL
# asset_url = asset["asset_images"][0]["url"]
# resp = requests.get(asset_url)
# if resp.status_code != 200:
# raise HTTPException(
# status_code=404,
# detail=f"Failed to download asset image: {asset_url}"
# )
# tgt_bytes = resp.content
# # ---------------------------------------------------------
# # DECODE BOTH IMAGES
# # ---------------------------------------------------------
# src_array = np.frombuffer(src_bytes, np.uint8)
# tgt_array = np.frombuffer(tgt_bytes, np.uint8)
# src_bgr = cv2.imdecode(src_array, cv2.IMREAD_COLOR)
# tgt_bgr = cv2.imdecode(tgt_array, cv2.IMREAD_COLOR)
# if src_bgr is None or tgt_bgr is None:
# await log_faceswap_hit(user_email, "error", start_time, datetime.now(timezone.utc))
# raise HTTPException(status_code=400, detail="Invalid image data")
# src_rgb = cv2.cvtColor(src_bgr, cv2.COLOR_BGR2RGB)
# tgt_rgb = cv2.cvtColor(tgt_bgr, cv2.COLOR_BGR2RGB)
# # ---------------------------------------------------------
# # FACE SWAP & ENHANCE
# # ---------------------------------------------------------
# final_img, final_path, err = face_swap_and_enhance(src_rgb, tgt_rgb)
# if err:
# await log_faceswap_hit(user_email, "error", start_time, datetime.now(timezone.utc))
# raise HTTPException(status_code=500, detail=err)
# # Save final output to DO Spaces
# with open(final_path, "rb") as f:
# result_bytes = f.read()
# result_key = f"bikini-theme/result/{uuid.uuid4().hex}_enhanced.png"
# result_url = upload_to_spaces(result_bytes, result_key, "image/png")
# await log_faceswap_hit(user_email, "success", start_time, datetime.now(timezone.utc))
# # ---------------------------------------------------------
# # SUCCESS RESPONSE
# # ---------------------------------------------------------
# return {
# "result_url": result_url,
# "category_id": category_id,
# "user_id": user_id,
# "new_subcategory_id": new_subcategory_id
# }
# except Exception as e:
# await log_faceswap_hit(user_email, "error", start_time, datetime.now(timezone.utc))
# raise HTTPException(status_code=500, detail=f"Face swap failed: {str(e)}")
# ####------------------------------------OLD CODE------------------------------------####
# # @fastapi_app.post("/face-swap")
# # async def face_swap_api(
# # source: UploadFile = File(...),
# # target_category_id: str = Form(...),
# # user_email: str = Depends(verify_firebase_token)
# # ):
# # # start_time = datetime.utcnow()
# # start_time = datetime.now(timezone.utc)
# # try:
# # src_bytes = await source.read()
# # src_key = f"bikini-theme/source/{uuid.uuid4().hex}_{source.filename}"
# # upload_to_spaces(src_bytes, src_key, content_type=source.content_type)
# # target_filename = f"{target_category_id}.png"
# # target_url = f"https://{DO_SPACES_BUCKET}.{DO_SPACES_REGION}.digitaloceanspaces.com/bikini-theme/target/{target_filename}"
# # resp = requests.get(target_url)
# # if resp.status_code != 200:
# # # end_time = datetime.utcnow()
# # end_time = datetime.now(timezone.utc)
# # await log_faceswap_hit(user_email, status="error", start_time=start_time, end_time=end_time)
# # raise HTTPException(status_code=404, detail=f"Target image not found at {target_url}")
# # tgt_bytes = resp.content
# # src_array = np.frombuffer(src_bytes, np.uint8)
# # tgt_array = np.frombuffer(tgt_bytes, np.uint8)
# # src_bgr = cv2.imdecode(src_array, cv2.IMREAD_COLOR)
# # tgt_bgr = cv2.imdecode(tgt_array, cv2.IMREAD_COLOR)
# # if src_bgr is None or tgt_bgr is None:
# # #end_time = datetime.utcnow()
# # end_time = datetime.now(timezone.utc)
# # await log_faceswap_hit(user_email, status="error", start_time=start_time, end_time=end_time)
# # raise HTTPException(status_code=400, detail="Invalid image data")
# # src_rgb = cv2.cvtColor(src_bgr, cv2.COLOR_BGR2RGB)
# # tgt_rgb = cv2.cvtColor(tgt_bgr, cv2.COLOR_BGR2RGB)
# # final_img, final_path, err = face_swap_and_enhance(src_rgb, tgt_rgb)
# # if err:
# # end_time = datetime.utcnow()
# # await log_faceswap_hit(user_email, status="error", start_time=start_time, end_time=end_time)
# # raise HTTPException(status_code=500, detail=err)
# # with open(final_path, "rb") as f:
# # result_bytes = f.read()
# # result_key = f"bikini-theme/result/{uuid.uuid4().hex}_enhanced.png"
# # result_url = upload_to_spaces(result_bytes, result_key, content_type="image/png")
# # #end_time = datetime.utcnow()
# # end_time = datetime.now(timezone.utc)
# # await log_faceswap_hit(user_email, status="success", start_time=start_time, end_time=end_time)
# # return {"result_url": result_url}
# # except Exception as e:
# # #end_time = datetime.utcnow()
# # end_time = datetime.now(timezone.utc)
# # # Ensure we log the error with timestamps before raising
# # try:
# # await log_faceswap_hit(user_email, status="error", start_time=start_time, end_time=end_time)
# # except Exception as log_exc:
# # logger.error("Failed to write log_faceswap_hit: %s", log_exc)
# # raise HTTPException(status_code=500, detail=f"Face swap failed: {str(e)}")
# @fastapi_app.get("/preview/{result_key:path}")
# async def preview_result(result_key: str):
# try:
# img_bytes = download_from_spaces(result_key)
# except Exception:
# raise HTTPException(status_code=404, detail="Result not found")
# return Response(
# content=img_bytes,
# media_type="image/png",
# headers={"Content-Disposition": "inline; filename=result.png"}
# )
# # --------------------- Mount Gradio ---------------------
# fastapi_app = mount_gradio_app(fastapi_app, demo, path="/gradio")
# if __name__ == "__main__":
# uvicorn.run(fastapi_app, host="0.0.0.0", port=7860)