import torch import numpy as np import gradio as gr from transformers import pipeline from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler from diffusers.utils import load_image, make_image_grid from PIL import Image # Function to get depth map def get_depth_map(image, depth_estimator): image = depth_estimator(image)["depth"] image = np.array(image) image = image[:, :, None] image = np.concatenate([image, image, image], axis=2) detected_map = torch.from_numpy(image).float() / 255.0 depth_map = detected_map.permute(2, 0, 1) return depth_map # Main function to process the image and prompt def process_image_and_prompt(input_image, prompt): # Convert PIL Image to the format expected by the pipeline input_image = input_image.convert("RGB") # Load depth estimator depth_estimator = pipeline("depth-estimation") # Get depth map depth_map = get_depth_map(input_image, depth_estimator).unsqueeze(0).half().to("cpu") # Load the ControlNet model and the StableDiffusionControlNetImg2ImgPipeline controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-normal", torch_dtype=torch.float16, use_safetensors=True) pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, use_safetensors=True ).to("cpu") pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe.enable_model_cpu_offload() # Generate the image output = pipe( prompt, image=input_image, control_image=depth_map, ).images[0] # Convert output to PIL Image for Gradio display output_image = Image.fromarray(output) return input_image, output_image # Create the Gradio interface iface = gr.Interface( fn=process_image_and_prompt, inputs=[gr.Image(type="pil"), gr.Textbox(label="Prompt")], outputs=[gr.Image(label="Original Image"), gr.Image(label="Generated Image")], title="Image and Prompt Processing with Stable Diffusion", description="Upload an image and enter a prompt to generate a new image." ) # Launch the Gradio app iface.launch()