import numpy as np import pandas as pd import tensorflow as tf from sklearn.neighbors import NearestNeighbors import pickle import requests from PIL import Image import streamlit as st @st.cache_resource def feature_extractor()->tf.keras.Sequential: model = tf.keras.applications.MobileNetV2(weights='imagenet', include_top=False, input_shape=[80,60,3]) feature_extactor = tf.keras.Sequential([ model, tf.keras.layers.Flatten() ]) return feature_extactor @st.cache_data def load_resource(resource_path): with open(f'./Embeddings/{resource_path}', 'rb') as fp: res = pickle.load(fp) return res class FashionSearch: def __init__(self) -> None: self.embeddings = load_resource(resource_path='image_embeddings.pkl') self.name = np.array(load_resource(resource_path='image_ids.pkl')) self.image_link = load_resource(resource_path='name_link_map.pkl') self.feature_extractor = feature_extractor() def KNN(self, metric:str='minkowski')->NearestNeighbors: knn = NearestNeighbors(n_neighbors=10, algorithm='brute', metric='minkowski') knn.fit(self.embeddings) return knn def image_feature_extraction(self, img:Image.Image): sample_img_arr = np.array(img.resize((60,80))) sample_img_arr = tf.keras.applications.mobilenet_v2.preprocess_input(sample_img_arr) sample_features = self.feature_extractor(sample_img_arr[None, :]) return sample_features def find_k_neighbors(self, sample_img:Image.Image, metric:str='minkowski')->list[int]: knn = self.KNN(metric=metric) features = self.image_feature_extraction(img=sample_img) distance, indices = knn.kneighbors(X=features, n_neighbors=16) return list(map(str, self.name[indices.flatten().tolist()]))