Ryanfafa commited on
Commit
36db703
·
verified ·
1 Parent(s): 778eb7b

Upload rag_engine.py

Browse files
Files changed (1) hide show
  1. rag_engine.py +186 -0
rag_engine.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RAG Engine
3
+ ──────────
4
+ - Embeddings : sentence-transformers/all-MiniLM-L6-v2 (HuggingFace, free)
5
+ - Vector DB : ChromaDB (local, in-memory / persistent)
6
+ - LLM : HuggingFace Router API (Mistral-7B-Instruct-v0.3, free tier)
7
+ - Chunking : Recursive character splitter with overlap
8
+ """
9
+
10
+ import os
11
+ import re
12
+ import requests
13
+ import tempfile
14
+ from typing import Tuple, List
15
+
16
+ import chromadb
17
+ from chromadb.config import Settings
18
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
19
+ from langchain_community.embeddings import HuggingFaceEmbeddings
20
+ from langchain_community.vectorstores import Chroma
21
+ from langchain_community.document_loaders import PyPDFLoader, TextLoader
22
+ from langchain.schema import Document
23
+
24
+ # ─── Configuration ─────────────────────────────────────────────────────────────
25
+ EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
26
+ HF_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3"
27
+ HF_API_URL = f"https://router.huggingface.co/hf-inference/models/{HF_MODEL_ID}/v1/chat/completions"
28
+ CHUNK_SIZE = 800
29
+ CHUNK_OVERLAP = 150
30
+ TOP_K = 4
31
+ COLLECTION_NAME = "docmind_collection"
32
+ CHROMA_DIR = "./chroma_db"
33
+
34
+
35
+ class RAGEngine:
36
+ """Full RAG pipeline: ingest → embed → store → retrieve → generate."""
37
+
38
+ def __init__(self):
39
+ self._embeddings = None
40
+ self._vectorstore = None
41
+ self._splitter = RecursiveCharacterTextSplitter(
42
+ chunk_size=CHUNK_SIZE,
43
+ chunk_overlap=CHUNK_OVERLAP,
44
+ separators=["\n\n", "\n", ". ", " ", ""],
45
+ )
46
+
47
+ # ── Lazy-load embeddings ───────────────────────────────────────────────────
48
+ @property
49
+ def embeddings(self):
50
+ if self._embeddings is None:
51
+ self._embeddings = HuggingFaceEmbeddings(
52
+ model_name=EMBED_MODEL,
53
+ model_kwargs={"device": "cpu"},
54
+ encode_kwargs={"normalize_embeddings": True},
55
+ )
56
+ return self._embeddings
57
+
58
+ # ── Ingest an uploaded Streamlit file object ───────────────────────────────
59
+ def ingest_file(self, uploaded_file) -> int:
60
+ suffix = Path_suffix(uploaded_file.name)
61
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
62
+ tmp.write(uploaded_file.read())
63
+ tmp_path = tmp.name
64
+ return self.ingest_path(tmp_path, uploaded_file.name)
65
+
66
+ # ── Ingest from a file path ────────────────────────────────────────────────
67
+ def ingest_path(self, path: str, name: str = "") -> int:
68
+ suffix = Path_suffix(name or path)
69
+
70
+ if suffix == ".pdf":
71
+ loader = PyPDFLoader(path)
72
+ else:
73
+ loader = TextLoader(path, encoding="utf-8")
74
+
75
+ raw_docs = loader.load()
76
+
77
+ # Add source metadata
78
+ for doc in raw_docs:
79
+ doc.metadata["source"] = name or os.path.basename(path)
80
+
81
+ chunks = self._splitter.split_documents(raw_docs)
82
+
83
+ # Reset & recreate vectorstore for the new document
84
+ self._vectorstore = Chroma.from_documents(
85
+ documents=chunks,
86
+ embedding=self.embeddings,
87
+ collection_name=COLLECTION_NAME,
88
+ persist_directory=CHROMA_DIR,
89
+ client_settings=Settings(anonymized_telemetry=False),
90
+ )
91
+
92
+ return len(chunks)
93
+
94
+ # ── Query: retrieve + generate ─────────────────────────────────────────────
95
+ def query(self, question: str) -> Tuple[str, List[str]]:
96
+ if self._vectorstore is None:
97
+ return "⚠️ Please upload a document first.", []
98
+
99
+ # 1. Retrieve top-k relevant chunks
100
+ retriever = self._vectorstore.as_retriever(
101
+ search_type="mmr",
102
+ search_kwargs={"k": TOP_K, "fetch_k": TOP_K * 3},
103
+ )
104
+ docs = retriever.invoke(question)
105
+
106
+ # 2. Build context
107
+ context = "\n\n---\n\n".join(
108
+ f"[Chunk {i+1}]\n{d.page_content}" for i, d in enumerate(docs)
109
+ )
110
+
111
+ # 3. Unique source names for display
112
+ sources = list({d.metadata.get("source", "Document") for d in docs})
113
+
114
+ # 4. Generate answer
115
+ answer = self._generate(question, context)
116
+
117
+ return answer, sources
118
+
119
+ # ── LLM call via NEW HuggingFace Router API ────────────────────────────────
120
+ def _generate(self, question: str, context: str) -> str:
121
+ try:
122
+ hf_token = os.environ.get("HF_TOKEN", "")
123
+
124
+ headers = {"Content-Type": "application/json"}
125
+ if hf_token:
126
+ headers["Authorization"] = f"Bearer {hf_token}"
127
+
128
+ payload = {
129
+ "model": HF_MODEL_ID,
130
+ "messages": [
131
+ {
132
+ "role": "system",
133
+ "content": (
134
+ "You are DocMind, an expert document analyst. "
135
+ "Answer the user's question using ONLY the provided document context. "
136
+ "Be concise, accurate, and cite specific details from the context. "
137
+ "If the answer is not in the context, say so clearly."
138
+ ),
139
+ },
140
+ {
141
+ "role": "user",
142
+ "content": (
143
+ f"Document context:\n{context}\n\n"
144
+ f"Question: {question}"
145
+ ),
146
+ },
147
+ ],
148
+ "max_tokens": 512,
149
+ "temperature": 0.2,
150
+ }
151
+
152
+ resp = requests.post(HF_API_URL, headers=headers, json=payload, timeout=60)
153
+ resp.raise_for_status()
154
+
155
+ answer = resp.json()["choices"][0]["message"]["content"].strip()
156
+ return answer or "I could not generate a response. Please try rephrasing."
157
+
158
+ except Exception as e:
159
+ return _fallback_answer(question, context, str(e))
160
+
161
+
162
+ # ─── Fallback (no LLM) ─────────────────────────────────────────────────────────
163
+ def _fallback_answer(question: str, context: str, error: str) -> str:
164
+ """Simple extractive answer when LLM is unavailable."""
165
+ keywords = set(re.findall(r'\b\w{4,}\b', question.lower()))
166
+ best_chunk, best_score = "", 0
167
+
168
+ for chunk in context.split("---"):
169
+ words = set(re.findall(r'\b\w{4,}\b', chunk.lower()))
170
+ score = len(keywords & words)
171
+ if score > best_score:
172
+ best_score = score
173
+ best_chunk = chunk.strip()
174
+
175
+ if best_chunk:
176
+ excerpt = best_chunk[:600] + ("..." if len(best_chunk) > 600 else "")
177
+ return (
178
+ f"*(LLM unavailable – showing most relevant excerpt)*\n\n{excerpt}\n\n"
179
+ f"<small>Error: {error}</small>"
180
+ )
181
+ return f"⚠️ Could not generate answer. Error: {error}"
182
+
183
+
184
+ # ─── Helper ────────────────────────────────────────────────────────────────────
185
+ def Path_suffix(name: str) -> str:
186
+ return os.path.splitext(name)[-1].lower() or ".txt"