Bjg6742635 commited on
Commit
1cdabc0
·
verified ·
1 Parent(s): 495968d

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +156 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,158 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ from datasets import load_dataset, concatenate_datasets
5
+ from sklearn.feature_extraction.text import TfidfVectorizer
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
9
+ import spacy
10
+ import nltk
11
+ from nltk.corpus import stopwords
12
+ from nltk.tokenize import word_tokenize
13
+ import re
14
+ from bs4 import BeautifulSoup
15
+
16
+ # === Загрузка и подготовка данных ===
17
+
18
+ @st.cache_resource
19
+ def load_data():
20
+ # Загрузка датасета
21
+ data = load_dataset('Romyx/ru_QA_school_history', split='train')
22
+ df = pd.DataFrame(data)
23
+ df['Pt_question'] = df['question'].apply(preprocess_text)
24
+ df['Pt_answer'] = df['answer'].apply(preprocess_text)
25
+ return df
26
+
27
+ @st.cache_resource
28
+ def load_model_and_tokenizer():
29
+ # Загрузка предобученной модели вопрос-ответа (например, SberQuad)
30
+ model_name = "AlexKay/xlm-roberta-large-qa-multilingual-finedtuned-ru" # замените на нужную модель, например, "bert-base-uncased"
31
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
32
+ model = AutoModelForQuestionAnswering.from_pretrained(model_name)
33
+ return tokenizer, model
34
+
35
+ @st.cache_resource
36
+ def build_vectorizer(_df):
37
+ combined_texts = _df['Pt_question'].tolist() + _df['Pt_answer'].tolist()
38
+ vectorizer = TfidfVectorizer()
39
+ tfidf_matrix = vectorizer.fit_transform(combined_texts)
40
+ return vectorizer, tfidf_matrix
41
+
42
+ # === Предобработка текста ===
43
+
44
+ # Загрузка Spacy модели
45
+ nlp = spacy.load('ru_core_news_lg')
46
+ stop_words = set(stopwords.words('russian'))
47
+
48
+ cache_dict = {}
49
+
50
+ def get_norm_form(word):
51
+ if word in cache_dict:
52
+ return cache_dict[word]
53
+ norm_form = nlp(word)[0].lemma_
54
+ cache_dict[word] = norm_form
55
+ return norm_form
56
+
57
+ def remove_html_tags(text):
58
+ soup = BeautifulSoup(text, 'html.parser')
59
+ return soup.text
60
+
61
+ def preprocess_text(text):
62
+ if pd.isna(text) or text is None:
63
+ return ""
64
+ text = remove_html_tags(text)
65
+ text = text.lower()
66
+
67
+ # Обработка знаков препинания
68
+ text = re.sub(r'([^\w\s-]|_)', r' \1 ', text)
69
+ text = re.sub(r'\s+', ' ', text)
70
+ text = re.sub(r'(\w+)-(\w+)', r'\1 \2', text)
71
+ text = re.sub(r'(\d+)(г|кг|см|м|мм|л|мл)', r'\1 \2', text)
72
+
73
+ # Удаление всего, кроме букв, цифр и пробелов
74
+ text = re.sub(r'[^\w\s]', '', text)
75
+
76
+ tokens = word_tokenize(text)
77
+ tokens = [token for token in tokens if token not in stop_words]
78
+ tokens = [get_norm_form(token) for token in tokens]
79
+
80
+ words_to_remove = {"ответ", "new"}
81
+ tokens = [token for token in tokens if token not in words_to_remove]
82
+
83
+ return ' '.join(tokens)
84
+
85
+ # === Основная функция получения ответа ===
86
+ def get_answer_from_qa_model(user_question, df, vectorizer, tfidf_matrix, model, tokenizer):
87
+ processed = preprocess_text(user_question)
88
+ user_vec = vectorizer.transform([processed])
89
+
90
+ similarities = cosine_similarity(user_vec, tfidf_matrix).flatten()
91
+
92
+ # Проверка, что similarities не пустой
93
+ if len(similarities) == 0:
94
+ return "Тема не входит в программу этих классов."
95
+
96
+ best_match_idx = similarities.argmax()
97
+ best_score = similarities[best_match_idx]
98
+
99
+ if best_score > 0.1:
100
+ # Проверка, что индекс не выходит за границы
101
+ if best_match_idx >= len(df):
102
+ return "Тема не входит в программу этих классов."
103
+
104
+ context = df.iloc[best_match_idx]['answer']
105
+ question = user_question
106
+
107
+ inputs = tokenizer(question, context, return_tensors="pt", truncation=True, padding=True)
108
+
109
+ with torch.no_grad():
110
+ outputs = model(**inputs)
111
+
112
+ start_scores = outputs.start_logits
113
+ end_scores = outputs.end_logits
114
+
115
+ # Проверка на корректность размера логитов
116
+ if len(start_scores.shape) == 2:
117
+ start_idx = torch.argmax(start_scores, dim=1)[0].item()
118
+ end_idx = torch.argmax(end_scores, dim=1)[0].item()
119
+ else:
120
+ start_idx = torch.argmax(start_scores).item()
121
+ end_idx = torch.argmax(end_scores).item()
122
+
123
+ # Проверка, что индексы не выходят за пределы
124
+ seq_len = inputs['input_ids'].shape[1]
125
+ if start_idx >= seq_len or end_idx >= seq_len or start_idx > end_idx:
126
+ return "Ответ не найден."
127
+
128
+ answer = tokenizer.decode(inputs['input_ids'][0][start_idx:end_idx+1], skip_special_tokens=True)
129
+ else:
130
+ answer = "Извините, я не понимаю во��рос."
131
+
132
+ return answer
133
+
134
+ # === Интерфейс Streamlit ===
135
+
136
+ st.title("🤖 ИИ-ассистент по истории (на основе вопрос-ответа)")
137
+
138
+ st.write("Задайте вопрос, и я постараюсь найти на него ответ из базы.")
139
+
140
+ # Загрузка данных и модели
141
+ df = load_data()
142
+ tokenizer, model = load_model_and_tokenizer()
143
+ vectorizer, tfidf_matrix = build_vectorizer(df)
144
+
145
+ # Поле ввода вопроса
146
+ user_input = st.text_input("Введите ваш вопрос:")
147
+
148
+ if st.button("Получить ответ"):
149
+ if user_input.strip():
150
+ with st.spinner("Ищем ответ..."):
151
+ response = get_answer_from_qa_model(
152
+ user_input, df, vectorizer, tfidf_matrix, model, tokenizer
153
+ )
154
+ st.success("Ответ:")
155
+ st.write(response)
156
+ else:
157
+ st.warning("Пожалуйста, введите вопрос.")
158