Update app.py
Browse files
app.py
CHANGED
|
@@ -4,15 +4,14 @@ import numpy as np
|
|
| 4 |
from PIL import Image
|
| 5 |
from diffusers import StableDiffusionInpaintPipeline
|
| 6 |
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection, pipeline
|
| 7 |
-
from transformers import SamProcessor, SamModel, pipeline as sam_pipeline
|
| 8 |
import spaces
|
| 9 |
|
| 10 |
@spaces.GPU
|
| 11 |
def remove_object_with_text(input_image, prompt):
|
| 12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 13 |
|
| 14 |
# 1️⃣ Grounding DINO zero-shot detection
|
| 15 |
-
dino_id = "pengxian/grounding-dino"
|
| 16 |
dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
|
| 17 |
dino = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny").to(device)
|
| 18 |
|
|
@@ -32,35 +31,27 @@ def remove_object_with_text(input_image, prompt):
|
|
| 32 |
x1, y1, x2, y2 = [int(v) for v in box]
|
| 33 |
mask.paste(255, (x1, y1, x2, y2))
|
| 34 |
|
| 35 |
-
# 2️⃣
|
| 36 |
-
sam_pipe = sam_pipeline("mask-generation", model="facebook/sam-vit-huge", device=0 if device=="cuda" else -1)
|
| 37 |
-
sam_out = sam_pipe(image=input_image, points_per_batch=256)
|
| 38 |
-
# Combine SAM masks that overlap the boxes
|
| 39 |
-
final_mask = Image.new("L", input_image.size, 0)
|
| 40 |
-
for m in sam_out["masks"]:
|
| 41 |
-
arr = np.array(m)
|
| 42 |
-
# apply only in box regions
|
| 43 |
-
for box in boxes:
|
| 44 |
-
x1,y1,x2,y2 = [int(v) for v in box]
|
| 45 |
-
sub = arr[y1:y2, x1:x2]
|
| 46 |
-
if sub.sum() > 1000:
|
| 47 |
-
final_mask.paste(Image.fromarray((arr*255).astype("uint8")), (0,0), Image.fromarray((arr*255).astype("uint8")))
|
| 48 |
-
|
| 49 |
-
# 3️⃣ Inpainting with Stable Diffusion
|
| 50 |
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
| 51 |
"stabilityai/stable-diffusion-2-inpainting",
|
| 52 |
torch_dtype=torch.float16 if device == "cuda" else torch.float32
|
| 53 |
).to(device)
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
| 57 |
|
| 58 |
output = pipe(prompt="background", image=img_resized, mask_image=mask_resized).images[0]
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
# Gradio UI
|
| 62 |
with gr.Blocks() as demo:
|
| 63 |
-
gr.Markdown("##
|
| 64 |
inp = gr.Image(type="pil")
|
| 65 |
txt = gr.Textbox(label="Describe object to remove", placeholder="e.g. a cat")
|
| 66 |
btn = gr.Button("Remove")
|
|
|
|
| 4 |
from PIL import Image
|
| 5 |
from diffusers import StableDiffusionInpaintPipeline
|
| 6 |
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection, pipeline
|
|
|
|
| 7 |
import spaces
|
| 8 |
|
| 9 |
@spaces.GPU
|
| 10 |
def remove_object_with_text(input_image, prompt):
|
| 11 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 12 |
+
original_size = input_image.size
|
| 13 |
|
| 14 |
# 1️⃣ Grounding DINO zero-shot detection
|
|
|
|
| 15 |
dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
|
| 16 |
dino = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny").to(device)
|
| 17 |
|
|
|
|
| 31 |
x1, y1, x2, y2 = [int(v) for v in box]
|
| 32 |
mask.paste(255, (x1, y1, x2, y2))
|
| 33 |
|
| 34 |
+
# 2️⃣ Inpainting with Stable Diffusion (high-res)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
| 36 |
"stabilityai/stable-diffusion-2-inpainting",
|
| 37 |
torch_dtype=torch.float16 if device == "cuda" else torch.float32
|
| 38 |
).to(device)
|
| 39 |
|
| 40 |
+
# Prepare for inpainting: resize to model's input size (512x512 or 768x768)
|
| 41 |
+
target_res = (768, 768) # Higher resolution = better quality
|
| 42 |
+
img_resized = input_image.resize(target_res, Image.LANCZOS)
|
| 43 |
+
mask_resized = mask.resize(target_res, Image.LANCZOS)
|
| 44 |
|
| 45 |
output = pipe(prompt="background", image=img_resized, mask_image=mask_resized).images[0]
|
| 46 |
+
|
| 47 |
+
# Resize back to original image size
|
| 48 |
+
final_output = output.resize(original_size, Image.LANCZOS)
|
| 49 |
+
|
| 50 |
+
return final_output, "Object removed and image size preserved."
|
| 51 |
|
| 52 |
# Gradio UI
|
| 53 |
with gr.Blocks() as demo:
|
| 54 |
+
gr.Markdown("## Object Removal with Text + Original Size Preservation")
|
| 55 |
inp = gr.Image(type="pil")
|
| 56 |
txt = gr.Textbox(label="Describe object to remove", placeholder="e.g. a cat")
|
| 57 |
btn = gr.Button("Remove")
|