Spaces:
Sleeping
Sleeping
| """Step D: categorize each transaction into a closed-set ledger head. | |
| Two layers: | |
| 1. Rules layer (free, deterministic, runs FIRST) -- keyword/regex -> category. | |
| Catches the obvious majority and cuts GPU time ~60%. | |
| 2. Model layer (optional, MiniCPM text model) -- only for rule-misses, batched. | |
| Anything the model returns off-list, or below the confidence floor, is mapped | |
| to "Suspense / Unclassified". The closed-set guarantee is enforced here. | |
| """ | |
| import json | |
| import os | |
| import re | |
| from constants import (CATEGORIES, CATEGORY_SET, CONFIDENCE_FLOOR, CONTRA, | |
| SUSPENSE, voucher_type_for) | |
| TEXT_MODEL_ID = os.environ.get("TEXT_MODEL_ID", "openbmb/MiniCPM3-4B") | |
| TEXT_FALLBACK_ID = "Qwen/Qwen3-4B-Instruct" | |
| CATEGORIZE_PROMPT = """You are an Indian accountant's assistant. Classify each bank transaction into EXACTLY ONE | |
| category from this list: {categories}. | |
| Use the narration text. Common Indian patterns: | |
| - "UPI/..." person names -> likely Sales receipt (credit) or Purchases (debit) | |
| - "NEFT/RTGS/IMPS" + company names -> Debtors/Creditors | |
| - "ACH/NACH" + "LIC"/"SBILIFE" -> Insurance; + bank/finance names -> Loan EMI | |
| - "GST"/"GSTIN"/"CBIC" -> GST Payment; "TDS"/"CPC" -> TDS/Income Tax | |
| - "INT.PD"/"INT CREDIT" -> Interest; "CHRG"/"SMS CHGS"/"AMC" -> Bank Charges | |
| - "ATM"/"CSH WDL" -> Cash Withdrawal; "CSH DEP" -> Cash Deposit | |
| - Same-name transfers -> Contra | |
| Output ONLY a JSON array: [{{"index": <i>, "category": "<from list>", "confidence": <0-1>}}]. | |
| If genuinely unsure, use "Suspense / Unclassified" with low confidence. Never invent a category | |
| outside the list. | |
| Transactions: | |
| {rows}""" | |
| # --------------------------------------------------------------------------- # | |
| # Rules layer | |
| # --------------------------------------------------------------------------- # | |
| # Each rule: (compiled regex, category, confidence, direction) | |
| # direction: "debit", "credit", or None (any). Checked top to bottom; first | |
| # match wins, so put specific rules before generic ones. | |
| def _r(pattern): | |
| return re.compile(pattern, re.IGNORECASE) | |
| _RULES = [ | |
| # --- highly specific government / statutory --- | |
| (_r(r"\b(income\s*tax|advance\s*tax|cbdt)\b"), "Income Tax Payment", 0.95, None), | |
| (_r(r"\btds\b|\bcpc\b"), "TDS Payment", 0.95, None), | |
| (_r(r"\bgst\b|gstin|cbic|gstr"), "GST Payment", 0.95, None), | |
| # --- salary / payroll --- | |
| (_r(r"salary|payroll|wages\b"), "Salary & Wages", 0.95, None), | |
| # --- rent --- | |
| (_r(r"\brent\b"), "Office Rent", 0.9, None), | |
| # --- repairs & maintenance (before bank-charge AMC) --- | |
| (_r(r"repair|maintenance|servicing|service\s*charge\s*ac"), "Repairs & Maintenance", 0.88, None), | |
| # --- insurance --- | |
| (_r(r"\blic\b|insurance|sbilife|sbi\s*life|premium|policy\b"), "Insurance", 0.9, None), | |
| # --- loan / emi --- | |
| (_r(r"\bemi\b|home\s*loan|\bloan\b|car\s*loan"), "Loan EMI", 0.9, None), | |
| # --- telephone & internet --- | |
| (_r(r"broadband|internet|telephone|\bjio\b|airtel|\bbsnl\b|mobile|recharge|postpaid|landline"), | |
| "Telephone & Internet", 0.88, None), | |
| # --- electricity & utilities --- | |
| (_r(r"electric|kesco|kseb|\bbses\b|power\s*bill|utility|\bwater\s*bill"), | |
| "Electricity & Utilities", 0.88, None), | |
| # --- bank charges --- | |
| (_r(r"\bchrg|\bchgs|sms\s*chg|service\s*charge|\bamc\b|debit\s*card|annual\s*fee|bank\s*charge"), | |
| "Bank Charges", 0.9, None), | |
| # --- interest (direction decides received vs paid) --- | |
| (_r(r"int\.?\s*pd|int\s*credit|interest\s*credit|fixed\s*deposit\s*interest|fd\s*interest"), | |
| "Interest Received", 0.9, "credit"), | |
| (_r(r"interest\s*paid|overdraft|\bod\s*a/c|od\s*interest"), "Interest Paid", 0.9, "debit"), | |
| (_r(r"interest"), "Interest Received", 0.7, "credit"), | |
| (_r(r"interest"), "Interest Paid", 0.7, "debit"), | |
| # --- cash --- | |
| (_r(r"csh\s*dep|cash\s*dep|cash\s*deposit"), "Cash Deposit", 0.92, None), | |
| (_r(r"\batm\b|csh\s*wdl|cash\s*wdl|cash\s*withdraw"), "Cash Withdrawal", 0.92, None), | |
| # --- contra (own-account transfer) --- | |
| (_r(r"own\s*account|own\s*a/c|self\b|to\s*self|transfer.*self"), CONTRA, 0.9, None), | |
| # --- drawings / capital --- | |
| (_r(r"drawings|proprietor|personal\s*withdraw"), "Drawings", 0.9, "debit"), | |
| (_r(r"capital\s*introduc|capital\s*infus|partner.*capital"), "Capital Introduced", 0.9, "credit"), | |
| # --- professional fees --- | |
| (_r(r"prof\s*fees|professional|consult|\bca\s*fee|audit\s*fee"), "Professional Fees", 0.85, None), | |
| # --- travel & conveyance --- | |
| (_r(r"irctc|travel|petrol|diesel|\bfuel\b|conveyance|\buber\b|\bola\b|rail|flight|hpcl|iocl|bpcl"), | |
| "Travel & Conveyance", 0.85, None), | |
| ] | |
| # Generic bank-transfer direction fallback: money in -> Debtors, out -> Creditors. | |
| _TRANSFER_HINT = _r(r"\bupi\b|\bneft\b|\brtgs\b|\bimps\b|\bach\b|\bnach\b") | |
| def categorize_rules(txn): | |
| """Return (category, confidence) from rules, or None if no rule matched.""" | |
| narration = txn.get("narration") or "" | |
| debit = txn.get("debit") | |
| credit = txn.get("credit") | |
| direction = "debit" if (debit and debit > 0) else ("credit" if (credit and credit > 0) else None) | |
| for rx, cat, conf, want_dir in _RULES: | |
| if rx.search(narration): | |
| if want_dir is None or want_dir == direction: | |
| return cat, conf | |
| # Generic transfer fallback (lower confidence -- a model could do better). | |
| if _TRANSFER_HINT.search(narration): | |
| if direction == "credit": | |
| return "Sales / Receipts from Debtors", 0.6 | |
| if direction == "debit": | |
| return "Purchases / Payments to Creditors", 0.6 | |
| return None | |
| # --------------------------------------------------------------------------- # | |
| # Closed-set enforcement | |
| # --------------------------------------------------------------------------- # | |
| def enforce_closed_set(category, confidence): | |
| """Guarantee a valid (category, confidence). Off-list or low-conf -> Suspense.""" | |
| try: | |
| confidence = float(confidence) | |
| except (TypeError, ValueError): | |
| confidence = 0.0 | |
| confidence = max(0.0, min(1.0, confidence)) | |
| if category not in CATEGORY_SET or confidence < CONFIDENCE_FLOOR: | |
| return SUSPENSE, confidence | |
| return category, confidence | |
| # --------------------------------------------------------------------------- # | |
| # Model layer (optional) | |
| # --------------------------------------------------------------------------- # | |
| _TEXT = {"model": None, "tokenizer": None, "id": None} | |
| def text_model_available(): | |
| try: | |
| import torch # noqa: F401 | |
| from transformers.utils import is_torch_available | |
| return bool(is_torch_available()) | |
| except Exception: | |
| return False | |
| def _load_text_model(): | |
| if _TEXT["model"] is not None: | |
| return | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| model_id = TEXT_MODEL_ID | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, trust_remote_code=True, torch_dtype=torch.bfloat16) | |
| tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
| except Exception: | |
| model_id = TEXT_FALLBACK_ID | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, trust_remote_code=True, torch_dtype=torch.bfloat16) | |
| tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
| _TEXT.update(model=model, tokenizer=tok, id=model_id) | |
| try: | |
| import spaces | |
| def _classify_with_model(batch): | |
| return _classify_with_model_impl(batch) | |
| except Exception: | |
| def _classify_with_model(batch): | |
| return _classify_with_model_impl(batch) | |
| def _classify_with_model_impl(batch): | |
| """batch = list of (index, narration, direction). Returns {index: (cat, conf)}.""" | |
| import torch | |
| _load_text_model() | |
| model, tok = _TEXT["model"], _TEXT["tokenizer"] | |
| if torch.cuda.is_available(): | |
| model = model.to("cuda") | |
| model.eval() | |
| results = {} | |
| for start in range(0, len(batch), 20): | |
| chunk = batch[start:start + 20] | |
| rows = "\n".join(f'{idx}: "{narr}" ({d or "?"})' for idx, narr, d in chunk) | |
| prompt = CATEGORIZE_PROMPT.format( | |
| categories=json.dumps(CATEGORIES, ensure_ascii=False), rows=rows) | |
| msgs = [{"role": "user", "content": prompt}] | |
| text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) | |
| inputs = tok(text, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| out = model.generate(**inputs, max_new_tokens=512, do_sample=False) | |
| decoded = tok.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
| for item in _parse_model_categories(decoded): | |
| results[item["index"]] = (item.get("category"), item.get("confidence", 0.0)) | |
| return results | |
| def _parse_model_categories(text): | |
| m = re.search(r"\[.*\]", text or "", re.DOTALL) | |
| if not m: | |
| return [] | |
| try: | |
| data = json.loads(m.group(0)) | |
| except Exception: | |
| return [] | |
| return [d for d in data if isinstance(d, dict) and "index" in d] | |
| # --------------------------------------------------------------------------- # | |
| # Orchestration | |
| # --------------------------------------------------------------------------- # | |
| def categorize(transactions, use_model=True): | |
| """Categorize a list of transactions in place (adds category/confidence/voucher_type). | |
| Returns (transactions, stats). stats reports how many were rule-hit vs model. | |
| """ | |
| misses = [] # (index, narration, direction) | |
| rule_hits = 0 | |
| for i, txn in enumerate(transactions): | |
| rule = categorize_rules(txn) | |
| if rule is not None: | |
| cat, conf = enforce_closed_set(*rule) | |
| txn["category"], txn["confidence"] = cat, conf | |
| rule_hits += 1 | |
| else: | |
| txn["category"], txn["confidence"] = SUSPENSE, 0.0 | |
| d = "debit" if txn.get("debit") else ("credit" if txn.get("credit") else None) | |
| misses.append((i, txn.get("narration") or "", d)) | |
| model_hits = 0 | |
| if use_model and misses and text_model_available(): | |
| try: | |
| preds = _classify_with_model(misses) | |
| for idx, (cat, conf) in preds.items(): | |
| if 0 <= idx < len(transactions): | |
| c, cf = enforce_closed_set(cat, conf) | |
| transactions[idx]["category"] = c | |
| transactions[idx]["confidence"] = cf | |
| model_hits += 1 | |
| except Exception: | |
| pass # degrade: rule-misses stay in Suspense | |
| for txn in transactions: | |
| txn["voucher_type"] = voucher_type_for( | |
| txn.get("debit"), txn.get("credit"), txn.get("category")) | |
| stats = { | |
| "total": len(transactions), | |
| "rule_hits": rule_hits, | |
| "model_hits": model_hits, | |
| "suspense": sum(1 for t in transactions if t["category"] == SUSPENSE), | |
| "model_used": model_hits > 0, | |
| } | |
| return transactions, stats | |