Spaces:
Sleeping
Sleeping
langchain added:
Browse files- app/main.py +0 -2
- app/rag/bm25.py +37 -42
- app/rag/ingestion.py +21 -1
app/main.py
CHANGED
|
@@ -12,7 +12,6 @@ from app.api.routes_review import router as review_router
|
|
| 12 |
from app.core.config import settings
|
| 13 |
from app.core.logging import configure_logging
|
| 14 |
from app.db.sqlite import init_db
|
| 15 |
-
from app.rag.bm25 import BM25Index
|
| 16 |
from app.rag.ingestion import DocumentIngestionService
|
| 17 |
from app.rag.qdrant_store import QdrantVectorStore
|
| 18 |
|
|
@@ -46,7 +45,6 @@ def startup() -> None:
|
|
| 46 |
QdrantVectorStore().ensure_collections()
|
| 47 |
if settings.auto_ingest_pdfs_on_startup:
|
| 48 |
DocumentIngestionService().ingest_pdf_directory(settings.document_dir)
|
| 49 |
-
BM25Index.load_or_create().save()
|
| 50 |
|
| 51 |
|
| 52 |
@app.get("/")
|
|
|
|
| 12 |
from app.core.config import settings
|
| 13 |
from app.core.logging import configure_logging
|
| 14 |
from app.db.sqlite import init_db
|
|
|
|
| 15 |
from app.rag.ingestion import DocumentIngestionService
|
| 16 |
from app.rag.qdrant_store import QdrantVectorStore
|
| 17 |
|
|
|
|
| 45 |
QdrantVectorStore().ensure_collections()
|
| 46 |
if settings.auto_ingest_pdfs_on_startup:
|
| 47 |
DocumentIngestionService().ingest_pdf_directory(settings.document_dir)
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
@app.get("/")
|
app/rag/bm25.py
CHANGED
|
@@ -2,74 +2,69 @@ import json
|
|
| 2 |
from pathlib import Path
|
| 3 |
from typing import Any
|
| 4 |
|
| 5 |
-
from
|
|
|
|
| 6 |
|
| 7 |
from app.core.config import settings
|
| 8 |
-
from app.db.sqlite import db
|
| 9 |
from app.rag.text import tokenize
|
| 10 |
|
| 11 |
|
| 12 |
class BM25Index:
|
| 13 |
def __init__(self, docs: list[dict[str, Any]]) -> None:
|
| 14 |
self.docs = docs
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
JOIN documents d ON d.doc_id = c.doc_id
|
| 26 |
-
"""
|
| 27 |
-
).fetchall()
|
| 28 |
-
docs = [
|
| 29 |
-
{
|
| 30 |
-
"id": row["chunk_id"],
|
| 31 |
-
"text": row["text"],
|
| 32 |
-
"source_name": row["source_name"],
|
| 33 |
-
"metadata": json.loads(row["metadata_json"]),
|
| 34 |
-
}
|
| 35 |
-
for row in rows
|
| 36 |
]
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
@classmethod
|
| 40 |
def load_or_create(cls) -> "BM25Index":
|
| 41 |
path = Path(settings.bm25_index_path)
|
| 42 |
if not path.exists():
|
| 43 |
-
return cls
|
| 44 |
try:
|
| 45 |
payload = json.loads(path.read_text(encoding="utf-8"))
|
| 46 |
return cls(payload.get("docs", []))
|
| 47 |
except (OSError, json.JSONDecodeError):
|
| 48 |
-
return cls
|
| 49 |
|
| 50 |
def save(self) -> None:
|
| 51 |
path = Path(settings.bm25_index_path)
|
| 52 |
path.parent.mkdir(parents=True, exist_ok=True)
|
| 53 |
path.write_text(json.dumps({"docs": self.docs}, ensure_ascii=True), encoding="utf-8")
|
| 54 |
|
| 55 |
-
def rebuild(self) -> None:
|
| 56 |
-
fresh = self.
|
| 57 |
self.docs = fresh.docs
|
| 58 |
-
self.
|
| 59 |
-
self.index = fresh.index
|
| 60 |
self.save()
|
| 61 |
|
| 62 |
def search(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
| 63 |
-
if not self.
|
| 64 |
return []
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
from typing import Any
|
| 4 |
|
| 5 |
+
from langchain_community.retrievers import BM25Retriever
|
| 6 |
+
from langchain_core.documents import Document
|
| 7 |
|
| 8 |
from app.core.config import settings
|
|
|
|
| 9 |
from app.rag.text import tokenize
|
| 10 |
|
| 11 |
|
| 12 |
class BM25Index:
|
| 13 |
def __init__(self, docs: list[dict[str, Any]]) -> None:
|
| 14 |
self.docs = docs
|
| 15 |
+
documents = [
|
| 16 |
+
Document(
|
| 17 |
+
page_content=doc["text"],
|
| 18 |
+
metadata={
|
| 19 |
+
**doc.get("metadata", {}),
|
| 20 |
+
"id": doc["id"],
|
| 21 |
+
"source_name": doc.get("source_name", "unknown"),
|
| 22 |
+
},
|
| 23 |
+
)
|
| 24 |
+
for doc in docs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
]
|
| 26 |
+
self.retriever = BM25Retriever.from_documents(
|
| 27 |
+
documents,
|
| 28 |
+
preprocess_func=tokenize,
|
| 29 |
+
) if documents else None
|
| 30 |
|
| 31 |
@classmethod
|
| 32 |
def load_or_create(cls) -> "BM25Index":
|
| 33 |
path = Path(settings.bm25_index_path)
|
| 34 |
if not path.exists():
|
| 35 |
+
return cls([])
|
| 36 |
try:
|
| 37 |
payload = json.loads(path.read_text(encoding="utf-8"))
|
| 38 |
return cls(payload.get("docs", []))
|
| 39 |
except (OSError, json.JSONDecodeError):
|
| 40 |
+
return cls([])
|
| 41 |
|
| 42 |
def save(self) -> None:
|
| 43 |
path = Path(settings.bm25_index_path)
|
| 44 |
path.parent.mkdir(parents=True, exist_ok=True)
|
| 45 |
path.write_text(json.dumps({"docs": self.docs}, ensure_ascii=True), encoding="utf-8")
|
| 46 |
|
| 47 |
+
def rebuild(self, docs: list[dict[str, Any]] | None = None) -> None:
|
| 48 |
+
fresh = BM25Index(docs or self.docs)
|
| 49 |
self.docs = fresh.docs
|
| 50 |
+
self.retriever = fresh.retriever
|
|
|
|
| 51 |
self.save()
|
| 52 |
|
| 53 |
def search(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
| 54 |
+
if not self.retriever or not self.docs:
|
| 55 |
return []
|
| 56 |
+
self.retriever.k = top_k
|
| 57 |
+
results = self.retriever.invoke(query)
|
| 58 |
+
hits = []
|
| 59 |
+
for rank, doc in enumerate(results, start=1):
|
| 60 |
+
metadata = dict(doc.metadata)
|
| 61 |
+
hits.append(
|
| 62 |
+
{
|
| 63 |
+
"id": str(metadata.get("id", f"bm25-{rank}")),
|
| 64 |
+
"text": doc.page_content,
|
| 65 |
+
"source_name": str(metadata.get("source_name", "unknown")),
|
| 66 |
+
"score": 1.0 / rank,
|
| 67 |
+
"metadata": {**metadata, "retriever": "langchain_bm25"},
|
| 68 |
+
}
|
| 69 |
+
)
|
| 70 |
+
return hits
|
app/rag/ingestion.py
CHANGED
|
@@ -97,6 +97,7 @@ class DocumentIngestionService:
|
|
| 97 |
split_docs = self.splitter.split_documents(docs)
|
| 98 |
chunk_records = []
|
| 99 |
new_chunk_texts = []
|
|
|
|
| 100 |
skipped_chunks = 0
|
| 101 |
|
| 102 |
for index, split_doc in enumerate(split_docs):
|
|
@@ -113,6 +114,14 @@ class DocumentIngestionService:
|
|
| 113 |
"source_name": source_name,
|
| 114 |
"text_hash": text_hash,
|
| 115 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
if self.cache.chunk_exists(text_hash):
|
| 117 |
skipped_chunks += 1
|
| 118 |
self.cache.save_chunk(chunk_id, doc_id, index, chunk, text_hash, chunk_metadata, embedded=False)
|
|
@@ -133,7 +142,7 @@ class DocumentIngestionService:
|
|
| 133 |
)
|
| 134 |
|
| 135 |
self.qdrant.upsert_chunks(points)
|
| 136 |
-
|
| 137 |
|
| 138 |
return {
|
| 139 |
"status": "embedded",
|
|
@@ -141,3 +150,14 @@ class DocumentIngestionService:
|
|
| 141 |
"embedded_chunks": len(points),
|
| 142 |
"skipped_chunks": skipped_chunks,
|
| 143 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
split_docs = self.splitter.split_documents(docs)
|
| 98 |
chunk_records = []
|
| 99 |
new_chunk_texts = []
|
| 100 |
+
bm25_docs = []
|
| 101 |
skipped_chunks = 0
|
| 102 |
|
| 103 |
for index, split_doc in enumerate(split_docs):
|
|
|
|
| 114 |
"source_name": source_name,
|
| 115 |
"text_hash": text_hash,
|
| 116 |
}
|
| 117 |
+
bm25_docs.append(
|
| 118 |
+
{
|
| 119 |
+
"id": chunk_id,
|
| 120 |
+
"text": chunk,
|
| 121 |
+
"source_name": source_name,
|
| 122 |
+
"metadata": chunk_metadata,
|
| 123 |
+
}
|
| 124 |
+
)
|
| 125 |
if self.cache.chunk_exists(text_hash):
|
| 126 |
skipped_chunks += 1
|
| 127 |
self.cache.save_chunk(chunk_id, doc_id, index, chunk, text_hash, chunk_metadata, embedded=False)
|
|
|
|
| 142 |
)
|
| 143 |
|
| 144 |
self.qdrant.upsert_chunks(points)
|
| 145 |
+
self._merge_bm25_docs(bm25_docs)
|
| 146 |
|
| 147 |
return {
|
| 148 |
"status": "embedded",
|
|
|
|
| 150 |
"embedded_chunks": len(points),
|
| 151 |
"skipped_chunks": skipped_chunks,
|
| 152 |
}
|
| 153 |
+
|
| 154 |
+
def _merge_bm25_docs(self, docs: list[dict[str, Any]]) -> None:
|
| 155 |
+
current = BM25Index.load_or_create()
|
| 156 |
+
by_hash = {
|
| 157 |
+
str(doc.get("metadata", {}).get("text_hash") or doc.get("id")): doc
|
| 158 |
+
for doc in current.docs
|
| 159 |
+
}
|
| 160 |
+
for doc in docs:
|
| 161 |
+
key = str(doc.get("metadata", {}).get("text_hash") or doc.get("id"))
|
| 162 |
+
by_hash[key] = doc
|
| 163 |
+
BM25Index(list(by_hash.values())).save()
|