bgm / app.py
Munaf1987's picture
Update app.py
140729b verified
import gradio as gr
import torch
import numpy as np
from PIL import Image
from diffusers import StableDiffusionInpaintPipeline
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection, pipeline
import spaces
@spaces.GPU
def remove_object_with_text(input_image, prompt):
device = "cuda" if torch.cuda.is_available() else "cpu"
original_size = input_image.size
# 1️⃣ Grounding DINO zero-shot detection
dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
dino = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny").to(device)
inputs = dino_processor(images=input_image, text=[[prompt]], return_tensors="pt").to(device)
outputs = dino(**inputs)
results = dino_processor.post_process_grounded_object_detection(
outputs, inputs.input_ids, box_threshold=0.3, text_threshold=0.3, target_sizes=[input_image.size[::-1]]
)
boxes = results[0]["boxes"]
if len(boxes) == 0:
return input_image, f"No object found for \"{prompt}\"."
# Create mask from boxes
mask = Image.new("L", input_image.size, 0)
for box in boxes:
x1, y1, x2, y2 = [int(v) for v in box]
mask.paste(255, (x1, y1, x2, y2))
# 2️⃣ Inpainting with Stable Diffusion (high-res)
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
# Prepare for inpainting: resize to model's input size (512x512 or 768x768)
target_res = (768, 768) # Higher resolution = better quality
img_resized = input_image.resize(target_res, Image.LANCZOS)
mask_resized = mask.resize(target_res, Image.LANCZOS)
output = pipe(prompt="background", image=img_resized, mask_image=mask_resized).images[0]
# Resize back to original image size
final_output = output.resize(original_size, Image.LANCZOS)
return final_output, "Object removed and image size preserved."
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Object Removal with Text + Original Size Preservation")
inp = gr.Image(type="pil")
txt = gr.Textbox(label="Describe object to remove", placeholder="e.g. a cat")
btn = gr.Button("Remove")
out = gr.Image(type="pil")
msg = gr.Textbox(interactive=False)
btn.click(remove_object_with_text, inputs=[inp, txt], outputs=[out, msg])
demo.launch()