BeyzaTopbas's picture
Upload 3 files
037dad3 verified
import os
import pickle as pkl
from pathlib import Path
import numpy as np
import streamlit as st
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.keras.layers import GlobalMaxPool2D
from tensorflow.keras.preprocessing import image as kimage
from sklearn.neighbors import NearestNeighbors
from numpy.linalg import norm
# -----------------------------
# Feature extraction
# -----------------------------
def extract_features(image_path: str, model: tf.keras.Model) -> np.ndarray:
"""Return L2-normalized 2048-d embedding from an image path."""
img = kimage.load_img(image_path, target_size=(224, 224))
arr = kimage.img_to_array(img)
arr = np.expand_dims(arr, axis=0)
arr = preprocess_input(arr)
feat = model.predict(arr, verbose=0).flatten()
return feat / (norm(feat) + 1e-10)
@st.cache_resource
def load_resnet():
base = ResNet50(weights="imagenet", include_top=False, input_shape=(224, 224, 3))
base.trainable = False
return tf.keras.Sequential([base, GlobalMaxPool2D()])
@st.cache_resource
def load_data(features_path: str, filenames_path: str):
with open(features_path, "rb") as f:
features = pkl.load(f)
with open(filenames_path, "rb") as f:
filenames = pkl.load(f)
features = np.array(features, dtype="float32")
# ✅ Auto-fix mismatch: pak evenveel filenames als features
if len(filenames) != features.shape[0]:
st.warning(
f"Mismatch gevonden: len(filenames)={len(filenames)} maar features={features.shape[0]}. "
f"Ik gebruik automatisch de eerste {features.shape[0]} filenames."
)
filenames = filenames[: features.shape[0]]
return features, filenames
@st.cache_resource
def fit_nn(features: np.ndarray, n_neighbors: int):
nn = NearestNeighbors(n_neighbors=n_neighbors, algorithm="brute", metric="euclidean")
nn.fit(features)
return nn
def resolve_path(p: str, images_dir: str) -> str:
"""Open the image even if filenames.pkl contains only basenames."""
if os.path.exists(p):
return p
return os.path.join(images_dir, p)
# -----------------------------
# UI
# -----------------------------
st.set_page_config(page_title="Fashion Recommender", layout="wide")
st.title("🧥 Fashion Deep Learning Recommender (mini-set ok)")
with st.sidebar:
st.header("📁 Bestanden")
images_dir = st.text_input("Images map", value="images")
features_pkl = st.text_input("Features pickle", value="Images_features.pkl")
filenames_pkl = st.text_input("Filenames pickle", value="filenames.pkl")
st.divider()
st.header("⚙️ Aanbevelingen")
k = st.slider("Aantal aanbevelingen", 1, 20, 5)
show_distance = st.checkbox("Toon euclidean afstand", value=False)
if not Path(features_pkl).exists() or not Path(filenames_pkl).exists():
st.error(
"Ik kan je .pkl bestanden niet vinden. "
"Zet `Images_features.pkl` en `filenames.pkl` in dezelfde map als dit script, "
"of geef het juiste pad op in de sidebar."
)
st.stop()
features, filenames = load_data(features_pkl, filenames_pkl)
st.caption(f"Geladen: **{len(filenames)}** items | embedding dim: **{features.shape[1]}**")
# pas n_neighbors aan zodat het nooit groter is dan dataset
n_neighbors = min(k + 1, features.shape[0])
neighbors = fit_nn(features, n_neighbors=n_neighbors)
model = load_resnet()
tab1, tab2 = st.tabs(["Kies uit dataset", "Upload een foto"])
# -----------------------------
# Tab 1: pick an existing item
# -----------------------------
with tab1:
st.subheader("Kies een item uit de dataset")
q = st.text_input("Zoek op bestandsnaam (bv. 16871)", value="")
if q.strip():
matches = [f for f in filenames if q in os.path.basename(f)]
else:
matches = filenames
if not matches:
st.info("Geen matches. Probeer een andere zoekterm.")
else:
# cap the dropdown for UI performance
selected = st.selectbox("Selecteer", options=matches[:5000])
selected_path = resolve_path(selected, images_dir)
if not os.path.exists(selected_path):
st.error(f"Bestand niet gevonden: {selected_path}")
st.stop()
colA, colB = st.columns([1, 2], gap="large")
with colA:
st.write("**Gekozen item**")
st.image(selected_path, use_container_width=True)
st.caption(os.path.basename(selected_path))
with colB:
st.write("**Aanbevelingen**")
query_vec = extract_features(selected_path, model)
dists, idxs = neighbors.kneighbors([query_vec], n_neighbors=n_neighbors)
recs = []
for d, idx in zip(dists[0], idxs[0]):
p = resolve_path(filenames[idx], images_dir)
# skip itself if it’s the same file
try:
if os.path.abspath(p) == os.path.abspath(selected_path):
continue
except Exception:
pass
if os.path.exists(p):
recs.append((p, float(d)))
if len(recs) >= k:
break
cols = st.columns(min(5, max(1, len(recs))))
for i, (p, d) in enumerate(recs):
with cols[i % len(cols)]:
st.image(p, use_container_width=True)
st.caption(os.path.basename(p))
if show_distance:
st.caption(f"afstand: {d:.4f}")
# -----------------------------
# Tab 2: upload an external image
# -----------------------------
with tab2:
st.subheader("Upload een afbeelding en krijg vergelijkbare items")
up = st.file_uploader("Upload een JPG/PNG", type=["jpg", "jpeg", "png"])
if up is not None:
tmp_dir = Path(".streamlit_tmp")
tmp_dir.mkdir(exist_ok=True)
tmp_path = tmp_dir / up.name
tmp_path.write_bytes(up.getvalue())
colA, colB = st.columns([1, 2], gap="large")
with colA:
st.write("**Geüpload**")
st.image(str(tmp_path), use_container_width=True)
st.caption(up.name)
with colB:
st.write("**Aanbevelingen**")
query_vec = extract_features(str(tmp_path), model)
dists, idxs = neighbors.kneighbors([query_vec], n_neighbors=n_neighbors)
recs = []
for d, idx in zip(dists[0], idxs[0]):
p = resolve_path(filenames[idx], images_dir)
if os.path.exists(p):
recs.append((p, float(d)))
if len(recs) >= k:
break
cols = st.columns(min(5, max(1, len(recs))))
for i, (p, d) in enumerate(recs):
with cols[i % len(cols)]:
st.image(p, use_container_width=True)
st.caption(os.path.basename(p))
if show_distance:
st.caption(f"afstand: {d:.4f}")