Nolej / rag.py
AshJem's picture
Update rag.py
b50f84e verified
import os
import requests
from pathlib import Path
from typing import List, Dict, Any, Tuple
import time
import numpy as np
import torch
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from dotenv import load_dotenv
load_dotenv()
# -----------------------
# CONFIG
# -----------------------
REPO_ID = os.getenv("HF_DATASET_REPO", "AshJem/NOLEJ1")
FAISS_URL = os.getenv(
"FAISS_URL",
"https://huggingface.co/datasets/AshJem/NOLEJ1/resolve/main/global.faiss"
)
TOP_K = int(os.getenv("TOP_K", "8"))
RERANK_TOP_N = int(os.getenv("RERANK_TOP_N", "4"))
MIN_CONTEXT_CHARS = int(os.getenv("MIN_CONTEXT_CHARS", "250"))
RERANK = os.getenv("RERANK", "1") == "1"
USE_OPENAI = os.getenv("USE_OPENAI", "1") == "1"
# -----------------------
# OPENAI CLIENT (GLOBAL)
# -----------------------
from openai import AsyncOpenAI
_openai_client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
# -----------------------
# LOAD ONCE (global)
# -----------------------
WRITE_DIR = Path(os.getenv("WRITE_DIR", "./tmp"))
WRITE_DIR.mkdir(parents=True, exist_ok=True)
FAISS_LOCAL = WRITE_DIR / "global.faiss"
def _download_faiss_if_needed():
if FAISS_LOCAL.exists():
return
with requests.get(FAISS_URL, stream=True, timeout=120) as r:
r.raise_for_status()
with open(FAISS_LOCAL, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
_download_faiss_if_needed()
dataset = load_dataset(REPO_ID, split="train", keep_in_memory=True)
dataset.load_faiss_index("embedding", str(FAISS_LOCAL))
# ✅ Must match the embedding model used to build the FAISS index
embedder = SentenceTransformer("BAAI/bge-m3")
if RERANK:
# ✅ Multilingual reranker aligned with bge-m3
reranker_model_id = "BAAI/bge-reranker-v2-m3"
reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_model_id)
reranker_model = AutoModelForSequenceClassification.from_pretrained(reranker_model_id)
reranker_model.eval()
else:
reranker_tokenizer = None
reranker_model = None
def embed_query(text: str) -> np.ndarray:
return embedder.encode(text, normalize_embeddings=True)
def retrieve_topk(query: str, k: int = TOP_K) -> List[Dict[str, Any]]:
qv = embed_query(query)
scores, ret = dataset.get_nearest_examples("embedding", qv, k=k)
items = []
for i in range(len(ret["text"])):
s = float(scores[i]) if scores is not None else None
items.append({
"text": ret["text"][i],
"metadata": ret["metadata"][i],
"score": s, # FAISS similarity score (as returned)
"final_score": s, # default = faiss score (overwritten by rerank if enabled)
})
return items
def rerank(query: str, items: List[Dict[str, Any]], top_n: int = RERANK_TOP_N) -> List[Dict[str, Any]]:
if (not RERANK) or (not items):
return items[:top_n]
inputs = [f"{query} [SEP] {it['text']}" for it in items]
enc = reranker_tokenizer(inputs, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
logits = reranker_model(**enc).logits.squeeze()
rr_scores = [float(logits)] if logits.ndim == 0 else [float(x) for x in logits.tolist()]
for it, s in zip(items, rr_scores):
it["rerank_score"] = s
it["final_score"] = s # final = rerank score (your current behavior)
items = sorted(items, key=lambda x: x.get("rerank_score", -1e9), reverse=True)
return items[:top_n]
def build_context(items: List[Dict[str, Any]]) -> Tuple[str, List[Dict[str, Any]]]:
"""
New index metadata schema (minimal):
source_file, source_path, doc_id, file_hash, chunk_index, chunk_id
"""
citations = []
parts = []
for it in items:
meta = it.get("metadata", {}) or {}
parts.append(it["text"])
citations.append({
"source_file": meta.get("source_file"),
"source_path": meta.get("source_path"),
"doc_id": meta.get("doc_id"),
"file_hash": meta.get("file_hash"),
"chunk_index": meta.get("chunk_index"),
"chunk_id": meta.get("chunk_id"),
# scores (for debug/UI)
"faiss_score": it.get("score"),
"rerank_score": it.get("rerank_score"),
"final_score": it.get("final_score"),
})
return "\n\n---\n\n".join(parts), citations
def should_answer(context: str) -> bool:
return isinstance(context, str) and len(context.strip()) >= MIN_CONTEXT_CHARS
async def generate_answer_openai(question: str, context: str) -> str:
system = (
"You are a QA assistant.\n"
"Use the CONVERSATION CONTEXT only to understand the user's question.\n"
"Answer ONLY using the provided DOCUMENT CONTEXT.\n"
"If the answer is not in the document context, reply exactly:\n"
"\"I don't know based on the provided documents\".\n"
"Be concise and factual."
)
user = f"QUESTION:\n{question}\n\nCONTEXT:\n{context}"
resp = await _openai_client.chat.completions.create(
model=os.getenv("OPENAI_MODEL", "gpt-4.1-mini"),
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user},
],
)
return resp.choices[0].message.content.strip()
async def generate_answer_together(question: str, context: str) -> str:
from llama_index.llms.together import TogetherLLM
llm = TogetherLLM(
model=os.getenv("TOGETHER_MODEL", "openai/gpt-oss-20b"),
api_key=os.environ.get("TOGETHER_API_KEY"),
streaming=False,
)
prompt = (
"Answer ONLY using the CONTEXT below.\n"
"If you cannot find the answer in the context, reply exactly:\n"
"\"I don't know based on the provided documents\".\n\n"
f"QUESTION:\n{question}\n\nCONTEXT:\n{context}\n"
)
out = llm.complete(prompt)
return (out.text or "").strip()
async def answer_question(question: str) -> Dict[str, Any]:
t0 = time.perf_counter()
retrieved = retrieve_topk(question, k=TOP_K)
t1 = time.perf_counter()
reranked = rerank(question, retrieved, top_n=RERANK_TOP_N)
t2 = time.perf_counter()
context, citations = build_context(reranked)
if not should_answer(context):
t3 = time.perf_counter()
print(
f"[timing] retrieve={t1 - t0:.2f}s "
f"rerank={t2 - t1:.2f}s "
f"llm=0.00s total={t3 - t0:.2f}s (refusal)"
)
return {
"answer": "I don't know based on the provided documents",
"citations": citations,
"used_chunks": reranked,
}
answer = await (generate_answer_openai(question, context) if USE_OPENAI else generate_answer_together(question, context))
t3 = time.perf_counter()
print(
f"[timing] retrieve={t1 - t0:.2f}s "
f"rerank={t2 - t1:.2f}s "
f"llm={t3 - t2:.2f}s total={t3 - t0:.2f}s"
)
return {
"answer": answer,
"citations": citations,
"used_chunks": reranked,
}