Commit ·
c7256ee
0
Parent(s):
hf-space: deploy branch without frontend/data/results
Browse files- .dockerignore +31 -0
- .gitattributes +35 -0
- .gitignore +32 -0
- Dockerfile +26 -0
- README.md +12 -0
- backend/api.py +32 -0
- backend/routes/health.py +11 -0
- backend/routes/predict.py +88 -0
- backend/routes/predict_stream.py +151 -0
- backend/routes/title.py +36 -0
- backend/schemas.py +35 -0
- backend/services/cache.py +30 -0
- backend/services/chunks.py +52 -0
- backend/services/models.py +37 -0
- backend/services/startup.py +124 -0
- backend/services/streaming.py +7 -0
- backend/services/title.py +86 -0
- backend/state.py +7 -0
- config.yaml +46 -0
- config_loader.py +27 -0
- main.py +659 -0
- main_easy.py +104 -0
- models/deepseek_v3.py +25 -0
- models/llama_3_8b.py +22 -0
- models/mistral_7b.py +29 -0
- models/qwen_2_5.py +22 -0
- models/tiny_aya.py +25 -0
- requirements.txt +97 -0
- retriever/evaluator.py +331 -0
- retriever/generator.py +45 -0
- retriever/processor.py +288 -0
- retriever/retriever.py +354 -0
.dockerignore
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python caches
|
| 2 |
+
__pycache__/
|
| 3 |
+
**/__pycache__/
|
| 4 |
+
*.py[cod]
|
| 5 |
+
*.pyo
|
| 6 |
+
|
| 7 |
+
# Virtual environments
|
| 8 |
+
.venv/
|
| 9 |
+
venv/
|
| 10 |
+
ENV/
|
| 11 |
+
env/
|
| 12 |
+
|
| 13 |
+
# Frontend app (deployed separately on Vercel)
|
| 14 |
+
frontend/
|
| 15 |
+
|
| 16 |
+
# Local/runtime cache
|
| 17 |
+
.cache/
|
| 18 |
+
|
| 19 |
+
# Explicit user-requested exclusions
|
| 20 |
+
/EntireBookCleaned.txt
|
| 21 |
+
/startup.txt
|
| 22 |
+
|
| 23 |
+
# Git and editor noise
|
| 24 |
+
.git/
|
| 25 |
+
.gitignore
|
| 26 |
+
.vscode/
|
| 27 |
+
.idea/
|
| 28 |
+
|
| 29 |
+
# OS artifacts
|
| 30 |
+
.DS_Store
|
| 31 |
+
Thumbs.db
|
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python specific ignores
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# Virtual environments
|
| 7 |
+
.venv/
|
| 8 |
+
venv/
|
| 9 |
+
env/
|
| 10 |
+
ENV/
|
| 11 |
+
|
| 12 |
+
# Environment and local secrets
|
| 13 |
+
.env
|
| 14 |
+
.env.*
|
| 15 |
+
!.env.example
|
| 16 |
+
|
| 17 |
+
# Build and packaging artifacts
|
| 18 |
+
build/
|
| 19 |
+
dist/
|
| 20 |
+
*.egg-info/
|
| 21 |
+
.eggs/
|
| 22 |
+
|
| 23 |
+
# Caches and tooling
|
| 24 |
+
.pytest_cache/
|
| 25 |
+
.mypy_cache/
|
| 26 |
+
.ruff_cache/
|
| 27 |
+
.ipynb_checkpoints/
|
| 28 |
+
.cache/
|
| 29 |
+
|
| 30 |
+
# IDE/editor
|
| 31 |
+
.vscode/
|
| 32 |
+
.idea/
|
Dockerfile
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 4 |
+
PYTHONUNBUFFERED=1 \
|
| 5 |
+
PIP_NO_CACHE_DIR=1
|
| 6 |
+
|
| 7 |
+
WORKDIR /app
|
| 8 |
+
|
| 9 |
+
# Minimal system packages for common Python builds.
|
| 10 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 11 |
+
build-essential \
|
| 12 |
+
curl \
|
| 13 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
+
|
| 15 |
+
COPY requirements.txt ./
|
| 16 |
+
RUN pip install --upgrade pip && pip install -r requirements.txt
|
| 17 |
+
|
| 18 |
+
COPY . .
|
| 19 |
+
|
| 20 |
+
# Fail fast during build if critical runtime folders are missing from context.
|
| 21 |
+
RUN test -d /app/backend && test -d /app/data && test -d /app/results
|
| 22 |
+
|
| 23 |
+
# Hugging Face Spaces exposes apps on port 7860 by default.
|
| 24 |
+
EXPOSE 7860
|
| 25 |
+
|
| 26 |
+
CMD ["uvicorn", "backend.api:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: NLP RAG
|
| 3 |
+
emoji: 🏢
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
license: mit
|
| 9 |
+
short_description: NLP Spring 2026 Project 1
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
backend/api.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
from backend.routes.health import router as health_router
|
| 4 |
+
from backend.routes.predict import router as predict_router
|
| 5 |
+
from backend.routes.predict_stream import router as predict_stream_router
|
| 6 |
+
from backend.routes.title import router as title_router
|
| 7 |
+
from backend.services.startup import initialize_runtime_state
|
| 8 |
+
from backend.state import state
|
| 9 |
+
|
| 10 |
+
# fastapi configs defined here
|
| 11 |
+
# all the router objects are imported here
|
| 12 |
+
#--@Qamar
|
| 13 |
+
|
| 14 |
+
app = FastAPI(title="RAG-AS3 API", version="0.1.0")
|
| 15 |
+
|
| 16 |
+
app.add_middleware(
|
| 17 |
+
CORSMiddleware,
|
| 18 |
+
allow_origins=["*"],
|
| 19 |
+
allow_credentials=True,
|
| 20 |
+
allow_methods=["*"],
|
| 21 |
+
allow_headers=["*"],
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
app.include_router(health_router)
|
| 25 |
+
app.include_router(title_router)
|
| 26 |
+
app.include_router(predict_router)
|
| 27 |
+
app.include_router(predict_stream_router)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@app.on_event("startup")
|
| 31 |
+
def startup_event() -> None:
|
| 32 |
+
initialize_runtime_state(state)
|
backend/routes/health.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter
|
| 2 |
+
|
| 3 |
+
from backend.state import REQUIRED_STATE_KEYS, state
|
| 4 |
+
|
| 5 |
+
router = APIRouter()
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@router.get("/health")
|
| 9 |
+
def health() -> dict[str, str]:
|
| 10 |
+
ready = all(k in state for k in REQUIRED_STATE_KEYS)
|
| 11 |
+
return {"status": "ok" if ready else "starting"}
|
backend/routes/predict.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
from fastapi import APIRouter, HTTPException
|
| 5 |
+
|
| 6 |
+
from backend.schemas import PredictRequest, PredictResponse
|
| 7 |
+
from backend.services.chunks import build_retrieved_chunks
|
| 8 |
+
from backend.services.models import resolve_model
|
| 9 |
+
from backend.state import state
|
| 10 |
+
from retriever.generator import RAGGenerator
|
| 11 |
+
from retriever.retriever import HybridRetriever
|
| 12 |
+
|
| 13 |
+
router = APIRouter()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@router.post("/predict", response_model=PredictResponse)
|
| 17 |
+
def predict(payload: PredictRequest) -> PredictResponse:
|
| 18 |
+
req_start = time.perf_counter()
|
| 19 |
+
|
| 20 |
+
precheck_start = time.perf_counter()
|
| 21 |
+
if not state:
|
| 22 |
+
raise HTTPException(status_code=503, detail="Service not initialized yet")
|
| 23 |
+
|
| 24 |
+
query = payload.query.strip()
|
| 25 |
+
if not query:
|
| 26 |
+
raise HTTPException(status_code=400, detail="Query cannot be empty")
|
| 27 |
+
precheck_time = time.perf_counter() - precheck_start
|
| 28 |
+
|
| 29 |
+
state_access_start = time.perf_counter()
|
| 30 |
+
retriever: HybridRetriever = state["retriever"]
|
| 31 |
+
index = state["index"]
|
| 32 |
+
rag_engine: RAGGenerator = state["rag_engine"]
|
| 33 |
+
models: dict[str, Any] = state["models"]
|
| 34 |
+
chunk_lookup: dict[str, dict[str, Any]] = state.get("chunk_lookup", {})
|
| 35 |
+
state_access_time = time.perf_counter() - state_access_start
|
| 36 |
+
|
| 37 |
+
model_resolve_start = time.perf_counter()
|
| 38 |
+
model_name, model_instance = resolve_model(payload.model, models)
|
| 39 |
+
model_resolve_time = time.perf_counter() - model_resolve_start
|
| 40 |
+
|
| 41 |
+
retrieval_start = time.perf_counter()
|
| 42 |
+
contexts = retriever.search(
|
| 43 |
+
query,
|
| 44 |
+
index,
|
| 45 |
+
chunking_technique=payload.chunking_technique,
|
| 46 |
+
mode=payload.mode,
|
| 47 |
+
rerank_strategy=payload.rerank_strategy,
|
| 48 |
+
use_mmr=payload.use_mmr,
|
| 49 |
+
lambda_param=payload.lambda_param,
|
| 50 |
+
top_k=payload.top_k,
|
| 51 |
+
final_k=payload.final_k,
|
| 52 |
+
verbose=False,
|
| 53 |
+
)
|
| 54 |
+
retrieval_time = time.perf_counter() - retrieval_start
|
| 55 |
+
|
| 56 |
+
if not contexts:
|
| 57 |
+
raise HTTPException(status_code=404, detail="No context chunks retrieved for this query")
|
| 58 |
+
|
| 59 |
+
inference_start = time.perf_counter()
|
| 60 |
+
answer = rag_engine.get_answer(model_instance, query, contexts, temperature=payload.temperature)
|
| 61 |
+
inference_time = time.perf_counter() - inference_start
|
| 62 |
+
|
| 63 |
+
mapping_start = time.perf_counter()
|
| 64 |
+
retrieved_chunks = build_retrieved_chunks(contexts=contexts, chunk_lookup=chunk_lookup)
|
| 65 |
+
mapping_time = time.perf_counter() - mapping_start
|
| 66 |
+
|
| 67 |
+
total_time = time.perf_counter() - req_start
|
| 68 |
+
|
| 69 |
+
print(
|
| 70 |
+
f"Predict timing | model={model_name} | mode={payload.mode} | "
|
| 71 |
+
f"rerank={payload.rerank_strategy} | use_mmr={payload.use_mmr} | "
|
| 72 |
+
f"lambda={payload.lambda_param:.2f} | temp={payload.temperature:.2f} | "
|
| 73 |
+
f"chunking={payload.chunking_technique} | "
|
| 74 |
+
f"top_k={payload.top_k} | final_k={payload.final_k} | returned={len(contexts)} | "
|
| 75 |
+
f"precheck={precheck_time:.3f}s | "
|
| 76 |
+
f"state_access={state_access_time:.3f}s | model_resolve={model_resolve_time:.3f}s | "
|
| 77 |
+
f"retrieval={retrieval_time:.3f}s | inference={inference_time:.3f}s | "
|
| 78 |
+
f"context_map={mapping_time:.3f}s | total={total_time:.3f}s"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
return PredictResponse(
|
| 82 |
+
model=model_name,
|
| 83 |
+
answer=answer,
|
| 84 |
+
contexts=contexts,
|
| 85 |
+
retrieved_chunks=retrieved_chunks,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
|
backend/routes/predict_stream.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter, HTTPException
|
| 6 |
+
from fastapi.responses import StreamingResponse
|
| 7 |
+
|
| 8 |
+
from backend.schemas import PredictRequest
|
| 9 |
+
from backend.services.chunks import build_retrieved_chunks
|
| 10 |
+
from backend.services.models import resolve_model
|
| 11 |
+
from backend.services.streaming import to_ndjson
|
| 12 |
+
from backend.state import state
|
| 13 |
+
from retriever.generator import RAGGenerator
|
| 14 |
+
from retriever.retriever import HybridRetriever
|
| 15 |
+
|
| 16 |
+
router = APIRouter()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# all paths define and API router object which is called
|
| 20 |
+
# in the api.py
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@router.post("/predict/stream")
|
| 24 |
+
def predict_stream(payload: PredictRequest) -> StreamingResponse:
|
| 25 |
+
req_start = time.perf_counter()
|
| 26 |
+
stream_max_tokens = int(os.getenv("STREAM_MAX_TOKENS", "400"))
|
| 27 |
+
|
| 28 |
+
precheck_start = time.perf_counter()
|
| 29 |
+
if not state:
|
| 30 |
+
raise HTTPException(status_code=503, detail="Service not initialized yet")
|
| 31 |
+
|
| 32 |
+
query = payload.query.strip()
|
| 33 |
+
if not query:
|
| 34 |
+
raise HTTPException(status_code=400, detail="Query cannot be empty")
|
| 35 |
+
precheck_time = time.perf_counter() - precheck_start
|
| 36 |
+
|
| 37 |
+
state_access_start = time.perf_counter()
|
| 38 |
+
retriever: HybridRetriever = state["retriever"]
|
| 39 |
+
index = state["index"]
|
| 40 |
+
rag_engine: RAGGenerator = state["rag_engine"]
|
| 41 |
+
models: dict[str, Any] = state["models"]
|
| 42 |
+
chunk_lookup: dict[str, dict[str, Any]] = state.get("chunk_lookup", {})
|
| 43 |
+
state_access_time = time.perf_counter() - state_access_start
|
| 44 |
+
|
| 45 |
+
model_resolve_start = time.perf_counter()
|
| 46 |
+
model_name, model_instance = resolve_model(payload.model, models)
|
| 47 |
+
model_resolve_time = time.perf_counter() - model_resolve_start
|
| 48 |
+
|
| 49 |
+
retrieval_start = time.perf_counter()
|
| 50 |
+
contexts = retriever.search(
|
| 51 |
+
query,
|
| 52 |
+
index,
|
| 53 |
+
chunking_technique=payload.chunking_technique,
|
| 54 |
+
mode=payload.mode,
|
| 55 |
+
rerank_strategy=payload.rerank_strategy,
|
| 56 |
+
use_mmr=payload.use_mmr,
|
| 57 |
+
lambda_param=payload.lambda_param,
|
| 58 |
+
top_k=payload.top_k,
|
| 59 |
+
final_k=payload.final_k,
|
| 60 |
+
verbose=False,
|
| 61 |
+
)
|
| 62 |
+
retrieval_time = time.perf_counter() - retrieval_start
|
| 63 |
+
|
| 64 |
+
if not contexts:
|
| 65 |
+
raise HTTPException(status_code=404, detail="No context chunks retrieved for this query")
|
| 66 |
+
|
| 67 |
+
def stream_events():
|
| 68 |
+
inference_start = time.perf_counter()
|
| 69 |
+
first_token_latency = None
|
| 70 |
+
answer_parts: list[str] = []
|
| 71 |
+
try:
|
| 72 |
+
yield to_ndjson(
|
| 73 |
+
{
|
| 74 |
+
"type": "status",
|
| 75 |
+
"stage": "inference_start",
|
| 76 |
+
"model": model_name,
|
| 77 |
+
"retrieval_s": round(retrieval_time, 3),
|
| 78 |
+
"retrieval_debug": {
|
| 79 |
+
"requested_chunking_technique": payload.chunking_technique,
|
| 80 |
+
"requested_top_k": payload.top_k,
|
| 81 |
+
"requested_final_k": payload.final_k,
|
| 82 |
+
"returned_context_count": len(contexts),
|
| 83 |
+
"use_mmr": payload.use_mmr,
|
| 84 |
+
"lambda_param": payload.lambda_param,
|
| 85 |
+
},
|
| 86 |
+
}
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
for token in rag_engine.get_answer_stream(
|
| 90 |
+
model_instance,
|
| 91 |
+
query,
|
| 92 |
+
contexts,
|
| 93 |
+
temperature=payload.temperature,
|
| 94 |
+
max_tokens=stream_max_tokens,
|
| 95 |
+
):
|
| 96 |
+
if first_token_latency is None:
|
| 97 |
+
first_token_latency = time.perf_counter() - inference_start
|
| 98 |
+
answer_parts.append(token)
|
| 99 |
+
yield to_ndjson({"type": "token", "token": token})
|
| 100 |
+
|
| 101 |
+
inference_time = time.perf_counter() - inference_start
|
| 102 |
+
answer = "".join(answer_parts)
|
| 103 |
+
retrieved_chunks = build_retrieved_chunks(contexts=contexts, chunk_lookup=chunk_lookup)
|
| 104 |
+
|
| 105 |
+
yield to_ndjson(
|
| 106 |
+
{
|
| 107 |
+
"type": "done",
|
| 108 |
+
"model": model_name,
|
| 109 |
+
"answer": answer,
|
| 110 |
+
"contexts": contexts,
|
| 111 |
+
"retrieved_chunks": retrieved_chunks,
|
| 112 |
+
"retrieval_debug": {
|
| 113 |
+
"requested_chunking_technique": payload.chunking_technique,
|
| 114 |
+
"requested_top_k": payload.top_k,
|
| 115 |
+
"requested_final_k": payload.final_k,
|
| 116 |
+
"returned_context_count": len(contexts),
|
| 117 |
+
"use_mmr": payload.use_mmr,
|
| 118 |
+
"lambda_param": payload.lambda_param,
|
| 119 |
+
},
|
| 120 |
+
}
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
total_time = time.perf_counter() - req_start
|
| 124 |
+
print(
|
| 125 |
+
f"Predict stream timing | model={model_name} | mode={payload.mode} | "
|
| 126 |
+
f"rerank={payload.rerank_strategy} | use_mmr={payload.use_mmr} | "
|
| 127 |
+
f"lambda={payload.lambda_param:.2f} | temp={payload.temperature:.2f} | "
|
| 128 |
+
f"chunking={payload.chunking_technique} | "
|
| 129 |
+
f"top_k={payload.top_k} | final_k={payload.final_k} | returned={len(contexts)} | "
|
| 130 |
+
f"precheck={precheck_time:.3f}s | "
|
| 131 |
+
f"state_access={state_access_time:.3f}s | model_resolve={model_resolve_time:.3f}s | "
|
| 132 |
+
f"retrieval={retrieval_time:.3f}s | first_token={first_token_latency if first_token_latency is not None else -1:.3f}s | "
|
| 133 |
+
f"inference={inference_time:.3f}s | total={total_time:.3f}s | "
|
| 134 |
+
f"max_tokens={stream_max_tokens}"
|
| 135 |
+
)
|
| 136 |
+
except Exception as exc:
|
| 137 |
+
total_time = time.perf_counter() - req_start
|
| 138 |
+
print(
|
| 139 |
+
f"Predict stream error | model={model_name} | mode={payload.mode} | "
|
| 140 |
+
f"retrieval={retrieval_time:.3f}s | elapsed={total_time:.3f}s | error={exc}"
|
| 141 |
+
)
|
| 142 |
+
yield to_ndjson({"type": "error", "message": f"Streaming failed: {exc}"})
|
| 143 |
+
|
| 144 |
+
return StreamingResponse(
|
| 145 |
+
stream_events(),
|
| 146 |
+
media_type="application/x-ndjson",
|
| 147 |
+
headers={
|
| 148 |
+
"Cache-Control": "no-cache",
|
| 149 |
+
"X-Accel-Buffering": "no",
|
| 150 |
+
},
|
| 151 |
+
)
|
backend/routes/title.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException
|
| 2 |
+
from huggingface_hub import InferenceClient
|
| 3 |
+
|
| 4 |
+
from backend.schemas import TitleRequest, TitleResponse
|
| 5 |
+
from backend.services.title import parse_title_model_candidates, title_from_hf, title_from_query
|
| 6 |
+
from backend.state import state
|
| 7 |
+
|
| 8 |
+
router = APIRouter()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@router.post("/predict/title", response_model=TitleResponse)
|
| 12 |
+
def suggest_title(payload: TitleRequest) -> TitleResponse:
|
| 13 |
+
query = payload.query.strip()
|
| 14 |
+
if not query:
|
| 15 |
+
raise HTTPException(status_code=400, detail="Query cannot be empty")
|
| 16 |
+
|
| 17 |
+
fallback_title = title_from_query(query)
|
| 18 |
+
|
| 19 |
+
title_client: InferenceClient | None = state.get("title_client")
|
| 20 |
+
title_model_ids: list[str] = state.get("title_model_ids", parse_title_model_candidates())
|
| 21 |
+
|
| 22 |
+
if title_client is not None:
|
| 23 |
+
for title_model_id in title_model_ids:
|
| 24 |
+
try:
|
| 25 |
+
hf_title = title_from_hf(query, title_client, title_model_id)
|
| 26 |
+
if hf_title:
|
| 27 |
+
return TitleResponse(title=hf_title, source=f"hf:{title_model_id}")
|
| 28 |
+
except Exception as exc:
|
| 29 |
+
err_text = str(exc)
|
| 30 |
+
if "model_not_supported" in err_text or "not supported by any provider" in err_text:
|
| 31 |
+
continue
|
| 32 |
+
print(f"Title generation model failed ({title_model_id}): {exc}")
|
| 33 |
+
continue
|
| 34 |
+
|
| 35 |
+
print("Title generation fallback triggered: no title model available/successful")
|
| 36 |
+
return TitleResponse(title=fallback_title, source="rule-based")
|
backend/schemas.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel, Field
|
| 4 |
+
|
| 5 |
+
# this defines the schemas for API endpoints
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class PredictRequest(BaseModel):
|
| 10 |
+
query: str = Field(..., min_length=1, description="User query text")
|
| 11 |
+
model: str = Field(default="Llama-3-8B", description="Model name key")
|
| 12 |
+
top_k: int = Field(default=10, ge=1, le=20)
|
| 13 |
+
final_k: int = Field(default=3, ge=1, le=8)
|
| 14 |
+
chunking_technique: str = Field(default="all", description="all | fixed | sentence | paragraph | semantic | recursive | page | markdown")
|
| 15 |
+
mode: str = Field(default="hybrid", description="semantic | bm25 | hybrid")
|
| 16 |
+
rerank_strategy: str = Field(default="cross-encoder", description="cross-encoder | rrf | none")
|
| 17 |
+
use_mmr: bool = Field(default=True, description="Whether to apply MMR after reranking")
|
| 18 |
+
lambda_param: float = Field(default=0.5, ge=0.0, le=1.0, description="MMR relevance/diversity tradeoff")
|
| 19 |
+
temperature: float = Field(default=0.1, ge=0.0, le=2.0, description="Generation temperature")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class PredictResponse(BaseModel):
|
| 23 |
+
model: str
|
| 24 |
+
answer: str
|
| 25 |
+
contexts: list[str]
|
| 26 |
+
retrieved_chunks: list[dict[str, Any]]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class TitleRequest(BaseModel):
|
| 30 |
+
query: str = Field(..., min_length=1, description="First user message")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class TitleResponse(BaseModel):
|
| 34 |
+
title: str
|
| 35 |
+
source: str
|
backend/services/cache.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
from data.vector_db import load_chunks_with_local_cache
|
| 5 |
+
|
| 6 |
+
# cacheing logic here
|
| 7 |
+
# note cacheing just useful in dev environment
|
| 8 |
+
# not really needed in hf, not even sure if hf memory is persistent
|
| 9 |
+
|
| 10 |
+
def get_cache_settings() -> tuple[str, bool]:
|
| 11 |
+
project_root = os.path.dirname(os.path.abspath(__file__))
|
| 12 |
+
cache_dir = os.getenv("BM25_CACHE_DIR", os.path.join(project_root, "..", ".cache"))
|
| 13 |
+
force_cache_refresh = os.getenv("BM25_CACHE_REFRESH", "0").lower() in {"1", "true", "yes"}
|
| 14 |
+
return cache_dir, force_cache_refresh
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def load_cached_chunks(
|
| 18 |
+
index: Any,
|
| 19 |
+
index_name: str,
|
| 20 |
+
cache_dir: str,
|
| 21 |
+
force_cache_refresh: bool,
|
| 22 |
+
batch_size: int = 100,
|
| 23 |
+
) -> tuple[list[dict[str, Any]], str]:
|
| 24 |
+
return load_chunks_with_local_cache(
|
| 25 |
+
index=index,
|
| 26 |
+
index_name=index_name,
|
| 27 |
+
cache_dir=cache_dir,
|
| 28 |
+
batch_size=batch_size,
|
| 29 |
+
force_refresh=force_cache_refresh,
|
| 30 |
+
)
|
backend/services/chunks.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
# might need to touch this to get the additional metadata for retrieved chunks, like title and url
|
| 5 |
+
# --@Qamar
|
| 6 |
+
|
| 7 |
+
def build_retrieved_chunks(
|
| 8 |
+
contexts: list[str],
|
| 9 |
+
chunk_lookup: dict[str, dict[str, Any]],
|
| 10 |
+
) -> list[dict[str, Any]]:
|
| 11 |
+
if not contexts:
|
| 12 |
+
return []
|
| 13 |
+
|
| 14 |
+
retrieved_chunks: list[dict[str, Any]] = []
|
| 15 |
+
|
| 16 |
+
for idx, text in enumerate(contexts, start=1):
|
| 17 |
+
meta = chunk_lookup.get(text, {})
|
| 18 |
+
title = meta.get("title") or "Untitled"
|
| 19 |
+
url = meta.get("url") or ""
|
| 20 |
+
chunk_index = meta.get("chunk_index")
|
| 21 |
+
page = meta.get("page")
|
| 22 |
+
section = meta.get("section")
|
| 23 |
+
source_type = meta.get("source_type") or meta.get("source")
|
| 24 |
+
image_url = (
|
| 25 |
+
meta.get("image_url")
|
| 26 |
+
or meta.get("image")
|
| 27 |
+
or meta.get("thumbnail_url")
|
| 28 |
+
or meta.get("media_url")
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
extra_metadata = {
|
| 32 |
+
k: v
|
| 33 |
+
for k, v in meta.items()
|
| 34 |
+
if k not in {"title", "url", "chunk_index", "text", "technique", "chunking_technique"}
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
retrieved_chunks.append(
|
| 38 |
+
{
|
| 39 |
+
"rank": idx,
|
| 40 |
+
"text": text,
|
| 41 |
+
"source_title": title,
|
| 42 |
+
"source_url": url,
|
| 43 |
+
"chunk_index": chunk_index,
|
| 44 |
+
"page": page,
|
| 45 |
+
"section": section,
|
| 46 |
+
"source_type": source_type,
|
| 47 |
+
"image_url": image_url,
|
| 48 |
+
"extra_metadata": extra_metadata,
|
| 49 |
+
}
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
return retrieved_chunks
|
backend/services/models.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
from fastapi import HTTPException
|
| 4 |
+
|
| 5 |
+
from models.llama_3_8b import Llama3_8B
|
| 6 |
+
from models.mistral_7b import Mistral_7b
|
| 7 |
+
from models.qwen_2_5 import Qwen2_5
|
| 8 |
+
from models.deepseek_v3 import DeepSeek_V3
|
| 9 |
+
from models.tiny_aya import TinyAya
|
| 10 |
+
|
| 11 |
+
# model defination
|
| 12 |
+
# copied from /models
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def build_models(hf_token: str) -> dict[str, Any]:
|
| 16 |
+
return {
|
| 17 |
+
"Llama-3-8B": Llama3_8B(token=hf_token),
|
| 18 |
+
"Mistral-7B": Mistral_7b(token=hf_token),
|
| 19 |
+
"Qwen-2.5": Qwen2_5(token=hf_token),
|
| 20 |
+
"DeepSeek-V3": DeepSeek_V3(token=hf_token),
|
| 21 |
+
"TinyAya": TinyAya(token=hf_token),
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def resolve_model(name: str, models: dict[str, Any]) -> tuple[str, Any]:
|
| 26 |
+
aliases = {
|
| 27 |
+
"llama": "Llama-3-8B",
|
| 28 |
+
"mistral": "Mistral-7B",
|
| 29 |
+
"qwen": "Qwen-2.5",
|
| 30 |
+
"deepseek": "DeepSeek-V3",
|
| 31 |
+
"tinyaya": "TinyAya",
|
| 32 |
+
}
|
| 33 |
+
model_key = aliases.get(name.lower(), name)
|
| 34 |
+
if model_key not in models:
|
| 35 |
+
allowed = ", ".join(models.keys())
|
| 36 |
+
raise HTTPException(status_code=400, detail=f"Unknown model '{name}'. Use one of: {allowed}")
|
| 37 |
+
return model_key, models[model_key]
|
backend/services/startup.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
from huggingface_hub import InferenceClient
|
| 7 |
+
|
| 8 |
+
from config_loader import cfg
|
| 9 |
+
from data.vector_db import get_index_by_name
|
| 10 |
+
from retriever.generator import RAGGenerator
|
| 11 |
+
from retriever.processor import ChunkProcessor
|
| 12 |
+
from retriever.retriever import HybridRetriever
|
| 13 |
+
|
| 14 |
+
from backend.services.cache import get_cache_settings, load_cached_chunks
|
| 15 |
+
from backend.services.models import build_models
|
| 16 |
+
from backend.services.title import parse_title_model_candidates
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# main file for initializing the runtime. Actual defines the
|
| 20 |
+
# pipeline objects, like retriever, generator and models
|
| 21 |
+
# i think i
|
| 22 |
+
|
| 23 |
+
def initialize_runtime_state(state: dict[str, Any]) -> None:
|
| 24 |
+
startup_start = time.perf_counter()
|
| 25 |
+
|
| 26 |
+
dotenv_start = time.perf_counter()
|
| 27 |
+
load_dotenv()
|
| 28 |
+
dotenv_time = time.perf_counter() - dotenv_start
|
| 29 |
+
|
| 30 |
+
env_start = time.perf_counter()
|
| 31 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 32 |
+
pinecone_api_key = os.getenv("PINECONE_API_KEY")
|
| 33 |
+
env_time = time.perf_counter() - env_start
|
| 34 |
+
|
| 35 |
+
if not pinecone_api_key:
|
| 36 |
+
raise RuntimeError("PINECONE_API_KEY not found in environment variables")
|
| 37 |
+
if not hf_token:
|
| 38 |
+
raise RuntimeError("HF_TOKEN not found in environment variables")
|
| 39 |
+
|
| 40 |
+
index_name = "cbt-book-recursive"
|
| 41 |
+
embed_model_name = cfg.processing.get("embedding_model", "all-MiniLM-L6-v2")
|
| 42 |
+
rerank_model_name = os.getenv(
|
| 43 |
+
"RERANK_MODEL_NAME",
|
| 44 |
+
cfg.retrieval.get("rerank_model", "mixedbread-ai/mxbai-rerank-base-v1"),
|
| 45 |
+
)
|
| 46 |
+
cache_dir, force_cache_refresh = get_cache_settings()
|
| 47 |
+
|
| 48 |
+
index_start = time.perf_counter()
|
| 49 |
+
index = get_index_by_name(api_key=pinecone_api_key, index_name=index_name)
|
| 50 |
+
index_time = time.perf_counter() - index_start
|
| 51 |
+
|
| 52 |
+
chunks_start = time.perf_counter()
|
| 53 |
+
final_chunks, chunk_source = load_cached_chunks(
|
| 54 |
+
index=index,
|
| 55 |
+
index_name=index_name,
|
| 56 |
+
cache_dir=cache_dir,
|
| 57 |
+
force_cache_refresh=force_cache_refresh,
|
| 58 |
+
)
|
| 59 |
+
chunk_load_time = time.perf_counter() - chunks_start
|
| 60 |
+
|
| 61 |
+
if not final_chunks:
|
| 62 |
+
raise RuntimeError("No chunks found in Pinecone metadata. Run indexing once before API mode.")
|
| 63 |
+
|
| 64 |
+
processor_start = time.perf_counter()
|
| 65 |
+
proc = ChunkProcessor(model_name=embed_model_name, verbose=False, load_hf_embeddings=False)
|
| 66 |
+
processor_time = time.perf_counter() - processor_start
|
| 67 |
+
|
| 68 |
+
retriever_start = time.perf_counter()
|
| 69 |
+
retriever = HybridRetriever(
|
| 70 |
+
final_chunks,
|
| 71 |
+
proc.encoder,
|
| 72 |
+
rerank_model_name=rerank_model_name,
|
| 73 |
+
verbose=False,
|
| 74 |
+
)
|
| 75 |
+
retriever_time = time.perf_counter() - retriever_start
|
| 76 |
+
|
| 77 |
+
rag_start = time.perf_counter()
|
| 78 |
+
rag_engine = RAGGenerator()
|
| 79 |
+
rag_time = time.perf_counter() - rag_start
|
| 80 |
+
|
| 81 |
+
models_start = time.perf_counter()
|
| 82 |
+
models = build_models(hf_token)
|
| 83 |
+
models_time = time.perf_counter() - models_start
|
| 84 |
+
|
| 85 |
+
state_start = time.perf_counter()
|
| 86 |
+
chunk_lookup: dict[str, dict[str, Any]] = {}
|
| 87 |
+
for chunk in final_chunks:
|
| 88 |
+
metadata = chunk.get("metadata", {})
|
| 89 |
+
text = metadata.get("text")
|
| 90 |
+
if not text or text in chunk_lookup:
|
| 91 |
+
continue
|
| 92 |
+
meta_without_text = {k: v for k, v in metadata.items() if k != "text"}
|
| 93 |
+
meta_without_text["title"] = metadata.get("title", "Untitled")
|
| 94 |
+
meta_without_text["url"] = metadata.get("url", "")
|
| 95 |
+
meta_without_text["chunk_index"] = metadata.get("chunk_index")
|
| 96 |
+
chunk_lookup[text] = meta_without_text
|
| 97 |
+
|
| 98 |
+
state["index"] = index
|
| 99 |
+
state["retriever"] = retriever
|
| 100 |
+
state["rag_engine"] = rag_engine
|
| 101 |
+
state["models"] = models
|
| 102 |
+
state["chunk_lookup"] = chunk_lookup
|
| 103 |
+
state["title_model_ids"] = parse_title_model_candidates()
|
| 104 |
+
state["title_client"] = InferenceClient(token=hf_token)
|
| 105 |
+
state_time = time.perf_counter() - state_start
|
| 106 |
+
|
| 107 |
+
startup_time = time.perf_counter() - startup_start
|
| 108 |
+
print(
|
| 109 |
+
f"API startup complete | chunks={len(final_chunks)} | "
|
| 110 |
+
f"dotenv={dotenv_time:.3f}s | "
|
| 111 |
+
f"env={env_time:.3f}s | "
|
| 112 |
+
f"index={index_time:.3f}s | "
|
| 113 |
+
f"cache_dir={cache_dir} | "
|
| 114 |
+
f"force_cache_refresh={force_cache_refresh} | "
|
| 115 |
+
f"chunk_source={chunk_source} | "
|
| 116 |
+
f"chunk_load={chunk_load_time:.3f}s | "
|
| 117 |
+
f"processor={processor_time:.3f}s | "
|
| 118 |
+
f"rerank_model={rerank_model_name} | "
|
| 119 |
+
f"retriever={retriever_time:.3f}s | "
|
| 120 |
+
f"rag={rag_time:.3f}s | "
|
| 121 |
+
f"models={models_time:.3f}s | "
|
| 122 |
+
f"state={state_time:.3f}s | "
|
| 123 |
+
f"total={startup_time:.3f}s"
|
| 124 |
+
)
|
backend/services/streaming.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
#need ndjson for streaming responses, this is a simple helper to convert dicts to ndjson format
|
| 5 |
+
|
| 6 |
+
def to_ndjson(payload: dict[str, Any]) -> str:
|
| 7 |
+
return json.dumps(payload, ensure_ascii=False) + "\n"
|
backend/services/title.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
from huggingface_hub import InferenceClient
|
| 5 |
+
|
| 6 |
+
# the functions for resolving and generating titles
|
| 7 |
+
# it tries to query and hf model for title
|
| 8 |
+
|
| 9 |
+
# some shitty fallback logic, if models fail
|
| 10 |
+
# could improve candidate model defining code
|
| 11 |
+
|
| 12 |
+
def title_from_query(query: str) -> str:
|
| 13 |
+
stop_words = {
|
| 14 |
+
"a", "an", "and", "are", "as", "at", "be", "by", "can", "do", "for", "from", "how",
|
| 15 |
+
"i", "in", "is", "it", "me", "my", "of", "on", "or", "please", "show", "tell", "that",
|
| 16 |
+
"the", "this", "to", "we", "what", "when", "where", "which", "why", "with", "you", "your",
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
words = re.findall(r"[A-Za-z0-9][A-Za-z0-9\-_/+]*", query)
|
| 20 |
+
if not words:
|
| 21 |
+
return "New Chat"
|
| 22 |
+
|
| 23 |
+
filtered: list[str] = []
|
| 24 |
+
for word in words:
|
| 25 |
+
cleaned = word.strip("-_/+")
|
| 26 |
+
if not cleaned:
|
| 27 |
+
continue
|
| 28 |
+
if cleaned.lower() in stop_words:
|
| 29 |
+
continue
|
| 30 |
+
filtered.append(cleaned)
|
| 31 |
+
if len(filtered) >= 6:
|
| 32 |
+
break
|
| 33 |
+
|
| 34 |
+
chosen = filtered if filtered else words[:6]
|
| 35 |
+
normalized = [w.capitalize() if w.islower() else w for w in chosen]
|
| 36 |
+
title = " ".join(normalized).strip()
|
| 37 |
+
return title[:80] if title else "New Chat"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def clean_title_text(raw: str) -> str:
|
| 41 |
+
text = (raw or "").strip()
|
| 42 |
+
text = text.replace("\n", " ").replace("\r", " ")
|
| 43 |
+
text = re.sub(r"^[\"'`\s]+|[\"'`\s]+$", "", text)
|
| 44 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 45 |
+
words = text.split()
|
| 46 |
+
if len(words) > 8:
|
| 47 |
+
text = " ".join(words[:8])
|
| 48 |
+
return text[:80]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def title_from_hf(query: str, client: InferenceClient, model_id: str) -> str | None:
|
| 52 |
+
system_prompt = (
|
| 53 |
+
"You generate short chat titles. Return only a title, no punctuation at the end, no quotes."
|
| 54 |
+
)
|
| 55 |
+
user_prompt = (
|
| 56 |
+
"Create a concise 3-7 word title for this user request:\n"
|
| 57 |
+
f"{query}"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
response = client.chat_completion(
|
| 61 |
+
model=model_id,
|
| 62 |
+
messages=[
|
| 63 |
+
{"role": "system", "content": system_prompt},
|
| 64 |
+
{"role": "user", "content": user_prompt},
|
| 65 |
+
],
|
| 66 |
+
max_tokens=24,
|
| 67 |
+
temperature=0.3,
|
| 68 |
+
)
|
| 69 |
+
if not response or not response.choices:
|
| 70 |
+
return None
|
| 71 |
+
|
| 72 |
+
raw_title = response.choices[0].message.content or ""
|
| 73 |
+
title = clean_title_text(raw_title)
|
| 74 |
+
if not title or title.lower() == "new chat":
|
| 75 |
+
return None
|
| 76 |
+
return title
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def parse_title_model_candidates() -> list[str]:
|
| 80 |
+
|
| 81 |
+
raw = os.getenv(
|
| 82 |
+
"TITLE_MODEL_IDS",
|
| 83 |
+
"Qwen/Qwen2.5-1.5B-Instruct,CohereLabs/tiny-aya-global,meta-llama/Meta-Llama-3-8B-Instruct",
|
| 84 |
+
)
|
| 85 |
+
models = [m.strip() for m in raw.split(",") if m.strip()]
|
| 86 |
+
return models or ["meta-llama/Meta-Llama-3-8B-Instruct"]
|
backend/state.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
state: dict[str, Any] = {}
|
| 4 |
+
# this file defines the state dict
|
| 5 |
+
# think of this as the runtime object created after startup
|
| 6 |
+
|
| 7 |
+
REQUIRED_STATE_KEYS = ("index", "retriever", "rag_engine", "models")
|
config.yaml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------
|
| 2 |
+
# RAG CBT QUESTION-ANSWERING SYSTEM CONFIGURATION
|
| 3 |
+
# ------------------------------------------------------------------
|
| 4 |
+
|
| 5 |
+
project:
|
| 6 |
+
name: "cbt-rag-system"
|
| 7 |
+
category: "psychology"
|
| 8 |
+
doc_limit: null # Load all pages from the book
|
| 9 |
+
|
| 10 |
+
processing:
|
| 11 |
+
# Embedding model used for both vector db and evaluator similarity
|
| 12 |
+
embedding_model: "jinaai/jina-embeddings-v2-small-en"
|
| 13 |
+
# Options: sentence, recursive, semantic, fixed
|
| 14 |
+
technique: "recursive"
|
| 15 |
+
# Jina supports 8192 tokens (~32k chars), using 1000 chars for better context
|
| 16 |
+
chunk_size: 1000
|
| 17 |
+
chunk_overlap: 100
|
| 18 |
+
|
| 19 |
+
vector_db:
|
| 20 |
+
base_index_name: "cbt-book"
|
| 21 |
+
dimension: 512 # Jina outputs 512 dimensions
|
| 22 |
+
metric: "cosine"
|
| 23 |
+
batch_size: 50 # Reduced batch size for CPU processing
|
| 24 |
+
|
| 25 |
+
retrieval:
|
| 26 |
+
# Options: hybrid, semantic, bm25
|
| 27 |
+
mode: "hybrid"
|
| 28 |
+
# Options: cross-encoder, rrf
|
| 29 |
+
rerank_strategy: "cross-encoder"
|
| 30 |
+
use_mmr: true
|
| 31 |
+
top_k: 10
|
| 32 |
+
final_k: 5
|
| 33 |
+
|
| 34 |
+
generation:
|
| 35 |
+
temperature: 0.
|
| 36 |
+
max_new_tokens: 512
|
| 37 |
+
# The model used to Judge the others (OpenRouter)
|
| 38 |
+
judge_model: "stepfun/step-3.5-flash:free"
|
| 39 |
+
|
| 40 |
+
# List of contestants in the tournament
|
| 41 |
+
models:
|
| 42 |
+
- "Llama-3-8B"
|
| 43 |
+
- "Mistral-7B"
|
| 44 |
+
- "Qwen-2.5"
|
| 45 |
+
- "DeepSeek-V3"
|
| 46 |
+
- "TinyAya"
|
config_loader.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
class RAGConfig:
|
| 5 |
+
def __init__(self, config_path="config.yaml"):
|
| 6 |
+
with open(config_path, 'r') as f:
|
| 7 |
+
self.data = yaml.safe_load(f)
|
| 8 |
+
|
| 9 |
+
@property
|
| 10 |
+
def project(self): return self.data['project']
|
| 11 |
+
|
| 12 |
+
@property
|
| 13 |
+
def processing(self): return self.data['processing']
|
| 14 |
+
|
| 15 |
+
@property
|
| 16 |
+
def db(self): return self.data['vector_db']
|
| 17 |
+
|
| 18 |
+
@property
|
| 19 |
+
def retrieval(self): return self.data['retrieval']
|
| 20 |
+
|
| 21 |
+
@property
|
| 22 |
+
def gen(self): return self.data['generation']
|
| 23 |
+
|
| 24 |
+
@property
|
| 25 |
+
def model_list(self): return self.data['models']
|
| 26 |
+
|
| 27 |
+
cfg = RAGConfig()
|
main.py
ADDED
|
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from multiprocessing import Pool, cpu_count
|
| 6 |
+
from functools import partial
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
from config_loader import cfg
|
| 9 |
+
|
| 10 |
+
from data.vector_db import get_pinecone_index, refresh_pinecone_index
|
| 11 |
+
from retriever.retriever import HybridRetriever
|
| 12 |
+
from retriever.generator import RAGGenerator
|
| 13 |
+
from retriever.processor import ChunkProcessor
|
| 14 |
+
from retriever.evaluator import RAGEvaluator
|
| 15 |
+
from data.data_loader import load_cbt_book, get_book_stats
|
| 16 |
+
from data.ingest import ingest_data, CHUNKING_TECHNIQUES
|
| 17 |
+
|
| 18 |
+
# Import model fleet
|
| 19 |
+
from models.llama_3_8b import Llama3_8B
|
| 20 |
+
from models.mistral_7b import Mistral_7b
|
| 21 |
+
from models.qwen_2_5 import Qwen2_5
|
| 22 |
+
from models.deepseek_v3 import DeepSeek_V3
|
| 23 |
+
from models.tiny_aya import TinyAya
|
| 24 |
+
|
| 25 |
+
MODEL_MAP = {
|
| 26 |
+
"Llama-3-8B": Llama3_8B,
|
| 27 |
+
"Mistral-7B": Mistral_7b,
|
| 28 |
+
"Qwen-2.5": Qwen2_5,
|
| 29 |
+
"DeepSeek-V3": DeepSeek_V3,
|
| 30 |
+
"TinyAya": TinyAya
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
load_dotenv()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def run_rag_for_technique(technique_name, query, index, encoder, models, evaluator, rag_engine):
|
| 37 |
+
"""Run RAG pipeline for a specific chunking technique."""
|
| 38 |
+
|
| 39 |
+
print(f"\n{'='*80}")
|
| 40 |
+
print(f"TECHNIQUE: {technique_name.upper()}")
|
| 41 |
+
print(f"{'='*80}")
|
| 42 |
+
|
| 43 |
+
# Filter chunks by technique metadata
|
| 44 |
+
query_vector = encoder.encode(query).tolist()
|
| 45 |
+
|
| 46 |
+
# Query with metadata filter for this technique - get more candidates for reranking
|
| 47 |
+
res = index.query(
|
| 48 |
+
vector=query_vector,
|
| 49 |
+
top_k=25,
|
| 50 |
+
include_metadata=True,
|
| 51 |
+
filter={"technique": {"$eq": technique_name}}
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# Extract context chunks with URLs
|
| 55 |
+
all_candidates = []
|
| 56 |
+
chunk_urls = []
|
| 57 |
+
for match in res['matches']:
|
| 58 |
+
all_candidates.append(match['metadata']['text'])
|
| 59 |
+
chunk_urls.append(match['metadata'].get('url', ''))
|
| 60 |
+
|
| 61 |
+
print(f"\nRetrieved {len(all_candidates)} candidate chunks for technique '{technique_name}'")
|
| 62 |
+
|
| 63 |
+
if not all_candidates:
|
| 64 |
+
print(f"WARNING: No chunks found for technique '{technique_name}'")
|
| 65 |
+
return {}
|
| 66 |
+
|
| 67 |
+
# Apply cross-encoder reranking to get top 5
|
| 68 |
+
# Use global reranker loaded once per worker
|
| 69 |
+
global _worker_reranker
|
| 70 |
+
pairs = [[query, chunk] for chunk in all_candidates]
|
| 71 |
+
scores = _worker_reranker.predict(pairs)
|
| 72 |
+
ranked = sorted(zip(all_candidates, chunk_urls, scores), key=lambda x: x[2], reverse=True)
|
| 73 |
+
context_chunks = [chunk for chunk, _, _ in ranked[:5]]
|
| 74 |
+
context_urls = [url for _, url, _ in ranked[:5]]
|
| 75 |
+
|
| 76 |
+
print(f"After reranking: {len(context_chunks)} chunks (top 5)")
|
| 77 |
+
|
| 78 |
+
# Print the final RAG context being passed to models (only once)
|
| 79 |
+
print(f"\n{'='*80}")
|
| 80 |
+
print(f"📚 FINAL RAG CONTEXT FOR TECHNIQUE '{technique_name.upper()}'")
|
| 81 |
+
print(f"{'='*80}")
|
| 82 |
+
for i, chunk in enumerate(context_chunks, 1):
|
| 83 |
+
print(f"\n[Chunk {i}] ({len(chunk)} chars):")
|
| 84 |
+
print(f"{'─'*60}")
|
| 85 |
+
print(chunk)
|
| 86 |
+
print(f"{'─'*60}")
|
| 87 |
+
print(f"\n{'='*80}")
|
| 88 |
+
|
| 89 |
+
# Run model tournament for this technique
|
| 90 |
+
tournament_results = {}
|
| 91 |
+
|
| 92 |
+
for name, model_inst in models.items():
|
| 93 |
+
print(f"\n{'-'*60}")
|
| 94 |
+
print(f"Model: {name}")
|
| 95 |
+
print(f"{'-'*60}")
|
| 96 |
+
try:
|
| 97 |
+
# Generation
|
| 98 |
+
answer = rag_engine.get_answer(
|
| 99 |
+
model_inst, query, context_chunks,
|
| 100 |
+
context_urls=context_urls,
|
| 101 |
+
temperature=cfg.gen['temperature']
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
print(f"\n{'─'*60}")
|
| 105 |
+
print(f"📝 FULL ANSWER from {name}:")
|
| 106 |
+
print(f"{'─'*60}")
|
| 107 |
+
print(answer)
|
| 108 |
+
print(f"{'─'*60}")
|
| 109 |
+
|
| 110 |
+
# Faithfulness Evaluation (strict=False reduces API calls from ~22 to ~3 per eval)
|
| 111 |
+
faith = evaluator.evaluate_faithfulness(answer, context_chunks, strict=False)
|
| 112 |
+
# Relevancy Evaluation
|
| 113 |
+
rel = evaluator.evaluate_relevancy(query, answer)
|
| 114 |
+
|
| 115 |
+
tournament_results[name] = {
|
| 116 |
+
"answer": answer,
|
| 117 |
+
"Faithfulness": faith['score'],
|
| 118 |
+
"Relevancy": rel['score'],
|
| 119 |
+
"Claims": faith['details'],
|
| 120 |
+
"context_chunks": context_chunks,
|
| 121 |
+
"context_urls": context_urls
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
print(f"\n📊 EVALUATION SCORES:")
|
| 125 |
+
print(f" Faithfulness: {faith['score']:.1f}%")
|
| 126 |
+
print(f" Relevancy: {rel['score']:.3f}")
|
| 127 |
+
print(f" Combined: {faith['score'] + rel['score']:.3f}")
|
| 128 |
+
|
| 129 |
+
except Exception as e:
|
| 130 |
+
print(f" Error evaluating {name}: {e}")
|
| 131 |
+
tournament_results[name] = {
|
| 132 |
+
"answer": "",
|
| 133 |
+
"Faithfulness": 0,
|
| 134 |
+
"Relevancy": 0,
|
| 135 |
+
"Claims": [],
|
| 136 |
+
"error": str(e),
|
| 137 |
+
"context_chunks": context_chunks,
|
| 138 |
+
"context_urls": context_urls
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
return tournament_results
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def generate_findings_document(all_query_results, queries, output_file="rag_ablation_findings.md"):
|
| 145 |
+
"""Generate detailed markdown document with findings from all techniques across all queries.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
all_query_results: Dict mapping query index to results dict
|
| 149 |
+
queries: List of all test queries
|
| 150 |
+
output_file: Path to output file
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 154 |
+
|
| 155 |
+
content = f"""# RAG Ablation Study Findings
|
| 156 |
+
|
| 157 |
+
*Generated:* {timestamp}
|
| 158 |
+
|
| 159 |
+
## Overview
|
| 160 |
+
|
| 161 |
+
This document presents findings from a comparative analysis of 6 different chunking techniques
|
| 162 |
+
applied to a Cognitive Behavioral Therapy (CBT) book. Each technique was evaluated using
|
| 163 |
+
multiple LLM models with RAG (Retrieval-Augmented Generation) pipeline.
|
| 164 |
+
|
| 165 |
+
## Test Queries
|
| 166 |
+
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
for i, query in enumerate(queries, 1):
|
| 170 |
+
content += f"{i}. {query}\n"
|
| 171 |
+
|
| 172 |
+
content += """
|
| 173 |
+
## Chunking Techniques Evaluated
|
| 174 |
+
|
| 175 |
+
1. *Fixed* - Fixed-size chunking (1000 chars, 100 overlap)
|
| 176 |
+
2. *Sentence* - Sentence-level chunking (NLTK)
|
| 177 |
+
3. *Paragraph* - Paragraph-level chunking (\\n\\n boundaries)
|
| 178 |
+
4. *Semantic* - Semantic chunking (embedding similarity)
|
| 179 |
+
5. *Recursive* - Recursive chunking (hierarchical separators)
|
| 180 |
+
6. *Page* - Page-level chunking (--- Page markers)
|
| 181 |
+
|
| 182 |
+
## Results by Technique (Aggregated Across All Queries)
|
| 183 |
+
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
# Aggregate results across all queries
|
| 187 |
+
aggregated_results = {}
|
| 188 |
+
|
| 189 |
+
for query_idx, query_results in all_query_results.items():
|
| 190 |
+
for technique_name, model_results in query_results.items():
|
| 191 |
+
if technique_name not in aggregated_results:
|
| 192 |
+
aggregated_results[technique_name] = {}
|
| 193 |
+
|
| 194 |
+
for model_name, results in model_results.items():
|
| 195 |
+
if model_name not in aggregated_results[technique_name]:
|
| 196 |
+
aggregated_results[technique_name][model_name] = {
|
| 197 |
+
'Faithfulness': [],
|
| 198 |
+
'Relevancy': [],
|
| 199 |
+
'answers': [],
|
| 200 |
+
'context_chunks': results.get('context_chunks', []),
|
| 201 |
+
'context_urls': results.get('context_urls', [])
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
aggregated_results[technique_name][model_name]['Faithfulness'].append(results.get('Faithfulness', 0))
|
| 205 |
+
aggregated_results[technique_name][model_name]['Relevancy'].append(results.get('Relevancy', 0))
|
| 206 |
+
aggregated_results[technique_name][model_name]['answers'].append(results.get('answer', ''))
|
| 207 |
+
|
| 208 |
+
# Add results for each technique
|
| 209 |
+
for technique_name, model_results in aggregated_results.items():
|
| 210 |
+
content += f"### {technique_name.upper()} Chunking\n\n"
|
| 211 |
+
|
| 212 |
+
if not model_results:
|
| 213 |
+
content += "No results available for this technique.\n\n"
|
| 214 |
+
continue
|
| 215 |
+
|
| 216 |
+
# Create results table with averaged scores
|
| 217 |
+
content += "| Model | Avg Faithfulness | Avg Relevancy | Avg Combined |\n"
|
| 218 |
+
content += "|-------|------------------|---------------|--------------|\n"
|
| 219 |
+
|
| 220 |
+
for model_name, results in model_results.items():
|
| 221 |
+
avg_faith = sum(results['Faithfulness']) / len(results['Faithfulness']) if results['Faithfulness'] else 0
|
| 222 |
+
avg_rel = sum(results['Relevancy']) / len(results['Relevancy']) if results['Relevancy'] else 0
|
| 223 |
+
avg_combined = avg_faith + avg_rel
|
| 224 |
+
content += f"| {model_name} | {avg_faith:.1f}% | {avg_rel:.3f} | {avg_combined:.3f} |\n"
|
| 225 |
+
|
| 226 |
+
# Find best model for this technique
|
| 227 |
+
if model_results:
|
| 228 |
+
best_model = max(
|
| 229 |
+
model_results.items(),
|
| 230 |
+
key=lambda x: (sum(x[1]['Faithfulness']) / len(x[1]['Faithfulness']) if x[1]['Faithfulness'] else 0) +
|
| 231 |
+
(sum(x[1]['Relevancy']) / len(x[1]['Relevancy']) if x[1]['Relevancy'] else 0)
|
| 232 |
+
)
|
| 233 |
+
best_name = best_model[0]
|
| 234 |
+
best_faith = sum(best_model[1]['Faithfulness']) / len(best_model[1]['Faithfulness']) if best_model[1]['Faithfulness'] else 0
|
| 235 |
+
best_rel = sum(best_model[1]['Relevancy']) / len(best_model[1]['Relevancy']) if best_model[1]['Relevancy'] else 0
|
| 236 |
+
|
| 237 |
+
content += f"\n*Best Model:* {best_name} (Avg Faithfulness: {best_faith:.1f}%, Avg Relevancy: {best_rel:.3f})\n\n"
|
| 238 |
+
|
| 239 |
+
# Show context chunks once per technique (not per model)
|
| 240 |
+
context_chunks = None
|
| 241 |
+
context_urls = None
|
| 242 |
+
for model_name, results in model_results.items():
|
| 243 |
+
if results.get('context_chunks'):
|
| 244 |
+
context_chunks = results['context_chunks']
|
| 245 |
+
context_urls = results.get('context_urls', [])
|
| 246 |
+
break
|
| 247 |
+
|
| 248 |
+
if context_chunks:
|
| 249 |
+
content += "#### Context Chunks Used\n\n"
|
| 250 |
+
for i, chunk in enumerate(context_chunks, 1):
|
| 251 |
+
url = context_urls[i-1] if context_urls and i-1 < len(context_urls) else ""
|
| 252 |
+
if url:
|
| 253 |
+
content += f"*Chunk {i}* ([Source]({url})):\n"
|
| 254 |
+
else:
|
| 255 |
+
content += f"*Chunk {i}*:\n"
|
| 256 |
+
content += f"\n{chunk}\n\n\n"
|
| 257 |
+
|
| 258 |
+
# Add detailed RAG results for each model
|
| 259 |
+
content += "#### Detailed RAG Results\n\n"
|
| 260 |
+
|
| 261 |
+
for model_name, results in model_results.items():
|
| 262 |
+
answers = results.get('answers', [])
|
| 263 |
+
avg_faith = sum(results['Faithfulness']) / len(results['Faithfulness']) if results['Faithfulness'] else 0
|
| 264 |
+
avg_rel = sum(results['Relevancy']) / len(results['Relevancy']) if results['Relevancy'] else 0
|
| 265 |
+
|
| 266 |
+
content += f"*{model_name}* (Avg Faithfulness: {avg_faith:.1f}%, Avg Relevancy: {avg_rel:.3f})\n\n"
|
| 267 |
+
|
| 268 |
+
# Show answers from each query
|
| 269 |
+
for q_idx, answer in enumerate(answers):
|
| 270 |
+
content += f"📝 *Answer for Query {q_idx + 1}:*\n\n"
|
| 271 |
+
content += f"\n{answer}\n\n\n"
|
| 272 |
+
|
| 273 |
+
content += "---\n\n"
|
| 274 |
+
|
| 275 |
+
# Add comparative analysis
|
| 276 |
+
content += """## Comparative Analysis
|
| 277 |
+
|
| 278 |
+
### Overall Performance Ranking (Across All Queries)
|
| 279 |
+
|
| 280 |
+
| Rank | Technique | Avg Faithfulness | Avg Relevancy | Avg Combined |
|
| 281 |
+
|------|-----------|------------------|---------------|--------------|
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
# Calculate averages for each technique across all queries
|
| 285 |
+
technique_averages = {}
|
| 286 |
+
for technique_name, model_results in aggregated_results.items():
|
| 287 |
+
if model_results:
|
| 288 |
+
all_faith = []
|
| 289 |
+
all_rel = []
|
| 290 |
+
for model_name, results in model_results.items():
|
| 291 |
+
all_faith.extend(results['Faithfulness'])
|
| 292 |
+
all_rel.extend(results['Relevancy'])
|
| 293 |
+
|
| 294 |
+
avg_faith = sum(all_faith) / len(all_faith) if all_faith else 0
|
| 295 |
+
avg_rel = sum(all_rel) / len(all_rel) if all_rel else 0
|
| 296 |
+
avg_combined = avg_faith + avg_rel
|
| 297 |
+
technique_averages[technique_name] = {
|
| 298 |
+
'faith': avg_faith,
|
| 299 |
+
'rel': avg_rel,
|
| 300 |
+
'combined': avg_combined
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
# Sort by combined score
|
| 304 |
+
sorted_techniques = sorted(
|
| 305 |
+
technique_averages.items(),
|
| 306 |
+
key=lambda x: x[1]['combined'],
|
| 307 |
+
reverse=True
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
for rank, (technique_name, averages) in enumerate(sorted_techniques, 1):
|
| 311 |
+
content += f"| {rank} | {technique_name} | {averages['faith']:.1f}% | {averages['rel']:.3f} | {averages['combined']:.3f} |\n"
|
| 312 |
+
|
| 313 |
+
content += """
|
| 314 |
+
### Key Findings
|
| 315 |
+
|
| 316 |
+
"""
|
| 317 |
+
|
| 318 |
+
if sorted_techniques:
|
| 319 |
+
best_technique = sorted_techniques[0][0]
|
| 320 |
+
worst_technique = sorted_techniques[-1][0]
|
| 321 |
+
|
| 322 |
+
content += f"""
|
| 323 |
+
1. *Best Performing Technique:* {best_technique}
|
| 324 |
+
- Achieved highest combined score across all models and queries
|
| 325 |
+
- Recommended for production RAG applications
|
| 326 |
+
|
| 327 |
+
2. *Worst Performing Technique:* {worst_technique}
|
| 328 |
+
- Lower combined scores across models and queries
|
| 329 |
+
- May need optimization or different configuration
|
| 330 |
+
|
| 331 |
+
3. *Model Consistency:*
|
| 332 |
+
- Analyzed which models perform consistently across techniques
|
| 333 |
+
- Identified technique-specific model preferences
|
| 334 |
+
|
| 335 |
+
"""
|
| 336 |
+
|
| 337 |
+
content += """## Recommendations
|
| 338 |
+
|
| 339 |
+
Based on the ablation study results:
|
| 340 |
+
|
| 341 |
+
1. *Primary Recommendation:* Use the best-performing chunking technique for your specific use case
|
| 342 |
+
2. *Hybrid Approach:* Consider combining techniques for different types of queries
|
| 343 |
+
3. *Model Selection:* Choose models that perform well across multiple techniques
|
| 344 |
+
4. *Parameter Tuning:* Optimize chunk sizes and overlaps based on document characteristics
|
| 345 |
+
|
| 346 |
+
## Technical Details
|
| 347 |
+
|
| 348 |
+
- *Embedding Model:* Jina embeddings (512 dimensions)
|
| 349 |
+
- *Vector Database:* Pinecone (serverless, AWS us-east-1)
|
| 350 |
+
- *Judge Model:* Openrouter Free models
|
| 351 |
+
- *Retrieval:* Top 5 chunks per technique
|
| 352 |
+
- *Evaluation Metrics:* Faithfulness (context grounding), Relevancy (query addressing)
|
| 353 |
+
|
| 354 |
+
---
|
| 355 |
+
|
| 356 |
+
This report was automatically generated by the RAG Ablation Study Pipeline.
|
| 357 |
+
"""
|
| 358 |
+
|
| 359 |
+
# Write to file
|
| 360 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 361 |
+
f.write(content)
|
| 362 |
+
|
| 363 |
+
print(f"\nFindings document saved to: {output_file}")
|
| 364 |
+
return output_file
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
# Global variables for worker processes
|
| 368 |
+
_worker_proc = None
|
| 369 |
+
_worker_evaluator = None
|
| 370 |
+
_worker_models = None
|
| 371 |
+
_worker_rag_engine = None
|
| 372 |
+
_worker_reranker = None
|
| 373 |
+
|
| 374 |
+
def init_worker(model_name, evaluator_config):
|
| 375 |
+
"""Initialize models once per worker process."""
|
| 376 |
+
global _worker_proc, _worker_evaluator, _worker_models, _worker_rag_engine, _worker_reranker
|
| 377 |
+
|
| 378 |
+
from retriever.processor import ChunkProcessor
|
| 379 |
+
from retriever.evaluator import RAGEvaluator
|
| 380 |
+
from retriever.generator import RAGGenerator
|
| 381 |
+
from sentence_transformers import CrossEncoder
|
| 382 |
+
from models.llama_3_8b import Llama3_8B
|
| 383 |
+
from models.mistral_7b import Mistral_7b
|
| 384 |
+
from models.qwen_2_5 import Qwen2_5
|
| 385 |
+
from models.deepseek_v3 import DeepSeek_V3
|
| 386 |
+
from models.tiny_aya import TinyAya
|
| 387 |
+
|
| 388 |
+
MODEL_MAP = {
|
| 389 |
+
"Llama-3-8B": Llama3_8B,
|
| 390 |
+
"Mistral-7B": Mistral_7b,
|
| 391 |
+
"Qwen-2.5": Qwen2_5,
|
| 392 |
+
"DeepSeek-V3": DeepSeek_V3,
|
| 393 |
+
"TinyAya": TinyAya
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
# Load embedding model once
|
| 397 |
+
_worker_proc = ChunkProcessor(model_name=model_name, verbose=False)
|
| 398 |
+
|
| 399 |
+
# Initialize evaluator
|
| 400 |
+
_worker_evaluator = RAGEvaluator(
|
| 401 |
+
judge_model=evaluator_config['judge_model'],
|
| 402 |
+
embedding_model=_worker_proc.encoder,
|
| 403 |
+
api_key=evaluator_config['api_key']
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
# Initialize models
|
| 407 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 408 |
+
_worker_models = {name: MODEL_MAP[name](token=hf_token) for name in evaluator_config['model_list']}
|
| 409 |
+
|
| 410 |
+
# Initialize RAG engine
|
| 411 |
+
_worker_rag_engine = RAGGenerator()
|
| 412 |
+
|
| 413 |
+
# Load reranker once per worker
|
| 414 |
+
_worker_reranker = CrossEncoder('jinaai/jina-reranker-v1-tiny-en')
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def run_rag_for_technique_wrapper(args):
|
| 418 |
+
"""Wrapper function for parallel execution."""
|
| 419 |
+
global _worker_proc, _worker_evaluator, _worker_models, _worker_rag_engine
|
| 420 |
+
|
| 421 |
+
technique, query, index_name, pinecone_key = args
|
| 422 |
+
try:
|
| 423 |
+
# Create new connection in worker process
|
| 424 |
+
from data.vector_db import get_index_by_name
|
| 425 |
+
index = get_index_by_name(pinecone_key, index_name)
|
| 426 |
+
|
| 427 |
+
return technique['name'], run_rag_for_technique(
|
| 428 |
+
technique_name=technique['name'],
|
| 429 |
+
query=query,
|
| 430 |
+
index=index,
|
| 431 |
+
encoder=_worker_proc.encoder,
|
| 432 |
+
models=_worker_models,
|
| 433 |
+
evaluator=_worker_evaluator,
|
| 434 |
+
rag_engine=_worker_rag_engine
|
| 435 |
+
)
|
| 436 |
+
except Exception as e:
|
| 437 |
+
import traceback
|
| 438 |
+
print(f"\n✗ Error processing technique {technique['name']}: {e}")
|
| 439 |
+
print(f"Full traceback:")
|
| 440 |
+
traceback.print_exc()
|
| 441 |
+
return technique['name'], {}
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def main():
|
| 445 |
+
"""Main function to run RAG ablation study across all 6 chunking techniques."""
|
| 446 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 447 |
+
pinecone_key = os.getenv("PINECONE_API_KEY")
|
| 448 |
+
openrouter_key = os.getenv("OPENROUTER_API_KEY")
|
| 449 |
+
|
| 450 |
+
# Verify environment variables
|
| 451 |
+
if not hf_token:
|
| 452 |
+
raise RuntimeError("HF_TOKEN not found in environment variables")
|
| 453 |
+
if not pinecone_key:
|
| 454 |
+
raise RuntimeError("PINECONE_API_KEY not found in environment variables")
|
| 455 |
+
if not openrouter_key:
|
| 456 |
+
raise RuntimeError("OPENROUTER_API_KEY not found in environment variables")
|
| 457 |
+
|
| 458 |
+
# Test queries
|
| 459 |
+
test_queries = [
|
| 460 |
+
"What is cognitive behavior therapy and how does it work?",
|
| 461 |
+
"What are the common cognitive distortions in CBT?",
|
| 462 |
+
"How does CBT help with anxiety and depression?"
|
| 463 |
+
]
|
| 464 |
+
|
| 465 |
+
print("=" * 80)
|
| 466 |
+
print("RAG ABLATION STUDY - 6 CHUNKING TECHNIQUES")
|
| 467 |
+
print("=" * 80)
|
| 468 |
+
print(f"\nTest Queries:")
|
| 469 |
+
for i, q in enumerate(test_queries, 1):
|
| 470 |
+
print(f" {i}. {q}")
|
| 471 |
+
|
| 472 |
+
# Step 1: Check if data already exists, skip ingestion if so
|
| 473 |
+
print("\n" + "=" * 80)
|
| 474 |
+
print("STEP 1: CHECKING/INGESTING DATA WITH ALL 6 TECHNIQUES")
|
| 475 |
+
print("=" * 80)
|
| 476 |
+
|
| 477 |
+
# Check if index already has data
|
| 478 |
+
from data.vector_db import get_index_by_name
|
| 479 |
+
index_name = f"{cfg.db['base_index_name']}-{cfg.processing['technique']}"
|
| 480 |
+
|
| 481 |
+
print(f"\nChecking for existing index: {index_name}")
|
| 482 |
+
|
| 483 |
+
try:
|
| 484 |
+
# Try to connect to existing index
|
| 485 |
+
print("Connecting to Pinecone...")
|
| 486 |
+
existing_index = get_index_by_name(pinecone_key, index_name)
|
| 487 |
+
print("Getting index stats...")
|
| 488 |
+
stats = existing_index.describe_index_stats()
|
| 489 |
+
existing_count = stats.get('total_vector_count', 0)
|
| 490 |
+
|
| 491 |
+
if existing_count > 0:
|
| 492 |
+
print(f"\n✓ Found existing index with {existing_count} vectors")
|
| 493 |
+
print("Skipping ingestion - using existing data")
|
| 494 |
+
|
| 495 |
+
# Initialize processor (this loads the embedding model)
|
| 496 |
+
print("Loading embedding model for retrieval...")
|
| 497 |
+
from retriever.processor import ChunkProcessor
|
| 498 |
+
proc = ChunkProcessor(model_name=cfg.processing['embedding_model'], verbose=False)
|
| 499 |
+
index = existing_index
|
| 500 |
+
all_chunks = [] # Empty since we're using existing data
|
| 501 |
+
final_chunks = []
|
| 502 |
+
print("✓ Processor initialized")
|
| 503 |
+
else:
|
| 504 |
+
print("\nIndex exists but is empty. Running full ingestion...")
|
| 505 |
+
all_chunks, final_chunks, proc, index = ingest_data()
|
| 506 |
+
except Exception as e:
|
| 507 |
+
print(f"\nIndex check failed: {e}")
|
| 508 |
+
print("Running full ingestion...")
|
| 509 |
+
all_chunks, final_chunks, proc, index = ingest_data()
|
| 510 |
+
|
| 511 |
+
print(f"\nTechniques to evaluate: {[tech['name'] for tech in CHUNKING_TECHNIQUES]}")
|
| 512 |
+
|
| 513 |
+
# Step 2: Initialize components
|
| 514 |
+
print("\n" + "=" * 80)
|
| 515 |
+
print("STEP 2: INITIALIZING COMPONENTS")
|
| 516 |
+
print("=" * 80)
|
| 517 |
+
|
| 518 |
+
# Initialize models
|
| 519 |
+
print("\nInitializing models...")
|
| 520 |
+
rag_engine = RAGGenerator()
|
| 521 |
+
models = {name: MODEL_MAP[name](token=hf_token) for name in cfg.model_list}
|
| 522 |
+
|
| 523 |
+
# Initialize evaluator
|
| 524 |
+
print("Initializing evaluator...")
|
| 525 |
+
if not openrouter_key:
|
| 526 |
+
raise RuntimeError("OPENROUTER_API_KEY not found in environment variables")
|
| 527 |
+
|
| 528 |
+
evaluator = RAGEvaluator(
|
| 529 |
+
judge_model=cfg.gen['judge_model'],
|
| 530 |
+
embedding_model=proc.encoder,
|
| 531 |
+
api_key=openrouter_key
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
# Step 3: Run RAG for all techniques in parallel for all queries
|
| 535 |
+
print("\n" + "=" * 80)
|
| 536 |
+
print("STEP 3: RUNNING RAG FOR ALL 6 TECHNIQUES (IN PARALLEL)")
|
| 537 |
+
print("=" * 80)
|
| 538 |
+
|
| 539 |
+
# Prepare arguments for parallel execution
|
| 540 |
+
num_processes = min(cpu_count(), len(CHUNKING_TECHNIQUES))
|
| 541 |
+
print(f"\nUsing {num_processes} parallel processes for {len(CHUNKING_TECHNIQUES)} techniques")
|
| 542 |
+
|
| 543 |
+
# Run techniques in parallel for all queries
|
| 544 |
+
evaluator_config = {
|
| 545 |
+
'judge_model': cfg.gen['judge_model'],
|
| 546 |
+
'api_key': openrouter_key,
|
| 547 |
+
'model_list': cfg.model_list
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
all_query_results = {}
|
| 551 |
+
|
| 552 |
+
for query_idx, query in enumerate(test_queries):
|
| 553 |
+
print(f"\n{'='*80}")
|
| 554 |
+
print(f"PROCESSING QUERY {query_idx + 1}/{len(test_queries)}")
|
| 555 |
+
print(f"Query: {query}")
|
| 556 |
+
print(f"{'='*80}")
|
| 557 |
+
|
| 558 |
+
with Pool(
|
| 559 |
+
processes=num_processes,
|
| 560 |
+
initializer=init_worker,
|
| 561 |
+
initargs=(cfg.processing['embedding_model'], evaluator_config)
|
| 562 |
+
) as pool:
|
| 563 |
+
args_list = [
|
| 564 |
+
(technique, query, index_name, pinecone_key)
|
| 565 |
+
for technique in CHUNKING_TECHNIQUES
|
| 566 |
+
]
|
| 567 |
+
results_list = pool.map(run_rag_for_technique_wrapper, args_list)
|
| 568 |
+
|
| 569 |
+
# Convert results to dictionary and store
|
| 570 |
+
query_results = {name: results for name, results in results_list}
|
| 571 |
+
all_query_results[query_idx] = query_results
|
| 572 |
+
|
| 573 |
+
# Print quick summary for this query
|
| 574 |
+
print(f"\n{'='*80}")
|
| 575 |
+
print(f"QUERY {query_idx + 1} SUMMARY")
|
| 576 |
+
print(f"{'='*80}")
|
| 577 |
+
print(f"\n{'Technique':<15} {'Avg Faith':>12} {'Avg Rel':>12} {'Best Model':<20}")
|
| 578 |
+
print("-" * 60)
|
| 579 |
+
|
| 580 |
+
for technique_name, model_results in query_results.items():
|
| 581 |
+
if model_results:
|
| 582 |
+
avg_faith = sum(r.get('Faithfulness', 0) for r in model_results.values()) / len(model_results)
|
| 583 |
+
avg_rel = sum(r.get('Relevancy', 0) for r in model_results.values()) / len(model_results)
|
| 584 |
+
|
| 585 |
+
# Find best model
|
| 586 |
+
best_model = max(
|
| 587 |
+
model_results.items(),
|
| 588 |
+
key=lambda x: x[1].get('Faithfulness', 0) + x[1].get('Relevancy', 0)
|
| 589 |
+
)
|
| 590 |
+
best_name = best_model[0]
|
| 591 |
+
|
| 592 |
+
print(f"{technique_name:<15} {avg_faith:>11.1f}% {avg_rel:>12.3f} {best_name:<20}")
|
| 593 |
+
else:
|
| 594 |
+
print(f"{technique_name:<15} {'N/A':>12} {'N/A':>12} {'N/A':<20}")
|
| 595 |
+
|
| 596 |
+
print("-" * 60)
|
| 597 |
+
|
| 598 |
+
# Step 4: Generate findings document from all queries
|
| 599 |
+
print("\n" + "=" * 80)
|
| 600 |
+
print("STEP 4: GENERATING FINDINGS DOCUMENT")
|
| 601 |
+
print("=" * 80)
|
| 602 |
+
|
| 603 |
+
findings_file = generate_findings_document(all_query_results, test_queries)
|
| 604 |
+
|
| 605 |
+
# Step 5: Final summary
|
| 606 |
+
print("\n" + "=" * 80)
|
| 607 |
+
print("ABLATION STUDY COMPLETE - SUMMARY")
|
| 608 |
+
print("=" * 80)
|
| 609 |
+
|
| 610 |
+
print(f"\nQueries processed: {len(test_queries)}")
|
| 611 |
+
print(f"Techniques evaluated: {len(CHUNKING_TECHNIQUES)}")
|
| 612 |
+
print(f"Models tested: {len(cfg.model_list)}")
|
| 613 |
+
print(f"\nFindings document: {findings_file}")
|
| 614 |
+
|
| 615 |
+
# Print final summary across all queries
|
| 616 |
+
print("\n" + "-" * 60)
|
| 617 |
+
print(f"{'Technique':<15} {'Avg Faith':>12} {'Avg Rel':>12} {'Best Model':<20}")
|
| 618 |
+
print("-" * 60)
|
| 619 |
+
|
| 620 |
+
# Calculate averages across all queries
|
| 621 |
+
for tech_config in CHUNKING_TECHNIQUES:
|
| 622 |
+
tech_name = tech_config['name']
|
| 623 |
+
all_faith = []
|
| 624 |
+
all_rel = []
|
| 625 |
+
best_model_name = None
|
| 626 |
+
best_combined = 0
|
| 627 |
+
|
| 628 |
+
for query_idx, query_results in all_query_results.items():
|
| 629 |
+
if tech_name in query_results and query_results[tech_name]:
|
| 630 |
+
model_results = query_results[tech_name]
|
| 631 |
+
for model_name, results in model_results.items():
|
| 632 |
+
faith = results.get('Faithfulness', 0)
|
| 633 |
+
rel = results.get('Relevancy', 0)
|
| 634 |
+
combined = faith + rel
|
| 635 |
+
all_faith.append(faith)
|
| 636 |
+
all_rel.append(rel)
|
| 637 |
+
|
| 638 |
+
if combined > best_combined:
|
| 639 |
+
best_combined = combined
|
| 640 |
+
best_model_name = model_name
|
| 641 |
+
|
| 642 |
+
if all_faith:
|
| 643 |
+
avg_faith = sum(all_faith) / len(all_faith)
|
| 644 |
+
avg_rel = sum(all_rel) / len(all_rel)
|
| 645 |
+
print(f"{tech_name:<15} {avg_faith:>11.1f}% {avg_rel:>12.3f} {best_model_name or 'N/A':<20}")
|
| 646 |
+
else:
|
| 647 |
+
print(f"{tech_name:<15} {'N/A':>12} {'N/A':>12} {'N/A':<20}")
|
| 648 |
+
|
| 649 |
+
print("-" * 60)
|
| 650 |
+
|
| 651 |
+
print("\n✓ Ablation study complete!")
|
| 652 |
+
print(f"✓ Results saved to: {findings_file}")
|
| 653 |
+
print("\nYou can now analyze the findings document to compare chunking techniques.")
|
| 654 |
+
|
| 655 |
+
return all_query_results
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
if __name__ == "__main__":
|
| 659 |
+
main()
|
main_easy.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
from config_loader import cfg
|
| 5 |
+
|
| 6 |
+
# Optimized imports - only what we need for Retrieval and Generation
|
| 7 |
+
from data.vector_db import get_index_by_name, load_chunks_from_pinecone # Using the new helper
|
| 8 |
+
from retriever.retriever import HybridRetriever
|
| 9 |
+
from retriever.generator import RAGGenerator
|
| 10 |
+
from retriever.processor import ChunkProcessor
|
| 11 |
+
from retriever.evaluator import RAGEvaluator
|
| 12 |
+
|
| 13 |
+
# Model Fleet
|
| 14 |
+
from models.llama_3_8b import Llama3_8B
|
| 15 |
+
from models.mistral_7b import Mistral_7b
|
| 16 |
+
from models.qwen_2_5 import Qwen2_5
|
| 17 |
+
from models.deepseek_v3 import DeepSeek_V3
|
| 18 |
+
from models.tiny_aya import TinyAya
|
| 19 |
+
|
| 20 |
+
MODEL_MAP = {
|
| 21 |
+
"Llama-3-8B": Llama3_8B,
|
| 22 |
+
"Mistral-7B": Mistral_7b,
|
| 23 |
+
"Qwen-2.5": Qwen2_5,
|
| 24 |
+
"DeepSeek-V3": DeepSeek_V3,
|
| 25 |
+
"TinyAya": TinyAya
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
load_dotenv()
|
| 29 |
+
|
| 30 |
+
def main():
|
| 31 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 32 |
+
pinecone_key = os.getenv("PINECONE_API_KEY")
|
| 33 |
+
query = "How do transformers handle long sequences?"
|
| 34 |
+
|
| 35 |
+
# 1. Connect to Existing Index (No creation, no uploading)
|
| 36 |
+
# We use the slugified name directly or via config
|
| 37 |
+
index_name = f"{cfg.db['base_index_name']}-{cfg.processing['technique']}"
|
| 38 |
+
index = get_index_by_name(pinecone_key, index_name)
|
| 39 |
+
|
| 40 |
+
# 2. Setup Processor (Required for the Encoder/Embedding model)
|
| 41 |
+
proc = ChunkProcessor(model_name=cfg.processing['embedding_model'])
|
| 42 |
+
|
| 43 |
+
# 3. Load BM25 Corpus (The "Source of Truth")
|
| 44 |
+
# This replaces the entire data_loader/chunking block
|
| 45 |
+
# Note: On first run, this hits Pinecone. Use a pickle cache here for 0s delay.
|
| 46 |
+
print("🔄 Loading BM25 context from Pinecone metadata...")
|
| 47 |
+
final_chunks = load_chunks_from_pinecone(index)
|
| 48 |
+
|
| 49 |
+
# 4. Retrieval Setup
|
| 50 |
+
retriever = HybridRetriever(final_chunks, proc.encoder)
|
| 51 |
+
|
| 52 |
+
print(f"🔎 Searching via {cfg.retrieval['mode']} mode...")
|
| 53 |
+
context_chunks = retriever.search(
|
| 54 |
+
query, index,
|
| 55 |
+
mode=cfg.retrieval['mode'],
|
| 56 |
+
rerank_strategy=cfg.retrieval['rerank_strategy'],
|
| 57 |
+
use_mmr=cfg.retrieval['use_mmr'],
|
| 58 |
+
top_k=cfg.retrieval['top_k'],
|
| 59 |
+
final_k=cfg.retrieval['final_k']
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# 5. Initialization of Contestants
|
| 63 |
+
rag_engine = RAGGenerator()
|
| 64 |
+
models = {name: MODEL_MAP[name](token=hf_token) for name in cfg.model_list}
|
| 65 |
+
|
| 66 |
+
evaluator = RAGEvaluator(
|
| 67 |
+
judge_model=cfg.gen['judge_model'],
|
| 68 |
+
embedding_model=proc.encoder,
|
| 69 |
+
api_key=os.getenv("GROQ_API_KEY")
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
tournament_results = {}
|
| 73 |
+
|
| 74 |
+
# 6. Tournament Loop
|
| 75 |
+
for name, model_inst in models.items():
|
| 76 |
+
print(f"\n🏆 Tournament: {name} is generating...")
|
| 77 |
+
try:
|
| 78 |
+
# Generation
|
| 79 |
+
answer = rag_engine.get_answer(
|
| 80 |
+
model_inst, query, context_chunks,
|
| 81 |
+
temperature=cfg.gen['temperature']
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Faithfulness Evaluation
|
| 85 |
+
faith = evaluator.evaluate_faithfulness(answer, context_chunks)
|
| 86 |
+
# Relevancy Evaluation
|
| 87 |
+
rel = evaluator.evaluate_relevancy(query, answer)
|
| 88 |
+
|
| 89 |
+
tournament_results[name] = {
|
| 90 |
+
"Answer": answer[:100] + "...", # Preview
|
| 91 |
+
"Faithfulness": faith['score'],
|
| 92 |
+
"Relevancy": rel['score']
|
| 93 |
+
}
|
| 94 |
+
print(f"✅ {name} Score - Faith: {faith['score']} | Rel: {rel['score']}")
|
| 95 |
+
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print(f"❌ Error evaluating {name}: {e}")
|
| 98 |
+
|
| 99 |
+
print("\n--- Final Tournament Standings ---")
|
| 100 |
+
for name, scores in tournament_results.items():
|
| 101 |
+
print(f"{name}: F={scores['Faithfulness']}, R={scores['Relevancy']}")
|
| 102 |
+
|
| 103 |
+
if __name__ == "__main__":
|
| 104 |
+
main()
|
models/deepseek_v3.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import InferenceClient
|
| 2 |
+
|
| 3 |
+
class DeepSeek_V3:
|
| 4 |
+
def __init__(self, token):
|
| 5 |
+
self.client = InferenceClient(token=token)
|
| 6 |
+
self.model_id = "deepseek-ai/DeepSeek-V3"
|
| 7 |
+
|
| 8 |
+
def generate_stream(self, prompt, max_tokens=1500, temperature=0.1):
|
| 9 |
+
try:
|
| 10 |
+
for message in self.client.chat_completion(
|
| 11 |
+
model=self.model_id,
|
| 12 |
+
messages=[{"role": "user", "content": prompt}],
|
| 13 |
+
max_tokens=max_tokens,
|
| 14 |
+
temperature=temperature,
|
| 15 |
+
stream=True,
|
| 16 |
+
):
|
| 17 |
+
if message.choices:
|
| 18 |
+
content = message.choices[0].delta.content
|
| 19 |
+
if content:
|
| 20 |
+
yield content
|
| 21 |
+
except Exception as e:
|
| 22 |
+
yield f" DeepSeek API Busy: {e}"
|
| 23 |
+
|
| 24 |
+
def generate(self, prompt, max_tokens=500, temperature=0.1):
|
| 25 |
+
return "".join(self.generate_stream(prompt, max_tokens=max_tokens, temperature=temperature))
|
models/llama_3_8b.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import InferenceClient
|
| 2 |
+
|
| 3 |
+
class Llama3_8B:
|
| 4 |
+
def __init__(self, token):
|
| 5 |
+
self.client = InferenceClient(token=token)
|
| 6 |
+
self.model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
| 7 |
+
|
| 8 |
+
def generate_stream(self, prompt, max_tokens=1500, temperature=0.1):
|
| 9 |
+
for message in self.client.chat_completion(
|
| 10 |
+
model=self.model_id,
|
| 11 |
+
messages=[{"role": "user", "content": prompt}],
|
| 12 |
+
max_tokens=max_tokens,
|
| 13 |
+
temperature=temperature,
|
| 14 |
+
stream=True,
|
| 15 |
+
):
|
| 16 |
+
if message.choices:
|
| 17 |
+
content = message.choices[0].delta.content
|
| 18 |
+
if content:
|
| 19 |
+
yield content
|
| 20 |
+
|
| 21 |
+
def generate(self, prompt, max_tokens=500, temperature=0.1):
|
| 22 |
+
return "".join(self.generate_stream(prompt, max_tokens=max_tokens, temperature=temperature))
|
models/mistral_7b.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import InferenceClient
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
class Mistral_7b:
|
| 5 |
+
def __init__(self, token):
|
| 6 |
+
self.client = InferenceClient(api_key=token)
|
| 7 |
+
# Provider-suffixed ids (e.g. :featherless-ai) are not valid HF repo ids.
|
| 8 |
+
# Keep a sane default and allow override via env for experimentation.
|
| 9 |
+
self.model_id = os.getenv("MISTRAL_MODEL_ID", "mistralai/Mistral-7B-Instruct-v0.2")
|
| 10 |
+
|
| 11 |
+
def generate_stream(self, prompt, max_tokens=1500, temperature=0.1):
|
| 12 |
+
try:
|
| 13 |
+
stream = self.client.chat.completions.create(
|
| 14 |
+
model=self.model_id,
|
| 15 |
+
messages=[{"role": "user", "content": prompt}],
|
| 16 |
+
max_tokens=max_tokens,
|
| 17 |
+
temperature=temperature,
|
| 18 |
+
stream=True,
|
| 19 |
+
)
|
| 20 |
+
for chunk in stream:
|
| 21 |
+
if chunk.choices and chunk.choices[0].delta.content:
|
| 22 |
+
content = chunk.choices[0].delta.content
|
| 23 |
+
yield content
|
| 24 |
+
|
| 25 |
+
except Exception as e:
|
| 26 |
+
yield f" Mistral Featherless Error: {e}"
|
| 27 |
+
|
| 28 |
+
def generate(self, prompt, max_tokens=500, temperature=0.1):
|
| 29 |
+
return "".join(self.generate_stream(prompt, max_tokens=max_tokens, temperature=temperature))
|
models/qwen_2_5.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import InferenceClient
|
| 2 |
+
|
| 3 |
+
class Qwen2_5:
|
| 4 |
+
def __init__(self, token):
|
| 5 |
+
self.client = InferenceClient(token=token)
|
| 6 |
+
self.model_id = "Qwen/Qwen2.5-72B-Instruct"
|
| 7 |
+
|
| 8 |
+
def generate_stream(self, prompt, max_tokens=1500, temperature=0.1):
|
| 9 |
+
for message in self.client.chat_completion(
|
| 10 |
+
model=self.model_id,
|
| 11 |
+
messages=[{"role": "user", "content": prompt}],
|
| 12 |
+
max_tokens=max_tokens,
|
| 13 |
+
temperature=temperature,
|
| 14 |
+
stream=True,
|
| 15 |
+
):
|
| 16 |
+
if message.choices:
|
| 17 |
+
content = message.choices[0].delta.content
|
| 18 |
+
if content:
|
| 19 |
+
yield content
|
| 20 |
+
|
| 21 |
+
def generate(self, prompt, max_tokens=500, temperature=0.1):
|
| 22 |
+
return "".join(self.generate_stream(prompt, max_tokens=max_tokens, temperature=temperature))
|
models/tiny_aya.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import InferenceClient
|
| 2 |
+
|
| 3 |
+
class TinyAya:
|
| 4 |
+
def __init__(self, token):
|
| 5 |
+
self.client = InferenceClient(token=token)
|
| 6 |
+
self.model_id = "CohereLabs/tiny-aya-global"
|
| 7 |
+
|
| 8 |
+
def generate_stream(self, prompt, max_tokens=1500, temperature=0.1):
|
| 9 |
+
try:
|
| 10 |
+
for message in self.client.chat_completion(
|
| 11 |
+
model=self.model_id,
|
| 12 |
+
messages=[{"role": "user", "content": prompt}],
|
| 13 |
+
max_tokens=max_tokens,
|
| 14 |
+
temperature=temperature,
|
| 15 |
+
stream=True,
|
| 16 |
+
):
|
| 17 |
+
if message.choices:
|
| 18 |
+
content = message.choices[0].delta.content
|
| 19 |
+
if content:
|
| 20 |
+
yield content
|
| 21 |
+
except Exception as e:
|
| 22 |
+
yield f" TinyAya Error: {e}"
|
| 23 |
+
|
| 24 |
+
def generate(self, prompt, max_tokens=500, temperature=0.1):
|
| 25 |
+
return "".join(self.generate_stream(prompt, max_tokens=max_tokens, temperature=temperature))
|
requirements.txt
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiohappyeyeballs==2.6.1
|
| 2 |
+
aiohttp==3.13.3
|
| 3 |
+
aiosignal==1.4.0
|
| 4 |
+
annotated-doc==0.0.4
|
| 5 |
+
annotated-types==0.7.0
|
| 6 |
+
anyio==4.12.1
|
| 7 |
+
arxiv==2.4.1
|
| 8 |
+
attrs==26.1.0
|
| 9 |
+
certifi==2026.2.25
|
| 10 |
+
charset-normalizer==3.4.6
|
| 11 |
+
click==8.3.1
|
| 12 |
+
colorama==0.4.6
|
| 13 |
+
dataclasses-json==0.6.7
|
| 14 |
+
feedparser==6.0.12
|
| 15 |
+
fastapi==0.121.1
|
| 16 |
+
filelock==3.25.2
|
| 17 |
+
frozenlist==1.8.0
|
| 18 |
+
fsspec==2026.2.0
|
| 19 |
+
greenlet==3.3.2
|
| 20 |
+
h11==0.16.0
|
| 21 |
+
hf-xet==1.4.2
|
| 22 |
+
httpcore==1.0.9
|
| 23 |
+
httpx==0.28.1
|
| 24 |
+
httpx-sse==0.4.3
|
| 25 |
+
huggingface_hub==0.36.0
|
| 26 |
+
idna==3.11
|
| 27 |
+
Jinja2==3.1.6
|
| 28 |
+
joblib==1.5.3
|
| 29 |
+
jsonpatch==1.33
|
| 30 |
+
jsonpointer==3.1.1
|
| 31 |
+
langchain-classic==1.0.3
|
| 32 |
+
langchain-community==0.4.1
|
| 33 |
+
langchain-core==1.2.21
|
| 34 |
+
langchain-experimental==0.4.1
|
| 35 |
+
langchain-huggingface==1.2.1
|
| 36 |
+
langchain-text-splitters==1.1.1
|
| 37 |
+
langsmith==0.7.22
|
| 38 |
+
markdown-it-py==4.0.0
|
| 39 |
+
MarkupSafe==3.0.3
|
| 40 |
+
marshmallow==3.26.2
|
| 41 |
+
mdurl==0.1.2
|
| 42 |
+
mpmath==1.3.0
|
| 43 |
+
multidict==6.7.1
|
| 44 |
+
mypy_extensions==1.1.0
|
| 45 |
+
networkx==3.6.1
|
| 46 |
+
nltk==3.9.4
|
| 47 |
+
numpy==2.4.3
|
| 48 |
+
orjson==3.11.7
|
| 49 |
+
packaging==24.2
|
| 50 |
+
pandas==3.0.1
|
| 51 |
+
pinecone==8.1.0
|
| 52 |
+
pinecone-plugin-assistant==3.0.2
|
| 53 |
+
pinecone-plugin-interface==0.0.7
|
| 54 |
+
propcache==0.4.1
|
| 55 |
+
pydantic==2.12.5
|
| 56 |
+
pydantic-settings==2.13.1
|
| 57 |
+
pydantic_core==2.41.5
|
| 58 |
+
Pygments==2.19.2
|
| 59 |
+
PyMuPDF==1.27.2.2
|
| 60 |
+
python-dateutil==2.9.0.post0
|
| 61 |
+
python-dotenv==1.2.2
|
| 62 |
+
PyYAML==6.0.3
|
| 63 |
+
rank-bm25==0.2.2
|
| 64 |
+
regex==2026.2.28
|
| 65 |
+
requests==2.32.5
|
| 66 |
+
requests-toolbelt==1.0.0
|
| 67 |
+
rich==14.3.3
|
| 68 |
+
safetensors==0.7.0
|
| 69 |
+
scikit-learn==1.8.0
|
| 70 |
+
scipy==1.17.1
|
| 71 |
+
sentence-transformers==5.3.0
|
| 72 |
+
setuptools==81.0.0
|
| 73 |
+
sgmllib3k==1.0.0
|
| 74 |
+
shellingham==1.5.4
|
| 75 |
+
six==1.17.0
|
| 76 |
+
SQLAlchemy==2.0.48
|
| 77 |
+
sympy==1.14.0
|
| 78 |
+
tenacity==9.1.4
|
| 79 |
+
threadpoolctl==3.6.0
|
| 80 |
+
tokenizers==0.22.2
|
| 81 |
+
torch==2.11.0
|
| 82 |
+
tqdm==4.67.3
|
| 83 |
+
transformers==4.57.1
|
| 84 |
+
typer==0.24.1
|
| 85 |
+
typing-inspect==0.9.0
|
| 86 |
+
typing-inspection==0.4.2
|
| 87 |
+
typing_extensions==4.15.0
|
| 88 |
+
tzdata==2025.3
|
| 89 |
+
urllib3==2.6.3
|
| 90 |
+
uvicorn==0.38.0
|
| 91 |
+
uuid_utils==0.14.1
|
| 92 |
+
xxhash==3.6.0
|
| 93 |
+
yarl==1.23.0
|
| 94 |
+
zstandard==0.25.0
|
| 95 |
+
groq==1.1.2
|
| 96 |
+
jiter==0.13.0
|
| 97 |
+
openai==2.30.0
|
retriever/evaluator.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import numpy as np
|
| 3 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 4 |
+
from openai import OpenAI
|
| 5 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# ------------------------------------------------------------------
|
| 9 |
+
# OpenRouter Judge Wrapper
|
| 10 |
+
# ------------------------------------------------------------------
|
| 11 |
+
|
| 12 |
+
class GroqJudge:
|
| 13 |
+
def __init__(self, api_key: str, model: str = "deepseek/deepseek-v3.2",):
|
| 14 |
+
"""
|
| 15 |
+
Wraps OpenRouter's chat completions to match the .generate(prompt) interface
|
| 16 |
+
expected by RAGEvaluator.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
api_key: Your OpenRouter API key (https://openrouter.ai)
|
| 20 |
+
model: OpenRouter model to use (primary model with fallback support)
|
| 21 |
+
"""
|
| 22 |
+
self.client = OpenAI(
|
| 23 |
+
base_url="https://openrouter.ai/api/v1",
|
| 24 |
+
api_key=api_key,
|
| 25 |
+
)
|
| 26 |
+
self.model = model
|
| 27 |
+
|
| 28 |
+
# Fallback models in order of preference (OpenRouter free models)
|
| 29 |
+
self.fallback_models = [
|
| 30 |
+
"deepseek/deepseek-v3.2",
|
| 31 |
+
"qwen/qwen3.6-plus-preview:free",
|
| 32 |
+
"stepfun/step-3.5-flash:free",
|
| 33 |
+
"nvidia/nemotron-3-super-120b-a12b:free",
|
| 34 |
+
"z-ai/glm-4.5-air:free",
|
| 35 |
+
"nvidia/nemotron-3-nano-30b-a3b:free",
|
| 36 |
+
"arcee-ai/trinity-mini:free",
|
| 37 |
+
"xiaomi/mimo-v2-flash"
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
def generate(self, prompt: str) -> str:
|
| 41 |
+
"""Generate response with fallback support for multiple models."""
|
| 42 |
+
last_error = None
|
| 43 |
+
|
| 44 |
+
# Try primary model first, then fallbacks
|
| 45 |
+
models_to_try = [self.model] + [m for m in self.fallback_models if m != self.model]
|
| 46 |
+
|
| 47 |
+
for model_name in models_to_try:
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
response = self.client.chat.completions.create(
|
| 52 |
+
model=model_name,
|
| 53 |
+
messages=[{"role": "user", "content": prompt}],
|
| 54 |
+
)
|
| 55 |
+
content = response.choices[0].message.content
|
| 56 |
+
if content is None:
|
| 57 |
+
raise ValueError(f"Model {model_name} returned None content")
|
| 58 |
+
return content.strip()
|
| 59 |
+
except Exception as e:
|
| 60 |
+
last_error = e
|
| 61 |
+
# If rate limited or model unavailable, try next model
|
| 62 |
+
if "429" in str(e) or "rate_limit" in str(e).lower() or "model" in str(e).lower():
|
| 63 |
+
continue
|
| 64 |
+
# For other errors, raise immediately
|
| 65 |
+
raise
|
| 66 |
+
|
| 67 |
+
# If all models fail, raise the last error
|
| 68 |
+
raise last_error
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ------------------------------------------------------------------
|
| 72 |
+
# RAG Evaluator
|
| 73 |
+
# ------------------------------------------------------------------
|
| 74 |
+
|
| 75 |
+
class RAGEvaluator:
|
| 76 |
+
def __init__(self, judge_model: str, embedding_model, api_key: str, verbose=True):
|
| 77 |
+
"""
|
| 78 |
+
judge_model: Model name string passed to OpenRouterJudge, must match cfg.gen['judge_model']
|
| 79 |
+
e.g. "stepfun/step-3.5-flash:free", "nvidia/nemotron-3-super-120b-a12b:free"
|
| 80 |
+
embedding_model: The proc.encoder (SentenceTransformer) for similarity checks
|
| 81 |
+
api_key: OpenRouter API key (https://openrouter.ai)
|
| 82 |
+
verbose: If True, prints progress via internal helpers
|
| 83 |
+
"""
|
| 84 |
+
self.judge = GroqJudge(api_key=api_key, model=judge_model)
|
| 85 |
+
self.encoder = embedding_model
|
| 86 |
+
self.verbose = verbose
|
| 87 |
+
|
| 88 |
+
# ------------------------------------------------------------------
|
| 89 |
+
# 1. FAITHFULNESS: Claim Extraction & Verification
|
| 90 |
+
# ------------------------------------------------------------------
|
| 91 |
+
|
| 92 |
+
def evaluate_faithfulness(self, answer: str, context_list: list[str], strict: bool = True) -> dict:
|
| 93 |
+
"""
|
| 94 |
+
Args:
|
| 95 |
+
strict: If True, verifies each claim against chunks individually
|
| 96 |
+
(more API calls but catches vague batch verdicts).
|
| 97 |
+
If False, uses single batched verification call.
|
| 98 |
+
"""
|
| 99 |
+
if self.verbose:
|
| 100 |
+
self._print_extraction_header(len(answer), strict=strict)
|
| 101 |
+
|
| 102 |
+
# --- Step A: Extraction ---
|
| 103 |
+
extraction_prompt = (
|
| 104 |
+
"Extract a list of independent factual claims from the following answer.\n"
|
| 105 |
+
"Rules:\n"
|
| 106 |
+
"- Each claim must be specific and verifiable — include numbers, names, or concrete details where present\n"
|
| 107 |
+
"- Vague claims like 'the model performs well' or 'this improves results' are NOT acceptable\n"
|
| 108 |
+
"- Do NOT include claims about what the context does or does not contain\n"
|
| 109 |
+
"- Do NOT include introductory text, numbering, or bullet points\n"
|
| 110 |
+
"- Do NOT rephrase or merge claims\n"
|
| 111 |
+
"- One claim per line only\n\n"
|
| 112 |
+
f"Answer: {answer}"
|
| 113 |
+
)
|
| 114 |
+
raw_claims = self.judge.generate(extraction_prompt)
|
| 115 |
+
|
| 116 |
+
# Filter out short lines, preamble, and lines ending with ':'
|
| 117 |
+
claims = [
|
| 118 |
+
c.strip() for c in raw_claims.split('\n')
|
| 119 |
+
if len(c.strip()) > 20 and not c.strip().endswith(':')
|
| 120 |
+
]
|
| 121 |
+
|
| 122 |
+
if not claims:
|
| 123 |
+
return {"score": 0, "details": []}
|
| 124 |
+
|
| 125 |
+
# --- Step B: Verification ---
|
| 126 |
+
if strict:
|
| 127 |
+
# Per-chunk: claim must be explicitly supported by at least one chunk
|
| 128 |
+
# Parallelize across claims as well
|
| 129 |
+
def verify_claim_wrapper(args):
|
| 130 |
+
i, claim = args
|
| 131 |
+
return i, self._verify_claim_against_chunks(claim, context_list)
|
| 132 |
+
|
| 133 |
+
with ThreadPoolExecutor(max_workers=min(len(claims), 5)) as executor:
|
| 134 |
+
futures = [executor.submit(verify_claim_wrapper, (i, claim)) for i, claim in enumerate(claims)]
|
| 135 |
+
verdicts = {i: result for future in as_completed(futures) for i, result in [future.result()]}
|
| 136 |
+
else:
|
| 137 |
+
# Batch: all chunks joined, strict burden-of-proof prompt
|
| 138 |
+
combined_context = "\n".join(context_list)
|
| 139 |
+
if len(combined_context) > 6000:
|
| 140 |
+
combined_context = combined_context[:6000]
|
| 141 |
+
|
| 142 |
+
claims_formatted = "\n".join([f"{i+1}. {c}" for i, c in enumerate(claims)])
|
| 143 |
+
|
| 144 |
+
batch_prompt = (
|
| 145 |
+
f"Context:\n{combined_context}\n\n"
|
| 146 |
+
f"For each claim, respond YES only if the claim is EXPLICITLY and DIRECTLY "
|
| 147 |
+
f"supported by the context above. Respond NO if the claim is inferred, assumed, "
|
| 148 |
+
f"or not clearly stated in the context.\n\n"
|
| 149 |
+
f"Format strictly as:\n"
|
| 150 |
+
f"1: YES\n"
|
| 151 |
+
f"2: NO\n\n"
|
| 152 |
+
f"Claims:\n{claims_formatted}"
|
| 153 |
+
)
|
| 154 |
+
raw_verdicts = self.judge.generate(batch_prompt)
|
| 155 |
+
|
| 156 |
+
verdicts = {}
|
| 157 |
+
for line in raw_verdicts.split('\n'):
|
| 158 |
+
match = re.match(r'(\d+)\s*:\s*(YES|NO)', line.strip().upper())
|
| 159 |
+
if match:
|
| 160 |
+
verdicts[int(match.group(1)) - 1] = match.group(2) == "YES"
|
| 161 |
+
|
| 162 |
+
# --- Step C: Scoring & Details ---
|
| 163 |
+
verified_count = 0
|
| 164 |
+
details = []
|
| 165 |
+
for i, claim in enumerate(claims):
|
| 166 |
+
is_supported = verdicts.get(i, False)
|
| 167 |
+
if is_supported:
|
| 168 |
+
verified_count += 1
|
| 169 |
+
details.append({
|
| 170 |
+
"claim": claim,
|
| 171 |
+
"verdict": "Supported" if is_supported else "Not Supported"
|
| 172 |
+
})
|
| 173 |
+
|
| 174 |
+
score = (verified_count / len(claims)) * 100
|
| 175 |
+
|
| 176 |
+
if self.verbose:
|
| 177 |
+
self._print_faithfulness_results(claims, details, score)
|
| 178 |
+
|
| 179 |
+
return {"score": score, "details": details}
|
| 180 |
+
|
| 181 |
+
def _verify_claim_against_chunks(self, claim: str, context_list: list[str]) -> bool:
|
| 182 |
+
"""Verify a single claim against each chunk individually. Returns True if any chunk supports it."""
|
| 183 |
+
def verify_single_chunk(chunk):
|
| 184 |
+
prompt = (
|
| 185 |
+
f"Context:\n{chunk}\n\n"
|
| 186 |
+
f"Claim: {claim}\n\n"
|
| 187 |
+
f"Is this claim EXPLICITLY and DIRECTLY stated in the context above? "
|
| 188 |
+
f"Do not infer or assume. Respond with YES or NO only."
|
| 189 |
+
)
|
| 190 |
+
result = self.judge.generate(prompt)
|
| 191 |
+
return "YES" in result.upper()
|
| 192 |
+
|
| 193 |
+
# Use ThreadPoolExecutor for parallel verification
|
| 194 |
+
with ThreadPoolExecutor(max_workers=min(len(context_list), 5)) as executor:
|
| 195 |
+
futures = [executor.submit(verify_single_chunk, chunk) for chunk in context_list]
|
| 196 |
+
for future in as_completed(futures):
|
| 197 |
+
if future.result():
|
| 198 |
+
return True
|
| 199 |
+
return False
|
| 200 |
+
|
| 201 |
+
# ------------------------------------------------------------------
|
| 202 |
+
# 2. RELEVANCY: Alternate Query Generation
|
| 203 |
+
# ------------------------------------------------------------------
|
| 204 |
+
|
| 205 |
+
def evaluate_relevancy(self, query: str, answer: str) -> dict:
|
| 206 |
+
if self.verbose:
|
| 207 |
+
self._print_relevancy_header()
|
| 208 |
+
|
| 209 |
+
# --- Step A: Generation ---
|
| 210 |
+
# Explicitly ask the judge NOT to rephrase the original query
|
| 211 |
+
gen_prompt = (
|
| 212 |
+
f"Generate 3 distinct questions that the following answer addresses.\n"
|
| 213 |
+
f"Rules:\n"
|
| 214 |
+
f"- Do NOT rephrase or repeat this question: '{query}'\n"
|
| 215 |
+
f"- Each question must end with a '?'\n"
|
| 216 |
+
f"- One question per line, no numbering or bullet points\n\n"
|
| 217 |
+
f"Answer: {answer}"
|
| 218 |
+
)
|
| 219 |
+
raw_gen = self.judge.generate(gen_prompt)
|
| 220 |
+
|
| 221 |
+
# Filter by length rather than just '?' presence
|
| 222 |
+
gen_queries = [
|
| 223 |
+
q.strip() for q in raw_gen.split('\n')
|
| 224 |
+
if len(q.strip()) > 10
|
| 225 |
+
][:3]
|
| 226 |
+
|
| 227 |
+
if not gen_queries:
|
| 228 |
+
return {"score": 0, "queries": []}
|
| 229 |
+
|
| 230 |
+
# --- Step B: Similarity (single batched encode call) ---
|
| 231 |
+
all_vecs = self.encoder.encode([query] + gen_queries)
|
| 232 |
+
original_vec = all_vecs[0:1]
|
| 233 |
+
generated_vecs = all_vecs[1:]
|
| 234 |
+
|
| 235 |
+
similarities = cosine_similarity(original_vec, generated_vecs)[0]
|
| 236 |
+
avg_score = float(np.mean(similarities))
|
| 237 |
+
|
| 238 |
+
if self.verbose:
|
| 239 |
+
self._print_relevancy_results(query, gen_queries, similarities, avg_score)
|
| 240 |
+
|
| 241 |
+
return {"score": avg_score, "queries": gen_queries}
|
| 242 |
+
|
| 243 |
+
# ------------------------------------------------------------------
|
| 244 |
+
# 3. DATASET-LEVEL EVALUATION
|
| 245 |
+
# ------------------------------------------------------------------
|
| 246 |
+
|
| 247 |
+
def evaluate_dataset(self, test_cases: list[dict], strict: bool = False) -> dict:
|
| 248 |
+
"""
|
| 249 |
+
Runs faithfulness + relevancy over a full test set and aggregates results.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
test_cases: List of dicts, each with keys:
|
| 253 |
+
- "query": str
|
| 254 |
+
- "answer": str
|
| 255 |
+
- "contexts": List[str]
|
| 256 |
+
strict: If True, passes strict=True to evaluate_faithfulness
|
| 257 |
+
(per-chunk verification, more API calls, harder to pass)
|
| 258 |
+
|
| 259 |
+
Returns:
|
| 260 |
+
{
|
| 261 |
+
"avg_faithfulness": float,
|
| 262 |
+
"avg_relevancy": float,
|
| 263 |
+
"per_query": List[dict]
|
| 264 |
+
}
|
| 265 |
+
"""
|
| 266 |
+
faithfulness_scores = []
|
| 267 |
+
relevancy_scores = []
|
| 268 |
+
per_query = []
|
| 269 |
+
|
| 270 |
+
for i, case in enumerate(test_cases):
|
| 271 |
+
if self.verbose:
|
| 272 |
+
print(f"\n{'='*60}")
|
| 273 |
+
print(f"Query {i+1}/{len(test_cases)}: {case['query']}")
|
| 274 |
+
print('='*60)
|
| 275 |
+
|
| 276 |
+
f_result = self.evaluate_faithfulness(case['answer'], case['contexts'], strict=strict)
|
| 277 |
+
r_result = self.evaluate_relevancy(case['query'], case['answer'])
|
| 278 |
+
|
| 279 |
+
faithfulness_scores.append(f_result['score'])
|
| 280 |
+
relevancy_scores.append(r_result['score'])
|
| 281 |
+
per_query.append({
|
| 282 |
+
"query": case['query'],
|
| 283 |
+
"faithfulness": f_result,
|
| 284 |
+
"relevancy": r_result,
|
| 285 |
+
})
|
| 286 |
+
|
| 287 |
+
results = {
|
| 288 |
+
"avg_faithfulness": float(np.mean(faithfulness_scores)),
|
| 289 |
+
"avg_relevancy": float(np.mean(relevancy_scores)),
|
| 290 |
+
"per_query": per_query,
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
if self.verbose:
|
| 294 |
+
self._print_dataset_summary(results)
|
| 295 |
+
|
| 296 |
+
return results
|
| 297 |
+
|
| 298 |
+
# ------------------------------------------------------------------
|
| 299 |
+
# 4. PRINT HELPERS
|
| 300 |
+
# ------------------------------------------------------------------
|
| 301 |
+
|
| 302 |
+
def _print_extraction_header(self, length, strict=False):
|
| 303 |
+
mode = "strict per-chunk" if strict else "batch"
|
| 304 |
+
print(f"\n[EVAL] Analyzing Faithfulness ({mode})...")
|
| 305 |
+
print(f" - Extracting claims from answer ({length} chars)")
|
| 306 |
+
|
| 307 |
+
def _print_faithfulness_results(self, claims, details, score):
|
| 308 |
+
print(f" - Verifying {len(claims)} claims against context...")
|
| 309 |
+
for i, detail in enumerate(details):
|
| 310 |
+
status = "✅" if "Yes" in detail['verdict'] else "❌"
|
| 311 |
+
print(f" {status} Claim {i+1}: {detail['claim'][:75]}...")
|
| 312 |
+
print(f" 🎯 Faithfulness Score: {score:.1f}%")
|
| 313 |
+
|
| 314 |
+
def _print_relevancy_header(self):
|
| 315 |
+
print(f"\n[EVAL] Analyzing Relevancy...")
|
| 316 |
+
print(f" - Generating 3 distinct questions addressed by the answer")
|
| 317 |
+
|
| 318 |
+
def _print_relevancy_results(self, query, gen_queries, similarities, avg):
|
| 319 |
+
print(f" - Comparing to original query: '{query}'")
|
| 320 |
+
for i, (q, sim) in enumerate(zip(gen_queries, similarities)):
|
| 321 |
+
print(f" Q{i+1}: {q} (Sim: {sim:.2f})")
|
| 322 |
+
print(f" 🎯 Average Relevancy: {avg:.2f}")
|
| 323 |
+
|
| 324 |
+
def _print_dataset_summary(self, results):
|
| 325 |
+
print(f"\n{'='*60}")
|
| 326 |
+
print(f" DATASET EVALUATION SUMMARY")
|
| 327 |
+
print(f"{'='*60}")
|
| 328 |
+
print(f" Avg Faithfulness : {results['avg_faithfulness']:.1f}%")
|
| 329 |
+
print(f" Avg Relevancy : {results['avg_relevancy']:.2f}")
|
| 330 |
+
print(f" Queries Evaluated: {len(results['per_query'])}")
|
| 331 |
+
print(f"{'='*60}")
|
retriever/generator.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#changed the prompt to output as markdown, plus some formating details
|
| 2 |
+
#also added get answer stream for incremental token rendering on the frontend
|
| 3 |
+
# --@Qamar
|
| 4 |
+
class RAGGenerator:
|
| 5 |
+
def generate_prompt(self, query, retrieved_contexts, context_urls=None):
|
| 6 |
+
if context_urls:
|
| 7 |
+
context_text = "\n\n".join([f"[Source {i+1}] {url}: {c}" for i, (c, url) in enumerate(zip(retrieved_contexts, context_urls))])
|
| 8 |
+
else:
|
| 9 |
+
context_text = "\n\n".join([f"[Source {i+1}]: {c}" for i, c in enumerate(retrieved_contexts)])
|
| 10 |
+
|
| 11 |
+
return f"""You are a specialized Cognitive Behavioral Therapy (CBT) assistant. Your task is to provide accurate, clinical, and structured answers based ONLY on the provided textbook excerpts.
|
| 12 |
+
|
| 13 |
+
INSTRUCTIONS:
|
| 14 |
+
1. Use the provided Sources to answer the question.
|
| 15 |
+
2. CITATIONS: You must cite the sources used in your answer (e.g., "CBT is based on the cognitive model [Source 1]").
|
| 16 |
+
3. FORMAT: Use clear headers and bullet points for complex explanations.
|
| 17 |
+
4. GROUNDING: If the sources do not contain the answer, explicitly state: "The provided excerpts from the textbook do not contain information to answer this specific question." Do not use your own internal knowledge.
|
| 18 |
+
5. TONE: Maintain a professional, empathetic, and academic tone.
|
| 19 |
+
|
| 20 |
+
RETRIVED TEXTBOOK CONTEXT:
|
| 21 |
+
{context_text}
|
| 22 |
+
|
| 23 |
+
USER QUESTION: {query}
|
| 24 |
+
|
| 25 |
+
ACADEMIC ANSWER (WITH CITATIONS):"""
|
| 26 |
+
|
| 27 |
+
def get_answer(self, model_instance, query, retrieved_contexts, context_urls=None, **kwargs):
|
| 28 |
+
"""Uses a specific model instance to generate the final answer."""
|
| 29 |
+
prompt = self.generate_prompt(query, retrieved_contexts, context_urls)
|
| 30 |
+
return model_instance.generate(prompt, **kwargs)
|
| 31 |
+
|
| 32 |
+
def get_answer_stream(self, model_instance, query, retrieved_contexts, context_urls=None, **kwargs):
|
| 33 |
+
"""Streams model output token-by-token for incremental UI updates."""
|
| 34 |
+
prompt = self.generate_prompt(query, retrieved_contexts, context_urls)
|
| 35 |
+
|
| 36 |
+
if hasattr(model_instance, "generate_stream"):
|
| 37 |
+
for token in model_instance.generate_stream(prompt, **kwargs):
|
| 38 |
+
if token:
|
| 39 |
+
yield token
|
| 40 |
+
return
|
| 41 |
+
|
| 42 |
+
# Fallback for model wrappers that only expose sync generation.
|
| 43 |
+
answer = model_instance.generate(prompt, **kwargs)
|
| 44 |
+
if answer:
|
| 45 |
+
yield answer
|
retriever/processor.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_text_splitters import (
|
| 2 |
+
RecursiveCharacterTextSplitter,
|
| 3 |
+
CharacterTextSplitter,
|
| 4 |
+
SentenceTransformersTokenTextSplitter,
|
| 5 |
+
NLTKTextSplitter
|
| 6 |
+
)
|
| 7 |
+
from langchain_experimental.text_splitter import SemanticChunker
|
| 8 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 9 |
+
from sentence_transformers import SentenceTransformer
|
| 10 |
+
from typing import List, Dict, Any, Optional
|
| 11 |
+
import nltk
|
| 12 |
+
nltk.download('punkt_tab', quiet=True)
|
| 13 |
+
import pandas as pd
|
| 14 |
+
import re
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MarkdownTextSplitter:
|
| 18 |
+
"""
|
| 19 |
+
Custom markdown header chunking strategy.
|
| 20 |
+
|
| 21 |
+
Splits text by headers in a hierarchical manner:
|
| 22 |
+
- First checks h1 (#) headers
|
| 23 |
+
- If h1 content <= max_chars, accepts it as a chunk
|
| 24 |
+
- If h1 content > max_chars, splits into h2 headers
|
| 25 |
+
- If any h2 > max_chars, splits into h3, and so on
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, max_chars: int = 4000):
|
| 29 |
+
self.max_chars = max_chars
|
| 30 |
+
self.headers = ["\n# ", "\n## ", "\n### ", "\n#### "]
|
| 31 |
+
|
| 32 |
+
def split_text(self, text: str) -> List[str]:
|
| 33 |
+
"""Split text using markdown header hierarchy."""
|
| 34 |
+
return self._split_by_header(text, 0)
|
| 35 |
+
|
| 36 |
+
def _split_by_header(self, content: str, header_level: int) -> List[str]:
|
| 37 |
+
"""
|
| 38 |
+
Recursively split content by header levels.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
content: The text content to split
|
| 42 |
+
header_level: Current header level (0=h1, 1=h2, etc.)
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
List of text chunks
|
| 46 |
+
"""
|
| 47 |
+
# If content is within limit, return it as is
|
| 48 |
+
if len(content) <= self.max_chars:
|
| 49 |
+
return [content]
|
| 50 |
+
|
| 51 |
+
# If we've exhausted all header levels, return as single chunk
|
| 52 |
+
if header_level >= len(self.headers):
|
| 53 |
+
return [content]
|
| 54 |
+
|
| 55 |
+
# Split by current header level
|
| 56 |
+
header = self.headers[header_level]
|
| 57 |
+
parts = re.split(f'(?={re.escape(header)})', content)
|
| 58 |
+
|
| 59 |
+
# If no split occurred (no headers found at this level), try next level
|
| 60 |
+
if len(parts) == 1:
|
| 61 |
+
return self._split_by_header(content, header_level + 1)
|
| 62 |
+
|
| 63 |
+
result = []
|
| 64 |
+
accumulated = ""
|
| 65 |
+
|
| 66 |
+
for i, part in enumerate(parts):
|
| 67 |
+
# If this single part is too large, recursively split it with next header level
|
| 68 |
+
if len(part) > self.max_chars:
|
| 69 |
+
# First, flush any accumulated content
|
| 70 |
+
if accumulated:
|
| 71 |
+
result.append(accumulated)
|
| 72 |
+
accumulated = ""
|
| 73 |
+
# Then recursively split this large part with deeper headers
|
| 74 |
+
result.extend(self._split_by_header(part, header_level + 1))
|
| 75 |
+
# If adding this part would exceed limit, flush accumulated and start new
|
| 76 |
+
elif accumulated and len(accumulated) + len(part) > self.max_chars:
|
| 77 |
+
result.append(accumulated)
|
| 78 |
+
accumulated = part
|
| 79 |
+
# Accumulate parts that fit together
|
| 80 |
+
else:
|
| 81 |
+
accumulated += part
|
| 82 |
+
|
| 83 |
+
# Don't forget the last accumulated part
|
| 84 |
+
if accumulated:
|
| 85 |
+
result.append(accumulated)
|
| 86 |
+
|
| 87 |
+
return result
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class ChunkProcessor:
|
| 91 |
+
def __init__(self, model_name='all-MiniLM-L6-v2', verbose: bool = True, load_hf_embeddings: bool = False):
|
| 92 |
+
self.model_name = model_name
|
| 93 |
+
self._use_remote_code = self._requires_remote_code(model_name)
|
| 94 |
+
st_kwargs = {"trust_remote_code": True} if self._use_remote_code else {}
|
| 95 |
+
self.encoder = SentenceTransformer(model_name, **st_kwargs)
|
| 96 |
+
self.verbose = verbose
|
| 97 |
+
hf_kwargs = {"model_kwargs": {"trust_remote_code": True}} if self._use_remote_code else {}
|
| 98 |
+
self.hf_embeddings = HuggingFaceEmbeddings(model_name=model_name, **hf_kwargs) if load_hf_embeddings else None
|
| 99 |
+
|
| 100 |
+
def _requires_remote_code(self, model_name: str) -> bool:
|
| 101 |
+
normalized = (model_name or "").strip().lower()
|
| 102 |
+
return normalized.startswith("jinaai/")
|
| 103 |
+
|
| 104 |
+
def _get_hf_embeddings(self):
|
| 105 |
+
if self.hf_embeddings is None:
|
| 106 |
+
hf_kwargs = {"model_kwargs": {"trust_remote_code": True}} if self._use_remote_code else {}
|
| 107 |
+
self.hf_embeddings = HuggingFaceEmbeddings(model_name=self.model_name, **hf_kwargs)
|
| 108 |
+
return self.hf_embeddings
|
| 109 |
+
|
| 110 |
+
# ------------------------------------------------------------------
|
| 111 |
+
# Splitters
|
| 112 |
+
# ------------------------------------------------------------------
|
| 113 |
+
|
| 114 |
+
def get_splitter(self, technique: str, chunk_size: int = 500, chunk_overlap: int = 50, **kwargs):
|
| 115 |
+
"""
|
| 116 |
+
Factory method to return different chunking strategies.
|
| 117 |
+
|
| 118 |
+
Strategies:
|
| 119 |
+
- "fixed": Character-based, may split mid-sentence
|
| 120 |
+
- "recursive": Recursive character splitting with hierarchical separators
|
| 121 |
+
- "character": Character-based splitting on paragraph boundaries
|
| 122 |
+
- "paragraph": Paragraph-level splitting on \\n\\n boundaries
|
| 123 |
+
- "sentence": Sliding window over NLTK sentences
|
| 124 |
+
- "semantic": Embedding-based semantic chunking
|
| 125 |
+
- "page": Page-level splitting on page markers
|
| 126 |
+
"""
|
| 127 |
+
if technique == "fixed":
|
| 128 |
+
return CharacterTextSplitter(
|
| 129 |
+
separator=kwargs.get('separator', ""),
|
| 130 |
+
chunk_size=chunk_size,
|
| 131 |
+
chunk_overlap=chunk_overlap,
|
| 132 |
+
length_function=len,
|
| 133 |
+
is_separator_regex=False
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
elif technique == "recursive":
|
| 137 |
+
return RecursiveCharacterTextSplitter(
|
| 138 |
+
chunk_size=chunk_size,
|
| 139 |
+
chunk_overlap=chunk_overlap,
|
| 140 |
+
separators=kwargs.get('separators', ["\n\n", "\n", ". ", "! ", "? ", "; ", ", ", " ", ""]),
|
| 141 |
+
length_function=len,
|
| 142 |
+
keep_separator=kwargs.get('keep_separator', True)
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
elif technique == "character":
|
| 146 |
+
return CharacterTextSplitter(
|
| 147 |
+
separator=kwargs.get('separator', "\n\n"),
|
| 148 |
+
chunk_size=chunk_size,
|
| 149 |
+
chunk_overlap=chunk_overlap,
|
| 150 |
+
length_function=len,
|
| 151 |
+
is_separator_regex=False
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
elif technique == "paragraph":
|
| 155 |
+
# Paragraph-level chunking using paragraph breaks
|
| 156 |
+
return CharacterTextSplitter(
|
| 157 |
+
separator=kwargs.get('separator', "\n\n"),
|
| 158 |
+
chunk_size=chunk_size,
|
| 159 |
+
chunk_overlap=chunk_overlap,
|
| 160 |
+
length_function=len,
|
| 161 |
+
is_separator_regex=False
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
elif technique == "sentence":
|
| 165 |
+
# sentence-level chunking using NLTK
|
| 166 |
+
return NLTKTextSplitter(
|
| 167 |
+
chunk_size=chunk_size,
|
| 168 |
+
chunk_overlap=chunk_overlap,
|
| 169 |
+
separator="\n"
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
elif technique == "semantic":
|
| 173 |
+
return SemanticChunker(
|
| 174 |
+
self._get_hf_embeddings(),
|
| 175 |
+
breakpoint_threshold_type=kwargs.get('breakpoint_threshold_type', "percentile"),
|
| 176 |
+
# Using 70 because 95 was giving way too big chunks
|
| 177 |
+
breakpoint_threshold_amount=kwargs.get('breakpoint_threshold_amount', 70)
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
elif technique == "page":
|
| 181 |
+
# Page-level chunking using page markers
|
| 182 |
+
return CharacterTextSplitter(
|
| 183 |
+
separator=kwargs.get('separator', "--- Page"),
|
| 184 |
+
chunk_size=chunk_size,
|
| 185 |
+
chunk_overlap=chunk_overlap,
|
| 186 |
+
length_function=len,
|
| 187 |
+
is_separator_regex=False
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
elif technique == "markdown":
|
| 191 |
+
# Markdown header chunking - splits by headers with max char limit
|
| 192 |
+
return MarkdownTextSplitter(max_chars=chunk_size)
|
| 193 |
+
|
| 194 |
+
else:
|
| 195 |
+
raise ValueError(f"Technique '{technique}' is not supported. Choose from: fixed, recursive, character, paragraph, sentence, semantic, page, markdown")
|
| 196 |
+
|
| 197 |
+
# ------------------------------------------------------------------
|
| 198 |
+
# Processing
|
| 199 |
+
# ------------------------------------------------------------------
|
| 200 |
+
|
| 201 |
+
def process(self, df: pd.DataFrame, technique: str = "recursive", chunk_size: int = 500,
|
| 202 |
+
chunk_overlap: int = 50, max_docs: Optional[int] = 5,
|
| 203 |
+
verbose: Optional[bool] = None, **kwargs) -> List[Dict[str, Any]]:
|
| 204 |
+
"""
|
| 205 |
+
Processes a DataFrame into vector-ready chunks.
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
df: DataFrame with columns: id, title, url, full_text
|
| 209 |
+
technique: Chunking strategy to use
|
| 210 |
+
chunk_size: Maximum size of each chunk in characters
|
| 211 |
+
chunk_overlap: Overlap between consecutive chunks
|
| 212 |
+
max_docs: Number of documents to process (None for all)
|
| 213 |
+
verbose: Override instance verbose setting
|
| 214 |
+
**kwargs: Additional arguments passed to the splitter
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
List of chunk dicts with embeddings and metadata
|
| 218 |
+
"""
|
| 219 |
+
should_print = verbose if verbose is not None else self.verbose
|
| 220 |
+
|
| 221 |
+
required_cols = ['id', 'title', 'url', 'full_text']
|
| 222 |
+
missing_cols = [col for col in required_cols if col not in df.columns]
|
| 223 |
+
if missing_cols:
|
| 224 |
+
raise ValueError(f"DataFrame missing required columns: {missing_cols}")
|
| 225 |
+
|
| 226 |
+
splitter = self.get_splitter(technique, chunk_size, chunk_overlap, **kwargs)
|
| 227 |
+
subset_df = df.head(max_docs) if max_docs else df
|
| 228 |
+
processed_chunks = []
|
| 229 |
+
|
| 230 |
+
for _, row in subset_df.iterrows():
|
| 231 |
+
if should_print:
|
| 232 |
+
self._print_document_header(row['title'], row['url'], technique, chunk_size, chunk_overlap)
|
| 233 |
+
|
| 234 |
+
raw_chunks = splitter.split_text(row['full_text'])
|
| 235 |
+
|
| 236 |
+
for i, text in enumerate(raw_chunks):
|
| 237 |
+
content = text.page_content if hasattr(text, 'page_content') else text
|
| 238 |
+
|
| 239 |
+
if should_print:
|
| 240 |
+
self._print_chunk(i, content)
|
| 241 |
+
|
| 242 |
+
processed_chunks.append({
|
| 243 |
+
"id": f"{row['id']}-chunk-{i}",
|
| 244 |
+
"values": self.encoder.encode(content).tolist(),
|
| 245 |
+
"metadata": {
|
| 246 |
+
"title": row['title'],
|
| 247 |
+
"text": content,
|
| 248 |
+
"url": row['url'],
|
| 249 |
+
"chunk_index": i,
|
| 250 |
+
"technique": technique,
|
| 251 |
+
"chunk_size": len(content),
|
| 252 |
+
"total_chunks": len(raw_chunks)
|
| 253 |
+
}
|
| 254 |
+
})
|
| 255 |
+
|
| 256 |
+
if should_print:
|
| 257 |
+
self._print_document_summary(len(raw_chunks))
|
| 258 |
+
|
| 259 |
+
if should_print:
|
| 260 |
+
self._print_processing_summary(len(subset_df), processed_chunks)
|
| 261 |
+
|
| 262 |
+
return processed_chunks
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# ------------------------------------------------------------------
|
| 266 |
+
# Printing
|
| 267 |
+
# ------------------------------------------------------------------
|
| 268 |
+
|
| 269 |
+
def _print_document_header(self, title, url, technique, chunk_size, chunk_overlap):
|
| 270 |
+
print("\n" + "="*80)
|
| 271 |
+
print(f"DOCUMENT: {title}")
|
| 272 |
+
print(f"URL: {url}")
|
| 273 |
+
print(f"Technique: {technique.upper()} | Chunk Size: {chunk_size} | Overlap: {chunk_overlap}")
|
| 274 |
+
print("-" * 80)
|
| 275 |
+
|
| 276 |
+
def _print_chunk(self, index, content):
|
| 277 |
+
print(f"\n[Chunk {index}] ({len(content)} chars):")
|
| 278 |
+
print(f" {content}")
|
| 279 |
+
|
| 280 |
+
def _print_document_summary(self, num_chunks):
|
| 281 |
+
print(f"Total Chunks Generated: {num_chunks}")
|
| 282 |
+
print("="*80)
|
| 283 |
+
|
| 284 |
+
def _print_processing_summary(self, num_docs, processed_chunks):
|
| 285 |
+
print(f"\nFinished processing {num_docs} documents into {len(processed_chunks)} chunks.")
|
| 286 |
+
if processed_chunks:
|
| 287 |
+
avg = sum(c['metadata']['chunk_size'] for c in processed_chunks) / len(processed_chunks)
|
| 288 |
+
print(f"Average chunk size: {avg:.0f} chars")
|
retriever/retriever.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import time
|
| 3 |
+
import re
|
| 4 |
+
from rank_bm25 import BM25Okapi
|
| 5 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 6 |
+
from typing import Optional, List
|
| 7 |
+
|
| 8 |
+
# changed mmr to return final k, as a param, prev was hardcoded to 3
|
| 9 |
+
# --@Qamare
|
| 10 |
+
|
| 11 |
+
# Try to import FlashRank for CPU optimization, fallback to sentence-transformers
|
| 12 |
+
try:
|
| 13 |
+
from flashrank import Ranker, RerankRequest
|
| 14 |
+
FLASHRANK_AVAILABLE = True
|
| 15 |
+
except ImportError:
|
| 16 |
+
from sentence_transformers import CrossEncoder
|
| 17 |
+
FLASHRANK_AVAILABLE = False
|
| 18 |
+
|
| 19 |
+
class HybridRetriever:
|
| 20 |
+
def __init__(self, final_chunks, embed_model, rerank_model_name='jinaai/jina-reranker-v1-tiny-en', verbose: bool = True):
|
| 21 |
+
self.final_chunks = final_chunks
|
| 22 |
+
self.embed_model = embed_model
|
| 23 |
+
self.verbose = verbose
|
| 24 |
+
self.rerank_model_name = self._normalize_rerank_model_name(rerank_model_name)
|
| 25 |
+
|
| 26 |
+
# Use FlashRank if available (faster on CPU), otherwise fallback to sentence-transformers
|
| 27 |
+
if FLASHRANK_AVAILABLE:
|
| 28 |
+
try:
|
| 29 |
+
self.rerank_model = Ranker(model_name=self.rerank_model_name)
|
| 30 |
+
self.use_flashrank = True
|
| 31 |
+
except Exception:
|
| 32 |
+
from sentence_transformers import CrossEncoder as STCrossEncoder
|
| 33 |
+
self.rerank_model = STCrossEncoder(self.rerank_model_name)
|
| 34 |
+
self.use_flashrank = False
|
| 35 |
+
else:
|
| 36 |
+
self.rerank_model = CrossEncoder(self.rerank_model_name)
|
| 37 |
+
self.use_flashrank = False
|
| 38 |
+
|
| 39 |
+
# Better tokenization for BM25 (strips punctuation)
|
| 40 |
+
self.tokenized_corpus = [self._tokenize(chunk['metadata']['text']) for chunk in final_chunks]
|
| 41 |
+
self.bm25 = BM25Okapi(self.tokenized_corpus)
|
| 42 |
+
self.technique_to_indices = self._build_chunking_index_map()
|
| 43 |
+
|
| 44 |
+
def _normalize_rerank_model_name(self, model_name: str) -> str:
|
| 45 |
+
normalized = (model_name or "").strip()
|
| 46 |
+
if not normalized:
|
| 47 |
+
return "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
| 48 |
+
if "/" in normalized:
|
| 49 |
+
return normalized
|
| 50 |
+
return f"cross-encoder/{normalized}"
|
| 51 |
+
|
| 52 |
+
def _tokenize(self, text: str) -> List[str]:
|
| 53 |
+
"""Tokenize text using regex to strip punctuation."""
|
| 54 |
+
return re.findall(r'\w+', text.lower())
|
| 55 |
+
|
| 56 |
+
# added these two helper methods for chunking based on chunk_technique metadata, and normalization of chunking_technique param
|
| 57 |
+
def _build_chunking_index_map(self) -> dict[str, List[int]]:
|
| 58 |
+
mapping: dict[str, List[int]] = {}
|
| 59 |
+
for idx, chunk in enumerate(self.final_chunks):
|
| 60 |
+
metadata = chunk.get('metadata', {})
|
| 61 |
+
technique = (metadata.get('chunking_technique') or '').strip().lower()
|
| 62 |
+
if not technique:
|
| 63 |
+
continue
|
| 64 |
+
mapping.setdefault(technique, []).append(idx)
|
| 65 |
+
return mapping
|
| 66 |
+
|
| 67 |
+
def _normalize_chunking_technique(self, chunking_technique: Optional[str]) -> Optional[str]:
|
| 68 |
+
if not chunking_technique:
|
| 69 |
+
return None
|
| 70 |
+
normalized = str(chunking_technique).strip().lower()
|
| 71 |
+
if not normalized or normalized in {"all", "any", "*", "none"}:
|
| 72 |
+
return None
|
| 73 |
+
return normalized
|
| 74 |
+
|
| 75 |
+
# ------------------------------------------------------------------
|
| 76 |
+
# Retrieval
|
| 77 |
+
# ------------------------------------------------------------------
|
| 78 |
+
|
| 79 |
+
def _semantic_search(self, query, index, top_k, chunking_technique: Optional[str] = None) -> tuple[np.ndarray, List[str]]:
|
| 80 |
+
query_vector = self.embed_model.encode(query)
|
| 81 |
+
query_kwargs = {
|
| 82 |
+
"vector": query_vector.tolist(),
|
| 83 |
+
"top_k": top_k,
|
| 84 |
+
"include_metadata": True,
|
| 85 |
+
}
|
| 86 |
+
if chunking_technique:
|
| 87 |
+
query_kwargs["filter"] = {"chunking_technique": {"$eq": chunking_technique}}
|
| 88 |
+
res = index.query(**query_kwargs)
|
| 89 |
+
chunks = [match['metadata']['text'] for match in res['matches']]
|
| 90 |
+
return query_vector, chunks
|
| 91 |
+
|
| 92 |
+
def _bm25_search(self, query, top_k, chunking_technique: Optional[str] = None) -> List[str]:
|
| 93 |
+
tokenized_query = self._tokenize(query)
|
| 94 |
+
scores = self.bm25.get_scores(tokenized_query)
|
| 95 |
+
|
| 96 |
+
if chunking_technique:
|
| 97 |
+
candidate_indices = self.technique_to_indices.get(chunking_technique, [])
|
| 98 |
+
if not candidate_indices:
|
| 99 |
+
return []
|
| 100 |
+
top_indices = sorted(candidate_indices, key=lambda i: scores[i], reverse=True)[:top_k]
|
| 101 |
+
else:
|
| 102 |
+
top_indices = np.argsort(scores)[::-1][:top_k]
|
| 103 |
+
|
| 104 |
+
return [self.final_chunks[i]['metadata']['text'] for i in top_indices]
|
| 105 |
+
|
| 106 |
+
# ------------------------------------------------------------------
|
| 107 |
+
# Fusion
|
| 108 |
+
# ------------------------------------------------------------------
|
| 109 |
+
|
| 110 |
+
def _rrf_score(self, semantic_results, bm25_results, k=60) -> List[str]:
|
| 111 |
+
scores = {}
|
| 112 |
+
for rank, chunk in enumerate(semantic_results):
|
| 113 |
+
scores[chunk] = scores.get(chunk, 0) + 1 / (k + rank + 1)
|
| 114 |
+
for rank, chunk in enumerate(bm25_results):
|
| 115 |
+
scores[chunk] = scores.get(chunk, 0) + 1 / (k + rank + 1)
|
| 116 |
+
return [chunk for chunk, _ in sorted(scores.items(), key=lambda x: x[1], reverse=True)]
|
| 117 |
+
|
| 118 |
+
# ------------------------------------------------------------------
|
| 119 |
+
# Reranking
|
| 120 |
+
# ------------------------------------------------------------------
|
| 121 |
+
|
| 122 |
+
def _cross_encoder_rerank(self, query, chunks, final_k) -> List[str]:
|
| 123 |
+
if self.use_flashrank:
|
| 124 |
+
# Use FlashRank for CPU-optimized reranking
|
| 125 |
+
passages = [{"id": i, "text": chunk} for i, chunk in enumerate(chunks)]
|
| 126 |
+
rerank_request = RerankRequest(query=query, passages=passages)
|
| 127 |
+
results = self.rerank_model.rerank(rerank_request)
|
| 128 |
+
ranked_chunks = [res['text'] for res in results]
|
| 129 |
+
return ranked_chunks[:final_k]
|
| 130 |
+
else:
|
| 131 |
+
# Fallback to sentence-transformers CrossEncoder
|
| 132 |
+
pairs = [[query, chunk] for chunk in chunks]
|
| 133 |
+
scores = self.rerank_model.predict(pairs)
|
| 134 |
+
ranked = sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
| 135 |
+
return [chunk for chunk, _ in ranked[:final_k]]
|
| 136 |
+
|
| 137 |
+
# ------------------------------------------------------------------
|
| 138 |
+
# MMR (applied after reranking as a diversity filter)
|
| 139 |
+
# ------------------------------------------------------------------
|
| 140 |
+
|
| 141 |
+
def _maximal_marginal_relevance(self, query_vector, chunks, lambda_param=0.5, top_k=10) -> List[str]:
|
| 142 |
+
"""
|
| 143 |
+
Maximum Marginal Relevance (MMR) for diversity filtering.
|
| 144 |
+
|
| 145 |
+
DIVISION BY ZERO DEBUGGING:
|
| 146 |
+
- This method can cause division by zero in cosine_similarity if vectors are zero
|
| 147 |
+
- We've added multiple safeguards to prevent this
|
| 148 |
+
"""
|
| 149 |
+
print(f" [MMR DEBUG] Starting MMR with {len(chunks)} chunks, top_k={top_k}")
|
| 150 |
+
|
| 151 |
+
if not chunks:
|
| 152 |
+
print(f" [MMR DEBUG] No chunks, returning empty list")
|
| 153 |
+
return []
|
| 154 |
+
|
| 155 |
+
# STEP 1: Encode chunks to get embeddings
|
| 156 |
+
print(f" [MMR DEBUG] Encoding {len(chunks)} chunks...")
|
| 157 |
+
try:
|
| 158 |
+
chunk_embeddings = self.embed_model.encode(chunks)
|
| 159 |
+
print(f" [MMR DEBUG] Chunk embeddings shape: {chunk_embeddings.shape}")
|
| 160 |
+
except Exception as e:
|
| 161 |
+
print(f" [MMR DEBUG] ERROR encoding chunks: {e}")
|
| 162 |
+
return chunks[:top_k]
|
| 163 |
+
|
| 164 |
+
# STEP 2: Reshape query vector
|
| 165 |
+
query_embedding = query_vector.reshape(1, -1)
|
| 166 |
+
print(f" [MMR DEBUG] Query embedding shape: {query_embedding.shape}")
|
| 167 |
+
|
| 168 |
+
# STEP 3: Check for zero vectors (POTENTIAL DIVISION BY ZERO SOURCE)
|
| 169 |
+
print(f" [MMR DEBUG] Checking for zero vectors...")
|
| 170 |
+
query_norm = np.linalg.norm(query_embedding)
|
| 171 |
+
chunk_norms = np.linalg.norm(chunk_embeddings, axis=1)
|
| 172 |
+
|
| 173 |
+
print(f" [MMR DEBUG] Query norm: {query_norm}")
|
| 174 |
+
print(f" [MMR DEBUG] Chunk norms min: {chunk_norms.min()}, max: {chunk_norms.max()}")
|
| 175 |
+
|
| 176 |
+
# Check for zero or near-zero vectors
|
| 177 |
+
if query_norm < 1e-10 or np.any(chunk_norms < 1e-10):
|
| 178 |
+
print(f" [MMR DEBUG] WARNING: Zero or near-zero vectors detected!")
|
| 179 |
+
print(f" [MMR DEBUG] Query norm < 1e-10: {query_norm < 1e-10}")
|
| 180 |
+
print(f" [MMR DEBUG] Any chunk norm < 1e-10: {np.any(chunk_norms < 1e-10)}")
|
| 181 |
+
print(f" [MMR DEBUG] Falling back to simple selection without MMR")
|
| 182 |
+
return chunks[:top_k]
|
| 183 |
+
|
| 184 |
+
# STEP 4: Compute relevance scores (POTENTIAL DIVISION BY ZERO SOURCE)
|
| 185 |
+
print(f" [MMR DEBUG] Computing relevance scores with cosine_similarity...")
|
| 186 |
+
try:
|
| 187 |
+
relevance_scores = cosine_similarity(query_embedding, chunk_embeddings)[0]
|
| 188 |
+
print(f" [MMR DEBUG] Relevance scores computed successfully")
|
| 189 |
+
print(f" [MMR DEBUG] Relevance scores shape: {relevance_scores.shape}")
|
| 190 |
+
print(f" [MMR DEBUG] Relevance scores min: {relevance_scores.min()}, max: {relevance_scores.max()}")
|
| 191 |
+
except Exception as e:
|
| 192 |
+
print(f" [MMR DEBUG] ERROR computing relevance scores: {e}")
|
| 193 |
+
print(f" [MMR DEBUG] Falling back to simple selection")
|
| 194 |
+
return chunks[:top_k]
|
| 195 |
+
|
| 196 |
+
# STEP 5: Initialize selection
|
| 197 |
+
selected, unselected = [], list(range(len(chunks)))
|
| 198 |
+
|
| 199 |
+
first = int(np.argmax(relevance_scores))
|
| 200 |
+
selected.append(first)
|
| 201 |
+
unselected.remove(first)
|
| 202 |
+
print(f" [MMR DEBUG] Selected first chunk: index {first}")
|
| 203 |
+
|
| 204 |
+
# STEP 6: Iteratively select chunks using MMR
|
| 205 |
+
print(f" [MMR DEBUG] Starting MMR iteration...")
|
| 206 |
+
iteration = 0
|
| 207 |
+
while len(selected) < min(top_k, len(chunks)):
|
| 208 |
+
iteration += 1
|
| 209 |
+
print(f" [MMR DEBUG] Iteration {iteration}: selected={len(selected)}, unselected={len(unselected)}")
|
| 210 |
+
|
| 211 |
+
# Calculate MMR scores
|
| 212 |
+
mmr_scores = []
|
| 213 |
+
for i in unselected:
|
| 214 |
+
# Compute max similarity to already selected items
|
| 215 |
+
max_sim = -1
|
| 216 |
+
for s in selected:
|
| 217 |
+
try:
|
| 218 |
+
# POTENTIAL DIVISION BY ZERO SOURCE: cosine_similarity
|
| 219 |
+
sim = cosine_similarity(
|
| 220 |
+
chunk_embeddings[i].reshape(1, -1),
|
| 221 |
+
chunk_embeddings[s].reshape(1, -1)
|
| 222 |
+
)[0][0]
|
| 223 |
+
max_sim = max(max_sim, sim)
|
| 224 |
+
except Exception as e:
|
| 225 |
+
print(f" [MMR DEBUG] ERROR computing similarity between chunk {i} and {s}: {e}")
|
| 226 |
+
# If similarity computation fails, use 0
|
| 227 |
+
max_sim = max(max_sim, 0)
|
| 228 |
+
|
| 229 |
+
mmr_score = lambda_param * relevance_scores[i] - (1 - lambda_param) * max_sim
|
| 230 |
+
mmr_scores.append((i, mmr_score))
|
| 231 |
+
|
| 232 |
+
# Select chunk with highest MMR score
|
| 233 |
+
if mmr_scores:
|
| 234 |
+
best, best_score = max(mmr_scores, key=lambda x: x[1])
|
| 235 |
+
selected.append(best)
|
| 236 |
+
unselected.remove(best)
|
| 237 |
+
print(f" [MMR DEBUG] Selected chunk {best} with MMR score {best_score:.4f}")
|
| 238 |
+
else:
|
| 239 |
+
print(f" [MMR DEBUG] No MMR scores computed, breaking")
|
| 240 |
+
break
|
| 241 |
+
|
| 242 |
+
print(f" [MMR DEBUG] MMR complete. Selected {len(selected)} chunks")
|
| 243 |
+
return [chunks[i] for i in selected]
|
| 244 |
+
|
| 245 |
+
# ------------------------------------------------------------------
|
| 246 |
+
# Main search
|
| 247 |
+
# ------------------------------------------------------------------
|
| 248 |
+
|
| 249 |
+
def search(self, query, index, top_k=25, final_k=5, mode="hybrid",
|
| 250 |
+
chunking_technique: Optional[str] = None,
|
| 251 |
+
rerank_strategy="cross-encoder", use_mmr=False, lambda_param=0.5,
|
| 252 |
+
verbose: Optional[bool] = None) -> List[str]:
|
| 253 |
+
"""
|
| 254 |
+
:param mode: "semantic", "bm25", or "hybrid"
|
| 255 |
+
:param rerank_strategy: "cross-encoder", "rrf", or "none"
|
| 256 |
+
:param use_mmr: Whether to apply MMR diversity filter after reranking
|
| 257 |
+
:param lambda_param: MMR trade-off between relevance (1.0) and diversity (0.0)
|
| 258 |
+
"""
|
| 259 |
+
should_print = verbose if verbose is not None else self.verbose
|
| 260 |
+
requested_technique = self._normalize_chunking_technique(chunking_technique)
|
| 261 |
+
total_start = time.perf_counter()
|
| 262 |
+
semantic_time = 0.0
|
| 263 |
+
bm25_time = 0.0
|
| 264 |
+
rerank_time = 0.0
|
| 265 |
+
mmr_time = 0.0
|
| 266 |
+
|
| 267 |
+
if should_print:
|
| 268 |
+
self._print_search_header(query, mode, rerank_strategy, top_k, final_k)
|
| 269 |
+
if requested_technique:
|
| 270 |
+
print(f"Chunking Filter: {requested_technique}")
|
| 271 |
+
|
| 272 |
+
# 1. Retrieve candidates
|
| 273 |
+
query_vector = None
|
| 274 |
+
semantic_chunks, bm25_chunks = [], []
|
| 275 |
+
|
| 276 |
+
if mode in ["semantic", "hybrid"]:
|
| 277 |
+
semantic_start = time.perf_counter()
|
| 278 |
+
query_vector, semantic_chunks = self._semantic_search(query, index, top_k, requested_technique)
|
| 279 |
+
semantic_time = time.perf_counter() - semantic_start
|
| 280 |
+
if should_print:
|
| 281 |
+
self._print_candidates("Semantic Search", semantic_chunks)
|
| 282 |
+
print(f"Semantic time: {semantic_time:.3f}s")
|
| 283 |
+
|
| 284 |
+
if mode in ["bm25", "hybrid"]:
|
| 285 |
+
bm25_start = time.perf_counter()
|
| 286 |
+
bm25_chunks = self._bm25_search(query, top_k, requested_technique)
|
| 287 |
+
bm25_time = time.perf_counter() - bm25_start
|
| 288 |
+
if should_print:
|
| 289 |
+
self._print_candidates("BM25 Search", bm25_chunks)
|
| 290 |
+
print(f"BM25 time: {bm25_time:.3f}s")
|
| 291 |
+
|
| 292 |
+
# 2. Fuse / rerank
|
| 293 |
+
rerank_start = time.perf_counter()
|
| 294 |
+
if rerank_strategy == "rrf":
|
| 295 |
+
candidates = self._rrf_score(semantic_chunks, bm25_chunks)[:final_k]
|
| 296 |
+
label = "RRF"
|
| 297 |
+
elif rerank_strategy == "cross-encoder":
|
| 298 |
+
combined = list(dict.fromkeys(semantic_chunks + bm25_chunks))
|
| 299 |
+
candidates = self._cross_encoder_rerank(query, combined, final_k)
|
| 300 |
+
label = "Cross-Encoder"
|
| 301 |
+
else: # "none"
|
| 302 |
+
candidates = list(dict.fromkeys(semantic_chunks + bm25_chunks))[:final_k]
|
| 303 |
+
label = "No Reranking"
|
| 304 |
+
rerank_time = time.perf_counter() - rerank_start
|
| 305 |
+
|
| 306 |
+
# 3. MMR diversity filter (applied after reranking)
|
| 307 |
+
if use_mmr and candidates:
|
| 308 |
+
mmr_start = time.perf_counter()
|
| 309 |
+
if query_vector is None:
|
| 310 |
+
query_vector = self.embed_model.encode(query)
|
| 311 |
+
candidates = self._maximal_marginal_relevance(query_vector, candidates,
|
| 312 |
+
lambda_param=lambda_param, top_k=final_k)
|
| 313 |
+
label += " + MMR"
|
| 314 |
+
mmr_time = time.perf_counter() - mmr_start
|
| 315 |
+
|
| 316 |
+
total_time = time.perf_counter() - total_start
|
| 317 |
+
|
| 318 |
+
if should_print:
|
| 319 |
+
self._print_final_results(candidates, label)
|
| 320 |
+
self._print_timing_summary(semantic_time, bm25_time, rerank_time, mmr_time, total_time)
|
| 321 |
+
|
| 322 |
+
return candidates
|
| 323 |
+
|
| 324 |
+
# ------------------------------------------------------------------
|
| 325 |
+
# Printing
|
| 326 |
+
# ------------------------------------------------------------------
|
| 327 |
+
|
| 328 |
+
def _print_search_header(self, query, mode, rerank_strategy, top_k, final_k):
|
| 329 |
+
print("\n" + "="*80)
|
| 330 |
+
print(f" SEARCH QUERY: {query}")
|
| 331 |
+
print(f"Mode: {mode.upper()} | Rerank: {rerank_strategy.upper()}")
|
| 332 |
+
print(f"Top-K: {top_k} | Final-K: {final_k}")
|
| 333 |
+
print("-" * 80)
|
| 334 |
+
|
| 335 |
+
def _print_candidates(self, label, chunks, preview_n=3):
|
| 336 |
+
print(f"{label}: Retrieved {len(chunks)} candidates")
|
| 337 |
+
for i, chunk in enumerate(chunks[:preview_n]):
|
| 338 |
+
preview = chunk[:100] + "..." if len(chunk) > 100 else chunk
|
| 339 |
+
print(f" [{i}] {preview}")
|
| 340 |
+
|
| 341 |
+
def _print_final_results(self, results, strategy_label):
|
| 342 |
+
print(f"\n Final {len(results)} Results ({strategy_label}):")
|
| 343 |
+
for i, chunk in enumerate(results):
|
| 344 |
+
preview = chunk[:150] + "..." if len(chunk) > 150 else chunk
|
| 345 |
+
print(f" [{i+1}] {preview}")
|
| 346 |
+
print("="*80)
|
| 347 |
+
|
| 348 |
+
def _print_timing_summary(self, semantic_time, bm25_time, rerank_time, mmr_time, total_time):
|
| 349 |
+
print(" Retrieval Timing:")
|
| 350 |
+
print(f" Semantic: {semantic_time:.3f}s")
|
| 351 |
+
print(f" BM25: {bm25_time:.3f}s")
|
| 352 |
+
print(f" Rerank/Fusion: {rerank_time:.3f}s")
|
| 353 |
+
print(f" MMR: {mmr_time:.3f}s")
|
| 354 |
+
print(f" Total Retrieval: {total_time:.3f}s")
|