Shoaib-33 commited on
Commit
56fc66b
·
1 Parent(s): 8058e7e

langchain added:

Browse files
Files changed (3) hide show
  1. app/main.py +0 -2
  2. app/rag/bm25.py +37 -42
  3. app/rag/ingestion.py +21 -1
app/main.py CHANGED
@@ -12,7 +12,6 @@ from app.api.routes_review import router as review_router
12
  from app.core.config import settings
13
  from app.core.logging import configure_logging
14
  from app.db.sqlite import init_db
15
- from app.rag.bm25 import BM25Index
16
  from app.rag.ingestion import DocumentIngestionService
17
  from app.rag.qdrant_store import QdrantVectorStore
18
 
@@ -46,7 +45,6 @@ def startup() -> None:
46
  QdrantVectorStore().ensure_collections()
47
  if settings.auto_ingest_pdfs_on_startup:
48
  DocumentIngestionService().ingest_pdf_directory(settings.document_dir)
49
- BM25Index.load_or_create().save()
50
 
51
 
52
  @app.get("/")
 
12
  from app.core.config import settings
13
  from app.core.logging import configure_logging
14
  from app.db.sqlite import init_db
 
15
  from app.rag.ingestion import DocumentIngestionService
16
  from app.rag.qdrant_store import QdrantVectorStore
17
 
 
45
  QdrantVectorStore().ensure_collections()
46
  if settings.auto_ingest_pdfs_on_startup:
47
  DocumentIngestionService().ingest_pdf_directory(settings.document_dir)
 
48
 
49
 
50
  @app.get("/")
app/rag/bm25.py CHANGED
@@ -2,74 +2,69 @@ import json
2
  from pathlib import Path
3
  from typing import Any
4
 
5
- from rank_bm25 import BM25Okapi
 
6
 
7
  from app.core.config import settings
8
- from app.db.sqlite import db
9
  from app.rag.text import tokenize
10
 
11
 
12
  class BM25Index:
13
  def __init__(self, docs: list[dict[str, Any]]) -> None:
14
  self.docs = docs
15
- self.tokens = [tokenize(d["text"]) for d in docs]
16
- self.index = BM25Okapi(self.tokens) if self.tokens else None
17
-
18
- @classmethod
19
- def from_db(cls) -> "BM25Index":
20
- with db() as conn:
21
- rows = conn.execute(
22
- """
23
- SELECT c.chunk_id, c.text, c.metadata_json, d.source_name
24
- FROM chunks c
25
- JOIN documents d ON d.doc_id = c.doc_id
26
- """
27
- ).fetchall()
28
- docs = [
29
- {
30
- "id": row["chunk_id"],
31
- "text": row["text"],
32
- "source_name": row["source_name"],
33
- "metadata": json.loads(row["metadata_json"]),
34
- }
35
- for row in rows
36
  ]
37
- return cls(docs)
 
 
 
38
 
39
  @classmethod
40
  def load_or_create(cls) -> "BM25Index":
41
  path = Path(settings.bm25_index_path)
42
  if not path.exists():
43
- return cls.from_db()
44
  try:
45
  payload = json.loads(path.read_text(encoding="utf-8"))
46
  return cls(payload.get("docs", []))
47
  except (OSError, json.JSONDecodeError):
48
- return cls.from_db()
49
 
50
  def save(self) -> None:
51
  path = Path(settings.bm25_index_path)
52
  path.parent.mkdir(parents=True, exist_ok=True)
53
  path.write_text(json.dumps({"docs": self.docs}, ensure_ascii=True), encoding="utf-8")
54
 
55
- def rebuild(self) -> None:
56
- fresh = self.from_db()
57
  self.docs = fresh.docs
58
- self.tokens = fresh.tokens
59
- self.index = fresh.index
60
  self.save()
61
 
62
  def search(self, query: str, top_k: int) -> list[dict[str, Any]]:
63
- if not self.index or not self.docs:
64
  return []
65
- scores = self.index.get_scores(tokenize(query))
66
- ranked = sorted(enumerate(scores), key=lambda item: item[1], reverse=True)[:top_k]
67
- return [
68
- {
69
- **self.docs[idx],
70
- "score": float(score),
71
- "metadata": {**self.docs[idx].get("metadata", {}), "retriever": "bm25"},
72
- }
73
- for idx, score in ranked
74
- if score > 0
75
- ]
 
 
 
 
 
2
  from pathlib import Path
3
  from typing import Any
4
 
5
+ from langchain_community.retrievers import BM25Retriever
6
+ from langchain_core.documents import Document
7
 
8
  from app.core.config import settings
 
9
  from app.rag.text import tokenize
10
 
11
 
12
  class BM25Index:
13
  def __init__(self, docs: list[dict[str, Any]]) -> None:
14
  self.docs = docs
15
+ documents = [
16
+ Document(
17
+ page_content=doc["text"],
18
+ metadata={
19
+ **doc.get("metadata", {}),
20
+ "id": doc["id"],
21
+ "source_name": doc.get("source_name", "unknown"),
22
+ },
23
+ )
24
+ for doc in docs
 
 
 
 
 
 
 
 
 
 
 
25
  ]
26
+ self.retriever = BM25Retriever.from_documents(
27
+ documents,
28
+ preprocess_func=tokenize,
29
+ ) if documents else None
30
 
31
  @classmethod
32
  def load_or_create(cls) -> "BM25Index":
33
  path = Path(settings.bm25_index_path)
34
  if not path.exists():
35
+ return cls([])
36
  try:
37
  payload = json.loads(path.read_text(encoding="utf-8"))
38
  return cls(payload.get("docs", []))
39
  except (OSError, json.JSONDecodeError):
40
+ return cls([])
41
 
42
  def save(self) -> None:
43
  path = Path(settings.bm25_index_path)
44
  path.parent.mkdir(parents=True, exist_ok=True)
45
  path.write_text(json.dumps({"docs": self.docs}, ensure_ascii=True), encoding="utf-8")
46
 
47
+ def rebuild(self, docs: list[dict[str, Any]] | None = None) -> None:
48
+ fresh = BM25Index(docs or self.docs)
49
  self.docs = fresh.docs
50
+ self.retriever = fresh.retriever
 
51
  self.save()
52
 
53
  def search(self, query: str, top_k: int) -> list[dict[str, Any]]:
54
+ if not self.retriever or not self.docs:
55
  return []
56
+ self.retriever.k = top_k
57
+ results = self.retriever.invoke(query)
58
+ hits = []
59
+ for rank, doc in enumerate(results, start=1):
60
+ metadata = dict(doc.metadata)
61
+ hits.append(
62
+ {
63
+ "id": str(metadata.get("id", f"bm25-{rank}")),
64
+ "text": doc.page_content,
65
+ "source_name": str(metadata.get("source_name", "unknown")),
66
+ "score": 1.0 / rank,
67
+ "metadata": {**metadata, "retriever": "langchain_bm25"},
68
+ }
69
+ )
70
+ return hits
app/rag/ingestion.py CHANGED
@@ -97,6 +97,7 @@ class DocumentIngestionService:
97
  split_docs = self.splitter.split_documents(docs)
98
  chunk_records = []
99
  new_chunk_texts = []
 
100
  skipped_chunks = 0
101
 
102
  for index, split_doc in enumerate(split_docs):
@@ -113,6 +114,14 @@ class DocumentIngestionService:
113
  "source_name": source_name,
114
  "text_hash": text_hash,
115
  }
 
 
 
 
 
 
 
 
116
  if self.cache.chunk_exists(text_hash):
117
  skipped_chunks += 1
118
  self.cache.save_chunk(chunk_id, doc_id, index, chunk, text_hash, chunk_metadata, embedded=False)
@@ -133,7 +142,7 @@ class DocumentIngestionService:
133
  )
134
 
135
  self.qdrant.upsert_chunks(points)
136
- BM25Index.from_db().save()
137
 
138
  return {
139
  "status": "embedded",
@@ -141,3 +150,14 @@ class DocumentIngestionService:
141
  "embedded_chunks": len(points),
142
  "skipped_chunks": skipped_chunks,
143
  }
 
 
 
 
 
 
 
 
 
 
 
 
97
  split_docs = self.splitter.split_documents(docs)
98
  chunk_records = []
99
  new_chunk_texts = []
100
+ bm25_docs = []
101
  skipped_chunks = 0
102
 
103
  for index, split_doc in enumerate(split_docs):
 
114
  "source_name": source_name,
115
  "text_hash": text_hash,
116
  }
117
+ bm25_docs.append(
118
+ {
119
+ "id": chunk_id,
120
+ "text": chunk,
121
+ "source_name": source_name,
122
+ "metadata": chunk_metadata,
123
+ }
124
+ )
125
  if self.cache.chunk_exists(text_hash):
126
  skipped_chunks += 1
127
  self.cache.save_chunk(chunk_id, doc_id, index, chunk, text_hash, chunk_metadata, embedded=False)
 
142
  )
143
 
144
  self.qdrant.upsert_chunks(points)
145
+ self._merge_bm25_docs(bm25_docs)
146
 
147
  return {
148
  "status": "embedded",
 
150
  "embedded_chunks": len(points),
151
  "skipped_chunks": skipped_chunks,
152
  }
153
+
154
+ def _merge_bm25_docs(self, docs: list[dict[str, Any]]) -> None:
155
+ current = BM25Index.load_or_create()
156
+ by_hash = {
157
+ str(doc.get("metadata", {}).get("text_hash") or doc.get("id")): doc
158
+ for doc in current.docs
159
+ }
160
+ for doc in docs:
161
+ key = str(doc.get("metadata", {}).get("text_hash") or doc.get("id"))
162
+ by_hash[key] = doc
163
+ BM25Index(list(by_hash.values())).save()