mr2along's picture
Update app.py
96f6463 verified
import os
# =========================
# FORCE CPU MODE
# =========================
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = ""
import torch
import sys
import asyncio
import imageio
import tempfile
import numpy as np
import gradio as gr
from typing import Sequence, Mapping, Any, Union
from PIL import Image
from huggingface_hub import hf_hub_download
# =========================
# DOWNLOAD MODELS (ONLY IF NOT EXISTS)
# =========================
def download_if_not_exists(repo, filename, local_dir):
path = os.path.join(local_dir, filename)
if not os.path.exists(path):
os.makedirs(local_dir, exist_ok=True)
hf_hub_download(repo_id=repo, filename=filename, local_dir=local_dir)
download_if_not_exists("ezioruan/inswapper_128.onnx", "inswapper_128.onnx", "models/insightface")
download_if_not_exists("martintomov/comfy", "facerestore_models/GPEN-BFR-512.onnx", "models")
download_if_not_exists("facefusion/models-3.3.0", "hyperswap_1a_256.onnx", "models/hyperswap")
download_if_not_exists("facefusion/models-3.3.0", "hyperswap_1b_256.onnx", "models/hyperswap")
download_if_not_exists("facefusion/models-3.3.0", "hyperswap_1c_256.onnx", "models/hyperswap")
# =========================
# COMFY INIT (GIỮ NGUYÊN)
# =========================
from comfy.model_management import CPUState
import comfy.model_management
comfy.model_management.cpu_state = CPUState.CPU
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
try:
return obj[index]
except Exception:
return obj["result"][index]
def find_path(name: str, path: str = None) -> str:
if path is None:
path = os.getcwd()
if name in os.listdir(path):
return os.path.join(path, name)
parent = os.path.dirname(path)
if parent == path:
return None
return find_path(name, parent)
def add_comfyui_directory_to_sys_path():
comfyui_path = find_path("ComfyUI")
if comfyui_path and os.path.isdir(comfyui_path):
sys.path.append(comfyui_path)
add_comfyui_directory_to_sys_path()
def add_extra_model_paths():
try:
from main import load_extra_path_config
except ImportError:
from utils.extra_config import load_extra_path_config
extra_model_paths = find_path("extra_model_paths.yaml")
if extra_model_paths:
load_extra_path_config(extra_model_paths)
add_extra_model_paths()
def import_custom_nodes():
import execution
from nodes import init_extra_nodes
import server
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server_instance = server.PromptServer(loop)
execution.PromptQueue(server_instance)
loop.run_until_complete(init_extra_nodes())
import_custom_nodes()
from nodes import NODE_CLASS_MAPPINGS
loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
reactorfaceswap = NODE_CLASS_MAPPINGS["ReActorFaceSwap"]()
# =========================
# MAIN FUNCTION
# =========================
def generate_image(source_files, target_files, target_index,
swap_model, face_restore_model, restore_strength):
os.makedirs("output", exist_ok=True)
output_paths = []
if not source_files or not target_files:
return []
with torch.inference_mode():
for s in source_files:
source_path = s.name
loadimage_source = loadimage.load_image(image=source_path)
source_tensor = get_value_at_index(loadimage_source, 0)
source_base = os.path.splitext(os.path.basename(source_path))[0]
for t in target_files:
target_path = t.name
target_base = os.path.splitext(os.path.basename(target_path))[0]
# ================= GIF =================
if target_path.lower().endswith(".gif"):
reader = imageio.get_reader(target_path)
frames = []
durations = []
for i, frame in enumerate(reader):
frame_rgb = Image.fromarray(frame).convert("RGB")
frames.append(np.array(frame_rgb))
meta = reader.get_meta_data(index=i)
durations.append(meta.get("duration", 100))
reader.close()
output_frames = []
for frame in frames:
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
Image.fromarray(frame).save(tmp.name)
temp_path = tmp.name
loadimage_target = loadimage.load_image(image=temp_path)
target_tensor = get_value_at_index(loadimage_target, 0)
result = reactorfaceswap.execute(
enabled=True,
swap_model=swap_model,
facedetection="YOLOv5l",
face_restore_model=face_restore_model,
face_restore_visibility=restore_strength,
codeformer_weight=0.5,
detect_gender_input="no",
detect_gender_source="no",
input_faces_index=str(target_index),
source_faces_index="0",
console_log_level=1,
input_image=target_tensor,
source_image=source_tensor,
)
swapped = get_value_at_index(result, 0)[0]
if isinstance(swapped, torch.Tensor):
swapped = swapped.detach().cpu().float().numpy()
if swapped.max() <= 1.0:
swapped *= 255.0
swapped = np.clip(swapped, 0, 255).astype(np.uint8)
output_frames.append(Image.fromarray(swapped).convert("RGB"))
os.remove(temp_path)
output_path = f"output/{source_base}_to_{target_base}.webp"
output_frames[0].save(
output_path,
save_all=True,
append_images=output_frames[1:],
duration=durations,
loop=0,
format="WEBP",
quality=90,
method=6
)
# ================= IMAGE =================
else:
loadimage_target = loadimage.load_image(image=target_path)
target_tensor = get_value_at_index(loadimage_target, 0)
result = reactorfaceswap.execute(
enabled=True,
swap_model=swap_model,
facedetection="YOLOv5l",
face_restore_model=face_restore_model,
face_restore_visibility=restore_strength,
codeformer_weight=0.5,
detect_gender_input="no",
detect_gender_source="no",
input_faces_index=str(target_index),
source_faces_index="0",
console_log_level=1,
input_image=target_tensor,
source_image=source_tensor,
)
swapped = get_value_at_index(result, 0)[0]
if isinstance(swapped, torch.Tensor):
swapped = swapped.detach().cpu().float().numpy()
if swapped.max() <= 1.0:
swapped *= 255.0
swapped = np.clip(swapped, 0, 255).astype(np.uint8)
output_path = f"output/{source_base}_to_{target_base}.webp"
Image.fromarray(swapped).save(
output_path,
format="WEBP",
quality=90,
method=6
)
output_paths.append(output_path)
return output_paths
# =========================
# GRADIO UI
# =========================
with gr.Blocks() as app:
source_files = gr.File(label="Source Faces", file_count="multiple", interactive=True)
target_files = gr.File(label="Target Images / GIFs", file_count="multiple", interactive=True)
swap_model = gr.Dropdown(
choices=["inswapper_128.onnx",
"hyperswap_1a_256.onnx",
"hyperswap_1b_256.onnx",
"hyperswap_1c_256.onnx"],
value="hyperswap_1b_256.onnx",
label="Swap Model"
)
face_restore_model = gr.Dropdown(
choices=["none", "GPEN-BFR-512.onnx"],
value="none",
label="Face Restore Model"
)
restore_strength = gr.Slider(0, 1, 0.7, step=0.05)
target_index = gr.Dropdown([0,1,2,3,4], value=0)
generate_btn = gr.Button("Generate")
output_files = gr.Files(label="Output WebPs")
generate_btn.click(
fn=generate_image,
inputs=[source_files, target_files, target_index,
swap_model, face_restore_model, restore_strength],
outputs=output_files
)
app.launch(share=True)