Spaces:
Sleeping
Sleeping
| # ========================================================== | |
| # 🪔 Saree Pattern Extractor (DINOv2 / ViT Patch Clustering) | |
| # Optimized for HF Free CPU Tier | |
| # ========================================================== | |
| import os, zipfile, io, cv2, numpy as np, torch | |
| from PIL import Image | |
| from sklearn.cluster import KMeans | |
| import gradio as gr | |
| from transformers import AutoImageProcessor, AutoModel | |
| import matplotlib.pyplot as plt | |
| # ----------------------------- | |
| # 1️⃣ Load DINOv2-small model | |
| # ----------------------------- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_id = "facebook/dinov2-small" | |
| processor = AutoImageProcessor.from_pretrained(model_id) | |
| model = AutoModel.from_pretrained(model_id).to(device).eval() | |
| # ----------------------------- | |
| # 2️⃣ Helper functions | |
| # ----------------------------- | |
| def upscale_label_map(lbl, target_size): | |
| h, w = lbl.shape | |
| return cv2.resize(lbl.astype(np.uint8), target_size, interpolation=cv2.INTER_NEAREST) | |
| def extract_patterns(image: Image.Image, K: int = 8): | |
| # Resize to keep under memory limits | |
| image = image.convert("RGB") | |
| img_small = image.copy() | |
| img_small.thumbnail((480,480)) | |
| inputs = processor(images=img_small, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs, output_hidden_states=True) | |
| feats = outputs.last_hidden_state.squeeze(0)[1:].cpu().numpy() | |
| grid = int(np.sqrt(feats.shape[0])) | |
| km = KMeans(n_clusters=K, random_state=0, n_init="auto").fit(feats) | |
| labels = km.labels_.reshape(grid, grid) | |
| # Rebuild overlay | |
| lbl_map = upscale_label_map(labels, img_small.size) | |
| colors = plt.cm.tab10(np.linspace(0,1,K))[:,:3] | |
| overlay = np.zeros((*lbl_map.shape,3)) | |
| for k in range(K): | |
| overlay[lbl_map==k] = colors[k] | |
| overlay = (overlay*255).astype(np.uint8) | |
| blend = cv2.addWeighted(np.array(img_small), 0.6, overlay, 0.4, 0) | |
| # Patch export | |
| outdir = "patterns" | |
| os.makedirs(outdir, exist_ok=True) | |
| ph, pw = np.array(img_small.size)//grid | |
| for k in range(K): | |
| mask = (labels==k) | |
| coords = np.argwhere(mask) | |
| for i,(y,x) in enumerate(coords): | |
| y0,y1 = int(y*ph), int((y+1)*ph) | |
| x0,x1 = int(x*pw), int((x+1)*pw) | |
| patch = img_small.crop((x0,y0,x1,y1)) | |
| patch.save(f"{outdir}/cluster{k}_patch{i}.png") | |
| # Zip patches | |
| zip_path = "patterns.zip" | |
| with zipfile.ZipFile(zip_path,"w",zipfile.ZIP_DEFLATED) as zf: | |
| for fn in os.listdir(outdir): | |
| zf.write(os.path.join(outdir,fn), fn) | |
| return Image.fromarray(blend), zip_path | |
| # ----------------------------- | |
| # 3️⃣ Gradio Interface | |
| # ----------------------------- | |
| title = "🪔 Saree Pattern Extractor (DINOv2)" | |
| desc = "Upload a saree image → cluster similar motifs & textures → preview overlay → download all patch crops as ZIP." | |
| demo = gr.Interface( | |
| fn=extract_patterns, | |
| inputs=[ | |
| gr.Image(label="Upload Saree Image", type="pil"), | |
| gr.Slider(3,12,value=8,step=1,label="Number of Pattern Clusters") | |
| ], | |
| outputs=[ | |
| gr.Image(label="Cluster Overlay"), | |
| gr.File(label="Download ZIP of Crops") | |
| ], | |
| title=title, | |
| description=desc, | |
| allow_flagging="never", | |
| cache_examples=False | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |