"""Step D: categorize each transaction into a closed-set ledger head. Two layers: 1. Rules layer (free, deterministic, runs FIRST) -- keyword/regex -> category. Catches the obvious majority and cuts GPU time ~60%. 2. Model layer (optional, MiniCPM text model) -- only for rule-misses, batched. Anything the model returns off-list, or below the confidence floor, is mapped to "Suspense / Unclassified". The closed-set guarantee is enforced here. """ import json import os import re from constants import (CATEGORIES, CATEGORY_SET, CONFIDENCE_FLOOR, CONTRA, SUSPENSE, voucher_type_for) TEXT_MODEL_ID = os.environ.get("TEXT_MODEL_ID", "openbmb/MiniCPM3-4B") TEXT_FALLBACK_ID = "Qwen/Qwen3-4B-Instruct" CATEGORIZE_PROMPT = """You are an Indian accountant's assistant. Classify each bank transaction into EXACTLY ONE category from this list: {categories}. Use the narration text. Common Indian patterns: - "UPI/..." person names -> likely Sales receipt (credit) or Purchases (debit) - "NEFT/RTGS/IMPS" + company names -> Debtors/Creditors - "ACH/NACH" + "LIC"/"SBILIFE" -> Insurance; + bank/finance names -> Loan EMI - "GST"/"GSTIN"/"CBIC" -> GST Payment; "TDS"/"CPC" -> TDS/Income Tax - "INT.PD"/"INT CREDIT" -> Interest; "CHRG"/"SMS CHGS"/"AMC" -> Bank Charges - "ATM"/"CSH WDL" -> Cash Withdrawal; "CSH DEP" -> Cash Deposit - Same-name transfers -> Contra Output ONLY a JSON array: [{{"index": , "category": "", "confidence": <0-1>}}]. If genuinely unsure, use "Suspense / Unclassified" with low confidence. Never invent a category outside the list. Transactions: {rows}""" # --------------------------------------------------------------------------- # # Rules layer # --------------------------------------------------------------------------- # # Each rule: (compiled regex, category, confidence, direction) # direction: "debit", "credit", or None (any). Checked top to bottom; first # match wins, so put specific rules before generic ones. def _r(pattern): return re.compile(pattern, re.IGNORECASE) _RULES = [ # --- highly specific government / statutory --- (_r(r"\b(income\s*tax|advance\s*tax|cbdt)\b"), "Income Tax Payment", 0.95, None), (_r(r"\btds\b|\bcpc\b"), "TDS Payment", 0.95, None), (_r(r"\bgst\b|gstin|cbic|gstr"), "GST Payment", 0.95, None), # --- salary / payroll --- (_r(r"salary|payroll|wages\b"), "Salary & Wages", 0.95, None), # --- rent --- (_r(r"\brent\b"), "Office Rent", 0.9, None), # --- repairs & maintenance (before bank-charge AMC) --- (_r(r"repair|maintenance|servicing|service\s*charge\s*ac"), "Repairs & Maintenance", 0.88, None), # --- insurance --- (_r(r"\blic\b|insurance|sbilife|sbi\s*life|premium|policy\b"), "Insurance", 0.9, None), # --- loan / emi --- (_r(r"\bemi\b|home\s*loan|\bloan\b|car\s*loan"), "Loan EMI", 0.9, None), # --- telephone & internet --- (_r(r"broadband|internet|telephone|\bjio\b|airtel|\bbsnl\b|mobile|recharge|postpaid|landline"), "Telephone & Internet", 0.88, None), # --- electricity & utilities --- (_r(r"electric|kesco|kseb|\bbses\b|power\s*bill|utility|\bwater\s*bill"), "Electricity & Utilities", 0.88, None), # --- bank charges --- (_r(r"\bchrg|\bchgs|sms\s*chg|service\s*charge|\bamc\b|debit\s*card|annual\s*fee|bank\s*charge"), "Bank Charges", 0.9, None), # --- interest (direction decides received vs paid) --- (_r(r"int\.?\s*pd|int\s*credit|interest\s*credit|fixed\s*deposit\s*interest|fd\s*interest"), "Interest Received", 0.9, "credit"), (_r(r"interest\s*paid|overdraft|\bod\s*a/c|od\s*interest"), "Interest Paid", 0.9, "debit"), (_r(r"interest"), "Interest Received", 0.7, "credit"), (_r(r"interest"), "Interest Paid", 0.7, "debit"), # --- cash --- (_r(r"csh\s*dep|cash\s*dep|cash\s*deposit"), "Cash Deposit", 0.92, None), (_r(r"\batm\b|csh\s*wdl|cash\s*wdl|cash\s*withdraw"), "Cash Withdrawal", 0.92, None), # --- contra (own-account transfer) --- (_r(r"own\s*account|own\s*a/c|self\b|to\s*self|transfer.*self"), CONTRA, 0.9, None), # --- drawings / capital --- (_r(r"drawings|proprietor|personal\s*withdraw"), "Drawings", 0.9, "debit"), (_r(r"capital\s*introduc|capital\s*infus|partner.*capital"), "Capital Introduced", 0.9, "credit"), # --- professional fees --- (_r(r"prof\s*fees|professional|consult|\bca\s*fee|audit\s*fee"), "Professional Fees", 0.85, None), # --- travel & conveyance --- (_r(r"irctc|travel|petrol|diesel|\bfuel\b|conveyance|\buber\b|\bola\b|rail|flight|hpcl|iocl|bpcl"), "Travel & Conveyance", 0.85, None), ] # Generic bank-transfer direction fallback: money in -> Debtors, out -> Creditors. _TRANSFER_HINT = _r(r"\bupi\b|\bneft\b|\brtgs\b|\bimps\b|\bach\b|\bnach\b") def categorize_rules(txn): """Return (category, confidence) from rules, or None if no rule matched.""" narration = txn.get("narration") or "" debit = txn.get("debit") credit = txn.get("credit") direction = "debit" if (debit and debit > 0) else ("credit" if (credit and credit > 0) else None) for rx, cat, conf, want_dir in _RULES: if rx.search(narration): if want_dir is None or want_dir == direction: return cat, conf # Generic transfer fallback (lower confidence -- a model could do better). if _TRANSFER_HINT.search(narration): if direction == "credit": return "Sales / Receipts from Debtors", 0.6 if direction == "debit": return "Purchases / Payments to Creditors", 0.6 return None # --------------------------------------------------------------------------- # # Closed-set enforcement # --------------------------------------------------------------------------- # def enforce_closed_set(category, confidence): """Guarantee a valid (category, confidence). Off-list or low-conf -> Suspense.""" try: confidence = float(confidence) except (TypeError, ValueError): confidence = 0.0 confidence = max(0.0, min(1.0, confidence)) if category not in CATEGORY_SET or confidence < CONFIDENCE_FLOOR: return SUSPENSE, confidence return category, confidence # --------------------------------------------------------------------------- # # Model layer (optional) # --------------------------------------------------------------------------- # _TEXT = {"model": None, "tokenizer": None, "id": None} def text_model_available(): try: import torch # noqa: F401 from transformers.utils import is_torch_available return bool(is_torch_available()) except Exception: return False def _load_text_model(): if _TEXT["model"] is not None: return import torch from transformers import AutoModelForCausalLM, AutoTokenizer model_id = TEXT_MODEL_ID try: model = AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=True, torch_dtype=torch.bfloat16) tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) except Exception: model_id = TEXT_FALLBACK_ID model = AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=True, torch_dtype=torch.bfloat16) tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) _TEXT.update(model=model, tokenizer=tok, id=model_id) try: import spaces @spaces.GPU(duration=120) def _classify_with_model(batch): return _classify_with_model_impl(batch) except Exception: def _classify_with_model(batch): return _classify_with_model_impl(batch) def _classify_with_model_impl(batch): """batch = list of (index, narration, direction). Returns {index: (cat, conf)}.""" import torch _load_text_model() model, tok = _TEXT["model"], _TEXT["tokenizer"] if torch.cuda.is_available(): model = model.to("cuda") model.eval() results = {} for start in range(0, len(batch), 20): chunk = batch[start:start + 20] rows = "\n".join(f'{idx}: "{narr}" ({d or "?"})' for idx, narr, d in chunk) prompt = CATEGORIZE_PROMPT.format( categories=json.dumps(CATEGORIES, ensure_ascii=False), rows=rows) msgs = [{"role": "user", "content": prompt}] text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) inputs = tok(text, return_tensors="pt").to(model.device) with torch.no_grad(): out = model.generate(**inputs, max_new_tokens=512, do_sample=False) decoded = tok.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) for item in _parse_model_categories(decoded): results[item["index"]] = (item.get("category"), item.get("confidence", 0.0)) return results def _parse_model_categories(text): m = re.search(r"\[.*\]", text or "", re.DOTALL) if not m: return [] try: data = json.loads(m.group(0)) except Exception: return [] return [d for d in data if isinstance(d, dict) and "index" in d] # --------------------------------------------------------------------------- # # Orchestration # --------------------------------------------------------------------------- # def categorize(transactions, use_model=True): """Categorize a list of transactions in place (adds category/confidence/voucher_type). Returns (transactions, stats). stats reports how many were rule-hit vs model. """ misses = [] # (index, narration, direction) rule_hits = 0 for i, txn in enumerate(transactions): rule = categorize_rules(txn) if rule is not None: cat, conf = enforce_closed_set(*rule) txn["category"], txn["confidence"] = cat, conf rule_hits += 1 else: txn["category"], txn["confidence"] = SUSPENSE, 0.0 d = "debit" if txn.get("debit") else ("credit" if txn.get("credit") else None) misses.append((i, txn.get("narration") or "", d)) model_hits = 0 if use_model and misses and text_model_available(): try: preds = _classify_with_model(misses) for idx, (cat, conf) in preds.items(): if 0 <= idx < len(transactions): c, cf = enforce_closed_set(cat, conf) transactions[idx]["category"] = c transactions[idx]["confidence"] = cf model_hits += 1 except Exception: pass # degrade: rule-misses stay in Suspense for txn in transactions: txn["voucher_type"] = voucher_type_for( txn.get("debit"), txn.get("credit"), txn.get("category")) stats = { "total": len(transactions), "rule_hits": rule_hits, "model_hits": model_hits, "suspense": sum(1 for t in transactions if t["category"] == SUSPENSE), "model_used": model_hits > 0, } return transactions, stats