DevKX's picture
Added guardrail for task formatting
720964a
import os
import logging
import re
from typing import Any, Type
from dotenv import load_dotenv
# Vector DB
from langchain_community.vectorstores.faiss import FAISS
from langchain_mistralai import MistralAIEmbeddings
from crewai import Agent
from crewai.tools import BaseTool
from pydantic import BaseModel, Field, ConfigDict
from crew.tasks.rag_tasks import create_policy_search_task, create_policy_summary_task
logger = logging.getLogger(__name__)
load_dotenv()
VECTORSTORE_DIR = "rag/vectorstore"
def sanitize_query(query: str) -> str:
"""
Sanitizes query to remove PII and specific Nationalities/Demographics
that are not present in the RAG Database.
"""
if not query: return ""
# 1. REMOVE NATIONALITIES (Specific -> Generic)
# Since your RAG DB doesn't know about 'Filipino' or 'Indian' policies,
# we strip them out so the Agent searches for generic "Loan Rates".
# Add common nationalities relevant to your market
nationalities = [
# Common Status Terms
"Foreigner", "Expat", "Expatriate", "Alien", "Non-Resident", "Resident",
"Permanent Resident", "Citizen", "PR",
"Afghan", "Albanian", "Algerian", "American", "Andorran", "Angolan", "Antiguan",
"Argentine", "Armenian", "Australian", "Austrian", "Azerbaijani",
"Bahamian", "Bahraini", "Bangladeshi", "Barbadian", "Belarusian", "Belgian",
"Belizean", "Beninese", "Bhutanese", "Bolivian", "Bosnian", "Botswanan",
"Brazilian", "British", "Bruneian", "Bulgarian", "Burkinabe", "Burmese", "Burundian",
"Cambodian", "Cameroonian", "Canadian", "Cape Verdean", "Central African", "Chadian",
"Chilean", "Chinese", "Colombian", "Comoran", "Congolese", "Costa Rican", "Croatian",
"Cuban", "Cypriot", "Czech",
"Danish", "Djiboutian", "Dominican", "Dutch", "East Timorese", "Ecuadorean",
"Egyptian", "Emirati", "Equatorial Guinean", "Eritrean", "Estonian", "Ethiopian",
"Fijian", "Filipino", "Finnish", "French", "Gabonese", "Gambian", "Georgian",
"German", "Ghanaian", "Greek", "Grenadian", "Guatemalan", "Guinea-Bissauan",
"Guinean", "Guyanese",
"Haitian", "Herzegovinian", "Honduran", "Hungarian", "Icelander", "Indian",
"Indonesian", "Iranian", "Iraqi", "Irish", "Israeli", "Italian", "Ivorian",
"Jamaican", "Japanese", "Jordanian", "Kazakhstani", "Kenyan", "Kittian and Nevisian",
"Kuwaiti", "Kyrgyz",
"Laotian", "Latvian", "Lebanese", "Liberian", "Libyan", "Liechtensteiner",
"Lithuanian", "Luxembourger",
"Macedonian", "Malagasy", "Malawian", "Malay", "Malaysian", "Maldivian", "Malian",
"Maltese", "Marshallese", "Mauritanian", "Mauritian", "Mexican", "Micronesian",
"Moldovan", "Monacan", "Mongolian", "Moroccan", "Mosotho", "Motswana", "Mozambican",
"Namibian", "Nauruan", "Nepalese", "New Zealander", "Nicaraguan", "Nigerian",
"Nigerien", "North Korean", "Northern Irish", "Norwegian",
"Omani", "Pakistani", "Palauan", "Panamanian", "Papua New Guinean", "Paraguayan",
"Peruvian", "Polish", "Portuguese",
"Qatari", "Romanian", "Russian", "Rwandan",
"Saint Lucian", "Salvadoran", "Samoan", "San Marinese", "Sao Tomean", "Saudi",
"Scottish", "Senegalese", "Serbian", "Seychellois", "Sierra Leonean", "Singaporean",
"Slovakian", "Slovenian", "Solomon Islander", "Somali", "South African",
"South Korean", "Spanish", "Sri Lankan", "Sudanese", "Surinamer", "Swazi",
"Swedish", "Swiss", "Syrian",
"Taiwanese", "Tajik", "Tanzanian", "Thai", "Togolese", "Tongan", "Trinidadian",
"Tunisian", "Turkish", "Tuvaluan",
"Ugandan", "Ukrainian", "Uruguayan", "Uzbekistani", "Venezuelan", "Vietnamese",
"Welsh", "Yemenite", "Zambian", "Zimbabwean"
]
# Create a regex pattern: \b(Filipino|Indian|...)\b
# flags=re.IGNORECASE makes it catch 'filipino' and 'Filipino'
nat_pattern = r'\b(' + '|'.join(nationalities) + r')\b'
# OPTION A: Replace with nothing (Search becomes "What are rates?")
query = re.sub(nat_pattern, "", query, flags=re.IGNORECASE)
# OPTION B: If your DB has a "Foreigner" section, use this instead:
# query = re.sub(nat_pattern, "Foreigner", query, flags=re.IGNORECASE)
# 2. REMOVE IDs
query = re.sub(r'\bID\s*\d+\b', '', query, flags=re.IGNORECASE)
# 3. REMOVE EMAILS
query = re.sub(r'\S+@\S+', '', query)
# 4. REMOVE NAMES (The aggressive check)
# WARNING: This regex '\b[A-Z][a-z]+\b' is very aggressive.
# It removes ANY capitalized word (like "Bank", "Loan", "Rate").
# I recommend commenting this out unless you really need it,
# or replacing it with a Named Entity Recognition (NER) library later.
# query = re.sub(r'\b[A-Z][a-z]+\b', '', query)
# Clean up double spaces created by removals
query = re.sub(r'\s+', ' ', query).strip()
return query
# Ensure only simple string
class RAGToolSchema(BaseModel):
query: str = Field(
...,
description="The search query string. Example: 'interest rates for high risk'"
)
class RAGSearchTool(BaseTool):
name: str = "rag_search_tool"
description: str = "Search the bank policy manual. Useful for finding rates, rules, and limits."
args_schema: Type[BaseModel] = RAGToolSchema
vectorstore: Any = Field(description="FAISS vectorstore instance")
model_config = ConfigDict(arbitrary_types_allowed=True)
def _run(self, query: Any) -> str:
if isinstance(query, dict):
query_str = query.get('query', "")
else:
query_str = str(query)
# Sanitize by removing weird character
clean_query = sanitize_query(query_str)
logger.info(f"RAG Tool Searching for: '{clean_query}'")
try:
# Search 4 text chunk
results = self.vectorstore.similarity_search(clean_query, k=4)
except Exception as e:
return f"SYSTEM_ERROR: Vector search failed. {str(e)}"
if not results:
return "RESULT: Policy Database Silent. No documents found matching this query."
# Return with the Source keyword
formatted_results = "\n\n".join([
f"[SOURCE: Page/Section {i+1}]: {doc.page_content}"
for i, doc in enumerate(results)
])
return formatted_results
class RAGAgent:
def __init__(self, llm):
logger.info("Initializing RAG Agent...")
# Mistral embedding act as translator that convert English text to "vectors" to match the format stored in vector DB
embeddings = MistralAIEmbeddings(
model="mistral-embed",
api_key=os.getenv("MISTRAL_API_KEY")
)
# Load Vector DB
self.search_tool = None
try:
if os.path.exists(VECTORSTORE_DIR):
vectorstore = FAISS.load_local(
VECTORSTORE_DIR,
embeddings,
allow_dangerous_deserialization=True
)
# Instantiate the tool with the vectorstore
self.search_tool = RAGSearchTool(vectorstore=vectorstore)
else:
logger.warning(f"Vector Store not found at {VECTORSTORE_DIR}.")
except Exception as e:
logger.error(f"Failed to load Vector DB: {e}")
# Safety check: Ensure tool exists
tools_list = [self.search_tool] if self.search_tool else []
"""
ADAPTIVE EXTRACTION: We instruct the agent to change its output format based on the topic
Crucial instruction to strip PII (Personal Identifiable Information). To prevent agent from search customer specific policy which cause endless loop
"""
self.agent = Agent(
role="Bank Policy Researcher",
goal="Retrieve and structure precise policy rules for any banking query.",
backstory=(
"You are the **Bank Policy Researcher**.\n"
"You have access to the bank's entire Policy Manual (via Vector Search).\n"
"**YOUR JOB**: When the Manager asks a question, you find the relevant page and extract the facts.\n"
"**YOUR STYLE**: You are adaptive. \n"
"- If asked about Rates -> Extract the rate table.\n"
"- If asked about Eligibility -> Extract the qualification criteria.\n"
"- If asked about Penalties -> Extract the fee structure.\n"
"You do NOT invent data. You only format what is in the document.\n"
"You do NOT take in customer name and find specific policy for that customer.\n"
"⚠️ IMPORTANT: NEVER include or search for any personal identifiers such as customer names, IDs, or emails. "
"Always convert any customer-specific query into generic attributes like credit score, account status, or loan type before searching.\n"
"### 🛑 TOOL USAGE DECREE:\n"
"1. **INPUT IS PLAIN TEXT**: Your tool input must be a simple, continuous string. Do not use { } or [ ].\n"
"2. **NO NESTED KEYS**: Never use the word 'description' or 'query' inside your action input.\n"
"3. **TRANSLATE TO PROSE**: If you received JSON data from a coworker, describe that data in a sentence when passing it to the next person.\n"
"4. **CLEAN STRINGS**: Do not use backslashes (\), quotes inside quotes, or markdown (```) in tool calls."
),
tools=tools_list,
verbose=True,
allow_delegation=False,
llm=llm
)
def get_policy_search_task(self, query: str):
clean_query = sanitize_query(query)
return create_policy_search_task(self.agent, clean_query)
def get_summary_search_task(self, query: str):
clean_query = sanitize_query(query)
return create_policy_summary_task(self.agent, clean_query)