PrepGraph-Backend / chatbot_retriever.py
07Codex07's picture
changed the context and retrieval
41d23d8
# chatbot_retriever.py
"""
Hybrid retriever:
- loads PDFs & PPTX (robust imports)
- chunks via RecursiveCharacterTextSplitter
- BM25 (rank_bm25) + FAISS (IVF when possible) using SentenceTransformers
- returns a combined context string limited by MAX_CONTEXT_CHARS
"""
import os
import re
import pickle
import logging
import shutil
import random
from typing import List, Optional, Dict, Any
import numpy as np
import faiss
from rank_bm25 import BM25Okapi
from langchain_community.document_loaders import UnstructuredFileLoader
# Document loaders: try langchain first, then community loader
try:
from langchain.document_loaders import PyPDFLoader, UnstructuredPowerPointLoader
except Exception:
# fallback to community package (older installations)
try:
from langchain_community.document_loaders import PyPDFLoader, UnstructuredPowerPointLoader
from langchain_community.document_loaders.powerpoint import UnstructuredPowerPointLoader
except Exception:
raise ImportError("Please install langchain + langchain-community (or upgrade).")
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
# ---------- Config ----------
DATA_DIR = os.getenv("DATA_DIR", "data")
CACHE_DIR = os.getenv("CACHE_DIR", ".ragg_cache")
CHUNKS_CACHE = os.path.join(CACHE_DIR, "chunks.pkl")
BM25_CACHE = os.path.join(CACHE_DIR, "bm25.pkl")
FAISS_DIR = os.getenv("FAISS_DIR", "faiss_index")
FAISS_INDEX_PATH = os.path.join(FAISS_DIR, "index.faiss")
FAISS_META_PATH = os.path.join(FAISS_DIR, "meta.pkl")
os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs(FAISS_DIR, exist_ok=True)
CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", 400))
CHUNK_OVERLAP = int(os.getenv("CHUNK_OVERLAP", 80))
EMBED_MODEL = os.getenv("EMBED_MODEL", "all-MiniLM-L6-v2")
TOP_K_DOCS = int(os.getenv("TOP_K_DOCS", 3))
MAX_CONTEXT_CHARS = int(os.getenv("MAX_CONTEXT_CHARS", 4000))
# FAISS params
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 256))
FAISS_NLIST = int(os.getenv("FAISS_NLIST", 100))
FAISS_TRAIN_SIZE = int(os.getenv("FAISS_TRAIN_SIZE", 2000))
FAISS_NPROBE = int(os.getenv("FAISS_NPROBE", 10))
SEARCH_EXPANSION = int(os.getenv("FAISS_SEARCH_EXPANSION", 5))
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
from huggingface_hub import hf_hub_download
import os
DATASET_REPO = "07Codex07/PrepGraph-Data"
def ensure_data_dir():
"""Ensure data/ folder exists and contains the Hugging Face dataset PDFs properly structured."""
from huggingface_hub import hf_hub_download
import shutil
data_dir = os.getenv("DATA_DIR", "data")
os.makedirs(data_dir, exist_ok=True)
files = [
"cn.pdf",
"dos.pdf",
"pyqs/cn_pyq_2019.pdf",
"pyqs/cn_pyq_2020.pdf",
"pyqs/cn_pyq_2022.pdf",
"pyqs/cn_pyq_2023.pdf",
"pyqs/cn_pyq_2024.pdf",
"pyqs/cn_pyq_2028.pdf",
"pyqs/dos_pyq_2019.pdf",
"pyqs/dos_pyq_2020.pdf",
"pyqs/dos_pyq_2024.pdf",
"pyqs/se_pyq_2018.pdf",
"pyqs/se_pyq_2019.pdf",
"pyqs/se_pyq_2020(S).pdf",
"pyqs/se_pyq_2020.pdf",
"pyqs/se_pyq_2022.pdf",
"pyqs/se_pyq_2024.pdf",
]
local_paths = []
downloaded_count = 0
for f in files:
dest_path = os.path.join(data_dir, f) # ✅ keep real folder structure
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
if not os.path.exists(dest_path):
logger.info(f"📥 Downloading {f} from Hugging Face (public dataset)...")
downloaded = hf_hub_download(
repo_id=DATASET_REPO,
filename=f,
repo_type="dataset",
force_download=True,
)
shutil.copy(downloaded, dest_path) # ✅ copy instead of rename (works inside HF Spaces)
downloaded_count += 1
local_paths.append(dest_path)
# Only print summary if files were actually downloaded
if downloaded_count > 0:
logger.info(f"✅ Downloaded {downloaded_count} new file(s). Total files ensured: {len(local_paths)}")
for p in local_paths[:3]:
logger.debug(f" → {p}")
else:
logger.debug(f"✅ All {len(local_paths)} data files already exist")
return local_paths
def detect_subject(fname: str) -> Optional[str]:
# light heuristic to guess subject code from filename
t = (fname or "").lower()
if "network" in t or "cn" in t:
return "cn"
if "distributed" in t or "dos" in t:
return "dos"
if "software" in t or "se" in t:
return "se"
return None
def extract_year(s: str) -> Optional[str]:
m = re.search(r"\b(20\d{2})\b", s)
return m.group(1) if m else None
# ---------- Embeddings wrapper (SentenceTransformers) ----------
class Embeddings:
def __init__(self, model_name=EMBED_MODEL):
self.model_name = model_name
self.model = SentenceTransformer(model_name)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
vecs = self.model.encode(texts, show_progress_bar=False, convert_to_numpy=True)
return [v.astype("float32") for v in vecs]
def embed_query(self, text: str) -> List[float]:
v = self.model.encode([text], show_progress_bar=False, convert_to_numpy=True)[0]
return v.astype("float32")
# ---------- Load documents ----------
def load_all_docs(base_dir: str = DATA_DIR) -> List:
docs = []
if not os.path.isdir(base_dir):
logger.warning("Data dir does not exist: %s", base_dir)
return docs
def load_file(path: str, filename: str, category: str):
try:
fname = filename.lower()
if fname.endswith(".pdf"):
loader = PyPDFLoader(path)
elif fname.endswith(".pptx"):
loader = UnstructuredPowerPointLoader(path)
else:
return []
file_docs = loader.load()
subject = detect_subject(fname)
year = extract_year(fname)
for d in file_docs:
d.metadata["subject"] = subject
d.metadata["filename"] = filename
d.metadata["category"] = category
if year:
d.metadata["year"] = year
return file_docs
except Exception:
logger.exception("Failed to load %s", filename)
return []
# root files
for file in os.listdir(base_dir):
path = os.path.join(base_dir, file)
if os.path.isfile(path) and (file.lower().endswith(".pdf") or file.lower().endswith(".pptx")):
docs.extend(load_file(path, file, "syllabus"))
# optional pyqs directory
pyqs_dir = os.path.join(base_dir, "pyqs")
if os.path.isdir(pyqs_dir):
for file in os.listdir(pyqs_dir):
path = os.path.join(pyqs_dir, file)
if os.path.isfile(path) and file.lower().endswith(".pdf"):
docs.extend(load_file(path, file, "pyq"))
logger.info("Loaded %d raw document pages", len(docs))
return docs
# ---------- Build / load FAISS + BM25 ----------
def build_or_load_indexes(force_reindex: bool = False):
"""Build or load FAISS and BM25 indexes. Returns (chunks, bm25, tokenized, corpus_texts, faiss_data)."""
if os.getenv("FORCE_REINDEX", "0").lower() in ("1", "true", "yes"):
force_reindex = True
# Only ensure data dir if files don't exist (check a sample file to avoid repeated calls)
sample_file = os.path.join(DATA_DIR, "cn.pdf")
if not os.path.exists(sample_file) or force_reindex:
logger.info("Data files missing or force_reindex=True, ensuring data directory...")
ensure_data_dir()
else:
logger.debug("Data files already exist, skipping ensure_data_dir()")
docs = load_all_docs(DATA_DIR)
if not docs:
logger.warning("No documents found in %s. Returning empty indexes.", DATA_DIR)
return [], None, [], [], None
logger.info("Loaded %d document pages from %s", len(docs), DATA_DIR)
# chunking
if os.path.exists(CHUNKS_CACHE) and not force_reindex:
with open(CHUNKS_CACHE, "rb") as f:
chunks = pickle.load(f)
logger.info("Loaded %d chunks from cache.", len(chunks))
else:
splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
chunks = splitter.split_documents(docs)
with open(CHUNKS_CACHE, "wb") as f:
pickle.dump(chunks, f)
logger.info("Created and cached %d chunks.", len(chunks))
corpus_texts = [c.page_content for c in chunks]
# BM25
if os.path.exists(BM25_CACHE) and not force_reindex:
try:
with open(BM25_CACHE, "rb") as f:
bm25_data = pickle.load(f)
bm25 = bm25_data.get("bm25")
tokenized = bm25_data.get("tokenized", [])
logger.info("Loaded BM25 from cache (n=%d)", len(corpus_texts))
except Exception:
logger.exception("Failed to load BM25 cache — rebuilding")
tokenized = [re.findall(r"\w+", t.lower()) for t in corpus_texts]
bm25 = BM25Okapi(tokenized)
with open(BM25_CACHE, "wb") as f:
pickle.dump({"bm25": bm25, "tokenized": tokenized}, f)
else:
tokenized = [re.findall(r"\w+", t.lower()) for t in corpus_texts]
bm25 = BM25Okapi(tokenized)
try:
with open(BM25_CACHE, "wb") as f:
pickle.dump({"bm25": bm25, "tokenized": tokenized}, f)
except Exception:
logger.warning("Could not write BM25 cache")
# Embeddings
embeddings = Embeddings()
metadatas = [c.metadata for c in chunks]
# load existing faiss index
if os.path.exists(FAISS_INDEX_PATH) and os.path.exists(FAISS_META_PATH) and not force_reindex:
try:
index = faiss.read_index(FAISS_INDEX_PATH)
with open(FAISS_META_PATH, "rb") as f:
meta = pickle.load(f)
texts = meta.get("texts", corpus_texts)
metadatas = meta.get("metadatas", metadatas)
try:
index.nprobe = FAISS_NPROBE
except Exception:
pass
logger.info("Loaded FAISS index from disk (%s), entries=%d", FAISS_INDEX_PATH, len(texts))
return chunks, bm25, tokenized, corpus_texts, {"index": index, "texts": texts, "metadatas": metadatas, "embeddings": embeddings}
except Exception:
logger.exception("Failed to load FAISS index; rebuilding")
# force reindex cleanup
if force_reindex:
try:
shutil.rmtree(FAISS_DIR, ignore_errors=True)
os.makedirs(FAISS_DIR, exist_ok=True)
except Exception:
pass
# Build FAISS (memory-aware, batch)
logger.info("Building FAISS index (nlist=%d). This may take a while...", FAISS_NLIST)
total = len(corpus_texts)
sample_size = min(total, FAISS_TRAIN_SIZE)
sample_indices = random.sample(range(total), sample_size) if sample_size < total else list(range(total))
sample_embs = []
for i in range(0, len(sample_indices), BATCH_SIZE):
batch_idx = sample_indices[i:i + BATCH_SIZE]
batch_texts = [corpus_texts[j] for j in batch_idx]
try:
batch_vecs = embeddings.embed_documents(batch_texts)
except Exception:
batch_vecs = [embeddings.embed_query(t) for t in batch_texts]
sample_embs.extend(batch_vecs)
sample_np = np.array(sample_embs, dtype="float32")
if sample_np.ndim == 1:
sample_np = sample_np.reshape(1, -1)
d = sample_np.shape[1]
n_train_samples = sample_np.shape[0]
use_ivf = True
if n_train_samples < FAISS_NLIST:
logger.warning("Not enough training samples (%d) for FAISS_NLIST=%d — using Flat index", n_train_samples, FAISS_NLIST)
use_ivf = False
try:
if use_ivf:
index_desc = f"IVF{FAISS_NLIST},Flat"
index = faiss.index_factory(d, index_desc, faiss.METRIC_L2)
if not index.is_trained:
try:
index.train(sample_np)
logger.info("Trained IVF on %d samples", n_train_samples)
except Exception:
logger.exception("IVF training failed — falling back to Flat")
index = faiss.index_factory(d, "Flat", faiss.METRIC_L2)
else:
index = faiss.index_factory(d, "Flat", faiss.METRIC_L2)
except Exception:
logger.exception("Failed to create FAISS index — using Flat")
index = faiss.index_factory(d, "Flat", faiss.METRIC_L2)
# add vectors in batches
added = 0
for i in range(0, total, BATCH_SIZE):
batch_texts = corpus_texts[i:i + BATCH_SIZE]
try:
batch_vecs = embeddings.embed_documents(batch_texts)
except Exception:
batch_vecs = [embeddings.embed_query(t) for t in batch_texts]
batch_np = np.array(batch_vecs, dtype="float32")
if batch_np.ndim == 1:
batch_np = batch_np.reshape(1, -1)
index.add(batch_np)
added += batch_np.shape[0]
logger.info("FAISS: added %d / %d vectors", added, total)
try:
index.nprobe = FAISS_NPROBE
except Exception:
pass
try:
faiss.write_index(index, FAISS_INDEX_PATH)
with open(FAISS_META_PATH, "wb") as f:
pickle.dump({
"texts": corpus_texts,
"metadatas": metadatas
}, f)
logger.info("FAISS index saved to %s (entries=%d)", FAISS_INDEX_PATH, total)
except Exception:
logger.exception("Failed to persist FAISS index on disk")
return chunks, bm25, tokenized, corpus_texts, {"index": index, "texts": corpus_texts, "metadatas": metadatas, "embeddings": embeddings}
# ---------- Hybrid retrieve ----------
def _ensure_index_built():
"""Ensure indexes are built. Only rebuilds if not already initialized."""
if not hasattr(hybrid_retrieve, "_index_built") or not hybrid_retrieve._index_built:
logger.info("Initializing indexes for hybrid_retrieve...")
hybrid_retrieve._chunks, hybrid_retrieve._bm25, hybrid_retrieve._tokenized, hybrid_retrieve._corpus, hybrid_retrieve._faiss = build_or_load_indexes()
hybrid_retrieve._index_built = True
logger.info("Indexes initialized: %d chunks available", len(hybrid_retrieve._chunks) if hybrid_retrieve._chunks else 0)
def _faiss_search(query: str, top_k: int = TOP_K_DOCS, subject: Optional[str] = None):
top_k = top_k or TOP_K_DOCS
faiss_data = hybrid_retrieve._faiss
if not faiss_data:
return []
index = faiss_data.get("index")
texts = faiss_data.get("texts", [])
metadatas = faiss_data.get("metadatas", [{}] * len(texts))
embeddings = faiss_data.get("embeddings")
try:
q_vec = embeddings.embed_query(query)
except Exception:
q_vec = embeddings.embed_documents([query])[0]
q_np = np.array(q_vec, dtype="float32").reshape(1, -1)
search_k = max(top_k * SEARCH_EXPANSION, top_k)
try:
distances, indices = index.search(q_np, int(search_k))
except Exception:
distances, indices = index.search(q_np, int(top_k))
results = []
for dist, idx in zip(distances[0], indices[0]):
if idx < 0 or idx >= len(texts):
continue
meta = metadatas[idx]
# subject filtering disabled because it blocks many relevant chunks
# if subject and meta.get("subject") != subject:
# continue
score_like = float(-dist)
results.append((score_like, meta, texts[idx]))
if len(results) >= top_k:
break
return results
def hybrid_retrieve(query: str, subject: Optional[str] = None, top_k: int = TOP_K_DOCS, max_chars: int = MAX_CONTEXT_CHARS) -> Dict[str, Any]:
if not query:
logger.warning("hybrid_retrieve called with empty query")
return {"context": None, "bm25_docs": [], "faiss_docs": [], "meta": []}
_ensure_index_built()
chunks = hybrid_retrieve._chunks
bm25 = hybrid_retrieve._bm25
if not chunks:
logger.error("No chunks available for retrieval. Indexes may not be built correctly.")
return {"context": None, "bm25_docs": [], "faiss_docs": [], "meta": []}
logger.debug("Retrieving for query: %s (top_k=%d)", query[:50], top_k)
# BM25
results_bm25 = []
try:
if bm25:
q_tokens = re.findall(r"\w+", query.lower())
scores = bm25.get_scores(q_tokens)
ranked_idx = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k]
for i in ranked_idx:
if i < len(chunks):
results_bm25.append((float(scores[i]), chunks[i].metadata, chunks[i].page_content))
logger.debug("BM25 found %d results", len(results_bm25))
else:
logger.warning("BM25 index is None")
except Exception:
logger.exception("BM25 search failed")
# FAISS
results_faiss = []
try:
results_faiss = _faiss_search(query, top_k=top_k, subject=subject)
logger.debug("FAISS found %d results", len(results_faiss))
except Exception:
logger.exception("FAISS search failed")
# Merge and dedupe by text
merged_texts = []
merged_meta = []
for score, meta, text in results_bm25:
if text and text.strip():
merged_texts.append(text)
merged_meta.append({
"source": meta.get("filename"),
"subject": meta.get("subject"),
"score": score
})
for score, meta, text in results_faiss:
if text and text.strip():
merged_texts.append(text)
merged_meta.append({
"source": meta.get("filename") if isinstance(meta, dict) else None,
"subject": meta.get("subject") if isinstance(meta, dict) else None,
"score": score
})
# compose context parts with headers
context_parts = []
seen_texts = set() # Deduplicate by text content
for i, t in enumerate(merged_texts):
# Deduplicate: skip if we've seen this text before
if t in seen_texts:
continue
seen_texts.add(t)
header = f"\n\n===== DOC {i+1} =====\n"
context_parts.append(header + t)
context = "\n".join(context_parts).strip()
if not context:
logger.warning("No context generated from retrieval for query: %s (BM25: %d, FAISS: %d results)",
query[:50], len(results_bm25), len(results_faiss))
return {"context": None, "bm25_docs": results_bm25, "faiss_docs": results_faiss, "meta": merged_meta}
if len(context) > max_chars:
context = context[:max_chars].rstrip() + "..."
logger.debug("Context truncated from %d to %d characters", len("\n".join(context_parts)), max_chars)
logger.info("Retrieved context: %d characters from %d documents", len(context), len(context_parts))
return {"context": context, "bm25_docs": results_bm25, "faiss_docs": results_faiss, "meta": merged_meta}
# ---------- retrieve_node (for reuse) ----------
def _last_n_user_messages(rows: List[tuple], n: int = 1) -> List[str]:
users = [r[1] for r in rows if r[0] == "user"]
return users[-1:] # always return ONLY the latest user query # only keep the last one
def retrieve_node_from_rows(rows: List[tuple], top_k: int = TOP_K_DOCS) -> Dict[str, Any]:
"""Retrieve context from documents based on the last user message in rows."""
last_users = _last_n_user_messages(rows, n=1)
current_query = " ".join(last_users).strip() if last_users else ""
if not current_query:
logger.warning("retrieve_node_from_rows: No user query found in rows")
return {"context": None, "direct": False}
logger.debug("retrieve_node_from_rows: Query='%s'", current_query[:50])
detected = None
try:
detected = detect_subject(current_query)
if detected:
logger.debug("Detected subject: %s", detected)
except Exception:
detected = None
result = hybrid_retrieve(current_query, subject=detected, top_k=top_k, max_chars=MAX_CONTEXT_CHARS)
context = result.get("context")
if context:
logger.info("retrieve_node_from_rows: Successfully retrieved %d characters of context", len(context))
else:
logger.warning("retrieve_node_from_rows: No context retrieved for query: %s", current_query[:50])
return {"context": context, "direct": False}