Zubaish commited on
Commit
e34c59e
·
1 Parent(s): 13ac6ca

Fix: use existing HF dataset hubrag-kb

Browse files
Files changed (1) hide show
  1. rag.py +31 -56
rag.py CHANGED
@@ -1,82 +1,57 @@
1
- # rag.py
2
-
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  from langchain_community.vectorstores import Chroma
5
  from langchain_huggingface import HuggingFaceEmbeddings
6
- from langchain.schema import Document
7
- from datasets import load_dataset
8
 
9
- from config import MODEL_ID, EMBEDDING_MODEL, HF_DATASET_REPO, TOP_K
 
 
10
 
11
-
12
- # ----------------------------
13
- # Load PDFs from HF Dataset
14
- # ----------------------------
15
  def load_documents():
 
16
  ds = load_dataset(HF_DATASET_REPO, split="train")
17
 
18
- docs = []
19
  for row in ds:
20
- text = row.get("text", "").strip()
21
- if text:
22
  docs.append(Document(page_content=text))
23
 
24
  return docs
25
 
26
-
27
- # ----------------------------
28
- # Build vector store (in-memory)
29
- # ----------------------------
30
  documents = load_documents()
31
 
32
  if not documents:
33
- raise RuntimeError("No documents loaded from HF Dataset")
34
 
35
- embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
36
 
37
- vectordb = Chroma.from_documents(
38
- documents=documents,
39
- embedding=embeddings
40
- )
41
-
42
-
43
- # ----------------------------
44
- # Load LLM (NO device_map)
45
- # ----------------------------
46
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
47
-
48
- model = AutoModelForCausalLM.from_pretrained(
49
- MODEL_ID,
50
- torch_dtype="auto"
51
- )
52
 
53
  llm = pipeline(
54
  "text-generation",
55
- model=model,
56
- tokenizer=tokenizer,
57
- max_new_tokens=256,
58
- temperature=0.2
59
  )
60
 
61
-
62
- # ----------------------------
63
- # Public API
64
- # ----------------------------
65
  def ask_rag_with_status(question: str):
66
- status = []
67
-
68
- status.append("Retrieving relevant documents…")
69
- docs = vectordb.similarity_search(question, k=TOP_K)
70
-
71
- if not docs:
72
  return {
73
- "answer": "No relevant documents found.",
74
- "status": status
75
  }
76
 
 
77
  context = "\n\n".join(d.page_content for d in docs)
78
 
79
- prompt = f"""Use the context below to answer the question.
80
 
81
  Context:
82
  {context}
@@ -86,12 +61,12 @@ Question:
86
 
87
  Answer:"""
88
 
89
- status.append("Generating answer…")
90
  result = llm(prompt)[0]["generated_text"]
91
 
92
- answer = result.split("Answer:")[-1].strip()
93
-
94
  return {
95
- "answer": answer,
96
- "status": status
 
 
 
97
  }
 
1
+ from datasets import load_dataset
2
+ from langchain.schema import Document
 
3
  from langchain_community.vectorstores import Chroma
4
  from langchain_huggingface import HuggingFaceEmbeddings
5
+ from transformers import pipeline
 
6
 
7
+ HF_DATASET_REPO = "Zubaish/hubrag-kb"
8
+ EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
9
+ CHROMA_DIR = "./chroma"
10
 
 
 
 
 
11
  def load_documents():
12
+ docs = []
13
  ds = load_dataset(HF_DATASET_REPO, split="train")
14
 
 
15
  for row in ds:
16
+ text = row.get("text")
17
+ if text and text.strip():
18
  docs.append(Document(page_content=text))
19
 
20
  return docs
21
 
 
 
 
 
22
  documents = load_documents()
23
 
24
  if not documents:
25
+ print("⚠️ No text documents found in dataset. PDFs must be converted to text.")
26
 
27
+ embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
28
 
29
+ vectordb = None
30
+ if documents:
31
+ vectordb = Chroma.from_documents(
32
+ documents,
33
+ embedding=embeddings,
34
+ persist_directory=CHROMA_DIR
35
+ )
 
 
 
 
 
 
 
 
36
 
37
  llm = pipeline(
38
  "text-generation",
39
+ model="microsoft/Phi-3-mini-4k-instruct",
40
+ trust_remote_code=True,
41
+ max_new_tokens=256
 
42
  )
43
 
 
 
 
 
44
  def ask_rag_with_status(question: str):
45
+ if not vectordb:
 
 
 
 
 
46
  return {
47
+ "answer": "Knowledge base is empty. Please upload text documents to the dataset.",
48
+ "status": ["No text documents loaded"]
49
  }
50
 
51
+ docs = vectordb.similarity_search(question, k=3)
52
  context = "\n\n".join(d.page_content for d in docs)
53
 
54
+ prompt = f"""Answer the question using only the context.
55
 
56
  Context:
57
  {context}
 
61
 
62
  Answer:"""
63
 
 
64
  result = llm(prompt)[0]["generated_text"]
65
 
 
 
66
  return {
67
+ "answer": result.split("Answer:")[-1].strip(),
68
+ "status": [
69
+ f"Loaded {len(documents)} documents",
70
+ f"Retrieved {len(docs)} chunks"
71
+ ]
72
  }