# chatbot_retriever.py """ Hybrid retriever: - loads PDFs & PPTX (robust imports) - chunks via RecursiveCharacterTextSplitter - BM25 (rank_bm25) + FAISS (IVF when possible) using SentenceTransformers - returns a combined context string limited by MAX_CONTEXT_CHARS """ import os import re import pickle import logging import shutil import random from typing import List, Optional, Dict, Any import numpy as np import faiss from rank_bm25 import BM25Okapi from langchain_community.document_loaders import UnstructuredFileLoader # Document loaders: try langchain first, then community loader try: from langchain.document_loaders import PyPDFLoader, UnstructuredPowerPointLoader except Exception: # fallback to community package (older installations) try: from langchain_community.document_loaders import PyPDFLoader, UnstructuredPowerPointLoader from langchain_community.document_loaders.powerpoint import UnstructuredPowerPointLoader except Exception: raise ImportError("Please install langchain + langchain-community (or upgrade).") from langchain.text_splitter import RecursiveCharacterTextSplitter from sentence_transformers import SentenceTransformer # ---------- Config ---------- DATA_DIR = os.getenv("DATA_DIR", "data") CACHE_DIR = os.getenv("CACHE_DIR", ".ragg_cache") CHUNKS_CACHE = os.path.join(CACHE_DIR, "chunks.pkl") BM25_CACHE = os.path.join(CACHE_DIR, "bm25.pkl") FAISS_DIR = os.getenv("FAISS_DIR", "faiss_index") FAISS_INDEX_PATH = os.path.join(FAISS_DIR, "index.faiss") FAISS_META_PATH = os.path.join(FAISS_DIR, "meta.pkl") os.makedirs(CACHE_DIR, exist_ok=True) os.makedirs(FAISS_DIR, exist_ok=True) CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", 400)) CHUNK_OVERLAP = int(os.getenv("CHUNK_OVERLAP", 80)) EMBED_MODEL = os.getenv("EMBED_MODEL", "all-MiniLM-L6-v2") TOP_K_DOCS = int(os.getenv("TOP_K_DOCS", 3)) MAX_CONTEXT_CHARS = int(os.getenv("MAX_CONTEXT_CHARS", 4000)) # FAISS params BATCH_SIZE = int(os.getenv("BATCH_SIZE", 256)) FAISS_NLIST = int(os.getenv("FAISS_NLIST", 100)) FAISS_TRAIN_SIZE = int(os.getenv("FAISS_TRAIN_SIZE", 2000)) FAISS_NPROBE = int(os.getenv("FAISS_NPROBE", 10)) SEARCH_EXPANSION = int(os.getenv("FAISS_SEARCH_EXPANSION", 5)) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) from huggingface_hub import hf_hub_download import os DATASET_REPO = "07Codex07/PrepGraph-Data" def ensure_data_dir(): """Ensure data/ folder exists and contains the Hugging Face dataset PDFs properly structured.""" from huggingface_hub import hf_hub_download import shutil data_dir = os.getenv("DATA_DIR", "data") os.makedirs(data_dir, exist_ok=True) files = [ "cn.pdf", "dos.pdf", "pyqs/cn_pyq_2019.pdf", "pyqs/cn_pyq_2020.pdf", "pyqs/cn_pyq_2022.pdf", "pyqs/cn_pyq_2023.pdf", "pyqs/cn_pyq_2024.pdf", "pyqs/cn_pyq_2028.pdf", "pyqs/dos_pyq_2019.pdf", "pyqs/dos_pyq_2020.pdf", "pyqs/dos_pyq_2024.pdf", "pyqs/se_pyq_2018.pdf", "pyqs/se_pyq_2019.pdf", "pyqs/se_pyq_2020(S).pdf", "pyqs/se_pyq_2020.pdf", "pyqs/se_pyq_2022.pdf", "pyqs/se_pyq_2024.pdf", ] local_paths = [] downloaded_count = 0 for f in files: dest_path = os.path.join(data_dir, f) # ✅ keep real folder structure os.makedirs(os.path.dirname(dest_path), exist_ok=True) if not os.path.exists(dest_path): logger.info(f"📥 Downloading {f} from Hugging Face (public dataset)...") downloaded = hf_hub_download( repo_id=DATASET_REPO, filename=f, repo_type="dataset", force_download=True, ) shutil.copy(downloaded, dest_path) # ✅ copy instead of rename (works inside HF Spaces) downloaded_count += 1 local_paths.append(dest_path) # Only print summary if files were actually downloaded if downloaded_count > 0: logger.info(f"✅ Downloaded {downloaded_count} new file(s). Total files ensured: {len(local_paths)}") for p in local_paths[:3]: logger.debug(f" → {p}") else: logger.debug(f"✅ All {len(local_paths)} data files already exist") return local_paths def detect_subject(fname: str) -> Optional[str]: # light heuristic to guess subject code from filename t = (fname or "").lower() if "network" in t or "cn" in t: return "cn" if "distributed" in t or "dos" in t: return "dos" if "software" in t or "se" in t: return "se" return None def extract_year(s: str) -> Optional[str]: m = re.search(r"\b(20\d{2})\b", s) return m.group(1) if m else None # ---------- Embeddings wrapper (SentenceTransformers) ---------- class Embeddings: def __init__(self, model_name=EMBED_MODEL): self.model_name = model_name self.model = SentenceTransformer(model_name) def embed_documents(self, texts: List[str]) -> List[List[float]]: vecs = self.model.encode(texts, show_progress_bar=False, convert_to_numpy=True) return [v.astype("float32") for v in vecs] def embed_query(self, text: str) -> List[float]: v = self.model.encode([text], show_progress_bar=False, convert_to_numpy=True)[0] return v.astype("float32") # ---------- Load documents ---------- def load_all_docs(base_dir: str = DATA_DIR) -> List: docs = [] if not os.path.isdir(base_dir): logger.warning("Data dir does not exist: %s", base_dir) return docs def load_file(path: str, filename: str, category: str): try: fname = filename.lower() if fname.endswith(".pdf"): loader = PyPDFLoader(path) elif fname.endswith(".pptx"): loader = UnstructuredPowerPointLoader(path) else: return [] file_docs = loader.load() subject = detect_subject(fname) year = extract_year(fname) for d in file_docs: d.metadata["subject"] = subject d.metadata["filename"] = filename d.metadata["category"] = category if year: d.metadata["year"] = year return file_docs except Exception: logger.exception("Failed to load %s", filename) return [] # root files for file in os.listdir(base_dir): path = os.path.join(base_dir, file) if os.path.isfile(path) and (file.lower().endswith(".pdf") or file.lower().endswith(".pptx")): docs.extend(load_file(path, file, "syllabus")) # optional pyqs directory pyqs_dir = os.path.join(base_dir, "pyqs") if os.path.isdir(pyqs_dir): for file in os.listdir(pyqs_dir): path = os.path.join(pyqs_dir, file) if os.path.isfile(path) and file.lower().endswith(".pdf"): docs.extend(load_file(path, file, "pyq")) logger.info("Loaded %d raw document pages", len(docs)) return docs # ---------- Build / load FAISS + BM25 ---------- def build_or_load_indexes(force_reindex: bool = False): """Build or load FAISS and BM25 indexes. Returns (chunks, bm25, tokenized, corpus_texts, faiss_data).""" if os.getenv("FORCE_REINDEX", "0").lower() in ("1", "true", "yes"): force_reindex = True # Only ensure data dir if files don't exist (check a sample file to avoid repeated calls) sample_file = os.path.join(DATA_DIR, "cn.pdf") if not os.path.exists(sample_file) or force_reindex: logger.info("Data files missing or force_reindex=True, ensuring data directory...") ensure_data_dir() else: logger.debug("Data files already exist, skipping ensure_data_dir()") docs = load_all_docs(DATA_DIR) if not docs: logger.warning("No documents found in %s. Returning empty indexes.", DATA_DIR) return [], None, [], [], None logger.info("Loaded %d document pages from %s", len(docs), DATA_DIR) # chunking if os.path.exists(CHUNKS_CACHE) and not force_reindex: with open(CHUNKS_CACHE, "rb") as f: chunks = pickle.load(f) logger.info("Loaded %d chunks from cache.", len(chunks)) else: splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) chunks = splitter.split_documents(docs) with open(CHUNKS_CACHE, "wb") as f: pickle.dump(chunks, f) logger.info("Created and cached %d chunks.", len(chunks)) corpus_texts = [c.page_content for c in chunks] # BM25 if os.path.exists(BM25_CACHE) and not force_reindex: try: with open(BM25_CACHE, "rb") as f: bm25_data = pickle.load(f) bm25 = bm25_data.get("bm25") tokenized = bm25_data.get("tokenized", []) logger.info("Loaded BM25 from cache (n=%d)", len(corpus_texts)) except Exception: logger.exception("Failed to load BM25 cache — rebuilding") tokenized = [re.findall(r"\w+", t.lower()) for t in corpus_texts] bm25 = BM25Okapi(tokenized) with open(BM25_CACHE, "wb") as f: pickle.dump({"bm25": bm25, "tokenized": tokenized}, f) else: tokenized = [re.findall(r"\w+", t.lower()) for t in corpus_texts] bm25 = BM25Okapi(tokenized) try: with open(BM25_CACHE, "wb") as f: pickle.dump({"bm25": bm25, "tokenized": tokenized}, f) except Exception: logger.warning("Could not write BM25 cache") # Embeddings embeddings = Embeddings() metadatas = [c.metadata for c in chunks] # load existing faiss index if os.path.exists(FAISS_INDEX_PATH) and os.path.exists(FAISS_META_PATH) and not force_reindex: try: index = faiss.read_index(FAISS_INDEX_PATH) with open(FAISS_META_PATH, "rb") as f: meta = pickle.load(f) texts = meta.get("texts", corpus_texts) metadatas = meta.get("metadatas", metadatas) try: index.nprobe = FAISS_NPROBE except Exception: pass logger.info("Loaded FAISS index from disk (%s), entries=%d", FAISS_INDEX_PATH, len(texts)) return chunks, bm25, tokenized, corpus_texts, {"index": index, "texts": texts, "metadatas": metadatas, "embeddings": embeddings} except Exception: logger.exception("Failed to load FAISS index; rebuilding") # force reindex cleanup if force_reindex: try: shutil.rmtree(FAISS_DIR, ignore_errors=True) os.makedirs(FAISS_DIR, exist_ok=True) except Exception: pass # Build FAISS (memory-aware, batch) logger.info("Building FAISS index (nlist=%d). This may take a while...", FAISS_NLIST) total = len(corpus_texts) sample_size = min(total, FAISS_TRAIN_SIZE) sample_indices = random.sample(range(total), sample_size) if sample_size < total else list(range(total)) sample_embs = [] for i in range(0, len(sample_indices), BATCH_SIZE): batch_idx = sample_indices[i:i + BATCH_SIZE] batch_texts = [corpus_texts[j] for j in batch_idx] try: batch_vecs = embeddings.embed_documents(batch_texts) except Exception: batch_vecs = [embeddings.embed_query(t) for t in batch_texts] sample_embs.extend(batch_vecs) sample_np = np.array(sample_embs, dtype="float32") if sample_np.ndim == 1: sample_np = sample_np.reshape(1, -1) d = sample_np.shape[1] n_train_samples = sample_np.shape[0] use_ivf = True if n_train_samples < FAISS_NLIST: logger.warning("Not enough training samples (%d) for FAISS_NLIST=%d — using Flat index", n_train_samples, FAISS_NLIST) use_ivf = False try: if use_ivf: index_desc = f"IVF{FAISS_NLIST},Flat" index = faiss.index_factory(d, index_desc, faiss.METRIC_L2) if not index.is_trained: try: index.train(sample_np) logger.info("Trained IVF on %d samples", n_train_samples) except Exception: logger.exception("IVF training failed — falling back to Flat") index = faiss.index_factory(d, "Flat", faiss.METRIC_L2) else: index = faiss.index_factory(d, "Flat", faiss.METRIC_L2) except Exception: logger.exception("Failed to create FAISS index — using Flat") index = faiss.index_factory(d, "Flat", faiss.METRIC_L2) # add vectors in batches added = 0 for i in range(0, total, BATCH_SIZE): batch_texts = corpus_texts[i:i + BATCH_SIZE] try: batch_vecs = embeddings.embed_documents(batch_texts) except Exception: batch_vecs = [embeddings.embed_query(t) for t in batch_texts] batch_np = np.array(batch_vecs, dtype="float32") if batch_np.ndim == 1: batch_np = batch_np.reshape(1, -1) index.add(batch_np) added += batch_np.shape[0] logger.info("FAISS: added %d / %d vectors", added, total) try: index.nprobe = FAISS_NPROBE except Exception: pass try: faiss.write_index(index, FAISS_INDEX_PATH) with open(FAISS_META_PATH, "wb") as f: pickle.dump({ "texts": corpus_texts, "metadatas": metadatas }, f) logger.info("FAISS index saved to %s (entries=%d)", FAISS_INDEX_PATH, total) except Exception: logger.exception("Failed to persist FAISS index on disk") return chunks, bm25, tokenized, corpus_texts, {"index": index, "texts": corpus_texts, "metadatas": metadatas, "embeddings": embeddings} # ---------- Hybrid retrieve ---------- def _ensure_index_built(): """Ensure indexes are built. Only rebuilds if not already initialized.""" if not hasattr(hybrid_retrieve, "_index_built") or not hybrid_retrieve._index_built: logger.info("Initializing indexes for hybrid_retrieve...") hybrid_retrieve._chunks, hybrid_retrieve._bm25, hybrid_retrieve._tokenized, hybrid_retrieve._corpus, hybrid_retrieve._faiss = build_or_load_indexes() hybrid_retrieve._index_built = True logger.info("Indexes initialized: %d chunks available", len(hybrid_retrieve._chunks) if hybrid_retrieve._chunks else 0) def _faiss_search(query: str, top_k: int = TOP_K_DOCS, subject: Optional[str] = None): top_k = top_k or TOP_K_DOCS faiss_data = hybrid_retrieve._faiss if not faiss_data: return [] index = faiss_data.get("index") texts = faiss_data.get("texts", []) metadatas = faiss_data.get("metadatas", [{}] * len(texts)) embeddings = faiss_data.get("embeddings") try: q_vec = embeddings.embed_query(query) except Exception: q_vec = embeddings.embed_documents([query])[0] q_np = np.array(q_vec, dtype="float32").reshape(1, -1) search_k = max(top_k * SEARCH_EXPANSION, top_k) try: distances, indices = index.search(q_np, int(search_k)) except Exception: distances, indices = index.search(q_np, int(top_k)) results = [] for dist, idx in zip(distances[0], indices[0]): if idx < 0 or idx >= len(texts): continue meta = metadatas[idx] # subject filtering disabled because it blocks many relevant chunks # if subject and meta.get("subject") != subject: # continue score_like = float(-dist) results.append((score_like, meta, texts[idx])) if len(results) >= top_k: break return results def hybrid_retrieve(query: str, subject: Optional[str] = None, top_k: int = TOP_K_DOCS, max_chars: int = MAX_CONTEXT_CHARS) -> Dict[str, Any]: if not query: logger.warning("hybrid_retrieve called with empty query") return {"context": None, "bm25_docs": [], "faiss_docs": [], "meta": []} _ensure_index_built() chunks = hybrid_retrieve._chunks bm25 = hybrid_retrieve._bm25 if not chunks: logger.error("No chunks available for retrieval. Indexes may not be built correctly.") return {"context": None, "bm25_docs": [], "faiss_docs": [], "meta": []} logger.debug("Retrieving for query: %s (top_k=%d)", query[:50], top_k) # BM25 results_bm25 = [] try: if bm25: q_tokens = re.findall(r"\w+", query.lower()) scores = bm25.get_scores(q_tokens) ranked_idx = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k] for i in ranked_idx: if i < len(chunks): results_bm25.append((float(scores[i]), chunks[i].metadata, chunks[i].page_content)) logger.debug("BM25 found %d results", len(results_bm25)) else: logger.warning("BM25 index is None") except Exception: logger.exception("BM25 search failed") # FAISS results_faiss = [] try: results_faiss = _faiss_search(query, top_k=top_k, subject=subject) logger.debug("FAISS found %d results", len(results_faiss)) except Exception: logger.exception("FAISS search failed") # Merge and dedupe by text merged_texts = [] merged_meta = [] for score, meta, text in results_bm25: if text and text.strip(): merged_texts.append(text) merged_meta.append({ "source": meta.get("filename"), "subject": meta.get("subject"), "score": score }) for score, meta, text in results_faiss: if text and text.strip(): merged_texts.append(text) merged_meta.append({ "source": meta.get("filename") if isinstance(meta, dict) else None, "subject": meta.get("subject") if isinstance(meta, dict) else None, "score": score }) # compose context parts with headers context_parts = [] seen_texts = set() # Deduplicate by text content for i, t in enumerate(merged_texts): # Deduplicate: skip if we've seen this text before if t in seen_texts: continue seen_texts.add(t) header = f"\n\n===== DOC {i+1} =====\n" context_parts.append(header + t) context = "\n".join(context_parts).strip() if not context: logger.warning("No context generated from retrieval for query: %s (BM25: %d, FAISS: %d results)", query[:50], len(results_bm25), len(results_faiss)) return {"context": None, "bm25_docs": results_bm25, "faiss_docs": results_faiss, "meta": merged_meta} if len(context) > max_chars: context = context[:max_chars].rstrip() + "..." logger.debug("Context truncated from %d to %d characters", len("\n".join(context_parts)), max_chars) logger.info("Retrieved context: %d characters from %d documents", len(context), len(context_parts)) return {"context": context, "bm25_docs": results_bm25, "faiss_docs": results_faiss, "meta": merged_meta} # ---------- retrieve_node (for reuse) ---------- def _last_n_user_messages(rows: List[tuple], n: int = 1) -> List[str]: users = [r[1] for r in rows if r[0] == "user"] return users[-1:] # always return ONLY the latest user query # only keep the last one def retrieve_node_from_rows(rows: List[tuple], top_k: int = TOP_K_DOCS) -> Dict[str, Any]: """Retrieve context from documents based on the last user message in rows.""" last_users = _last_n_user_messages(rows, n=1) current_query = " ".join(last_users).strip() if last_users else "" if not current_query: logger.warning("retrieve_node_from_rows: No user query found in rows") return {"context": None, "direct": False} logger.debug("retrieve_node_from_rows: Query='%s'", current_query[:50]) detected = None try: detected = detect_subject(current_query) if detected: logger.debug("Detected subject: %s", detected) except Exception: detected = None result = hybrid_retrieve(current_query, subject=detected, top_k=top_k, max_chars=MAX_CONTEXT_CHARS) context = result.get("context") if context: logger.info("retrieve_node_from_rows: Successfully retrieved %d characters of context", len(context)) else: logger.warning("retrieve_node_from_rows: No context retrieved for query: %s", current_query[:50]) return {"context": context, "direct": False}