nishanth-saka commited on
Commit
eebbfa3
Β·
verified Β·
1 Parent(s): a7405e5
Files changed (1) hide show
  1. app.py +82 -53
app.py CHANGED
@@ -1,66 +1,95 @@
1
- import gradio as gr
2
- import cv2, numpy as np
 
 
 
3
  from PIL import Image
 
 
 
 
4
 
5
- # -------------------------------------------------
6
- # 🌈 Color-Preserving Fold Remover (LAB domain)
7
- # -------------------------------------------------
8
- def remove_folds_color_safe(image, intensity=0.5, enhance=False):
9
- if image is None:
10
- return None
 
11
 
12
- # Convert to numpy / LAB
13
- img_rgb = np.array(image.convert("RGB"))
14
- img_lab = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2LAB)
15
- L, A, B = cv2.split(img_lab)
 
 
16
 
17
- # --- (1) Smooth illumination field ---
18
- ksize = int(61 + intensity * 140) # adaptive blur
19
- ksize = ksize + 1 if ksize % 2 == 0 else ksize
20
- L_blur = cv2.GaussianBlur(L, (ksize, ksize), 0)
 
21
 
22
- # --- (2) Lighting flattening (division in luminance only) ---
23
- L_float = L.astype(np.float32)
24
- L_norm = (L_float / (L_blur + 1e-6)) * np.mean(L_blur)
25
- L_norm = np.clip(L_norm, 0, 255).astype(np.uint8)
26
 
27
- # --- (3) Local contrast restore (CLAHE) ---
28
- clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
29
- L_final = clahe.apply(L_norm)
30
 
31
- # --- (4) Optional weave-enhancement ---
32
- if enhance:
33
- sharp = cv2.GaussianBlur(L_final, (0, 0), 3)
34
- L_final = cv2.addWeighted(L_final, 1.3, sharp, -0.3, 0)
 
 
 
 
35
 
36
- # --- (5) Merge color back and convert to RGB ---
37
- img_lab_eq = cv2.merge([L_final, A, B])
38
- result = cv2.cvtColor(img_lab_eq, cv2.COLOR_LAB2RGB)
 
 
 
 
 
 
 
 
 
39
 
40
- return Image.fromarray(result)
 
 
 
 
41
 
42
- # -------------------------------------------------
43
- # πŸŽ›οΈ Gradio UI
44
- # -------------------------------------------------
45
- with gr.Blocks(title="πŸͺ„ Saree Fold Remover β€” Color-Safe LAB Edition") as demo:
46
- gr.Markdown("""
47
- ## πŸͺ„ Saree Fold Remover β€” *Color-Safe Edition*
48
- Removes folds & lighting shadows **without changing original colors**.<br>
49
- Works beautifully on silk, cotton and zari fabrics.<br>
50
- Adjust *Fold Intensity* for stronger corrections; enable *Enhance Texture* for sharper weave.
51
- """)
52
 
53
- with gr.Row():
54
- inp = gr.Image(label="Upload Saree Image", type="pil")
55
- intensity = gr.Slider(
56
- 0.0, 1.0, value=0.5, step=0.05,
57
- label="Fold Intensity",
58
- info="Higher = stronger flattening"
59
- )
60
- enhance = gr.Checkbox(label="✨ Enhance Texture", value=False)
61
 
62
- out = gr.Image(label="Flat, Fold-Free Output")
63
- run = gr.Button("πŸš€ Remove Folds")
64
- run.click(fn=remove_folds_color_safe, inputs=[inp, intensity, enhance], outputs=out)
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- demo.launch()
 
 
1
+ # ==========================================================
2
+ # πŸͺ” Saree Pattern Extractor (DINOv2 / ViT Patch Clustering)
3
+ # Optimized for HF Free CPU Tier
4
+ # ==========================================================
5
+ import os, zipfile, io, cv2, numpy as np, torch
6
  from PIL import Image
7
+ from sklearn.cluster import KMeans
8
+ import gradio as gr
9
+ from transformers import AutoImageProcessor, AutoModel
10
+ import matplotlib.pyplot as plt
11
 
12
+ # -----------------------------
13
+ # 1️⃣ Load DINOv2-small model
14
+ # -----------------------------
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ model_id = "facebook/dinov2-small"
17
+ processor = AutoImageProcessor.from_pretrained(model_id)
18
+ model = AutoModel.from_pretrained(model_id).to(device).eval()
19
 
20
+ # -----------------------------
21
+ # 2️⃣ Helper functions
22
+ # -----------------------------
23
+ def upscale_label_map(lbl, target_size):
24
+ h, w = lbl.shape
25
+ return cv2.resize(lbl.astype(np.uint8), target_size, interpolation=cv2.INTER_NEAREST)
26
 
27
+ def extract_patterns(image: Image.Image, K: int = 8):
28
+ # Resize to keep under memory limits
29
+ image = image.convert("RGB")
30
+ img_small = image.copy()
31
+ img_small.thumbnail((480,480))
32
 
33
+ inputs = processor(images=img_small, return_tensors="pt").to(device)
34
+ with torch.no_grad():
35
+ outputs = model(**inputs, output_hidden_states=True)
36
+ feats = outputs.last_hidden_state.squeeze(0)[1:].cpu().numpy()
37
 
38
+ grid = int(np.sqrt(feats.shape[0]))
39
+ km = KMeans(n_clusters=K, random_state=0, n_init="auto").fit(feats)
40
+ labels = km.labels_.reshape(grid, grid)
41
 
42
+ # Rebuild overlay
43
+ lbl_map = upscale_label_map(labels, img_small.size)
44
+ colors = plt.cm.tab10(np.linspace(0,1,K))[:,:3]
45
+ overlay = np.zeros((*lbl_map.shape,3))
46
+ for k in range(K):
47
+ overlay[lbl_map==k] = colors[k]
48
+ overlay = (overlay*255).astype(np.uint8)
49
+ blend = cv2.addWeighted(np.array(img_small), 0.6, overlay, 0.4, 0)
50
 
51
+ # Patch export
52
+ outdir = "patterns"
53
+ os.makedirs(outdir, exist_ok=True)
54
+ ph, pw = np.array(img_small.size)//grid
55
+ for k in range(K):
56
+ mask = (labels==k)
57
+ coords = np.argwhere(mask)
58
+ for i,(y,x) in enumerate(coords):
59
+ y0,y1 = int(y*ph), int((y+1)*ph)
60
+ x0,x1 = int(x*pw), int((x+1)*pw)
61
+ patch = img_small.crop((x0,y0,x1,y1))
62
+ patch.save(f"{outdir}/cluster{k}_patch{i}.png")
63
 
64
+ # Zip patches
65
+ zip_path = "patterns.zip"
66
+ with zipfile.ZipFile(zip_path,"w",zipfile.ZIP_DEFLATED) as zf:
67
+ for fn in os.listdir(outdir):
68
+ zf.write(os.path.join(outdir,fn), fn)
69
 
70
+ return Image.fromarray(blend), zip_path
 
 
 
 
 
 
 
 
 
71
 
72
+ # -----------------------------
73
+ # 3️⃣ Gradio Interface
74
+ # -----------------------------
75
+ title = "πŸͺ” Saree Pattern Extractor (DINOv2)"
76
+ desc = "Upload a saree image β†’ cluster similar motifs & textures β†’ preview overlay β†’ download all patch crops as ZIP."
 
 
 
77
 
78
+ demo = gr.Interface(
79
+ fn=extract_patterns,
80
+ inputs=[
81
+ gr.Image(label="Upload Saree Image", type="pil"),
82
+ gr.Slider(3,12,value=8,step=1,label="Number of Pattern Clusters")
83
+ ],
84
+ outputs=[
85
+ gr.Image(label="Cluster Overlay"),
86
+ gr.File(label="Download ZIP of Crops")
87
+ ],
88
+ title=title,
89
+ description=desc,
90
+ allow_flagging="never",
91
+ cache_examples=False
92
+ )
93
 
94
+ if __name__ == "__main__":
95
+ demo.launch()