Kdnv commited on
Commit
360fc74
·
1 Parent(s): 8edef47
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from kdnv_preprocess import data_preprocessing
3
+ from sentence_transformers import SentenceTransformer
4
+ import faiss
5
+ import numpy as np
6
+ import json
7
+ from collections import Counter
8
+
9
+ @st.cache_resource
10
+ def load_model():
11
+ return SentenceTransformer("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
12
+
13
+ @st.cache_resource
14
+ def load_index():
15
+ indices = {
16
+ 'L2': faiss.read_index("models/index_l2.faiss"),
17
+ 'Dot': faiss.read_index("models/index_dot.faiss"),
18
+ 'Cos': faiss.read_index("models/index_cosine.faiss")
19
+ }
20
+ return indices
21
+
22
+ model = load_model()
23
+ indices = load_index()
24
+
25
+ with open('class_dict.json', 'r') as file:
26
+ class_dict = json.load(file)
27
+
28
+ st.header('Кальянный угадыватель')
29
+ st.caption('для Кобза')
30
+ st.divider()
31
+
32
+ with st.form(key='pred'):
33
+ text = st.text_area(label='Введи сюда описание табака')
34
+ button = st.form_submit_button('Узнать предсказание')
35
+
36
+ if button:
37
+ text = data_preprocessing(text)
38
+ prompt_embedding = model.encode(text).astype('float32')
39
+ prompt_embedding = prompt_embedding[np.newaxis, :]
40
+
41
+ _, indices_result_l2 = indices['L2'].search(prompt_embedding, 1)
42
+ _, indices_result_dot = indices['Dot'].search(prompt_embedding, 1)
43
+ _, indices_result_cosine = indices['Cos'].search(prompt_embedding, 1)
44
+
45
+ pred_l2 = class_dict[str(indices_result_l2[0][0])]
46
+ pred_dot = class_dict[str(indices_result_dot[0][0])]
47
+ pred_cosine = class_dict[str(indices_result_cosine[0][0])]
48
+
49
+ predictions = [pred_l2, pred_dot, pred_cosine]
50
+
51
+ prediction_counts = Counter(predictions)
52
+
53
+ final_prediction = prediction_counts.most_common(1)[0][0]
54
+
55
+ if len(prediction_counts) == len(predictions):
56
+ final_prediction = pred_l2
57
+
58
+ st.subheader(f'Я считаю, что это: {final_prediction}')
class_dict.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"0": "Cereal", "1": "Ginger", "2": "Melons", "3": "Raspberry", "4": "Tobacco", "5": "Currant", "6": "Tropic mix", "7": "Mixed berries", "8": "Blueberry", "9": "Tea", "10": "Passionfruit", "11": "Strawberry", "12": "Mixed fruits", "13": "Fruit/berry desserts", "14": "Aloe", "15": "Cucumber", "16": "Mixed citrus", "17": "Plum", "18": "Orange", "19": "Mango", "20": "Fruit/berry yogurt", "21": "Rice", "22": "Chocolate", "23": "Grapefruit", "24": "Sweets", "25": "Gum", "26": "Coctail", "27": "Cloudberry", "28": "Lemon", "29": "Calamansi", "30": "Milk/Cream", "31": "Cola", "32": "Gin", "33": "Violet", "34": "Pomegranate", "35": "Grass", "36": "Prickly Pear", "37": "Lavander", "38": "Whiskey", "39": "Honey", "40": "Wild strawberry", "41": "Feijoa", "42": "Cranberry", "43": "Blackberry", "44": "Pineapple", "45": "Papaya", "46": "Melon", "47": "Pear", "48": "Banana", "49": "Apple", "50": "Nectarine", "51": "Grape", "52": "Cream", "53": "Waffles", "54": "Fruit/berry coctail", "55": "Kiwi", "56": "Coffee", "57": "Energy drink", "58": "Lemongrass", "59": "Mint", "60": "Cinnamon", "61": "Cookies", "62": "Estragon", "63": "Hazelnut", "64": "Lime", "65": "Basil", "66": "Elderberry", "67": "Bergamot", "68": "Coconut", "69": "Cherry", "70": "Mixed florals", "71": "Ice Cream", "72": "Cactus", "73": "Peach", "74": "Guava", "75": "Fir", "76": "Pistachio", "77": "Gooseberry", "78": "Salbei", "79": "Lychee", "80": "Cooling", "81": "Wildberry", "82": "Jackfruit", "83": "Watermelon", "84": "Chocomint", "85": "Almond", "86": "Cake", "87": "Liqueur", "88": "Mandarin", "89": "Baikal", "90": "Soursop", "91": "Root beer", "92": "Vanilla", "93": "Corn", "94": "Curry", "95": "Orange flower", "96": "Wine", "97": "Beer", "98": "Piones", "99": "Yudzu", "100": "Saffron", "101": "Wood", "102": "Anise", "103": "Rum", "104": "Maple", "105": "Marula", "106": "Quince", "107": "Sea Buckthorn"}
kdnv_preprocess.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+
3
+ # Патч для pymorphy2, чтобы использовать getfullargspec вместо getargspec
4
+ if not hasattr(inspect, 'getargspec'):
5
+ def getargspec(func):
6
+ specs = inspect.getfullargspec(func)
7
+ return specs.args, specs.varargs, specs.varkw, specs.defaults
8
+ inspect.getargspec = getargspec
9
+
10
+ import re
11
+ import string
12
+ import numpy as np
13
+ import torch
14
+ from torch import Tensor
15
+ import pymorphy2
16
+ from nltk.corpus import stopwords
17
+ import nltk
18
+ nltk.download('stopwords')
19
+
20
+ import spacy
21
+ import subprocess
22
+
23
+ # Попытка загрузки модели, если она не установлена
24
+ try:
25
+ nlp = spacy.load("ru_core_news_sm", disable=["parser", "ner"])
26
+ except OSError:
27
+ # Установка модели с помощью команды Spacy
28
+ subprocess.run(["python", "-m", "spacy", "download", "ru_core_news_sm"])
29
+ nlp = spacy.load("ru_core_news_sm", disable=["parser", "ner"])
30
+
31
+
32
+ # Загрузка стоп-слов для русского языка
33
+ stop_words = set(stopwords.words('russian'))
34
+ # Загрузка модели spacy для русского языка
35
+ nlp = spacy.load("ru_core_news_sm", disable=["parser", "ner"])
36
+ # Инициализация pymorphy2
37
+ morph = pymorphy2.MorphAnalyzer()
38
+
39
+ def data_preprocessing(text: str) -> str:
40
+ # Приведение к нижнему регистру
41
+ text = text.lower()
42
+
43
+ # Удаление HTML-тегов
44
+ text = re.sub(r'<.*?>', '', text)
45
+
46
+ # Удаление символов переноса строки и неразрывного пробела
47
+ text = text.replace('\n', ' ').replace('\xa0', ' ')
48
+
49
+ # Удаление пунктуации и цифр в одном шаге
50
+ text = ''.join([c for c in text if c not in string.punctuation and not c.isdigit()])
51
+
52
+ # Удаление стоп-слов и лемматизация
53
+ doc = nlp(text)
54
+ text = ' '.join([morph.parse(token.text)[0].normal_form for token in doc if token.text not in stop_words and not token.is_digit])
55
+
56
+ return text
57
+
58
+ def get_words_by_freq(sorted_words: list[tuple[str, int]], n: int = 10) -> list:
59
+ return list(filter(lambda x: x[1] > n, sorted_words))
60
+
61
+ def padding(review_int: list, seq_len: int) -> np.array:
62
+ """Make left-sided padding for input list of tokens
63
+
64
+ Args:
65
+ review_int (list): input list of tokens
66
+ seq_len (int): max length of sequence, it len(review_int[i]) > seq_len it will be trimmed, else it will be padded by zeros
67
+
68
+ Returns:
69
+ np.array: padded sequences
70
+ """
71
+ features = np.zeros((len(review_int), seq_len), dtype=int)
72
+ for i, review in enumerate(review_int):
73
+ if len(review) <= seq_len:
74
+ zeros = list(np.zeros(seq_len - len(review)))
75
+ new = zeros + review
76
+ else:
77
+ new = review[: seq_len]
78
+ features[i, :] = np.array(new)
79
+
80
+ return features
81
+
82
+ def preprocess_single_string(
83
+ input_string: str,
84
+ seq_len: int,
85
+ vocab_to_int: dict,
86
+ verbose: bool = False
87
+ ) -> Tensor:
88
+ """Function for all preprocessing steps on a single string
89
+
90
+ Args:
91
+ input_string (str): input single string for preprocessing
92
+ seq_len (int): max length of sequence, it len(review_int[i]) > seq_len it will be trimmed, else it will be padded by zeros
93
+ vocab_to_int (dict, optional): word corpus {'word' : int index}. Defaults to vocab_to_int.
94
+
95
+ Returns:
96
+ list: preprocessed string
97
+ """
98
+ preprocessed_string = data_preprocessing(input_string)
99
+ result_list = []
100
+ for word in preprocessed_string.split():
101
+ try:
102
+ result_list.append(vocab_to_int[word])
103
+ except KeyError as e:
104
+ if verbose:
105
+ print(f'{e}: not in dictionary!')
106
+ pass
107
+ result_padded = padding([result_list], seq_len)[0]
108
+
109
+ return Tensor(result_padded)
models/index_cosine.faiss ADDED
Binary file (332 kB). View file
 
models/index_dot.faiss ADDED
Binary file (332 kB). View file
 
models/index_l2.faiss ADDED
Binary file (332 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ faiss
2
+ nltk
3
+ numpy
4
+ pymorphy2
5
+ sentence_transformers
6
+ spacy
7
+ streamlit
8
+ torch