| import os |
| import pickle as pkl |
| from pathlib import Path |
|
|
| import numpy as np |
| import streamlit as st |
|
|
| import tensorflow as tf |
| from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input |
| from tensorflow.keras.layers import GlobalMaxPool2D |
| from tensorflow.keras.preprocessing import image as kimage |
|
|
| from sklearn.neighbors import NearestNeighbors |
| from numpy.linalg import norm |
|
|
|
|
| |
| |
| |
| def extract_features(image_path: str, model: tf.keras.Model) -> np.ndarray: |
| """Return L2-normalized 2048-d embedding from an image path.""" |
| img = kimage.load_img(image_path, target_size=(224, 224)) |
| arr = kimage.img_to_array(img) |
| arr = np.expand_dims(arr, axis=0) |
| arr = preprocess_input(arr) |
| feat = model.predict(arr, verbose=0).flatten() |
| return feat / (norm(feat) + 1e-10) |
|
|
|
|
| @st.cache_resource |
| def load_resnet(): |
| base = ResNet50(weights="imagenet", include_top=False, input_shape=(224, 224, 3)) |
| base.trainable = False |
| return tf.keras.Sequential([base, GlobalMaxPool2D()]) |
|
|
|
|
| @st.cache_resource |
| def load_data(features_path: str, filenames_path: str): |
| with open(features_path, "rb") as f: |
| features = pkl.load(f) |
| with open(filenames_path, "rb") as f: |
| filenames = pkl.load(f) |
|
|
| features = np.array(features, dtype="float32") |
|
|
| |
| if len(filenames) != features.shape[0]: |
| st.warning( |
| f"Mismatch gevonden: len(filenames)={len(filenames)} maar features={features.shape[0]}. " |
| f"Ik gebruik automatisch de eerste {features.shape[0]} filenames." |
| ) |
| filenames = filenames[: features.shape[0]] |
|
|
| return features, filenames |
|
|
|
|
| @st.cache_resource |
| def fit_nn(features: np.ndarray, n_neighbors: int): |
| nn = NearestNeighbors(n_neighbors=n_neighbors, algorithm="brute", metric="euclidean") |
| nn.fit(features) |
| return nn |
|
|
|
|
| def resolve_path(p: str, images_dir: str) -> str: |
| """Open the image even if filenames.pkl contains only basenames.""" |
| if os.path.exists(p): |
| return p |
| return os.path.join(images_dir, p) |
|
|
|
|
| |
| |
| |
| st.set_page_config(page_title="Fashion Recommender", layout="wide") |
| st.title(" Fashion Recommender ") |
|
|
| BASE = Path(__file__).resolve().parent |
|
|
| with st.sidebar: |
| st.header("📁 Bestanden") |
| images_dir = st.text_input("Images map", value=str(BASE / "images")) |
| features_pkl = st.text_input("Features pickle", value=str(BASE / "Images_features.pkl")) |
| filenames_pkl = st.text_input("Filenames pickle", value=str(BASE / "filenames.pkl")) |
|
|
| st.divider() |
| st.header("⚙️ Aanbevelingen") |
| k = st.slider("Aantal aanbevelingen", 1, 20, 5) |
| show_distance = st.checkbox("Toon euclidean afstand", value=False) |
|
|
|
|
| if not Path(features_pkl).exists() or not Path(filenames_pkl).exists(): |
| st.error( |
| "Ik kan je .pkl bestanden niet vinden. " |
| "Zet `Images_features.pkl` en `filenames.pkl` in dezelfde map als dit script, " |
| "of geef het juiste pad op in de sidebar." |
| ) |
| st.stop() |
|
|
| features, filenames = load_data(features_pkl, filenames_pkl) |
|
|
| st.caption(f"Geladen: **{len(filenames)}** items | embedding dim: **{features.shape[1]}**") |
|
|
| |
| n_neighbors = min(k + 1, features.shape[0]) |
| neighbors = fit_nn(features, n_neighbors=n_neighbors) |
| model = load_resnet() |
|
|
| tab1, tab2 = st.tabs(["Kies uit dataset", "Upload een foto"]) |
|
|
| |
| |
| |
| with tab1: |
| st.subheader("Kies een item uit de dataset") |
|
|
| q = st.text_input("Zoek op bestandsnaam (bv. 16871)", value="") |
| if q.strip(): |
| matches = [f for f in filenames if q in os.path.basename(f)] |
| else: |
| matches = filenames |
|
|
| if not matches: |
| st.info("Geen matches. Probeer een andere zoekterm.") |
| else: |
| |
| selected = st.selectbox("Selecteer", options=matches[:5000]) |
| selected_path = resolve_path(selected, images_dir) |
|
|
| if not os.path.exists(selected_path): |
| st.error(f"Bestand niet gevonden: {selected_path}") |
| st.stop() |
|
|
| colA, colB = st.columns([1, 2], gap="large") |
| with colA: |
| st.write("**Gekozen item**") |
| st.image(selected_path, use_container_width=True) |
| st.caption(os.path.basename(selected_path)) |
|
|
| with colB: |
| st.write("**Aanbevelingen**") |
| query_vec = extract_features(selected_path, model) |
| dists, idxs = neighbors.kneighbors([query_vec], n_neighbors=n_neighbors) |
|
|
| recs = [] |
| for d, idx in zip(dists[0], idxs[0]): |
| p = resolve_path(filenames[idx], images_dir) |
|
|
| |
| try: |
| if os.path.abspath(p) == os.path.abspath(selected_path): |
| continue |
| except Exception: |
| pass |
|
|
| if os.path.exists(p): |
| recs.append((p, float(d))) |
| if len(recs) >= k: |
| break |
|
|
| cols = st.columns(min(5, max(1, len(recs)))) |
| for i, (p, d) in enumerate(recs): |
| with cols[i % len(cols)]: |
| st.image(p, use_container_width=True) |
| st.caption(os.path.basename(p)) |
| if show_distance: |
| st.caption(f"afstand: {d:.4f}") |
|
|
| |
| |
| |
| with tab2: |
| st.subheader("Upload een afbeelding en krijg vergelijkbare items") |
| up = st.file_uploader("Upload een JPG/PNG", type=["jpg", "jpeg", "png"]) |
| if up is not None: |
| tmp_dir = Path(".streamlit_tmp") |
| tmp_dir.mkdir(exist_ok=True) |
| tmp_path = tmp_dir / up.name |
| tmp_path.write_bytes(up.getvalue()) |
|
|
| colA, colB = st.columns([1, 2], gap="large") |
| with colA: |
| st.write("**Geüpload**") |
| st.image(str(tmp_path), use_container_width=True) |
| st.caption(up.name) |
|
|
| with colB: |
| st.write("**Aanbevelingen**") |
| query_vec = extract_features(str(tmp_path), model) |
| dists, idxs = neighbors.kneighbors([query_vec], n_neighbors=n_neighbors) |
|
|
| recs = [] |
| for d, idx in zip(dists[0], idxs[0]): |
| p = resolve_path(filenames[idx], images_dir) |
| if os.path.exists(p): |
| recs.append((p, float(d))) |
| if len(recs) >= k: |
| break |
|
|
| cols = st.columns(min(5, max(1, len(recs)))) |
| for i, (p, d) in enumerate(recs): |
| with cols[i % len(cols)]: |
| st.image(p, use_container_width=True) |
| st.caption(os.path.basename(p)) |
| if show_distance: |
| st.caption(f"afstand: {d:.4f}") |
|
|