import os import pickle as pkl from pathlib import Path import numpy as np import streamlit as st import tensorflow as tf from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input from tensorflow.keras.layers import GlobalMaxPool2D from tensorflow.keras.preprocessing import image as kimage from sklearn.neighbors import NearestNeighbors from numpy.linalg import norm # ----------------------------- # Feature extraction # ----------------------------- def extract_features(image_path: str, model: tf.keras.Model) -> np.ndarray: """Return L2-normalized 2048-d embedding from an image path.""" img = kimage.load_img(image_path, target_size=(224, 224)) arr = kimage.img_to_array(img) arr = np.expand_dims(arr, axis=0) arr = preprocess_input(arr) feat = model.predict(arr, verbose=0).flatten() return feat / (norm(feat) + 1e-10) @st.cache_resource def load_resnet(): base = ResNet50(weights="imagenet", include_top=False, input_shape=(224, 224, 3)) base.trainable = False return tf.keras.Sequential([base, GlobalMaxPool2D()]) @st.cache_resource def load_data(features_path: str, filenames_path: str): with open(features_path, "rb") as f: features = pkl.load(f) with open(filenames_path, "rb") as f: filenames = pkl.load(f) features = np.array(features, dtype="float32") # ✅ Auto-fix mismatch: pak evenveel filenames als features if len(filenames) != features.shape[0]: st.warning( f"Mismatch gevonden: len(filenames)={len(filenames)} maar features={features.shape[0]}. " f"Ik gebruik automatisch de eerste {features.shape[0]} filenames." ) filenames = filenames[: features.shape[0]] return features, filenames @st.cache_resource def fit_nn(features: np.ndarray, n_neighbors: int): nn = NearestNeighbors(n_neighbors=n_neighbors, algorithm="brute", metric="euclidean") nn.fit(features) return nn def resolve_path(p: str, images_dir: str) -> str: """Open the image even if filenames.pkl contains only basenames.""" if os.path.exists(p): return p return os.path.join(images_dir, p) # ----------------------------- # UI # ----------------------------- st.set_page_config(page_title="Fashion Recommender", layout="wide") st.title("🧥 Fashion Deep Learning Recommender (mini-set ok)") with st.sidebar: st.header("📁 Bestanden") images_dir = st.text_input("Images map", value="images") features_pkl = st.text_input("Features pickle", value="Images_features.pkl") filenames_pkl = st.text_input("Filenames pickle", value="filenames.pkl") st.divider() st.header("⚙️ Aanbevelingen") k = st.slider("Aantal aanbevelingen", 1, 20, 5) show_distance = st.checkbox("Toon euclidean afstand", value=False) if not Path(features_pkl).exists() or not Path(filenames_pkl).exists(): st.error( "Ik kan je .pkl bestanden niet vinden. " "Zet `Images_features.pkl` en `filenames.pkl` in dezelfde map als dit script, " "of geef het juiste pad op in de sidebar." ) st.stop() features, filenames = load_data(features_pkl, filenames_pkl) st.caption(f"Geladen: **{len(filenames)}** items | embedding dim: **{features.shape[1]}**") # pas n_neighbors aan zodat het nooit groter is dan dataset n_neighbors = min(k + 1, features.shape[0]) neighbors = fit_nn(features, n_neighbors=n_neighbors) model = load_resnet() tab1, tab2 = st.tabs(["Kies uit dataset", "Upload een foto"]) # ----------------------------- # Tab 1: pick an existing item # ----------------------------- with tab1: st.subheader("Kies een item uit de dataset") q = st.text_input("Zoek op bestandsnaam (bv. 16871)", value="") if q.strip(): matches = [f for f in filenames if q in os.path.basename(f)] else: matches = filenames if not matches: st.info("Geen matches. Probeer een andere zoekterm.") else: # cap the dropdown for UI performance selected = st.selectbox("Selecteer", options=matches[:5000]) selected_path = resolve_path(selected, images_dir) if not os.path.exists(selected_path): st.error(f"Bestand niet gevonden: {selected_path}") st.stop() colA, colB = st.columns([1, 2], gap="large") with colA: st.write("**Gekozen item**") st.image(selected_path, use_container_width=True) st.caption(os.path.basename(selected_path)) with colB: st.write("**Aanbevelingen**") query_vec = extract_features(selected_path, model) dists, idxs = neighbors.kneighbors([query_vec], n_neighbors=n_neighbors) recs = [] for d, idx in zip(dists[0], idxs[0]): p = resolve_path(filenames[idx], images_dir) # skip itself if it’s the same file try: if os.path.abspath(p) == os.path.abspath(selected_path): continue except Exception: pass if os.path.exists(p): recs.append((p, float(d))) if len(recs) >= k: break cols = st.columns(min(5, max(1, len(recs)))) for i, (p, d) in enumerate(recs): with cols[i % len(cols)]: st.image(p, use_container_width=True) st.caption(os.path.basename(p)) if show_distance: st.caption(f"afstand: {d:.4f}") # ----------------------------- # Tab 2: upload an external image # ----------------------------- with tab2: st.subheader("Upload een afbeelding en krijg vergelijkbare items") up = st.file_uploader("Upload een JPG/PNG", type=["jpg", "jpeg", "png"]) if up is not None: tmp_dir = Path(".streamlit_tmp") tmp_dir.mkdir(exist_ok=True) tmp_path = tmp_dir / up.name tmp_path.write_bytes(up.getvalue()) colA, colB = st.columns([1, 2], gap="large") with colA: st.write("**Geüpload**") st.image(str(tmp_path), use_container_width=True) st.caption(up.name) with colB: st.write("**Aanbevelingen**") query_vec = extract_features(str(tmp_path), model) dists, idxs = neighbors.kneighbors([query_vec], n_neighbors=n_neighbors) recs = [] for d, idx in zip(dists[0], idxs[0]): p = resolve_path(filenames[idx], images_dir) if os.path.exists(p): recs.append((p, float(d))) if len(recs) >= k: break cols = st.columns(min(5, max(1, len(recs)))) for i, (p, d) in enumerate(recs): with cols[i % len(cols)]: st.image(p, use_container_width=True) st.caption(os.path.basename(p)) if show_distance: st.caption(f"afstand: {d:.4f}")