|
|
import streamlit as st |
|
|
from kdnv_preprocess import data_preprocessing |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import faiss |
|
|
import numpy as np |
|
|
import json |
|
|
from collections import Counter |
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
return SentenceTransformer("sentence-transformers/paraphrase-multilingual-mpnet-base-v2") |
|
|
|
|
|
@st.cache_resource |
|
|
def load_index(): |
|
|
indices = { |
|
|
'L2': faiss.read_index("models/index_l2.faiss"), |
|
|
'Dot': faiss.read_index("models/index_dot.faiss"), |
|
|
'Cos': faiss.read_index("models/index_cosine.faiss") |
|
|
} |
|
|
return indices |
|
|
|
|
|
model = load_model() |
|
|
indices = load_index() |
|
|
|
|
|
with open('class_dict.json', 'r') as file: |
|
|
class_dict = json.load(file) |
|
|
|
|
|
st.header('Кальянный угадыватель') |
|
|
st.markdown('[для Кобза](https://www.youtube.com/watch?v=dQw4w9WgXcQ)') |
|
|
st.divider() |
|
|
|
|
|
with st.form(key='pred'): |
|
|
text = st.text_area(label='Введи сюда описание табака') |
|
|
button = st.form_submit_button('Узнать предсказание') |
|
|
|
|
|
if button: |
|
|
text = data_preprocessing(text) |
|
|
prompt_embedding = model.encode(text).astype('float32') |
|
|
prompt_embedding = prompt_embedding[np.newaxis, :] |
|
|
|
|
|
_, indices_result_l2 = indices['L2'].search(prompt_embedding, 1) |
|
|
_, indices_result_dot = indices['Dot'].search(prompt_embedding, 1) |
|
|
_, indices_result_cosine = indices['Cos'].search(prompt_embedding, 1) |
|
|
|
|
|
pred_l2 = class_dict[str(indices_result_l2[0][0])] |
|
|
pred_dot = class_dict[str(indices_result_dot[0][0])] |
|
|
pred_cosine = class_dict[str(indices_result_cosine[0][0])] |
|
|
|
|
|
predictions = [pred_l2, pred_dot, pred_cosine] |
|
|
|
|
|
prediction_counts = Counter(predictions) |
|
|
|
|
|
final_prediction = prediction_counts.most_common(1)[0][0] |
|
|
|
|
|
if len(prediction_counts) == len(predictions): |
|
|
final_prediction = pred_l2 |
|
|
|
|
|
st.subheader(f'Я считаю, что это: {final_prediction}') |
|
|
|