Munaf1987 commited on
Commit
140729b
·
verified ·
1 Parent(s): fb794c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -21
app.py CHANGED
@@ -4,15 +4,14 @@ import numpy as np
4
  from PIL import Image
5
  from diffusers import StableDiffusionInpaintPipeline
6
  from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection, pipeline
7
- from transformers import SamProcessor, SamModel, pipeline as sam_pipeline
8
  import spaces
9
 
10
  @spaces.GPU
11
  def remove_object_with_text(input_image, prompt):
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
13
 
14
  # 1️⃣ Grounding DINO zero-shot detection
15
- dino_id = "pengxian/grounding-dino"
16
  dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
17
  dino = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny").to(device)
18
 
@@ -32,35 +31,27 @@ def remove_object_with_text(input_image, prompt):
32
  x1, y1, x2, y2 = [int(v) for v in box]
33
  mask.paste(255, (x1, y1, x2, y2))
34
 
35
- # 2️⃣ SAM automatic mask refinement
36
- sam_pipe = sam_pipeline("mask-generation", model="facebook/sam-vit-huge", device=0 if device=="cuda" else -1)
37
- sam_out = sam_pipe(image=input_image, points_per_batch=256)
38
- # Combine SAM masks that overlap the boxes
39
- final_mask = Image.new("L", input_image.size, 0)
40
- for m in sam_out["masks"]:
41
- arr = np.array(m)
42
- # apply only in box regions
43
- for box in boxes:
44
- x1,y1,x2,y2 = [int(v) for v in box]
45
- sub = arr[y1:y2, x1:x2]
46
- if sub.sum() > 1000:
47
- final_mask.paste(Image.fromarray((arr*255).astype("uint8")), (0,0), Image.fromarray((arr*255).astype("uint8")))
48
-
49
- # 3️⃣ Inpainting with Stable Diffusion
50
  pipe = StableDiffusionInpaintPipeline.from_pretrained(
51
  "stabilityai/stable-diffusion-2-inpainting",
52
  torch_dtype=torch.float16 if device == "cuda" else torch.float32
53
  ).to(device)
54
 
55
- img_resized = input_image.resize((512,512))
56
- mask_resized = final_mask.resize((512,512))
 
 
57
 
58
  output = pipe(prompt="background", image=img_resized, mask_image=mask_resized).images[0]
59
- return output, "Object removed."
 
 
 
 
60
 
61
  # Gradio UI
62
  with gr.Blocks() as demo:
63
- gr.Markdown("## Text-Based Object Removal + Inpainting")
64
  inp = gr.Image(type="pil")
65
  txt = gr.Textbox(label="Describe object to remove", placeholder="e.g. a cat")
66
  btn = gr.Button("Remove")
 
4
  from PIL import Image
5
  from diffusers import StableDiffusionInpaintPipeline
6
  from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection, pipeline
 
7
  import spaces
8
 
9
  @spaces.GPU
10
  def remove_object_with_text(input_image, prompt):
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ original_size = input_image.size
13
 
14
  # 1️⃣ Grounding DINO zero-shot detection
 
15
  dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
16
  dino = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny").to(device)
17
 
 
31
  x1, y1, x2, y2 = [int(v) for v in box]
32
  mask.paste(255, (x1, y1, x2, y2))
33
 
34
+ # 2️⃣ Inpainting with Stable Diffusion (high-res)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  pipe = StableDiffusionInpaintPipeline.from_pretrained(
36
  "stabilityai/stable-diffusion-2-inpainting",
37
  torch_dtype=torch.float16 if device == "cuda" else torch.float32
38
  ).to(device)
39
 
40
+ # Prepare for inpainting: resize to model's input size (512x512 or 768x768)
41
+ target_res = (768, 768) # Higher resolution = better quality
42
+ img_resized = input_image.resize(target_res, Image.LANCZOS)
43
+ mask_resized = mask.resize(target_res, Image.LANCZOS)
44
 
45
  output = pipe(prompt="background", image=img_resized, mask_image=mask_resized).images[0]
46
+
47
+ # Resize back to original image size
48
+ final_output = output.resize(original_size, Image.LANCZOS)
49
+
50
+ return final_output, "Object removed and image size preserved."
51
 
52
  # Gradio UI
53
  with gr.Blocks() as demo:
54
+ gr.Markdown("## Object Removal with Text + Original Size Preservation")
55
  inp = gr.Image(type="pil")
56
  txt = gr.Textbox(label="Describe object to remove", placeholder="e.g. a cat")
57
  btn = gr.Button("Remove")