my3 / app.py
AkashKumarave's picture
Update app.py
f2f8f12 verified
raw
history blame
5.45 kB
import cv2
import torch
import numpy as np
import gradio as gr
from diffusers import StableDiffusionPipeline # Use SD 2.1 instead of SDXL
from insightface.app import FaceAnalysis
from huggingface_hub import hf_hub_download
import os
import logging
import time
# Set up detailed logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Allow network access
os.environ["HF_HUB_OFFLINE"] = "0"
# Set device to CPU
device = "cpu"
dtype = torch.float32
# Define cache directory
cache_dir = "./cache"
os.makedirs(cache_dir, exist_ok=True)
# Load face encoder
logger.info("Starting InsightFace initialization...")
try:
face_app = FaceAnalysis(providers=["CPUExecutionProvider"])
face_app.prepare(ctx_id=0, det_size=(480, 480))
logger.info("InsightFace model loaded successfully.")
except Exception as e:
logger.error(f"Failed to load InsightFace model: {e}")
raise
# Download function with retry logic
def download_file(repo_id, filename, local_dir, max_retries=3):
file_path = os.path.join(local_dir, filename)
if not os.path.exists(file_path):
for attempt in range(max_retries):
logger.info(f"Attempt {attempt + 1}/{max_retries}: Downloading {filename} from {repo_id} to {local_dir}...")
try:
downloaded_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=local_dir,
cache_dir=cache_dir,
local_files_only=False
)
logger.info(f"Downloaded to {downloaded_path}")
return downloaded_path
except Exception as e:
logger.error(f"Download attempt {attempt + 1} failed: {e}")
if attempt < max_retries - 1:
logger.info("Retrying in 5 seconds...")
time.sleep(5)
else:
raise RuntimeError(f"Failed to download {filename} after {max_retries} attempts: {e}")
else:
logger.info(f"Using cached file at {file_path}")
return file_path
# Define paths
ip_adapter_path = "./"
os.makedirs(ip_adapter_path, exist_ok=True)
# Download IP-Adapter weights with retries
logger.info("Starting weights download...")
ip_adapter_weights = download_file(
"Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus",
"ipa-faceid-plus.bin",
ip_adapter_path
)
# Load the pipeline with SD 2.1
logger.info("Loading Stable Diffusion 2.1 base model...")
try:
max_retries = 3
for attempt in range(max_retries):
try:
logger.info(f"Attempt {attempt + 1}/{max_retries}: Loading SD 2.1 model...")
pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1",
torch_dtype=dtype,
safety_checker=None,
local_files_only=False,
cache_dir=cache_dir,
variant="fp16",
use_safetensors=True
)
logger.info("SD 2.1 base model loaded successfully.")
break
except Exception as e:
logger.error(f"Load attempt {attempt + 1} failed: {e}")
if attempt < max_retries - 1:
logger.info("Retrying in 5 seconds...")
time.sleep(5)
else:
raise RuntimeError(f"Failed to load SD 2.1 model after {max_retries} attempts: {e}")
except Exception as e:
logger.error(f"Failed to load SD 2.1 base model: {e}")
raise
# Load IP-Adapter
logger.info(f"Loading IP-Adapter from {ip_adapter_weights}...")
try:
pipe.load_ip_adapter(ip_adapter_path, subfolder=None, weight_name="ipa-faceid-plus.bin")
logger.info("IP-Adapter loaded successfully.")
except Exception as e:
logger.error(f"Failed to load IP-Adapter: {e}")
raise
# Move pipeline to CPU
logger.info("Moving pipeline to CPU...")
pipe.to(device)
logger.info("Pipeline moved to CPU.")
def generate_image(uploaded_image, prompt):
logger.info("Starting image generation...")
try:
img = cv2.cvtColor(np.array(uploaded_image), cv2.COLOR_RGB2BGR)
faces = face_app.get(img)
if not faces:
logger.warning("No face detected in uploaded image.")
return "No face detected!", None
face_info = faces[-1]
face_emb = face_info["embedding"]
logger.info(f"Generating image with prompt: {prompt}")
image = pipe(
prompt=prompt,
image_embeds=face_emb,
num_inference_steps=10,
guidance_scale=7.5,
height=256,
width=256
).images[0]
logger.info("Image generated successfully.")
return "Image generated successfully!", image
except Exception as e:
logger.error(f"Generation failed: {e}")
return f"Generation failed: {e}", None
# Gradio interface
interface = gr.Interface(
fn=generate_image,
inputs=[
gr.Image(type="pil", label="Upload Reference Image"),
gr.Textbox(label="Enter Prompt", placeholder="e.g., A photorealistic astronaut in space")
],
outputs=[
gr.Textbox(label="Status"),
gr.Image(label="Generated Image")
],
title="Face Reference Image Generator",
description="Upload an image with a face and generate a new image."
)
logger.info("Launching Gradio interface...")
interface.launch()