LogicGoInfotechSpaces's picture
Update app.py
ae2a6cd verified
import os
os.environ["OMP_NUM_THREADS"] = "1"
import shutil
import uuid
import cv2
import numpy as np
import threading
import subprocess
import logging
from datetime import datetime
import requests
from pymongo import MongoClient
import time
import insightface
from insightface.app import FaceAnalysis
from huggingface_hub import hf_hub_download
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Security, Form, Response
from fastapi.responses import RedirectResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.concurrency import run_in_threadpool
from fastapi.staticfiles import StaticFiles
import uvicorn
import gradio as gr
from gradio import mount_gradio_app
# --------------------- LOGGING ---------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --------------------- TARGET IMAGE BASE URL ---------------------
TARGET_BASE_URL = "https://halloween-image-generation.onrender.com/garment_templates"
# --------------------- PATHS ---------------------
REPO_ID = "HariLogicgo/face_swap_models"
MODELS_DIR = "./models"
os.makedirs(MODELS_DIR, exist_ok=True)
# --------------------- SECRETS ---------------------
# --------------------- MONGODB ---------------------
MONGO_URL = "mongodb+srv://harilogicgo_db_user:g6Zz4M2xWpr3B2VM@cluster0.bnzjt7f.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0"
mongo_client = MongoClient(MONGO_URL)
mongo_db = mongo_client["halloween_db"]
api_logs_collection = mongo_db["api_logs"]
HF_TOKEN = os.getenv("HF_TOKEN")
API_SECRET_TOKEN = os.getenv("API_SECRET_TOKEN")
# --------------------- DOWNLOAD MODELS ---------------------
def download_models():
logger.info("Downloading models...")
inswapper_path = hf_hub_download(
repo_id=REPO_ID,
filename="models/inswapper_128.onnx",
repo_type="model",
local_dir=MODELS_DIR,
token=HF_TOKEN
)
buffalo_files = [
"1k3d68.onnx",
"2d106det.onnx",
"genderage.onnx",
"det_10g.onnx",
"w600k_r50.onnx"
]
for f in buffalo_files:
hf_hub_download(
repo_id=REPO_ID,
filename=f"models/buffalo_l/{f}",
repo_type="model",
local_dir=MODELS_DIR,
token=HF_TOKEN
)
return inswapper_path
inswapper_path = download_models()
# --------------------- FACE ANALYSIS ---------------------
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
face_analysis_app = FaceAnalysis(name="buffalo_l", root=MODELS_DIR, providers=providers)
face_analysis_app.prepare(ctx_id=0, det_size=(640, 640))
swapper = insightface.model_zoo.get_model(inswapper_path, providers=providers)
# --------------------- CODEFORMER ---------------------
CODEFORMER_PATH = "CodeFormer/inference_codeformer.py"
def ensure_codeformer():
if not os.path.exists("CodeFormer"):
subprocess.run("git clone https://github.com/sczhou/CodeFormer.git", shell=True, check=True)
subprocess.run("pip install -r CodeFormer/requirements.txt", shell=True, check=True)
subprocess.run("python CodeFormer/basicsr/setup.py develop", shell=True, check=True)
subprocess.run("python CodeFormer/scripts/download_pretrained_models.py facelib", shell=True, check=True)
subprocess.run("python CodeFormer/scripts/download_pretrained_models.py CodeFormer", shell=True, check=True)
ensure_codeformer()
# --------------------- FASTAPI ---------------------
fastapi_app = FastAPI()
# --------------------- AUTH ---------------------
security = HTTPBearer()
def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)):
if credentials.credentials != API_SECRET_TOKEN:
raise HTTPException(status_code=401, detail="Invalid Token")
return credentials.credentials
# --------------------- FACE SWAP LOGIC ---------------------
swap_lock = threading.Lock()
def face_swap_and_enhance(src_img, tgt_img, temp_dir="/tmp/faceswap"):
try:
with swap_lock:
shutil.rmtree(temp_dir, ignore_errors=True)
os.makedirs(temp_dir, exist_ok=True)
src_bgr = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR)
tgt_bgr = cv2.cvtColor(tgt_img, cv2.COLOR_RGB2BGR)
src_faces = face_analysis_app.get(src_bgr)
tgt_faces = face_analysis_app.get(tgt_bgr)
if not src_faces or not tgt_faces:
return None, None, "Face not detected"
swapped_path = os.path.join(temp_dir, "swap.jpg")
swapped = swapper.get(tgt_bgr, tgt_faces[0], src_faces[0])
cv2.imwrite(swapped_path, swapped)
cmd = f"python {CODEFORMER_PATH} -w 0.7 --input_path {swapped_path} --output_path {temp_dir} --face_upsample"
subprocess.run(cmd, shell=True, check=True)
final_dir = os.path.join(temp_dir, "final_results")
final_file = os.listdir(final_dir)[0]
final_path = os.path.join(final_dir, final_file)
final_img = cv2.cvtColor(cv2.imread(final_path), cv2.COLOR_BGR2RGB)
return final_img, final_path, ""
except Exception as e:
return None, None, str(e)
# --------------------- TARGET LOAD FROM PUBLIC URL ---------------------
async def load_target_from_url(filename_or_index: str):
if filename_or_index.isdigit():
filename = f"{filename_or_index}.png"
else:
filename = filename_or_index
url = f"{TARGET_BASE_URL}/{filename}"
logger.info(f"Fetching target from: {url}")
resp = requests.get(url, timeout=15)
if resp.status_code != 200:
raise HTTPException(status_code=404, detail="Target image not found")
arr = np.frombuffer(resp.content, np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=400, detail="Invalid target image")
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# --------------------- API ---------------------
@fastapi_app.post("/face-swap", dependencies=[Depends(verify_token)])
async def face_swap_api(
source: UploadFile = File(...),
target: str = Form(...)
):
start_time = time.time()
status = "success"
error_msg = None
try:
src_bytes = await source.read()
src_arr = np.frombuffer(src_bytes, np.uint8)
src_img = cv2.imdecode(src_arr, cv2.IMREAD_COLOR)
if src_img is None:
raise HTTPException(status_code=400, detail="Invalid source image")
src_rgb = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
tgt_rgb = await load_target_from_url(target)
img, path, err = await run_in_threadpool(
face_swap_and_enhance,
src_rgb,
tgt_rgb
)
if err:
raise HTTPException(status_code=500, detail=err)
os.makedirs("garment_output", exist_ok=True)
output_uuid = str(uuid.uuid4())
output_filename = f"{output_uuid}.webp"
output_path = os.path.join("garment_output", output_filename)
cv2.imwrite(output_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
response_data = {
"status": "success",
"preview_url": f"/garment_output/{output_filename}",
"filename": output_filename
}
return response_data
except Exception as e:
status = "failure"
error_msg = str(e)
raise
finally:
end_time = time.time()
response_time_ms = round((end_time - start_time) * 1000, 2)
log_data = {
"api": "/face-swap",
"status": status,
"date": datetime.utcnow().strftime("%Y-%m-%d"),
"time": datetime.utcnow().strftime("%H:%M:%S"),
"response_time_ms": response_time_ms,
"target": target,
"error": error_msg
}
api_logs_collection.insert_one(log_data)
# @fastapi_app.post("/face-swap", dependencies=[Depends(verify_token)])
# async def face_swap_api(
# source: UploadFile = File(...),
# target: str = Form(...)
# ):
# src_bytes = await source.read()
# src_arr = np.frombuffer(src_bytes, np.uint8)
# src_img = cv2.imdecode(src_arr, cv2.IMREAD_COLOR)
# if src_img is None:
# raise HTTPException(status_code=400, detail="Invalid source image")
# src_rgb = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
# tgt_rgb = await load_target_from_url(target)
# img, path, err = await run_in_threadpool(
# face_swap_and_enhance,
# src_rgb,
# tgt_rgb
# )
# if err:
# raise HTTPException(status_code=500, detail=err)
# # ---------------- SAVE OUTPUT TO garment_output ----------------
# os.makedirs("garment_output", exist_ok=True)
# output_uuid = str(uuid.uuid4())
# output_filename = f"{output_uuid}.webp"
# output_path = os.path.join("garment_output", output_filename)
# cv2.imwrite(output_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
# return {
# "status": "success",
# "preview_url": f"/garment_output/{output_filename}",
# "filename": output_filename
# }
# --------------------- GRADIO ---------------------
with gr.Blocks() as demo:
gr.Markdown("## Face Swap (Target from URL)")
src = gr.Image(type="numpy", label="Upload Face")
target_name = gr.Textbox(label="Target Number (e.g. 1 or 10)")
btn = gr.Button("Swap")
out = gr.Image()
msg = gr.Textbox()
def process(src_img, filename):
tgt = requests.get(f"{TARGET_BASE_URL}/{filename}.png")
arr = np.frombuffer(tgt.content, np.uint8)
tgt_img = cv2.cvtColor(cv2.imdecode(arr, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
img, _, err = face_swap_and_enhance(src_img, tgt_img)
return img, err
btn.click(process, [src, target_name], [out, msg])
# ---------------- STATIC FILES (garment_output) ----------------
os.makedirs("garment_output", exist_ok=True)
fastapi_app.mount("/garment_output", StaticFiles(directory="garment_output"), name="garment_output")
# --------------------- MOUNT GRADIO ---------------------
fastapi_app = mount_gradio_app(fastapi_app, demo, path="/gradio")
# --------------------- HEALTH CHECK ---------------------
@fastapi_app.get("/health")
def health_check():
return {
"status": "healthy",
"service": "face-swap",
"timestamp": datetime.utcnow().isoformat() + "Z"
}
@fastapi_app.get("/")
def root():
return RedirectResponse("/gradio")
# --------------------- RUN ---------------------
if __name__ == "__main__":
uvicorn.run(fastapi_app, host="0.0.0.0", port=7860)
# import os
# os.environ["OMP_NUM_THREADS"] = "1"
# import shutil
# import uuid
# import cv2
# import numpy as np
# import threading
# import subprocess
# import logging
# from datetime import datetime
# import requests
# import insightface
# from insightface.app import FaceAnalysis
# from huggingface_hub import hf_hub_download
# from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Security, Form, Response
# from fastapi.responses import RedirectResponse
# from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
# from fastapi.concurrency import run_in_threadpool
# import uvicorn
# import gradio as gr
# from gradio import mount_gradio_app
# # --------------------- LOGGING ---------------------
# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)
# # --------------------- TARGET IMAGE BASE URL ---------------------
# TARGET_BASE_URL = "https://halloween-image-generation.onrender.com/garment_templates"
# # --------------------- PATHS ---------------------
# REPO_ID = "HariLogicgo/face_swap_models"
# MODELS_DIR = "./models"
# os.makedirs(MODELS_DIR, exist_ok=True)
# # --------------------- SECRETS ---------------------
# HF_TOKEN = os.getenv("HF_TOKEN")
# API_SECRET_TOKEN = os.getenv("API_SECRET_TOKEN")
# # --------------------- DOWNLOAD MODELS ---------------------
# def download_models():
# logger.info("Downloading models...")
# inswapper_path = hf_hub_download(
# repo_id=REPO_ID,
# filename="models/inswapper_128.onnx",
# repo_type="model",
# local_dir=MODELS_DIR,
# token=HF_TOKEN
# )
# buffalo_files = [
# "1k3d68.onnx",
# "2d106det.onnx",
# "genderage.onnx",
# "det_10g.onnx",
# "w600k_r50.onnx"
# ]
# for f in buffalo_files:
# hf_hub_download(
# repo_id=REPO_ID,
# filename=f"models/buffalo_l/{f}",
# repo_type="model",
# local_dir=MODELS_DIR,
# token=HF_TOKEN
# )
# return inswapper_path
# inswapper_path = download_models()
# # --------------------- FACE ANALYSIS ---------------------
# providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
# face_analysis_app = FaceAnalysis(name="buffalo_l", root=MODELS_DIR, providers=providers)
# face_analysis_app.prepare(ctx_id=0, det_size=(640, 640))
# swapper = insightface.model_zoo.get_model(inswapper_path, providers=providers)
# # --------------------- CODEFORMER ---------------------
# CODEFORMER_PATH = "CodeFormer/inference_codeformer.py"
# def ensure_codeformer():
# if not os.path.exists("CodeFormer"):
# subprocess.run("git clone https://github.com/sczhou/CodeFormer.git", shell=True, check=True)
# subprocess.run("pip install -r CodeFormer/requirements.txt", shell=True, check=True)
# subprocess.run("python CodeFormer/basicsr/setup.py develop", shell=True, check=True)
# subprocess.run("python CodeFormer/scripts/download_pretrained_models.py facelib", shell=True, check=True)
# subprocess.run("python CodeFormer/scripts/download_pretrained_models.py CodeFormer", shell=True, check=True)
# ensure_codeformer()
# # --------------------- FASTAPI ---------------------
# fastapi_app = FastAPI()
# # --------------------- AUTH ---------------------
# security = HTTPBearer()
# def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)):
# if credentials.credentials != API_SECRET_TOKEN:
# raise HTTPException(status_code=401, detail="Invalid Token")
# return credentials.credentials
# # --------------------- FACE SWAP LOGIC ---------------------
# swap_lock = threading.Lock()
# def face_swap_and_enhance(src_img, tgt_img, temp_dir="/tmp/faceswap"):
# try:
# with swap_lock:
# shutil.rmtree(temp_dir, ignore_errors=True)
# os.makedirs(temp_dir, exist_ok=True)
# src_bgr = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR)
# tgt_bgr = cv2.cvtColor(tgt_img, cv2.COLOR_RGB2BGR)
# src_faces = face_analysis_app.get(src_bgr)
# tgt_faces = face_analysis_app.get(tgt_bgr)
# if not src_faces or not tgt_faces:
# return None, None, "Face not detected"
# swapped_path = os.path.join(temp_dir, "swap.jpg")
# swapped = swapper.get(tgt_bgr, tgt_faces[0], src_faces[0])
# cv2.imwrite(swapped_path, swapped)
# cmd = f"python {CODEFORMER_PATH} -w 0.7 --input_path {swapped_path} --output_path {temp_dir} --face_upsample"
# subprocess.run(cmd, shell=True, check=True)
# final_dir = os.path.join(temp_dir, "final_results")
# final_file = os.listdir(final_dir)[0]
# final_path = os.path.join(final_dir, final_file)
# final_img = cv2.cvtColor(cv2.imread(final_path), cv2.COLOR_BGR2RGB)
# return final_img, final_path, ""
# except Exception as e:
# return None, None, str(e)
# # --------------------- TARGET LOAD FROM PUBLIC URL ---------------------
# async def load_target_from_url(filename_or_index: str):
# if filename_or_index.isdigit():
# filename = f"{filename_or_index}.png"
# else:
# filename = filename_or_index
# url = f"{TARGET_BASE_URL}/{filename}"
# logger.info(f"Fetching target from: {url}")
# resp = requests.get(url, timeout=15)
# if resp.status_code != 200:
# raise HTTPException(status_code=404, detail="Target image not found")
# arr = np.frombuffer(resp.content, np.uint8)
# img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
# if img is None:
# raise HTTPException(status_code=400, detail="Invalid target image")
# return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# # --------------------- API ---------------------
# @fastapi_app.post("/face-swap", dependencies=[Depends(verify_token)])
# async def face_swap_api(
# source: UploadFile = File(...),
# target: str = Form(...)
# ):
# src_bytes = await source.read()
# src_arr = np.frombuffer(src_bytes, np.uint8)
# src_img = cv2.imdecode(src_arr, cv2.IMREAD_COLOR)
# if src_img is None:
# raise HTTPException(status_code=400, detail="Invalid source image")
# src_rgb = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
# tgt_rgb = await load_target_from_url(target)
# img, path, err = await run_in_threadpool(
# face_swap_and_enhance,
# src_rgb,
# tgt_rgb
# )
# if err:
# raise HTTPException(status_code=500, detail=err)
# with open(path, "rb") as f:
# data = f.read()
# return Response(content=data, media_type="image/png")
# # --------------------- GRADIO ---------------------
# with gr.Blocks() as demo:
# gr.Markdown("## Face Swap (Target from URL)")
# src = gr.Image(type="numpy", label="Upload Face")
# target_name = gr.Textbox(label="Target Number (e.g. 1 or 10)")
# btn = gr.Button("Swap")
# out = gr.Image()
# msg = gr.Textbox()
# def process(src_img, filename):
# tgt = requests.get(f"{TARGET_BASE_URL}/{filename}.png")
# arr = np.frombuffer(tgt.content, np.uint8)
# tgt_img = cv2.cvtColor(cv2.imdecode(arr, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
# img, _, err = face_swap_and_enhance(src_img, tgt_img)
# return img, err
# btn.click(process, [src, target_name], [out, msg])
# # --------------------- MOUNT ---------------------
# fastapi_app = mount_gradio_app(fastapi_app, demo, path="/gradio")
# @fastapi_app.get("/")
# def root():
# return RedirectResponse("/gradio")
# # --------------------- RUN ---------------------
# if __name__ == "__main__":
# uvicorn.run(fastapi_app, host="0.0.0.0", port=7860)