Files changed (1) hide show
  1. app.py +115 -108
app.py CHANGED
@@ -1,117 +1,124 @@
1
  import gradio as gr
2
- import cv2
3
- import numpy as np
4
- from PIL import Image
5
  import torch
 
 
 
 
6
 
7
- # Optional: try to load a Hugging Face dewarping model if installed
8
- try:
9
- from transformers import AutoModel, AutoImageProcessor
10
- MODEL_REPO = "richard1231/Document_dewarping_platform"
11
- processor = AutoImageProcessor.from_pretrained(MODEL_REPO)
12
- model = AutoModel.from_pretrained(MODEL_REPO)
13
- model.eval()
14
- USE_HF_MODEL = True
15
- except Exception as e:
16
- print("⚠️ Hugging Face model not found, using OpenCV-only version.")
17
- USE_HF_MODEL = False
18
-
19
-
20
- # -------------------------------------------------------------
21
- # 🔹 Perspective correction (OpenCV fallback)
22
- # -------------------------------------------------------------
23
- def flatten_perspective(input_image: Image.Image) -> Image.Image:
24
- img = np.array(input_image.convert("RGB"))
25
- gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
26
- blur = cv2.GaussianBlur(gray, (5, 5), 0)
27
- edges = cv2.Canny(blur, 50, 150)
28
-
29
- contours, _ = cv2.findContours(edges, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
30
- if not contours:
31
- return input_image
32
- contour = max(contours, key=cv2.contourArea)
33
-
34
- peri = cv2.arcLength(contour, True)
35
- approx = cv2.approxPolyDP(contour, 0.02 * peri, True)
36
- if len(approx) != 4:
37
- return input_image
38
-
39
- pts = np.float32(approx.reshape(4, 2))
40
- s = pts.sum(axis=1)
41
- rect = np.zeros((4, 2), dtype="float32")
42
- rect[0] = pts[np.argmin(s)]
43
- rect[2] = pts[np.argmax(s)]
44
- diff = np.diff(pts, axis=1)
45
- rect[1] = pts[np.argmin(diff)]
46
- rect[3] = pts[np.argmax(diff)]
47
-
48
- (tl, tr, br, bl) = rect
49
- widthA = np.linalg.norm(br - bl)
50
- widthB = np.linalg.norm(tr - tl)
51
- heightA = np.linalg.norm(tr - br)
52
- heightB = np.linalg.norm(tl - bl)
53
- maxWidth, maxHeight = int(max(widthA, widthB)), int(max(heightA, heightB))
54
-
55
- dst = np.array([[0, 0],
56
- [maxWidth - 1, 0],
57
- [maxWidth - 1, maxHeight - 1],
58
- [0, maxHeight - 1]], dtype="float32")
59
-
60
- M = cv2.getPerspectiveTransform(rect, dst)
61
- warped = cv2.warpPerspective(img, M, (maxWidth, maxHeight))
62
- return Image.fromarray(warped)
63
-
64
-
65
- # -------------------------------------------------------------
66
- # 🔹 Learned de-warping (Hugging Face model)
67
- # -------------------------------------------------------------
68
- @torch.no_grad()
69
- def flatten_learned(input_image: Image.Image) -> Image.Image:
70
- if not USE_HF_MODEL:
71
- return flatten_perspective(input_image)
72
-
73
- inputs = processor(images=input_image, return_tensors="pt")
74
- outputs = model(**inputs)
75
- # Post-process — many HF models return tensors in 0-1 range
76
- out_img = outputs.last_hidden_state[0]
77
- out_img = (out_img - out_img.min()) / (out_img.max() - out_img.min())
78
- out_img = (out_img * 255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
79
- return Image.fromarray(out_img)
80
-
81
-
82
- # -------------------------------------------------------------
83
- # 🔹 Gradio UI
84
- # -------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
85
  description = """
86
- ## 🧾 Auto Image Flattening (Perspective + Learned Dewarping)
87
- Upload a **tilted or curved document/fabric photo**.
88
- - Default: OpenCV 4-point perspective flattening
89
- - Optional: if the **Hugging Face DewarpNet/DocRes model** is available, uses that instead
 
 
 
 
90
  """
91
 
92
- with gr.Blocks() as demo:
93
- gr.Markdown("# 📄 Auto Image Flattening (OpenCV / Hugging Face)")
94
- gr.Markdown(description)
95
-
96
- with gr.Row():
97
- inp = gr.Image(type="pil", label="Upload Image")
98
- out = gr.Image(type="pil", label="Flattened Output")
99
-
100
- mode = gr.Radio(["Auto (Use HF if available)", "OpenCV Only"], value="Auto (Use HF if available)", label="Mode")
101
-
102
- def process(img, mode):
103
- if mode == "OpenCV Only" or not USE_HF_MODEL:
104
- return flatten_perspective(img)
105
- return flatten_learned(img)
106
-
107
- btn = gr.Button("Flatten Image")
108
- btn.click(process, inputs=[inp, mode], outputs=out)
109
-
110
- gr.Examples(
111
- examples=["example1.jpg", "example2.jpg"],
112
- inputs=inp,
113
- examples_per_page=2,
114
- )
115
 
116
  if __name__ == "__main__":
117
  demo.launch()
 
1
  import gradio as gr
2
+ from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
3
+ from PIL import Image, ImageDraw
 
4
  import torch
5
+ import numpy as np
6
+ from sklearn.cluster import KMeans
7
+ from transformers import AutoImageProcessor, AutoModel
8
+ import cv2
9
 
10
+ # -----------------------------------------------------
11
+ # 1️⃣ Load SAM + DINOv2
12
+ # -----------------------------------------------------
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ sam = sam_model_registry["vit_b"](checkpoint=None).to(device)
15
+ mask_generator = SamAutomaticMaskGenerator(sam)
16
+ processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
17
+ dinov2 = AutoModel.from_pretrained("facebook/dinov2-base").to(device)
18
+
19
+ # -----------------------------------------------------
20
+ # 2️⃣ Utility Functions
21
+ # -----------------------------------------------------
22
+ def get_embeddings(img):
23
+ """DINOv2 feature embedding for region similarity."""
24
+ inputs = processor(images=img, return_tensors="pt").to(device)
25
+ with torch.no_grad():
26
+ outputs = dinov2(**inputs)
27
+ feat = outputs.last_hidden_state[0].cpu().numpy()
28
+ return feat.mean(axis=0)
29
+
30
+ def remove_background(image):
31
+ """Simple background removal using SAM largest mask."""
32
+ masks = mask_generator.generate(image)
33
+ if not masks:
34
+ return image
35
+ main_mask = max(masks, key=lambda x: x['area'])['segmentation']
36
+ image[~main_mask] = 255 # white background
37
+ return image
38
+
39
+ def get_centroid(mask):
40
+ coords = np.column_stack(np.where(mask))
41
+ if len(coords) == 0:
42
+ return (0, 0)
43
+ y, x = coords.mean(axis=0)
44
+ return int(x), int(y)
45
+
46
+ # -----------------------------------------------------
47
+ # 3️⃣ Segmentation Core
48
+ # -----------------------------------------------------
49
+ def segment_saree(image):
50
+ image = np.array(image.convert("RGB"))
51
+ image = remove_background(image) # background cleanup
52
+ masks = mask_generator.generate(image)
53
+ if not masks:
54
+ return None, None, None, None
55
+
56
+ regions = []
57
+ for m in masks:
58
+ area = m['area']
59
+ mask = m['segmentation']
60
+ region_img = Image.fromarray(np.uint8(image) * mask[..., None])
61
+ emb = get_embeddings(region_img)
62
+ regions.append((mask, emb, area))
63
+
64
+ # Cluster regions (3 = body/border/pallu)
65
+ feats = np.array([r[1] for r in regions])
66
+ kmeans = KMeans(n_clusters=3, random_state=42).fit(feats)
67
+ labels = kmeans.labels_
68
+
69
+ label_names = ["Body", "Border", "Pallu"]
70
+ colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0)]
71
+ seg_color = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.uint8)
72
+
73
+ # prepare transparent layers
74
+ layers = [np.zeros_like(image, dtype=np.uint8) for _ in range(3)]
75
+ for i, (mask, _, _) in enumerate(regions):
76
+ seg_color[mask] = colors[labels[i]]
77
+ layers[labels[i]][mask] = image[mask]
78
+
79
+ # overlay label text + legend
80
+ seg_img = Image.fromarray(seg_color)
81
+ draw = ImageDraw.Draw(seg_img)
82
+ for i, (mask, _, _) in enumerate(regions):
83
+ x, y = get_centroid(mask)
84
+ draw.text((x, y), label_names[labels[i]], fill=(255, 255, 255))
85
+
86
+ # create transparent PILs
87
+ layer_imgs = [Image.fromarray(cv2.cvtColor(l, cv2.COLOR_BGR2RGBA)) for l in layers]
88
+ for l in layer_imgs:
89
+ alpha = np.where(np.all(np.array(l)[..., :3] == 0, axis=-1), 0, 255)
90
+ arr = np.array(l)
91
+ arr[..., 3] = alpha
92
+ l.paste(Image.fromarray(arr))
93
+
94
+ return seg_img, layer_imgs[0], layer_imgs[1], layer_imgs[2]
95
+
96
+ # -----------------------------------------------------
97
+ # 4️⃣ Gradio Interface
98
+ # -----------------------------------------------------
99
  description = """
100
+ ### 🧶 Saree AI SAM + DINOv2 Smart Segmentation
101
+ Upload a **flat or draped saree image**.
102
+ The app will:
103
+ - ✂️ Remove background
104
+ - 🎨 Segment into **Body**, **Border**, **Pallu**
105
+ - 🪞 Give you individual transparent PNGs
106
+
107
+ Ideal for recoloring, catalog creation, or draping models.
108
  """
109
 
110
+ demo = gr.Interface(
111
+ fn=segment_saree,
112
+ inputs=gr.Image(type="pil", label="Upload Saree Image"),
113
+ outputs=[
114
+ gr.Image(type="pil", label="Overlay Mask with Labels"),
115
+ gr.Image(type="pil", label="Body (Transparent)"),
116
+ gr.Image(type="pil", label="Border (Transparent)"),
117
+ gr.Image(type="pil", label="Pallu (Transparent)"),
118
+ ],
119
+ title="🧵 Saree AI — Intelligent Segmentation & Layer Extraction",
120
+ description=description,
121
+ )
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  if __name__ == "__main__":
124
  demo.launch()