Inam65 commited on
Commit
350a741
·
verified ·
1 Parent(s): 7cfa978

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import SamModel, SamProcessor
6
+
7
+ # 1. Load the Model and Processor
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
10
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
11
+
12
+ def segment_object(image_data):
13
+ # image_data is a dictionary from the Gradio ImageEditor
14
+ if image_data is None or "composite" not in image_data:
15
+ return None
16
+
17
+ raw_image = image_data["background"].convert("RGB")
18
+
19
+ # Get the bounding box from the editor's layers
20
+ # Gradio returns shapes in a list. We look for the rectangle.
21
+ layers = image_data.get("layers", [])
22
+ if not layers:
23
+ return raw_image # Return original if no box drawn
24
+
25
+ # For simplicity, we take the first box found
26
+ # In a real app, you'd iterate to find the 'crop' or 'rect' layer
27
+ # Here we use the composite mask logic for a beginner-friendly approach
28
+
29
+ # Convert image for model
30
+ inputs = processor(raw_image, return_tensors="pt").to(device)
31
+ image_embeddings = model.get_image_embeddings(inputs["pixel_values"])
32
+
33
+ # In this simple version, we'll use the 'mask' drawn by the user
34
+ # to find the object. If you use the 'brush' or 'rect' tool:
35
+ mask = image_data["layers"][0].split()[-1] # Alpha channel of the drawing layer
36
+ mask = np.array(mask)
37
+
38
+ # Find the coordinates of the drawn rectangle
39
+ coords = np.argwhere(mask > 0)
40
+ if coords.size == 0:
41
+ return raw_image
42
+
43
+ y0, x0 = coords.min(axis=0)
44
+ y1, x1 = coords.max(axis=0)
45
+ input_boxes = [[[x0, y0, x1, y1]]]
46
+
47
+ # 2. Predict the mask
48
+ inputs = processor(raw_image, input_boxes=[input_boxes], return_tensors="pt").to(device)
49
+ inputs.pop("pixel_values", None)
50
+ inputs["image_embeddings"] = image_embeddings
51
+
52
+ with torch.no_grad():
53
+ outputs = model(**inputs)
54
+
55
+ # 3. Process the results
56
+ masks = processor.image_processor.post_process_masks(
57
+ outputs.pred_masks.cpu(),
58
+ inputs.original_sizes.cpu(),
59
+ inputs.reshaped_input_sizes.cpu()
60
+ )
61
+
62
+ # Take the first mask (best guess)
63
+ best_mask = masks[0][0][0].numpy()
64
+
65
+ # 4. Create High-Quality White Background
66
+ raw_np = np.array(raw_image)
67
+ # Create an image where the background is white [255, 255, 255]
68
+ white_bg = np.ones_like(raw_np) * 255
69
+
70
+ # Place object on white background
71
+ # We use the mask to choose between original pixels and white pixels
72
+ final_img = np.where(best_mask[..., None], raw_np, white_bg)
73
+
74
+ return Image.fromarray(final_img.astype('uint8'))
75
+
76
+ # 3. Create the Gradio Interface
77
+ with gr.Blocks() as demo:
78
+ gr.Markdown("# 🖌️ Object Extractor to White Background")
79
+ gr.Markdown("1. Upload an image. 2. Use the **Box** or **Brush** tool to highlight the object. 3. Click Submit.")
80
+
81
+ with gr.Row():
82
+ input_img = gr.ImageEditor(
83
+ label="Input Image",
84
+ type="pil",
85
+ layers=True,
86
+ canvas_size=(512, 512)
87
+ )
88
+ output_img = gr.Image(label="Extracted Object", type="pil")
89
+
90
+ submit_btn = gr.Button("Extract Object")
91
+ submit_btn.click(segment_object, inputs=[input_img], outputs=[output_img])
92
+
93
+ demo.launch()