import gradio as gr import os import cv2 import numpy as np import torch from PIL import Image from insightface.app import FaceAnalysis from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDPlus import argparse import random from insightface.utils import face_align from pyngrok import ngrok import threading import time # Argument parser for command line options parser = argparse.ArgumentParser() parser.add_argument("--share", action="store_true", help="Enable Gradio share option") parser.add_argument("--num_images", type=int, default=1, help="Number of images to generate") parser.add_argument("--cache_limit", type=int, default=1, help="Limit for model cache") parser.add_argument("--ngrok_token", type=str, default=None, help="ngrok authtoken for tunneling") args = parser.parse_args() # Add new model names here static_model_names = [ "SG161222/Realistic_Vision_V6.0_B1_noVAE", "stablediffusionapi/rev-animated-v122-eol", "Lykon/DreamShaper", "stablediffusionapi/toonyou", "stablediffusionapi/real-cartoon-3d", "KBlueLeaf/kohaku-v2.1", "nitrosocke/Ghibli-Diffusion", "Linaqruf/anything-v3.0", "jinaai/flat-2d-animerge", "stablediffusionapi/realcartoon3d", "stablediffusionapi/disney-pixar-cartoon", "stablediffusionapi/pastel-mix-stylized-anime", "stablediffusionapi/anything-v5", "SG161222/Realistic_Vision_V2.0", "SG161222/Realistic_Vision_V4.0_noVAE", "SG161222/Realistic_Vision_V5.1_noVAE", r"C:\Users\King\Downloads\New folder\3D Animation Diffusion" ] # Cache for loaded models model_cache = {} max_cache_size = args.cache_limit # Function to load and cache model def load_model(model_name): if model_name in model_cache: return model_cache[model_name] # Limit cache size if len(model_cache) >= max_cache_size: model_cache.pop(next(iter(model_cache))) device = "cuda" noise_scheduler = DDIMScheduler( num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset=1, ) vae_model_path = "stabilityai/sd-vae-ft-mse" vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16) # Load model based on the selected model name pipe = StableDiffusionPipeline.from_pretrained( model_name, torch_dtype=torch.float16, scheduler=noise_scheduler, vae=vae, feature_extractor=None, safety_checker=None ).to(device) image_encoder_path = "h94/IP-Adapter/models/image_encoder" ip_ckpt = "adapters/ip-adapter-faceid-plusv2_sd15.bin" ip_model = IPAdapterFaceIDPlus(pipe, image_encoder_path, ip_ckpt, device) model_cache[model_name] = ip_model return ip_model # Function to process image and generate output def generate_image(input_image, positive_prompt, negative_prompt, width, height, model_name, num_inference_steps, seed, randomize_seed, num_images, batch_size, enable_shortcut, s_scale): saved_images = [] # Load and prepare the model ip_model = load_model(model_name) # Convert input image to the format expected by the model input_image = input_image.convert("RGB") input_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) app = FaceAnalysis( name="buffalo_l", providers=["CUDAExecutionProvider", "CPUExecutionProvider"] ) app.prepare(ctx_id=0, det_size=(640, 640)) faces = app.get(input_image) if not faces: raise ValueError("No faces found in the image.") faceid_embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0) face_image = face_align.norm_crop(input_image, landmark=faces[0].kps, image_size=224) for image_index in range(num_images): if randomize_seed or image_index > 0: seed = random.randint(0, 2**32 - 1) # Generate the image with the new parameters generated_images = ip_model.generate( prompt=positive_prompt, negative_prompt=negative_prompt, faceid_embeds=faceid_embeds, face_image=face_image, num_samples=batch_size, shortcut=enable_shortcut, s_scale=s_scale, width=width, height=height, num_inference_steps=num_inference_steps, seed=seed, ) # Save and prepare the generated images for display outputs_dir = "outputs" if not os.path.exists(outputs_dir): os.makedirs(outputs_dir) for i, img in enumerate(generated_images, start=1): image_path = os.path.join(outputs_dir, f"generated_{len(os.listdir(outputs_dir)) + i}.png") img.save(image_path) saved_images.append(image_path) return saved_images, f"Saved images: {', '.join(saved_images)}", seed # Gradio interface, using the static list of models with gr.Blocks() as demo: gr.Markdown("Developed by SECourses - only distributed on https://www.patreon.com/posts/95759342") with gr.Row(): input_image = gr.Image(type="pil") generate_btn = gr.Button("Generate") with gr.Row(): width = gr.Number(value=512, label="Width") height = gr.Number(value=768, label="Height") with gr.Row(): num_inference_steps = gr.Number(value=30, label="Number of Inference Steps", step=1, minimum=10, maximum=100) seed = gr.Number(value=2023, label="Seed") randomize_seed = gr.Checkbox(value=True, label="Randomize Seed") with gr.Row(): num_images = gr.Number(value=args.num_images, label="Number of Images to Generate", step=1, minimum=1) batch_size = gr.Number(value=1, label="Batch Size", step=1) with gr.Row(): enable_shortcut = gr.Checkbox(value=True, label="Enable Shortcut") s_scale = gr.Number(value=1.0, label="Scale Factor (s_scale)", step=0.1, minimum=0.5, maximum=4.0) with gr.Row(): positive_prompt = gr.Textbox(label="Positive Prompt") negative_prompt = gr.Textbox(label="Negative Prompt") with gr.Row(): model_selector = gr.Dropdown(label="Select Model", choices=static_model_names, value=static_model_names[0]) with gr.Column(): output_gallery = gr.Gallery(label="Generated Images") output_text = gr.Textbox(label="Output Info") display_seed = gr.Textbox(label="Used Seed", interactive=False) generate_btn.click( generate_image, inputs=[input_image, positive_prompt, negative_prompt, width, height, model_selector, num_inference_steps, seed, randomize_seed, num_images, batch_size, enable_shortcut, s_scale], outputs=[output_gallery, output_text, display_seed], ) def start_ngrok(): time.sleep(10) # Delay for 10 seconds to ensure Gradio starts first ngrok.set_auth_token(args.ngrok_token) public_url = ngrok.connect(port=7860) # Adjust to your Gradio app's port print(f"ngrok tunnel started at {public_url}") if __name__ == "__main__": if args.ngrok_token: # Start ngrok in a daemon thread with a delay ngrok_thread = threading.Thread(target=start_ngrok, daemon=True) ngrok_thread.start() # Launch the Gradio app demo.launch(share=args.share, inbrowser=True)