SvetlanaSS commited on
Commit
73b6937
·
verified ·
1 Parent(s): 65c61a6

Create App.py

Browse files
Files changed (1) hide show
  1. App.py +163 -0
App.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import glob
3
+ from docx import Document
4
+ from sklearn.feature_extraction.text import TfidfVectorizer
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+ import torch
7
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
8
+ import numpy as np
9
+
10
+ def is_header(txt):
11
+ if not txt or len(txt) < 35:
12
+ if txt == txt.upper() and not txt.endswith(('.', ':', '?', '!')):
13
+ return True
14
+ if txt.istitle() and len(txt.split()) < 6 and not txt.endswith(('.', ':', '?', '!')):
15
+ return True
16
+ return False
17
+
18
+ def get_blocks_from_docx():
19
+ docx_list = glob.glob("*.docx")
20
+ if not docx_list:
21
+ return [], []
22
+ doc = Document(docx_list[0])
23
+ blocks = []
24
+ normal_blocks = []
25
+ for p in doc.paragraphs:
26
+ txt = p.text.strip()
27
+ if (
28
+ txt
29
+ and not (len(txt) <= 3 and txt.isdigit())
30
+ and len(txt.split()) > 3
31
+ ):
32
+ blocks.append(txt)
33
+ if not is_header(txt) and len(txt) > 25:
34
+ normal_blocks.append(txt)
35
+ for table in doc.tables:
36
+ for row in table.rows:
37
+ row_text = " | ".join(cell.text.strip() for cell in row.cells if cell.text.strip())
38
+ if row_text and len(row_text.split()) > 3 and len(row_text) > 25:
39
+ blocks.append(row_text)
40
+ if not is_header(row_text):
41
+ normal_blocks.append(row_text)
42
+ # remove duplicates
43
+ seen = set(); blocks_clean = []
44
+ for b in blocks:
45
+ if b not in seen:
46
+ blocks_clean.append(b)
47
+ seen.add(b)
48
+ seen = set(); normal_blocks_clean = []
49
+ for b in normal_blocks:
50
+ if b not in seen:
51
+ normal_blocks_clean.append(b)
52
+ seen.add(b)
53
+ return blocks_clean, normal_blocks_clean
54
+
55
+ blocks, normal_blocks = get_blocks_from_docx()
56
+ if not blocks or not normal_blocks:
57
+ blocks = ["База знаний пуста: проверьте содержимое и структуру вашего .docx!"]
58
+ normal_blocks = ["База знаний пуста: проверьте содержимое и структуру вашего .docx!"]
59
+
60
+ vectorizer = TfidfVectorizer(lowercase=True).fit(blocks)
61
+ matrix = vectorizer.transform(blocks)
62
+
63
+ tokenizer = T5Tokenizer.from_pretrained("cointegrated/rut5-base-multitask")
64
+ model = T5ForConditionalGeneration.from_pretrained("cointegrated/rut5-base-multitask")
65
+ model.eval()
66
+ device = 'cpu'
67
+
68
+ def rut5_answer(question, context):
69
+ prompt = f"question: {question} context: {context}"
70
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
71
+ with torch.no_grad():
72
+ output_ids = model.generate(
73
+ input_ids,
74
+ max_length=250, num_beams=4, min_length=40,
75
+ no_repeat_ngram_size=3, do_sample=False
76
+ )
77
+ return tokenizer.decode(output_ids[0], skip_special_tokens=True)
78
+
79
+ def flatten_index(idx):
80
+ # Универсальный способ из всего достать int
81
+ if isinstance(idx, (int, float, np.integer, np.floating)):
82
+ return int(idx)
83
+ if isinstance(idx, (list, tuple, np.ndarray)):
84
+ if len(idx) == 0:
85
+ return 0
86
+ return flatten_index(idx)
87
+ if hasattr(idx, "tolist"):
88
+ item = idx.tolist()
89
+ return flatten_index(item)
90
+ try:
91
+ return int(idx)
92
+ except Exception:
93
+ return 0
94
+
95
+ def ask_chatbot(question):
96
+ question = question.strip()
97
+ if not question:
98
+ return "Пожалуйста, введите вопрос."
99
+ if not normal_blocks or normal_blocks == ["База знаний пуста: проверьте содержимое и структуру вашего .docx!"]:
100
+ return "Ошибка: база знаний пуста. Проверьте .docx и перезапустите Space."
101
+
102
+ user_vec = vectorizer.transform([question.lower()])
103
+ sims = cosine_similarity(user_vec, matrix)[0]
104
+ n_blocks = min(3, len(blocks))
105
+ if n_blocks == 0:
106
+ return "Ошибка: база знаний отсутствует или пуста."
107
+ sorted_idxs = sims.argsort()[-n_blocks:][::-1]
108
+ context_blocks = []
109
+ for idx in sorted_idxs:
110
+ idx_int = flatten_index(idx)
111
+ if isinstance(idx_int, int) and 0 <= idx_int < len(blocks):
112
+ context_blocks.append(blocks[idx_int])
113
+ context = " ".join(context_blocks)
114
+ # Ответ только из абзацев, не заголовков!
115
+ best_normal_block = ""
116
+ max_sim = -1
117
+ for nb in normal_blocks:
118
+ v_nb = vectorizer.transform([nb.lower()])
119
+ sim = cosine_similarity(user_vec, v_nb)[0]
120
+ if sim > max_sim:
121
+ max_sim = sim
122
+ best_normal_block = nb
123
+ if not best_normal_block:
124
+ best_normal_block = context_blocks if context_blocks else ""
125
+ answer = rut5_answer(question, context)
126
+ if len(answer.strip().split()) < 8 or answer.count('.') < 2:
127
+ answer += "\n\n" + best_normal_block
128
+ if is_header(answer):
129
+ answer = best_normal_block
130
+ return answer
131
+
132
+ EXAMPLES = [
133
+ "Как оформить список литературы?",
134
+ "Какие сроки сдачи и защиты ВКР?",
135
+ "Какой процент оригинальности требуется?",
136
+ "Как оформлять формулы?"
137
+ ]
138
+
139
+ with gr.Blocks() as demo:
140
+ gr.Markdown(
141
+ "# Русскоязычный Чат-бот по методичке (AI+документ)\nЗадайте вопрос — получите развернутый ответ на основании вашего документа!"
142
+ )
143
+ question = gr.Textbox(label="Ваш вопрос", lines=2)
144
+ ask_btn = gr.Button("Получить ответ")
145
+ answer = gr.Markdown(label="Ответ", visible=True)
146
+
147
+ def with_spinner(q):
148
+ yield "Чат-бот думает..."
149
+ yield ask_chatbot(q)
150
+
151
+ ask_btn.click(with_spinner, question, answer)
152
+ question.submit(with_spinner, question, answer)
153
+ gr.Markdown("#### Примеры вопросов:")
154
+ gr.Examples(EXAMPLES, inputs=question)
155
+ gr.Markdown("""
156
+ ---
157
+ ### Контакты (укажите свои)
158
+ Преподаватель: ___________________
159
+ Email: ___________________________
160
+ Кафедра: _________________________
161
+ """)
162
+
163
+ demo.launch()