Files changed (1) hide show
  1. app.py +86 -63
app.py CHANGED
@@ -1,18 +1,26 @@
1
  import gradio as gr
2
  from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
 
 
3
  from PIL import Image, ImageDraw
4
- import torch
5
- import numpy as np
6
  from sklearn.cluster import KMeans
7
- from transformers import AutoImageProcessor, AutoModel
8
- import cv2
9
 
10
  # -----------------------------------------------------
11
- # 1️⃣ Load SAM + DINOv2
12
  # -----------------------------------------------------
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
- sam = sam_model_registry["vit_b"](checkpoint=None).to(device)
 
 
 
 
 
 
15
  mask_generator = SamAutomaticMaskGenerator(sam)
 
 
 
16
  processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
17
  dinov2 = AutoModel.from_pretrained("facebook/dinov2-base").to(device)
18
 
@@ -20,7 +28,7 @@ dinov2 = AutoModel.from_pretrained("facebook/dinov2-base").to(device)
20
  # 2️⃣ Utility Functions
21
  # -----------------------------------------------------
22
  def get_embeddings(img):
23
- """DINOv2 feature embedding for region similarity."""
24
  inputs = processor(images=img, return_tensors="pt").to(device)
25
  with torch.no_grad():
26
  outputs = dinov2(**inputs)
@@ -28,12 +36,12 @@ def get_embeddings(img):
28
  return feat.mean(axis=0)
29
 
30
  def remove_background(image):
31
- """Simple background removal using SAM largest mask."""
32
  masks = mask_generator.generate(image)
33
  if not masks:
34
  return image
35
  main_mask = max(masks, key=lambda x: x['area'])['segmentation']
36
- image[~main_mask] = 255 # white background
37
  return image
38
 
39
  def get_centroid(mask):
@@ -43,68 +51,81 @@ def get_centroid(mask):
43
  y, x = coords.mean(axis=0)
44
  return int(x), int(y)
45
 
 
 
 
 
 
46
  # -----------------------------------------------------
47
- # 3️⃣ Segmentation Core
48
  # -----------------------------------------------------
49
  def segment_saree(image):
50
- image = np.array(image.convert("RGB"))
51
- image = remove_background(image) # background cleanup
52
- masks = mask_generator.generate(image)
53
- if not masks:
54
- return None, None, None, None
55
-
56
- regions = []
57
- for m in masks:
58
- area = m['area']
59
- mask = m['segmentation']
60
- region_img = Image.fromarray(np.uint8(image) * mask[..., None])
61
- emb = get_embeddings(region_img)
62
- regions.append((mask, emb, area))
63
-
64
- # Cluster regions (3 = body/border/pallu)
65
- feats = np.array([r[1] for r in regions])
66
- kmeans = KMeans(n_clusters=3, random_state=42).fit(feats)
67
- labels = kmeans.labels_
68
-
69
- label_names = ["Body", "Border", "Pallu"]
70
- colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0)]
71
- seg_color = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.uint8)
72
-
73
- # prepare transparent layers
74
- layers = [np.zeros_like(image, dtype=np.uint8) for _ in range(3)]
75
- for i, (mask, _, _) in enumerate(regions):
76
- seg_color[mask] = colors[labels[i]]
77
- layers[labels[i]][mask] = image[mask]
78
-
79
- # overlay label text + legend
80
- seg_img = Image.fromarray(seg_color)
81
- draw = ImageDraw.Draw(seg_img)
82
- for i, (mask, _, _) in enumerate(regions):
83
- x, y = get_centroid(mask)
84
- draw.text((x, y), label_names[labels[i]], fill=(255, 255, 255))
85
-
86
- # create transparent PILs
87
- layer_imgs = [Image.fromarray(cv2.cvtColor(l, cv2.COLOR_BGR2RGBA)) for l in layers]
88
- for l in layer_imgs:
89
- alpha = np.where(np.all(np.array(l)[..., :3] == 0, axis=-1), 0, 255)
90
- arr = np.array(l)
91
- arr[..., 3] = alpha
92
- l.paste(Image.fromarray(arr))
93
-
94
- return seg_img, layer_imgs[0], layer_imgs[1], layer_imgs[2]
 
 
 
 
 
 
 
 
95
 
96
  # -----------------------------------------------------
97
- # 4️⃣ Gradio Interface
98
  # -----------------------------------------------------
99
  description = """
100
- ### 🧶 Saree AI — SAM + DINOv2 Smart Segmentation
101
- Upload a **flat or draped saree image**.
102
- The app will:
103
  - ✂️ Remove background
104
- - 🎨 Segment into **Body**, **Border**, **Pallu**
105
- - 🪞 Give you individual transparent PNGs
 
106
 
107
- Ideal for recoloring, catalog creation, or draping models.
108
  """
109
 
110
  demo = gr.Interface(
@@ -115,9 +136,11 @@ demo = gr.Interface(
115
  gr.Image(type="pil", label="Body (Transparent)"),
116
  gr.Image(type="pil", label="Border (Transparent)"),
117
  gr.Image(type="pil", label="Pallu (Transparent)"),
 
118
  ],
119
- title="🧵 Saree AI — Intelligent Segmentation & Layer Extraction",
120
  description=description,
 
121
  )
122
 
123
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
3
+ from transformers import AutoImageProcessor, AutoModel
4
+ from huggingface_hub import snapshot_download
5
  from PIL import Image, ImageDraw
6
+ import torch, numpy as np, cv2, zipfile, io, os
 
7
  from sklearn.cluster import KMeans
 
 
8
 
9
  # -----------------------------------------------------
10
+ # 1️⃣ Model Initialization
11
  # -----------------------------------------------------
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ # --- Download SAM checkpoint if missing ---
15
+ if not os.path.exists("sam_vit_b_01ec64.pth"):
16
+ os.system("wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth")
17
+
18
+ # --- Load SAM ---
19
+ sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth").to(device)
20
  mask_generator = SamAutomaticMaskGenerator(sam)
21
+
22
+ # --- Preload DINOv2 ---
23
+ snapshot_download("facebook/dinov2-base")
24
  processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
25
  dinov2 = AutoModel.from_pretrained("facebook/dinov2-base").to(device)
26
 
 
28
  # 2️⃣ Utility Functions
29
  # -----------------------------------------------------
30
  def get_embeddings(img):
31
+ """Extract DINOv2 feature embeddings."""
32
  inputs = processor(images=img, return_tensors="pt").to(device)
33
  with torch.no_grad():
34
  outputs = dinov2(**inputs)
 
36
  return feat.mean(axis=0)
37
 
38
  def remove_background(image):
39
+ """Use largest SAM mask to isolate saree from background."""
40
  masks = mask_generator.generate(image)
41
  if not masks:
42
  return image
43
  main_mask = max(masks, key=lambda x: x['area'])['segmentation']
44
+ image[~main_mask] = 255 # white out background
45
  return image
46
 
47
  def get_centroid(mask):
 
51
  y, x = coords.mean(axis=0)
52
  return int(x), int(y)
53
 
54
+ def make_transparent(img, mask):
55
+ rgba = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA)
56
+ rgba[..., 3] = np.where(mask, 255, 0).astype(np.uint8)
57
+ return rgba
58
+
59
  # -----------------------------------------------------
60
+ # 3️⃣ Main Segmentation Function
61
  # -----------------------------------------------------
62
  def segment_saree(image):
63
+ try:
64
+ image = np.array(image.convert("RGB"))
65
+ image = remove_background(image)
66
+ masks = mask_generator.generate(image)
67
+ if not masks:
68
+ raise ValueError("No masks generated")
69
+
70
+ regions = []
71
+ for m in masks:
72
+ mask = m["segmentation"]
73
+ region_img = Image.fromarray(np.uint8(image) * mask[..., None])
74
+ emb = get_embeddings(region_img)
75
+ regions.append((mask, emb))
76
+
77
+ if len(regions) < 3:
78
+ raise ValueError("Insufficient distinct regions")
79
+
80
+ features = np.array([r[1] for r in regions])
81
+ kmeans = KMeans(n_clusters=3, random_state=42).fit(features)
82
+ labels = kmeans.labels_
83
+
84
+ colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0)]
85
+ names = ["Body", "Border", "Pallu"]
86
+ seg_out = np.zeros_like(image)
87
+ layers = [np.zeros_like(image, dtype=np.uint8) for _ in range(3)]
88
+
89
+ for i, (mask, _) in enumerate(regions):
90
+ seg_out[mask] = colors[labels[i]]
91
+ layers[labels[i]][mask] = image[mask]
92
+
93
+ seg_img = Image.fromarray(seg_out)
94
+ draw = ImageDraw.Draw(seg_img)
95
+ for (mask, _), lbl in zip(regions, labels):
96
+ x, y = get_centroid(mask)
97
+ draw.text((x, y), names[lbl], fill=(255, 255, 255))
98
+
99
+ # Transparent layers
100
+ transparent_imgs = [Image.fromarray(make_transparent(l, l.any(axis=2))) for l in layers]
101
+
102
+ # ZIP all outputs
103
+ zip_buffer = io.BytesIO()
104
+ with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED) as zf:
105
+ for n, t in zip(names, transparent_imgs):
106
+ bio = io.BytesIO()
107
+ t.save(bio, format="PNG")
108
+ zf.writestr(f"{n}.png", bio.getvalue())
109
+ zip_buffer.seek(0)
110
+
111
+ return seg_img, transparent_imgs[0], transparent_imgs[1], transparent_imgs[2], zip_buffer
112
+ except Exception as e:
113
+ print("Error:", e)
114
+ blank = Image.new("RGB", (512, 512), color=(30, 30, 30))
115
+ return blank, blank, blank, blank, None
116
 
117
  # -----------------------------------------------------
118
+ # 4️⃣ Gradio UI
119
  # -----------------------------------------------------
120
  description = """
121
+ ### 🧵 Saree AI — Intelligent Segmentation & Layer Export
122
+ Upload a **flat or draped saree image**, and this tool will:
 
123
  - ✂️ Remove background
124
+ - 🧠 Segment into **Body**, **Border**, **Pallu** using SAM + DINOv2
125
+ - 🪞 Provide transparent PNGs
126
+ - 📦 Download all masks as a single ZIP
127
 
128
+ Built for saree recoloring, catalog automation, and AI draping pipelines.
129
  """
130
 
131
  demo = gr.Interface(
 
136
  gr.Image(type="pil", label="Body (Transparent)"),
137
  gr.Image(type="pil", label="Border (Transparent)"),
138
  gr.Image(type="pil", label="Pallu (Transparent)"),
139
+ gr.File(label="📦 Download All (ZIP)"),
140
  ],
141
+ title="🧶 Saree AI — SAM + DINOv2 Smart Segmentation",
142
  description=description,
143
+ allow_flagging="never",
144
  )
145
 
146
  if __name__ == "__main__":