Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from sentence_transformers import SentenceTransformer | |
| from app import show_serial | |
| import faiss | |
| import numpy as np | |
| import pandas as pd | |
| def load_model(): | |
| return SentenceTransformer("paraphrase-multilingual-mpnet-base-v2") | |
| def get_shows(): | |
| data = pd.read_csv('data/data.csv') | |
| return data['description'].tolist() | |
| def compute_index(_model, sentences): | |
| embeddings = _model.encode(sentences).astype('float32') | |
| normalized_embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) | |
| index_l2 = faiss.IndexFlatL2(embeddings.shape[1]) | |
| index_dot = faiss.IndexFlatIP(embeddings.shape[1]) | |
| index_cosine = faiss.IndexFlatIP(embeddings.shape[1]) | |
| index_l2.add(embeddings) | |
| index_dot.add(embeddings) | |
| index_cosine.add(normalized_embeddings) | |
| return index_l2, index_dot, index_cosine | |
| model = load_model() | |
| sentences = get_shows() | |
| index_l2, index_dot, index_cosine = compute_index(model, sentences) | |
| indices = { | |
| 'L2': index_l2, | |
| 'Dot': index_dot, | |
| 'Cos': index_cosine | |
| } | |
| # indices = { | |
| # 'L2': faiss.read_index("models/index_l2.faiss"), | |
| # 'Dot': faiss.read_index("models/index_dot.faiss"), | |
| # 'Cos': faiss.read_index("models/index_cosine.faiss") | |
| # } | |
| with st.form(key='submit'): | |
| cols = st.columns(2) | |
| with cols[0]: | |
| prompt = st.text_area('Ваш запрос') | |
| with cols[1]: | |
| amount = st.slider(label='Количество рекомендаций', min_value=1, max_value=10, value=5) | |
| metrics = st.radio(label='Метрика', options=('L2', 'Dot', 'Cos'), horizontal=True) | |
| submit = st.form_submit_button('Искать', use_container_width=True) | |
| if submit: | |
| prompt_embedding = model.encode(prompt).astype('float32') | |
| prompt_embedding = prompt_embedding[np.newaxis, :] | |
| if metrics == 'Cos': | |
| prompt_embedding = prompt_embedding / np.linalg.norm(prompt_embedding, axis=1, keepdims=True) | |
| distances, indices_result = indices[metrics].search(prompt_embedding, amount) | |
| sorted_indices = indices_result[0].tolist() | |
| sorted_distances = distances[0].tolist() | |
| for idx, score in zip(sorted_indices, sorted_distances): | |
| show_serial(idx, (score, metrics)) | |