rec_project / pages /search.py
Kdnv's picture
app search speed 2.0
7de79a5
import streamlit as st
from sentence_transformers import SentenceTransformer
from app import show_serial
import faiss
import numpy as np
import pandas as pd
@st.cache_resource
def load_model():
return SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
@st.cache_resource
def get_shows():
data = pd.read_csv('data/data.csv')
return data['description'].tolist()
@st.cache_resource
def compute_index(_model, sentences):
embeddings = _model.encode(sentences).astype('float32')
normalized_embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
index_l2 = faiss.IndexFlatL2(embeddings.shape[1])
index_dot = faiss.IndexFlatIP(embeddings.shape[1])
index_cosine = faiss.IndexFlatIP(embeddings.shape[1])
index_l2.add(embeddings)
index_dot.add(embeddings)
index_cosine.add(normalized_embeddings)
return index_l2, index_dot, index_cosine
model = load_model()
sentences = get_shows()
index_l2, index_dot, index_cosine = compute_index(model, sentences)
indices = {
'L2': index_l2,
'Dot': index_dot,
'Cos': index_cosine
}
# indices = {
# 'L2': faiss.read_index("models/index_l2.faiss"),
# 'Dot': faiss.read_index("models/index_dot.faiss"),
# 'Cos': faiss.read_index("models/index_cosine.faiss")
# }
with st.form(key='submit'):
cols = st.columns(2)
with cols[0]:
prompt = st.text_area('Ваш запрос')
with cols[1]:
amount = st.slider(label='Количество рекомендаций', min_value=1, max_value=10, value=5)
metrics = st.radio(label='Метрика', options=('L2', 'Dot', 'Cos'), horizontal=True)
submit = st.form_submit_button('Искать', use_container_width=True)
if submit:
prompt_embedding = model.encode(prompt).astype('float32')
prompt_embedding = prompt_embedding[np.newaxis, :]
if metrics == 'Cos':
prompt_embedding = prompt_embedding / np.linalg.norm(prompt_embedding, axis=1, keepdims=True)
distances, indices_result = indices[metrics].search(prompt_embedding, amount)
sorted_indices = indices_result[0].tolist()
sorted_distances = distances[0].tolist()
for idx, score in zip(sorted_indices, sorted_distances):
show_serial(idx, (score, metrics))