ContiAI-v4 / rag /rag_engine_sources.py
ziadsameh32's picture
Add login page
033e647
# rag/rag_engine_sources.py
import os
import uuid
from typing import List, Optional
from dataclasses import asdict
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.http import models as qm
from schemas.books.sources_schema import ChunkRecord, DocRaw
from .preprocess import normalize_arabic
from .chuncking import chunk_pages # صحّحت اسم الملف من chuncking -> chunking
class ArabicBookRAGWithSources:
def __init__(
self, user_id: str, book_id: str, embedding_model: str, batch_size: int = 128
):
self.user_id = user_id
self.book_id = book_id
self.collection = f"user_{user_id}__book_{book_id}"
self.batch_size = batch_size
# load embedder once per instance (this is why ingest_from_net uses single instance)
self.embedder = SentenceTransformer(embedding_model)
self.qdrant = QdrantClient(
url=os.environ["QDRANT_URL"],
api_key=os.environ["QDRANT_API_KEY"],
)
self._ensure_collection()
def _ensure_collection(self):
dim = self.embedder.get_sentence_embedding_dimension()
existing = self.qdrant.get_collections()
collections = (
[c.name for c in existing.collections]
if existing and getattr(existing, "collections", None)
else []
)
if self.collection not in collections:
self.qdrant.create_collection(
collection_name=self.collection,
vectors_config=qm.VectorParams(size=dim, distance=qm.Distance.COSINE),
)
self.qdrant.create_payload_index(
collection_name=self.collection,
field_name="doc_id",
field_schema=qm.PayloadSchemaType.KEYWORD,
)
def ingest_pages(self, pages: List[str], raw_doc: DocRaw):
"""
Create chunks (with page ranges), embed in batches, and upsert to Qdrant in batches.
Returns stats dict.
"""
chunks = chunk_pages(pages)
records = []
for txt, ps, pe in chunks:
records.append(
ChunkRecord(
chunk_id=str(uuid.uuid4()),
user_id=self.user_id,
book_id=self.book_id,
doc_id=raw_doc.doc_id,
source_url=raw_doc.source_url,
source_type=raw_doc.source_type,
domain=raw_doc.domain,
title="",
authors="",
year=None,
publisher_or_journal="",
language=raw_doc.language,
apa7="",
page_start=ps,
page_end=pe,
text=txt,
)
)
# encode in batches to save memory/time
vectors = []
texts = [r.text for r in records]
for i in range(0, len(texts), self.batch_size):
batch_texts = texts[i : i + self.batch_size]
batch_vecs = self.embedder.encode(batch_texts, normalize_embeddings=True)
vectors.extend(batch_vecs)
# upsert to Qdrant in batches
points = []
for r, v in zip(records, vectors):
points.append(
qm.PointStruct(
id=r.chunk_id,
vector=v.tolist(),
payload={
"chunk_id": r.chunk_id,
"user_id": r.user_id,
"book_id": r.book_id,
"doc_id": r.doc_id,
"source_url": r.source_url,
"source_type": r.source_type,
"domain": r.domain,
"language": r.language,
"page_start": r.page_start,
"page_end": r.page_end,
"text": r.text,
},
)
)
for i in range(0, len(points), self.batch_size):
batch = points[i : i + self.batch_size]
self.qdrant.upsert(collection_name=self.collection, points=batch)
return {"pages": len(pages), "chunks": len(records)}
def retrieve(
self,
queries: List[str],
doc_id: Optional[str] = None,
top_k: int = 8,
):
if not queries:
return []
must = []
if doc_id:
must.append(
qm.FieldCondition(
key="doc_id",
match=qm.MatchValue(value=doc_id),
)
)
query_filter = qm.Filter(must=must) if must else None
hits = []
for q in queries:
q = q.strip()
if not q:
continue
q_norm = normalize_arabic(q)
vec = self.embedder.encode([q_norm], normalize_embeddings=True)[0]
# use query_points (or search depending on client version)
res = self.qdrant.query_points(
collection_name=self.collection,
query=vec.tolist(),
limit=top_k,
with_payload=True,
query_filter=query_filter,
).points
hits.extend(res)
return hits
def delete_book_collection(self):
self.qdrant.delete_collection(self.collection)
def delete_from_qdrant(self, doc_id: str):
try:
self.qdrant.delete(
collection_name=self.collection,
points_selector=qm.Filter(
must=[
qm.FieldCondition(
key="doc_id",
match=qm.MatchValue(value=doc_id),
)
]
),
)
print(f"🗑️ Deleted from Qdrant: {doc_id}")
except Exception as e:
print(f"❌ Qdrant delete error {doc_id}: {e}")