Zubaish commited on
Commit
ffadad7
·
1 Parent(s): 1e98153

rag update

Browse files
Files changed (1) hide show
  1. rag.py +40 -27
rag.py CHANGED
@@ -14,43 +14,51 @@ from config import (
14
  LLM_MODEL,
15
  )
16
 
17
- # -----------------------------
18
- # Load embeddings (CPU-safe)
19
- # -----------------------------
20
  embeddings = HuggingFaceEmbeddings(
21
  model_name=EMBEDDING_MODEL
22
  )
23
 
24
- # -----------------------------
25
- # Load documents
26
- # -----------------------------
27
- docs = []
 
28
  if os.path.exists(KB_DIR):
29
  for file in os.listdir(KB_DIR):
30
- if file.endswith(".pdf"):
31
  loader = PyPDFLoader(os.path.join(KB_DIR, file))
32
- docs.extend(loader.load())
33
 
 
 
 
34
  splitter = RecursiveCharacterTextSplitter(
35
  chunk_size=500,
36
  chunk_overlap=50
37
  )
38
- splits = splitter.split_documents(docs)
39
-
40
- # -----------------------------
41
- # Vector store
42
- # -----------------------------
43
- vectordb = Chroma.from_documents(
44
- splits,
45
- embedding=embeddings,
46
- persist_directory=VECTOR_DB_DIR
47
- )
48
 
49
- retriever = vectordb.as_retriever(search_kwargs={"k": 3})
 
 
 
 
 
 
50
 
51
- # -----------------------------
 
 
 
 
 
 
 
 
52
  # Load LLM (CPU ONLY, NO ACCELERATE)
53
- # -----------------------------
54
  tokenizer = AutoTokenizer.from_pretrained(
55
  LLM_MODEL,
56
  trust_remote_code=True
@@ -58,8 +66,7 @@ tokenizer = AutoTokenizer.from_pretrained(
58
 
59
  model = AutoModelForCausalLM.from_pretrained(
60
  LLM_MODEL,
61
- trust_remote_code=True,
62
- torch_dtype=None, # CPU-safe
63
  )
64
 
65
  llm = pipeline(
@@ -70,12 +77,18 @@ llm = pipeline(
70
  do_sample=False
71
  )
72
 
73
- # -----------------------------
74
- # RAG function
75
- # -----------------------------
76
  def ask_rag_with_status(question: str):
77
  status = []
78
 
 
 
 
 
 
 
79
  status.append("🔍 Retrieving documents...")
80
  docs = retriever.get_relevant_documents(question)
81
 
 
14
  LLM_MODEL,
15
  )
16
 
17
+ # --------------------------------------------------
18
+ # Embeddings (CPU-safe)
19
+ # --------------------------------------------------
20
  embeddings = HuggingFaceEmbeddings(
21
  model_name=EMBEDDING_MODEL
22
  )
23
 
24
+ # --------------------------------------------------
25
+ # Load PDFs (if any)
26
+ # --------------------------------------------------
27
+ documents = []
28
+
29
  if os.path.exists(KB_DIR):
30
  for file in os.listdir(KB_DIR):
31
+ if file.lower().endswith(".pdf"):
32
  loader = PyPDFLoader(os.path.join(KB_DIR, file))
33
+ documents.extend(loader.load())
34
 
35
+ # --------------------------------------------------
36
+ # Split documents
37
+ # --------------------------------------------------
38
  splitter = RecursiveCharacterTextSplitter(
39
  chunk_size=500,
40
  chunk_overlap=50
41
  )
 
 
 
 
 
 
 
 
 
 
42
 
43
+ splits = splitter.split_documents(documents) if documents else []
44
+
45
+ # --------------------------------------------------
46
+ # Vector DB (ONLY if docs exist)
47
+ # --------------------------------------------------
48
+ vectordb = None
49
+ retriever = None
50
 
51
+ if splits:
52
+ vectordb = Chroma.from_documents(
53
+ splits,
54
+ embedding=embeddings,
55
+ persist_directory=VECTOR_DB_DIR
56
+ )
57
+ retriever = vectordb.as_retriever(search_kwargs={"k": 3})
58
+
59
+ # --------------------------------------------------
60
  # Load LLM (CPU ONLY, NO ACCELERATE)
61
+ # --------------------------------------------------
62
  tokenizer = AutoTokenizer.from_pretrained(
63
  LLM_MODEL,
64
  trust_remote_code=True
 
66
 
67
  model = AutoModelForCausalLM.from_pretrained(
68
  LLM_MODEL,
69
+ trust_remote_code=True
 
70
  )
71
 
72
  llm = pipeline(
 
77
  do_sample=False
78
  )
79
 
80
+ # --------------------------------------------------
81
+ # Public RAG API
82
+ # --------------------------------------------------
83
  def ask_rag_with_status(question: str):
84
  status = []
85
 
86
+ if retriever is None:
87
+ return {
88
+ "answer": "❌ Knowledge base is empty. Please upload PDFs to the dataset or storage.",
89
+ "status": ["⚠️ No documents indexed"]
90
+ }
91
+
92
  status.append("🔍 Retrieving documents...")
93
  docs = retriever.get_relevant_documents(question)
94