BeyzaTopbas commited on
Commit
781d900
·
verified ·
1 Parent(s): eff919e

Upload 3 files

Browse files
Files changed (3) hide show
  1. Images_features.pkl +3 -0
  2. app.py +202 -0
  3. filenames.pkl +3 -0
Images_features.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26ca746c033196caee8255e2379aedef75d2e1b517b9290efbe40cdb46121dae
3
+ size 246929
app.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle as pkl
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import streamlit as st
7
+
8
+ import tensorflow as tf
9
+ from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
10
+ from tensorflow.keras.layers import GlobalMaxPool2D
11
+ from tensorflow.keras.preprocessing import image as kimage
12
+
13
+ from sklearn.neighbors import NearestNeighbors
14
+ from numpy.linalg import norm
15
+
16
+
17
+ # -----------------------------
18
+ # Feature extraction
19
+ # -----------------------------
20
+ def extract_features(image_path: str, model: tf.keras.Model) -> np.ndarray:
21
+ """Return L2-normalized 2048-d embedding from an image path."""
22
+ img = kimage.load_img(image_path, target_size=(224, 224))
23
+ arr = kimage.img_to_array(img)
24
+ arr = np.expand_dims(arr, axis=0)
25
+ arr = preprocess_input(arr)
26
+ feat = model.predict(arr, verbose=0).flatten()
27
+ return feat / (norm(feat) + 1e-10)
28
+
29
+
30
+ @st.cache_resource
31
+ def load_resnet():
32
+ base = ResNet50(weights="imagenet", include_top=False, input_shape=(224, 224, 3))
33
+ base.trainable = False
34
+ return tf.keras.Sequential([base, GlobalMaxPool2D()])
35
+
36
+
37
+ @st.cache_resource
38
+ def load_data(features_path: str, filenames_path: str):
39
+ with open(features_path, "rb") as f:
40
+ features = pkl.load(f)
41
+ with open(filenames_path, "rb") as f:
42
+ filenames = pkl.load(f)
43
+
44
+ features = np.array(features, dtype="float32")
45
+
46
+ # ✅ Auto-fix mismatch: pak evenveel filenames als features
47
+ if len(filenames) != features.shape[0]:
48
+ st.warning(
49
+ f"Mismatch gevonden: len(filenames)={len(filenames)} maar features={features.shape[0]}. "
50
+ f"Ik gebruik automatisch de eerste {features.shape[0]} filenames."
51
+ )
52
+ filenames = filenames[: features.shape[0]]
53
+
54
+ return features, filenames
55
+
56
+
57
+ @st.cache_resource
58
+ def fit_nn(features: np.ndarray, n_neighbors: int):
59
+ nn = NearestNeighbors(n_neighbors=n_neighbors, algorithm="brute", metric="euclidean")
60
+ nn.fit(features)
61
+ return nn
62
+
63
+
64
+ def resolve_path(p: str, images_dir: str) -> str:
65
+ """Open the image even if filenames.pkl contains only basenames."""
66
+ if os.path.exists(p):
67
+ return p
68
+ return os.path.join(images_dir, p)
69
+
70
+
71
+ # -----------------------------
72
+ # UI
73
+ # -----------------------------
74
+ st.set_page_config(page_title="Fashion Recommender", layout="wide")
75
+ st.title("🧥 Fashion Deep Learning Recommender (mini-set ok)")
76
+
77
+ with st.sidebar:
78
+ st.header("📁 Bestanden")
79
+ images_dir = st.text_input("Images map", value="images")
80
+ features_pkl = st.text_input("Features pickle", value="Images_features.pkl")
81
+ filenames_pkl = st.text_input("Filenames pickle", value="filenames.pkl")
82
+
83
+ st.divider()
84
+ st.header("⚙️ Aanbevelingen")
85
+ k = st.slider("Aantal aanbevelingen", 1, 20, 5)
86
+ show_distance = st.checkbox("Toon euclidean afstand", value=False)
87
+
88
+ if not Path(features_pkl).exists() or not Path(filenames_pkl).exists():
89
+ st.error(
90
+ "Ik kan je .pkl bestanden niet vinden. "
91
+ "Zet `Images_features.pkl` en `filenames.pkl` in dezelfde map als dit script, "
92
+ "of geef het juiste pad op in de sidebar."
93
+ )
94
+ st.stop()
95
+
96
+ features, filenames = load_data(features_pkl, filenames_pkl)
97
+
98
+ st.caption(f"Geladen: **{len(filenames)}** items | embedding dim: **{features.shape[1]}**")
99
+
100
+ # pas n_neighbors aan zodat het nooit groter is dan dataset
101
+ n_neighbors = min(k + 1, features.shape[0])
102
+ neighbors = fit_nn(features, n_neighbors=n_neighbors)
103
+ model = load_resnet()
104
+
105
+ tab1, tab2 = st.tabs(["Kies uit dataset", "Upload een foto"])
106
+
107
+ # -----------------------------
108
+ # Tab 1: pick an existing item
109
+ # -----------------------------
110
+ with tab1:
111
+ st.subheader("Kies een item uit de dataset")
112
+
113
+ q = st.text_input("Zoek op bestandsnaam (bv. 16871)", value="")
114
+ if q.strip():
115
+ matches = [f for f in filenames if q in os.path.basename(f)]
116
+ else:
117
+ matches = filenames
118
+
119
+ if not matches:
120
+ st.info("Geen matches. Probeer een andere zoekterm.")
121
+ else:
122
+ # cap the dropdown for UI performance
123
+ selected = st.selectbox("Selecteer", options=matches[:5000])
124
+ selected_path = resolve_path(selected, images_dir)
125
+
126
+ if not os.path.exists(selected_path):
127
+ st.error(f"Bestand niet gevonden: {selected_path}")
128
+ st.stop()
129
+
130
+ colA, colB = st.columns([1, 2], gap="large")
131
+ with colA:
132
+ st.write("**Gekozen item**")
133
+ st.image(selected_path, use_container_width=True)
134
+ st.caption(os.path.basename(selected_path))
135
+
136
+ with colB:
137
+ st.write("**Aanbevelingen**")
138
+ query_vec = extract_features(selected_path, model)
139
+ dists, idxs = neighbors.kneighbors([query_vec], n_neighbors=n_neighbors)
140
+
141
+ recs = []
142
+ for d, idx in zip(dists[0], idxs[0]):
143
+ p = resolve_path(filenames[idx], images_dir)
144
+
145
+ # skip itself if it’s the same file
146
+ try:
147
+ if os.path.abspath(p) == os.path.abspath(selected_path):
148
+ continue
149
+ except Exception:
150
+ pass
151
+
152
+ if os.path.exists(p):
153
+ recs.append((p, float(d)))
154
+ if len(recs) >= k:
155
+ break
156
+
157
+ cols = st.columns(min(5, max(1, len(recs))))
158
+ for i, (p, d) in enumerate(recs):
159
+ with cols[i % len(cols)]:
160
+ st.image(p, use_container_width=True)
161
+ st.caption(os.path.basename(p))
162
+ if show_distance:
163
+ st.caption(f"afstand: {d:.4f}")
164
+
165
+ # -----------------------------
166
+ # Tab 2: upload an external image
167
+ # -----------------------------
168
+ with tab2:
169
+ st.subheader("Upload een afbeelding en krijg vergelijkbare items")
170
+ up = st.file_uploader("Upload een JPG/PNG", type=["jpg", "jpeg", "png"])
171
+ if up is not None:
172
+ tmp_dir = Path(".streamlit_tmp")
173
+ tmp_dir.mkdir(exist_ok=True)
174
+ tmp_path = tmp_dir / up.name
175
+ tmp_path.write_bytes(up.getvalue())
176
+
177
+ colA, colB = st.columns([1, 2], gap="large")
178
+ with colA:
179
+ st.write("**Geüpload**")
180
+ st.image(str(tmp_path), use_container_width=True)
181
+ st.caption(up.name)
182
+
183
+ with colB:
184
+ st.write("**Aanbevelingen**")
185
+ query_vec = extract_features(str(tmp_path), model)
186
+ dists, idxs = neighbors.kneighbors([query_vec], n_neighbors=n_neighbors)
187
+
188
+ recs = []
189
+ for d, idx in zip(dists[0], idxs[0]):
190
+ p = resolve_path(filenames[idx], images_dir)
191
+ if os.path.exists(p):
192
+ recs.append((p, float(d)))
193
+ if len(recs) >= k:
194
+ break
195
+
196
+ cols = st.columns(min(5, max(1, len(recs))))
197
+ for i, (p, d) in enumerate(recs):
198
+ with cols[i % len(cols)]:
199
+ st.image(p, use_container_width=True)
200
+ st.caption(os.path.basename(p))
201
+ if show_distance:
202
+ st.caption(f"afstand: {d:.4f}")
filenames.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7076602ce29da942074fab43bd8af025e5f0f7fe3ec5b7ed0166bd15f8d212b
3
+ size 837463