kundan621 commited on
Commit
41f1bec
·
1 Parent(s): 03b3405

working guardrail

Browse files
Files changed (1) hide show
  1. src/search_final.py +13 -27
src/search_final.py CHANGED
@@ -113,20 +113,6 @@ FAISS_PATH = os.path.join(OUT_DIR, "faiss_merged.index")
113
  BM25_PATH = os.path.join(OUT_DIR, "bm25_merged.pkl")
114
  META_PATH = os.path.join(OUT_DIR, "meta_merged.pkl")
115
 
116
- BLOCKED_TERMS = ["weather","cricket","movie","song","football","holiday",
117
- "travel","recipe","music","game","sports","politics","election"]
118
-
119
- FINANCE_DOMAINS = [
120
- "financial reporting","balance sheet","income statement","assets and liabilities",
121
- "equity","revenue","profit and loss","goodwill impairment","cash flow","dividends",
122
- "taxation","investment","valuation","capital structure","ownership interests",
123
- "subsidiaries","shareholders equity","expenses","earnings","debt","amortization","depreciation"
124
- ]
125
-
126
- ALLOWED_COMPANY = ["make my trip","mmt"]
127
-
128
- # crude regex to detect "company-like" words (any capitalized word(s) followed by Ltd, Inc, Company, etc.)
129
- COMPANY_PATTERN = re.compile(r"\b([A-Z][a-zA-Z]+(?:\s+[A-Z][a-zA-Z]+)*\s+(?:Ltd|Limited|Inc|Corporation|Corp|LLC|Group|Company|Bank))\b", re.IGNORECASE)
130
 
131
  # ---------------- Load Indexes ----------------
132
  logger.info("Loading FAISS, BM25, metadata, and models...")
@@ -184,29 +170,28 @@ def get_mistral_answer(query: str, context: str) -> str:
184
  return f"Error fetching answer from LLM: {e}"
185
 
186
  # ---------------- Guardrails ----------------
 
 
 
 
 
 
 
 
 
 
 
 
187
  finance_embeds = embed_model.encode(FINANCE_DOMAINS, convert_to_tensor=True)
188
 
189
  def validate_query(query: str, threshold: float = 0.5) -> bool:
190
  q_lower = query.lower()
191
-
192
- # Blocklist check
193
  if any(bad in q_lower for bad in BLOCKED_TERMS):
194
  print("[Guardrail] Rejected by blocklist.")
195
  return False
196
-
197
- # Check for company mentions
198
- companies_found = COMPANY_PATTERN.findall(query)
199
- if companies_found:
200
- # If any company is mentioned, only allow MakeMyTrip
201
- if not any(ALLOWED_COMPANY in c.lower() for c in companies_found):
202
- print(f"[Guardrail] Rejected: company mention {companies_found}, not {ALLOWED_COMPANY}.")
203
- return False
204
-
205
- # Semantic similarity check with financial domain
206
  q_emb = embed_model.encode(query, convert_to_tensor=True)
207
  sim_scores = util.cos_sim(q_emb, finance_embeds)
208
  max_score = float(sim_scores.max())
209
-
210
  if max_score > threshold:
211
  print(f"[Guardrail] Accepted (semantic match {max_score:.2f})")
212
  return True
@@ -214,6 +199,7 @@ def validate_query(query: str, threshold: float = 0.5) -> bool:
214
  print(f"[Guardrail] Rejected (low semantic score {max_score:.2f})")
215
  return False
216
 
 
217
  #-------------------Output Guardrail------------------
218
  def validate_output(answer: str, context_docs: List[Dict]) -> str:
219
  combined_context = " ".join([doc["content"].lower() for doc in context_docs])
 
113
  BM25_PATH = os.path.join(OUT_DIR, "bm25_merged.pkl")
114
  META_PATH = os.path.join(OUT_DIR, "meta_merged.pkl")
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  # ---------------- Load Indexes ----------------
118
  logger.info("Loading FAISS, BM25, metadata, and models...")
 
170
  return f"Error fetching answer from LLM: {e}"
171
 
172
  # ---------------- Guardrails ----------------
173
+ # ---------------- Guardrails ----------------
174
+ BLOCKED_TERMS = ["weather", "cricket", "movie", "song", "football", "holiday",
175
+ "travel", "recipe", "music", "game", "sports", "politics", "election"]
176
+
177
+ FINANCE_DOMAINS = [
178
+ "financial reporting", "balance sheet", "income statement",
179
+ "assets and liabilities", "equity", "revenue", "profit and loss",
180
+ "goodwill impairment", "cash flow", "dividends", "taxation",
181
+ "investment", "valuation", "capital structure", "ownership interests",
182
+ "subsidiaries", "shareholders equity", "expenses", "earnings",
183
+ "debt", "amortization", "depreciation"
184
+ ]
185
  finance_embeds = embed_model.encode(FINANCE_DOMAINS, convert_to_tensor=True)
186
 
187
  def validate_query(query: str, threshold: float = 0.5) -> bool:
188
  q_lower = query.lower()
 
 
189
  if any(bad in q_lower for bad in BLOCKED_TERMS):
190
  print("[Guardrail] Rejected by blocklist.")
191
  return False
 
 
 
 
 
 
 
 
 
 
192
  q_emb = embed_model.encode(query, convert_to_tensor=True)
193
  sim_scores = util.cos_sim(q_emb, finance_embeds)
194
  max_score = float(sim_scores.max())
 
195
  if max_score > threshold:
196
  print(f"[Guardrail] Accepted (semantic match {max_score:.2f})")
197
  return True
 
199
  print(f"[Guardrail] Rejected (low semantic score {max_score:.2f})")
200
  return False
201
 
202
+
203
  #-------------------Output Guardrail------------------
204
  def validate_output(answer: str, context_docs: List[Dict]) -> str:
205
  combined_context = " ".join([doc["content"].lower() for doc in context_docs])