Spaces:
Running
Running
| import cv2 | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from diffusers import StableDiffusionPipeline # Use SD 2.1 instead of SDXL | |
| from insightface.app import FaceAnalysis | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| import logging | |
| import time | |
| # Set up detailed logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Allow network access | |
| os.environ["HF_HUB_OFFLINE"] = "0" | |
| # Set device to CPU | |
| device = "cpu" | |
| dtype = torch.float32 | |
| # Define cache directory | |
| cache_dir = "./cache" | |
| os.makedirs(cache_dir, exist_ok=True) | |
| # Load face encoder | |
| logger.info("Starting InsightFace initialization...") | |
| try: | |
| face_app = FaceAnalysis(providers=["CPUExecutionProvider"]) | |
| face_app.prepare(ctx_id=0, det_size=(480, 480)) | |
| logger.info("InsightFace model loaded successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to load InsightFace model: {e}") | |
| raise | |
| # Download function with retry logic | |
| def download_file(repo_id, filename, local_dir, max_retries=3): | |
| file_path = os.path.join(local_dir, filename) | |
| if not os.path.exists(file_path): | |
| for attempt in range(max_retries): | |
| logger.info(f"Attempt {attempt + 1}/{max_retries}: Downloading {filename} from {repo_id} to {local_dir}...") | |
| try: | |
| downloaded_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| local_dir=local_dir, | |
| cache_dir=cache_dir, | |
| local_files_only=False | |
| ) | |
| logger.info(f"Downloaded to {downloaded_path}") | |
| return downloaded_path | |
| except Exception as e: | |
| logger.error(f"Download attempt {attempt + 1} failed: {e}") | |
| if attempt < max_retries - 1: | |
| logger.info("Retrying in 5 seconds...") | |
| time.sleep(5) | |
| else: | |
| raise RuntimeError(f"Failed to download {filename} after {max_retries} attempts: {e}") | |
| else: | |
| logger.info(f"Using cached file at {file_path}") | |
| return file_path | |
| # Define paths | |
| ip_adapter_path = "./" | |
| os.makedirs(ip_adapter_path, exist_ok=True) | |
| # Download IP-Adapter weights with retries | |
| logger.info("Starting weights download...") | |
| ip_adapter_weights = download_file( | |
| "Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus", | |
| "ipa-faceid-plus.bin", | |
| ip_adapter_path | |
| ) | |
| # Load the pipeline with SD 2.1 | |
| logger.info("Loading Stable Diffusion 2.1 base model...") | |
| try: | |
| max_retries = 3 | |
| for attempt in range(max_retries): | |
| try: | |
| logger.info(f"Attempt {attempt + 1}/{max_retries}: Loading SD 2.1 model...") | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-2-1", | |
| torch_dtype=dtype, | |
| safety_checker=None, | |
| local_files_only=False, | |
| cache_dir=cache_dir, | |
| variant="fp16", | |
| use_safetensors=True | |
| ) | |
| logger.info("SD 2.1 base model loaded successfully.") | |
| break | |
| except Exception as e: | |
| logger.error(f"Load attempt {attempt + 1} failed: {e}") | |
| if attempt < max_retries - 1: | |
| logger.info("Retrying in 5 seconds...") | |
| time.sleep(5) | |
| else: | |
| raise RuntimeError(f"Failed to load SD 2.1 model after {max_retries} attempts: {e}") | |
| except Exception as e: | |
| logger.error(f"Failed to load SD 2.1 base model: {e}") | |
| raise | |
| # Load IP-Adapter | |
| logger.info(f"Loading IP-Adapter from {ip_adapter_weights}...") | |
| try: | |
| pipe.load_ip_adapter(ip_adapter_path, subfolder=None, weight_name="ipa-faceid-plus.bin") | |
| logger.info("IP-Adapter loaded successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to load IP-Adapter: {e}") | |
| raise | |
| # Move pipeline to CPU | |
| logger.info("Moving pipeline to CPU...") | |
| pipe.to(device) | |
| logger.info("Pipeline moved to CPU.") | |
| def generate_image(uploaded_image, prompt): | |
| logger.info("Starting image generation...") | |
| try: | |
| img = cv2.cvtColor(np.array(uploaded_image), cv2.COLOR_RGB2BGR) | |
| faces = face_app.get(img) | |
| if not faces: | |
| logger.warning("No face detected in uploaded image.") | |
| return "No face detected!", None | |
| face_info = faces[-1] | |
| face_emb = face_info["embedding"] | |
| logger.info(f"Generating image with prompt: {prompt}") | |
| image = pipe( | |
| prompt=prompt, | |
| image_embeds=face_emb, | |
| num_inference_steps=10, | |
| guidance_scale=7.5, | |
| height=256, | |
| width=256 | |
| ).images[0] | |
| logger.info("Image generated successfully.") | |
| return "Image generated successfully!", image | |
| except Exception as e: | |
| logger.error(f"Generation failed: {e}") | |
| return f"Generation failed: {e}", None | |
| # Gradio interface | |
| interface = gr.Interface( | |
| fn=generate_image, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Reference Image"), | |
| gr.Textbox(label="Enter Prompt", placeholder="e.g., A photorealistic astronaut in space") | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Status"), | |
| gr.Image(label="Generated Image") | |
| ], | |
| title="Face Reference Image Generator", | |
| description="Upload an image with a face and generate a new image." | |
| ) | |
| logger.info("Launching Gradio interface...") | |
| interface.launch() |