Spaces:
Running
Running
| # rag/rag_engine_sources.py | |
| import os | |
| import uuid | |
| from typing import List, Optional | |
| from dataclasses import asdict | |
| from sentence_transformers import SentenceTransformer | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http import models as qm | |
| from schemas.books.sources_schema import ChunkRecord, DocRaw | |
| from .preprocess import normalize_arabic | |
| from .chuncking import chunk_pages # صحّحت اسم الملف من chuncking -> chunking | |
| class ArabicBookRAGWithSources: | |
| def __init__( | |
| self, user_id: str, book_id: str, embedding_model: str, batch_size: int = 128 | |
| ): | |
| self.user_id = user_id | |
| self.book_id = book_id | |
| self.collection = f"user_{user_id}__book_{book_id}" | |
| self.batch_size = batch_size | |
| # load embedder once per instance (this is why ingest_from_net uses single instance) | |
| self.embedder = SentenceTransformer(embedding_model) | |
| self.qdrant = QdrantClient( | |
| url=os.environ["QDRANT_URL"], | |
| api_key=os.environ["QDRANT_API_KEY"], | |
| ) | |
| self._ensure_collection() | |
| def _ensure_collection(self): | |
| dim = self.embedder.get_sentence_embedding_dimension() | |
| existing = self.qdrant.get_collections() | |
| collections = ( | |
| [c.name for c in existing.collections] | |
| if existing and getattr(existing, "collections", None) | |
| else [] | |
| ) | |
| if self.collection not in collections: | |
| self.qdrant.create_collection( | |
| collection_name=self.collection, | |
| vectors_config=qm.VectorParams(size=dim, distance=qm.Distance.COSINE), | |
| ) | |
| self.qdrant.create_payload_index( | |
| collection_name=self.collection, | |
| field_name="doc_id", | |
| field_schema=qm.PayloadSchemaType.KEYWORD, | |
| ) | |
| def ingest_pages(self, pages: List[str], raw_doc: DocRaw): | |
| """ | |
| Create chunks (with page ranges), embed in batches, and upsert to Qdrant in batches. | |
| Returns stats dict. | |
| """ | |
| chunks = chunk_pages(pages) | |
| records = [] | |
| for txt, ps, pe in chunks: | |
| records.append( | |
| ChunkRecord( | |
| chunk_id=str(uuid.uuid4()), | |
| user_id=self.user_id, | |
| book_id=self.book_id, | |
| doc_id=raw_doc.doc_id, | |
| source_url=raw_doc.source_url, | |
| source_type=raw_doc.source_type, | |
| domain=raw_doc.domain, | |
| title="", | |
| authors="", | |
| year=None, | |
| publisher_or_journal="", | |
| language=raw_doc.language, | |
| apa7="", | |
| page_start=ps, | |
| page_end=pe, | |
| text=txt, | |
| ) | |
| ) | |
| # encode in batches to save memory/time | |
| vectors = [] | |
| texts = [r.text for r in records] | |
| for i in range(0, len(texts), self.batch_size): | |
| batch_texts = texts[i : i + self.batch_size] | |
| batch_vecs = self.embedder.encode(batch_texts, normalize_embeddings=True) | |
| vectors.extend(batch_vecs) | |
| # upsert to Qdrant in batches | |
| points = [] | |
| for r, v in zip(records, vectors): | |
| points.append( | |
| qm.PointStruct( | |
| id=r.chunk_id, | |
| vector=v.tolist(), | |
| payload={ | |
| "chunk_id": r.chunk_id, | |
| "user_id": r.user_id, | |
| "book_id": r.book_id, | |
| "doc_id": r.doc_id, | |
| "source_url": r.source_url, | |
| "source_type": r.source_type, | |
| "domain": r.domain, | |
| "language": r.language, | |
| "page_start": r.page_start, | |
| "page_end": r.page_end, | |
| "text": r.text, | |
| }, | |
| ) | |
| ) | |
| for i in range(0, len(points), self.batch_size): | |
| batch = points[i : i + self.batch_size] | |
| self.qdrant.upsert(collection_name=self.collection, points=batch) | |
| return {"pages": len(pages), "chunks": len(records)} | |
| def retrieve( | |
| self, | |
| queries: List[str], | |
| doc_id: Optional[str] = None, | |
| top_k: int = 8, | |
| ): | |
| if not queries: | |
| return [] | |
| must = [] | |
| if doc_id: | |
| must.append( | |
| qm.FieldCondition( | |
| key="doc_id", | |
| match=qm.MatchValue(value=doc_id), | |
| ) | |
| ) | |
| query_filter = qm.Filter(must=must) if must else None | |
| hits = [] | |
| for q in queries: | |
| q = q.strip() | |
| if not q: | |
| continue | |
| q_norm = normalize_arabic(q) | |
| vec = self.embedder.encode([q_norm], normalize_embeddings=True)[0] | |
| # use query_points (or search depending on client version) | |
| res = self.qdrant.query_points( | |
| collection_name=self.collection, | |
| query=vec.tolist(), | |
| limit=top_k, | |
| with_payload=True, | |
| query_filter=query_filter, | |
| ).points | |
| hits.extend(res) | |
| return hits | |
| def delete_book_collection(self): | |
| self.qdrant.delete_collection(self.collection) | |
| def delete_from_qdrant(self, doc_id: str): | |
| try: | |
| self.qdrant.delete( | |
| collection_name=self.collection, | |
| points_selector=qm.Filter( | |
| must=[ | |
| qm.FieldCondition( | |
| key="doc_id", | |
| match=qm.MatchValue(value=doc_id), | |
| ) | |
| ] | |
| ), | |
| ) | |
| print(f"🗑️ Deleted from Qdrant: {doc_id}") | |
| except Exception as e: | |
| print(f"❌ Qdrant delete error {doc_id}: {e}") | |