# ========================================================== # 🪔 Saree Pattern Extractor (DINOv2 / ViT Patch Clustering) # Optimized for HF Free CPU Tier # ========================================================== import os, zipfile, io, cv2, numpy as np, torch from PIL import Image from sklearn.cluster import KMeans import gradio as gr from transformers import AutoImageProcessor, AutoModel import matplotlib.pyplot as plt # ----------------------------- # 1️⃣ Load DINOv2-small model # ----------------------------- device = "cuda" if torch.cuda.is_available() else "cpu" model_id = "facebook/dinov2-small" processor = AutoImageProcessor.from_pretrained(model_id) model = AutoModel.from_pretrained(model_id).to(device).eval() # ----------------------------- # 2️⃣ Helper functions # ----------------------------- def upscale_label_map(lbl, target_size): h, w = lbl.shape return cv2.resize(lbl.astype(np.uint8), target_size, interpolation=cv2.INTER_NEAREST) def extract_patterns(image: Image.Image, K: int = 8): # Resize to keep under memory limits image = image.convert("RGB") img_small = image.copy() img_small.thumbnail((480,480)) inputs = processor(images=img_small, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs, output_hidden_states=True) feats = outputs.last_hidden_state.squeeze(0)[1:].cpu().numpy() grid = int(np.sqrt(feats.shape[0])) km = KMeans(n_clusters=K, random_state=0, n_init="auto").fit(feats) labels = km.labels_.reshape(grid, grid) # Rebuild overlay lbl_map = upscale_label_map(labels, img_small.size) colors = plt.cm.tab10(np.linspace(0,1,K))[:,:3] overlay = np.zeros((*lbl_map.shape,3)) for k in range(K): overlay[lbl_map==k] = colors[k] overlay = (overlay*255).astype(np.uint8) blend = cv2.addWeighted(np.array(img_small), 0.6, overlay, 0.4, 0) # Patch export outdir = "patterns" os.makedirs(outdir, exist_ok=True) ph, pw = np.array(img_small.size)//grid for k in range(K): mask = (labels==k) coords = np.argwhere(mask) for i,(y,x) in enumerate(coords): y0,y1 = int(y*ph), int((y+1)*ph) x0,x1 = int(x*pw), int((x+1)*pw) patch = img_small.crop((x0,y0,x1,y1)) patch.save(f"{outdir}/cluster{k}_patch{i}.png") # Zip patches zip_path = "patterns.zip" with zipfile.ZipFile(zip_path,"w",zipfile.ZIP_DEFLATED) as zf: for fn in os.listdir(outdir): zf.write(os.path.join(outdir,fn), fn) return Image.fromarray(blend), zip_path # ----------------------------- # 3️⃣ Gradio Interface # ----------------------------- title = "🪔 Saree Pattern Extractor (DINOv2)" desc = "Upload a saree image → cluster similar motifs & textures → preview overlay → download all patch crops as ZIP." demo = gr.Interface( fn=extract_patterns, inputs=[ gr.Image(label="Upload Saree Image", type="pil"), gr.Slider(3,12,value=8,step=1,label="Number of Pattern Clusters") ], outputs=[ gr.Image(label="Cluster Overlay"), gr.File(label="Download ZIP of Crops") ], title=title, description=desc, allow_flagging="never", cache_examples=False ) if __name__ == "__main__": demo.launch()