Zubaish commited on
Commit
98b93b7
·
1 Parent(s): abd4e0b

Stable HF-ready RAG using HF Datasets

Browse files
Files changed (7) hide show
  1. .gitignore +3 -17
  2. Dockerfile +0 -2
  3. app.py +5 -10
  4. config.py +14 -5
  5. ingest.py +11 -23
  6. rag.py +49 -57
  7. requirements.txt +2 -9
.gitignore CHANGED
@@ -1,18 +1,4 @@
1
- # Python
2
- __pycache__/
3
- *.pyc
4
-
5
- # Chroma DB
6
- chroma_db/
7
- data/
8
-
9
- # Environment
10
- .env
11
-
12
- # Frontend (DO NOT COMMIT)
13
- frontend/node_modules/
14
- frontend/dist/
15
- frontend/.vite/
16
-
17
  kb/
18
- vectordb/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  kb/
2
+ chroma_db/
3
+ *.pdf
4
+ __pycache__/
Dockerfile CHANGED
@@ -9,8 +9,6 @@ RUN pip install --no-cache-dir -r requirements.txt
9
 
10
  COPY app.py rag.py ingest.py config.py ./
11
 
12
- RUN mkdir -p kb vectordb
13
-
14
  EXPOSE 7860
15
 
16
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
9
 
10
  COPY app.py rag.py ingest.py config.py ./
11
 
 
 
12
  EXPOSE 7860
13
 
14
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py CHANGED
@@ -1,18 +1,13 @@
1
  from fastapi import FastAPI
2
- from pydantic import BaseModel
3
  from rag import ask_rag_with_status
4
 
5
- app = FastAPI(title="RAG Knowledge Bot")
6
-
7
- class Query(BaseModel):
8
- question: str
9
-
10
 
11
  @app.get("/")
12
  def health():
13
  return {"status": "ok"}
14
 
15
-
16
- @app.post("/chat")
17
- def chat(query: Query):
18
- return ask_rag_with_status(query.question)
 
1
  from fastapi import FastAPI
 
2
  from rag import ask_rag_with_status
3
 
4
+ app = FastAPI()
 
 
 
 
5
 
6
  @app.get("/")
7
  def health():
8
  return {"status": "ok"}
9
 
10
+ @app.post("/ask")
11
+ def ask(payload: dict):
12
+ question = payload.get("question", "")
13
+ return ask_rag_with_status(question)
config.py CHANGED
@@ -1,9 +1,18 @@
1
  import os
 
 
2
 
3
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
 
4
 
5
- KB_DIR = os.path.join(BASE_DIR, "kb")
6
- VECTOR_DB_DIR = os.path.join(BASE_DIR, "vectordb")
 
 
 
 
7
 
8
- EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
9
- LLM_MODEL = "microsoft/Phi-3-mini-4k-instruct"
 
 
 
1
  import os
2
+ from huggingface_hub import snapshot_download
3
+ from config import HF_DATASET_ID, KB_DIR
4
 
5
+ def download_kb():
6
+ os.makedirs(KB_DIR, exist_ok=True)
7
 
8
+ snapshot_download(
9
+ repo_id=HF_DATASET_ID,
10
+ repo_type="dataset",
11
+ local_dir=KB_DIR,
12
+ local_dir_use_symlinks=False
13
+ )
14
 
15
+ print("✅ Knowledge base downloaded")
16
+
17
+ if __name__ == "__main__":
18
+ download_kb()
ingest.py CHANGED
@@ -1,30 +1,18 @@
1
  import os
2
- from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader
3
- from langchain_text_splitters import RecursiveCharacterTextSplitter
4
- from langchain_community.vectorstores import Chroma
5
- from langchain_community.embeddings import HuggingFaceEmbeddings
6
- from config import KB_DIR, VECTOR_DB_DIR, EMBEDDING_MODEL
7
 
8
- def ingest():
9
- if not os.path.exists(KB_DIR) or not os.listdir(KB_DIR):
10
- print("⚠️ No PDFs found in kb/. Skipping ingestion.")
11
- return
12
 
13
- loader = DirectoryLoader(KB_DIR, glob="**/*.pdf", loader_cls=PyPDFLoader)
14
- docs = loader.load()
15
-
16
- splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
17
- chunks = splitter.split_documents(docs)
18
-
19
- embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
20
-
21
- Chroma.from_documents(
22
- chunks,
23
- embeddings,
24
- persist_directory=VECTOR_DB_DIR
25
  )
26
 
27
- print("✅ Ingestion complete")
28
 
29
  if __name__ == "__main__":
30
- ingest()
 
1
  import os
2
+ from huggingface_hub import snapshot_download
3
+ from config import HF_DATASET_ID, KB_DIR
 
 
 
4
 
5
+ def download_kb():
6
+ os.makedirs(KB_DIR, exist_ok=True)
 
 
7
 
8
+ snapshot_download(
9
+ repo_id=HF_DATASET_ID,
10
+ repo_type="dataset",
11
+ local_dir=KB_DIR,
12
+ local_dir_use_symlinks=False
 
 
 
 
 
 
 
13
  )
14
 
15
+ print("✅ Knowledge base downloaded")
16
 
17
  if __name__ == "__main__":
18
+ download_kb()
rag.py CHANGED
@@ -1,74 +1,66 @@
1
- from langchain_community.vectorstores import Chroma
2
- from langchain_community.embeddings import HuggingFaceEmbeddings
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
- import torch
5
-
6
- from config import VECTOR_DB_DIR, EMBEDDING_MODEL, LLM_MODEL
7
-
8
- _embeddings = None
9
- _db = None
10
- _tokenizer = None
11
- _model = None
12
-
13
-
14
- def get_vector_db():
15
- global _embeddings, _db
16
 
17
- if _db is None:
18
- _embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
19
- _db = Chroma(
20
- persist_directory=VECTOR_DB_DIR,
21
- embedding_function=_embeddings,
22
- )
23
- return _db
24
 
 
 
 
 
 
 
 
 
25
 
26
- def get_llm():
27
- global _tokenizer, _model
 
28
 
29
- if _model is None:
30
- _tokenizer = AutoTokenizer.from_pretrained(
31
- LLM_MODEL, trust_remote_code=True
32
- )
33
- _model = AutoModelForCausalLM.from_pretrained(
34
- LLM_MODEL,
35
- trust_remote_code=True,
36
- torch_dtype=torch.float32
37
- )
38
- return _tokenizer, _model
39
 
 
 
 
 
 
 
40
 
41
- def ask_rag_with_status(question: str):
42
- status = []
43
 
44
- db = get_vector_db()
45
- status.append("📚 Vector DB loaded")
 
 
46
 
47
- docs = db.similarity_search(question, k=3)
48
- context = "\n\n".join(d.page_content for d in docs)
49
- status.append("🔍 Retrieved relevant context")
50
 
51
- tokenizer, model = get_llm()
52
- status.append("🤖 LLM loaded")
 
53
 
54
- prompt = f"""
55
- You are a helpful assistant.
 
 
 
56
 
57
- Context:
58
- {context}
59
 
60
- Question:
61
- {question}
 
62
 
63
- Answer clearly and concisely.
64
- """
65
 
66
- inputs = tokenizer(prompt, return_tensors="pt")
67
- outputs = model.generate(**inputs, max_new_tokens=300)
68
-
69
- answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
70
 
71
  return {
72
- "answer": answer,
73
- "status": status
 
74
  }
 
1
+ import os
2
+ from typing import Dict
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ from langchain_community.document_loaders import PyPDFLoader
5
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
6
+ from langchain_community.embeddings import HuggingFaceEmbeddings
7
+ from langchain_community.vectorstores import Chroma
 
 
 
8
 
9
+ from ingest import download_kb
10
+ from config import (
11
+ KB_DIR,
12
+ CHROMA_DIR,
13
+ EMBED_MODEL,
14
+ CHUNK_SIZE,
15
+ CHUNK_OVERLAP,
16
+ )
17
 
18
+ # -------------------------
19
+ # Startup: download + index
20
+ # -------------------------
21
 
22
+ print("⬇️ Downloading KB...")
23
+ download_kb()
 
 
 
 
 
 
 
 
24
 
25
+ print("📄 Loading documents...")
26
+ documents = []
27
+ for file in os.listdir(KB_DIR):
28
+ if file.endswith(".pdf"):
29
+ loader = PyPDFLoader(os.path.join(KB_DIR, file))
30
+ documents.extend(loader.load())
31
 
32
+ print(f"📚 Loaded {len(documents)} pages")
 
33
 
34
+ splitter = RecursiveCharacterTextSplitter(
35
+ chunk_size=CHUNK_SIZE,
36
+ chunk_overlap=CHUNK_OVERLAP,
37
+ )
38
 
39
+ splits = splitter.split_documents(documents)
 
 
40
 
41
+ embeddings = HuggingFaceEmbeddings(
42
+ model_name=EMBED_MODEL
43
+ )
44
 
45
+ vectorstore = Chroma.from_documents(
46
+ documents=splits,
47
+ embedding=embeddings,
48
+ persist_directory=CHROMA_DIR
49
+ )
50
 
51
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
 
52
 
53
+ # -------------------------
54
+ # Query API
55
+ # -------------------------
56
 
57
+ def ask_rag_with_status(question: str) -> Dict:
58
+ docs = retriever.get_relevant_documents(question)
59
 
60
+ context = "\n\n".join(d.page_content for d in docs)
 
 
 
61
 
62
  return {
63
+ "question": question,
64
+ "chunks_used": len(docs),
65
+ "context_preview": context[:500]
66
  }
requirements.txt CHANGED
@@ -1,17 +1,10 @@
1
  fastapi
2
  uvicorn
3
- python-dotenv
4
-
5
  langchain==0.2.17
6
  langchain-community==0.2.17
7
  langchain-text-splitters==0.2.4
8
-
9
  chromadb==0.5.5
10
  sentence-transformers
 
11
  pypdf
12
-
13
- transformers>=4.39.0
14
- huggingface_hub<1.0.0
15
- numpy<2
16
- SQLAlchemy<3
17
- requests<3
 
1
  fastapi
2
  uvicorn
 
 
3
  langchain==0.2.17
4
  langchain-community==0.2.17
5
  langchain-text-splitters==0.2.4
 
6
  chromadb==0.5.5
7
  sentence-transformers
8
+ huggingface_hub
9
  pypdf
10
+ numpy<2