Shoaib-33 commited on
Commit
9a55d40
·
verified ·
1 Parent(s): 471ac15

Update backend/rag_pipeline.py

Browse files
Files changed (1) hide show
  1. backend/rag_pipeline.py +190 -41
backend/rag_pipeline.py CHANGED
@@ -4,7 +4,8 @@ from langchain_google_genai import ChatGoogleGenerativeAI
4
  from langchain.prompts import PromptTemplate
5
  from langchain_core.output_parsers import StrOutputParser
6
  from langchain_core.runnables import RunnablePassthrough
7
- from .data_loader import all_chunks
 
8
  import os
9
  from dotenv import load_dotenv
10
 
@@ -13,59 +14,80 @@ load_dotenv()
13
  # --- API Key ---
14
  os.environ["GOOGLE_API_KEY"] = os.environ.get("GOOGLE_API_KEY", "")
15
 
16
- # --- Embeddings ---
17
- embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
 
 
 
18
 
19
- # --- Vector DB ---
20
  vectordb = Chroma(
21
  collection_name="bus_data",
22
  embedding_function=embedding_model,
23
  persist_directory="vectorstore"
24
  )
25
 
26
- # Convert list metadata → strings
27
- def clean_metadata(metadata):
 
 
 
 
28
  cleaned = {}
29
  for key, value in metadata.items():
30
  if isinstance(value, list):
31
  cleaned[key] = ", ".join(str(v) for v in value)
 
 
32
  else:
33
  cleaned[key] = value
34
  return cleaned
35
 
36
- # Load chunks if empty
 
 
 
37
  if len(vectordb.get()["ids"]) == 0:
38
  print("Adding chunks to vector DB...")
39
  for chunk in all_chunks:
40
  metadata = chunk["metadata"].copy()
41
  if "provider" in metadata and metadata["provider"]:
42
  metadata["provider"] = metadata["provider"].strip().lower()
43
- vectordb.add_texts([chunk["content"]], metadatas=[clean_metadata(metadata)])
 
 
 
44
  print(f"✅ Added {len(all_chunks)} chunks.")
45
  else:
46
- print(f"ℹ️ Vector database already contains {len(vectordb.get()['ids'])} chunks.")
 
47
 
48
- # --- Gemini LLM ---
 
 
49
  gemini_llm = ChatGoogleGenerativeAI(
50
  temperature=0.3,
51
  model="gemini-2.5-flash",
52
  google_api_key=os.environ["GOOGLE_API_KEY"]
53
  )
54
 
55
- # --- Enhanced Prompt ---
 
 
 
56
  prompt_template = """You are a friendly and helpful bus service assistant for Bangladesh bus services.
57
 
58
- CRITICAL INSTRUCTIONS - READ CAREFULLY:
59
- 1. If the user asks about a SPECIFIC bus provider (like Hanif, Ena, Desh Travel, etc.), ONLY use information from that provider's context.
60
  2. NEVER mix contact information, policies, or details between different providers.
61
- 3. When answering about contact information, address, or policy, make absolutely sure you're looking at the correct provider's data.
62
  4. If you're not certain which provider the information belongs to, say you don't know.
63
 
64
  GENERAL INSTRUCTIONS:
65
- - Answer ONLY from the context provided below
66
- - Be conversational, friendly, and concise
67
- - Always mention prices in "Taka"
68
- - Use bullet points for lists
69
  - If information is missing, say: "I don't have that information. Please contact the bus service directly."
70
 
71
  Context Information:
@@ -80,37 +102,148 @@ PROMPT = PromptTemplate(
80
  input_variables=["context", "question"]
81
  )
82
 
 
83
  # ======================================================
84
- # Detect provider from query
85
  # ======================================================
86
- def detect_provider_from_query(query: str):
 
 
 
 
 
 
87
  query_lower = query.lower()
88
- providers = ["hanif", "ena", "desh travel", "green line", "soudia", "shyamoli"]
89
- for provider in providers:
90
  if provider in query_lower:
91
  return provider
92
  return None
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  # ======================================================
95
- # Format retrieved docs into a string
96
  # ======================================================
97
- def format_docs(docs):
98
  return "\n\n".join(doc.page_content for doc in docs)
99
 
 
100
  # ======================================================
101
- # Build RAG chain
102
  # ======================================================
103
- def get_rag_chain(provider: str = None):
104
- if provider:
105
- retriever = vectordb.as_retriever(
106
- search_type="similarity",
107
- search_kwargs={"k": 10, "filter": {"provider": provider.strip().lower()}}
108
- )
 
 
 
 
 
 
 
 
 
 
 
109
  else:
110
- retriever = vectordb.as_retriever(
111
- search_type="similarity",
112
- search_kwargs={"k": 20}
113
- )
 
 
 
 
 
 
114
 
115
  chain = (
116
  {
@@ -124,22 +257,38 @@ def get_rag_chain(provider: str = None):
124
 
125
  return chain, retriever
126
 
 
127
  # ======================================================
128
- # Get answer
129
  # ======================================================
130
- def get_answer(query: str, provider: str = None):
 
 
 
 
131
  provider = provider or detect_provider_from_query(query)
132
- chain, _ = get_rag_chain(provider)
133
  return chain.invoke(query)
134
 
135
- def get_answer_with_sources(query: str, provider: str = None):
 
 
 
 
 
136
  provider = provider or detect_provider_from_query(query)
137
- chain, retriever = get_rag_chain(provider)
138
 
139
  docs = retriever.invoke(query)
140
  answer = chain.invoke(query)
141
 
142
  return {
143
  "answer": answer,
144
- "source_documents": docs
 
 
 
 
 
 
145
  }
 
4
  from langchain.prompts import PromptTemplate
5
  from langchain_core.output_parsers import StrOutputParser
6
  from langchain_core.runnables import RunnablePassthrough
7
+ from .data_loader import all_chunks, providers as raw_providers
8
+ import re
9
  import os
10
  from dotenv import load_dotenv
11
 
 
14
  # --- API Key ---
15
  os.environ["GOOGLE_API_KEY"] = os.environ.get("GOOGLE_API_KEY", "")
16
 
17
+ # ======================================================
18
+ # Embeddings & Vector DB
19
+ # ======================================================
20
+ embedding_model = HuggingFaceEmbeddings(
21
+ model_name="sentence-transformers/all-MiniLM-L6-v2"
22
+ )
23
 
 
24
  vectordb = Chroma(
25
  collection_name="bus_data",
26
  embedding_function=embedding_model,
27
  persist_directory="vectorstore"
28
  )
29
 
30
+
31
+ # ======================================================
32
+ # Metadata Cleaner
33
+ # ======================================================
34
+ def clean_metadata(metadata: dict) -> dict:
35
+ """Convert list values to comma-separated strings for Chroma compatibility."""
36
  cleaned = {}
37
  for key, value in metadata.items():
38
  if isinstance(value, list):
39
  cleaned[key] = ", ".join(str(v) for v in value)
40
+ elif isinstance(value, int) or isinstance(value, float):
41
+ cleaned[key] = value # keep numbers as numbers for filtering
42
  else:
43
  cleaned[key] = value
44
  return cleaned
45
 
46
+
47
+ # ======================================================
48
+ # Load Chunks into Vector DB (once)
49
+ # ======================================================
50
  if len(vectordb.get()["ids"]) == 0:
51
  print("Adding chunks to vector DB...")
52
  for chunk in all_chunks:
53
  metadata = chunk["metadata"].copy()
54
  if "provider" in metadata and metadata["provider"]:
55
  metadata["provider"] = metadata["provider"].strip().lower()
56
+ vectordb.add_texts(
57
+ [chunk["content"]],
58
+ metadatas=[clean_metadata(metadata)]
59
+ )
60
  print(f"✅ Added {len(all_chunks)} chunks.")
61
  else:
62
+ print(f"ℹ️ Vector DB already has {len(vectordb.get()['ids'])} chunks.")
63
+
64
 
65
+ # ======================================================
66
+ # LLM
67
+ # ======================================================
68
  gemini_llm = ChatGoogleGenerativeAI(
69
  temperature=0.3,
70
  model="gemini-2.5-flash",
71
  google_api_key=os.environ["GOOGLE_API_KEY"]
72
  )
73
 
74
+
75
+ # ======================================================
76
+ # Prompt
77
+ # ======================================================
78
  prompt_template = """You are a friendly and helpful bus service assistant for Bangladesh bus services.
79
 
80
+ CRITICAL INSTRUCTIONS:
81
+ 1. If the user asks about a SPECIFIC bus provider (Hanif, Ena, Desh Travel, etc.), ONLY use information from that provider's context.
82
  2. NEVER mix contact information, policies, or details between different providers.
83
+ 3. When answering about contact info, address, or policy make sure you're reading the correct provider's data.
84
  4. If you're not certain which provider the information belongs to, say you don't know.
85
 
86
  GENERAL INSTRUCTIONS:
87
+ - Answer ONLY from the context provided below.
88
+ - Be conversational, friendly, and concise.
89
+ - Always mention prices in "Taka".
90
+ - Use bullet points for lists.
91
  - If information is missing, say: "I don't have that information. Please contact the bus service directly."
92
 
93
  Context Information:
 
102
  input_variables=["context", "question"]
103
  )
104
 
105
+
106
  # ======================================================
107
+ # Query Understanding Helpers
108
  # ======================================================
109
+
110
+ # Build known provider list dynamically from data
111
+ KNOWN_PROVIDERS = [p["name"].lower() for p in raw_providers]
112
+
113
+
114
+ def detect_provider_from_query(query: str) -> str | None:
115
+ """Detect if user is asking about a specific bus provider."""
116
  query_lower = query.lower()
117
+ for provider in KNOWN_PROVIDERS:
 
118
  if provider in query_lower:
119
  return provider
120
  return None
121
 
122
+
123
+ def detect_query_type(query: str) -> str | None:
124
+ """
125
+ Detect the type of information the user is looking for.
126
+ Returns: 'policy' | 'dropping_point' | 'provider' | None
127
+ """
128
+ query_lower = query.lower()
129
+
130
+ policy_keywords = ["policy", "cancel", "refund", "reschedule", "terms", "rules", "luggage"]
131
+ fare_keywords = ["fare", "price", "taka", "cost", "cheap", "expensive", "affordable", "route", "ticket"]
132
+ provider_keywords = ["contact", "phone", "address", "office", "helpline", "number", "location"]
133
+
134
+ if any(w in query_lower for w in policy_keywords):
135
+ return "policy"
136
+ if any(w in query_lower for w in fare_keywords):
137
+ return "dropping_point"
138
+ if any(w in query_lower for w in provider_keywords):
139
+ return "provider"
140
+
141
+ return None # broad search — no type filter applied
142
+
143
+
144
+ def extract_price_filter(query: str) -> dict | None:
145
+ """
146
+ Extract numeric price constraints from natural language.
147
+ Returns a Chroma-compatible filter dict or None.
148
+ """
149
+ # between X and Y
150
+ match = re.search(r'between\s*(\d+)\s*and\s*(\d+)', query, re.IGNORECASE)
151
+ if match:
152
+ return {
153
+ "$and": [
154
+ {"price": {"$gte": int(match.group(1))}},
155
+ {"price": {"$lte": int(match.group(2))}}
156
+ ]
157
+ }
158
+
159
+ # under / below / less than X
160
+ match = re.search(r'(under|below|less than)\s*(\d+)', query, re.IGNORECASE)
161
+ if match:
162
+ return {"price": {"$lte": int(match.group(2))}}
163
+
164
+ # above / over / more than X
165
+ match = re.search(r'(above|over|more than)\s*(\d+)', query, re.IGNORECASE)
166
+ if match:
167
+ return {"price": {"$gte": int(match.group(2))}}
168
+
169
+ # exactly X taka
170
+ match = re.search(r'exactly\s*(\d+)', query, re.IGNORECASE)
171
+ if match:
172
+ return {"price": {"$eq": int(match.group(1))}}
173
+
174
+ return None
175
+
176
+
177
+ def build_filter(provider: str = None, query: str = None) -> dict | None:
178
+ """
179
+ Combine all filters (provider + type + price) into a single
180
+ Chroma-compatible where clause.
181
+ """
182
+ conditions = []
183
+
184
+ # 1. Provider filter
185
+ if provider:
186
+ conditions.append({"provider": {"$eq": provider.strip().lower()}})
187
+
188
+ if query:
189
+ # 2. Type filter
190
+ query_type = detect_query_type(query)
191
+ if query_type:
192
+ conditions.append({"type": {"$eq": query_type}})
193
+
194
+ # 3. Price filter — only applies to dropping_point type
195
+ price_filter = extract_price_filter(query)
196
+ if price_filter:
197
+ # Force type to dropping_point when price is involved
198
+ if not query_type:
199
+ conditions.append({"type": {"$eq": "dropping_point"}})
200
+ conditions.append(price_filter)
201
+
202
+ if len(conditions) == 0:
203
+ return None
204
+ if len(conditions) == 1:
205
+ return conditions[0]
206
+ return {"$and": conditions}
207
+
208
+
209
  # ======================================================
210
+ # Format Retrieved Docs
211
  # ======================================================
212
+ def format_docs(docs) -> str:
213
  return "\n\n".join(doc.page_content for doc in docs)
214
 
215
+
216
  # ======================================================
217
+ # Build RAG Chain
218
  # ======================================================
219
+ def get_rag_chain(provider: str = None, query: str = None):
220
+ """
221
+ Build a LangChain RAG chain with smart filtering.
222
+ - Provider filter: only chunks from that provider
223
+ - Type filter: policy / dropping_point / provider
224
+ - Price filter: $lte / $gte / $eq on metadata price field
225
+ """
226
+ where_filter = build_filter(provider=provider, query=query)
227
+
228
+ # Adaptive k:
229
+ # Policy queries need more chunks (long text split into many pieces)
230
+ # Fare/price queries need fewer (very specific records)
231
+ query_type = detect_query_type(query) if query else None
232
+ if query_type == "policy":
233
+ k = 8
234
+ elif query_type == "dropping_point":
235
+ k = 6
236
  else:
237
+ k = 10
238
+
239
+ search_kwargs = {"k": k}
240
+ if where_filter:
241
+ search_kwargs["filter"] = where_filter
242
+
243
+ retriever = vectordb.as_retriever(
244
+ search_type="similarity",
245
+ search_kwargs=search_kwargs
246
+ )
247
 
248
  chain = (
249
  {
 
257
 
258
  return chain, retriever
259
 
260
+
261
  # ======================================================
262
+ # Public API
263
  # ======================================================
264
+ def get_answer(query: str, provider: str = None) -> str:
265
+ """
266
+ Get a plain answer string for a user query.
267
+ Provider is auto-detected from query if not passed explicitly.
268
+ """
269
  provider = provider or detect_provider_from_query(query)
270
+ chain, _ = get_rag_chain(provider=provider, query=query)
271
  return chain.invoke(query)
272
 
273
+
274
+ def get_answer_with_sources(query: str, provider: str = None) -> dict:
275
+ """
276
+ Get answer + source documents for debugging or display.
277
+ Returns: { answer: str, source_documents: list[Document] }
278
+ """
279
  provider = provider or detect_provider_from_query(query)
280
+ chain, retriever = get_rag_chain(provider=provider, query=query)
281
 
282
  docs = retriever.invoke(query)
283
  answer = chain.invoke(query)
284
 
285
  return {
286
  "answer": answer,
287
+ "source_documents": docs,
288
+ "debug": {
289
+ "provider_detected": provider,
290
+ "query_type": detect_query_type(query),
291
+ "price_filter": extract_price_filter(query),
292
+ "chunks_retrieved": len(docs)
293
+ }
294
  }