| import gradio as gr |
| import torch |
| import numpy as np |
| from PIL import Image |
| from diffusers import StableDiffusionInpaintPipeline |
| from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection, pipeline |
| import spaces |
|
|
| @spaces.GPU |
| def remove_object_with_text(input_image, prompt): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| original_size = input_image.size |
|
|
| |
| dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny") |
| dino = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny").to(device) |
|
|
| inputs = dino_processor(images=input_image, text=[[prompt]], return_tensors="pt").to(device) |
| outputs = dino(**inputs) |
|
|
| results = dino_processor.post_process_grounded_object_detection( |
| outputs, inputs.input_ids, box_threshold=0.3, text_threshold=0.3, target_sizes=[input_image.size[::-1]] |
| ) |
| boxes = results[0]["boxes"] |
| if len(boxes) == 0: |
| return input_image, f"No object found for \"{prompt}\"." |
|
|
| |
| mask = Image.new("L", input_image.size, 0) |
| for box in boxes: |
| x1, y1, x2, y2 = [int(v) for v in box] |
| mask.paste(255, (x1, y1, x2, y2)) |
|
|
| |
| pipe = StableDiffusionInpaintPipeline.from_pretrained( |
| "stabilityai/stable-diffusion-2-inpainting", |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32 |
| ).to(device) |
|
|
| |
| target_res = (768, 768) |
| img_resized = input_image.resize(target_res, Image.LANCZOS) |
| mask_resized = mask.resize(target_res, Image.LANCZOS) |
|
|
| output = pipe(prompt="background", image=img_resized, mask_image=mask_resized).images[0] |
|
|
| |
| final_output = output.resize(original_size, Image.LANCZOS) |
|
|
| return final_output, "Object removed and image size preserved." |
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("## Object Removal with Text + Original Size Preservation") |
| inp = gr.Image(type="pil") |
| txt = gr.Textbox(label="Describe object to remove", placeholder="e.g. a cat") |
| btn = gr.Button("Remove") |
| out = gr.Image(type="pil") |
| msg = gr.Textbox(interactive=False) |
|
|
| btn.click(remove_object_with_text, inputs=[inp, txt], outputs=[out, msg]) |
|
|
| demo.launch() |
|
|