Spaces:
Sleeping
Sleeping
Commit ·
b8261f9
1
Parent(s): 7654fed
Add login page
Browse files- rag/chuncking.py +55 -25
- rag/ingest_net.py +141 -98
- rag/rag_engine_sources.py +58 -62
rag/chuncking.py
CHANGED
|
@@ -1,56 +1,86 @@
|
|
|
|
|
| 1 |
import re
|
| 2 |
from typing import List, Tuple
|
| 3 |
|
| 4 |
|
| 5 |
def approx_token_count(text: str) -> int:
|
| 6 |
"""
|
| 7 |
-
Rough token
|
|
|
|
| 8 |
"""
|
| 9 |
return max(1, len(text) // 4)
|
| 10 |
|
| 11 |
|
| 12 |
def chunk_pages(
|
| 13 |
-
pages: List[str],
|
|
|
|
|
|
|
| 14 |
) -> List[Tuple[str, int, int]]:
|
| 15 |
"""
|
| 16 |
-
Split pages into overlapping chunks.
|
| 17 |
-
Returns:
|
|
|
|
|
|
|
| 18 |
"""
|
| 19 |
-
chunks = []
|
| 20 |
-
|
| 21 |
-
|
|
|
|
| 22 |
buffer_tokens = 0
|
| 23 |
|
| 24 |
-
def
|
| 25 |
nonlocal buffer, buffer_pages, buffer_tokens
|
| 26 |
-
if buffer:
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
for page_idx, page in enumerate(pages, start=1):
|
| 31 |
paragraphs = [p.strip() for p in re.split(r"\n\s*\n", page) if p.strip()]
|
| 32 |
|
| 33 |
for para in paragraphs:
|
| 34 |
t = approx_token_count(para)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
if buffer_tokens + t > target_tokens:
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
buffer.append(para)
|
| 39 |
buffer_pages.append(page_idx)
|
| 40 |
buffer_tokens += t
|
| 41 |
|
| 42 |
-
flush
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
if overlap_tokens > 0 and len(chunks) > 1:
|
| 46 |
-
overlapped = []
|
| 47 |
-
prev_tail = ""
|
| 48 |
-
|
| 49 |
-
for text, ps, pe in chunks:
|
| 50 |
-
merged = (prev_tail + "\n" + text).strip()
|
| 51 |
-
overlapped.append((merged, ps, pe))
|
| 52 |
-
prev_tail = text[-overlap_tokens * 4 :]
|
| 53 |
-
|
| 54 |
-
return overlapped
|
| 55 |
|
| 56 |
return chunks
|
|
|
|
| 1 |
+
# chunking.py
|
| 2 |
import re
|
| 3 |
from typing import List, Tuple
|
| 4 |
|
| 5 |
|
| 6 |
def approx_token_count(text: str) -> int:
|
| 7 |
"""
|
| 8 |
+
Rough token estimate for chunk sizing (heuristic).
|
| 9 |
+
Keep this simple but consistent: ~4 chars per token.
|
| 10 |
"""
|
| 11 |
return max(1, len(text) // 4)
|
| 12 |
|
| 13 |
|
| 14 |
def chunk_pages(
|
| 15 |
+
pages: List[str],
|
| 16 |
+
target_tokens: int = 520,
|
| 17 |
+
overlap_tokens: int = 80,
|
| 18 |
) -> List[Tuple[str, int, int]]:
|
| 19 |
"""
|
| 20 |
+
Split pages (list[str]) into overlapping chunks.
|
| 21 |
+
Returns list of tuples: (chunk_text, page_start, page_end)
|
| 22 |
+
|
| 23 |
+
Overlap is implemented at paragraph level (keeps page ranges correct).
|
| 24 |
"""
|
| 25 |
+
chunks: List[Tuple[str, int, int]] = []
|
| 26 |
+
|
| 27 |
+
buffer: List[str] = []
|
| 28 |
+
buffer_pages: List[int] = []
|
| 29 |
buffer_tokens = 0
|
| 30 |
|
| 31 |
+
def make_chunk():
|
| 32 |
nonlocal buffer, buffer_pages, buffer_tokens
|
| 33 |
+
if not buffer:
|
| 34 |
+
return
|
| 35 |
+
chunk_text = "\n\n".join(buffer).strip()
|
| 36 |
+
page_start = min(buffer_pages)
|
| 37 |
+
page_end = max(buffer_pages)
|
| 38 |
+
chunks.append((chunk_text, page_start, page_end))
|
| 39 |
|
| 40 |
for page_idx, page in enumerate(pages, start=1):
|
| 41 |
paragraphs = [p.strip() for p in re.split(r"\n\s*\n", page) if p.strip()]
|
| 42 |
|
| 43 |
for para in paragraphs:
|
| 44 |
t = approx_token_count(para)
|
| 45 |
+
|
| 46 |
+
# If single paragraph exceeds target, create it as its own chunk
|
| 47 |
+
if buffer_tokens == 0 and t > target_tokens:
|
| 48 |
+
# make chunk with this large paragraph alone
|
| 49 |
+
buffer = [para]
|
| 50 |
+
buffer_pages = [page_idx]
|
| 51 |
+
buffer_tokens = t
|
| 52 |
+
make_chunk()
|
| 53 |
+
buffer, buffer_pages, buffer_tokens = [], [], 0
|
| 54 |
+
continue
|
| 55 |
+
|
| 56 |
+
# If adding this paragraph would exceed target, flush current chunk
|
| 57 |
if buffer_tokens + t > target_tokens:
|
| 58 |
+
make_chunk()
|
| 59 |
+
|
| 60 |
+
# prepare overlap: keep tail paragraphs whose tokens sum >= overlap_tokens
|
| 61 |
+
tail_buffer: List[str] = []
|
| 62 |
+
tail_pages: List[int] = []
|
| 63 |
+
tail_tokens = 0
|
| 64 |
+
# iterate buffer in reverse to pick tail paragraphs
|
| 65 |
+
for p, p_pg in zip(reversed(buffer), reversed(buffer_pages)):
|
| 66 |
+
pt = approx_token_count(p)
|
| 67 |
+
tail_buffer.insert(0, p)
|
| 68 |
+
tail_pages.insert(0, p_pg)
|
| 69 |
+
tail_tokens += pt
|
| 70 |
+
if tail_tokens >= overlap_tokens:
|
| 71 |
+
break
|
| 72 |
|
| 73 |
+
buffer = tail_buffer
|
| 74 |
+
buffer_pages = tail_pages
|
| 75 |
+
buffer_tokens = tail_tokens
|
| 76 |
+
|
| 77 |
+
# append current paragraph
|
| 78 |
buffer.append(para)
|
| 79 |
buffer_pages.append(page_idx)
|
| 80 |
buffer_tokens += t
|
| 81 |
|
| 82 |
+
# final flush
|
| 83 |
+
if buffer:
|
| 84 |
+
make_chunk()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
return chunks
|
rag/ingest_net.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
# rag/ingest_net.py
|
| 2 |
import uuid
|
| 3 |
import os
|
| 4 |
-
|
|
|
|
| 5 |
from typing import List, Dict, Any, Tuple
|
| 6 |
|
| 7 |
from core.books.fetch_url import get_pdf_bytes
|
|
@@ -12,9 +13,8 @@ from core.books.storage import (
|
|
| 12 |
mark_raw_status,
|
| 13 |
)
|
| 14 |
from schemas.books.sources_schema import DocRaw
|
| 15 |
-
from .rag_engine_sources import ArabicBookRAGWithSources
|
| 16 |
|
| 17 |
-
MAX_INGEST_WORKERS = int(os.getenv("MAX_INGEST_WORKERS", "
|
| 18 |
MIN_PDF_SIZE_KB = int(os.getenv("MIN_PDF_SIZE_KB", "60"))
|
| 19 |
|
| 20 |
EMBEDDING_MODEL = os.getenv(
|
|
@@ -26,115 +26,158 @@ HEAD_PAGES_N = int(os.getenv("HEAD_PAGES_N", "5"))
|
|
| 26 |
|
| 27 |
|
| 28 |
def ingest_from_net(user_id: str, book_id: str, sources: List[Dict[str, Any]]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
init_db()
|
| 30 |
if not check_db_health():
|
| 31 |
raise RuntimeError("Supabase / DB is not reachable")
|
| 32 |
|
| 33 |
-
|
| 34 |
-
user_id=user_id, book_id=book_id, embedding_model=EMBEDDING_MODEL
|
| 35 |
-
)
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
|
|
|
| 39 |
doc_id = str(uuid.uuid4())
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
print(f"Done pypdf2✅ | pages={len(pages)}")
|
| 58 |
-
extraction_method = "text"
|
| 59 |
-
else:
|
| 60 |
-
if len(pdf_bytes) > 10 * 1024 * 1024:
|
| 61 |
-
return {
|
| 62 |
-
"url": url,
|
| 63 |
-
"status": "rejected",
|
| 64 |
-
"reason": "pdf_too_large_for_ocr",
|
| 65 |
-
}
|
| 66 |
-
try:
|
| 67 |
-
pages = mistral_ocr_pdf(pdf_bytes) # يفترض بيرجع list[str] صفحات
|
| 68 |
-
extraction_method = "ocr"
|
| 69 |
-
except Exception:
|
| 70 |
-
return {"url": url, "status": "rejected", "reason": "ocr_failed"}
|
| 71 |
-
|
| 72 |
-
if not pages:
|
| 73 |
-
return {"url": url, "status": "rejected", "reason": "no_text_extracted"}
|
| 74 |
-
|
| 75 |
-
# ---------- Preprocess for Arabic (optional) ----------
|
| 76 |
-
from .preprocess import normalize_arabic, drop_common_headers_footers
|
| 77 |
-
|
| 78 |
-
language = src.get("language", "ar")
|
| 79 |
-
if language == "ar":
|
| 80 |
-
pages = [normalize_arabic(p) for p in pages]
|
| 81 |
-
pages = drop_common_headers_footers(pages)
|
| 82 |
-
|
| 83 |
-
pages_head = pages[:HEAD_PAGES_N]
|
| 84 |
-
|
| 85 |
-
# ---------- Store RAW in Supabase (Phase 1) ----------
|
| 86 |
-
raw_doc = DocRaw(
|
| 87 |
-
doc_id=doc_id,
|
| 88 |
-
user_id=user_id,
|
| 89 |
-
book_id=book_id,
|
| 90 |
-
source_url=url,
|
| 91 |
-
source_type=src.get("source_type", "pdf"),
|
| 92 |
-
domain=src.get("domain", ""),
|
| 93 |
-
language=language,
|
| 94 |
-
pages_head=pages_head,
|
| 95 |
-
extraction_method=extraction_method,
|
| 96 |
-
pdf_size_bytes=len(pdf_bytes),
|
| 97 |
-
status="pending",
|
| 98 |
-
error_reason="",
|
| 99 |
-
)
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
return {
|
| 105 |
-
"url": url,
|
| 106 |
-
"status": "rejected",
|
| 107 |
-
"reason": f"db_insert_failed: {str(e)}",
|
| 108 |
-
}
|
| 109 |
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
try:
|
| 112 |
stats = rag.ingest_pages(pages=pages, raw_doc=raw_doc)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
except Exception as e:
|
| 114 |
-
#
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
"extraction": extraction_method,
|
| 129 |
-
"stats": stats,
|
| 130 |
-
}
|
| 131 |
-
|
| 132 |
-
items = []
|
| 133 |
-
with ThreadPoolExecutor(max_workers=MAX_INGEST_WORKERS) as ex:
|
| 134 |
-
for result in ex.map(worker, sources):
|
| 135 |
-
items.append(result)
|
| 136 |
-
|
| 137 |
-
ingested = sum(1 for x in items if x["status"] == "ingested")
|
| 138 |
-
rejected = len(items) - ingested
|
| 139 |
|
| 140 |
-
return
|
|
|
|
| 1 |
# rag/ingest_net.py
|
| 2 |
import uuid
|
| 3 |
import os
|
| 4 |
+
import time
|
| 5 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 6 |
from typing import List, Dict, Any, Tuple
|
| 7 |
|
| 8 |
from core.books.fetch_url import get_pdf_bytes
|
|
|
|
| 13 |
mark_raw_status,
|
| 14 |
)
|
| 15 |
from schemas.books.sources_schema import DocRaw
|
|
|
|
| 16 |
|
| 17 |
+
MAX_INGEST_WORKERS = int(os.getenv("MAX_INGEST_WORKERS", "6"))
|
| 18 |
MIN_PDF_SIZE_KB = int(os.getenv("MIN_PDF_SIZE_KB", "60"))
|
| 19 |
|
| 20 |
EMBEDDING_MODEL = os.getenv(
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
def ingest_from_net(user_id: str, book_id: str, sources: List[Dict[str, Any]]):
|
| 29 |
+
"""
|
| 30 |
+
Two-phase ingest:
|
| 31 |
+
Phase A (concurrent): fetch PDFs, extract text/OCR, preprocess, insert raw doc to DB.
|
| 32 |
+
Phase B (sequential): load embedder once and ingest pages (encode + upsert) to Qdrant.
|
| 33 |
+
This avoids loading the embedding model many times (huge slow-down).
|
| 34 |
+
"""
|
| 35 |
init_db()
|
| 36 |
if not check_db_health():
|
| 37 |
raise RuntimeError("Supabase / DB is not reachable")
|
| 38 |
|
| 39 |
+
collection_name = f"user_{user_id}__book_{book_id}"
|
|
|
|
|
|
|
| 40 |
|
| 41 |
+
# Worker for phase A: fetch + extract + preprocess + insert_raw_document
|
| 42 |
+
def fetch_and_prepare(src: Dict[str, Any]) -> Dict[str, Any]:
|
| 43 |
+
url = src.get("url")
|
| 44 |
doc_id = str(uuid.uuid4())
|
| 45 |
+
start = time.time()
|
| 46 |
+
result = {"url": url, "doc_id": doc_id, "status": "rejected", "reason": None}
|
| 47 |
|
| 48 |
+
try:
|
| 49 |
+
pdf_bytes = get_pdf_bytes(url)
|
| 50 |
+
if not pdf_bytes:
|
| 51 |
+
result["reason"] = "no_pdf_or_blocked"
|
| 52 |
+
return result
|
| 53 |
+
|
| 54 |
+
if len(pdf_bytes) < MIN_PDF_SIZE_KB * 1024:
|
| 55 |
+
result["reason"] = "pdf_too_small"
|
| 56 |
+
return result
|
| 57 |
+
|
| 58 |
+
# delayed imports for optional OCR libs
|
| 59 |
+
from .pdf_text import extract_text_pypdf2, is_text_usable
|
| 60 |
+
from .ocr import mistral_ocr_pdf
|
| 61 |
+
|
| 62 |
+
pages = extract_text_pypdf2(pdf_bytes)
|
| 63 |
+
joined = "\n".join(pages)
|
| 64 |
+
|
| 65 |
+
if is_text_usable(joined):
|
| 66 |
+
extraction_method = "text"
|
| 67 |
+
else:
|
| 68 |
+
if len(pdf_bytes) > 10 * 1024 * 1024:
|
| 69 |
+
result["reason"] = "pdf_too_large_for_ocr"
|
| 70 |
+
return result
|
| 71 |
+
try:
|
| 72 |
+
pages = mistral_ocr_pdf(pdf_bytes)
|
| 73 |
+
extraction_method = "ocr"
|
| 74 |
+
except Exception as e:
|
| 75 |
+
result["reason"] = f"ocr_failed:{str(e)}"
|
| 76 |
+
return result
|
| 77 |
+
|
| 78 |
+
if not pages:
|
| 79 |
+
result["reason"] = "no_text_extracted"
|
| 80 |
+
return result
|
| 81 |
+
|
| 82 |
+
# preprocess
|
| 83 |
+
from .preprocess import normalize_arabic, drop_common_headers_footers
|
| 84 |
+
|
| 85 |
+
language = src.get("language", "ar")
|
| 86 |
+
if language == "ar":
|
| 87 |
+
pages = [normalize_arabic(p) for p in pages]
|
| 88 |
+
pages = drop_common_headers_footers(pages)
|
| 89 |
+
|
| 90 |
+
pages_head = pages[:HEAD_PAGES_N]
|
| 91 |
+
|
| 92 |
+
raw_doc = DocRaw(
|
| 93 |
+
doc_id=doc_id,
|
| 94 |
+
user_id=user_id,
|
| 95 |
+
book_id=book_id,
|
| 96 |
+
source_url=url,
|
| 97 |
+
source_type=src.get("source_type", "pdf"),
|
| 98 |
+
domain=src.get("domain", ""),
|
| 99 |
+
language=language,
|
| 100 |
+
pages_head=pages_head,
|
| 101 |
+
extraction_method=extraction_method,
|
| 102 |
+
pdf_size_bytes=len(pdf_bytes),
|
| 103 |
+
status="pending",
|
| 104 |
+
error_reason="",
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# insert raw doc (phase 1)
|
| 108 |
+
insert_raw_document(raw_doc)
|
| 109 |
|
| 110 |
+
result.update(
|
| 111 |
+
{
|
| 112 |
+
"status": "prepared",
|
| 113 |
+
"doc_id": doc_id,
|
| 114 |
+
"raw_doc": raw_doc,
|
| 115 |
+
"pages": pages,
|
| 116 |
+
"extraction": extraction_method,
|
| 117 |
+
"duration": time.time() - start,
|
| 118 |
+
}
|
| 119 |
+
)
|
| 120 |
+
return result
|
| 121 |
|
| 122 |
+
except Exception as e:
|
| 123 |
+
result["reason"] = f"fetch_prepare_failed:{str(e)}"
|
| 124 |
+
return result
|
| 125 |
|
| 126 |
+
# Phase A: concurrent fetch + prepare
|
| 127 |
+
prepared_items = []
|
| 128 |
+
with ThreadPoolExecutor(max_workers=MAX_INGEST_WORKERS) as ex:
|
| 129 |
+
futures = {ex.submit(fetch_and_prepare, src): src for src in sources}
|
| 130 |
+
for fut in as_completed(futures):
|
| 131 |
+
res = fut.result()
|
| 132 |
+
prepared_items.append(res)
|
| 133 |
|
| 134 |
+
# Phase B: sequential embedding & qdrant upsert with single embedder instance
|
| 135 |
+
from .rag_engine_sources import ArabicBookRAGWithSources
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
+
rag = ArabicBookRAGWithSources(
|
| 138 |
+
user_id=user_id, book_id=book_id, embedding_model=EMBEDDING_MODEL
|
| 139 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
+
items_out = []
|
| 142 |
+
ingested = 0
|
| 143 |
+
rejected = 0
|
| 144 |
+
|
| 145 |
+
for item in prepared_items:
|
| 146 |
+
if item.get("status") != "prepared":
|
| 147 |
+
items_out.append(item)
|
| 148 |
+
rejected += 1
|
| 149 |
+
continue
|
| 150 |
+
|
| 151 |
+
doc_id = item["doc_id"]
|
| 152 |
+
pages = item["pages"]
|
| 153 |
+
raw_doc = item["raw_doc"]
|
| 154 |
try:
|
| 155 |
stats = rag.ingest_pages(pages=pages, raw_doc=raw_doc)
|
| 156 |
+
items_out.append(
|
| 157 |
+
{
|
| 158 |
+
"url": item["url"],
|
| 159 |
+
"status": "ingested",
|
| 160 |
+
"doc_id": doc_id,
|
| 161 |
+
"source_type": raw_doc.source_type,
|
| 162 |
+
"extraction": item["extraction"],
|
| 163 |
+
"stats": stats,
|
| 164 |
+
}
|
| 165 |
+
)
|
| 166 |
+
ingested += 1
|
| 167 |
except Exception as e:
|
| 168 |
+
# mark failed on db
|
| 169 |
+
try:
|
| 170 |
+
mark_raw_status(doc_id, "failed", f"qdrant_failed:{str(e)}")
|
| 171 |
+
except Exception:
|
| 172 |
+
pass
|
| 173 |
+
items_out.append(
|
| 174 |
+
{
|
| 175 |
+
"url": item["url"],
|
| 176 |
+
"status": "rejected",
|
| 177 |
+
"reason": f"qdrant_failed:{str(e)}",
|
| 178 |
+
"doc_id": doc_id,
|
| 179 |
+
}
|
| 180 |
+
)
|
| 181 |
+
rejected += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
+
return items_out, ingested, rejected, rag.collection
|
rag/rag_engine_sources.py
CHANGED
|
@@ -1,24 +1,28 @@
|
|
| 1 |
# rag/rag_engine_sources.py
|
| 2 |
import os
|
| 3 |
import uuid
|
| 4 |
-
from dataclasses import asdict
|
| 5 |
from typing import List, Optional
|
| 6 |
-
from
|
| 7 |
-
|
| 8 |
|
| 9 |
from sentence_transformers import SentenceTransformer
|
| 10 |
from qdrant_client import QdrantClient
|
| 11 |
from qdrant_client.http import models as qm
|
| 12 |
|
| 13 |
from schemas.books.sources_schema import ChunkRecord, DocRaw
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
class ArabicBookRAGWithSources:
|
| 17 |
-
def __init__(
|
|
|
|
|
|
|
| 18 |
self.user_id = user_id
|
| 19 |
self.book_id = book_id
|
| 20 |
self.collection = f"user_{user_id}__book_{book_id}"
|
|
|
|
| 21 |
|
|
|
|
| 22 |
self.embedder = SentenceTransformer(embedding_model)
|
| 23 |
self.qdrant = QdrantClient(
|
| 24 |
url=os.environ["QDRANT_URL"],
|
|
@@ -28,7 +32,12 @@ class ArabicBookRAGWithSources:
|
|
| 28 |
|
| 29 |
def _ensure_collection(self):
|
| 30 |
dim = self.embedder.get_sentence_embedding_dimension()
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
if self.collection not in collections:
|
| 33 |
self.qdrant.create_collection(
|
| 34 |
collection_name=self.collection,
|
|
@@ -36,41 +45,47 @@ class ArabicBookRAGWithSources:
|
|
| 36 |
)
|
| 37 |
|
| 38 |
def ingest_pages(self, pages: List[str], raw_doc: DocRaw):
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
| 41 |
chunks = chunk_pages(pages)
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
| 62 |
)
|
| 63 |
-
for txt, ps, pe in chunks
|
| 64 |
-
]
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
qm.PointStruct(
|
| 75 |
id=r.chunk_id,
|
| 76 |
vector=v.tolist(),
|
|
@@ -88,9 +103,11 @@ class ArabicBookRAGWithSources:
|
|
| 88 |
"text": r.text,
|
| 89 |
},
|
| 90 |
)
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
)
|
|
|
|
|
|
|
| 94 |
|
| 95 |
return {"pages": len(pages), "chunks": len(records)}
|
| 96 |
|
|
@@ -100,17 +117,9 @@ class ArabicBookRAGWithSources:
|
|
| 100 |
doc_id: Optional[str] = None,
|
| 101 |
top_k: int = 8,
|
| 102 |
):
|
| 103 |
-
"""
|
| 104 |
-
Semantic retrieval from Qdrant collection.
|
| 105 |
-
Returns list of scored chunks.
|
| 106 |
-
"""
|
| 107 |
-
|
| 108 |
if not queries:
|
| 109 |
return []
|
| 110 |
|
| 111 |
-
# -------------------------
|
| 112 |
-
# Optional filter
|
| 113 |
-
# -------------------------
|
| 114 |
must = []
|
| 115 |
if doc_id:
|
| 116 |
must.append(
|
|
@@ -129,23 +138,10 @@ class ArabicBookRAGWithSources:
|
|
| 129 |
if not q:
|
| 130 |
continue
|
| 131 |
|
| 132 |
-
# normalize Arabic
|
| 133 |
q_norm = normalize_arabic(q)
|
|
|
|
| 134 |
|
| 135 |
-
#
|
| 136 |
-
vec = self.embedder.encode(
|
| 137 |
-
[q_norm],
|
| 138 |
-
normalize_embeddings=True,
|
| 139 |
-
)[0]
|
| 140 |
-
|
| 141 |
-
# search
|
| 142 |
-
# res = self.qdrant.search(
|
| 143 |
-
# collection_name=self.collection,
|
| 144 |
-
# query_vector=vec.tolist(),
|
| 145 |
-
# limit=top_k,
|
| 146 |
-
# with_payload=True,
|
| 147 |
-
# query_filter=query_filter,
|
| 148 |
-
# )
|
| 149 |
res = self.qdrant.query_points(
|
| 150 |
collection_name=self.collection,
|
| 151 |
query=vec.tolist(),
|
|
|
|
| 1 |
# rag/rag_engine_sources.py
|
| 2 |
import os
|
| 3 |
import uuid
|
|
|
|
| 4 |
from typing import List, Optional
|
| 5 |
+
from dataclasses import asdict
|
|
|
|
| 6 |
|
| 7 |
from sentence_transformers import SentenceTransformer
|
| 8 |
from qdrant_client import QdrantClient
|
| 9 |
from qdrant_client.http import models as qm
|
| 10 |
|
| 11 |
from schemas.books.sources_schema import ChunkRecord, DocRaw
|
| 12 |
+
from .preprocess import normalize_arabic
|
| 13 |
+
from .chuncking import chunk_pages # صحّحت اسم الملف من chuncking -> chunking
|
| 14 |
|
| 15 |
|
| 16 |
class ArabicBookRAGWithSources:
|
| 17 |
+
def __init__(
|
| 18 |
+
self, user_id: str, book_id: str, embedding_model: str, batch_size: int = 128
|
| 19 |
+
):
|
| 20 |
self.user_id = user_id
|
| 21 |
self.book_id = book_id
|
| 22 |
self.collection = f"user_{user_id}__book_{book_id}"
|
| 23 |
+
self.batch_size = batch_size
|
| 24 |
|
| 25 |
+
# load embedder once per instance (this is why ingest_from_net uses single instance)
|
| 26 |
self.embedder = SentenceTransformer(embedding_model)
|
| 27 |
self.qdrant = QdrantClient(
|
| 28 |
url=os.environ["QDRANT_URL"],
|
|
|
|
| 32 |
|
| 33 |
def _ensure_collection(self):
|
| 34 |
dim = self.embedder.get_sentence_embedding_dimension()
|
| 35 |
+
existing = self.qdrant.get_collections()
|
| 36 |
+
collections = (
|
| 37 |
+
[c.name for c in existing.collections]
|
| 38 |
+
if existing and getattr(existing, "collections", None)
|
| 39 |
+
else []
|
| 40 |
+
)
|
| 41 |
if self.collection not in collections:
|
| 42 |
self.qdrant.create_collection(
|
| 43 |
collection_name=self.collection,
|
|
|
|
| 45 |
)
|
| 46 |
|
| 47 |
def ingest_pages(self, pages: List[str], raw_doc: DocRaw):
|
| 48 |
+
"""
|
| 49 |
+
Create chunks (with page ranges), embed in batches, and upsert to Qdrant in batches.
|
| 50 |
+
Returns stats dict.
|
| 51 |
+
"""
|
| 52 |
chunks = chunk_pages(pages)
|
| 53 |
|
| 54 |
+
records = []
|
| 55 |
+
for txt, ps, pe in chunks:
|
| 56 |
+
records.append(
|
| 57 |
+
ChunkRecord(
|
| 58 |
+
chunk_id=str(uuid.uuid4()),
|
| 59 |
+
user_id=self.user_id,
|
| 60 |
+
book_id=self.book_id,
|
| 61 |
+
doc_id=raw_doc.doc_id,
|
| 62 |
+
source_url=raw_doc.source_url,
|
| 63 |
+
source_type=raw_doc.source_type,
|
| 64 |
+
domain=raw_doc.domain,
|
| 65 |
+
title="",
|
| 66 |
+
authors="",
|
| 67 |
+
year=None,
|
| 68 |
+
publisher_or_journal="",
|
| 69 |
+
language=raw_doc.language,
|
| 70 |
+
apa7="",
|
| 71 |
+
page_start=ps,
|
| 72 |
+
page_end=pe,
|
| 73 |
+
text=txt,
|
| 74 |
+
)
|
| 75 |
)
|
|
|
|
|
|
|
| 76 |
|
| 77 |
+
# encode in batches to save memory/time
|
| 78 |
+
vectors = []
|
| 79 |
+
texts = [r.text for r in records]
|
| 80 |
+
for i in range(0, len(texts), self.batch_size):
|
| 81 |
+
batch_texts = texts[i : i + self.batch_size]
|
| 82 |
+
batch_vecs = self.embedder.encode(batch_texts, normalize_embeddings=True)
|
| 83 |
+
vectors.extend(batch_vecs)
|
| 84 |
+
|
| 85 |
+
# upsert to Qdrant in batches
|
| 86 |
+
points = []
|
| 87 |
+
for r, v in zip(records, vectors):
|
| 88 |
+
points.append(
|
| 89 |
qm.PointStruct(
|
| 90 |
id=r.chunk_id,
|
| 91 |
vector=v.tolist(),
|
|
|
|
| 103 |
"text": r.text,
|
| 104 |
},
|
| 105 |
)
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
for i in range(0, len(points), self.batch_size):
|
| 109 |
+
batch = points[i : i + self.batch_size]
|
| 110 |
+
self.qdrant.upsert(collection_name=self.collection, points=batch)
|
| 111 |
|
| 112 |
return {"pages": len(pages), "chunks": len(records)}
|
| 113 |
|
|
|
|
| 117 |
doc_id: Optional[str] = None,
|
| 118 |
top_k: int = 8,
|
| 119 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
if not queries:
|
| 121 |
return []
|
| 122 |
|
|
|
|
|
|
|
|
|
|
| 123 |
must = []
|
| 124 |
if doc_id:
|
| 125 |
must.append(
|
|
|
|
| 138 |
if not q:
|
| 139 |
continue
|
| 140 |
|
|
|
|
| 141 |
q_norm = normalize_arabic(q)
|
| 142 |
+
vec = self.embedder.encode([q_norm], normalize_embeddings=True)[0]
|
| 143 |
|
| 144 |
+
# use query_points (or search depending on client version)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
res = self.qdrant.query_points(
|
| 146 |
collection_name=self.collection,
|
| 147 |
query=vec.tolist(),
|