Zubaish commited on
Commit
4efaf50
·
1 Parent(s): 1715fb7

Fix: HF dataset PDF loading + stable RAG

Browse files
Files changed (1) hide show
  1. rag.py +90 -46
rag.py CHANGED
@@ -1,89 +1,133 @@
1
  # rag.py
2
 
3
- from datasets import load_dataset
 
 
 
 
 
4
  from langchain_community.vectorstores import Chroma
5
  from langchain_community.embeddings import HuggingFaceEmbeddings
6
- from langchain_text_splitters import RecursiveCharacterTextSplitter
7
- from langchain.schema import Document
8
  from transformers import pipeline
9
 
10
- from config import HF_DATASET_REPO, EMBEDDING_MODEL, LLM_MODEL
 
 
 
 
 
 
 
11
 
12
- # ------------------------
13
- # Load documents from HF Dataset
14
- # ------------------------
15
  def load_documents():
16
- ds = load_dataset(HF_DATASET_REPO, split="train")
17
-
18
  docs = []
19
- for row in ds:
20
- text = row.get("text") or row.get("content")
21
- if text and text.strip():
22
- docs.append(Document(page_content=text))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  return docs
25
 
26
 
27
- # ------------------------
28
- # Build Vector DB (ONCE)
29
- # ------------------------
30
- documents = load_documents()
 
31
 
32
- if not documents:
33
- raise RuntimeError("No documents loaded from HF Dataset")
 
34
 
35
- splitter = RecursiveCharacterTextSplitter(
36
- chunk_size=500,
37
- chunk_overlap=50,
38
- )
39
 
40
- chunks = splitter.split_documents(documents)
41
 
42
- embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
 
 
 
 
 
 
 
 
 
 
43
 
44
- vectordb = Chroma.from_documents(
45
- documents=chunks,
46
- embedding=embeddings,
47
- )
48
 
49
- retriever = vectordb.as_retriever(search_kwargs={"k": 3})
 
50
 
51
- # ------------------------
52
- # LLM (CPU SAFE)
53
- # ------------------------
54
- llm = pipeline(
55
  "text2text-generation",
56
  model=LLM_MODEL,
57
- max_new_tokens=256,
58
  )
59
 
60
- # ------------------------
61
- # RAG Query
62
- # ------------------------
63
- def ask_rag_with_status(question: str):
 
64
  status = []
65
 
66
- status.append("🔎 Retrieving documents")
 
 
 
67
  docs = retriever.get_relevant_documents(question)
68
 
69
  if not docs:
70
- return "No relevant documents found.", status
71
 
72
  context = "\n\n".join(d.page_content for d in docs)
73
 
74
  prompt = f"""
75
- Answer the question using the context below.
76
 
77
  Context:
78
  {context}
79
 
80
  Question:
81
  {question}
82
-
83
- Answer:
84
  """
85
 
86
- status.append("🧠 Generating answer")
87
- result = llm(prompt)[0]["generated_text"]
 
 
88
 
89
  return result.strip(), status
 
1
  # rag.py
2
 
3
+ import os
4
+ from typing import List, Tuple
5
+
6
+ from huggingface_hub import hf_hub_download, list_repo_files
7
+ from langchain_community.document_loaders import PyPDFLoader
8
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
9
  from langchain_community.vectorstores import Chroma
10
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
 
11
  from transformers import pipeline
12
 
13
+ from config import (
14
+ HF_DATASET_REPO,
15
+ EMBEDDING_MODEL,
16
+ LLM_MODEL,
17
+ CHROMA_DIR,
18
+ CHUNK_SIZE,
19
+ CHUNK_OVERLAP,
20
+ )
21
 
22
+ # -----------------------------
23
+ # Load PDFs from HF Dataset repo
24
+ # -----------------------------
25
  def load_documents():
 
 
26
  docs = []
27
+
28
+ try:
29
+ files = list_repo_files(
30
+ repo_id=HF_DATASET_REPO,
31
+ repo_type="dataset"
32
+ )
33
+ except Exception as e:
34
+ print("❌ Could not access dataset:", e)
35
+ return []
36
+
37
+ pdf_files = [f for f in files if f.lower().endswith(".pdf")]
38
+
39
+ if not pdf_files:
40
+ print("⚠️ No PDFs found in dataset")
41
+ return []
42
+
43
+ os.makedirs("kb", exist_ok=True)
44
+
45
+ for pdf in pdf_files:
46
+ local_path = hf_hub_download(
47
+ repo_id=HF_DATASET_REPO,
48
+ filename=pdf,
49
+ repo_type="dataset"
50
+ )
51
+
52
+ loader = PyPDFLoader(local_path)
53
+ docs.extend(loader.load())
54
 
55
  return docs
56
 
57
 
58
+ # -----------------------------
59
+ # Build vector DB (safe)
60
+ # -----------------------------
61
+ def build_vectorstore():
62
+ documents = load_documents()
63
 
64
+ if not documents:
65
+ print("⚠️ No documents loaded, vector DB will be empty")
66
+ return None
67
 
68
+ splitter = RecursiveCharacterTextSplitter(
69
+ chunk_size=CHUNK_SIZE,
70
+ chunk_overlap=CHUNK_OVERLAP,
71
+ )
72
 
73
+ splits = splitter.split_documents(documents)
74
 
75
+ embeddings = HuggingFaceEmbeddings(
76
+ model_name=EMBEDDING_MODEL
77
+ )
78
+
79
+ vectordb = Chroma.from_documents(
80
+ documents=splits,
81
+ embedding=embeddings,
82
+ persist_directory=CHROMA_DIR
83
+ )
84
+
85
+ return vectordb
86
 
 
 
 
 
87
 
88
+ # Build once at startup
89
+ VECTOR_DB = build_vectorstore()
90
 
91
+ # -----------------------------
92
+ # LLM (CPU-safe)
93
+ # -----------------------------
94
+ qa_pipeline = pipeline(
95
  "text2text-generation",
96
  model=LLM_MODEL,
97
+ max_new_tokens=256
98
  )
99
 
100
+
101
+ # -----------------------------
102
+ # Public API
103
+ # -----------------------------
104
+ def ask_rag_with_status(question: str) -> Tuple[str, List[str]]:
105
  status = []
106
 
107
+ if VECTOR_DB is None:
108
+ return "No documents available.", ["Vector DB not initialized"]
109
+
110
+ retriever = VECTOR_DB.as_retriever(search_kwargs={"k": 3})
111
  docs = retriever.get_relevant_documents(question)
112
 
113
  if not docs:
114
+ return "No relevant information found.", ["No matching chunks"]
115
 
116
  context = "\n\n".join(d.page_content for d in docs)
117
 
118
  prompt = f"""
119
+ Answer the question using ONLY the context below.
120
 
121
  Context:
122
  {context}
123
 
124
  Question:
125
  {question}
 
 
126
  """
127
 
128
+ result = qa_pipeline(prompt)[0]["generated_text"]
129
+
130
+ status.append(f"Retrieved {len(docs)} chunks")
131
+ status.append("Answer generated")
132
 
133
  return result.strip(), status