Spaces:
Sleeping
Sleeping
working guardrail
Browse files- src/search_final.py +13 -27
src/search_final.py
CHANGED
|
@@ -113,20 +113,6 @@ FAISS_PATH = os.path.join(OUT_DIR, "faiss_merged.index")
|
|
| 113 |
BM25_PATH = os.path.join(OUT_DIR, "bm25_merged.pkl")
|
| 114 |
META_PATH = os.path.join(OUT_DIR, "meta_merged.pkl")
|
| 115 |
|
| 116 |
-
BLOCKED_TERMS = ["weather","cricket","movie","song","football","holiday",
|
| 117 |
-
"travel","recipe","music","game","sports","politics","election"]
|
| 118 |
-
|
| 119 |
-
FINANCE_DOMAINS = [
|
| 120 |
-
"financial reporting","balance sheet","income statement","assets and liabilities",
|
| 121 |
-
"equity","revenue","profit and loss","goodwill impairment","cash flow","dividends",
|
| 122 |
-
"taxation","investment","valuation","capital structure","ownership interests",
|
| 123 |
-
"subsidiaries","shareholders equity","expenses","earnings","debt","amortization","depreciation"
|
| 124 |
-
]
|
| 125 |
-
|
| 126 |
-
ALLOWED_COMPANY = ["make my trip","mmt"]
|
| 127 |
-
|
| 128 |
-
# crude regex to detect "company-like" words (any capitalized word(s) followed by Ltd, Inc, Company, etc.)
|
| 129 |
-
COMPANY_PATTERN = re.compile(r"\b([A-Z][a-zA-Z]+(?:\s+[A-Z][a-zA-Z]+)*\s+(?:Ltd|Limited|Inc|Corporation|Corp|LLC|Group|Company|Bank))\b", re.IGNORECASE)
|
| 130 |
|
| 131 |
# ---------------- Load Indexes ----------------
|
| 132 |
logger.info("Loading FAISS, BM25, metadata, and models...")
|
|
@@ -184,29 +170,28 @@ def get_mistral_answer(query: str, context: str) -> str:
|
|
| 184 |
return f"Error fetching answer from LLM: {e}"
|
| 185 |
|
| 186 |
# ---------------- Guardrails ----------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
finance_embeds = embed_model.encode(FINANCE_DOMAINS, convert_to_tensor=True)
|
| 188 |
|
| 189 |
def validate_query(query: str, threshold: float = 0.5) -> bool:
|
| 190 |
q_lower = query.lower()
|
| 191 |
-
|
| 192 |
-
# Blocklist check
|
| 193 |
if any(bad in q_lower for bad in BLOCKED_TERMS):
|
| 194 |
print("[Guardrail] Rejected by blocklist.")
|
| 195 |
return False
|
| 196 |
-
|
| 197 |
-
# Check for company mentions
|
| 198 |
-
companies_found = COMPANY_PATTERN.findall(query)
|
| 199 |
-
if companies_found:
|
| 200 |
-
# If any company is mentioned, only allow MakeMyTrip
|
| 201 |
-
if not any(ALLOWED_COMPANY in c.lower() for c in companies_found):
|
| 202 |
-
print(f"[Guardrail] Rejected: company mention {companies_found}, not {ALLOWED_COMPANY}.")
|
| 203 |
-
return False
|
| 204 |
-
|
| 205 |
-
# Semantic similarity check with financial domain
|
| 206 |
q_emb = embed_model.encode(query, convert_to_tensor=True)
|
| 207 |
sim_scores = util.cos_sim(q_emb, finance_embeds)
|
| 208 |
max_score = float(sim_scores.max())
|
| 209 |
-
|
| 210 |
if max_score > threshold:
|
| 211 |
print(f"[Guardrail] Accepted (semantic match {max_score:.2f})")
|
| 212 |
return True
|
|
@@ -214,6 +199,7 @@ def validate_query(query: str, threshold: float = 0.5) -> bool:
|
|
| 214 |
print(f"[Guardrail] Rejected (low semantic score {max_score:.2f})")
|
| 215 |
return False
|
| 216 |
|
|
|
|
| 217 |
#-------------------Output Guardrail------------------
|
| 218 |
def validate_output(answer: str, context_docs: List[Dict]) -> str:
|
| 219 |
combined_context = " ".join([doc["content"].lower() for doc in context_docs])
|
|
|
|
| 113 |
BM25_PATH = os.path.join(OUT_DIR, "bm25_merged.pkl")
|
| 114 |
META_PATH = os.path.join(OUT_DIR, "meta_merged.pkl")
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
# ---------------- Load Indexes ----------------
|
| 118 |
logger.info("Loading FAISS, BM25, metadata, and models...")
|
|
|
|
| 170 |
return f"Error fetching answer from LLM: {e}"
|
| 171 |
|
| 172 |
# ---------------- Guardrails ----------------
|
| 173 |
+
# ---------------- Guardrails ----------------
|
| 174 |
+
BLOCKED_TERMS = ["weather", "cricket", "movie", "song", "football", "holiday",
|
| 175 |
+
"travel", "recipe", "music", "game", "sports", "politics", "election"]
|
| 176 |
+
|
| 177 |
+
FINANCE_DOMAINS = [
|
| 178 |
+
"financial reporting", "balance sheet", "income statement",
|
| 179 |
+
"assets and liabilities", "equity", "revenue", "profit and loss",
|
| 180 |
+
"goodwill impairment", "cash flow", "dividends", "taxation",
|
| 181 |
+
"investment", "valuation", "capital structure", "ownership interests",
|
| 182 |
+
"subsidiaries", "shareholders equity", "expenses", "earnings",
|
| 183 |
+
"debt", "amortization", "depreciation"
|
| 184 |
+
]
|
| 185 |
finance_embeds = embed_model.encode(FINANCE_DOMAINS, convert_to_tensor=True)
|
| 186 |
|
| 187 |
def validate_query(query: str, threshold: float = 0.5) -> bool:
|
| 188 |
q_lower = query.lower()
|
|
|
|
|
|
|
| 189 |
if any(bad in q_lower for bad in BLOCKED_TERMS):
|
| 190 |
print("[Guardrail] Rejected by blocklist.")
|
| 191 |
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
q_emb = embed_model.encode(query, convert_to_tensor=True)
|
| 193 |
sim_scores = util.cos_sim(q_emb, finance_embeds)
|
| 194 |
max_score = float(sim_scores.max())
|
|
|
|
| 195 |
if max_score > threshold:
|
| 196 |
print(f"[Guardrail] Accepted (semantic match {max_score:.2f})")
|
| 197 |
return True
|
|
|
|
| 199 |
print(f"[Guardrail] Rejected (low semantic score {max_score:.2f})")
|
| 200 |
return False
|
| 201 |
|
| 202 |
+
|
| 203 |
#-------------------Output Guardrail------------------
|
| 204 |
def validate_output(answer: str, context_docs: List[Dict]) -> str:
|
| 205 |
combined_context = " ".join([doc["content"].lower() for doc in context_docs])
|