nishanth-saka commited on
Commit
f65f5e2
Β·
verified Β·
1 Parent(s): cb790cd
Files changed (1) hide show
  1. app.py +112 -102
app.py CHANGED
@@ -1,107 +1,117 @@
1
- # ==============================================================
2
- # πŸ‘— Saree AI β€” Content-Aware Image Fitting (Smart Bounding Box)
3
- # Hugging Face Space β€” Stable Version (no runtime errors)
4
- # ==============================================================
5
-
6
- import cv2, numpy as np, gradio as gr
7
-
8
- # --------------------------------------------------------------
9
- # Core Function
10
- # --------------------------------------------------------------
11
- def content_aware_fit(image, target_size=(512, 512)):
12
- """Performs content-aware fitting of saree images preserving ornate regions."""
13
- img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
14
- img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
15
-
16
- # ----------------------------------------------------------
17
- # Step 1: Texture-based "Saliency" detection
18
- # ----------------------------------------------------------
19
- gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
20
- lap = cv2.Laplacian(gray, cv2.CV_64F)
21
- lap = np.absolute(lap)
22
- saliency_map = np.uint8(255 * (lap / np.max(lap)))
23
- saliency_map = cv2.GaussianBlur(saliency_map, (7, 7), 0)
24
- saliency_map = cv2.normalize(saliency_map, None, 0, 255, cv2.NORM_MINMAX)
25
- _, importance_mask = cv2.threshold(saliency_map, 128, 255, cv2.THRESH_BINARY)
26
-
27
- # ----------------------------------------------------------
28
- # Step 2: Center of Mass (Weighted by Saliency)
29
- # ----------------------------------------------------------
30
- M = cv2.moments(importance_mask)
31
- if M["m00"] != 0:
32
- cx = int(M["m10"] / M["m00"])
33
- cy = int(M["m01"] / M["m00"])
34
- else:
35
- cx, cy = img_rgb.shape[1] // 2, img_rgb.shape[0] // 2
36
-
37
- # ----------------------------------------------------------
38
- # Step 3: Smart Bounding Box + Padding
39
- # ----------------------------------------------------------
40
- coords = np.column_stack(np.where(importance_mask > 0))
41
- if coords.shape[0] > 0:
42
- x, y, w, h = cv2.boundingRect(coords)
43
- else:
44
- x, y, w, h = 0, 0, img_rgb.shape[1], img_rgb.shape[0]
45
-
46
- pad_x = int(0.1 * w)
47
- pad_y = int(0.1 * h)
48
- x1, y1 = max(0, x - pad_x), max(0, y - pad_y)
49
- x2, y2 = min(img_rgb.shape[1], x + w + pad_x), min(img_rgb.shape[0], y + h + pad_y)
50
- cropped = img_rgb[y1:y2, x1:x2]
51
-
52
- # ----------------------------------------------------------
53
- # Step 4: Aspect-ratio-safe Fit + Padding
54
- # ----------------------------------------------------------
55
- h, w, _ = cropped.shape
56
- target_h, target_w = target_size
57
- scale = min(target_w / w, target_h / h)
58
- new_w, new_h = int(w * scale), int(h * scale)
59
- resized = cv2.resize(cropped, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
60
-
61
- pad_x = (target_w - new_w) // 2
62
- pad_y = (target_h - new_h) // 2
63
- fitted = cv2.copyMakeBorder(
64
- resized, pad_y, target_h - new_h - pad_y,
65
- pad_x, target_w - new_w - pad_x,
66
- cv2.BORDER_CONSTANT, value=[255, 255, 255]
67
- )
68
 
69
- # ----------------------------------------------------------
70
- # Step 5: Overlay (for visualization)
71
- # ----------------------------------------------------------
72
- marked = img_rgb.copy()
73
- cv2.circle(marked, (cx, cy), 8, (255, 0, 0), -1)
74
- cv2.rectangle(marked, (x1, y1), (x2, y2), (0, 255, 0), 3)
75
-
76
- return (
77
- cv2.cvtColor(saliency_map, cv2.COLOR_GRAY2RGB),
78
- cv2.cvtColor(importance_mask, cv2.COLOR_GRAY2RGB),
79
- marked,
80
- cropped,
81
- fitted
82
- )
83
 
84
- # --------------------------------------------------------------
85
- # Gradio Interface
86
- # --------------------------------------------------------------
87
- demo = gr.Interface(
88
- fn=content_aware_fit,
89
- inputs=gr.Image(type="pil", label="Upload Saree Image"),
90
- outputs=[
91
- gr.Image(label="Texture / Saliency Map"),
92
- gr.Image(label="Importance Mask"),
93
- gr.Image(label="Center + Bounding Box"),
94
- gr.Image(label="Cropped View"),
95
- gr.Image(label="Final Content-Aware Fitted Output")
96
- ],
97
- title="πŸ‘— Saree AI β€” Content-Aware Image Fitting",
98
- description=(
99
- "Automatically detects ornate or high-detail regions (borders, pallus) "
100
- "and fits the saree image with smart padding to preserve its design. "
101
- "Ideal preprocessing step before saree draping or catalog generation."
102
- ),
103
- allow_flagging="never"
104
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  if __name__ == "__main__":
107
  demo.launch()
 
1
+ import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+
7
+ # Optional: try to load a Hugging Face dewarping model if installed
8
+ try:
9
+ from transformers import AutoModel, AutoImageProcessor
10
+ MODEL_REPO = "richard1231/Document_dewarping_platform"
11
+ processor = AutoImageProcessor.from_pretrained(MODEL_REPO)
12
+ model = AutoModel.from_pretrained(MODEL_REPO)
13
+ model.eval()
14
+ USE_HF_MODEL = True
15
+ except Exception as e:
16
+ print("⚠️ Hugging Face model not found, using OpenCV-only version.")
17
+ USE_HF_MODEL = False
18
+
19
+
20
+ # -------------------------------------------------------------
21
+ # πŸ”Ή Perspective correction (OpenCV fallback)
22
+ # -------------------------------------------------------------
23
+ def flatten_perspective(input_image: Image.Image) -> Image.Image:
24
+ img = np.array(input_image.convert("RGB"))
25
+ gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
26
+ blur = cv2.GaussianBlur(gray, (5, 5), 0)
27
+ edges = cv2.Canny(blur, 50, 150)
28
+
29
+ contours, _ = cv2.findContours(edges, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
30
+ if not contours:
31
+ return input_image
32
+ contour = max(contours, key=cv2.contourArea)
33
+
34
+ peri = cv2.arcLength(contour, True)
35
+ approx = cv2.approxPolyDP(contour, 0.02 * peri, True)
36
+ if len(approx) != 4:
37
+ return input_image
38
+
39
+ pts = np.float32(approx.reshape(4, 2))
40
+ s = pts.sum(axis=1)
41
+ rect = np.zeros((4, 2), dtype="float32")
42
+ rect[0] = pts[np.argmin(s)]
43
+ rect[2] = pts[np.argmax(s)]
44
+ diff = np.diff(pts, axis=1)
45
+ rect[1] = pts[np.argmin(diff)]
46
+ rect[3] = pts[np.argmax(diff)]
47
+
48
+ (tl, tr, br, bl) = rect
49
+ widthA = np.linalg.norm(br - bl)
50
+ widthB = np.linalg.norm(tr - tl)
51
+ heightA = np.linalg.norm(tr - br)
52
+ heightB = np.linalg.norm(tl - bl)
53
+ maxWidth, maxHeight = int(max(widthA, widthB)), int(max(heightA, heightB))
54
+
55
+ dst = np.array([[0, 0],
56
+ [maxWidth - 1, 0],
57
+ [maxWidth - 1, maxHeight - 1],
58
+ [0, maxHeight - 1]], dtype="float32")
59
+
60
+ M = cv2.getPerspectiveTransform(rect, dst)
61
+ warped = cv2.warpPerspective(img, M, (maxWidth, maxHeight))
62
+ return Image.fromarray(warped)
 
 
 
 
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ # -------------------------------------------------------------
66
+ # πŸ”Ή Learned de-warping (Hugging Face model)
67
+ # -------------------------------------------------------------
68
+ @torch.no_grad()
69
+ def flatten_learned(input_image: Image.Image) -> Image.Image:
70
+ if not USE_HF_MODEL:
71
+ return flatten_perspective(input_image)
72
+
73
+ inputs = processor(images=input_image, return_tensors="pt")
74
+ outputs = model(**inputs)
75
+ # Post-process β€” many HF models return tensors in 0-1 range
76
+ out_img = outputs.last_hidden_state[0]
77
+ out_img = (out_img - out_img.min()) / (out_img.max() - out_img.min())
78
+ out_img = (out_img * 255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
79
+ return Image.fromarray(out_img)
80
+
81
+
82
+ # -------------------------------------------------------------
83
+ # πŸ”Ή Gradio UI
84
+ # -------------------------------------------------------------
85
+ description = """
86
+ ## 🧾 Auto Image Flattening (Perspective + Learned Dewarping)
87
+ Upload a **tilted or curved document/fabric photo**.
88
+ - Default: OpenCV 4-point perspective flattening
89
+ - Optional: if the **Hugging Face DewarpNet/DocRes model** is available, uses that instead
90
+ """
91
+
92
+ with gr.Blocks() as demo:
93
+ gr.Markdown("# πŸ“„ Auto Image Flattening (OpenCV / Hugging Face)")
94
+ gr.Markdown(description)
95
+
96
+ with gr.Row():
97
+ inp = gr.Image(type="pil", label="Upload Image")
98
+ out = gr.Image(type="pil", label="Flattened Output")
99
+
100
+ mode = gr.Radio(["Auto (Use HF if available)", "OpenCV Only"], value="Auto (Use HF if available)", label="Mode")
101
+
102
+ def process(img, mode):
103
+ if mode == "OpenCV Only" or not USE_HF_MODEL:
104
+ return flatten_perspective(img)
105
+ return flatten_learned(img)
106
+
107
+ btn = gr.Button("Flatten Image")
108
+ btn.click(process, inputs=[inp, mode], outputs=out)
109
+
110
+ gr.Examples(
111
+ examples=["example1.jpg", "example2.jpg"],
112
+ inputs=inp,
113
+ examples_per_page=2,
114
+ )
115
 
116
  if __name__ == "__main__":
117
  demo.launch()