DolAr1610 commited on
Commit
a5c9fa3
·
1 Parent(s): cde7325

add new logic

Browse files
ingestion/ingest_text.py CHANGED
@@ -3,6 +3,7 @@ from tqdm import tqdm
3
  from db.text_db import init_chroma, add_document_text
4
  from embeddings.text_embedder import get_text_embedding
5
  from ingestion.config import JSON_PATH
 
6
 
7
 
8
  def chunk_text(text, chunk_size=400, overlap=50):
@@ -23,6 +24,8 @@ def ingest_texts():
23
  articles = json.load(f)
24
  print(f"Found {len(articles)} articles.")
25
 
 
 
26
  for article in tqdm(articles, desc="Indexing texts"):
27
  full_text = f"{article.get('title', '')}\n{article.get('description', '')}\n{article.get('content', '')}"
28
 
@@ -38,12 +41,21 @@ def ingest_texts():
38
  chunks = chunk_text(full_text, chunk_size=400, overlap=50)
39
 
40
  for i, chunk in enumerate(chunks):
 
 
 
 
 
 
 
 
 
41
  emb = get_text_embedding(chunk)
42
  if emb:
43
- doc_id = f"{metadata['source_url']}#chunk{i}" if metadata[
44
- "source_url"] else f"{metadata['title']}#chunk{i}"
45
  add_document_text(vectordb, doc_id, emb, chunk, metadata)
46
  else:
47
  print(f"Failed to embed chunk {i} of {metadata['title']}")
48
 
49
- print("Done indexing texts.")
 
 
 
3
  from db.text_db import init_chroma, add_document_text
4
  from embeddings.text_embedder import get_text_embedding
5
  from ingestion.config import JSON_PATH
6
+ from search.bm_25_index import build_and_save_bm25
7
 
8
 
9
  def chunk_text(text, chunk_size=400, overlap=50):
 
24
  articles = json.load(f)
25
  print(f"Found {len(articles)} articles.")
26
 
27
+ bm25_chunks = []
28
+
29
  for article in tqdm(articles, desc="Indexing texts"):
30
  full_text = f"{article.get('title', '')}\n{article.get('description', '')}\n{article.get('content', '')}"
31
 
 
41
  chunks = chunk_text(full_text, chunk_size=400, overlap=50)
42
 
43
  for i, chunk in enumerate(chunks):
44
+ doc_id = f"{metadata['source_url']}#chunk{i}" if metadata["source_url"] else f"{metadata['title']}#chunk{i}"
45
+
46
+
47
+ bm25_chunks.append({
48
+ "chunk_id": doc_id,
49
+ "chunk_text": chunk,
50
+ "metadata": metadata,
51
+ })
52
+
53
  emb = get_text_embedding(chunk)
54
  if emb:
 
 
55
  add_document_text(vectordb, doc_id, emb, chunk, metadata)
56
  else:
57
  print(f"Failed to embed chunk {i} of {metadata['title']}")
58
 
59
+ build_and_save_bm25(bm25_chunks)
60
+
61
+ print("Done indexing texts.")
llm.py CHANGED
@@ -7,22 +7,43 @@ OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
7
 
8
 
9
  def generate_response(question, retrieved_docs, model="meta-llama/llama-3-8b-instruct"):
10
- context = "\n\n".join(
11
- f"Title: {doc.get('title', 'N/A')}\n"
12
- f"Description: {doc.get('description', 'N/A')}\n"
13
- f"Content: {doc.get('content', 'N/A')}\n"
14
- for doc in retrieved_docs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  )
16
 
17
- prompt = (
18
- "You are a polite assistant who provides clear and detailed answers based solely on the information from The Batch articles.\n\n"
19
- "Rules:\n"
20
- "- Answer only using the knowledge from The Batch articles.\n"
21
- "- Do not mention other sources or questions; provide only accurate, detailed, and understandable answers.\n"
22
- "- If the information is present in the context, give a clear answer.\n"
23
- "- If the information is missing, respond with: 'Sorry, I could not find the answer in the provided context.'\n"
24
- "- Do not guess, fabricate information, or go beyond the given context.\n\n"
25
- f"Context for the answer:\n{context}"
26
  )
27
 
28
  headers = {
@@ -33,13 +54,18 @@ def generate_response(question, retrieved_docs, model="meta-llama/llama-3-8b-ins
33
  data = {
34
  "model": model,
35
  "messages": [
36
- {"role": "system", "content": prompt},
37
- {"role": "user", "content": question}
38
  ],
39
- "temperature": 0.3
40
  }
41
 
42
- response = requests.post("https://openrouter.ai/api/v1/chat/completions", headers=headers, json=data)
 
 
 
 
 
43
 
44
  if response.status_code == 200:
45
  return response.json()['choices'][0]['message']['content'].strip()
 
7
 
8
 
9
  def generate_response(question, retrieved_docs, model="meta-llama/llama-3-8b-instruct"):
10
+ # 1) Нумеруємо джерела, щоб модель могла цитувати [1], [2], ...
11
+ sources_lines = []
12
+ for i, doc in enumerate(retrieved_docs, start=1):
13
+ title = doc.get("title", "N/A")
14
+ desc = doc.get("description", "")
15
+ content = doc.get("content", "")
16
+
17
+ # трохи обмежимо розмір (щоб не вбивати контекст)
18
+ content = (content or "")[:2000]
19
+
20
+ sources_lines.append(
21
+ f"[{i}] Title: {title}\n"
22
+ f"Description: {desc}\n"
23
+ f"Content: {content}\n"
24
+ )
25
+
26
+ sources = "\n\n".join(sources_lines).strip()
27
+
28
+ system_prompt = (
29
+ "You are a Retrieval-Augmented Question Answering assistant for The Batch articles.\n"
30
+ "Answer the user ONLY using the SOURCES provided.\n\n"
31
+ "Hard rules:\n"
32
+ "1) Use ONLY facts that appear in the SOURCES. Do NOT use outside knowledge.\n"
33
+ "2) Every factual claim MUST have a citation like [1] or [2].\n"
34
+ " - If a sentence contains multiple facts from different sources, cite all relevant sources: [1][3].\n"
35
+ "3) If the SOURCES do not contain enough information to answer, say:\n"
36
+ " \"Sorry, I could not find the answer in the provided sources.\" (and do not add citations)\n"
37
+ "4) Do not invent titles, dates, links, or quotes.\n"
38
+ "5) Keep the answer concise and clear.\n\n"
39
+ "Output format:\n"
40
+ "Answer: <your answer with citations>\n"
41
+ "Used sources: <list of source numbers you actually cited, e.g. [1], [3]>\n"
42
  )
43
 
44
+ user_prompt = (
45
+ f"SOURCES:\n{sources}\n\n"
46
+ f"QUESTION:\n{question}\n"
 
 
 
 
 
 
47
  )
48
 
49
  headers = {
 
54
  data = {
55
  "model": model,
56
  "messages": [
57
+ {"role": "system", "content": system_prompt},
58
+ {"role": "user", "content": user_prompt}
59
  ],
60
+ "temperature": 0.2
61
  }
62
 
63
+ response = requests.post(
64
+ "https://openrouter.ai/api/v1/chat/completions",
65
+ headers=headers,
66
+ json=data,
67
+ timeout=60
68
+ )
69
 
70
  if response.status_code == 200:
71
  return response.json()['choices'][0]['message']['content'].strip()
main.py CHANGED
@@ -1,8 +1,18 @@
1
  import streamlit as st
2
- from search.search_classical import classical_search
 
3
  from search.search_best_pair import best_pair_search
4
  from llm import generate_response
5
 
 
 
 
 
 
 
 
 
 
6
  st.set_page_config(page_title="🔍 Multimodal Search The Batch")
7
  st.image("data/the-batch-logo.webp", width=300)
8
  st.title("Multimodal Assistant")
@@ -10,9 +20,24 @@ st.title("Multimodal Assistant")
10
  mode = st.selectbox("🔎 Select the search mode:", ["Classical RAG", "Multimodal RAG"])
11
  query = st.text_input("📝 Enter the text query:")
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  if query:
14
  if mode == "Classical RAG":
15
- results = classical_search(query, k=3)
 
16
  else:
17
  results = best_pair_search(query, k=3)
18
 
@@ -35,16 +60,86 @@ if query:
35
  st.markdown(f"[🔗 Read the full article →]({meta['source_url']})")
36
  st.markdown("---")
37
 
38
- if st.button("🧠 Generate a response to a query"):
39
- docs = [
40
- {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  "title": meta.get("title", ""),
42
  "description": meta.get("description", ""),
43
- "content": meta.get("content", "")
44
- }
45
- for meta in results
46
- ]
 
47
 
48
  response = generate_response(query, docs)
49
  st.markdown("### 🤖 Generated Response:")
50
- st.success(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+
3
+ from search.search_classical import classical_search, classical_retrieve_chunks
4
  from search.search_best_pair import best_pair_search
5
  from llm import generate_response
6
 
7
+
8
+ def pick_mode(label: str) -> str:
9
+ if label.startswith("Semantic"):
10
+ return "semantic"
11
+ if label.startswith("Keyword"):
12
+ return "bm25"
13
+ return "hybrid"
14
+
15
+
16
  st.set_page_config(page_title="🔍 Multimodal Search The Batch")
17
  st.image("data/the-batch-logo.webp", width=300)
18
  st.title("Multimodal Assistant")
 
20
  mode = st.selectbox("🔎 Select the search mode:", ["Classical RAG", "Multimodal RAG"])
21
  query = st.text_input("📝 Enter the text query:")
22
 
23
+ # --- Classical controls ---
24
+ classical_retriever = "Semantic (embeddings)"
25
+ use_reranker = True
26
+
27
+ if mode == "Classical RAG":
28
+ classical_retriever = st.radio(
29
+ "🧩 Classical retrieval:",
30
+ ["Semantic (embeddings)", "Keyword (BM25)", "Hybrid (BM25 + Semantic)"],
31
+ horizontal=True
32
+ )
33
+ use_reranker = st.checkbox("✨ Use reranker (cross-encoder)", value=True)
34
+
35
+ # --- Preview results ---
36
+ results = []
37
  if query:
38
  if mode == "Classical RAG":
39
+ search_mode = pick_mode(classical_retriever)
40
+ results = classical_search(query, k=3, mode=search_mode)
41
  else:
42
  results = best_pair_search(query, k=3)
43
 
 
60
  st.markdown(f"[🔗 Read the full article →]({meta['source_url']})")
61
  st.markdown("---")
62
 
63
+ # --- Generate answer ---
64
+ if query and st.button("🧠 Generate a response to a query"):
65
+
66
+ if mode == "Classical RAG":
67
+ search_mode = pick_mode(classical_retriever)
68
+
69
+ chunks = classical_retrieve_chunks(
70
+ query=query,
71
+ mode=search_mode,
72
+ fetch_k=50,
73
+ rerank_k=5,
74
+ use_reranker=use_reranker
75
+ )
76
+
77
+ docs = []
78
+ for idx, c in enumerate(chunks, start=1):
79
+ meta = c.get("metadata", {})
80
+ docs.append({
81
+ "id": idx,
82
  "title": meta.get("title", ""),
83
  "description": meta.get("description", ""),
84
+ "source_url": meta.get("source_url", ""),
85
+ "content": c.get("chunk_text", ""),
86
+ "retriever": c.get("retriever", ""),
87
+ "rerank_score": c.get("rerank_score", None),
88
+ })
89
 
90
  response = generate_response(query, docs)
91
  st.markdown("### 🤖 Generated Response:")
92
+ st.success(response)
93
+
94
+ st.markdown("### 📌 Sources")
95
+ for d in docs:
96
+ st.markdown(f"**[{d['id']}] {d.get('title','')}**")
97
+ if d.get("source_url"):
98
+ st.markdown(d["source_url"])
99
+ st.write((d.get("content") or "")[:450] + "...")
100
+ if d.get("retriever"):
101
+ st.caption(f"retriever: {d['retriever']}")
102
+ if d.get("rerank_score") is not None:
103
+ st.caption(f"rerank_score: {d['rerank_score']:.4f}")
104
+ st.markdown("---")
105
+
106
+ else:
107
+ # ✅ Multimodal mode:
108
+ # Preview stays multimodal (best_pair_search),
109
+ # but the ANSWER is generated from TEXT chunks (hybrid) for reliable QA + citations.
110
+ chunks = classical_retrieve_chunks(
111
+ query=query,
112
+ mode="hybrid",
113
+ fetch_k=50,
114
+ rerank_k=5,
115
+ use_reranker=True
116
+ )
117
+
118
+ docs = []
119
+ for idx, c in enumerate(chunks, start=1):
120
+ meta = c.get("metadata", {})
121
+ docs.append({
122
+ "id": idx,
123
+ "title": meta.get("title", ""),
124
+ "description": meta.get("description", ""),
125
+ "source_url": meta.get("source_url", ""),
126
+ "content": c.get("chunk_text", ""),
127
+ "retriever": c.get("retriever", ""),
128
+ "rerank_score": c.get("rerank_score", None),
129
+ })
130
+
131
+ response = generate_response(query, docs)
132
+ st.markdown("### 🤖 Generated Response:")
133
+ st.success(response)
134
+
135
+ st.markdown("### 📌 Sources (text chunks)")
136
+ for d in docs:
137
+ st.markdown(f"**[{d['id']}] {d.get('title','')}**")
138
+ if d.get("source_url"):
139
+ st.markdown(d["source_url"])
140
+ st.write((d.get("content") or "")[:450] + "...")
141
+ if d.get("retriever"):
142
+ st.caption(f"retriever: {d['retriever']}")
143
+ if d.get("rerank_score") is not None:
144
+ st.caption(f"rerank_score: {d['rerank_score']:.4f}")
145
+ st.markdown("---")
requirements.txt CHANGED
@@ -3,7 +3,7 @@ langchain
3
  sentence-transformers
4
  transformers
5
  torch
6
- chromadb==0.4.22
7
  nltk
8
  requests
9
  tqdm
@@ -13,4 +13,5 @@ selenium
13
  webdriver-manager
14
  langchain-community
15
  emoji
16
- numpy==1.26.4
 
 
3
  sentence-transformers
4
  transformers
5
  torch
6
+ chromadb==1.3.6
7
  nltk
8
  requests
9
  tqdm
 
13
  webdriver-manager
14
  langchain-community
15
  emoji
16
+ numpy==1.26.4
17
+ rank-bm25
search/bm_25_index.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import re
4
+ from rank_bm25 import BM25Okapi
5
+
6
+ BM25_PATH = "bm_25_index/bm25.pkl"
7
+
8
+ def tokenize(text: str):
9
+ text = (text or "").lower()
10
+ return re.findall(r"[a-z0-9]+", text)
11
+
12
+ def build_and_save_bm25(chunks: list[dict], path: str = BM25_PATH) -> None:
13
+ """
14
+ chunks: [{ "chunk_id": str, "chunk_text": str, "metadata": dict }, ...]
15
+ """
16
+ os.makedirs(os.path.dirname(path), exist_ok=True)
17
+
18
+ corpus_tokens = [tokenize(c["chunk_text"]) for c in chunks]
19
+ bm25 = BM25Okapi(corpus_tokens)
20
+
21
+ payload = {
22
+ "bm25": bm25,
23
+ "chunks": chunks, # зберігаємо, щоб потім віддати метадані
24
+ }
25
+
26
+ with open(path, "wb") as f:
27
+ pickle.dump(payload, f)
search/bm_25_search.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import re
3
+
4
+ BM25_PATH = "bm_25_index/bm25.pkl"
5
+
6
+ def tokenize(text: str):
7
+ return re.findall(r"[a-z0-9]+", (text or "").lower())
8
+
9
+ def bm25_search(query: str, k: int = 50):
10
+ with open(BM25_PATH, "rb") as f:
11
+ payload = pickle.load(f)
12
+
13
+ bm25 = payload["bm25"]
14
+ chunks = payload["chunks"]
15
+
16
+ scores = bm25.get_scores(tokenize(query))
17
+ top_idx = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:k]
18
+
19
+ return [{
20
+ "chunk_id": chunks[i]["chunk_id"],
21
+ "chunk_text": chunks[i]["chunk_text"],
22
+ "metadata": chunks[i]["metadata"],
23
+ "score": float(scores[i]),
24
+ } for i in top_idx]
search/reranker.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import CrossEncoder
2
+
3
+ # легка і популярна модель для rerank
4
+ _MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
5
+ _ce = None
6
+
7
+ def rerank(query: str, chunks: list[dict], top_k: int = 5) -> list[dict]:
8
+ """
9
+ chunks: [{ "chunk_id":..., "chunk_text":..., "metadata":..., ... }, ...]
10
+ returns same dicts + "rerank_score"
11
+ """
12
+ global _ce
13
+ if _ce is None:
14
+ _ce = CrossEncoder(_MODEL_NAME)
15
+
16
+ pairs = [(query, c.get("chunk_text", "")) for c in chunks]
17
+ scores = _ce.predict(pairs)
18
+
19
+ for c, s in zip(chunks, scores):
20
+ c["rerank_score"] = float(s)
21
+
22
+ chunks.sort(key=lambda x: x.get("rerank_score", 0.0), reverse=True)
23
+ return chunks[:top_k]
search/search_classical.py CHANGED
@@ -1,18 +1,85 @@
1
  from db.text_db import init_chroma
2
  from embeddings.text_embedder import get_text_embedding
 
 
 
3
 
4
 
5
- def classical_search(query, k=5):
6
- db = init_chroma()
7
- emb = get_text_embedding(query)
8
- results = db.similarity_search_by_vector(emb, k=k)
9
- articles = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  seen = set()
 
 
 
 
 
 
 
11
 
12
- for r in results:
13
- meta = r.metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  aid = meta.get("source_url") or meta.get("title")
15
- if aid not in seen:
16
  seen.add(aid)
17
  articles.append({
18
  "title": meta.get("title"),
@@ -22,5 +89,4 @@ def classical_search(query, k=5):
22
  "content": meta.get("content"),
23
  "source_url": meta.get("source_url"),
24
  })
25
-
26
- return articles
 
1
  from db.text_db import init_chroma
2
  from embeddings.text_embedder import get_text_embedding
3
+ from search.bm_25_search import bm25_search
4
+ from search.reranker import rerank
5
+ import re
6
 
7
 
8
+ def classical_retrieve_chunks(
9
+ query: str,
10
+ mode: str = "semantic",
11
+ fetch_k: int = 50,
12
+ rerank_k: int = 5,
13
+ use_reranker: bool = True,
14
+ year_filter: int | None = None
15
+ ) -> list[dict]:
16
+ """
17
+ returns chunks:
18
+ [{chunk_id, chunk_text, metadata, score? , rerank_score?}, ...]
19
+ """
20
+ chunks = []
21
+
22
+ if mode in ("semantic", "hybrid"):
23
+ db = init_chroma()
24
+ emb = get_text_embedding(query)
25
+ dense_res = db.similarity_search_by_vector(emb, k=fetch_k)
26
+
27
+ for i, r in enumerate(dense_res):
28
+ chunks.append({
29
+ "chunk_id": f"semantic_{i}",
30
+ "chunk_text": r.page_content,
31
+ "metadata": r.metadata,
32
+ "retriever": "semantic",
33
+ "score": None,
34
+ })
35
+
36
+ if mode in ("bm25", "hybrid"):
37
+ bm25_res = bm25_search(query, k=fetch_k)
38
+ for r in bm25_res:
39
+ r["retriever"] = "bm25"
40
+ chunks.append(r)
41
+
42
+
43
+ # deduplicate by chunk_text
44
  seen = set()
45
+ unique_chunks = []
46
+
47
+ for c in chunks:
48
+ key = c["chunk_text"][:200] # хеш по тексту
49
+ if key not in seen:
50
+ seen.add(key)
51
+ unique_chunks.append(c)
52
 
53
+ chunks = unique_chunks
54
+
55
+ if use_reranker and chunks:
56
+ chunks = rerank(query, chunks, top_k=rerank_k)
57
+ else:
58
+ chunks = chunks[:rerank_k]
59
+
60
+ return chunks
61
+
62
+
63
+ def classical_search(query, k=5, mode="semantic"):
64
+ """
65
+ для UI результатів (статті), як у тебе було
66
+ """
67
+ chunks = classical_retrieve_chunks(
68
+ query=query,
69
+ mode=mode,
70
+ fetch_k=max(50, k * 20),
71
+ rerank_k=max(10, k * 5),
72
+ use_reranker=False, # для списку статей можна без reranker
73
+ year_filter=None
74
+ )
75
+
76
+ # дедуп по статтях
77
+ articles = []
78
+ seen = set()
79
+ for c in chunks:
80
+ meta = c["metadata"]
81
  aid = meta.get("source_url") or meta.get("title")
82
+ if aid and aid not in seen:
83
  seen.add(aid)
84
  articles.append({
85
  "title": meta.get("title"),
 
89
  "content": meta.get("content"),
90
  "source_url": meta.get("source_url"),
91
  })
92
+ return articles[:k]