Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from diffusers import StableDiffusionInpaintPipeline | |
| from PIL import Image | |
| from segment_anything import SamPredictor, sam_model_registry | |
| device = "cuda" | |
| sam_checkpoint = "/home/jupyter/diffusers/examples/sam_vit_h_4b8939.pth" # Added missing forward slash at the beginning | |
| model_type = "vit_h" | |
| # Load the model using the function from the registry and pass the checkpoint path | |
| model_fn = sam_model_registry[model_type] | |
| model = model_fn(checkpoint=sam_checkpoint) | |
| # Move the model to the desired device (GPU) | |
| model.to(device) | |
| predictor = SamPredictor(model) | |
| pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-2-inpainting", | |
| torch_dtype=torch.float16, | |
| ) # Removed space | |
| pipe = pipe.to(device) | |
| selected_pixels = [] | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| input_img = gr.Image(label="Input") # Removed space | |
| mask_img = gr.Image(label="Mask") # Corrected "Mas" to "Mask" | |
| output_img = gr.Image(label="Output") # Removed space | |
| with gr.Row(): | |
| prompt_text = gr.Textbox(lines=1, label="Prompt") # Removed space | |
| with gr.Row(): | |
| submit = gr.Button("Submit") | |
| def generate_mask(image, evt: gr.SelectData): | |
| selected_pixels.append(evt.index) # Removed space | |
| predictor.set_image(image) # Removed space | |
| input_points = np.array(selected_pixels) | |
| input_labels = np.ones(input_points.shape[0]) | |
| mask, _, _ = predictor.predict( | |
| point_coords=input_points, | |
| point_labels=input_labels, | |
| multimask_output=False | |
| ) | |
| # (n, sz, sz) | |
| mask = Image.fromarray(mask[0, :, :]) # Removed space | |
| return mask | |
| def inpaint(image, mask, prompt): | |
| image = Image.fromarray(image) # Removed space | |
| mask = Image.fromarray(mask) # Removed space | |
| image = image.resize((512, 512)) | |
| mask = mask.resize((512, 512)) | |
| output = pipe( | |
| prompt=prompt, | |
| image=image, | |
| mask_image=mask, | |
| ).images[0] | |
| return output | |
| input_img.select(generate_mask, [input_img], [mask_img]) | |
| submit.click(inpaint, inputs=[input_img, mask_img, prompt_text], outputs=[output_img]) | |
| if __name__ == "__main__": | |
| demo.launch() |