antimoda1 commited on
Commit
7a668f2
·
1 Parent(s): ee60fb3

update logic

Browse files
Files changed (3) hide show
  1. _1_get_documents.py +18 -9
  2. app.py +22 -137
  3. retrieval.py +1 -1
_1_get_documents.py CHANGED
@@ -8,16 +8,25 @@ def get_text(inst):
8
  if isinstance(inst, dict):
9
  return get_text(inst['text'])
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def load_and_process_data() -> list[dict]:
13
  """Загрузка и предобработка данных из JSON файлов"""
14
- all_messages = []
15
- doc_names = os.listdir('texts')
16
- txt_paths = ['texts/'+file for file in doc_names]
17
- for file_path in txt_paths:
18
- with open(file_path, 'r', encoding='utf-8-sig') as f:
19
- text = f.read()
20
- assert text
21
- all_messages.append(text)
22
 
23
- return all_messages, [x[:-3] for x in doc_names] # убираем расширение .md из имен документов
 
8
  if isinstance(inst, dict):
9
  return get_text(inst['text'])
10
 
11
+
12
+ def process_file(file_path):
13
+ with open(file_path, 'r', encoding='utf-8-sig') as f:
14
+ text = f.read()
15
+ assert text
16
+ return str(file_path).split('.')[-1], text
17
+
18
+ def process_folder_recursive(folder_path):
19
+ all_messages = []
20
+ for file in os.listdir(folder_path):
21
+ file_path = os.path.join(folder_path, file)
22
+ if os.path.isfile(file_path):
23
+ all_messages.append(process_file(file_path))
24
+ else:
25
+ all_messages += process_folder_recursive(file_path)
26
+ return all_messages
27
 
28
  def load_and_process_data() -> list[dict]:
29
  """Загрузка и предобработка данных из JSON файлов"""
30
+ all_messages = process_folder_recursive('texts')
 
 
 
 
 
 
 
31
 
32
+ return [x[0] for x in all_messages], [x[1][:-3] for x in all_messages] # возвращаем расширения и тексты документов
app.py CHANGED
@@ -1,8 +1,5 @@
 
1
  import gradio as gr
2
- import numpy as np
3
- import plotly.express as px
4
- import plotly.graph_objects as go
5
- import pandas as pd
6
  from generation import wrap_prompt
7
  from llm import get_llm_answer
8
  from retrieval import Retrieval
@@ -14,90 +11,6 @@ vocabulary = parse_vocabulary('vocabulary/vocabulary.md')
14
  retrieval = Retrieval()
15
 
16
 
17
- def create_heatmap(scores, chunk_ids, top_k_indices=None):
18
- """Создает heatmap релевантности документов по чанкам"""
19
- if len(scores) == 0:
20
- return go.Figure()
21
-
22
- # Группируем чанки по документам
23
- docs_chunks = {}
24
- chunk_to_doc_map = {}
25
-
26
- for idx, (chunk_id, score) in enumerate(zip(chunk_ids, scores)):
27
- doc_id = retrieval.docs_metadata[chunk_id]
28
- doc_name = retrieval.docs_names[doc_id]
29
- chunk_to_doc_map[chunk_id] = doc_name
30
-
31
- if doc_name not in docs_chunks:
32
- docs_chunks[doc_name] = []
33
-
34
- # Сохраняем информацию о чанке
35
- docs_chunks[doc_name].append({
36
- 'absolute_idx': idx,
37
- 'chunk_id': chunk_id,
38
- 'score': score,
39
- 'in_top_k': top_k_indices is not None and idx in top_k_indices
40
- })
41
-
42
- if not docs_chunks:
43
- return go.Figure()
44
-
45
- # Сортируем чанки внутри каждого документа по chunk_id
46
- for doc_name in docs_chunks:
47
- docs_chunks[doc_name].sort(key=lambda x: x['chunk_id'])
48
-
49
- # Создаем DataFrame для heatmap с относительными номерами чанков
50
- df_data = []
51
- for doc_name, chunks in docs_chunks.items():
52
- for relative_idx, chunk_info in enumerate(chunks):
53
- df_data.append({
54
- 'Документ': doc_name,
55
- 'Чанк (внутри документа)': f'Чанк {relative_idx + 1}',
56
- 'Релевантность': chunk_info['score'],
57
- 'Абсолютный ID': chunk_info['chunk_id'],
58
- 'В top-k': chunk_info['in_top_k']
59
- })
60
-
61
- df = pd.DataFrame(df_data)
62
-
63
- # Создаем heatmap
64
- fig = px.density_heatmap(
65
- df,
66
- x='Чанк (внутри документа)',
67
- y='Документ',
68
- z='Релевантность',
69
- title='Heatmap релевантности (по документам, с относительными номерами чанков)',
70
- color_continuous_scale='Viridis',
71
- labels={'Релевантность': 'Score'}
72
- )
73
-
74
- # Добавляем обводку для top-k чанков
75
- top_k_df = df[df['В top-k'] == True]
76
- if not top_k_df.empty:
77
- fig.add_trace(go.Scatter(
78
- x=top_k_df['Чанк (внутри документа)'],
79
- y=top_k_df['Документ'],
80
- mode='markers',
81
- marker=dict(
82
- symbol='circle-open',
83
- size=20,
84
- line=dict(color='red', width=2),
85
- color='rgba(0,0,0,0)'
86
- ),
87
- name='Top-k чанки',
88
- showlegend=True
89
- ))
90
-
91
- fig.update_layout(
92
- xaxis={'side': 'bottom', 'tickangle': -45},
93
- height=max(400, len(docs_chunks) * 30), # Адаптивная высота
94
- width=800,
95
- xaxis_title="Номер чанка в документе",
96
- yaxis_title="Документ"
97
- )
98
-
99
- return fig
100
-
101
  def perform_search(query, top_k, year_from, year_to):
102
  """Этап 1: Поиск и возврат результатов с фильтром по датам"""
103
 
@@ -125,6 +38,25 @@ def perform_search(query, top_k, year_from, year_to):
125
 
126
  # Выполняем поиск BM25
127
  scores = retrieval.bm25_search(query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  # Получаем индексы чанков
130
  chunk_ids = list(range(len(scores)))
@@ -147,29 +79,6 @@ def perform_search(query, top_k, year_from, year_to):
147
 
148
  return scores, chunk_ids, top_k_indices, status
149
 
150
- def filter_chunks_by_documents(top_k_indices, all_scores, selected_docs):
151
- """Фильтрует чанки по выбранным документам"""
152
- if len(top_k_indices)==0 or len(all_scores)==0:
153
- return []
154
-
155
- filtered_indices = []
156
- for idx in top_k_indices:
157
- if idx >= len(retrieval.docs_metadata):
158
- continue
159
-
160
- doc_id = retrieval.docs_metadata[idx]
161
- doc_name = retrieval.docs_names[doc_id] if doc_id < len(retrieval.docs_names) else "Неизвестный документ"
162
-
163
- # Если документы выбраны, проверяем наличие в списке
164
- if selected_docs and len(selected_docs) > 0:
165
- if doc_name in selected_docs:
166
- filtered_indices.append(idx)
167
- else:
168
- # Если ничего не выбрано, показываем все
169
- filtered_indices.append(idx)
170
-
171
- return filtered_indices
172
-
173
  def format_selected_chunks(selected_indices):
174
  """Форматирует выбранные чанки в единый текст для вывода и LLM
175
 
@@ -352,8 +261,7 @@ def ask_llm(query, filtered_indices_state):
352
 
353
  # Создаем интерфейс Gradio
354
  with gr.Blocks(title="RAG Application", theme=gr.themes.Soft()) as iface:
355
- gr.Markdown("# RAG Application для исторических документов")
356
- gr.Markdown("## Двухэтапная работа с документами")
357
 
358
  # Строка 1: поиск и фильтр по датам
359
  with gr.Row():
@@ -406,14 +314,6 @@ with gr.Blocks(title="RAG Application", theme=gr.themes.Soft()) as iface:
406
  )
407
 
408
  with gr.Row():
409
- with gr.Column(scale=1):
410
- # Фильтр ПОСЛЕ поиска для документов
411
- docs_after = gr.CheckboxGroup(
412
- choices=retrieval.docs_names,
413
- label="Фильтр по документам",
414
- info="Выберите документы (если ничего не выбрано - показываются все)"
415
- )
416
-
417
  with gr.Column(scale=2):
418
  # Большое текстовое поле для результатов retrieval
419
  retrieval_results = gr.Textbox(
@@ -454,27 +354,12 @@ with gr.Blocks(title="RAG Application", theme=gr.themes.Soft()) as iface:
454
  fn=perform_search,
455
  inputs=[search_query_input, top_k_slider, year_from_input, year_to_input],
456
  outputs=[all_scores_state, all_chunk_ids_state, top_k_indices_state, search_status]
457
- ).then(
458
- fn=filter_chunks_by_documents,
459
- inputs=[top_k_indices_state, all_scores_state, docs_after],
460
- outputs=[filtered_indices_state]
461
  ).then(
462
  fn=format_retrieval_results,
463
  inputs=[filtered_indices_state, top_k_slider],
464
  outputs=[retrieval_results]
465
  )
466
-
467
- # Обработчик изменения фильтра документов
468
- docs_after.change(
469
- fn=filter_chunks_by_documents,
470
- inputs=[top_k_indices_state, all_scores_state, docs_after],
471
- outputs=[filtered_indices_state]
472
- ).then(
473
- fn=format_retrieval_results,
474
- inputs=[filtered_indices_state, top_k_slider],
475
- outputs=[retrieval_results]
476
- )
477
-
478
  # Обработчик изменения слайдера top_k
479
  top_k_slider.change(
480
  fn=format_retrieval_results,
 
1
+ import re
2
  import gradio as gr
 
 
 
 
3
  from generation import wrap_prompt
4
  from llm import get_llm_answer
5
  from retrieval import Retrieval
 
11
  retrieval = Retrieval()
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def perform_search(query, top_k, year_from, year_to):
15
  """Этап 1: Поиск и возврат результатов с фильтром по датам"""
16
 
 
38
 
39
  # Выполняем поиск BM25
40
  scores = retrieval.bm25_search(query)
41
+ scores = list(scores) # Преобразуем в список если это ndarray
42
+
43
+ # Повышаем scores для документов, названия которых содержат паттерн маршрута из query
44
+ # Паттерн: [АМт]\d{2} (буква А, М или Т + две цифры, например А10, М30, Т2)
45
+ pattern = r'[АМт]\d{2}'
46
+ matches = re.findall(pattern, query)
47
+
48
+ if matches:
49
+ max_score = max(scores) if scores else 0
50
+ boost_score = max_score + 1 # Максимальный score + 1
51
+
52
+ for match in set(matches): # Используем set чтобы избежать дубликатов
53
+ # Ищем документы, которые содержат этот паттерн в названии
54
+ for doc_id, doc_name in enumerate(retrieval.docs_names):
55
+ if match in doc_name:
56
+ # Повышаем scores всех чанков из этого документа
57
+ for chunk_id, chunk_doc_id in enumerate(retrieval.docs_metadata):
58
+ if chunk_doc_id == doc_id:
59
+ scores[chunk_id] = boost_score
60
 
61
  # Получаем индексы чанков
62
  chunk_ids = list(range(len(scores)))
 
79
 
80
  return scores, chunk_ids, top_k_indices, status
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def format_selected_chunks(selected_indices):
83
  """Форматирует выбранные чанки в единый текст для вывода и LLM
84
 
 
261
 
262
  # Создаем интерфейс Gradio
263
  with gr.Blocks(title="RAG Application", theme=gr.themes.Soft()) as iface:
264
+ gr.Markdown("# Справочник по общественного истории транспорта Рязани")
 
265
 
266
  # Строка 1: поиск и фильтр по датам
267
  with gr.Row():
 
314
  )
315
 
316
  with gr.Row():
 
 
 
 
 
 
 
 
317
  with gr.Column(scale=2):
318
  # Большое текстовое поле для результатов retrieval
319
  retrieval_results = gr.Textbox(
 
354
  fn=perform_search,
355
  inputs=[search_query_input, top_k_slider, year_from_input, year_to_input],
356
  outputs=[all_scores_state, all_chunk_ids_state, top_k_indices_state, search_status]
 
 
 
 
357
  ).then(
358
  fn=format_retrieval_results,
359
  inputs=[filtered_indices_state, top_k_slider],
360
  outputs=[retrieval_results]
361
  )
362
+
 
 
 
 
 
 
 
 
 
 
 
363
  # Обработчик изменения слайдера top_k
364
  top_k_slider.change(
365
  fn=format_retrieval_results,
retrieval.py CHANGED
@@ -7,7 +7,7 @@ import warnings
7
  warnings.filterwarnings('ignore')
8
 
9
  from _1_get_documents import load_and_process_data
10
- from _2_splitting import parse_year_metadata, years_overlap, YEAR_OLD, YEAR_NEW
11
  from lemmatizer import RussianLemmatizer
12
  # from _3_chunking import RussianEmbedder
13
 
 
7
  warnings.filterwarnings('ignore')
8
 
9
  from _1_get_documents import load_and_process_data
10
+ from _2_splitting import parse_year_metadata, years_overlap
11
  from lemmatizer import RussianLemmatizer
12
  # from _3_chunking import RussianEmbedder
13