import streamlit as st from kdnv_preprocess import data_preprocessing from sentence_transformers import SentenceTransformer import faiss import numpy as np import json from collections import Counter @st.cache_resource def load_model(): return SentenceTransformer("sentence-transformers/paraphrase-multilingual-mpnet-base-v2") @st.cache_resource def load_index(): indices = { 'L2': faiss.read_index("models/index_l2.faiss"), 'Dot': faiss.read_index("models/index_dot.faiss"), 'Cos': faiss.read_index("models/index_cosine.faiss") } return indices model = load_model() indices = load_index() with open('class_dict.json', 'r') as file: class_dict = json.load(file) st.header('Кальянный угадыватель') st.markdown('[для Кобза](https://www.youtube.com/watch?v=dQw4w9WgXcQ)') st.divider() with st.form(key='pred'): text = st.text_area(label='Введи сюда описание табака') button = st.form_submit_button('Узнать предсказание') if button: text = data_preprocessing(text) prompt_embedding = model.encode(text).astype('float32') prompt_embedding = prompt_embedding[np.newaxis, :] _, indices_result_l2 = indices['L2'].search(prompt_embedding, 1) _, indices_result_dot = indices['Dot'].search(prompt_embedding, 1) _, indices_result_cosine = indices['Cos'].search(prompt_embedding, 1) pred_l2 = class_dict[str(indices_result_l2[0][0])] pred_dot = class_dict[str(indices_result_dot[0][0])] pred_cosine = class_dict[str(indices_result_cosine[0][0])] predictions = [pred_l2, pred_dot, pred_cosine] prediction_counts = Counter(predictions) final_prediction = prediction_counts.most_common(1)[0][0] if len(prediction_counts) == len(predictions): final_prediction = pred_l2 st.subheader(f'Я считаю, что это: {final_prediction}')