init
Browse files- app.py +58 -0
- class_dict.json +1 -0
- kdnv_preprocess.py +109 -0
- models/index_cosine.faiss +0 -0
- models/index_dot.faiss +0 -0
- models/index_l2.faiss +0 -0
- requirements.txt +8 -0
app.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from kdnv_preprocess import data_preprocessing
|
| 3 |
+
from sentence_transformers import SentenceTransformer
|
| 4 |
+
import faiss
|
| 5 |
+
import numpy as np
|
| 6 |
+
import json
|
| 7 |
+
from collections import Counter
|
| 8 |
+
|
| 9 |
+
@st.cache_resource
|
| 10 |
+
def load_model():
|
| 11 |
+
return SentenceTransformer("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
|
| 12 |
+
|
| 13 |
+
@st.cache_resource
|
| 14 |
+
def load_index():
|
| 15 |
+
indices = {
|
| 16 |
+
'L2': faiss.read_index("models/index_l2.faiss"),
|
| 17 |
+
'Dot': faiss.read_index("models/index_dot.faiss"),
|
| 18 |
+
'Cos': faiss.read_index("models/index_cosine.faiss")
|
| 19 |
+
}
|
| 20 |
+
return indices
|
| 21 |
+
|
| 22 |
+
model = load_model()
|
| 23 |
+
indices = load_index()
|
| 24 |
+
|
| 25 |
+
with open('class_dict.json', 'r') as file:
|
| 26 |
+
class_dict = json.load(file)
|
| 27 |
+
|
| 28 |
+
st.header('Кальянный угадыватель')
|
| 29 |
+
st.caption('для Кобза')
|
| 30 |
+
st.divider()
|
| 31 |
+
|
| 32 |
+
with st.form(key='pred'):
|
| 33 |
+
text = st.text_area(label='Введи сюда описание табака')
|
| 34 |
+
button = st.form_submit_button('Узнать предсказание')
|
| 35 |
+
|
| 36 |
+
if button:
|
| 37 |
+
text = data_preprocessing(text)
|
| 38 |
+
prompt_embedding = model.encode(text).astype('float32')
|
| 39 |
+
prompt_embedding = prompt_embedding[np.newaxis, :]
|
| 40 |
+
|
| 41 |
+
_, indices_result_l2 = indices['L2'].search(prompt_embedding, 1)
|
| 42 |
+
_, indices_result_dot = indices['Dot'].search(prompt_embedding, 1)
|
| 43 |
+
_, indices_result_cosine = indices['Cos'].search(prompt_embedding, 1)
|
| 44 |
+
|
| 45 |
+
pred_l2 = class_dict[str(indices_result_l2[0][0])]
|
| 46 |
+
pred_dot = class_dict[str(indices_result_dot[0][0])]
|
| 47 |
+
pred_cosine = class_dict[str(indices_result_cosine[0][0])]
|
| 48 |
+
|
| 49 |
+
predictions = [pred_l2, pred_dot, pred_cosine]
|
| 50 |
+
|
| 51 |
+
prediction_counts = Counter(predictions)
|
| 52 |
+
|
| 53 |
+
final_prediction = prediction_counts.most_common(1)[0][0]
|
| 54 |
+
|
| 55 |
+
if len(prediction_counts) == len(predictions):
|
| 56 |
+
final_prediction = pred_l2
|
| 57 |
+
|
| 58 |
+
st.subheader(f'Я считаю, что это: {final_prediction}')
|
class_dict.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"0": "Cereal", "1": "Ginger", "2": "Melons", "3": "Raspberry", "4": "Tobacco", "5": "Currant", "6": "Tropic mix", "7": "Mixed berries", "8": "Blueberry", "9": "Tea", "10": "Passionfruit", "11": "Strawberry", "12": "Mixed fruits", "13": "Fruit/berry desserts", "14": "Aloe", "15": "Cucumber", "16": "Mixed citrus", "17": "Plum", "18": "Orange", "19": "Mango", "20": "Fruit/berry yogurt", "21": "Rice", "22": "Chocolate", "23": "Grapefruit", "24": "Sweets", "25": "Gum", "26": "Coctail", "27": "Cloudberry", "28": "Lemon", "29": "Calamansi", "30": "Milk/Cream", "31": "Cola", "32": "Gin", "33": "Violet", "34": "Pomegranate", "35": "Grass", "36": "Prickly Pear", "37": "Lavander", "38": "Whiskey", "39": "Honey", "40": "Wild strawberry", "41": "Feijoa", "42": "Cranberry", "43": "Blackberry", "44": "Pineapple", "45": "Papaya", "46": "Melon", "47": "Pear", "48": "Banana", "49": "Apple", "50": "Nectarine", "51": "Grape", "52": "Cream", "53": "Waffles", "54": "Fruit/berry coctail", "55": "Kiwi", "56": "Coffee", "57": "Energy drink", "58": "Lemongrass", "59": "Mint", "60": "Cinnamon", "61": "Cookies", "62": "Estragon", "63": "Hazelnut", "64": "Lime", "65": "Basil", "66": "Elderberry", "67": "Bergamot", "68": "Coconut", "69": "Cherry", "70": "Mixed florals", "71": "Ice Cream", "72": "Cactus", "73": "Peach", "74": "Guava", "75": "Fir", "76": "Pistachio", "77": "Gooseberry", "78": "Salbei", "79": "Lychee", "80": "Cooling", "81": "Wildberry", "82": "Jackfruit", "83": "Watermelon", "84": "Chocomint", "85": "Almond", "86": "Cake", "87": "Liqueur", "88": "Mandarin", "89": "Baikal", "90": "Soursop", "91": "Root beer", "92": "Vanilla", "93": "Corn", "94": "Curry", "95": "Orange flower", "96": "Wine", "97": "Beer", "98": "Piones", "99": "Yudzu", "100": "Saffron", "101": "Wood", "102": "Anise", "103": "Rum", "104": "Maple", "105": "Marula", "106": "Quince", "107": "Sea Buckthorn"}
|
kdnv_preprocess.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
|
| 3 |
+
# Патч для pymorphy2, чтобы использовать getfullargspec вместо getargspec
|
| 4 |
+
if not hasattr(inspect, 'getargspec'):
|
| 5 |
+
def getargspec(func):
|
| 6 |
+
specs = inspect.getfullargspec(func)
|
| 7 |
+
return specs.args, specs.varargs, specs.varkw, specs.defaults
|
| 8 |
+
inspect.getargspec = getargspec
|
| 9 |
+
|
| 10 |
+
import re
|
| 11 |
+
import string
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
import pymorphy2
|
| 16 |
+
from nltk.corpus import stopwords
|
| 17 |
+
import nltk
|
| 18 |
+
nltk.download('stopwords')
|
| 19 |
+
|
| 20 |
+
import spacy
|
| 21 |
+
import subprocess
|
| 22 |
+
|
| 23 |
+
# Попытка загрузки модели, если она не установлена
|
| 24 |
+
try:
|
| 25 |
+
nlp = spacy.load("ru_core_news_sm", disable=["parser", "ner"])
|
| 26 |
+
except OSError:
|
| 27 |
+
# Установка модели с помощью команды Spacy
|
| 28 |
+
subprocess.run(["python", "-m", "spacy", "download", "ru_core_news_sm"])
|
| 29 |
+
nlp = spacy.load("ru_core_news_sm", disable=["parser", "ner"])
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Загрузка стоп-слов для русского языка
|
| 33 |
+
stop_words = set(stopwords.words('russian'))
|
| 34 |
+
# Загрузка модели spacy для русского языка
|
| 35 |
+
nlp = spacy.load("ru_core_news_sm", disable=["parser", "ner"])
|
| 36 |
+
# Инициализация pymorphy2
|
| 37 |
+
morph = pymorphy2.MorphAnalyzer()
|
| 38 |
+
|
| 39 |
+
def data_preprocessing(text: str) -> str:
|
| 40 |
+
# Приведение к нижнему регистру
|
| 41 |
+
text = text.lower()
|
| 42 |
+
|
| 43 |
+
# Удаление HTML-тегов
|
| 44 |
+
text = re.sub(r'<.*?>', '', text)
|
| 45 |
+
|
| 46 |
+
# Удаление символов переноса строки и неразрывного пробела
|
| 47 |
+
text = text.replace('\n', ' ').replace('\xa0', ' ')
|
| 48 |
+
|
| 49 |
+
# Удаление пунктуации и цифр в одном шаге
|
| 50 |
+
text = ''.join([c for c in text if c not in string.punctuation and not c.isdigit()])
|
| 51 |
+
|
| 52 |
+
# Удаление стоп-слов и лемматизация
|
| 53 |
+
doc = nlp(text)
|
| 54 |
+
text = ' '.join([morph.parse(token.text)[0].normal_form for token in doc if token.text not in stop_words and not token.is_digit])
|
| 55 |
+
|
| 56 |
+
return text
|
| 57 |
+
|
| 58 |
+
def get_words_by_freq(sorted_words: list[tuple[str, int]], n: int = 10) -> list:
|
| 59 |
+
return list(filter(lambda x: x[1] > n, sorted_words))
|
| 60 |
+
|
| 61 |
+
def padding(review_int: list, seq_len: int) -> np.array:
|
| 62 |
+
"""Make left-sided padding for input list of tokens
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
review_int (list): input list of tokens
|
| 66 |
+
seq_len (int): max length of sequence, it len(review_int[i]) > seq_len it will be trimmed, else it will be padded by zeros
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
np.array: padded sequences
|
| 70 |
+
"""
|
| 71 |
+
features = np.zeros((len(review_int), seq_len), dtype=int)
|
| 72 |
+
for i, review in enumerate(review_int):
|
| 73 |
+
if len(review) <= seq_len:
|
| 74 |
+
zeros = list(np.zeros(seq_len - len(review)))
|
| 75 |
+
new = zeros + review
|
| 76 |
+
else:
|
| 77 |
+
new = review[: seq_len]
|
| 78 |
+
features[i, :] = np.array(new)
|
| 79 |
+
|
| 80 |
+
return features
|
| 81 |
+
|
| 82 |
+
def preprocess_single_string(
|
| 83 |
+
input_string: str,
|
| 84 |
+
seq_len: int,
|
| 85 |
+
vocab_to_int: dict,
|
| 86 |
+
verbose: bool = False
|
| 87 |
+
) -> Tensor:
|
| 88 |
+
"""Function for all preprocessing steps on a single string
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
input_string (str): input single string for preprocessing
|
| 92 |
+
seq_len (int): max length of sequence, it len(review_int[i]) > seq_len it will be trimmed, else it will be padded by zeros
|
| 93 |
+
vocab_to_int (dict, optional): word corpus {'word' : int index}. Defaults to vocab_to_int.
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
list: preprocessed string
|
| 97 |
+
"""
|
| 98 |
+
preprocessed_string = data_preprocessing(input_string)
|
| 99 |
+
result_list = []
|
| 100 |
+
for word in preprocessed_string.split():
|
| 101 |
+
try:
|
| 102 |
+
result_list.append(vocab_to_int[word])
|
| 103 |
+
except KeyError as e:
|
| 104 |
+
if verbose:
|
| 105 |
+
print(f'{e}: not in dictionary!')
|
| 106 |
+
pass
|
| 107 |
+
result_padded = padding([result_list], seq_len)[0]
|
| 108 |
+
|
| 109 |
+
return Tensor(result_padded)
|
models/index_cosine.faiss
ADDED
|
Binary file (332 kB). View file
|
|
|
models/index_dot.faiss
ADDED
|
Binary file (332 kB). View file
|
|
|
models/index_l2.faiss
ADDED
|
Binary file (332 kB). View file
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
faiss
|
| 2 |
+
nltk
|
| 3 |
+
numpy
|
| 4 |
+
pymorphy2
|
| 5 |
+
sentence_transformers
|
| 6 |
+
spacy
|
| 7 |
+
streamlit
|
| 8 |
+
torch
|