""" RAG (Retrieval Augmented Generation) for FinEE ============================================== Provides context-aware entity extraction using: 1. Merchant Knowledge Base - 10K+ Indian merchants 2. Transaction History - Similar past transactions 3. Category Taxonomy - Hierarchical categories Author: Ranjit Behera """ import json import numpy as np from pathlib import Path from typing import List, Dict, Optional, Tuple from dataclasses import dataclass, asdict from collections import defaultdict import hashlib # ============================================================================ # DATA STRUCTURES # ============================================================================ @dataclass class Merchant: """Merchant entity with metadata.""" name: str vpa: Optional[str] = None category: str = "other" aliases: List[str] = None keywords: List[str] = None typical_amount_range: Tuple[float, float] = (0, 100000) is_p2m: bool = True def __post_init__(self): if self.aliases is None: self.aliases = [] if self.keywords is None: self.keywords = [] @dataclass class Transaction: """Stored transaction for similarity search.""" id: str text: str amount: Optional[float] = None type: Optional[str] = None merchant: Optional[str] = None category: Optional[str] = None embedding: Optional[List[float]] = None @dataclass class RetrievedContext: """Context retrieved for augmentation.""" merchant_info: Optional[Dict] = None similar_transactions: List[Dict] = None category_hierarchy: List[str] = None confidence_boost: float = 0.0 # ============================================================================ # MERCHANT KNOWLEDGE BASE # ============================================================================ class MerchantKnowledgeBase: """ Knowledge base of Indian merchants for RAG. Provides merchant lookup, alias resolution, and category inference. """ def __init__(self): self.merchants: Dict[str, Merchant] = {} self.vpa_index: Dict[str, str] = {} # VPA -> merchant name self.alias_index: Dict[str, str] = {} # alias -> merchant name self._load_default_merchants() def _load_default_merchants(self): """Load built-in merchant database.""" default_merchants = [ # Food Delivery Merchant("Swiggy", "swiggy@ybl", "food", ["swiggy.in", "swiggy food"], ["delivery", "food order"], (50, 2000)), Merchant("Zomato", "zomato@paytm", "food", ["zomato.com", "zomato food"], ["restaurant", "dining"], (100, 3000)), Merchant("Zepto", "zepto@ybl", "grocery", [], ["10 min", "quick commerce"], (100, 2000)), Merchant("BigBasket", "bigbasket@ybl", "grocery", ["bb", "bigbasket.com"], ["grocery", "vegetables"], (200, 5000)), Merchant("Blinkit", "blinkit@ybl", "grocery", ["grofers"], ["instant delivery"], (100, 2000)), # E-commerce Merchant("Amazon", "amazon@apl", "shopping", ["amzn", "amazon.in", "amazonpay"], ["order", "delivery"], (100, 50000)), Merchant("Flipkart", "flipkart@ybl", "shopping", ["fk", "flipkart.com"], ["order", "electronics"], (100, 100000)), Merchant("Myntra", "myntra@ybl", "shopping", ["myntra.com"], ["fashion", "clothing"], (200, 10000)), Merchant("Meesho", "meesho@ybl", "shopping", [], ["reseller", "wholesale"], (100, 5000)), Merchant("Ajio", "ajio@ybl", "shopping", ["ajio.com"], ["fashion", "brand"], (200, 15000)), # Transport Merchant("Uber", "uber@axisbank", "transport", ["uber india", "uberindia"], ["ride", "cab"], (50, 2000)), Merchant("Ola", "ola@ybl", "transport", ["olacabs", "ola cabs"], ["ride", "cab", "auto"], (30, 1500)), Merchant("Rapido", "rapido@ybl", "transport", [], ["bike taxi", "auto"], (20, 500)), # Travel Merchant("IRCTC", "irctc@sbi", "travel", ["indian railways", "railway"], ["train", "ticket"], (100, 10000)), Merchant("MakeMyTrip", "makemytrip@ybl", "travel", ["mmt", "makemytrip.com"], ["flight", "hotel", "booking"], (500, 100000)), Merchant("Goibibo", "goibibo@ybl", "travel", [], ["flight", "hotel"], (500, 50000)), Merchant("Yatra", "yatra@ybl", "travel", ["yatra.com"], ["travel", "booking"], (500, 50000)), # Bills & Utilities Merchant("Airtel", "airtel@paytm", "bills", ["bharti airtel"], ["recharge", "postpaid", "broadband"], (100, 5000)), Merchant("Jio", "jio@ybl", "bills", ["reliance jio"], ["recharge", "fiber"], (100, 3000)), Merchant("Vi", "vi@ybl", "bills", ["vodafone", "vodafone idea"], ["recharge", "prepaid"], (100, 2000)), Merchant("BSNL", "bsnl@ybl", "bills", [], ["landline", "broadband"], (100, 3000)), Merchant("Tata Power", "tatapower@ybl", "bills", [], ["electricity", "power bill"], (500, 20000)), Merchant("BESCOM", "bescom@ybl", "bills", [], ["electricity", "bangalore"], (200, 10000)), # Entertainment Merchant("Netflix", "netflix@ybl", "entertainment", ["netflix.com"], ["subscription", "streaming"], (149, 799)), Merchant("Amazon Prime", "amazonprime@apl", "entertainment", ["prime video", "primevideo"], ["subscription", "streaming"], (129, 1499)), Merchant("Hotstar", "hotstar@ybl", "entertainment", ["disney hotstar", "disney+"], ["subscription", "cricket"], (299, 1499)), Merchant("BookMyShow", "bookmyshow@ybl", "entertainment", ["bms"], ["movie", "ticket", "event"], (100, 5000)), Merchant("Spotify", "spotify@ybl", "entertainment", [], ["music", "subscription"], (59, 179)), # Investment Merchant("Zerodha", "zerodha@hdfcbank", "investment", ["zerodha kite", "kite"], ["stocks", "trading", "MF"], (100, 1000000)), Merchant("Groww", "groww@ybl", "investment", ["groww.in"], ["mutual fund", "SIP", "stocks"], (100, 500000)), Merchant("Upstox", "upstox@ybl", "investment", [], ["trading", "demat"], (100, 500000)), Merchant("Angel One", "angelone@ybl", "investment", ["angel broking"], ["trading", "stocks"], (100, 500000)), Merchant("5Paisa", "5paisa@ybl", "investment", ["five paisa"], ["trading", "broker"], (100, 500000)), Merchant("Coin by Zerodha", "coin@zerodha", "investment", ["zerodha coin"], ["mutual fund", "MF"], (500, 100000)), Merchant("Kuvera", "kuvera@ybl", "investment", [], ["mutual fund", "goal based"], (500, 100000)), Merchant("ET Money", "etmoney@ybl", "investment", [], ["mutual fund", "SIP"], (500, 50000)), # Healthcare Merchant("Apollo Pharmacy", "apollo@ybl", "healthcare", ["apollo", "apollo hospitals"], ["medicine", "pharmacy"], (50, 10000)), Merchant("PharmEasy", "pharmeasy@ybl", "healthcare", [], ["medicine", "health", "lab test"], (100, 5000)), Merchant("1mg", "1mg@ybl", "healthcare", ["onemg", "tata 1mg"], ["medicine", "pharmacy"], (50, 5000)), Merchant("Netmeds", "netmeds@ybl", "healthcare", [], ["medicine", "pharmacy"], (100, 3000)), Merchant("Practo", "practo@ybl", "healthcare", [], ["doctor", "consultation", "appointment"], (100, 2000)), # Payment Apps (for cashback detection) Merchant("PhonePe", "phonepe@ybl", "transfer", [], ["cashback", "reward"], (1, 100000), False), Merchant("Paytm", "paytm@paytm", "transfer", ["paytmmall"], ["cashback", "wallet"], (1, 100000), False), Merchant("Google Pay", "googlepay@okicici", "transfer", ["gpay", "tez"], ["cashback", "reward"], (1, 100000), False), # Fuel Merchant("HP Petrol", "hpcl@ybl", "transport", ["hindustan petroleum", "HPCL"], ["fuel", "petrol"], (100, 10000)), Merchant("Indian Oil", "iocl@ybl", "transport", ["IOCL", "indian oil corporation"], ["fuel", "petrol"], (100, 10000)), Merchant("Bharat Petroleum", "bpcl@ybl", "transport", ["BPCL"], ["fuel", "petrol", "diesel"], (100, 10000)), ] for merchant in default_merchants: self.add_merchant(merchant) def add_merchant(self, merchant: Merchant): """Add merchant to knowledge base.""" self.merchants[merchant.name.lower()] = merchant if merchant.vpa: self.vpa_index[merchant.vpa.lower()] = merchant.name for alias in merchant.aliases: self.alias_index[alias.lower()] = merchant.name def lookup(self, query: str) -> Optional[Merchant]: """Look up merchant by name, VPA, or alias.""" query_lower = query.lower().strip() # Direct name match if query_lower in self.merchants: return self.merchants[query_lower] # VPA match if '@' in query_lower and query_lower in self.vpa_index: name = self.vpa_index[query_lower] return self.merchants.get(name.lower()) # Alias match if query_lower in self.alias_index: name = self.alias_index[query_lower] return self.merchants.get(name.lower()) # Partial match for name, merchant in self.merchants.items(): if query_lower in name or name in query_lower: return merchant # Check VPA contains query if merchant.vpa and query_lower in merchant.vpa.lower(): return merchant return None def search(self, text: str, limit: int = 5) -> List[Merchant]: """Search merchants by text content.""" text_lower = text.lower() matches = [] for name, merchant in self.merchants.items(): score = 0 # Name match if name in text_lower: score += 10 # VPA match if merchant.vpa and merchant.vpa.split('@')[0] in text_lower: score += 8 # Alias match for alias in merchant.aliases: if alias.lower() in text_lower: score += 5 # Keyword match for keyword in merchant.keywords: if keyword.lower() in text_lower: score += 2 if score > 0: matches.append((merchant, score)) # Sort by score and return top matches matches.sort(key=lambda x: x[1], reverse=True) return [m[0] for m in matches[:limit]] def get_category_merchants(self, category: str) -> List[Merchant]: """Get all merchants in a category.""" return [m for m in self.merchants.values() if m.category == category] def to_dict(self) -> Dict: """Export to dictionary.""" return { name: asdict(merchant) for name, merchant in self.merchants.items() } def save(self, path: Path): """Save to JSON file.""" with open(path, 'w') as f: json.dump(self.to_dict(), f, indent=2) @classmethod def load(cls, path: Path) -> 'MerchantKnowledgeBase': """Load from JSON file.""" kb = cls() kb.merchants = {} # Clear defaults kb.vpa_index = {} kb.alias_index = {} with open(path) as f: data = json.load(f) for name, merchant_data in data.items(): merchant = Merchant(**merchant_data) kb.add_merchant(merchant) return kb # ============================================================================ # CATEGORY TAXONOMY # ============================================================================ class CategoryTaxonomy: """ Hierarchical category system for transactions. """ TAXONOMY = { "food": { "parent": None, "children": ["restaurant", "delivery", "cafe", "street_food"], "keywords": ["food", "meal", "lunch", "dinner", "breakfast", "eating"], }, "grocery": { "parent": "shopping", "children": ["supermarket", "vegetables", "dairy"], "keywords": ["grocery", "vegetables", "fruits", "milk", "provisions"], }, "shopping": { "parent": None, "children": ["electronics", "fashion", "home", "grocery"], "keywords": ["purchase", "order", "buy", "shopping"], }, "transport": { "parent": None, "children": ["cab", "auto", "fuel", "parking"], "keywords": ["ride", "travel", "cab", "uber", "ola", "petrol"], }, "travel": { "parent": None, "children": ["flight", "hotel", "train", "bus"], "keywords": ["booking", "ticket", "travel", "trip", "vacation"], }, "bills": { "parent": None, "children": ["electricity", "mobile", "broadband", "gas", "water"], "keywords": ["bill", "recharge", "payment", "utility"], }, "entertainment": { "parent": None, "children": ["movies", "streaming", "gaming", "events"], "keywords": ["movie", "show", "subscription", "netflix", "concert"], }, "healthcare": { "parent": None, "children": ["pharmacy", "doctor", "hospital", "lab"], "keywords": ["medicine", "health", "doctor", "pharmacy", "medical"], }, "investment": { "parent": None, "children": ["stocks", "mutual_fund", "sip", "ipo"], "keywords": ["invest", "trading", "SIP", "mutual fund", "stock"], }, "transfer": { "parent": None, "children": ["p2p", "salary", "refund"], "keywords": ["transfer", "send", "receive", "credited"], }, "emi": { "parent": None, "children": ["loan", "credit_card"], "keywords": ["EMI", "loan", "installment", "auto debit"], }, } @classmethod def get_hierarchy(cls, category: str) -> List[str]: """Get full category hierarchy.""" if category not in cls.TAXONOMY: return [category] hierarchy = [category] current = category while cls.TAXONOMY.get(current, {}).get("parent"): parent = cls.TAXONOMY[current]["parent"] hierarchy.insert(0, parent) current = parent return hierarchy @classmethod def infer_category(cls, text: str, amount: Optional[float] = None) -> str: """Infer category from text.""" text_lower = text.lower() scores = defaultdict(int) for category, info in cls.TAXONOMY.items(): for keyword in info["keywords"]: if keyword.lower() in text_lower: scores[category] += 1 if scores: return max(scores.items(), key=lambda x: x[1])[0] return "other" # ============================================================================ # VECTOR STORE (Simple) # ============================================================================ class SimpleVectorStore: """ Simple in-memory vector store for transaction similarity. Uses TF-IDF-like scoring for efficiency (no external dependencies). """ def __init__(self): self.documents: List[Transaction] = [] self.vocabulary: Dict[str, int] = {} self.idf: Dict[str, float] = {} self.tfidf_matrix: List[Dict[str, float]] = [] def _tokenize(self, text: str) -> List[str]: """Simple tokenization.""" import re text = text.lower() tokens = re.findall(r'\b\w+\b', text) return [t for t in tokens if len(t) > 2] def _compute_tf(self, tokens: List[str]) -> Dict[str, float]: """Compute term frequency.""" tf = defaultdict(int) for token in tokens: tf[token] += 1 total = len(tokens) or 1 return {k: v/total for k, v in tf.items()} def add(self, transaction: Transaction): """Add transaction to store.""" self.documents.append(transaction) tokens = self._tokenize(transaction.text) for token in set(tokens): if token not in self.vocabulary: self.vocabulary[token] = len(self.vocabulary) # Update IDF self._update_idf() # Compute TF-IDF for new document tf = self._compute_tf(tokens) tfidf = {k: v * self.idf.get(k, 0) for k, v in tf.items()} self.tfidf_matrix.append(tfidf) def _update_idf(self): """Update IDF scores.""" import math n_docs = len(self.documents) doc_freq = defaultdict(int) for doc in self.documents: tokens = set(self._tokenize(doc.text)) for token in tokens: doc_freq[token] += 1 self.idf = { token: math.log((n_docs + 1) / (df + 1)) + 1 for token, df in doc_freq.items() } def search(self, query: str, limit: int = 5) -> List[Tuple[Transaction, float]]: """Search for similar transactions.""" if not self.documents: return [] query_tokens = self._tokenize(query) query_tf = self._compute_tf(query_tokens) query_tfidf = {k: v * self.idf.get(k, 0) for k, v in query_tf.items()} # Compute cosine similarity results = [] for i, doc_tfidf in enumerate(self.tfidf_matrix): score = self._cosine_similarity(query_tfidf, doc_tfidf) if score > 0: results.append((self.documents[i], score)) results.sort(key=lambda x: x[1], reverse=True) return results[:limit] def _cosine_similarity(self, vec1: Dict[str, float], vec2: Dict[str, float]) -> float: """Compute cosine similarity between two sparse vectors.""" import math common_keys = set(vec1.keys()) & set(vec2.keys()) if not common_keys: return 0.0 dot_product = sum(vec1[k] * vec2[k] for k in common_keys) norm1 = math.sqrt(sum(v**2 for v in vec1.values())) norm2 = math.sqrt(sum(v**2 for v in vec2.values())) if norm1 == 0 or norm2 == 0: return 0.0 return dot_product / (norm1 * norm2) def save(self, path: Path): """Save vector store to file.""" data = { "documents": [asdict(d) for d in self.documents], "vocabulary": self.vocabulary, "idf": self.idf, } with open(path, 'w') as f: json.dump(data, f) @classmethod def load(cls, path: Path) -> 'SimpleVectorStore': """Load vector store from file.""" store = cls() with open(path) as f: data = json.load(f) store.vocabulary = data["vocabulary"] store.idf = data["idf"] for doc_data in data["documents"]: doc = Transaction(**doc_data) store.documents.append(doc) tokens = store._tokenize(doc.text) tf = store._compute_tf(tokens) tfidf = {k: v * store.idf.get(k, 0) for k, v in tf.items()} store.tfidf_matrix.append(tfidf) return store # ============================================================================ # RAG ENGINE # ============================================================================ class RAGEngine: """ RAG Engine for context-aware entity extraction. Combines: 1. Merchant Knowledge Base 2. Transaction History (Vector Store) 3. Category Taxonomy """ def __init__(self): self.merchant_kb = MerchantKnowledgeBase() self.vector_store = SimpleVectorStore() self.taxonomy = CategoryTaxonomy() def retrieve(self, message: str, amount: Optional[float] = None) -> RetrievedContext: """ Retrieve relevant context for a message. Returns: RetrievedContext with merchant info, similar transactions, and category hierarchy """ context = RetrievedContext( similar_transactions=[], confidence_boost=0.0 ) # 1. Search merchants merchants = self.merchant_kb.search(message, limit=3) if merchants: top_merchant = merchants[0] context.merchant_info = { "name": top_merchant.name, "category": top_merchant.category, "vpa": top_merchant.vpa, "is_p2m": top_merchant.is_p2m, "typical_range": top_merchant.typical_amount_range, } context.confidence_boost += 0.1 # Validate amount range if amount and top_merchant.typical_amount_range: min_amt, max_amt = top_merchant.typical_amount_range if min_amt <= amount <= max_amt: context.confidence_boost += 0.05 # 2. Search similar transactions similar = self.vector_store.search(message, limit=3) for txn, score in similar: context.similar_transactions.append({ "text": txn.text[:100], "merchant": txn.merchant, "category": txn.category, "score": round(score, 3) }) if score > 0.7: context.confidence_boost += 0.05 # 3. Infer category hierarchy inferred_category = self.taxonomy.infer_category(message, amount) context.category_hierarchy = self.taxonomy.get_hierarchy(inferred_category) return context def augment_prompt(self, message: str, context: RetrievedContext) -> str: """ Augment extraction prompt with retrieved context. """ prompt_parts = [f"Message: {message}"] if context.merchant_info: prompt_parts.append(f"\nKnown Merchant: {context.merchant_info['name']} ({context.merchant_info['category']})") if context.similar_transactions: prompt_parts.append("\nSimilar past transactions:") for txn in context.similar_transactions[:2]: if txn['merchant']: prompt_parts.append(f" - {txn['merchant']} ({txn['category']})") if context.category_hierarchy: prompt_parts.append(f"\nLikely category: {' > '.join(context.category_hierarchy)}") return "\n".join(prompt_parts) def add_transaction(self, text: str, extracted: Dict): """Add extracted transaction to history for future retrieval.""" txn = Transaction( id=hashlib.md5(text.encode()).hexdigest()[:8], text=text, amount=extracted.get("amount"), type=extracted.get("type"), merchant=extracted.get("merchant"), category=extracted.get("category"), ) self.vector_store.add(txn) def save(self, directory: Path): """Save RAG state.""" directory.mkdir(parents=True, exist_ok=True) self.merchant_kb.save(directory / "merchants.json") self.vector_store.save(directory / "transactions.json") def load(self, directory: Path): """Load RAG state.""" merchant_path = directory / "merchants.json" if merchant_path.exists(): self.merchant_kb = MerchantKnowledgeBase.load(merchant_path) txn_path = directory / "transactions.json" if txn_path.exists(): self.vector_store = SimpleVectorStore.load(txn_path) # ============================================================================ # USAGE EXAMPLE # ============================================================================ if __name__ == "__main__": # Initialize RAG rag = RAGEngine() # Test retrieval test_messages = [ "HDFC Bank: Rs.499 debited from A/c XX1234. UPI:swiggy@ybl. Ref:123456", "SBI: Rs.25,000 credited. NEFT from ZERODHA BROKING. Ref:N123456", "ICICI: Rs.199 debited. Netflix subscription. Card XX5678", ] print("=" * 60) print("RAG Engine Demo") print("=" * 60) for msg in test_messages: print(f"\nšŸ“© Message: {msg[:60]}...") # Retrieve context context = rag.retrieve(msg) print(f"\nšŸ” Retrieved Context:") if context.merchant_info: print(f" Merchant: {context.merchant_info['name']} ({context.merchant_info['category']})") if context.category_hierarchy: print(f" Category: {' > '.join(context.category_hierarchy)}") print(f" Confidence Boost: +{context.confidence_boost:.0%}") # Augmented prompt augmented = rag.augment_prompt(msg, context) print(f"\nšŸ“ Augmented Prompt:\n{augmented}") print("-" * 60)