Spaces:
Running
Running
File size: 4,920 Bytes
31a2688 9612292 31a2688 9612292 31a2688 ec64993 31a2688 ec64993 31a2688 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | """End-to-end test: PDF ingestion → chunking → embedding → hybrid search → rerank → answer.
Runs the full RAG pipeline directly against src/ modules without FastAPI.
Uses local providers (Ollama for LLM, HuggingFace for embeddings).
"""
import logging
import os
import shutil
import sys
import tempfile
# Ensure project root is on sys.path
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, PROJECT_ROOT)
from src.config import load_settings
from src.models import ChunkStrategy
from src.provider import create_embeddings, create_llm, create_reranker
from src.ingestion.pipeline import IngestionPipeline
from src.retrieval.embedder import Embedder
from src.retrieval.vector_store import VectorStore
from src.retrieval.bm25_search import BM25Search
from src.retrieval.hybrid import HybridRetriever
from src.retrieval.reranker import Reranker
from src.agent.intent_classifier import IntentClassifier
from src.agent.router import QueryRouter
from langchain_core.output_parsers import StrOutputParser
logger = logging.getLogger(__name__)
DOCS_DIR = os.path.join(PROJECT_ROOT, "docs")
TEST_QUERY = "Hvad er reglerne for brug af AI på KU?"
def main() -> None:
"""Run full end-to-end RAG pipeline test."""
# --- Config ---
settings = load_settings()
logging.basicConfig(
level=getattr(logging, settings.log_level.upper(), logging.INFO),
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger.info("=== E2E Test Start ===")
logger.info("LLM provider: %s | Embedding provider: %s", settings.llm_provider, settings.embedding_provider)
# Use a temporary Qdrant path so we don't pollute the main store
qdrant_tmp = tempfile.mkdtemp(prefix="e2e_qdrant_")
logger.info("Qdrant temp path: %s", qdrant_tmp)
try:
# --- 1) Create providers ---
logger.info("Creating LLM and embeddings...")
llm = create_llm(settings)
embeddings = create_embeddings(settings)
# --- 2) Ingest all PDFs from docs/ ---
logger.info("Ingesting PDFs from %s ...", DOCS_DIR)
pipeline = IngestionPipeline(
strategy=ChunkStrategy.RECURSIVE,
chunk_size=settings.chunk_size,
chunk_overlap=settings.chunk_overlap,
)
chunks = pipeline.ingest_directory(DOCS_DIR)
logger.info("Total chunks created: %d", len(chunks))
if not chunks:
logger.error("No chunks produced. Check that docs/ contains valid PDFs.")
sys.exit(1)
# --- 3) Embed and index ---
logger.info("Embedding %d chunks...", len(chunks))
embedder = Embedder(embeddings)
vectors = embedder.embed_batch([c.text for c in chunks])
logger.info("Embedding complete (dim=%d)", len(vectors[0]))
logger.info("Indexing into Qdrant...")
vector_store = VectorStore(
path=qdrant_tmp,
collection_name="e2e_test",
dimension=settings.embedding_dimension,
)
vector_store.add_chunks(chunks, vectors)
logger.info("Building BM25 index...")
bm25 = BM25Search()
bm25.index(chunks)
# --- 4) Build retrieval + generation pipeline ---
hybrid = HybridRetriever(
vector_store=vector_store,
bm25_search=bm25,
embedder=embedder,
dense_weight=settings.dense_weight,
bm25_weight=settings.bm25_weight,
)
reranker = Reranker(model=create_reranker(settings.reranker_model))
classifier = IntentClassifier(llm=llm)
llm_chain = llm | StrOutputParser()
router = QueryRouter(
intent_classifier=classifier,
hybrid_retriever=hybrid,
reranker=reranker,
llm_chain=llm_chain,
)
# --- 5) Run query ---
logger.info("Query: %s", TEST_QUERY)
response = router.route(query=TEST_QUERY, top_k=settings.top_k)
# --- Print results ---
print("\n" + "=" * 70)
print("QUERY:", TEST_QUERY)
print("=" * 70)
print(f"\nINTENT: {response.intent.value}")
print(f"CONFIDENCE: {response.confidence:.3f}")
print(f"\nANSWER:\n{response.answer}")
print("\nSOURCES:")
for i, result in enumerate(response.sources, 1):
src = result.chunk.metadata.get("source", "unknown")
page = result.chunk.metadata.get("page_number", "?")
print(f" [{i}] {os.path.basename(src)} (p.{page}) — score: {result.score:.4f}")
print(f" {result.chunk.text[:120]}...")
print("=" * 70)
logger.info("=== E2E Test Complete ===")
finally:
# Clean up temp Qdrant data
shutil.rmtree(qdrant_tmp, ignore_errors=True)
logger.info("Cleaned up temp Qdrant at %s", qdrant_tmp)
if __name__ == "__main__":
main()
|