Zubaish commited on
Commit
c488d16
·
1 Parent(s): 4efaf50

Rollback: stable local RAG

Browse files
Files changed (2) hide show
  1. config.py +3 -1
  2. rag.py +39 -72
config.py CHANGED
@@ -35,4 +35,6 @@ LLM_MODEL = "google/flan-t5-small"
35
  # Text splitting
36
  # -----------------------------
37
  CHUNK_SIZE = 500
38
- CHUNK_OVERLAP = 50
 
 
 
35
  # Text splitting
36
  # -----------------------------
37
  CHUNK_SIZE = 500
38
+ CHUNK_OVERLAP = 50
39
+
40
+ KB_DIR = "./kb"
rag.py CHANGED
@@ -3,115 +3,83 @@
3
  import os
4
  from typing import List, Tuple
5
 
6
- from huggingface_hub import hf_hub_download, list_repo_files
7
  from langchain_community.document_loaders import PyPDFLoader
8
  from langchain_text_splitters import RecursiveCharacterTextSplitter
9
  from langchain_community.vectorstores import Chroma
10
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
11
  from transformers import pipeline
12
 
13
  from config import (
14
- HF_DATASET_REPO,
 
15
  EMBEDDING_MODEL,
16
  LLM_MODEL,
17
- CHROMA_DIR,
18
- CHUNK_SIZE,
19
- CHUNK_OVERLAP,
20
  )
21
 
22
  # -----------------------------
23
- # Load PDFs from HF Dataset repo
24
  # -----------------------------
25
- def load_documents():
26
  docs = []
27
 
28
- try:
29
- files = list_repo_files(
30
- repo_id=HF_DATASET_REPO,
31
- repo_type="dataset"
32
- )
33
- except Exception as e:
34
- print("❌ Could not access dataset:", e)
35
- return []
36
-
37
- pdf_files = [f for f in files if f.lower().endswith(".pdf")]
38
-
39
- if not pdf_files:
40
- print("⚠️ No PDFs found in dataset")
41
- return []
42
-
43
- os.makedirs("kb", exist_ok=True)
44
 
45
- for pdf in pdf_files:
46
- local_path = hf_hub_download(
47
- repo_id=HF_DATASET_REPO,
48
- filename=pdf,
49
- repo_type="dataset"
50
- )
51
-
52
- loader = PyPDFLoader(local_path)
53
- docs.extend(loader.load())
54
 
55
  return docs
56
 
57
 
58
  # -----------------------------
59
- # Build vector DB (safe)
60
  # -----------------------------
61
- def build_vectorstore():
62
- documents = load_documents()
63
-
64
- if not documents:
65
- print("⚠️ No documents loaded, vector DB will be empty")
66
- return None
67
-
68
- splitter = RecursiveCharacterTextSplitter(
69
- chunk_size=CHUNK_SIZE,
70
- chunk_overlap=CHUNK_OVERLAP,
71
- )
72
-
73
- splits = splitter.split_documents(documents)
74
 
75
- embeddings = HuggingFaceEmbeddings(
76
- model_name=EMBEDDING_MODEL
77
- )
 
78
 
79
- vectordb = Chroma.from_documents(
80
- documents=splits,
81
- embedding=embeddings,
82
- persist_directory=CHROMA_DIR
83
- )
84
 
85
- return vectordb
 
 
86
 
 
 
 
 
 
87
 
88
- # Build once at startup
89
- VECTOR_DB = build_vectorstore()
90
 
91
  # -----------------------------
92
- # LLM (CPU-safe)
93
  # -----------------------------
94
- qa_pipeline = pipeline(
95
  "text2text-generation",
96
  model=LLM_MODEL,
97
- max_new_tokens=256
98
  )
99
 
100
-
101
  # -----------------------------
102
- # Public API
103
  # -----------------------------
104
- def ask_rag_with_status(question: str) -> Tuple[str, List[str]]:
105
  status = []
106
 
107
- if VECTOR_DB is None:
108
- return "No documents available.", ["Vector DB not initialized"]
109
 
110
- retriever = VECTOR_DB.as_retriever(search_kwargs={"k": 3})
111
  docs = retriever.get_relevant_documents(question)
112
 
113
- if not docs:
114
- return "No relevant information found.", ["No matching chunks"]
115
 
116
  context = "\n\n".join(d.page_content for d in docs)
117
 
@@ -123,11 +91,10 @@ Context:
123
 
124
  Question:
125
  {question}
126
- """
127
 
128
- result = qa_pipeline(prompt)[0]["generated_text"]
 
129
 
130
- status.append(f"Retrieved {len(docs)} chunks")
131
- status.append("Answer generated")
132
 
133
  return result.strip(), status
 
3
  import os
4
  from typing import List, Tuple
5
 
 
6
  from langchain_community.document_loaders import PyPDFLoader
7
  from langchain_text_splitters import RecursiveCharacterTextSplitter
8
  from langchain_community.vectorstores import Chroma
9
  from langchain_community.embeddings import HuggingFaceEmbeddings
10
+ from langchain.schema import Document
11
  from transformers import pipeline
12
 
13
  from config import (
14
+ KB_DIR,
15
+ CHROMA_DIR,
16
  EMBEDDING_MODEL,
17
  LLM_MODEL,
 
 
 
18
  )
19
 
20
  # -----------------------------
21
+ # Load documents
22
  # -----------------------------
23
+ def load_documents() -> List[Document]:
24
  docs = []
25
 
26
+ if not os.path.exists(KB_DIR):
27
+ print(f"⚠️ KB_DIR not found: {KB_DIR}")
28
+ return docs
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ for file in os.listdir(KB_DIR):
31
+ if file.lower().endswith(".pdf"):
32
+ loader = PyPDFLoader(os.path.join(KB_DIR, file))
33
+ docs.extend(loader.load())
 
 
 
 
 
34
 
35
  return docs
36
 
37
 
38
  # -----------------------------
39
+ # Build vector DB (once)
40
  # -----------------------------
41
+ documents = load_documents()
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ splitter = RecursiveCharacterTextSplitter(
44
+ chunk_size=800,
45
+ chunk_overlap=100
46
+ )
47
 
48
+ chunks = splitter.split_documents(documents)
 
 
 
 
49
 
50
+ embeddings = HuggingFaceEmbeddings(
51
+ model_name=EMBEDDING_MODEL
52
+ )
53
 
54
+ vectordb = Chroma.from_documents(
55
+ documents=chunks,
56
+ embedding=embeddings,
57
+ persist_directory=CHROMA_DIR
58
+ )
59
 
60
+ retriever = vectordb.as_retriever(search_kwargs={"k": 3})
 
61
 
62
  # -----------------------------
63
+ # LLM (CORRECT task)
64
  # -----------------------------
65
+ llm = pipeline(
66
  "text2text-generation",
67
  model=LLM_MODEL,
68
+ device=-1
69
  )
70
 
 
71
  # -----------------------------
72
+ # RAG call
73
  # -----------------------------
74
+ def ask_rag_with_status(question: str) -> Tuple[str, list]:
75
  status = []
76
 
77
+ if vectordb._collection.count() == 0:
78
+ return "Knowledge base is empty.", ["No documents indexed"]
79
 
 
80
  docs = retriever.get_relevant_documents(question)
81
 
82
+ status.append(f"Retrieved {len(docs)} chunks")
 
83
 
84
  context = "\n\n".join(d.page_content for d in docs)
85
 
 
91
 
92
  Question:
93
  {question}
 
94
 
95
+ Answer:
96
+ """
97
 
98
+ result = llm(prompt, max_new_tokens=256)[0]["generated_text"]
 
99
 
100
  return result.strip(), status