Zubaish commited on
Commit
79ff3c4
·
1 Parent(s): c488d16

Rollback: stable local RAG

Browse files
Files changed (1) hide show
  1. rag.py +53 -57
rag.py CHANGED
@@ -1,90 +1,86 @@
1
  # rag.py
2
 
3
  import os
4
- from typing import List, Tuple
5
-
6
- from langchain_community.document_loaders import PyPDFLoader
7
- from langchain_text_splitters import RecursiveCharacterTextSplitter
8
  from langchain_community.vectorstores import Chroma
9
  from langchain_community.embeddings import HuggingFaceEmbeddings
10
- from langchain.schema import Document
11
- from transformers import pipeline
12
 
13
- from config import (
14
- KB_DIR,
15
- CHROMA_DIR,
16
- EMBEDDING_MODEL,
17
- LLM_MODEL,
18
- )
19
 
20
  # -----------------------------
21
- # Load documents
22
  # -----------------------------
23
- def load_documents() -> List[Document]:
24
- docs = []
25
 
26
- if not os.path.exists(KB_DIR):
27
- print(f"⚠️ KB_DIR not found: {KB_DIR}")
28
- return docs
 
 
29
 
30
- for file in os.listdir(KB_DIR):
31
- if file.lower().endswith(".pdf"):
32
- loader = PyPDFLoader(os.path.join(KB_DIR, file))
33
- docs.extend(loader.load())
 
34
 
35
- return docs
 
36
 
37
 
38
  # -----------------------------
39
- # Build vector DB (once)
40
  # -----------------------------
41
- documents = load_documents()
42
-
43
- splitter = RecursiveCharacterTextSplitter(
44
- chunk_size=800,
45
- chunk_overlap=100
46
- )
47
-
48
- chunks = splitter.split_documents(documents)
49
-
50
  embeddings = HuggingFaceEmbeddings(
51
  model_name=EMBEDDING_MODEL
52
  )
53
 
54
- vectordb = Chroma.from_documents(
55
- documents=chunks,
56
- embedding=embeddings,
57
- persist_directory=CHROMA_DIR
58
- )
59
 
60
- retriever = vectordb.as_retriever(search_kwargs={"k": 3})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  # -----------------------------
63
- # LLM (CORRECT task)
64
  # -----------------------------
65
- llm = pipeline(
66
- "text2text-generation",
67
  model=LLM_MODEL,
68
- device=-1
69
  )
70
 
 
71
  # -----------------------------
72
- # RAG call
73
  # -----------------------------
74
- def ask_rag_with_status(question: str) -> Tuple[str, list]:
75
- status = []
76
-
77
- if vectordb._collection.count() == 0:
78
- return "Knowledge base is empty.", ["No documents indexed"]
79
 
80
- docs = retriever.get_relevant_documents(question)
81
 
82
- status.append(f"Retrieved {len(docs)} chunks")
 
83
 
84
  context = "\n\n".join(d.page_content for d in docs)
85
 
86
- prompt = f"""
87
- Answer the question using ONLY the context below.
88
 
89
  Context:
90
  {context}
@@ -92,9 +88,9 @@ Context:
92
  Question:
93
  {question}
94
 
95
- Answer:
96
- """
97
 
98
- result = llm(prompt, max_new_tokens=256)[0]["generated_text"]
 
99
 
100
- return result.strip(), status
 
1
  # rag.py
2
 
3
  import os
4
+ from datasets import load_dataset
5
+ from transformers import pipeline
6
+ from langchain.schema import Document
 
7
  from langchain_community.vectorstores import Chroma
8
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
 
9
 
10
+ from config import HF_DATASET_REPO, EMBEDDING_MODEL, LLM_MODEL
11
+
 
 
 
 
12
 
13
  # -----------------------------
14
+ # Load documents from HF dataset
15
  # -----------------------------
16
+ def load_documents():
17
+ documents = []
18
 
19
+ try:
20
+ ds = load_dataset(HF_DATASET_REPO, split="train")
21
+ except Exception as e:
22
+ print(f"❌ Could not load dataset: {e}")
23
+ return []
24
 
25
+ # Expecting dataset rows like: { "text": "..." }
26
+ for row in ds:
27
+ text = row.get("text")
28
+ if text and isinstance(text, str):
29
+ documents.append(Document(page_content=text))
30
 
31
+ print(f"✅ Loaded {len(documents)} documents from dataset")
32
+ return documents
33
 
34
 
35
  # -----------------------------
36
+ # Embeddings
37
  # -----------------------------
 
 
 
 
 
 
 
 
 
38
  embeddings = HuggingFaceEmbeddings(
39
  model_name=EMBEDDING_MODEL
40
  )
41
 
 
 
 
 
 
42
 
43
+ # -----------------------------
44
+ # Vector DB (safe creation)
45
+ # -----------------------------
46
+ documents = load_documents()
47
+
48
+ if not documents:
49
+ print("⚠️ No documents loaded. Vector DB will be disabled.")
50
+ vectordb = None
51
+ else:
52
+ vectordb = Chroma.from_documents(
53
+ documents=documents,
54
+ embedding=embeddings
55
+ )
56
+ print("✅ Vector DB initialized")
57
+
58
 
59
  # -----------------------------
60
+ # LLM Pipeline (CPU safe)
61
  # -----------------------------
62
+ qa_pipeline = pipeline(
63
+ task="text-generation",
64
  model=LLM_MODEL,
65
+ max_new_tokens=256
66
  )
67
 
68
+
69
  # -----------------------------
70
+ # RAG Query Function
71
  # -----------------------------
72
+ def ask_rag_with_status(question: str):
73
+ if vectordb is None:
74
+ return "Knowledge base is empty.", "NO_KB"
 
 
75
 
76
+ docs = vectordb.similarity_search(question, k=3)
77
 
78
+ if not docs:
79
+ return "No relevant documents found.", "NO_MATCH"
80
 
81
  context = "\n\n".join(d.page_content for d in docs)
82
 
83
+ prompt = f"""Use the context below to answer the question.
 
84
 
85
  Context:
86
  {context}
 
88
  Question:
89
  {question}
90
 
91
+ Answer:"""
 
92
 
93
+ result = qa_pipeline(prompt)
94
+ answer = result[0]["generated_text"]
95
 
96
+ return answer, "OK"