Folds / app.py
nishanth-saka's picture
PATTERN (#7)
bb6f301 verified
# ==========================================================
# 🪔 Saree Pattern Extractor (DINOv2 / ViT Patch Clustering)
# Optimized for HF Free CPU Tier
# ==========================================================
import os, zipfile, io, cv2, numpy as np, torch
from PIL import Image
from sklearn.cluster import KMeans
import gradio as gr
from transformers import AutoImageProcessor, AutoModel
import matplotlib.pyplot as plt
# -----------------------------
# 1️⃣ Load DINOv2-small model
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "facebook/dinov2-small"
processor = AutoImageProcessor.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id).to(device).eval()
# -----------------------------
# 2️⃣ Helper functions
# -----------------------------
def upscale_label_map(lbl, target_size):
h, w = lbl.shape
return cv2.resize(lbl.astype(np.uint8), target_size, interpolation=cv2.INTER_NEAREST)
def extract_patterns(image: Image.Image, K: int = 8):
# Resize to keep under memory limits
image = image.convert("RGB")
img_small = image.copy()
img_small.thumbnail((480,480))
inputs = processor(images=img_small, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
feats = outputs.last_hidden_state.squeeze(0)[1:].cpu().numpy()
grid = int(np.sqrt(feats.shape[0]))
km = KMeans(n_clusters=K, random_state=0, n_init="auto").fit(feats)
labels = km.labels_.reshape(grid, grid)
# Rebuild overlay
lbl_map = upscale_label_map(labels, img_small.size)
colors = plt.cm.tab10(np.linspace(0,1,K))[:,:3]
overlay = np.zeros((*lbl_map.shape,3))
for k in range(K):
overlay[lbl_map==k] = colors[k]
overlay = (overlay*255).astype(np.uint8)
blend = cv2.addWeighted(np.array(img_small), 0.6, overlay, 0.4, 0)
# Patch export
outdir = "patterns"
os.makedirs(outdir, exist_ok=True)
ph, pw = np.array(img_small.size)//grid
for k in range(K):
mask = (labels==k)
coords = np.argwhere(mask)
for i,(y,x) in enumerate(coords):
y0,y1 = int(y*ph), int((y+1)*ph)
x0,x1 = int(x*pw), int((x+1)*pw)
patch = img_small.crop((x0,y0,x1,y1))
patch.save(f"{outdir}/cluster{k}_patch{i}.png")
# Zip patches
zip_path = "patterns.zip"
with zipfile.ZipFile(zip_path,"w",zipfile.ZIP_DEFLATED) as zf:
for fn in os.listdir(outdir):
zf.write(os.path.join(outdir,fn), fn)
return Image.fromarray(blend), zip_path
# -----------------------------
# 3️⃣ Gradio Interface
# -----------------------------
title = "🪔 Saree Pattern Extractor (DINOv2)"
desc = "Upload a saree image → cluster similar motifs & textures → preview overlay → download all patch crops as ZIP."
demo = gr.Interface(
fn=extract_patterns,
inputs=[
gr.Image(label="Upload Saree Image", type="pil"),
gr.Slider(3,12,value=8,step=1,label="Number of Pattern Clusters")
],
outputs=[
gr.Image(label="Cluster Overlay"),
gr.File(label="Download ZIP of Crops")
],
title=title,
description=desc,
allow_flagging="never",
cache_examples=False
)
if __name__ == "__main__":
demo.launch()