File size: 7,274 Bytes
037dad3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import os
import pickle as pkl
from pathlib import Path

import numpy as np
import streamlit as st

import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.keras.layers import GlobalMaxPool2D
from tensorflow.keras.preprocessing import image as kimage

from sklearn.neighbors import NearestNeighbors
from numpy.linalg import norm


# -----------------------------
# Feature extraction
# -----------------------------
def extract_features(image_path: str, model: tf.keras.Model) -> np.ndarray:
    """Return L2-normalized 2048-d embedding from an image path."""
    img = kimage.load_img(image_path, target_size=(224, 224))
    arr = kimage.img_to_array(img)
    arr = np.expand_dims(arr, axis=0)
    arr = preprocess_input(arr)
    feat = model.predict(arr, verbose=0).flatten()
    return feat / (norm(feat) + 1e-10)


@st.cache_resource
def load_resnet():
    base = ResNet50(weights="imagenet", include_top=False, input_shape=(224, 224, 3))
    base.trainable = False
    return tf.keras.Sequential([base, GlobalMaxPool2D()])


@st.cache_resource
def load_data(features_path: str, filenames_path: str):
    with open(features_path, "rb") as f:
        features = pkl.load(f)
    with open(filenames_path, "rb") as f:
        filenames = pkl.load(f)

    features = np.array(features, dtype="float32")

    # ✅ Auto-fix mismatch: pak evenveel filenames als features
    if len(filenames) != features.shape[0]:
        st.warning(
            f"Mismatch gevonden: len(filenames)={len(filenames)} maar features={features.shape[0]}. "
            f"Ik gebruik automatisch de eerste {features.shape[0]} filenames."
        )
        filenames = filenames[: features.shape[0]]

    return features, filenames


@st.cache_resource
def fit_nn(features: np.ndarray, n_neighbors: int):
    nn = NearestNeighbors(n_neighbors=n_neighbors, algorithm="brute", metric="euclidean")
    nn.fit(features)
    return nn


def resolve_path(p: str, images_dir: str) -> str:
    """Open the image even if filenames.pkl contains only basenames."""
    if os.path.exists(p):
        return p
    return os.path.join(images_dir, p)


# -----------------------------
# UI
# -----------------------------
st.set_page_config(page_title="Fashion Recommender", layout="wide")
st.title("🧥 Fashion Deep Learning Recommender (mini-set ok)")

with st.sidebar:
    st.header("📁 Bestanden")
    images_dir = st.text_input("Images map", value="images")
    features_pkl = st.text_input("Features pickle", value="Images_features.pkl")
    filenames_pkl = st.text_input("Filenames pickle", value="filenames.pkl")

    st.divider()
    st.header("⚙️ Aanbevelingen")
    k = st.slider("Aantal aanbevelingen", 1, 20, 5)
    show_distance = st.checkbox("Toon euclidean afstand", value=False)

if not Path(features_pkl).exists() or not Path(filenames_pkl).exists():
    st.error(
        "Ik kan je .pkl bestanden niet vinden. "
        "Zet `Images_features.pkl` en `filenames.pkl` in dezelfde map als dit script, "
        "of geef het juiste pad op in de sidebar."
    )
    st.stop()

features, filenames = load_data(features_pkl, filenames_pkl)

st.caption(f"Geladen: **{len(filenames)}** items | embedding dim: **{features.shape[1]}**")

# pas n_neighbors aan zodat het nooit groter is dan dataset
n_neighbors = min(k + 1, features.shape[0])
neighbors = fit_nn(features, n_neighbors=n_neighbors)
model = load_resnet()

tab1, tab2 = st.tabs(["Kies uit dataset", "Upload een foto"])

# -----------------------------
# Tab 1: pick an existing item
# -----------------------------
with tab1:
    st.subheader("Kies een item uit de dataset")

    q = st.text_input("Zoek op bestandsnaam (bv. 16871)", value="")
    if q.strip():
        matches = [f for f in filenames if q in os.path.basename(f)]
    else:
        matches = filenames

    if not matches:
        st.info("Geen matches. Probeer een andere zoekterm.")
    else:
        # cap the dropdown for UI performance
        selected = st.selectbox("Selecteer", options=matches[:5000])
        selected_path = resolve_path(selected, images_dir)

        if not os.path.exists(selected_path):
            st.error(f"Bestand niet gevonden: {selected_path}")
            st.stop()

        colA, colB = st.columns([1, 2], gap="large")
        with colA:
            st.write("**Gekozen item**")
            st.image(selected_path, use_container_width=True)
            st.caption(os.path.basename(selected_path))

        with colB:
            st.write("**Aanbevelingen**")
            query_vec = extract_features(selected_path, model)
            dists, idxs = neighbors.kneighbors([query_vec], n_neighbors=n_neighbors)

            recs = []
            for d, idx in zip(dists[0], idxs[0]):
                p = resolve_path(filenames[idx], images_dir)

                # skip itself if it’s the same file
                try:
                    if os.path.abspath(p) == os.path.abspath(selected_path):
                        continue
                except Exception:
                    pass

                if os.path.exists(p):
                    recs.append((p, float(d)))
                if len(recs) >= k:
                    break

            cols = st.columns(min(5, max(1, len(recs))))
            for i, (p, d) in enumerate(recs):
                with cols[i % len(cols)]:
                    st.image(p, use_container_width=True)
                    st.caption(os.path.basename(p))
                    if show_distance:
                        st.caption(f"afstand: {d:.4f}")

# -----------------------------
# Tab 2: upload an external image
# -----------------------------
with tab2:
    st.subheader("Upload een afbeelding en krijg vergelijkbare items")
    up = st.file_uploader("Upload een JPG/PNG", type=["jpg", "jpeg", "png"])
    if up is not None:
        tmp_dir = Path(".streamlit_tmp")
        tmp_dir.mkdir(exist_ok=True)
        tmp_path = tmp_dir / up.name
        tmp_path.write_bytes(up.getvalue())

        colA, colB = st.columns([1, 2], gap="large")
        with colA:
            st.write("**Geüpload**")
            st.image(str(tmp_path), use_container_width=True)
            st.caption(up.name)

        with colB:
            st.write("**Aanbevelingen**")
            query_vec = extract_features(str(tmp_path), model)
            dists, idxs = neighbors.kneighbors([query_vec], n_neighbors=n_neighbors)

            recs = []
            for d, idx in zip(dists[0], idxs[0]):
                p = resolve_path(filenames[idx], images_dir)
                if os.path.exists(p):
                    recs.append((p, float(d)))
                if len(recs) >= k:
                    break

            cols = st.columns(min(5, max(1, len(recs))))
            for i, (p, d) in enumerate(recs):
                with cols[i % len(cols)]:
                    st.image(p, use_container_width=True)
                    st.caption(os.path.basename(p))
                    if show_distance:
                        st.caption(f"afstand: {d:.4f}")