kobz / app.py
Kdnv's picture
some update
5e5a201
import streamlit as st
from kdnv_preprocess import data_preprocessing
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import json
from collections import Counter
@st.cache_resource
def load_model():
return SentenceTransformer("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
@st.cache_resource
def load_index():
indices = {
'L2': faiss.read_index("models/index_l2.faiss"),
'Dot': faiss.read_index("models/index_dot.faiss"),
'Cos': faiss.read_index("models/index_cosine.faiss")
}
return indices
model = load_model()
indices = load_index()
with open('class_dict.json', 'r') as file:
class_dict = json.load(file)
st.header('Кальянный угадыватель')
st.markdown('[для Кобза](https://www.youtube.com/watch?v=dQw4w9WgXcQ)')
st.divider()
with st.form(key='pred'):
text = st.text_area(label='Введи сюда описание табака')
button = st.form_submit_button('Узнать предсказание')
if button:
text = data_preprocessing(text)
prompt_embedding = model.encode(text).astype('float32')
prompt_embedding = prompt_embedding[np.newaxis, :]
_, indices_result_l2 = indices['L2'].search(prompt_embedding, 1)
_, indices_result_dot = indices['Dot'].search(prompt_embedding, 1)
_, indices_result_cosine = indices['Cos'].search(prompt_embedding, 1)
pred_l2 = class_dict[str(indices_result_l2[0][0])]
pred_dot = class_dict[str(indices_result_dot[0][0])]
pred_cosine = class_dict[str(indices_result_cosine[0][0])]
predictions = [pred_l2, pred_dot, pred_cosine]
prediction_counts = Counter(predictions)
final_prediction = prediction_counts.most_common(1)[0][0]
if len(prediction_counts) == len(predictions):
final_prediction = pred_l2
st.subheader(f'Я считаю, что это: {final_prediction}')