File size: 3,339 Bytes
bb6f301
 
 
 
 
f2c1e17
bb6f301
 
 
 
f2c1e17
bb6f301
 
 
 
 
 
 
f2c1e17
bb6f301
 
 
 
 
 
f2c1e17
bb6f301
 
 
 
 
f2c1e17
bb6f301
 
 
 
f2c1e17
bb6f301
 
 
31e52e7
bb6f301
 
 
 
 
 
 
 
31e52e7
bb6f301
 
 
 
 
 
 
 
 
 
 
 
2423f94
bb6f301
 
 
 
 
f2c1e17
bb6f301
f2c1e17
bb6f301
 
 
 
 
f2c1e17
bb6f301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a057ced
bb6f301
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# ==========================================================
# 🪔 Saree Pattern Extractor (DINOv2 / ViT Patch Clustering)
# Optimized for HF Free CPU Tier
# ==========================================================
import os, zipfile, io, cv2, numpy as np, torch
from PIL import Image
from sklearn.cluster import KMeans
import gradio as gr
from transformers import AutoImageProcessor, AutoModel
import matplotlib.pyplot as plt

# -----------------------------
# 1️⃣  Load DINOv2-small model
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "facebook/dinov2-small"
processor = AutoImageProcessor.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id).to(device).eval()

# -----------------------------
# 2️⃣  Helper functions
# -----------------------------
def upscale_label_map(lbl, target_size):
    h, w = lbl.shape
    return cv2.resize(lbl.astype(np.uint8), target_size, interpolation=cv2.INTER_NEAREST)

def extract_patterns(image: Image.Image, K: int = 8):
    # Resize to keep under memory limits
    image = image.convert("RGB")
    img_small = image.copy()
    img_small.thumbnail((480,480))

    inputs = processor(images=img_small, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        feats = outputs.last_hidden_state.squeeze(0)[1:].cpu().numpy()

    grid = int(np.sqrt(feats.shape[0]))
    km = KMeans(n_clusters=K, random_state=0, n_init="auto").fit(feats)
    labels = km.labels_.reshape(grid, grid)

    # Rebuild overlay
    lbl_map = upscale_label_map(labels, img_small.size)
    colors = plt.cm.tab10(np.linspace(0,1,K))[:,:3]
    overlay = np.zeros((*lbl_map.shape,3))
    for k in range(K):
        overlay[lbl_map==k] = colors[k]
    overlay = (overlay*255).astype(np.uint8)
    blend = cv2.addWeighted(np.array(img_small), 0.6, overlay, 0.4, 0)

    # Patch export
    outdir = "patterns"
    os.makedirs(outdir, exist_ok=True)
    ph, pw = np.array(img_small.size)//grid
    for k in range(K):
        mask = (labels==k)
        coords = np.argwhere(mask)
        for i,(y,x) in enumerate(coords):
            y0,y1 = int(y*ph), int((y+1)*ph)
            x0,x1 = int(x*pw), int((x+1)*pw)
            patch = img_small.crop((x0,y0,x1,y1))
            patch.save(f"{outdir}/cluster{k}_patch{i}.png")

    # Zip patches
    zip_path = "patterns.zip"
    with zipfile.ZipFile(zip_path,"w",zipfile.ZIP_DEFLATED) as zf:
        for fn in os.listdir(outdir):
            zf.write(os.path.join(outdir,fn), fn)

    return Image.fromarray(blend), zip_path

# -----------------------------
# 3️⃣  Gradio Interface
# -----------------------------
title = "🪔 Saree Pattern Extractor (DINOv2)"
desc = "Upload a saree image → cluster similar motifs & textures → preview overlay → download all patch crops as ZIP."

demo = gr.Interface(
    fn=extract_patterns,
    inputs=[
        gr.Image(label="Upload Saree Image", type="pil"),
        gr.Slider(3,12,value=8,step=1,label="Number of Pattern Clusters")
    ],
    outputs=[
        gr.Image(label="Cluster Overlay"),
        gr.File(label="Download ZIP of Crops")
    ],
    title=title,
    description=desc,
    allow_flagging="never",
    cache_examples=False
)

if __name__ == "__main__":
    demo.launch()