File size: 7,951 Bytes
1534eca
56126c6
54a2029
56126c6
 
54a2029
1534eca
54a2029
 
1534eca
ed33c71
ba275bd
1534eca
54a2029
 
 
 
1534eca
 
54a2029
1534eca
 
8408d7d
1534eca
56126c6
54a2029
1cc8096
1534eca
 
 
 
 
 
 
 
1aab843
 
0f2723f
 
56126c6
f923319
1534eca
54a2029
1534eca
 
 
 
 
 
 
54a2029
1534eca
 
 
54a2029
 
 
 
 
 
 
1534eca
54a2029
1534eca
54a2029
 
1534eca
54a2029
56126c6
1534eca
54a2029
 
 
 
1534eca
54a2029
 
1534eca
 
 
 
 
 
 
 
 
54a2029
 
1534eca
 
 
 
 
 
 
 
 
 
54a2029
 
1534eca
54a2029
1534eca
54a2029
1534eca
54a2029
 
1534eca
 
54a2029
1534eca
 
54a2029
 
1534eca
 
 
 
54a2029
1534eca
 
 
 
 
 
54a2029
 
1534eca
56126c6
1534eca
 
 
 
 
 
56126c6
 
54a2029
1534eca
 
54a2029
 
 
1534eca
54a2029
eabf267
1534eca
 
 
 
54a2029
1534eca
54a2029
1534eca
0a20e04
54a2029
 
1534eca
56126c6
1534eca
56126c6
 
1534eca
 
f4370bc
 
 
 
 
1534eca
 
56126c6
1534eca
 
54a2029
0aba9da
 
54a2029
1534eca
 
54a2029
1534eca
 
 
54a2029
0aba9da
1534eca
 
 
 
 
 
 
 
 
 
ffe4ede
8991730
a2c0f65
8991730
 
a2c0f65
8991730
 
a2c0f65
8991730
1534eca
54a2029
56126c6
8991730
1534eca
8991730
 
 
56126c6
1534eca
 
7467be4
b14c280
cf7268b
057a3ab
f24509f
cf7268b
 
 
 
56126c6
 
7467be4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import os, json, joblib
import numpy as np
import cv2
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from sklearn.preprocessing import normalize
from sklearn.neighbors import NearestNeighbors
import gradio as gr
from PIL import Image
import pickle

from skimage.color import rgb2lab, lab2rgb
from skimage.feature import local_binary_pattern, hog
from sklearn.cluster import KMeans

# ---------------- CONFIG ----------------
ARTIFACTS_DIR = "."
FEATURES_PATH = os.path.join(ARTIFACTS_DIR, "features.npy")
PATHS_PATH = os.path.join(ARTIFACTS_DIR, "image_paths.json")
PALETTES_PATH = os.path.join(ARTIFACTS_DIR, "palettes.json")
INDEX_PATH = os.path.join(ARTIFACTS_DIR, "nn_index.joblib")
MODEL_PATH = os.path.join(ARTIFACTS_DIR, "resnet50_multilayer_ssl.pt")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
GDRIVE_FOLDER = "https://drive.google.com/drive/folders/10EXzo27vuTjyG9FXHWWO4J5AhVQhyUp9?usp=sharing"

# ---------------- LOAD ARTIFACTS ----------------
features = np.load(FEATURES_PATH)
with open(PATHS_PATH, "r") as f:
    IMG_PATHS = json.load(f)
with open(PALETTES_PATH, "r") as f:
    DATA_PALETTES = json.load(f)
nn_index = joblib.load(INDEX_PATH)
with open("kmeans.pkl", "rb") as f:
    kmeans = pickle.load(f)
with open("kmeans.pkl", "rb") as f:
    fitted_kmeans = pickle.load(f)


# ---------------- FEATURE CLASSES ----------------
class AutoColor:
    def __init__(self, n_colors=5, sample_px=150000, random_state=42):
        self.n_colors = n_colors
        self.sample_px = sample_px
        self.random_state = random_state

    def extract(self, arr: np.ndarray):
        lab = rgb2lab(arr / 255.0).reshape(-1, 3)
        if lab.shape[0] > self.sample_px:
            idx = np.random.RandomState(self.random_state).choice(
                lab.shape[0], self.sample_px, replace=False
            )
            lab = lab[idx]
        kmeans = KMeans(n_clusters=self.n_colors, random_state=self.random_state, n_init=8)
        kmeans.fit(lab)
        centers = kmeans.cluster_centers_
        labels = kmeans.labels_
        counts = np.bincount(labels, minlength=self.n_colors).astype(np.float32)
        props = counts / counts.sum()
        return centers, props

    def vectorize(self, centers, props):
        return np.concatenate([centers.flatten(), props]).astype(np.float32)


class TextureBank:
    def __init__(self):
        self.lbp_settings = [(8, 1), (8, 2), (16, 3)]
        self.gabor_kernels = []
        for theta in np.linspace(0, np.pi, 6, endpoint=False):
            for sigma in (1.0, 2.0, 3.0):
                for lambd in (3.0, 6.0, 9.0):
                    kern = cv2.getGaborKernel((9, 9), sigma, theta, lambd, gamma=0.5, psi=0)
                    self.gabor_kernels.append(kern)

    def extract(self, arr: np.ndarray):
        gray = cv2.cvtColor(arr, cv2.COLOR_RGB2GRAY)
        gray = cv2.resize(gray, (512, 512), interpolation=cv2.INTER_AREA)
        feats = []
        for (P, R) in self.lbp_settings:
            lbp = local_binary_pattern(gray, P=P, R=R, method="uniform")
            n_bins = P + 2
            hist, _ = np.histogram(lbp, bins=n_bins, range=(0, n_bins), density=True)
            feats.append(hist.astype(np.float32))
        for k in self.gabor_kernels:
            resp = cv2.filter2D(gray, cv2.CV_32F, k)
            feats.append([resp.mean(), resp.std()])
        h = hog(
            gray,
            pixels_per_cell=(16, 16),
            cells_per_block=(2, 2),
            orientations=9,
            visualize=False,
            feature_vector=True,
        )
        feats.append(h.astype(np.float32))
        return np.concatenate(feats, axis=0)


class ORBBoVW:
    def __init__(self, n_words=64):
        self.n_words = n_words
        self.kmeans = None
        self.orb = cv2.ORB_create(nfeatures=3000)

    def _orb_des(self, arr: np.ndarray):
        gray = cv2.cvtColor(arr, cv2.COLOR_RGB2GRAY)
        kps, des = self.orb.detectAndCompute(gray, None)
        if des is None:
            return np.zeros((0, 32), dtype=np.uint8)
        return des

    def transform(self, arr: np.ndarray):
        d = self._orb_des(arr)
        if d.shape[0] == 0:
            bow = np.zeros((self.n_words,), dtype=np.float32)
        else:
            idx = self.kmeans.predict(d.astype(np.float32))
            bow, _ = np.histogram(idx, bins=np.arange(self.n_words + 1))
            bow = bow.astype(np.float32)
            bow /= np.linalg.norm(bow) + 1e-8
        return bow


class ResNetMultiLayer(nn.Module):
    def __init__(self):
        super().__init__()
        base = torchvision.models.resnet50(weights=None)
        self.conv1 = base.conv1; self.bn1 = base.bn1
        self.relu = base.relu; self.maxpool = base.maxpool
        self.layer1 = base.layer1; self.layer2 = base.layer2
        self.layer3 = base.layer3; self.layer4 = base.layer4
        self.gap = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = self.conv1(x); x = self.bn1(x); x = self.relu(x); x = self.maxpool(x)
        x = self.layer1(x); x2 = self.layer2(x)
        x3 = self.layer3(x2); x4 = self.layer4(x3)
        f2 = self.gap(x2).flatten(1)
        f3 = self.gap(x3).flatten(1)
        f4 = self.gap(x4).flatten(1)
        return torch.cat([f2, f3, f4], dim=1)


# ---------------- LOAD MODELS ----------------
backbone = ResNetMultiLayer().to(DEVICE)
backbone.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
backbone.eval()

autocolor = AutoColor()
texturebank = TextureBank()
bovw = ORBBoVW(n_words=64)
bovw.kmeans = kmeans # dummy for transform()

TF_INFER = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
])

# ---------------- FEATURE EXTRACTION ----------------
def extract_single_feature(img):
    if isinstance(img, str):
        img = Image.open(img).convert("RGB")
    else:
        img = img.convert("RGB")
    arr = np.array(img)
    pil = transforms.ToPILImage()(arr)
    x = TF_INFER(pil).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        fcnn = backbone(x).cpu().numpy()
    fcnn = normalize(fcnn, norm="l2") * 0.50

    # forb = bovw.transform(arr)[None, :]
    # forb = normalize(forb, norm="l2") * 0.10

    ftex = texturebank.extract(arr)[None, :]
    ftex = normalize(ftex, norm="l2") * 0.30

    centers, props = autocolor.extract(arr)
    fcol = autocolor.vectorize(centers, props)[None, :]
    fcol = normalize(fcol, norm="l2") * 0.10

    feats = np.hstack([fcnn, ftex, fcol]).astype(np.float32)
    feats = normalize(feats, norm="l2")
    return feats


def adjust_path(colab_path: str):
    fname = os.path.basename(colab_path)
    return f"{GDRIVE_FOLDER}/{fname}"


def recommend_gradio(img, top_k=5):
    qf = extract_single_feature(img)
    qf = np.array(qf).reshape(1, -1)  

# 🔹 PAD if dimensions mismatch
    expected_dim = nn_index._fit_X.shape[1]   # dimension nn_index was trained on
    if qf.shape[1] < expected_dim:
        padding = np.zeros((1, expected_dim - qf.shape[1]), dtype=qf.dtype)
        qf = np.hstack([qf, padding])
    elif qf.shape[1] > expected_dim:
        qf = qf[:, :expected_dim]  # just in case it's larger (rare)
    dists, idxs = nn_index.kneighbors(qf)
    idxs = idxs[0].tolist()
    results = []
    for i in idxs[:top_k]:
        cand = IMG_PATHS[i]
        adjusted = adjust_path(cand)
        results.append(f"[View Image]({adjusted})")
    return "\n".join(results)


# ---------------- GRADIO APP ----------------
interface = gr.Interface(
    fn=recommend_gradio,
    inputs=gr.Image(type="filepath", label="Upload an Image"),
    # outputs=gr.Gallery(label="Top Matches", columns=5, rows=2),
    outputs=gr.Markdown(),
    title="Image Similarity Search",
    description="Upload an image and find the most similar images from the dataset."
)


if __name__ == "__main__":
    interface.launch()