nishanth-saka commited on
Commit
c209164
·
verified ·
1 Parent(s): 01c9c63
Files changed (1) hide show
  1. app.py +75 -147
app.py CHANGED
@@ -1,148 +1,76 @@
1
  import gradio as gr
2
- from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
3
- from transformers import AutoImageProcessor, AutoModel
4
- from huggingface_hub import snapshot_download
5
- from PIL import Image, ImageDraw
6
- import torch, numpy as np, cv2, zipfile, io, os, tempfile
7
- from sklearn.cluster import KMeans
8
-
9
- # -----------------------------------------------------
10
- # 1️⃣ Model Initialization
11
- # -----------------------------------------------------
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
13
-
14
- # --- Download SAM checkpoint if missing ---
15
- if not os.path.exists("sam_vit_b_01ec64.pth"):
16
- os.system("wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth")
17
-
18
- # --- Load SAM ---
19
- sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth").to(device)
20
- mask_generator = SamAutomaticMaskGenerator(sam)
21
-
22
- # --- Preload DINOv2 ---
23
- snapshot_download("facebook/dinov2-base")
24
- processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
25
- dinov2 = AutoModel.from_pretrained("facebook/dinov2-base").to(device)
26
-
27
- # -----------------------------------------------------
28
- # 2️⃣ Utility Functions
29
- # -----------------------------------------------------
30
- def get_embeddings(img):
31
- """Extract DINOv2 feature embeddings."""
32
- inputs = processor(images=img, return_tensors="pt").to(device)
33
- with torch.no_grad():
34
- outputs = dinov2(**inputs)
35
- feat = outputs.last_hidden_state[0].cpu().numpy()
36
- return feat.mean(axis=0)
37
-
38
- def remove_background(image):
39
- """Use largest SAM mask to isolate saree from background."""
40
- masks = mask_generator.generate(image)
41
- if not masks:
42
- return image
43
- main_mask = max(masks, key=lambda x: x['area'])['segmentation']
44
- image[~main_mask] = 255 # white out background
45
- return image
46
-
47
- def get_centroid(mask):
48
- coords = np.column_stack(np.where(mask))
49
- if len(coords) == 0:
50
- return (0, 0)
51
- y, x = coords.mean(axis=0)
52
- return int(x), int(y)
53
-
54
- def make_transparent(img, mask):
55
- rgba = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA)
56
- rgba[..., 3] = np.where(mask, 255, 0).astype(np.uint8)
57
- return rgba
58
-
59
- # -----------------------------------------------------
60
- # 3️⃣ Main Segmentation Function
61
- # -----------------------------------------------------
62
- def segment_saree(image):
63
- try:
64
- image = np.array(image.convert("RGB"))
65
- image = remove_background(image)
66
- masks = mask_generator.generate(image)
67
- if not masks:
68
- raise ValueError("No masks generated")
69
-
70
- regions = []
71
- for m in masks:
72
- mask = m["segmentation"]
73
- region_img = Image.fromarray(np.uint8(image) * mask[..., None])
74
- emb = get_embeddings(region_img)
75
- regions.append((mask, emb))
76
-
77
- if len(regions) < 3:
78
- raise ValueError("Insufficient distinct regions")
79
-
80
- features = np.array([r[1] for r in regions])
81
- kmeans = KMeans(n_clusters=3, random_state=42).fit(features)
82
- labels = kmeans.labels_
83
-
84
- colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0)]
85
- names = ["Body", "Border", "Pallu"]
86
- seg_out = np.zeros_like(image)
87
- layers = [np.zeros_like(image, dtype=np.uint8) for _ in range(3)]
88
-
89
- for i, (mask, _) in enumerate(regions):
90
- seg_out[mask] = colors[labels[i]]
91
- layers[labels[i]][mask] = image[mask]
92
-
93
- seg_img = Image.fromarray(seg_out)
94
- draw = ImageDraw.Draw(seg_img)
95
- for (mask, _), lbl in zip(regions, labels):
96
- x, y = get_centroid(mask)
97
- draw.text((x, y), names[lbl], fill=(255, 255, 255))
98
-
99
- # Transparent layers
100
- transparent_imgs = [Image.fromarray(make_transparent(l, l.any(axis=2))) for l in layers]
101
-
102
- # Write ZIP to a temp file (Gradio expects a real path)
103
- tmpdir = tempfile.mkdtemp()
104
- zip_path = os.path.join(tmpdir, "saree_layers.zip")
105
- with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
106
- for n, t in zip(names, transparent_imgs):
107
- tmp_img = os.path.join(tmpdir, f"{n}.png")
108
- t.save(tmp_img)
109
- zf.write(tmp_img, arcname=f"{n}.png")
110
-
111
- return seg_img, transparent_imgs[0], transparent_imgs[1], transparent_imgs[2], zip_path
112
-
113
- except Exception as e:
114
- print("Error:", e)
115
- blank = Image.new("RGB", (512, 512), color=(30, 30, 30))
116
- return blank, blank, blank, blank, None
117
-
118
- # -----------------------------------------------------
119
- # 4️⃣ Gradio UI
120
- # -----------------------------------------------------
121
- description = """
122
- ### 🧶 Saree AI — Intelligent Segmentation & Layer Export
123
- Upload a **flat or draped saree image**, and this tool will:
124
- - ✂️ Remove background
125
- - 🧠 Segment into **Body**, **Border**, **Pallu** using SAM + DINOv2
126
- - 🪞 Provide transparent PNGs
127
- - 📦 Download all masks as a single ZIP
128
-
129
- Built for saree recoloring, catalog automation, and AI draping pipelines.
130
- """
131
-
132
- demo = gr.Interface(
133
- fn=segment_saree,
134
- inputs=gr.Image(type="pil", label="Upload Saree Image"),
135
- outputs=[
136
- gr.Image(type="pil", label="Overlay Mask with Labels"),
137
- gr.Image(type="pil", label="Body (Transparent)"),
138
- gr.Image(type="pil", label="Border (Transparent)"),
139
- gr.Image(type="pil", label="Pallu (Transparent)"),
140
- gr.File(label="📦 Download All (ZIP)"),
141
- ],
142
- title="🧵 Saree AI — SAM + DINOv2 Smart Segmentation",
143
- description=description,
144
- flagging_mode="never",
145
- )
146
-
147
- if __name__ == "__main__":
148
- demo.launch()
 
1
  import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ def flatten_image(img, points):
7
+ """
8
+ img: PIL.Image
9
+ points: list of (x, y) tuples in order [TL, TR, BR, BL]
10
+ """
11
+ if img is None or not points or len(points) != 4:
12
+ return None, "Please click exactly 4 points (TL, TR, BR, BL)."
13
+
14
+ # Convert to numpy array
15
+ image_np = np.array(img)
16
+ h, w = image_np.shape[:2]
17
+
18
+ # Convert input points to float32 numpy array
19
+ src_pts = np.array(points, dtype=np.float32)
20
+
21
+ # Compute output rectangle size using distances
22
+ width_top = np.linalg.norm(src_pts[0] - src_pts[1])
23
+ width_bottom = np.linalg.norm(src_pts[3] - src_pts[2])
24
+ height_left = np.linalg.norm(src_pts[0] - src_pts[3])
25
+ height_right = np.linalg.norm(src_pts[1] - src_pts[2])
26
+
27
+ max_width = int(max(width_top, width_bottom))
28
+ max_height = int(max(height_left, height_right))
29
+
30
+ dst_pts = np.array([
31
+ [0, 0],
32
+ [max_width - 1, 0],
33
+ [max_width - 1, max_height - 1],
34
+ [0, max_height - 1]
35
+ ], dtype=np.float32)
36
+
37
+ # Compute homography
38
+ M = cv2.getPerspectiveTransform(src_pts, dst_pts)
39
+
40
+ # Apply perspective warp
41
+ warped = cv2.warpPerspective(image_np, M, (max_width, max_height), flags=cv2.INTER_CUBIC)
42
+
43
+ warped_pil = Image.fromarray(warped)
44
+ return warped_pil, None
45
+
46
+
47
+ with gr.Blocks() as demo:
48
+ gr.Markdown("## 📸 Perspective Flatten Tool\nUpload an image, click 4 corners (Top-Left → Top-Right → Bottom-Right → Bottom-Left), then flatten!")
49
+
50
+ with gr.Row():
51
+ input_image = gr.Image(label="Upload Image", tool="select", type="pil")
52
+ output_image = gr.Image(label="Flattened Output")
53
+
54
+ coords = gr.State([])
55
+
56
+ def collect_points(evt: gr.SelectData, points):
57
+ if points is None:
58
+ points = []
59
+ points.append(evt.index) # evt.index returns (x, y)
60
+ if len(points) > 4:
61
+ points = points[-4:] # keep only last 4
62
+ return points, f"Selected {len(points)}/4 points: {points}"
63
+
64
+ points_output = gr.Textbox(label="Selected Points", interactive=False)
65
+
66
+ input_image.select(fn=collect_points, inputs=coords, outputs=[coords, points_output])
67
+
68
+ flatten_btn = gr.Button("🔄 Flatten Image")
69
+
70
+ error_box = gr.Textbox(label="Messages", interactive=False)
71
+
72
+ flatten_btn.click(fn=flatten_image, inputs=[input_image, coords], outputs=[output_image, error_box])
73
+
74
+ gr.Markdown("Tip: Re-upload image to reset point selection.")
75
+
76
+ demo.launch()