|
|
""" |
|
|
RAG (Retrieval Augmented Generation) for FinEE |
|
|
============================================== |
|
|
|
|
|
Provides context-aware entity extraction using: |
|
|
1. Merchant Knowledge Base - 10K+ Indian merchants |
|
|
2. Transaction History - Similar past transactions |
|
|
3. Category Taxonomy - Hierarchical categories |
|
|
|
|
|
Author: Ranjit Behera |
|
|
""" |
|
|
|
|
|
import json |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Optional, Tuple |
|
|
from dataclasses import dataclass, asdict |
|
|
from collections import defaultdict |
|
|
import hashlib |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Merchant: |
|
|
"""Merchant entity with metadata.""" |
|
|
name: str |
|
|
vpa: Optional[str] = None |
|
|
category: str = "other" |
|
|
aliases: List[str] = None |
|
|
keywords: List[str] = None |
|
|
typical_amount_range: Tuple[float, float] = (0, 100000) |
|
|
is_p2m: bool = True |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.aliases is None: |
|
|
self.aliases = [] |
|
|
if self.keywords is None: |
|
|
self.keywords = [] |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Transaction: |
|
|
"""Stored transaction for similarity search.""" |
|
|
id: str |
|
|
text: str |
|
|
amount: Optional[float] = None |
|
|
type: Optional[str] = None |
|
|
merchant: Optional[str] = None |
|
|
category: Optional[str] = None |
|
|
embedding: Optional[List[float]] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class RetrievedContext: |
|
|
"""Context retrieved for augmentation.""" |
|
|
merchant_info: Optional[Dict] = None |
|
|
similar_transactions: List[Dict] = None |
|
|
category_hierarchy: List[str] = None |
|
|
confidence_boost: float = 0.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MerchantKnowledgeBase: |
|
|
""" |
|
|
Knowledge base of Indian merchants for RAG. |
|
|
|
|
|
Provides merchant lookup, alias resolution, and category inference. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.merchants: Dict[str, Merchant] = {} |
|
|
self.vpa_index: Dict[str, str] = {} |
|
|
self.alias_index: Dict[str, str] = {} |
|
|
self._load_default_merchants() |
|
|
|
|
|
def _load_default_merchants(self): |
|
|
"""Load built-in merchant database.""" |
|
|
default_merchants = [ |
|
|
|
|
|
Merchant("Swiggy", "swiggy@ybl", "food", |
|
|
["swiggy.in", "swiggy food"], ["delivery", "food order"], (50, 2000)), |
|
|
Merchant("Zomato", "zomato@paytm", "food", |
|
|
["zomato.com", "zomato food"], ["restaurant", "dining"], (100, 3000)), |
|
|
Merchant("Zepto", "zepto@ybl", "grocery", |
|
|
[], ["10 min", "quick commerce"], (100, 2000)), |
|
|
Merchant("BigBasket", "bigbasket@ybl", "grocery", |
|
|
["bb", "bigbasket.com"], ["grocery", "vegetables"], (200, 5000)), |
|
|
Merchant("Blinkit", "blinkit@ybl", "grocery", |
|
|
["grofers"], ["instant delivery"], (100, 2000)), |
|
|
|
|
|
|
|
|
Merchant("Amazon", "amazon@apl", "shopping", |
|
|
["amzn", "amazon.in", "amazonpay"], ["order", "delivery"], (100, 50000)), |
|
|
Merchant("Flipkart", "flipkart@ybl", "shopping", |
|
|
["fk", "flipkart.com"], ["order", "electronics"], (100, 100000)), |
|
|
Merchant("Myntra", "myntra@ybl", "shopping", |
|
|
["myntra.com"], ["fashion", "clothing"], (200, 10000)), |
|
|
Merchant("Meesho", "meesho@ybl", "shopping", |
|
|
[], ["reseller", "wholesale"], (100, 5000)), |
|
|
Merchant("Ajio", "ajio@ybl", "shopping", |
|
|
["ajio.com"], ["fashion", "brand"], (200, 15000)), |
|
|
|
|
|
|
|
|
Merchant("Uber", "uber@axisbank", "transport", |
|
|
["uber india", "uberindia"], ["ride", "cab"], (50, 2000)), |
|
|
Merchant("Ola", "ola@ybl", "transport", |
|
|
["olacabs", "ola cabs"], ["ride", "cab", "auto"], (30, 1500)), |
|
|
Merchant("Rapido", "rapido@ybl", "transport", |
|
|
[], ["bike taxi", "auto"], (20, 500)), |
|
|
|
|
|
|
|
|
Merchant("IRCTC", "irctc@sbi", "travel", |
|
|
["indian railways", "railway"], ["train", "ticket"], (100, 10000)), |
|
|
Merchant("MakeMyTrip", "makemytrip@ybl", "travel", |
|
|
["mmt", "makemytrip.com"], ["flight", "hotel", "booking"], (500, 100000)), |
|
|
Merchant("Goibibo", "goibibo@ybl", "travel", |
|
|
[], ["flight", "hotel"], (500, 50000)), |
|
|
Merchant("Yatra", "yatra@ybl", "travel", |
|
|
["yatra.com"], ["travel", "booking"], (500, 50000)), |
|
|
|
|
|
|
|
|
Merchant("Airtel", "airtel@paytm", "bills", |
|
|
["bharti airtel"], ["recharge", "postpaid", "broadband"], (100, 5000)), |
|
|
Merchant("Jio", "jio@ybl", "bills", |
|
|
["reliance jio"], ["recharge", "fiber"], (100, 3000)), |
|
|
Merchant("Vi", "vi@ybl", "bills", |
|
|
["vodafone", "vodafone idea"], ["recharge", "prepaid"], (100, 2000)), |
|
|
Merchant("BSNL", "bsnl@ybl", "bills", |
|
|
[], ["landline", "broadband"], (100, 3000)), |
|
|
Merchant("Tata Power", "tatapower@ybl", "bills", |
|
|
[], ["electricity", "power bill"], (500, 20000)), |
|
|
Merchant("BESCOM", "bescom@ybl", "bills", |
|
|
[], ["electricity", "bangalore"], (200, 10000)), |
|
|
|
|
|
|
|
|
Merchant("Netflix", "netflix@ybl", "entertainment", |
|
|
["netflix.com"], ["subscription", "streaming"], (149, 799)), |
|
|
Merchant("Amazon Prime", "amazonprime@apl", "entertainment", |
|
|
["prime video", "primevideo"], ["subscription", "streaming"], (129, 1499)), |
|
|
Merchant("Hotstar", "hotstar@ybl", "entertainment", |
|
|
["disney hotstar", "disney+"], ["subscription", "cricket"], (299, 1499)), |
|
|
Merchant("BookMyShow", "bookmyshow@ybl", "entertainment", |
|
|
["bms"], ["movie", "ticket", "event"], (100, 5000)), |
|
|
Merchant("Spotify", "spotify@ybl", "entertainment", |
|
|
[], ["music", "subscription"], (59, 179)), |
|
|
|
|
|
|
|
|
Merchant("Zerodha", "zerodha@hdfcbank", "investment", |
|
|
["zerodha kite", "kite"], ["stocks", "trading", "MF"], (100, 1000000)), |
|
|
Merchant("Groww", "groww@ybl", "investment", |
|
|
["groww.in"], ["mutual fund", "SIP", "stocks"], (100, 500000)), |
|
|
Merchant("Upstox", "upstox@ybl", "investment", |
|
|
[], ["trading", "demat"], (100, 500000)), |
|
|
Merchant("Angel One", "angelone@ybl", "investment", |
|
|
["angel broking"], ["trading", "stocks"], (100, 500000)), |
|
|
Merchant("5Paisa", "5paisa@ybl", "investment", |
|
|
["five paisa"], ["trading", "broker"], (100, 500000)), |
|
|
Merchant("Coin by Zerodha", "coin@zerodha", "investment", |
|
|
["zerodha coin"], ["mutual fund", "MF"], (500, 100000)), |
|
|
Merchant("Kuvera", "kuvera@ybl", "investment", |
|
|
[], ["mutual fund", "goal based"], (500, 100000)), |
|
|
Merchant("ET Money", "etmoney@ybl", "investment", |
|
|
[], ["mutual fund", "SIP"], (500, 50000)), |
|
|
|
|
|
|
|
|
Merchant("Apollo Pharmacy", "apollo@ybl", "healthcare", |
|
|
["apollo", "apollo hospitals"], ["medicine", "pharmacy"], (50, 10000)), |
|
|
Merchant("PharmEasy", "pharmeasy@ybl", "healthcare", |
|
|
[], ["medicine", "health", "lab test"], (100, 5000)), |
|
|
Merchant("1mg", "1mg@ybl", "healthcare", |
|
|
["onemg", "tata 1mg"], ["medicine", "pharmacy"], (50, 5000)), |
|
|
Merchant("Netmeds", "netmeds@ybl", "healthcare", |
|
|
[], ["medicine", "pharmacy"], (100, 3000)), |
|
|
Merchant("Practo", "practo@ybl", "healthcare", |
|
|
[], ["doctor", "consultation", "appointment"], (100, 2000)), |
|
|
|
|
|
|
|
|
Merchant("PhonePe", "phonepe@ybl", "transfer", |
|
|
[], ["cashback", "reward"], (1, 100000), False), |
|
|
Merchant("Paytm", "paytm@paytm", "transfer", |
|
|
["paytmmall"], ["cashback", "wallet"], (1, 100000), False), |
|
|
Merchant("Google Pay", "googlepay@okicici", "transfer", |
|
|
["gpay", "tez"], ["cashback", "reward"], (1, 100000), False), |
|
|
|
|
|
|
|
|
Merchant("HP Petrol", "hpcl@ybl", "transport", |
|
|
["hindustan petroleum", "HPCL"], ["fuel", "petrol"], (100, 10000)), |
|
|
Merchant("Indian Oil", "iocl@ybl", "transport", |
|
|
["IOCL", "indian oil corporation"], ["fuel", "petrol"], (100, 10000)), |
|
|
Merchant("Bharat Petroleum", "bpcl@ybl", "transport", |
|
|
["BPCL"], ["fuel", "petrol", "diesel"], (100, 10000)), |
|
|
] |
|
|
|
|
|
for merchant in default_merchants: |
|
|
self.add_merchant(merchant) |
|
|
|
|
|
def add_merchant(self, merchant: Merchant): |
|
|
"""Add merchant to knowledge base.""" |
|
|
self.merchants[merchant.name.lower()] = merchant |
|
|
|
|
|
if merchant.vpa: |
|
|
self.vpa_index[merchant.vpa.lower()] = merchant.name |
|
|
|
|
|
for alias in merchant.aliases: |
|
|
self.alias_index[alias.lower()] = merchant.name |
|
|
|
|
|
def lookup(self, query: str) -> Optional[Merchant]: |
|
|
"""Look up merchant by name, VPA, or alias.""" |
|
|
query_lower = query.lower().strip() |
|
|
|
|
|
|
|
|
if query_lower in self.merchants: |
|
|
return self.merchants[query_lower] |
|
|
|
|
|
|
|
|
if '@' in query_lower and query_lower in self.vpa_index: |
|
|
name = self.vpa_index[query_lower] |
|
|
return self.merchants.get(name.lower()) |
|
|
|
|
|
|
|
|
if query_lower in self.alias_index: |
|
|
name = self.alias_index[query_lower] |
|
|
return self.merchants.get(name.lower()) |
|
|
|
|
|
|
|
|
for name, merchant in self.merchants.items(): |
|
|
if query_lower in name or name in query_lower: |
|
|
return merchant |
|
|
|
|
|
|
|
|
if merchant.vpa and query_lower in merchant.vpa.lower(): |
|
|
return merchant |
|
|
|
|
|
return None |
|
|
|
|
|
def search(self, text: str, limit: int = 5) -> List[Merchant]: |
|
|
"""Search merchants by text content.""" |
|
|
text_lower = text.lower() |
|
|
matches = [] |
|
|
|
|
|
for name, merchant in self.merchants.items(): |
|
|
score = 0 |
|
|
|
|
|
|
|
|
if name in text_lower: |
|
|
score += 10 |
|
|
|
|
|
|
|
|
if merchant.vpa and merchant.vpa.split('@')[0] in text_lower: |
|
|
score += 8 |
|
|
|
|
|
|
|
|
for alias in merchant.aliases: |
|
|
if alias.lower() in text_lower: |
|
|
score += 5 |
|
|
|
|
|
|
|
|
for keyword in merchant.keywords: |
|
|
if keyword.lower() in text_lower: |
|
|
score += 2 |
|
|
|
|
|
if score > 0: |
|
|
matches.append((merchant, score)) |
|
|
|
|
|
|
|
|
matches.sort(key=lambda x: x[1], reverse=True) |
|
|
return [m[0] for m in matches[:limit]] |
|
|
|
|
|
def get_category_merchants(self, category: str) -> List[Merchant]: |
|
|
"""Get all merchants in a category.""" |
|
|
return [m for m in self.merchants.values() if m.category == category] |
|
|
|
|
|
def to_dict(self) -> Dict: |
|
|
"""Export to dictionary.""" |
|
|
return { |
|
|
name: asdict(merchant) |
|
|
for name, merchant in self.merchants.items() |
|
|
} |
|
|
|
|
|
def save(self, path: Path): |
|
|
"""Save to JSON file.""" |
|
|
with open(path, 'w') as f: |
|
|
json.dump(self.to_dict(), f, indent=2) |
|
|
|
|
|
@classmethod |
|
|
def load(cls, path: Path) -> 'MerchantKnowledgeBase': |
|
|
"""Load from JSON file.""" |
|
|
kb = cls() |
|
|
kb.merchants = {} |
|
|
kb.vpa_index = {} |
|
|
kb.alias_index = {} |
|
|
|
|
|
with open(path) as f: |
|
|
data = json.load(f) |
|
|
|
|
|
for name, merchant_data in data.items(): |
|
|
merchant = Merchant(**merchant_data) |
|
|
kb.add_merchant(merchant) |
|
|
|
|
|
return kb |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CategoryTaxonomy: |
|
|
""" |
|
|
Hierarchical category system for transactions. |
|
|
""" |
|
|
|
|
|
TAXONOMY = { |
|
|
"food": { |
|
|
"parent": None, |
|
|
"children": ["restaurant", "delivery", "cafe", "street_food"], |
|
|
"keywords": ["food", "meal", "lunch", "dinner", "breakfast", "eating"], |
|
|
}, |
|
|
"grocery": { |
|
|
"parent": "shopping", |
|
|
"children": ["supermarket", "vegetables", "dairy"], |
|
|
"keywords": ["grocery", "vegetables", "fruits", "milk", "provisions"], |
|
|
}, |
|
|
"shopping": { |
|
|
"parent": None, |
|
|
"children": ["electronics", "fashion", "home", "grocery"], |
|
|
"keywords": ["purchase", "order", "buy", "shopping"], |
|
|
}, |
|
|
"transport": { |
|
|
"parent": None, |
|
|
"children": ["cab", "auto", "fuel", "parking"], |
|
|
"keywords": ["ride", "travel", "cab", "uber", "ola", "petrol"], |
|
|
}, |
|
|
"travel": { |
|
|
"parent": None, |
|
|
"children": ["flight", "hotel", "train", "bus"], |
|
|
"keywords": ["booking", "ticket", "travel", "trip", "vacation"], |
|
|
}, |
|
|
"bills": { |
|
|
"parent": None, |
|
|
"children": ["electricity", "mobile", "broadband", "gas", "water"], |
|
|
"keywords": ["bill", "recharge", "payment", "utility"], |
|
|
}, |
|
|
"entertainment": { |
|
|
"parent": None, |
|
|
"children": ["movies", "streaming", "gaming", "events"], |
|
|
"keywords": ["movie", "show", "subscription", "netflix", "concert"], |
|
|
}, |
|
|
"healthcare": { |
|
|
"parent": None, |
|
|
"children": ["pharmacy", "doctor", "hospital", "lab"], |
|
|
"keywords": ["medicine", "health", "doctor", "pharmacy", "medical"], |
|
|
}, |
|
|
"investment": { |
|
|
"parent": None, |
|
|
"children": ["stocks", "mutual_fund", "sip", "ipo"], |
|
|
"keywords": ["invest", "trading", "SIP", "mutual fund", "stock"], |
|
|
}, |
|
|
"transfer": { |
|
|
"parent": None, |
|
|
"children": ["p2p", "salary", "refund"], |
|
|
"keywords": ["transfer", "send", "receive", "credited"], |
|
|
}, |
|
|
"emi": { |
|
|
"parent": None, |
|
|
"children": ["loan", "credit_card"], |
|
|
"keywords": ["EMI", "loan", "installment", "auto debit"], |
|
|
}, |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def get_hierarchy(cls, category: str) -> List[str]: |
|
|
"""Get full category hierarchy.""" |
|
|
if category not in cls.TAXONOMY: |
|
|
return [category] |
|
|
|
|
|
hierarchy = [category] |
|
|
current = category |
|
|
|
|
|
while cls.TAXONOMY.get(current, {}).get("parent"): |
|
|
parent = cls.TAXONOMY[current]["parent"] |
|
|
hierarchy.insert(0, parent) |
|
|
current = parent |
|
|
|
|
|
return hierarchy |
|
|
|
|
|
@classmethod |
|
|
def infer_category(cls, text: str, amount: Optional[float] = None) -> str: |
|
|
"""Infer category from text.""" |
|
|
text_lower = text.lower() |
|
|
scores = defaultdict(int) |
|
|
|
|
|
for category, info in cls.TAXONOMY.items(): |
|
|
for keyword in info["keywords"]: |
|
|
if keyword.lower() in text_lower: |
|
|
scores[category] += 1 |
|
|
|
|
|
if scores: |
|
|
return max(scores.items(), key=lambda x: x[1])[0] |
|
|
|
|
|
return "other" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SimpleVectorStore: |
|
|
""" |
|
|
Simple in-memory vector store for transaction similarity. |
|
|
|
|
|
Uses TF-IDF-like scoring for efficiency (no external dependencies). |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.documents: List[Transaction] = [] |
|
|
self.vocabulary: Dict[str, int] = {} |
|
|
self.idf: Dict[str, float] = {} |
|
|
self.tfidf_matrix: List[Dict[str, float]] = [] |
|
|
|
|
|
def _tokenize(self, text: str) -> List[str]: |
|
|
"""Simple tokenization.""" |
|
|
import re |
|
|
text = text.lower() |
|
|
tokens = re.findall(r'\b\w+\b', text) |
|
|
return [t for t in tokens if len(t) > 2] |
|
|
|
|
|
def _compute_tf(self, tokens: List[str]) -> Dict[str, float]: |
|
|
"""Compute term frequency.""" |
|
|
tf = defaultdict(int) |
|
|
for token in tokens: |
|
|
tf[token] += 1 |
|
|
|
|
|
total = len(tokens) or 1 |
|
|
return {k: v/total for k, v in tf.items()} |
|
|
|
|
|
def add(self, transaction: Transaction): |
|
|
"""Add transaction to store.""" |
|
|
self.documents.append(transaction) |
|
|
|
|
|
tokens = self._tokenize(transaction.text) |
|
|
for token in set(tokens): |
|
|
if token not in self.vocabulary: |
|
|
self.vocabulary[token] = len(self.vocabulary) |
|
|
|
|
|
|
|
|
self._update_idf() |
|
|
|
|
|
|
|
|
tf = self._compute_tf(tokens) |
|
|
tfidf = {k: v * self.idf.get(k, 0) for k, v in tf.items()} |
|
|
self.tfidf_matrix.append(tfidf) |
|
|
|
|
|
def _update_idf(self): |
|
|
"""Update IDF scores.""" |
|
|
import math |
|
|
n_docs = len(self.documents) |
|
|
|
|
|
doc_freq = defaultdict(int) |
|
|
for doc in self.documents: |
|
|
tokens = set(self._tokenize(doc.text)) |
|
|
for token in tokens: |
|
|
doc_freq[token] += 1 |
|
|
|
|
|
self.idf = { |
|
|
token: math.log((n_docs + 1) / (df + 1)) + 1 |
|
|
for token, df in doc_freq.items() |
|
|
} |
|
|
|
|
|
def search(self, query: str, limit: int = 5) -> List[Tuple[Transaction, float]]: |
|
|
"""Search for similar transactions.""" |
|
|
if not self.documents: |
|
|
return [] |
|
|
|
|
|
query_tokens = self._tokenize(query) |
|
|
query_tf = self._compute_tf(query_tokens) |
|
|
query_tfidf = {k: v * self.idf.get(k, 0) for k, v in query_tf.items()} |
|
|
|
|
|
|
|
|
results = [] |
|
|
for i, doc_tfidf in enumerate(self.tfidf_matrix): |
|
|
score = self._cosine_similarity(query_tfidf, doc_tfidf) |
|
|
if score > 0: |
|
|
results.append((self.documents[i], score)) |
|
|
|
|
|
results.sort(key=lambda x: x[1], reverse=True) |
|
|
return results[:limit] |
|
|
|
|
|
def _cosine_similarity(self, vec1: Dict[str, float], vec2: Dict[str, float]) -> float: |
|
|
"""Compute cosine similarity between two sparse vectors.""" |
|
|
import math |
|
|
|
|
|
common_keys = set(vec1.keys()) & set(vec2.keys()) |
|
|
if not common_keys: |
|
|
return 0.0 |
|
|
|
|
|
dot_product = sum(vec1[k] * vec2[k] for k in common_keys) |
|
|
norm1 = math.sqrt(sum(v**2 for v in vec1.values())) |
|
|
norm2 = math.sqrt(sum(v**2 for v in vec2.values())) |
|
|
|
|
|
if norm1 == 0 or norm2 == 0: |
|
|
return 0.0 |
|
|
|
|
|
return dot_product / (norm1 * norm2) |
|
|
|
|
|
def save(self, path: Path): |
|
|
"""Save vector store to file.""" |
|
|
data = { |
|
|
"documents": [asdict(d) for d in self.documents], |
|
|
"vocabulary": self.vocabulary, |
|
|
"idf": self.idf, |
|
|
} |
|
|
with open(path, 'w') as f: |
|
|
json.dump(data, f) |
|
|
|
|
|
@classmethod |
|
|
def load(cls, path: Path) -> 'SimpleVectorStore': |
|
|
"""Load vector store from file.""" |
|
|
store = cls() |
|
|
|
|
|
with open(path) as f: |
|
|
data = json.load(f) |
|
|
|
|
|
store.vocabulary = data["vocabulary"] |
|
|
store.idf = data["idf"] |
|
|
|
|
|
for doc_data in data["documents"]: |
|
|
doc = Transaction(**doc_data) |
|
|
store.documents.append(doc) |
|
|
|
|
|
tokens = store._tokenize(doc.text) |
|
|
tf = store._compute_tf(tokens) |
|
|
tfidf = {k: v * store.idf.get(k, 0) for k, v in tf.items()} |
|
|
store.tfidf_matrix.append(tfidf) |
|
|
|
|
|
return store |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RAGEngine: |
|
|
""" |
|
|
RAG Engine for context-aware entity extraction. |
|
|
|
|
|
Combines: |
|
|
1. Merchant Knowledge Base |
|
|
2. Transaction History (Vector Store) |
|
|
3. Category Taxonomy |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.merchant_kb = MerchantKnowledgeBase() |
|
|
self.vector_store = SimpleVectorStore() |
|
|
self.taxonomy = CategoryTaxonomy() |
|
|
|
|
|
def retrieve(self, message: str, amount: Optional[float] = None) -> RetrievedContext: |
|
|
""" |
|
|
Retrieve relevant context for a message. |
|
|
|
|
|
Returns: |
|
|
RetrievedContext with merchant info, similar transactions, and category hierarchy |
|
|
""" |
|
|
context = RetrievedContext( |
|
|
similar_transactions=[], |
|
|
confidence_boost=0.0 |
|
|
) |
|
|
|
|
|
|
|
|
merchants = self.merchant_kb.search(message, limit=3) |
|
|
if merchants: |
|
|
top_merchant = merchants[0] |
|
|
context.merchant_info = { |
|
|
"name": top_merchant.name, |
|
|
"category": top_merchant.category, |
|
|
"vpa": top_merchant.vpa, |
|
|
"is_p2m": top_merchant.is_p2m, |
|
|
"typical_range": top_merchant.typical_amount_range, |
|
|
} |
|
|
context.confidence_boost += 0.1 |
|
|
|
|
|
|
|
|
if amount and top_merchant.typical_amount_range: |
|
|
min_amt, max_amt = top_merchant.typical_amount_range |
|
|
if min_amt <= amount <= max_amt: |
|
|
context.confidence_boost += 0.05 |
|
|
|
|
|
|
|
|
similar = self.vector_store.search(message, limit=3) |
|
|
for txn, score in similar: |
|
|
context.similar_transactions.append({ |
|
|
"text": txn.text[:100], |
|
|
"merchant": txn.merchant, |
|
|
"category": txn.category, |
|
|
"score": round(score, 3) |
|
|
}) |
|
|
|
|
|
if score > 0.7: |
|
|
context.confidence_boost += 0.05 |
|
|
|
|
|
|
|
|
inferred_category = self.taxonomy.infer_category(message, amount) |
|
|
context.category_hierarchy = self.taxonomy.get_hierarchy(inferred_category) |
|
|
|
|
|
return context |
|
|
|
|
|
def augment_prompt(self, message: str, context: RetrievedContext) -> str: |
|
|
""" |
|
|
Augment extraction prompt with retrieved context. |
|
|
""" |
|
|
prompt_parts = [f"Message: {message}"] |
|
|
|
|
|
if context.merchant_info: |
|
|
prompt_parts.append(f"\nKnown Merchant: {context.merchant_info['name']} ({context.merchant_info['category']})") |
|
|
|
|
|
if context.similar_transactions: |
|
|
prompt_parts.append("\nSimilar past transactions:") |
|
|
for txn in context.similar_transactions[:2]: |
|
|
if txn['merchant']: |
|
|
prompt_parts.append(f" - {txn['merchant']} ({txn['category']})") |
|
|
|
|
|
if context.category_hierarchy: |
|
|
prompt_parts.append(f"\nLikely category: {' > '.join(context.category_hierarchy)}") |
|
|
|
|
|
return "\n".join(prompt_parts) |
|
|
|
|
|
def add_transaction(self, text: str, extracted: Dict): |
|
|
"""Add extracted transaction to history for future retrieval.""" |
|
|
txn = Transaction( |
|
|
id=hashlib.md5(text.encode()).hexdigest()[:8], |
|
|
text=text, |
|
|
amount=extracted.get("amount"), |
|
|
type=extracted.get("type"), |
|
|
merchant=extracted.get("merchant"), |
|
|
category=extracted.get("category"), |
|
|
) |
|
|
self.vector_store.add(txn) |
|
|
|
|
|
def save(self, directory: Path): |
|
|
"""Save RAG state.""" |
|
|
directory.mkdir(parents=True, exist_ok=True) |
|
|
self.merchant_kb.save(directory / "merchants.json") |
|
|
self.vector_store.save(directory / "transactions.json") |
|
|
|
|
|
def load(self, directory: Path): |
|
|
"""Load RAG state.""" |
|
|
merchant_path = directory / "merchants.json" |
|
|
if merchant_path.exists(): |
|
|
self.merchant_kb = MerchantKnowledgeBase.load(merchant_path) |
|
|
|
|
|
txn_path = directory / "transactions.json" |
|
|
if txn_path.exists(): |
|
|
self.vector_store = SimpleVectorStore.load(txn_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
rag = RAGEngine() |
|
|
|
|
|
|
|
|
test_messages = [ |
|
|
"HDFC Bank: Rs.499 debited from A/c XX1234. UPI:swiggy@ybl. Ref:123456", |
|
|
"SBI: Rs.25,000 credited. NEFT from ZERODHA BROKING. Ref:N123456", |
|
|
"ICICI: Rs.199 debited. Netflix subscription. Card XX5678", |
|
|
] |
|
|
|
|
|
print("=" * 60) |
|
|
print("RAG Engine Demo") |
|
|
print("=" * 60) |
|
|
|
|
|
for msg in test_messages: |
|
|
print(f"\n📩 Message: {msg[:60]}...") |
|
|
|
|
|
|
|
|
context = rag.retrieve(msg) |
|
|
|
|
|
print(f"\n🔍 Retrieved Context:") |
|
|
if context.merchant_info: |
|
|
print(f" Merchant: {context.merchant_info['name']} ({context.merchant_info['category']})") |
|
|
|
|
|
if context.category_hierarchy: |
|
|
print(f" Category: {' > '.join(context.category_hierarchy)}") |
|
|
|
|
|
print(f" Confidence Boost: +{context.confidence_boost:.0%}") |
|
|
|
|
|
|
|
|
augmented = rag.augment_prompt(msg, context) |
|
|
print(f"\n📝 Augmented Prompt:\n{augmented}") |
|
|
print("-" * 60) |
|
|
|