File size: 1,846 Bytes
9bdc63e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.neighbors import NearestNeighbors

import pickle
import requests
from PIL import Image

import streamlit as st


@st.cache_resource 
def feature_extractor()->tf.keras.Sequential:
    model = tf.keras.applications.MobileNetV2(weights='imagenet', include_top=False, input_shape=[80,60,3])
    feature_extactor = tf.keras.Sequential([
        model,
        tf.keras.layers.Flatten()
    ])
    return feature_extactor  

@st.cache_data
def load_resource(resource_path):
    with open(f'./Embeddings/{resource_path}', 'rb') as fp:
        res = pickle.load(fp)
    return res




class FashionSearch:

    def __init__(self) -> None:
        self.embeddings = load_resource(resource_path='image_embeddings.pkl')
        self.name = np.array(load_resource(resource_path='image_ids.pkl'))
        self.image_link =  load_resource(resource_path='name_link_map.pkl')
        self.feature_extractor = feature_extractor()
    
    def KNN(self, metric:str='minkowski')->NearestNeighbors:
        knn = NearestNeighbors(n_neighbors=10, algorithm='brute', metric='minkowski')
        knn.fit(self.embeddings)
        return knn
    
    def image_feature_extraction(self, img:Image.Image):
        sample_img_arr = np.array(img.resize((60,80)))
        sample_img_arr = tf.keras.applications.mobilenet_v2.preprocess_input(sample_img_arr)
        sample_features = self.feature_extractor(sample_img_arr[None, :])
        return sample_features
    
    def find_k_neighbors(self, sample_img:Image.Image, metric:str='minkowski')->list[int]:
        knn = self.KNN(metric=metric)
        features = self.image_feature_extraction(img=sample_img)
        distance, indices = knn.kneighbors(X=features, n_neighbors=16)
        return list(map(str, self.name[indices.flatten().tolist()]))