import cv2 import torch import numpy as np import gradio as gr from diffusers import StableDiffusionPipeline # Use SD 2.1 instead of SDXL from insightface.app import FaceAnalysis from huggingface_hub import hf_hub_download import os import logging import time # Set up detailed logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Allow network access os.environ["HF_HUB_OFFLINE"] = "0" # Set device to CPU device = "cpu" dtype = torch.float32 # Define cache directory cache_dir = "./cache" os.makedirs(cache_dir, exist_ok=True) # Load face encoder logger.info("Starting InsightFace initialization...") try: face_app = FaceAnalysis(providers=["CPUExecutionProvider"]) face_app.prepare(ctx_id=0, det_size=(480, 480)) logger.info("InsightFace model loaded successfully.") except Exception as e: logger.error(f"Failed to load InsightFace model: {e}") raise # Download function with retry logic def download_file(repo_id, filename, local_dir, max_retries=3): file_path = os.path.join(local_dir, filename) if not os.path.exists(file_path): for attempt in range(max_retries): logger.info(f"Attempt {attempt + 1}/{max_retries}: Downloading {filename} from {repo_id} to {local_dir}...") try: downloaded_path = hf_hub_download( repo_id=repo_id, filename=filename, local_dir=local_dir, cache_dir=cache_dir, local_files_only=False ) logger.info(f"Downloaded to {downloaded_path}") return downloaded_path except Exception as e: logger.error(f"Download attempt {attempt + 1} failed: {e}") if attempt < max_retries - 1: logger.info("Retrying in 5 seconds...") time.sleep(5) else: raise RuntimeError(f"Failed to download {filename} after {max_retries} attempts: {e}") else: logger.info(f"Using cached file at {file_path}") return file_path # Define paths ip_adapter_path = "./" os.makedirs(ip_adapter_path, exist_ok=True) # Download IP-Adapter weights with retries logger.info("Starting weights download...") ip_adapter_weights = download_file( "Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus", "ipa-faceid-plus.bin", ip_adapter_path ) # Load the pipeline with SD 2.1 logger.info("Loading Stable Diffusion 2.1 base model...") try: max_retries = 3 for attempt in range(max_retries): try: logger.info(f"Attempt {attempt + 1}/{max_retries}: Loading SD 2.1 model...") pipe = StableDiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-2-1", torch_dtype=dtype, safety_checker=None, local_files_only=False, cache_dir=cache_dir, variant="fp16", use_safetensors=True ) logger.info("SD 2.1 base model loaded successfully.") break except Exception as e: logger.error(f"Load attempt {attempt + 1} failed: {e}") if attempt < max_retries - 1: logger.info("Retrying in 5 seconds...") time.sleep(5) else: raise RuntimeError(f"Failed to load SD 2.1 model after {max_retries} attempts: {e}") except Exception as e: logger.error(f"Failed to load SD 2.1 base model: {e}") raise # Load IP-Adapter logger.info(f"Loading IP-Adapter from {ip_adapter_weights}...") try: pipe.load_ip_adapter(ip_adapter_path, subfolder=None, weight_name="ipa-faceid-plus.bin") logger.info("IP-Adapter loaded successfully.") except Exception as e: logger.error(f"Failed to load IP-Adapter: {e}") raise # Move pipeline to CPU logger.info("Moving pipeline to CPU...") pipe.to(device) logger.info("Pipeline moved to CPU.") def generate_image(uploaded_image, prompt): logger.info("Starting image generation...") try: img = cv2.cvtColor(np.array(uploaded_image), cv2.COLOR_RGB2BGR) faces = face_app.get(img) if not faces: logger.warning("No face detected in uploaded image.") return "No face detected!", None face_info = faces[-1] face_emb = face_info["embedding"] logger.info(f"Generating image with prompt: {prompt}") image = pipe( prompt=prompt, image_embeds=face_emb, num_inference_steps=10, guidance_scale=7.5, height=256, width=256 ).images[0] logger.info("Image generated successfully.") return "Image generated successfully!", image except Exception as e: logger.error(f"Generation failed: {e}") return f"Generation failed: {e}", None # Gradio interface interface = gr.Interface( fn=generate_image, inputs=[ gr.Image(type="pil", label="Upload Reference Image"), gr.Textbox(label="Enter Prompt", placeholder="e.g., A photorealistic astronaut in space") ], outputs=[ gr.Textbox(label="Status"), gr.Image(label="Generated Image") ], title="Face Reference Image Generator", description="Upload an image with a face and generate a new image." ) logger.info("Launching Gradio interface...") interface.launch()