statementsetu / categorize.py
perceptron01's picture
Upload 16 files
10ec275 verified
Raw
History Blame Contribute Delete
11 kB
"""Step D: categorize each transaction into a closed-set ledger head.
Two layers:
1. Rules layer (free, deterministic, runs FIRST) -- keyword/regex -> category.
Catches the obvious majority and cuts GPU time ~60%.
2. Model layer (optional, MiniCPM text model) -- only for rule-misses, batched.
Anything the model returns off-list, or below the confidence floor, is mapped
to "Suspense / Unclassified". The closed-set guarantee is enforced here.
"""
import json
import os
import re
from constants import (CATEGORIES, CATEGORY_SET, CONFIDENCE_FLOOR, CONTRA,
SUSPENSE, voucher_type_for)
TEXT_MODEL_ID = os.environ.get("TEXT_MODEL_ID", "openbmb/MiniCPM3-4B")
TEXT_FALLBACK_ID = "Qwen/Qwen3-4B-Instruct"
CATEGORIZE_PROMPT = """You are an Indian accountant's assistant. Classify each bank transaction into EXACTLY ONE
category from this list: {categories}.
Use the narration text. Common Indian patterns:
- "UPI/..." person names -> likely Sales receipt (credit) or Purchases (debit)
- "NEFT/RTGS/IMPS" + company names -> Debtors/Creditors
- "ACH/NACH" + "LIC"/"SBILIFE" -> Insurance; + bank/finance names -> Loan EMI
- "GST"/"GSTIN"/"CBIC" -> GST Payment; "TDS"/"CPC" -> TDS/Income Tax
- "INT.PD"/"INT CREDIT" -> Interest; "CHRG"/"SMS CHGS"/"AMC" -> Bank Charges
- "ATM"/"CSH WDL" -> Cash Withdrawal; "CSH DEP" -> Cash Deposit
- Same-name transfers -> Contra
Output ONLY a JSON array: [{{"index": <i>, "category": "<from list>", "confidence": <0-1>}}].
If genuinely unsure, use "Suspense / Unclassified" with low confidence. Never invent a category
outside the list.
Transactions:
{rows}"""
# --------------------------------------------------------------------------- #
# Rules layer
# --------------------------------------------------------------------------- #
# Each rule: (compiled regex, category, confidence, direction)
# direction: "debit", "credit", or None (any). Checked top to bottom; first
# match wins, so put specific rules before generic ones.
def _r(pattern):
return re.compile(pattern, re.IGNORECASE)
_RULES = [
# --- highly specific government / statutory ---
(_r(r"\b(income\s*tax|advance\s*tax|cbdt)\b"), "Income Tax Payment", 0.95, None),
(_r(r"\btds\b|\bcpc\b"), "TDS Payment", 0.95, None),
(_r(r"\bgst\b|gstin|cbic|gstr"), "GST Payment", 0.95, None),
# --- salary / payroll ---
(_r(r"salary|payroll|wages\b"), "Salary & Wages", 0.95, None),
# --- rent ---
(_r(r"\brent\b"), "Office Rent", 0.9, None),
# --- repairs & maintenance (before bank-charge AMC) ---
(_r(r"repair|maintenance|servicing|service\s*charge\s*ac"), "Repairs & Maintenance", 0.88, None),
# --- insurance ---
(_r(r"\blic\b|insurance|sbilife|sbi\s*life|premium|policy\b"), "Insurance", 0.9, None),
# --- loan / emi ---
(_r(r"\bemi\b|home\s*loan|\bloan\b|car\s*loan"), "Loan EMI", 0.9, None),
# --- telephone & internet ---
(_r(r"broadband|internet|telephone|\bjio\b|airtel|\bbsnl\b|mobile|recharge|postpaid|landline"),
"Telephone & Internet", 0.88, None),
# --- electricity & utilities ---
(_r(r"electric|kesco|kseb|\bbses\b|power\s*bill|utility|\bwater\s*bill"),
"Electricity & Utilities", 0.88, None),
# --- bank charges ---
(_r(r"\bchrg|\bchgs|sms\s*chg|service\s*charge|\bamc\b|debit\s*card|annual\s*fee|bank\s*charge"),
"Bank Charges", 0.9, None),
# --- interest (direction decides received vs paid) ---
(_r(r"int\.?\s*pd|int\s*credit|interest\s*credit|fixed\s*deposit\s*interest|fd\s*interest"),
"Interest Received", 0.9, "credit"),
(_r(r"interest\s*paid|overdraft|\bod\s*a/c|od\s*interest"), "Interest Paid", 0.9, "debit"),
(_r(r"interest"), "Interest Received", 0.7, "credit"),
(_r(r"interest"), "Interest Paid", 0.7, "debit"),
# --- cash ---
(_r(r"csh\s*dep|cash\s*dep|cash\s*deposit"), "Cash Deposit", 0.92, None),
(_r(r"\batm\b|csh\s*wdl|cash\s*wdl|cash\s*withdraw"), "Cash Withdrawal", 0.92, None),
# --- contra (own-account transfer) ---
(_r(r"own\s*account|own\s*a/c|self\b|to\s*self|transfer.*self"), CONTRA, 0.9, None),
# --- drawings / capital ---
(_r(r"drawings|proprietor|personal\s*withdraw"), "Drawings", 0.9, "debit"),
(_r(r"capital\s*introduc|capital\s*infus|partner.*capital"), "Capital Introduced", 0.9, "credit"),
# --- professional fees ---
(_r(r"prof\s*fees|professional|consult|\bca\s*fee|audit\s*fee"), "Professional Fees", 0.85, None),
# --- travel & conveyance ---
(_r(r"irctc|travel|petrol|diesel|\bfuel\b|conveyance|\buber\b|\bola\b|rail|flight|hpcl|iocl|bpcl"),
"Travel & Conveyance", 0.85, None),
]
# Generic bank-transfer direction fallback: money in -> Debtors, out -> Creditors.
_TRANSFER_HINT = _r(r"\bupi\b|\bneft\b|\brtgs\b|\bimps\b|\bach\b|\bnach\b")
def categorize_rules(txn):
"""Return (category, confidence) from rules, or None if no rule matched."""
narration = txn.get("narration") or ""
debit = txn.get("debit")
credit = txn.get("credit")
direction = "debit" if (debit and debit > 0) else ("credit" if (credit and credit > 0) else None)
for rx, cat, conf, want_dir in _RULES:
if rx.search(narration):
if want_dir is None or want_dir == direction:
return cat, conf
# Generic transfer fallback (lower confidence -- a model could do better).
if _TRANSFER_HINT.search(narration):
if direction == "credit":
return "Sales / Receipts from Debtors", 0.6
if direction == "debit":
return "Purchases / Payments to Creditors", 0.6
return None
# --------------------------------------------------------------------------- #
# Closed-set enforcement
# --------------------------------------------------------------------------- #
def enforce_closed_set(category, confidence):
"""Guarantee a valid (category, confidence). Off-list or low-conf -> Suspense."""
try:
confidence = float(confidence)
except (TypeError, ValueError):
confidence = 0.0
confidence = max(0.0, min(1.0, confidence))
if category not in CATEGORY_SET or confidence < CONFIDENCE_FLOOR:
return SUSPENSE, confidence
return category, confidence
# --------------------------------------------------------------------------- #
# Model layer (optional)
# --------------------------------------------------------------------------- #
_TEXT = {"model": None, "tokenizer": None, "id": None}
def text_model_available():
try:
import torch # noqa: F401
from transformers.utils import is_torch_available
return bool(is_torch_available())
except Exception:
return False
def _load_text_model():
if _TEXT["model"] is not None:
return
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = TEXT_MODEL_ID
try:
model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, torch_dtype=torch.bfloat16)
tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
except Exception:
model_id = TEXT_FALLBACK_ID
model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, torch_dtype=torch.bfloat16)
tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
_TEXT.update(model=model, tokenizer=tok, id=model_id)
try:
import spaces
@spaces.GPU(duration=120)
def _classify_with_model(batch):
return _classify_with_model_impl(batch)
except Exception:
def _classify_with_model(batch):
return _classify_with_model_impl(batch)
def _classify_with_model_impl(batch):
"""batch = list of (index, narration, direction). Returns {index: (cat, conf)}."""
import torch
_load_text_model()
model, tok = _TEXT["model"], _TEXT["tokenizer"]
if torch.cuda.is_available():
model = model.to("cuda")
model.eval()
results = {}
for start in range(0, len(batch), 20):
chunk = batch[start:start + 20]
rows = "\n".join(f'{idx}: "{narr}" ({d or "?"})' for idx, narr, d in chunk)
prompt = CATEGORIZE_PROMPT.format(
categories=json.dumps(CATEGORIES, ensure_ascii=False), rows=rows)
msgs = [{"role": "user", "content": prompt}]
text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
inputs = tok(text, return_tensors="pt").to(model.device)
with torch.no_grad():
out = model.generate(**inputs, max_new_tokens=512, do_sample=False)
decoded = tok.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
for item in _parse_model_categories(decoded):
results[item["index"]] = (item.get("category"), item.get("confidence", 0.0))
return results
def _parse_model_categories(text):
m = re.search(r"\[.*\]", text or "", re.DOTALL)
if not m:
return []
try:
data = json.loads(m.group(0))
except Exception:
return []
return [d for d in data if isinstance(d, dict) and "index" in d]
# --------------------------------------------------------------------------- #
# Orchestration
# --------------------------------------------------------------------------- #
def categorize(transactions, use_model=True):
"""Categorize a list of transactions in place (adds category/confidence/voucher_type).
Returns (transactions, stats). stats reports how many were rule-hit vs model.
"""
misses = [] # (index, narration, direction)
rule_hits = 0
for i, txn in enumerate(transactions):
rule = categorize_rules(txn)
if rule is not None:
cat, conf = enforce_closed_set(*rule)
txn["category"], txn["confidence"] = cat, conf
rule_hits += 1
else:
txn["category"], txn["confidence"] = SUSPENSE, 0.0
d = "debit" if txn.get("debit") else ("credit" if txn.get("credit") else None)
misses.append((i, txn.get("narration") or "", d))
model_hits = 0
if use_model and misses and text_model_available():
try:
preds = _classify_with_model(misses)
for idx, (cat, conf) in preds.items():
if 0 <= idx < len(transactions):
c, cf = enforce_closed_set(cat, conf)
transactions[idx]["category"] = c
transactions[idx]["confidence"] = cf
model_hits += 1
except Exception:
pass # degrade: rule-misses stay in Suspense
for txn in transactions:
txn["voucher_type"] = voucher_type_for(
txn.get("debit"), txn.get("credit"), txn.get("category"))
stats = {
"total": len(transactions),
"rule_hits": rule_hits,
"model_hits": model_hits,
"suspense": sum(1 for t in transactions if t["category"] == SUSPENSE),
"model_used": model_hits > 0,
}
return transactions, stats