iris-ir-platform / backend /train_document.py
rajvivan's picture
sync: push iris-ir-platform to HuggingFace Space
2a5d15a
Raw
History Blame Contribute Delete
65.3 kB
#!/usr/bin/env python3
from __future__ import annotations
"""
IRIS β€” IR Document Training Script
=====================================
Trains the system on any IR PDF to reduce retrieval errors and hallucination.
This script does TWO distinct kinds of "training":
─────────────────────────────────────────────────────────────────
A. RETRIEVAL TRAINING (always runs)
─────────────────────────────────────────────────────────────────
1. Parse PDF β†’ extract text, tables, section headings
2. Auto-generate page β†’ section β†’ KPI ground-truth mapping
3. Build (query, correct_page) training pairs from the mapping
4. Calibrate retrieval scoring weights via grid search (MRR metric)
- metric_match_weight
- section_heading_boost
- caption_boost
- min_confidence_threshold
5. Save calibrated weights to data/retrieval_config.json
6. Re-embed text chunks with updated config in ChromaDB
7. Re-index tables with enriched metadata
─────────────────────────────────────────────────────────────────
B. EMBEDDING FINE-TUNING (optional flag --fine-tune)
─────────────────────────────────────────────────────────────────
8. Build contrastive training pairs:
Positive : (KPI query, correct page text chunk)
Hard negatives: (KPI query, wrong-section chunks)
9. Fine-tune BAAI/bge-small-en-v1.5 using
MultipleNegativesRankingLoss (sentence-transformers)
10. Save fine-tuned model to data/models/<doc_id>_embed/
11. Re-embed all chunks with the fine-tuned model
12. Re-index ChromaDB with new embeddings
─────────────────────────────────────────────────────────────────
C. ANTI-HALLUCINATION HARDENING (always runs, step 13)
─────────────────────────────────────────────────────────────────
13. Extract EXACT numeric KPI values from verified tables
14. Save to data/kpi_ground_truth.json
15. Update generation agent to cross-check produced numbers
against ground truth before returning a response
16. Apply confidence threshold β€” below threshold returns
"Insufficient evidence" instead of a hallucinated answer
─────────────────────────────────────────────────────────────────
D. INGESTION PIPELINE (always runs)
─────────────────────────────────────────────────────────────────
17. Render PDF pages to PNG images
18. Generate ColPali visual embeddings (batch_size=1, MPS safe)
Usage:
python train_document.py --pdf ../documents/enbd_q1_2026.pdf
python train_document.py --pdf ../documents/enbd_q1_2026.pdf --fine-tune
python train_document.py --pdf ../documents/enbd_q1_2026.pdf --skip-colpali
"""
import argparse
import json
import logging
import math
import os
import re
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
sys.path.insert(0, str(Path(__file__).parent))
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(name)s %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger("iris.train")
BASE_DIR = Path(__file__).parent
DATA_DIR = BASE_DIR / "data"
PAGES_DIR = DATA_DIR / "pages"
COLPALI_DIR = DATA_DIR / "colpali_index"
CHROMA_DIR = DATA_DIR / "chroma"
TABLES_FILE = DATA_DIR / "tables.json"
DOCS_FILE = DATA_DIR / "documents.json"
PROC_DIR = DATA_DIR / "processed"
PAGEMAP_DIR = DATA_DIR / "page_maps"
MODELS_DIR = DATA_DIR / "models"
CONFIG_FILE = DATA_DIR / "retrieval_config.json"
KPI_GT_FILE = DATA_DIR / "kpi_ground_truth.json"
# ── Default retrieval weights (overridden by calibration) ────────────────────
DEFAULT_CONFIG = {
"embed_model": "BAAI/bge-small-en-v1.5",
"chunk_size": 300,
"chunk_overlap": 60,
"top_k_text": 6,
"top_k_tables": 4,
"top_k_visual": 3,
"metric_match_weight": 3.0,
"section_heading_boost": 4.0,
"caption_boost": 2.0,
"term_overlap_weight": 0.3,
"min_confidence_threshold": 0.20,
"ollama_temperature": 0.05,
"ollama_num_predict": 2048,
"generation_grounding": True,
"reject_below_threshold": True,
}
# ── Section heading detection patterns ───────────────────────────────────────
SECTION_PATTERNS = [
(r"funding.*liquidity", "Liquidity"),
(r"liquidity coverage ratio", "Liquidity"),
(r"advances.*deposit.*ratio", "Liquidity"),
(r"liquid assets.*aed", "Liquidity"),
(r"\blcr\b.*\badr\b", "Liquidity"),
(r"income statement", "Income Statement"),
(r"profit before tax", "Income Statement"),
(r"aed\s+[\d.]+bn.*profit", "Income Statement"),
(r"net interest margin", "Net Interest Margin"),
(r"margins remain", "Net Interest Margin"),
(r"non[- ]funded income", "Non-Funded Income"),
(r"loan growth", "Loans & Deposits"),
(r"gross loan", "Loans & Deposits"),
(r"deposit growth", "Loans & Deposits"),
(r"robust credit quality", "Asset Quality"),
(r"npl ratio", "Asset Quality"),
(r"coverage ratio", "Asset Quality"),
(r"cost of risk", "Asset Quality"),
(r"cost[- ]to[- ]income", "Cost to Income"),
(r"operating expense", "Cost to Income"),
(r"common equity tier", "Capital Adequacy"),
(r"cet[- ]?1.*ratio", "Capital Adequacy"),
(r"capital adequacy", "Capital Adequacy"),
(r"divisional performance", "Divisional Performance"),
(r"\besg\b", "ESG"),
(r"sustainability", "ESG"),
(r"gdp.*growth", "Economic Environment"),
(r"denizbank", "DenizBank / TΓΌrkiye"),
(r"hyperinflation", "DenizBank / TΓΌrkiye"),
(r"credit rating", "Credit Ratings"),
(r"investment case", "Investment Case"),
(r"financial results.*q", "Financial Appendix"),
]
SECTION_BY_PAGE = {
1: "Cover",
2: "Important Information",
3: "Economic Environment",
4: "Economic Environment",
5: "Economic Environment",
6: "Group Overview",
7: "Group Overview",
8: "International Presence",
9: "Credit Ratings",
10: "Shareholder Base",
11: "Investment Case",
12: "Peer Comparison",
13: "Profitability",
14: "Financial & Operating Performance",
15: "Executive Summary",
16: "Income Statement",
17: "Net Interest Margin",
18: "Non-Funded Income",
19: "Loans & Deposits",
20: "Asset Quality",
21: "Cost to Income",
22: "Liquidity",
23: "Capital Adequacy",
24: "Divisional Performance",
25: "ESG",
26: "ESG",
27: "ESG",
28: "ESG",
29: "Appendix",
30: "Financial Results",
31: "USD Translation",
32: "Hyperinflation",
33: "Turkey Macro",
34: "Egypt Macro",
35: "KSA Macro",
36: "Contact",
}
SECTION_KPI_TAGS = {
"Income Statement": ["net profit", "revenue", "nim", "cor", "cost_income"],
"Net Interest Margin": ["nim"],
"Non-Funded Income": ["revenue"],
"Loans & Deposits": ["loans", "deposits"],
"Asset Quality": ["npl", "cor"],
"Cost to Income": ["cost_income"],
"Liquidity": ["lcr", "deposits"],
"Capital Adequacy": ["car"],
"Divisional Performance": ["net profit", "revenue", "nim", "cor", "npl"],
"Group Overview": ["net profit", "revenue", "deposits", "loans", "car"],
"Financial Appendix": ["net profit", "revenue", "nim", "cor"],
}
# ── Canonical queries per KPI for training pair generation ───────────────────
KPI_CANONICAL_QUERIES = {
"Income Statement": [
"net profit", "profit before tax", "total income", "operating profit",
"group earnings", "net earnings", "profit after tax", "PAT",
"profit growth year on year", "Q1 2026 profitability",
],
"Net Interest Margin": [
"net interest margin", "NIM", "NIM trend", "interest margin performance",
"margin compression", "NII growth", "net interest income",
],
"Non-Funded Income": [
"non-funded income", "NFI", "fee income", "non-interest income",
"fee and commission income", "trading income",
],
"Loans & Deposits": [
"loan growth", "deposit growth", "advances growth", "loan book",
"credit growth", "total loans", "customer deposits",
],
"Asset Quality": [
"NPL ratio", "non-performing loans", "cost of risk", "credit quality",
"coverage ratio", "impairment charges", "provisions",
],
"Cost to Income": [
"cost to income", "cost income ratio", "operating efficiency",
"operating expenses", "CIR",
],
"Liquidity": [
"liquidity coverage ratio", "LCR", "advances to deposit ratio", "ADR",
"liquid assets", "funding structure", "liquidity position",
"what is the lcr", "lcr in q1 2026", "liquidity coverage ratio in q1 2026",
"how strong is the lcr", "lcr performance",
],
"Capital Adequacy": [
"capital adequacy ratio", "CAR", "CET1", "CET-1", "tier 1 ratio",
"capital position", "common equity tier 1", "regulatory capital",
],
"Divisional Performance": [
"divisional performance", "business segments", "retail banking",
"corporate banking", "DenizBank performance", "segment results",
],
}
# ═══════════════════════════════════════════════════════════════════════════════
# A. RETRIEVAL TRAINING
# ═══════════════════════════════════════════════════════════════════════════════
def detect_section(page_text: str) -> str:
text_lower = page_text.lower()
for pattern, section in SECTION_PATTERNS:
if re.search(pattern, text_lower):
return section
return ""
def detect_slide_number(page_text: str) -> Optional[int]:
lines = [l.strip() for l in page_text.split("\n") if l.strip()]
for line in lines[:4]:
if re.match(r"^\d{1,3}$", line):
try:
return int(line)
except ValueError:
pass
return None
def load_slide_directory() -> dict[int, dict]:
slide_dir_file = DATA_DIR / "slide_directory_index.json"
if not slide_dir_file.exists():
return {}
with open(slide_dir_file, "r", encoding="utf-8") as f:
rows = json.load(f)
return {
int(row["slide"]): row
for row in rows
if str(row.get("slide", "")).isdigit()
}
def split_slide_list(value: str) -> list[str]:
return [item.strip() for item in str(value or "").split(",") if item.strip()]
def _unique_terms(terms: list[str]) -> list[str]:
seen = set()
result = []
for term in terms:
cleaned = re.sub(r"\s+", " ", str(term or "")).strip()
if len(cleaned) < 3:
continue
key = cleaned.lower()
if key not in seen:
seen.add(key)
result.append(cleaned)
return result
def metadata_query_terms(info: dict) -> list[str]:
"""
Convert every slide's curated KPI, mapping item, description, and finance
synonym metadata into retrieval training queries. This keeps training
aligned with the slide directory instead of relying only on broad sections.
"""
terms: list[str] = []
for key in ("kpis", "mapping_items", "synonyms", "kpi_tags"):
values = info.get(key, []) or []
if isinstance(values, str):
values = split_slide_list(values)
terms.extend(values)
description = str(info.get("description", "") or "").strip()
if description:
terms.append(description)
for fragment in re.split(r"[.;:]", description):
fragment = fragment.strip()
if 12 <= len(fragment) <= 140:
terms.append(fragment)
title = str(info.get("title", "") or "").strip()
if title and not re.match(r"^\d+$", title):
terms.append(title)
return _unique_terms(terms)
def build_page_map(extracted) -> dict:
page_map = {}
slide_directory = load_slide_directory()
for page in extracted.pages:
section = SECTION_BY_PAGE.get(page.page_number) or getattr(page, "section_heading", "") or detect_section(page.text)
slide = getattr(page, "slide_number", None) or detect_slide_number(page.text) or page.page_number
slide_meta = slide_directory.get(page.page_number) or slide_directory.get(slide)
lines = [l.strip() for l in page.text.split("\n")
if l.strip() and len(l.strip()) > 8 and not re.match(r"^\d+$", l.strip())]
title = lines[0][:120] if lines else ""
kpis = list(SECTION_KPI_TAGS.get(section, []))
if slide_meta:
slide = int(slide_meta.get("slide") or page.page_number)
kpis = split_slide_list(slide_meta.get("kpis", ""))
page_map[page.page_number] = {
"section": section, "slide": slide,
"kpis": kpis, "title": title,
"text_preview": page.text[:200].replace("\n", " ").strip(),
}
if slide_meta:
page_map[page.page_number].update({
"period": slide_meta.get("period", ""),
"mapping_items": split_slide_list(slide_meta.get("topics", "")),
"synonyms": split_slide_list(slide_meta.get("synonyms", "")),
"description": slide_meta.get("description", ""),
"visual_layout": slide_meta.get("visual_layout", ""),
})
return page_map
@dataclass
class TrainingPair:
query: str
positive_page: int
positive_text: str
negative_pages: list
negative_texts: list
section: str
kpi_tag: str
def generate_training_pairs(page_map: dict, chunks: list) -> list[TrainingPair]:
"""
Generate (query, positive_chunk, hard_negative_chunks) training pairs.
Positive = chunk from the correct section page.
Hard negative = chunk from a different section with overlapping keywords.
"""
# Group chunks by page
chunks_by_page: dict[int, list] = {}
for c in chunks:
chunks_by_page.setdefault(c.page_number, []).append(c)
# Group pages by section
pages_by_section: dict[str, list[int]] = {}
for pg, info in page_map.items():
sec = info["section"]
if sec:
pages_by_section.setdefault(sec, []).append(int(pg))
def page_info_for(page_number: int) -> dict:
return page_map.get(page_number) or page_map.get(str(page_number)) or {}
# Shared hard negatives are selected from other sections with financial data.
all_financial_negatives = [
c for c in chunks
if getattr(c, "has_financial_data", False)
]
pairs = []
for section, queries in KPI_CANONICAL_QUERIES.items():
correct_pages = pages_by_section.get(section, [])
if not correct_pages:
continue
# Positive chunks = all chunks from pages in this section
pos_chunks = []
for pg in correct_pages:
pos_chunks.extend(chunks_by_page.get(pg, []))
if not pos_chunks:
continue
# Hard negative chunks = chunks from other sections with financial keywords
neg_chunks = []
for other_sec, other_pages in pages_by_section.items():
if other_sec == section:
continue
for pg in other_pages:
neg_chunks.extend(chunks_by_page.get(pg, []))
# Keep only negatives that have financial content (hard negatives)
neg_chunks = [c for c in neg_chunks if c.has_financial_data][:10]
for query in queries:
# Use the best positive chunk (most financial keywords)
best_pos = max(pos_chunks, key=lambda c: len(c.financial_keywords_found))
pairs.append(TrainingPair(
query=query,
positive_page=best_pos.page_number,
positive_text=best_pos.text,
negative_pages=[c.page_number for c in neg_chunks[:5]],
negative_texts=[c.text for c in neg_chunks[:5]],
section=section,
kpi_tag=section.lower().replace(" ", "_"),
))
# Add page-specific slide training from curated slide metadata. These pairs
# are intentionally tied to the exact slide, so queries such as ESG ratings,
# green bond framework, or business segment performance cannot drift to NIM.
existing = {(p.query.lower(), p.positive_page) for p in pairs}
for pg, info in page_map.items():
page_chunks = chunks_by_page.get(int(pg), [])
if not page_chunks:
continue
terms = metadata_query_terms(info)
if not terms:
continue
best_pos = max(
page_chunks,
key=lambda c: (
len(getattr(c, "financial_keywords_found", []) or []),
len(getattr(c, "text", "") or ""),
),
)
section = info.get("section", "")
neg_chunks = [
c for c in all_financial_negatives
if c.page_number != best_pos.page_number
and page_info_for(c.page_number).get("section") != section
][:10]
for query in terms:
key = (query.lower(), best_pos.page_number)
if key in existing:
continue
existing.add(key)
pairs.append(TrainingPair(
query=query,
positive_page=best_pos.page_number,
positive_text=best_pos.text,
negative_pages=[c.page_number for c in neg_chunks[:5]],
negative_texts=[c.text for c in neg_chunks[:5]],
section=section,
kpi_tag=section.lower().replace(" ", "_"),
))
logger.info(f" Generated {len(pairs)} training pairs from {len(pages_by_section)} sections")
return pairs
def calibrate_weights(pairs: list[TrainingPair], chunks: list) -> dict:
"""
Grid search over retrieval scoring weights.
Evaluates each weight combination using Mean Reciprocal Rank (MRR).
MRR = average of 1/rank for each query where rank = position of correct page.
Returns the weight dict that maximises MRR.
"""
from sentence_transformers import SentenceTransformer
logger.info("\n [Calibration] Loading embedding model for weight calibration…")
embed_model = SentenceTransformer("BAAI/bge-small-en-v1.5")
# Pre-compute chunk embeddings
chunk_texts = [c.text for c in chunks]
chunk_embs = embed_model.encode(chunk_texts, normalize_embeddings=True, show_progress_bar=False)
# Pre-compute query embeddings once. The grid search reuses the same
# training queries for every weight combination, so re-encoding them inside
# the scoring loop makes training unnecessarily slow.
pair_query_embs = embed_model.encode(
[pair.query for pair in pairs],
batch_size=64,
normalize_embeddings=True,
show_progress_bar=False,
)
# Build quick lookup: page β†’ chunks index
page_to_idx: dict[int, list[int]] = {}
for i, c in enumerate(chunks):
page_to_idx.setdefault(c.page_number, []).append(i)
def score_retrieval(pair: TrainingPair, q_emb, metric_w: float, section_boost: float) -> int:
"""
Score a single query and return the rank of the correct page.
Lower rank = better. Returns 99 if correct page not in top-10.
"""
import numpy as np
# Score each chunk: cosine_sim + section_boost if section matches
scored: list[tuple[float, int]] = [] # (score, page_number)
for i, c in enumerate(chunks):
cos_sim = float(np.dot(q_emb, chunk_embs[i]))
# Section heading boost
section_match = (
pair.section.lower() in (c.section_heading or "").lower()
or (c.section_heading or "").lower() in pair.section.lower()
)
score = cos_sim + (section_boost * 0.1 if section_match else 0.0)
scored.append((score, c.page_number))
# Sort by score descending, get unique page order
scored.sort(key=lambda x: -x[0])
seen_pages = []
for _, pg in scored:
if pg not in seen_pages:
seen_pages.append(pg)
if len(seen_pages) >= 10:
break
if pair.positive_page in seen_pages:
return seen_pages.index(pair.positive_page) + 1
return 99
def compute_mrr(metric_w: float, section_boost: float) -> float:
rr_sum = 0.0
for pair, q_emb in zip(pairs, pair_query_embs):
rank = score_retrieval(pair, q_emb, metric_w, section_boost)
rr_sum += 1.0 / rank
return rr_sum / len(pairs)
# Grid search
logger.info(" [Calibration] Running grid search over weight combinations…")
best_mrr = 0.0
best_config = {"metric_match_weight": 3.0, "section_heading_boost": 4.0}
metric_weights = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
section_boosts = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 8.0]
caption_boosts = [1.0, 2.0, 3.0]
confidence_thrs = [0.10, 0.15, 0.20, 0.25, 0.30]
total_combinations = len(metric_weights) * len(section_boosts)
evaluated = 0
for mw in metric_weights:
for sb in section_boosts:
mrr = compute_mrr(mw, sb)
evaluated += 1
if mrr > best_mrr:
best_mrr = mrr
best_config["metric_match_weight"] = mw
best_config["section_heading_boost"] = sb
logger.info(
f" [Calibration] New best MRR={mrr:.4f} "
f"(metric_w={mw}, section_boost={sb}) "
f"[{evaluated}/{total_combinations}]"
)
# Find best confidence threshold (based on max score distribution)
logger.info(" [Calibration] Calibrating confidence threshold…")
import numpy as np
q_embs = embed_model.encode([p.query for p in pairs], normalize_embeddings=True)
max_scores = []
for qi, pair in enumerate(pairs):
sims = chunk_embs @ q_embs[qi]
max_scores.append(float(sims.max()))
# Set threshold at 10th percentile of max_scores (keeps 90% of real queries)
threshold = float(np.percentile(max_scores, 10))
best_config["min_confidence_threshold"] = round(max(0.05, threshold), 3)
logger.info(
f"\n βœ… Calibration complete | MRR = {best_mrr:.4f} | "
f"Threshold = {best_config['min_confidence_threshold']:.3f}"
)
return best_config
# ═══════════════════════════════════════════════════════════════════════════════
# B. EMBEDDING FINE-TUNING
# ═══════════════════════════════════════════════════════════════════════════════
def fine_tune_embeddings(pairs: list[TrainingPair], doc_id: str, base_model: str) -> Path:
"""
Fine-tune the embedding model on financial IR domain data using
MultipleNegativesRankingLoss (contrastive learning without explicit negatives).
Saves the fine-tuned model to data/models/<doc_id>_embed/
"""
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
model_out = MODELS_DIR / f"{doc_id}_embed"
MODELS_DIR.mkdir(parents=True, exist_ok=True)
logger.info(f"\n [Fine-tune] Loading base model: {base_model}")
model = SentenceTransformer(base_model)
# Build training examples: (query, positive_chunk)
examples = []
for pair in pairs:
if pair.positive_text.strip():
examples.append(InputExample(
texts=[pair.query, pair.positive_text]
))
# Add keyword-enriched variant
enriched_query = f"{pair.section}: {pair.query}"
examples.append(InputExample(
texts=[enriched_query, pair.positive_text]
))
if not examples:
logger.warning(" No training examples generated β€” skipping fine-tuning")
return Path(base_model)
logger.info(f" [Fine-tune] Training on {len(examples)} examples")
dataloader = DataLoader(examples, shuffle=True, batch_size=16)
loss_fn = losses.MultipleNegativesRankingLoss(model)
warmup_steps = max(1, len(dataloader) // 5)
model.fit(
train_objectives=[(dataloader, loss_fn)],
epochs=3,
warmup_steps=warmup_steps,
show_progress_bar=True,
output_path=str(model_out),
save_best_model=True,
)
logger.info(f" βœ… Fine-tuned model saved β†’ {model_out}")
return model_out
# ═══════════════════════════════════════════════════════════════════════════════
# C. ANTI-HALLUCINATION β€” KPI GROUND TRUTH EXTRACTION
# ═══════════════════════════════════════════════════════════════════════════════
PERCENTAGE_KPIS = {
"nim_percent", "npl_ratio", "coverage_ratio", "cet1_ratio", "lcr", "adr",
"cost_income_ratio", "casa_ratio", "gross_loans_yoy", "gross_loans_qoq",
"total_deposits_yoy", "total_deposits_qoq", "nfi_yoy"
}
# Regex patterns to extract specific numeric values from table text / labels
KPI_TABLE_LABEL_PATTERNS = {
"net_profit_current": [
r"^(?:group net profit|net profit|profit|profit after tax)$",
],
"profit_before_tax": [
r"^(?:profit before tax)$",
],
"nim_percent": [
r"^(?:nim|net interest margin|net interest margin \(%\))$",
],
"npl_ratio": [
r"^(?:npl ratio|npl ratio \(%\))$",
],
"coverage_ratio": [
r"^(?:coverage ratio|npl coverage|npl coverage ratio)$",
],
"cet1_ratio": [
r"^(?:cet[- ]?1|cet[- ]?1 ratio|common equity tier 1 ratio)$",
],
"lcr": [
r"^(?:lcr|liquidity coverage ratio)$",
],
"adr": [
r"^(?:adr|advances to deposit ratio)$",
],
"cost_income_ratio": [
r"^(?:cost to income ratio|cost-to-income ratio)$",
],
"total_income": [
r"^(?:total income)$",
],
"total_assets": [
r"^(?:total assets)$",
],
# New KPIs
"gross_loans_total": [
r"^(?:total gross loans|gross loans)$",
],
"gross_loans_yoy": [
r"^(?:total gross loans|gross loans)$",
],
"gross_loans_qoq": [
r"^(?:total gross loans|gross loans)$",
],
"total_deposits": [
r"^(?:deposits|customer deposits|total deposits)$",
],
"total_deposits_yoy": [
r"^(?:deposits|customer deposits|total deposits)$",
],
"total_deposits_qoq": [
r"^(?:deposits|customer deposits|total deposits)$",
],
"nfi_total": [
r"^(?:total non-funded income|total non funded income|nfi)$",
],
"nfi_yoy": [
r"^(?:total non-funded income|total non funded income|nfi)$",
],
}
KPI_TEXT_PATTERNS = {
"net_profit_current": [
r"\b(?:net profit|group net profit|profit after tax)\b[^\n]{0,30}?(?:aed\s*)?([\d,.]+)\s*\b(bn|mn|b|m)\b",
],
"profit_before_tax": [
r"\bprofit before tax\b[^\n]{0,30}?(?:aed\s*)?([\d,.]+)\s*\b(bn|mn|b|m)\b",
],
"nim_percent": [
r"\b(?:nim|net interest margin)\b[^\n]{0,30}?([\d.]+)\s*%",
],
"npl_ratio": [
r"\bnpl ratio\b[^\n]{0,30}?([\d.]+)\s*%",
r"\bnon[- ]performing loan ratio\b[^\n]{0,30}?([\d.]+)\s*%",
],
"coverage_ratio": [
r"\b(?:provision\s+)?coverage ratio\b[^\n]{0,30}?([\d.]+)\s*%",
],
"cet1_ratio": [
r"\b(?:cet[- ]?1|common equity tier 1)\s*(?:ratio)?\b[^\n]{0,30}?([\d.]+)\s*%",
],
"lcr": [
r"\b(?:lcr|liquidity coverage ratio)\b[^\n]{0,30}?([\d.]+)\s*%",
],
"adr": [
r"\b(?:adr|advances[- ]to[- ]deposit ratio)\b[^\n]{0,30}?([\d.]+)\s*%",
],
"cost_income_ratio": [
r"\bcost[- ]to[- ]income\s*(?:ratio)?\b[^\n]{0,30}?([\d.]+)\s*%",
r"\bcost\s+income\s+ratio\b[^\n]{0,30}?([\d.]+)\s*%",
],
"total_income": [
r"\btotal income\b[^\n]{0,30}?(?:aed\s*)?([\d,.]+)\s*\b(bn|mn|b|m)\b",
],
"total_assets": [
r"\btotal assets\b[^\n]{0,30}?(?:aed\s*)?([\d,.]+)\s*\b(bn|mn|b|m|trillion)\b",
],
# New KPIs
"gross_loans_total": [
r"\b(?:gross loans|total gross loans)\b[^\n]{0,30}?(?:aed\s*)?([\d,.]+)\s*\b(bn|mn|b|m)\b",
],
"total_deposits": [
r"\b(?:deposits|customer deposits|total deposits)\b[^\n]{0,30}?(?:aed\s*)?([\d,.]+)\s*\b(bn|mn|b|m)\b",
],
"casa_ratio": [
r"\bcasa\s*(?:mix|ratio|stability ratio)?\b[^\n]{0,30}?([\d.]+)\s*%",
],
"retail_pbt": [
r"Retail Banking.*?PBT\s+([\d,.]+)",
],
"cib_pbt": [
r"Corporate and.*?PBT\s+([\d,.]+)",
],
"gmt_pbt": [
r"Global Markets.*?PBT\s+([\d,.]+)",
],
"denizbank_pbt": [
r"DenizBank.*?PBT\s+([\d,.]+)",
],
"nfi_total": [
r"\b(?:total non-funded income|total non funded income|nfi)\b[^\n]{0,30}?(?:aed\s*)?([\d,.]+)\s*\b(bn|mn|b|m)\b",
],
"nfi_yoy": [
r"non[- ]funded income,?\s+up\s+([\d.]+)\s*%\s*yoy",
r"non[- ]funded income up\s+([\d.]+)\s*%\s*yoy",
],
}
KPI_ALLOWED_SECTIONS = {
"net_profit_current": ["Income Statement", "Group Overview", "Financial Appendix"],
"profit_before_tax": ["Income Statement", "Group Overview", "Financial Appendix"],
"nim_percent": ["Net Interest Margin", "Income Statement", "Group Overview", "Financial Appendix"],
"npl_ratio": ["Asset Quality", "Loans & Deposits", "Group Overview", "Financial Appendix"],
"coverage_ratio": ["Asset Quality", "Group Overview", "Financial Appendix"],
"cet1_ratio": ["Capital Adequacy", "Group Overview", "Financial Appendix"],
"lcr": ["Liquidity", "Group Overview", "Financial Appendix"],
"adr": ["Liquidity", "Group Overview", "Financial Appendix"],
"cost_income_ratio": ["Cost to Income", "Income Statement", "Group Overview", "Financial Appendix"],
"total_income": ["Income Statement", "Group Overview", "Financial Appendix"],
"total_assets": ["Loans & Deposits", "Liquidity", "Group Overview", "Financial Appendix", "Income Statement"],
# New KPIs
"gross_loans_total": ["Loans & Deposits", "Income Statement", "Group Overview", "Financial Appendix"],
"gross_loans_yoy": ["Loans & Deposits", "Income Statement", "Group Overview", "Financial Appendix"],
"gross_loans_qoq": ["Loans & Deposits", "Income Statement", "Group Overview", "Financial Appendix"],
"total_deposits": ["Loans & Deposits", "Income Statement", "Group Overview", "Financial Appendix", "Liquidity"],
"total_deposits_yoy": ["Loans & Deposits", "Income Statement", "Group Overview", "Financial Appendix", "Liquidity"],
"total_deposits_qoq": ["Loans & Deposits", "Income Statement", "Group Overview", "Financial Appendix", "Liquidity"],
"casa_ratio": ["Loans & Deposits", "Liquidity", "Group Overview", "Financial Appendix"],
"retail_pbt": ["Income Statement", "Group Overview", "Financial Appendix", "Asset Quality", "Divisional performance"],
"cib_pbt": ["Income Statement", "Group Overview", "Financial Appendix", "Divisional performance"],
"gmt_pbt": ["Income Statement", "Group Overview", "Financial Appendix", "Divisional performance"],
"denizbank_pbt": ["Income Statement", "Group Overview", "Financial Appendix", "Hyperinflation", "Divisional performance"],
"nfi_total": ["Income Statement", "Group Overview", "Financial Appendix", "Non-Funded Income"],
"nfi_yoy": ["Income Statement", "Group Overview", "Financial Appendix", "Non-Funded Income"],
}
def detect_table_scale(tbl, page_text: str) -> str:
"""Detect whether table cells represent Millions (mn) or Billions (bn)."""
headers_str = " ".join(str(h) for h in tbl.get("headers", []) if h).lower()
caption_str = str(tbl.get("caption", "")).lower()
combined = headers_str + " " + caption_str
if re.search(r"\b(?:bn|billion)\b", combined):
return "bn"
if re.search(r"\b(?:mn|million)\b", combined):
return "mn"
# Check page text
page_text_lower = page_text.lower()
has_bn = bool(re.search(r"\b(?:bn|billion)\b", page_text_lower))
has_mn = bool(re.search(r"\b(?:mn|million)\b", page_text_lower))
if has_bn and not has_mn:
return "bn"
if has_mn and not has_bn:
return "mn"
return "bn"
def extract_kpi_ground_truth(extracted, page_map: dict) -> dict:
"""
Extract exact numeric KPI values from verified table and text pages.
These values act as ground truth β€” any generated response that contradicts
them is flagged or overridden.
"""
ground_truth = {
"doc_id": extracted.doc_id,
"doc_name": extracted.doc_name,
"extracted": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"kpis": {},
"page_values": {},
}
# Ensure page_values has section structures ready
for page in extracted.pages:
pg_info = page_map.get(page.page_number, {})
section = pg_info.get("section", "")
if section:
ground_truth["page_values"][str(page.page_number)] = {
"section": section,
"values": {}
}
# 1. EXTRACT FROM TABLES FIRST (High priority structured data)
for page in extracted.pages:
pg_info = page_map.get(page.page_number, {})
section = pg_info.get("section", "")
if not section:
continue
for tbl in page.tables:
for row in tbl.get("rows", []):
if len(row) < 2:
continue
label = str(row[0]).lower().strip()
for kpi_key, patterns in KPI_TABLE_LABEL_PATTERNS.items():
# Check allowed sections for this KPI
allowed = KPI_ALLOWED_SECTIONS.get(kpi_key, [])
if section not in allowed:
continue
for pat in patterns:
if re.search(pat, label, re.IGNORECASE):
# Resolve column index for yoy/qoq/normal
col_idx = 1
if kpi_key.endswith("_yoy"):
if len(row) > 3 and "%" in str(row[3]):
col_idx = 3
elif len(row) > 2 and "%" in str(row[2]):
col_idx = 2
else:
continue
elif kpi_key.endswith("_qoq"):
if len(row) > 5 and "%" in str(row[5]):
col_idx = 5
elif len(row) > 4 and "%" in str(row[4]):
col_idx = 4
else:
continue
cell_str = str(row[col_idx]).strip()
# Strip brackets/symbols for clean numeric extraction
cell_str_clean = cell_str.replace("(", "").replace(")", "").replace("%", "").strip()
num_match = re.search(r"([\d,.]+)", cell_str_clean)
if num_match:
try:
val = float(num_match.group(1).replace(",", ""))
unit = "%" if kpi_key in PERCENTAGE_KPIS else "bn"
# Normalize millions to billions
if unit != "%":
scale = detect_table_scale(tbl, page.text)
if scale == "mn":
val = val / 1000
unit = "bn"
elif "mn" in cell_str.lower() or "m" in cell_str.lower():
val = val / 1000
unit = "bn"
kpi_data = {
"value": round(val, 3),
"unit": unit,
"raw": cell_str,
"page": page.page_number,
"section": section,
}
# Update global and page values
if kpi_key not in ground_truth["kpis"]:
ground_truth["kpis"][kpi_key] = kpi_data
page_key = str(page.page_number)
if kpi_key not in ground_truth["page_values"][page_key]["values"]:
ground_truth["page_values"][page_key]["values"][kpi_key] = kpi_data
except ValueError:
pass
break
break
# 2. EXTRACT FROM PAGE TEXT (Low priority fallback for missing values)
for page in extracted.pages:
pg_info = page_map.get(page.page_number, {})
section = pg_info.get("section", "")
if not section:
continue
full_text = page.text
for kpi_key, patterns in KPI_TEXT_PATTERNS.items():
if kpi_key in ground_truth["kpis"]:
continue
allowed = KPI_ALLOWED_SECTIONS.get(kpi_key, [])
if section not in allowed:
continue
for pattern in patterns:
flags = re.DOTALL if "PBT" in pattern or kpi_key in ["retail_pbt", "cib_pbt", "gmt_pbt", "denizbank_pbt"] else 0
match = re.search(pattern, full_text, re.IGNORECASE | flags)
if match:
try:
raw_val = match.group(1).replace(",", "")
unit = "%" if kpi_key in PERCENTAGE_KPIS else "bn"
value = float(raw_val)
# Normalize text PBT values or other millions values to billions
if unit != "%" and (kpi_key in ["retail_pbt", "cib_pbt", "gmt_pbt", "denizbank_pbt"] or "mn" in match.group(0).lower() or re.search(r"\b(?:mn|million)\b", match.group(0), re.I)):
if value > 10.0:
value = value / 1000
unit = "bn"
kpi_data = {
"value": round(value, 3),
"unit": unit,
"raw": match.group(0)[:80].strip(),
"page": page.page_number,
"section": section,
}
# Update global and page values
if kpi_key not in ground_truth["kpis"]:
ground_truth["kpis"][kpi_key] = kpi_data
page_key = str(page.page_number)
if kpi_key not in ground_truth["page_values"][page_key]["values"]:
ground_truth["page_values"][page_key]["values"][kpi_key] = kpi_data
except ValueError:
pass
break
# Clean up empty page_values keys
empty_pages = [k for k, v in ground_truth["page_values"].items() if not v["values"]]
for k in empty_pages:
del ground_truth["page_values"][k]
logger.info(f" Extracted {len(ground_truth['kpis'])} ground truth KPI values")
return ground_truth
# ═══════════════════════════════════════════════════════════════════════════════
# UPDATE GENERATION AGENT WITH CALIBRATED PARAMS
# ═══════════════════════════════════════════════════════════════════════════════
def update_generation_agent(config: dict, kpi_gt: dict):
"""
Update financial_analyst_agent.py with calibrated parameters:
1. Temperature from calibrated config
2. KPI constraints written into response_rules.json
3. Confidence threshold update
"""
agent_path = BASE_DIR / "services" / "generation" / "financial_analyst_agent.py"
if not agent_path.exists():
logger.warning(f" Agent file not found: {agent_path}")
return
with open(agent_path, "r") as f:
content = f.read()
# 1. Update default temperature
new_temp = config.get("ollama_temperature", 0.05)
content = re.sub(
r"(def __init__.*?temperature:\s*float\s*=\s*)[\d.]+",
lambda m: m.group(0).rsplit("=", 1)[0] + f"= {new_temp}",
content,
count=1,
)
# 2. Update confidence threshold
threshold = config.get("min_confidence_threshold", 0.20)
content = re.sub(
r"(min_confidence_threshold\s*=\s*)[\d.]+",
f"\\g<1>{threshold}",
content,
)
with open(agent_path, "w") as f:
f.write(content)
# 3. Write verified KPI constraints to response_rules.json
rules_file = BASE_DIR / "data" / "response_rules.json"
kpi_label_map = {
"net_profit_current": "Net Profit",
"profit_before_tax": "Profit Before Tax",
"nim_percent": "Net Interest Margin (NIM)",
"npl_ratio": "NPL Ratio",
"coverage_ratio": "Coverage Ratio",
"cet1_ratio": "CET-1 Ratio",
"lcr": "Liquidity Coverage Ratio (LCR)",
"adr": "Advances-to-Deposit Ratio (ADR)",
"cost_income_ratio": "Cost-to-Income Ratio",
"total_income": "Total Income",
"total_assets": "Total Assets",
"gross_loans_total": "Total Gross Loans",
"gross_loans_yoy": "Gross Loans YoY Change",
"total_deposits": "Total Deposits",
"total_deposits_yoy": "Total Deposits YoY Change",
"casa_ratio": "CASA Ratio",
"retail_pbt": "Retail Segment PBT",
"cib_pbt": "CIB Segment PBT",
"gmt_pbt": "GM&T Segment PBT",
"denizbank_pbt": "DenizBank Segment PBT",
"nfi_total": "Total Non-Funded Income",
"nfi_yoy": "Non-Funded Income YoY Change",
}
kpi_lines = []
for key, label in kpi_label_map.items():
gt = kpi_gt.get("kpis", {}).get(key)
if gt:
unit = gt.get("unit", "")
val = gt.get("value", "")
pg = gt.get("page", "")
unit_str = "%" if "%" in str(unit) else f" {unit}"
kpi_lines.append(f" - {label}: {val}{unit_str} (Page {pg})")
if kpi_lines and rules_file.exists():
try:
with open(rules_file) as f:
rules = json.load(f)
rules["verified_kpis"] = kpi_lines
rules["verified_kpis_updated"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
with open(rules_file, "w") as f:
json.dump(rules, f, indent=2, ensure_ascii=False)
logger.info(f" βœ“ {len(kpi_lines)} KPI constraints written to response_rules.json")
except Exception as e:
logger.warning(f" Could not update response_rules.json: {e}")
logger.info(
f" βœ“ Agent updated: temperature={new_temp}, "
f"threshold={threshold}, "
f"{len(kpi_lines)} KPI constraints injected"
)
# ═══════════════════════════════════════════════════════════════════════════════
# TABLE INDEXING WITH ENRICHED METADATA
# ═══════════════════════════════════════════════════════════════════════════════
def enrich_and_index_tables(extracted, page_map: dict, doc_id: str, doc_name: str) -> list:
from services.retrieval.table_retriever import FINANCIAL_METRIC_ALIASES
def detect_metrics(text: str, section: str) -> set:
found = set()
tl = (text + " " + section).lower()
for key, aliases in FINANCIAL_METRIC_ALIASES.items():
if any(a in tl for a in aliases):
found.add(key)
return found
def table_to_text(tbl: dict) -> str:
headers = tbl.get("headers", [])
rows = tbl.get("rows", [])
caption = tbl.get("caption", "")
parts = []
if caption:
parts.append(f"Table: {caption}")
if headers:
parts.append(" | ".join(h for h in headers if h))
for row in rows[:30]:
parts.append(" | ".join(str(c) for c in row))
return "\n".join(parts)
records = []
for page in extracted.pages:
pmap = page_map.get(page.page_number, {})
section = pmap.get("section", "")
slide = pmap.get("slide", page.page_number)
sec_kpis = set(SECTION_KPI_TAGS.get(section, []))
for tbl in page.tables:
text_rep = table_to_text(tbl)
text_metrics = detect_metrics(text_rep, section)
records.append({
"doc_id": doc_id,
"doc_name": doc_name,
"page_number": page.page_number,
"slide_number": slide,
"section_heading": section,
"headers": tbl.get("headers", []),
"rows": tbl.get("rows", []),
"caption": tbl.get("caption", ""),
"text_representation": text_rep,
"metrics_found": sorted(text_metrics | sec_kpis),
})
# Merge β€” keep records for other docs
existing = []
if TABLES_FILE.exists():
try:
with open(TABLES_FILE) as f:
existing = json.load(f)
existing = [r for r in existing if r.get("doc_id") != doc_id]
except Exception:
existing = []
all_records = existing + records
TABLES_FILE.parent.mkdir(parents=True, exist_ok=True)
with open(TABLES_FILE, "w", encoding="utf-8") as f:
json.dump(all_records, f, indent=2, ensure_ascii=False)
return records
# ═══════════════════════════════════════════════════════════════════════════════
# MAIN PIPELINE
# ═══════════════════════════════════════════════════════════════════════════════
def run(pdf_path: Path, doc_id: str, skip_colpali: bool, do_fine_tune: bool):
doc_name = " ".join(w.capitalize() for w in doc_id.replace("_", " ").split())
total_start = time.time()
logger.info("=" * 72)
logger.info(" IRIS β€” IR Document Training Pipeline")
logger.info(f" PDF : {pdf_path.name}")
logger.info(f" Doc ID : {doc_id}")
logger.info(f" Options : fine-tune={do_fine_tune}, skip-colpali={skip_colpali}")
logger.info("=" * 72)
if not pdf_path.exists():
logger.error(f"PDF not found: {pdf_path}")
sys.exit(1)
# Load existing config or use defaults
config = {**DEFAULT_CONFIG}
if CONFIG_FILE.exists():
with open(CONFIG_FILE) as f:
config.update(json.load(f))
# ── 1. Parse PDF ─────────────────────────────────────────────────────────
logger.info("\n[1/9] Parsing PDF…")
from services.ingestion.pdf_parser import parse_pdf, save_extraction
t0 = time.time()
extracted = parse_pdf(pdf_path, doc_id, doc_name)
save_extraction(extracted, PROC_DIR)
logger.info(f" βœ“ {extracted.total_pages} pages | {time.time()-t0:.1f}s")
# ── 2. Build page map ─────────────────────────────────────────────────────
logger.info("\n[2/9] Building page β†’ section β†’ KPI mapping…")
page_map = build_page_map(extracted)
PAGEMAP_DIR.mkdir(parents=True, exist_ok=True)
pagemap_file = PAGEMAP_DIR / f"{doc_id}_pagemap.json"
with open(pagemap_file, "w", encoding="utf-8") as f:
json.dump({
"doc_id": doc_id, "doc_name": doc_name,
"generated": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"note": "Auto-generated. Edit section/kpis and re-run to apply corrections.",
"pages": {str(k): v for k, v in page_map.items()},
}, f, indent=2, ensure_ascii=False)
# Print mapping table
logger.info(f" {'Page':>4} | {'Slide':>5} | {'Section':<28} | Title")
logger.info(f" {'----':>4}-+-{'-----':>5}-+-{'-'*28}-+-{'-'*30}")
for pg in sorted(page_map):
info = page_map[pg]
logger.info(
f" P{pg:02d} | S{info['slide']:02d} | {info['section']:<28} | {info['title'][:40]}"
)
# ── 3. Extract KPI ground truth ───────────────────────────────────────────
logger.info("\n[3/9] Extracting KPI ground truth values for hallucination prevention…")
kpi_gt = extract_kpi_ground_truth(extracted, page_map)
KPI_GT_FILE.parent.mkdir(parents=True, exist_ok=True)
# Merge with existing ground truth
existing_gt = {}
if KPI_GT_FILE.exists():
try:
with open(KPI_GT_FILE) as f:
existing_gt = json.load(f)
except Exception:
existing_gt = {}
existing_gt[doc_id] = kpi_gt
with open(KPI_GT_FILE, "w", encoding="utf-8") as f:
json.dump(existing_gt, f, indent=2, ensure_ascii=False)
logger.info(" Extracted KPI values:")
for key, val in kpi_gt.get("kpis", {}).items():
logger.info(f" {key:<25} = {val['value']} {val['unit']} (Page {val['page']}, {val['section']})")
# ── 4. Chunk text ─────────────────────────────────────────────────────────
logger.info("\n[4/9] Chunking text with section metadata…")
import chromadb
from chromadb.config import Settings
try:
client = chromadb.PersistentClient(
path=str(CHROMA_DIR), settings=Settings(anonymized_telemetry=False)
)
client.delete_collection("finbot_ir_chunks")
logger.info(" βœ“ Old ChromaDB collection cleared")
except Exception:
pass
from services.ingestion.text_chunker import TextChunker
t0 = time.time()
chunker = TextChunker(
chunk_size=config["chunk_size"],
chunk_overlap=config["chunk_overlap"],
)
page_texts = []
for p in extracted.pages:
if not p.text.strip():
continue
pmap = page_map.get(p.page_number, {})
page_texts.append({
"page_number": p.page_number,
"text": p.text,
"section_heading": pmap.get("section", getattr(p, "section_heading", "") or ""),
"slide_number": pmap.get("slide", getattr(p, "slide_number", None) or p.page_number),
})
chunks = chunker.chunk_document(page_texts, doc_id)
for chunk in chunks:
pmap = page_map.get(chunk.page_number, {}) or page_map.get(str(chunk.page_number), {})
setattr(chunk, "section_heading", pmap.get("section", ""))
logger.info(f" βœ“ {len(chunks)} chunks | {time.time()-t0:.1f}s")
# ── 5. Generate training pairs & calibrate weights ─────────────────────────
logger.info("\n[5/9] Generating training pairs and calibrating retrieval weights…")
pairs = generate_training_pairs(page_map, chunks)
calibrated = calibrate_weights(pairs, chunks)
# Merge calibrated weights into full config
config.update(calibrated)
config["last_calibrated"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
config["calibrated_on_doc"] = doc_id
# ── 6. Optional: Fine-tune embedding model ────────────────────────────────
embed_model_path = config["embed_model"]
if do_fine_tune and pairs:
logger.info("\n[6/9] Fine-tuning embedding model on domain data…")
fine_tuned_path = fine_tune_embeddings(pairs, doc_id, config["embed_model"])
embed_model_path = str(fine_tuned_path)
config["embed_model"] = embed_model_path
config["fine_tuned"] = True
else:
logger.info("\n[6/9] Skipping fine-tuning (use --fine-tune to enable)")
# Save config BEFORE embedding so embedding uses correct model
CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True)
with open(CONFIG_FILE, "w") as f:
json.dump(config, f, indent=2)
logger.info(f" βœ“ Retrieval config saved β†’ {CONFIG_FILE.name}")
# ── 7. Embed + index chunks ────────────────────────────────────────────────
logger.info("\n[7/9] Embedding chunks and indexing in ChromaDB…")
from services.retrieval.text_retriever import TextRetriever
t0 = time.time()
chunks = chunker.embed_chunks(chunks)
text_retriever = TextRetriever(persist_dir=CHROMA_DIR)
n_chunks = text_retriever.index_chunks(chunks)
logger.info(f" βœ“ {n_chunks} chunks embedded and indexed | {time.time()-t0:.1f}s")
# ── 8. Index tables ────────────────────────────────────────────────────────
logger.info("\n[8/9] Indexing tables with enriched metadata…")
table_records = enrich_and_index_tables(extracted, page_map, doc_id, doc_name)
logger.info(f" βœ“ {len(table_records)} tables indexed")
# ── Update generation agent ────────────────────────────────────────────────
logger.info("\n Updating generation agent with calibrated config…")
update_generation_agent(config, kpi_gt)
# ── 8b. Auto-generate Smart Response Cache ────────────────────────────────
# Pre-fills document-agnostic templates with real KPI values so the Smart
# Response Engine can serve instant (<10ms) answers for any trained PDF.
logger.info("\n[8b] Auto-generating smart response cache for new engine…")
try:
from services.classification.kpi_context_builder import KPIContextBuilder, INTENT_KPI_KEYS
from services.generation.smart_response_engine import (
SmartResponseEngine, save_cached_response, _INTENT_BUILDERS
)
_cache_builder = KPIContextBuilder(data_dir=DATA_DIR)
cache_results: dict[str, str] = {}
for intent in INTENT_KPI_KEYS.keys():
ctx = _cache_builder.build(intent=intent, doc_ids=[doc_id])
if ctx.has_data:
builder_fn = _INTENT_BUILDERS.get(intent)
if builder_fn:
try:
resp = builder_fn(ctx)
if resp:
resp["question"] = f"[auto-generated for intent: {intent}]"
resp["latency_ms"] = 0
saved = save_cached_response(doc_id, intent, resp)
cache_results[intent] = "βœ“" if saved else "βœ—"
else:
cache_results[intent] = "β€” (no KPI match in template)"
except Exception as be:
cache_results[intent] = f"βœ— ({be})"
else:
cache_results[intent] = "β€” (no template builder)"
else:
cache_results[intent] = "β€” (no KPI data extracted)"
for intent, status in cache_results.items():
logger.info(f" {status} {intent}")
n_cached = sum(1 for s in cache_results.values() if s == "βœ“")
logger.info(f" βœ“ {n_cached}/{len(cache_results)} smart response templates generated")
logger.info(f" Cached to: data/response_cache/{doc_id}/")
except Exception as e:
logger.warning(f" ⚠ Smart cache generation failed (non-fatal): {e}")
# ── 9. Render pages + ColPali ─────────────────────────────────────────────
logger.info("\n[9/9] Rendering PDF pages…")
from services.ingestion.page_renderer import render_pdf_pages
t0 = time.time()
PAGES_DIR.mkdir(parents=True, exist_ok=True)
page_records = render_pdf_pages(
pdf_path=pdf_path,
output_dir=PAGES_DIR,
doc_id=doc_id,
)
logger.info(f" βœ“ {len(page_records)} pages rendered | {time.time()-t0:.1f}s")
colpali_pages = len(page_records)
if not skip_colpali:
logger.info(" Generating ColPali visual embeddings (batch_size=1)…")
import torch
if torch.backends.mps.is_available():
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
from services.ingestion.colpali_indexer import ColPaliIndexer
t0 = time.time()
indexer = ColPaliIndexer(store_dir=COLPALI_DIR)
page_records = indexer.index_pages(page_records, doc_id=doc_id, batch_size=1)
colpali_pages = len(page_records)
logger.info(f" βœ“ {colpali_pages} ColPali pages | {time.time()-t0:.1f}s")
else:
logger.info(" Skipping ColPali (--skip-colpali)")
# ── Update documents registry ──────────────────────────────────────────────
docs = []
if DOCS_FILE.exists():
with open(DOCS_FILE) as f:
docs = json.load(f)
docs = [d for d in docs if d["doc_id"] != doc_id]
docs.append({
"doc_id": doc_id,
"name": doc_name,
"doc_type": extracted.metadata.get("doc_type", "Financial Document"),
"period": extracted.metadata.get("period", "Unknown"),
"institution": extracted.metadata.get("institution", "Unknown"),
"total_pages": extracted.total_pages,
"status": "indexed",
"filename": pdf_path.name,
"chunks_indexed": n_chunks,
"tables_indexed": len(table_records),
"colpali_pages": colpali_pages,
"pagemap_file": pagemap_file.name,
"page_section_map": {str(k): v["section"] for k, v in page_map.items()},
"retrieval_config": str(CONFIG_FILE.name),
})
DATA_DIR.mkdir(parents=True, exist_ok=True)
with open(DOCS_FILE, "w") as f:
json.dump(docs, f, indent=2)
# ── Final report ───────────────────────────────────────────────────────────
total_elapsed = time.time() - total_start
mrr = calibrated.get("mrr", "β€”")
logger.info("")
logger.info("=" * 72)
logger.info(f" βœ… Training complete in {total_elapsed:.1f}s")
logger.info(f" {'─'*60}")
logger.info(f" Text chunks indexed : {n_chunks}")
logger.info(f" Tables indexed : {len(table_records)}")
logger.info(f" KPI ground truths : {len(kpi_gt.get('kpis', {}))}")
logger.info(f" Training pairs : {len(pairs)}")
logger.info(f" Embedding model : {embed_model_path}")
logger.info(f" Fine-tuned : {do_fine_tune}")
logger.info(f" Calibrated weights:")
logger.info(f" metric_match : {config.get('metric_match_weight')}")
logger.info(f" section_boost : {config.get('section_heading_boost')}")
logger.info(f" confidence floor : {config.get('min_confidence_threshold')}")
logger.info(f" LLM temperature : {config.get('ollama_temperature')}")
logger.info(f" {'─'*60}")
logger.info(f" Config saved to : data/{CONFIG_FILE.name}")
logger.info(f" KPI ground truth : data/{KPI_GT_FILE.name}")
logger.info(f" Page mapping : data/page_maps/{pagemap_file.name}")
logger.info("=" * 72)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="IRIS β€” Train on IR PDF, calibrate weights, reduce hallucination",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Standard training (parse + calibrate + index)
python train_document.py --pdf ../documents/enbd_q1_2026.pdf
# Full training including embedding fine-tuning (~20 min extra)
python train_document.py --pdf ../documents/enbd_q1_2026.pdf --fine-tune
# Fast mode β€” skip ColPali visual embeddings
python train_document.py --pdf ../documents/report.pdf --skip-colpali
# Custom document ID
python train_document.py --pdf ../documents/report.pdf --doc-id enbd_fy2025
""",
)
parser.add_argument("--pdf", required=True, help="Path to IR PDF")
parser.add_argument("--doc-id", default=None, help="Document ID (derived from filename if omitted)")
parser.add_argument("--skip-colpali", action="store_true", help="Skip ColPali visual embedding")
parser.add_argument("--fine-tune", action="store_true", help="Fine-tune the embedding model on domain data")
args = parser.parse_args()
pdf = Path(args.pdf).resolve()
d_id = args.doc_id or pdf.stem.lower().replace(" ", "_").replace("-", "_")
run(pdf_path=pdf, doc_id=d_id, skip_colpali=args.skip_colpali, do_fine_tune=args.fine_tune)