Zubaish commited on
Commit
f6f60e8
·
1 Parent(s): 7167638
Files changed (1) hide show
  1. rag.py +25 -52
rag.py CHANGED
@@ -1,63 +1,33 @@
1
  # rag.py
2
-
3
  import os
4
- from datasets import load_dataset
5
  from transformers import pipeline
6
- from langchain.schema import Document
7
- from langchain_community.vectorstores import Chroma
8
- from langchain_community.embeddings import HuggingFaceEmbeddings
9
-
10
- from config import HF_DATASET_REPO, EMBEDDING_MODEL, LLM_MODEL
11
-
12
-
13
- # -----------------------------
14
- # Load documents from HF dataset
15
- # -----------------------------
16
- def load_documents():
17
- documents = []
18
-
19
- try:
20
- ds = load_dataset(HF_DATASET_REPO, split="train")
21
- except Exception as e:
22
- print(f"❌ Could not load dataset: {e}")
23
- return []
24
-
25
- # Expecting dataset rows like: { "text": "..." }
26
- for row in ds:
27
- text = row.get("text")
28
- if text and isinstance(text, str):
29
- documents.append(Document(page_content=text))
30
-
31
- print(f"✅ Loaded {len(documents)} documents from dataset")
32
- return documents
33
-
34
 
35
  # -----------------------------
36
- # Embeddings
37
  # -----------------------------
38
  embeddings = HuggingFaceEmbeddings(
39
  model_name=EMBEDDING_MODEL
40
  )
41
 
42
-
43
  # -----------------------------
44
- # Vector DB (safe creation)
45
  # -----------------------------
46
- documents = load_documents()
47
-
48
- if not documents:
49
- print("⚠️ No documents loaded. Vector DB will be disabled.")
50
- vectordb = None
51
- else:
52
- vectordb = Chroma.from_documents(
53
- documents=documents,
54
- embedding=embeddings
55
  )
56
- print("✅ Vector DB initialized")
57
-
 
 
58
 
59
  # -----------------------------
60
- # LLM Pipeline (CPU safe)
61
  # -----------------------------
62
  qa_pipeline = pipeline(
63
  task="text-generation",
@@ -65,22 +35,22 @@ qa_pipeline = pipeline(
65
  max_new_tokens=256
66
  )
67
 
68
-
69
  # -----------------------------
70
- # RAG Query Function
71
  # -----------------------------
72
  def ask_rag_with_status(question: str):
73
  if vectordb is None:
74
- return "Knowledge base is empty.", "NO_KB"
75
 
 
76
  docs = vectordb.similarity_search(question, k=3)
77
 
78
  if not docs:
79
- return "No relevant documents found.", "NO_MATCH"
80
 
81
  context = "\n\n".join(d.page_content for d in docs)
82
 
83
- prompt = f"""Use the context below to answer the question.
84
 
85
  Context:
86
  {context}
@@ -91,6 +61,9 @@ Question:
91
  Answer:"""
92
 
93
  result = qa_pipeline(prompt)
94
- answer = result[0]["generated_text"]
 
 
 
95
 
96
- return answer, "OK"
 
1
  # rag.py
 
2
  import os
 
3
  from transformers import pipeline
4
+ from langchain_huggingface import HuggingFaceEmbeddings
5
+ from langchain_chroma import Chroma
6
+ from config import EMBEDDING_MODEL, LLM_MODEL, CHROMA_DIR
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # -----------------------------
9
+ # 1. Initialize Embeddings (LangChain-HuggingFace)
10
  # -----------------------------
11
  embeddings = HuggingFaceEmbeddings(
12
  model_name=EMBEDDING_MODEL
13
  )
14
 
 
15
  # -----------------------------
16
+ # 2. Load Vector DB (Safe Loading)
17
  # -----------------------------
18
+ # We expect the DB to be pre-built by ingest.py during Docker build
19
+ if os.path.exists(CHROMA_DIR) and os.listdir(CHROMA_DIR):
20
+ vectordb = Chroma(
21
+ persist_directory=CHROMA_DIR,
22
+ embedding_function=embeddings
 
 
 
 
23
  )
24
+ print(f"✅ Vector DB loaded from {CHROMA_DIR}")
25
+ else:
26
+ print(f"⚠️ Vector DB not found at {CHROMA_DIR}. Please check ingestion.")
27
+ vectordb = None
28
 
29
  # -----------------------------
30
+ # 3. LLM Pipeline (CPU safe)
31
  # -----------------------------
32
  qa_pipeline = pipeline(
33
  task="text-generation",
 
35
  max_new_tokens=256
36
  )
37
 
 
38
  # -----------------------------
39
+ # 4. RAG Query Function
40
  # -----------------------------
41
  def ask_rag_with_status(question: str):
42
  if vectordb is None:
43
+ return "Knowledge base is empty. Technical error during ingestion.", "NO_KB"
44
 
45
+ # Search for relevant context
46
  docs = vectordb.similarity_search(question, k=3)
47
 
48
  if not docs:
49
+ return "No relevant documents found in the knowledge base.", "NO_MATCH"
50
 
51
  context = "\n\n".join(d.page_content for d in docs)
52
 
53
+ prompt = f"""Use the context below to answer the question accurately.
54
 
55
  Context:
56
  {context}
 
61
  Answer:"""
62
 
63
  result = qa_pipeline(prompt)
64
+
65
+ # Extract only the generated answer
66
+ full_text = result[0]["generated_text"]
67
+ answer = full_text.split("Answer:")[-1].strip()
68
 
69
+ return answer, ["Context retrieved", "LLM processed"]