Zubaish commited on
Commit
9a9d2bd
·
1 Parent(s): d322d09

Fix: stable RAG implementation

Browse files
Files changed (1) hide show
  1. rag.py +102 -25
rag.py CHANGED
@@ -1,42 +1,114 @@
 
 
 
 
 
1
  from langchain_community.vectorstores import Chroma
2
- from langchain_community.embeddings import HuggingFaceEmbeddings
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
- from config import VECTOR_DIR, EMBED_MODEL
5
 
6
- # Embeddings
7
- embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
 
 
 
8
 
9
- # Vector DB
10
- db = Chroma(
11
- persist_directory=VECTOR_DIR,
12
- embedding_function=embeddings
 
 
 
 
13
  )
14
 
15
- # LLM (CPU-safe)
16
- MODEL_ID = "microsoft/Phi-3-mini-4k-instruct"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  tokenizer = AutoTokenizer.from_pretrained(
19
- MODEL_ID,
20
- trust_remote_code=True
21
  )
22
 
23
  model = AutoModelForCausalLM.from_pretrained(
24
- MODEL_ID,
25
- trust_remote_code=True
 
26
  )
27
 
28
- llm = pipeline(
29
  "text-generation",
30
  model=model,
31
  tokenizer=tokenizer,
32
- max_new_tokens=256
 
 
33
  )
34
 
35
- def ask_rag_with_status(question: str):
36
- docs = db.similarity_search(question, k=3)
37
- context = "\n\n".join(d.page_content for d in docs)
38
 
39
- prompt = f"""Use the context below to answer.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  Context:
42
  {context}
@@ -44,10 +116,15 @@ Context:
44
  Question:
45
  {question}
46
 
47
- Answer:"""
 
 
 
 
 
48
 
49
- output = llm(prompt)[0]["generated_text"]
50
  return {
51
- "answer": output,
52
- "sources": len(docs)
 
53
  }
 
1
+ import os
2
+ from typing import Dict
3
+
4
+ from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader
5
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
6
  from langchain_community.vectorstores import Chroma
7
+ from langchain_huggingface import HuggingFaceEmbeddings
 
 
8
 
9
+ from transformers import (
10
+ AutoTokenizer,
11
+ AutoModelForCausalLM,
12
+ pipeline,
13
+ )
14
 
15
+ from config import (
16
+ KB_DIR,
17
+ CHROMA_DIR,
18
+ EMBEDDING_MODEL,
19
+ LLM_MODEL,
20
+ CHUNK_SIZE,
21
+ CHUNK_OVERLAP,
22
+ TOP_K,
23
  )
24
 
25
+ # ---------------------------
26
+ # Load & index documents
27
+ # ---------------------------
28
+
29
+ def load_documents():
30
+ loader = DirectoryLoader(
31
+ KB_DIR,
32
+ glob="**/*.pdf",
33
+ loader_cls=PyPDFLoader,
34
+ )
35
+ return loader.load()
36
+
37
+
38
+ def build_vectorstore():
39
+ documents = load_documents()
40
+
41
+ splitter = RecursiveCharacterTextSplitter(
42
+ chunk_size=CHUNK_SIZE,
43
+ chunk_overlap=CHUNK_OVERLAP,
44
+ )
45
+ chunks = splitter.split_documents(documents)
46
+
47
+ embeddings = HuggingFaceEmbeddings(
48
+ model_name=EMBEDDING_MODEL
49
+ )
50
+
51
+ vectordb = Chroma.from_documents(
52
+ documents=chunks,
53
+ embedding=embeddings,
54
+ persist_directory=CHROMA_DIR,
55
+ )
56
+
57
+ vectordb.persist()
58
+ return vectordb
59
+
60
+
61
+ # Build or load Chroma DB
62
+ if os.path.exists(CHROMA_DIR):
63
+ embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
64
+ vectordb = Chroma(
65
+ persist_directory=CHROMA_DIR,
66
+ embedding_function=embeddings,
67
+ )
68
+ else:
69
+ vectordb = build_vectorstore()
70
+
71
+
72
+ # ---------------------------
73
+ # Load LLM (HF Space safe)
74
+ # ---------------------------
75
 
76
  tokenizer = AutoTokenizer.from_pretrained(
77
+ LLM_MODEL,
78
+ trust_remote_code=True,
79
  )
80
 
81
  model = AutoModelForCausalLM.from_pretrained(
82
+ LLM_MODEL,
83
+ trust_remote_code=True,
84
+ device_map="cpu",
85
  )
86
 
87
+ generator = pipeline(
88
  "text-generation",
89
  model=model,
90
  tokenizer=tokenizer,
91
+ max_new_tokens=256,
92
+ do_sample=True,
93
+ temperature=0.7,
94
  )
95
 
 
 
 
96
 
97
+ # ---------------------------
98
+ # RAG Query
99
+ # ---------------------------
100
+
101
+ def ask_rag_with_status(question: str) -> Dict:
102
+ docs = vectordb.similarity_search(question, k=TOP_K)
103
+
104
+ context = "\n\n".join(
105
+ [doc.page_content for doc in docs]
106
+ )
107
+
108
+ prompt = f"""
109
+ You are a helpful assistant.
110
+ Answer the question using ONLY the context below.
111
+ If the answer is not in the context, say "I don't know".
112
 
113
  Context:
114
  {context}
 
116
  Question:
117
  {question}
118
 
119
+ Answer:
120
+ """.strip()
121
+
122
+ output = generator(prompt)[0]["generated_text"]
123
+
124
+ answer = output.split("Answer:")[-1].strip()
125
 
 
126
  return {
127
+ "question": question,
128
+ "answer": answer,
129
+ "sources": [doc.metadata for doc in docs],
130
  }