Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| """ | |
| IRIS β IR Document Training Script | |
| ===================================== | |
| Trains the system on any IR PDF to reduce retrieval errors and hallucination. | |
| This script does TWO distinct kinds of "training": | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| A. RETRIEVAL TRAINING (always runs) | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| 1. Parse PDF β extract text, tables, section headings | |
| 2. Auto-generate page β section β KPI ground-truth mapping | |
| 3. Build (query, correct_page) training pairs from the mapping | |
| 4. Calibrate retrieval scoring weights via grid search (MRR metric) | |
| - metric_match_weight | |
| - section_heading_boost | |
| - caption_boost | |
| - min_confidence_threshold | |
| 5. Save calibrated weights to data/retrieval_config.json | |
| 6. Re-embed text chunks with updated config in ChromaDB | |
| 7. Re-index tables with enriched metadata | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| B. EMBEDDING FINE-TUNING (optional flag --fine-tune) | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| 8. Build contrastive training pairs: | |
| Positive : (KPI query, correct page text chunk) | |
| Hard negatives: (KPI query, wrong-section chunks) | |
| 9. Fine-tune BAAI/bge-small-en-v1.5 using | |
| MultipleNegativesRankingLoss (sentence-transformers) | |
| 10. Save fine-tuned model to data/models/<doc_id>_embed/ | |
| 11. Re-embed all chunks with the fine-tuned model | |
| 12. Re-index ChromaDB with new embeddings | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| C. ANTI-HALLUCINATION HARDENING (always runs, step 13) | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| 13. Extract EXACT numeric KPI values from verified tables | |
| 14. Save to data/kpi_ground_truth.json | |
| 15. Update generation agent to cross-check produced numbers | |
| against ground truth before returning a response | |
| 16. Apply confidence threshold β below threshold returns | |
| "Insufficient evidence" instead of a hallucinated answer | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| D. INGESTION PIPELINE (always runs) | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| 17. Render PDF pages to PNG images | |
| 18. Generate ColPali visual embeddings (batch_size=1, MPS safe) | |
| Usage: | |
| python train_document.py --pdf ../documents/enbd_q1_2026.pdf | |
| python train_document.py --pdf ../documents/enbd_q1_2026.pdf --fine-tune | |
| python train_document.py --pdf ../documents/enbd_q1_2026.pdf --skip-colpali | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| import math | |
| import os | |
| import re | |
| import sys | |
| import time | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Optional | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)-8s %(name)s %(message)s", | |
| datefmt="%H:%M:%S", | |
| ) | |
| logger = logging.getLogger("iris.train") | |
| BASE_DIR = Path(__file__).parent | |
| DATA_DIR = BASE_DIR / "data" | |
| PAGES_DIR = DATA_DIR / "pages" | |
| COLPALI_DIR = DATA_DIR / "colpali_index" | |
| CHROMA_DIR = DATA_DIR / "chroma" | |
| TABLES_FILE = DATA_DIR / "tables.json" | |
| DOCS_FILE = DATA_DIR / "documents.json" | |
| PROC_DIR = DATA_DIR / "processed" | |
| PAGEMAP_DIR = DATA_DIR / "page_maps" | |
| MODELS_DIR = DATA_DIR / "models" | |
| CONFIG_FILE = DATA_DIR / "retrieval_config.json" | |
| KPI_GT_FILE = DATA_DIR / "kpi_ground_truth.json" | |
| # ββ Default retrieval weights (overridden by calibration) ββββββββββββββββββββ | |
| DEFAULT_CONFIG = { | |
| "embed_model": "BAAI/bge-small-en-v1.5", | |
| "chunk_size": 300, | |
| "chunk_overlap": 60, | |
| "top_k_text": 6, | |
| "top_k_tables": 4, | |
| "top_k_visual": 3, | |
| "metric_match_weight": 3.0, | |
| "section_heading_boost": 4.0, | |
| "caption_boost": 2.0, | |
| "term_overlap_weight": 0.3, | |
| "min_confidence_threshold": 0.20, | |
| "ollama_temperature": 0.05, | |
| "ollama_num_predict": 2048, | |
| "generation_grounding": True, | |
| "reject_below_threshold": True, | |
| } | |
| # ββ Section heading detection patterns βββββββββββββββββββββββββββββββββββββββ | |
| SECTION_PATTERNS = [ | |
| (r"funding.*liquidity", "Liquidity"), | |
| (r"liquidity coverage ratio", "Liquidity"), | |
| (r"advances.*deposit.*ratio", "Liquidity"), | |
| (r"liquid assets.*aed", "Liquidity"), | |
| (r"\blcr\b.*\badr\b", "Liquidity"), | |
| (r"income statement", "Income Statement"), | |
| (r"profit before tax", "Income Statement"), | |
| (r"aed\s+[\d.]+bn.*profit", "Income Statement"), | |
| (r"net interest margin", "Net Interest Margin"), | |
| (r"margins remain", "Net Interest Margin"), | |
| (r"non[- ]funded income", "Non-Funded Income"), | |
| (r"loan growth", "Loans & Deposits"), | |
| (r"gross loan", "Loans & Deposits"), | |
| (r"deposit growth", "Loans & Deposits"), | |
| (r"robust credit quality", "Asset Quality"), | |
| (r"npl ratio", "Asset Quality"), | |
| (r"coverage ratio", "Asset Quality"), | |
| (r"cost of risk", "Asset Quality"), | |
| (r"cost[- ]to[- ]income", "Cost to Income"), | |
| (r"operating expense", "Cost to Income"), | |
| (r"common equity tier", "Capital Adequacy"), | |
| (r"cet[- ]?1.*ratio", "Capital Adequacy"), | |
| (r"capital adequacy", "Capital Adequacy"), | |
| (r"divisional performance", "Divisional Performance"), | |
| (r"\besg\b", "ESG"), | |
| (r"sustainability", "ESG"), | |
| (r"gdp.*growth", "Economic Environment"), | |
| (r"denizbank", "DenizBank / TΓΌrkiye"), | |
| (r"hyperinflation", "DenizBank / TΓΌrkiye"), | |
| (r"credit rating", "Credit Ratings"), | |
| (r"investment case", "Investment Case"), | |
| (r"financial results.*q", "Financial Appendix"), | |
| ] | |
| SECTION_BY_PAGE = { | |
| 1: "Cover", | |
| 2: "Important Information", | |
| 3: "Economic Environment", | |
| 4: "Economic Environment", | |
| 5: "Economic Environment", | |
| 6: "Group Overview", | |
| 7: "Group Overview", | |
| 8: "International Presence", | |
| 9: "Credit Ratings", | |
| 10: "Shareholder Base", | |
| 11: "Investment Case", | |
| 12: "Peer Comparison", | |
| 13: "Profitability", | |
| 14: "Financial & Operating Performance", | |
| 15: "Executive Summary", | |
| 16: "Income Statement", | |
| 17: "Net Interest Margin", | |
| 18: "Non-Funded Income", | |
| 19: "Loans & Deposits", | |
| 20: "Asset Quality", | |
| 21: "Cost to Income", | |
| 22: "Liquidity", | |
| 23: "Capital Adequacy", | |
| 24: "Divisional Performance", | |
| 25: "ESG", | |
| 26: "ESG", | |
| 27: "ESG", | |
| 28: "ESG", | |
| 29: "Appendix", | |
| 30: "Financial Results", | |
| 31: "USD Translation", | |
| 32: "Hyperinflation", | |
| 33: "Turkey Macro", | |
| 34: "Egypt Macro", | |
| 35: "KSA Macro", | |
| 36: "Contact", | |
| } | |
| SECTION_KPI_TAGS = { | |
| "Income Statement": ["net profit", "revenue", "nim", "cor", "cost_income"], | |
| "Net Interest Margin": ["nim"], | |
| "Non-Funded Income": ["revenue"], | |
| "Loans & Deposits": ["loans", "deposits"], | |
| "Asset Quality": ["npl", "cor"], | |
| "Cost to Income": ["cost_income"], | |
| "Liquidity": ["lcr", "deposits"], | |
| "Capital Adequacy": ["car"], | |
| "Divisional Performance": ["net profit", "revenue", "nim", "cor", "npl"], | |
| "Group Overview": ["net profit", "revenue", "deposits", "loans", "car"], | |
| "Financial Appendix": ["net profit", "revenue", "nim", "cor"], | |
| } | |
| # ββ Canonical queries per KPI for training pair generation βββββββββββββββββββ | |
| KPI_CANONICAL_QUERIES = { | |
| "Income Statement": [ | |
| "net profit", "profit before tax", "total income", "operating profit", | |
| "group earnings", "net earnings", "profit after tax", "PAT", | |
| "profit growth year on year", "Q1 2026 profitability", | |
| ], | |
| "Net Interest Margin": [ | |
| "net interest margin", "NIM", "NIM trend", "interest margin performance", | |
| "margin compression", "NII growth", "net interest income", | |
| ], | |
| "Non-Funded Income": [ | |
| "non-funded income", "NFI", "fee income", "non-interest income", | |
| "fee and commission income", "trading income", | |
| ], | |
| "Loans & Deposits": [ | |
| "loan growth", "deposit growth", "advances growth", "loan book", | |
| "credit growth", "total loans", "customer deposits", | |
| ], | |
| "Asset Quality": [ | |
| "NPL ratio", "non-performing loans", "cost of risk", "credit quality", | |
| "coverage ratio", "impairment charges", "provisions", | |
| ], | |
| "Cost to Income": [ | |
| "cost to income", "cost income ratio", "operating efficiency", | |
| "operating expenses", "CIR", | |
| ], | |
| "Liquidity": [ | |
| "liquidity coverage ratio", "LCR", "advances to deposit ratio", "ADR", | |
| "liquid assets", "funding structure", "liquidity position", | |
| "what is the lcr", "lcr in q1 2026", "liquidity coverage ratio in q1 2026", | |
| "how strong is the lcr", "lcr performance", | |
| ], | |
| "Capital Adequacy": [ | |
| "capital adequacy ratio", "CAR", "CET1", "CET-1", "tier 1 ratio", | |
| "capital position", "common equity tier 1", "regulatory capital", | |
| ], | |
| "Divisional Performance": [ | |
| "divisional performance", "business segments", "retail banking", | |
| "corporate banking", "DenizBank performance", "segment results", | |
| ], | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # A. RETRIEVAL TRAINING | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def detect_section(page_text: str) -> str: | |
| text_lower = page_text.lower() | |
| for pattern, section in SECTION_PATTERNS: | |
| if re.search(pattern, text_lower): | |
| return section | |
| return "" | |
| def detect_slide_number(page_text: str) -> Optional[int]: | |
| lines = [l.strip() for l in page_text.split("\n") if l.strip()] | |
| for line in lines[:4]: | |
| if re.match(r"^\d{1,3}$", line): | |
| try: | |
| return int(line) | |
| except ValueError: | |
| pass | |
| return None | |
| def load_slide_directory() -> dict[int, dict]: | |
| slide_dir_file = DATA_DIR / "slide_directory_index.json" | |
| if not slide_dir_file.exists(): | |
| return {} | |
| with open(slide_dir_file, "r", encoding="utf-8") as f: | |
| rows = json.load(f) | |
| return { | |
| int(row["slide"]): row | |
| for row in rows | |
| if str(row.get("slide", "")).isdigit() | |
| } | |
| def split_slide_list(value: str) -> list[str]: | |
| return [item.strip() for item in str(value or "").split(",") if item.strip()] | |
| def _unique_terms(terms: list[str]) -> list[str]: | |
| seen = set() | |
| result = [] | |
| for term in terms: | |
| cleaned = re.sub(r"\s+", " ", str(term or "")).strip() | |
| if len(cleaned) < 3: | |
| continue | |
| key = cleaned.lower() | |
| if key not in seen: | |
| seen.add(key) | |
| result.append(cleaned) | |
| return result | |
| def metadata_query_terms(info: dict) -> list[str]: | |
| """ | |
| Convert every slide's curated KPI, mapping item, description, and finance | |
| synonym metadata into retrieval training queries. This keeps training | |
| aligned with the slide directory instead of relying only on broad sections. | |
| """ | |
| terms: list[str] = [] | |
| for key in ("kpis", "mapping_items", "synonyms", "kpi_tags"): | |
| values = info.get(key, []) or [] | |
| if isinstance(values, str): | |
| values = split_slide_list(values) | |
| terms.extend(values) | |
| description = str(info.get("description", "") or "").strip() | |
| if description: | |
| terms.append(description) | |
| for fragment in re.split(r"[.;:]", description): | |
| fragment = fragment.strip() | |
| if 12 <= len(fragment) <= 140: | |
| terms.append(fragment) | |
| title = str(info.get("title", "") or "").strip() | |
| if title and not re.match(r"^\d+$", title): | |
| terms.append(title) | |
| return _unique_terms(terms) | |
| def build_page_map(extracted) -> dict: | |
| page_map = {} | |
| slide_directory = load_slide_directory() | |
| for page in extracted.pages: | |
| section = SECTION_BY_PAGE.get(page.page_number) or getattr(page, "section_heading", "") or detect_section(page.text) | |
| slide = getattr(page, "slide_number", None) or detect_slide_number(page.text) or page.page_number | |
| slide_meta = slide_directory.get(page.page_number) or slide_directory.get(slide) | |
| lines = [l.strip() for l in page.text.split("\n") | |
| if l.strip() and len(l.strip()) > 8 and not re.match(r"^\d+$", l.strip())] | |
| title = lines[0][:120] if lines else "" | |
| kpis = list(SECTION_KPI_TAGS.get(section, [])) | |
| if slide_meta: | |
| slide = int(slide_meta.get("slide") or page.page_number) | |
| kpis = split_slide_list(slide_meta.get("kpis", "")) | |
| page_map[page.page_number] = { | |
| "section": section, "slide": slide, | |
| "kpis": kpis, "title": title, | |
| "text_preview": page.text[:200].replace("\n", " ").strip(), | |
| } | |
| if slide_meta: | |
| page_map[page.page_number].update({ | |
| "period": slide_meta.get("period", ""), | |
| "mapping_items": split_slide_list(slide_meta.get("topics", "")), | |
| "synonyms": split_slide_list(slide_meta.get("synonyms", "")), | |
| "description": slide_meta.get("description", ""), | |
| "visual_layout": slide_meta.get("visual_layout", ""), | |
| }) | |
| return page_map | |
| class TrainingPair: | |
| query: str | |
| positive_page: int | |
| positive_text: str | |
| negative_pages: list | |
| negative_texts: list | |
| section: str | |
| kpi_tag: str | |
| def generate_training_pairs(page_map: dict, chunks: list) -> list[TrainingPair]: | |
| """ | |
| Generate (query, positive_chunk, hard_negative_chunks) training pairs. | |
| Positive = chunk from the correct section page. | |
| Hard negative = chunk from a different section with overlapping keywords. | |
| """ | |
| # Group chunks by page | |
| chunks_by_page: dict[int, list] = {} | |
| for c in chunks: | |
| chunks_by_page.setdefault(c.page_number, []).append(c) | |
| # Group pages by section | |
| pages_by_section: dict[str, list[int]] = {} | |
| for pg, info in page_map.items(): | |
| sec = info["section"] | |
| if sec: | |
| pages_by_section.setdefault(sec, []).append(int(pg)) | |
| def page_info_for(page_number: int) -> dict: | |
| return page_map.get(page_number) or page_map.get(str(page_number)) or {} | |
| # Shared hard negatives are selected from other sections with financial data. | |
| all_financial_negatives = [ | |
| c for c in chunks | |
| if getattr(c, "has_financial_data", False) | |
| ] | |
| pairs = [] | |
| for section, queries in KPI_CANONICAL_QUERIES.items(): | |
| correct_pages = pages_by_section.get(section, []) | |
| if not correct_pages: | |
| continue | |
| # Positive chunks = all chunks from pages in this section | |
| pos_chunks = [] | |
| for pg in correct_pages: | |
| pos_chunks.extend(chunks_by_page.get(pg, [])) | |
| if not pos_chunks: | |
| continue | |
| # Hard negative chunks = chunks from other sections with financial keywords | |
| neg_chunks = [] | |
| for other_sec, other_pages in pages_by_section.items(): | |
| if other_sec == section: | |
| continue | |
| for pg in other_pages: | |
| neg_chunks.extend(chunks_by_page.get(pg, [])) | |
| # Keep only negatives that have financial content (hard negatives) | |
| neg_chunks = [c for c in neg_chunks if c.has_financial_data][:10] | |
| for query in queries: | |
| # Use the best positive chunk (most financial keywords) | |
| best_pos = max(pos_chunks, key=lambda c: len(c.financial_keywords_found)) | |
| pairs.append(TrainingPair( | |
| query=query, | |
| positive_page=best_pos.page_number, | |
| positive_text=best_pos.text, | |
| negative_pages=[c.page_number for c in neg_chunks[:5]], | |
| negative_texts=[c.text for c in neg_chunks[:5]], | |
| section=section, | |
| kpi_tag=section.lower().replace(" ", "_"), | |
| )) | |
| # Add page-specific slide training from curated slide metadata. These pairs | |
| # are intentionally tied to the exact slide, so queries such as ESG ratings, | |
| # green bond framework, or business segment performance cannot drift to NIM. | |
| existing = {(p.query.lower(), p.positive_page) for p in pairs} | |
| for pg, info in page_map.items(): | |
| page_chunks = chunks_by_page.get(int(pg), []) | |
| if not page_chunks: | |
| continue | |
| terms = metadata_query_terms(info) | |
| if not terms: | |
| continue | |
| best_pos = max( | |
| page_chunks, | |
| key=lambda c: ( | |
| len(getattr(c, "financial_keywords_found", []) or []), | |
| len(getattr(c, "text", "") or ""), | |
| ), | |
| ) | |
| section = info.get("section", "") | |
| neg_chunks = [ | |
| c for c in all_financial_negatives | |
| if c.page_number != best_pos.page_number | |
| and page_info_for(c.page_number).get("section") != section | |
| ][:10] | |
| for query in terms: | |
| key = (query.lower(), best_pos.page_number) | |
| if key in existing: | |
| continue | |
| existing.add(key) | |
| pairs.append(TrainingPair( | |
| query=query, | |
| positive_page=best_pos.page_number, | |
| positive_text=best_pos.text, | |
| negative_pages=[c.page_number for c in neg_chunks[:5]], | |
| negative_texts=[c.text for c in neg_chunks[:5]], | |
| section=section, | |
| kpi_tag=section.lower().replace(" ", "_"), | |
| )) | |
| logger.info(f" Generated {len(pairs)} training pairs from {len(pages_by_section)} sections") | |
| return pairs | |
| def calibrate_weights(pairs: list[TrainingPair], chunks: list) -> dict: | |
| """ | |
| Grid search over retrieval scoring weights. | |
| Evaluates each weight combination using Mean Reciprocal Rank (MRR). | |
| MRR = average of 1/rank for each query where rank = position of correct page. | |
| Returns the weight dict that maximises MRR. | |
| """ | |
| from sentence_transformers import SentenceTransformer | |
| logger.info("\n [Calibration] Loading embedding model for weight calibrationβ¦") | |
| embed_model = SentenceTransformer("BAAI/bge-small-en-v1.5") | |
| # Pre-compute chunk embeddings | |
| chunk_texts = [c.text for c in chunks] | |
| chunk_embs = embed_model.encode(chunk_texts, normalize_embeddings=True, show_progress_bar=False) | |
| # Pre-compute query embeddings once. The grid search reuses the same | |
| # training queries for every weight combination, so re-encoding them inside | |
| # the scoring loop makes training unnecessarily slow. | |
| pair_query_embs = embed_model.encode( | |
| [pair.query for pair in pairs], | |
| batch_size=64, | |
| normalize_embeddings=True, | |
| show_progress_bar=False, | |
| ) | |
| # Build quick lookup: page β chunks index | |
| page_to_idx: dict[int, list[int]] = {} | |
| for i, c in enumerate(chunks): | |
| page_to_idx.setdefault(c.page_number, []).append(i) | |
| def score_retrieval(pair: TrainingPair, q_emb, metric_w: float, section_boost: float) -> int: | |
| """ | |
| Score a single query and return the rank of the correct page. | |
| Lower rank = better. Returns 99 if correct page not in top-10. | |
| """ | |
| import numpy as np | |
| # Score each chunk: cosine_sim + section_boost if section matches | |
| scored: list[tuple[float, int]] = [] # (score, page_number) | |
| for i, c in enumerate(chunks): | |
| cos_sim = float(np.dot(q_emb, chunk_embs[i])) | |
| # Section heading boost | |
| section_match = ( | |
| pair.section.lower() in (c.section_heading or "").lower() | |
| or (c.section_heading or "").lower() in pair.section.lower() | |
| ) | |
| score = cos_sim + (section_boost * 0.1 if section_match else 0.0) | |
| scored.append((score, c.page_number)) | |
| # Sort by score descending, get unique page order | |
| scored.sort(key=lambda x: -x[0]) | |
| seen_pages = [] | |
| for _, pg in scored: | |
| if pg not in seen_pages: | |
| seen_pages.append(pg) | |
| if len(seen_pages) >= 10: | |
| break | |
| if pair.positive_page in seen_pages: | |
| return seen_pages.index(pair.positive_page) + 1 | |
| return 99 | |
| def compute_mrr(metric_w: float, section_boost: float) -> float: | |
| rr_sum = 0.0 | |
| for pair, q_emb in zip(pairs, pair_query_embs): | |
| rank = score_retrieval(pair, q_emb, metric_w, section_boost) | |
| rr_sum += 1.0 / rank | |
| return rr_sum / len(pairs) | |
| # Grid search | |
| logger.info(" [Calibration] Running grid search over weight combinationsβ¦") | |
| best_mrr = 0.0 | |
| best_config = {"metric_match_weight": 3.0, "section_heading_boost": 4.0} | |
| metric_weights = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] | |
| section_boosts = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 8.0] | |
| caption_boosts = [1.0, 2.0, 3.0] | |
| confidence_thrs = [0.10, 0.15, 0.20, 0.25, 0.30] | |
| total_combinations = len(metric_weights) * len(section_boosts) | |
| evaluated = 0 | |
| for mw in metric_weights: | |
| for sb in section_boosts: | |
| mrr = compute_mrr(mw, sb) | |
| evaluated += 1 | |
| if mrr > best_mrr: | |
| best_mrr = mrr | |
| best_config["metric_match_weight"] = mw | |
| best_config["section_heading_boost"] = sb | |
| logger.info( | |
| f" [Calibration] New best MRR={mrr:.4f} " | |
| f"(metric_w={mw}, section_boost={sb}) " | |
| f"[{evaluated}/{total_combinations}]" | |
| ) | |
| # Find best confidence threshold (based on max score distribution) | |
| logger.info(" [Calibration] Calibrating confidence thresholdβ¦") | |
| import numpy as np | |
| q_embs = embed_model.encode([p.query for p in pairs], normalize_embeddings=True) | |
| max_scores = [] | |
| for qi, pair in enumerate(pairs): | |
| sims = chunk_embs @ q_embs[qi] | |
| max_scores.append(float(sims.max())) | |
| # Set threshold at 10th percentile of max_scores (keeps 90% of real queries) | |
| threshold = float(np.percentile(max_scores, 10)) | |
| best_config["min_confidence_threshold"] = round(max(0.05, threshold), 3) | |
| logger.info( | |
| f"\n β Calibration complete | MRR = {best_mrr:.4f} | " | |
| f"Threshold = {best_config['min_confidence_threshold']:.3f}" | |
| ) | |
| return best_config | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # B. EMBEDDING FINE-TUNING | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def fine_tune_embeddings(pairs: list[TrainingPair], doc_id: str, base_model: str) -> Path: | |
| """ | |
| Fine-tune the embedding model on financial IR domain data using | |
| MultipleNegativesRankingLoss (contrastive learning without explicit negatives). | |
| Saves the fine-tuned model to data/models/<doc_id>_embed/ | |
| """ | |
| from sentence_transformers import SentenceTransformer, InputExample, losses | |
| from torch.utils.data import DataLoader | |
| model_out = MODELS_DIR / f"{doc_id}_embed" | |
| MODELS_DIR.mkdir(parents=True, exist_ok=True) | |
| logger.info(f"\n [Fine-tune] Loading base model: {base_model}") | |
| model = SentenceTransformer(base_model) | |
| # Build training examples: (query, positive_chunk) | |
| examples = [] | |
| for pair in pairs: | |
| if pair.positive_text.strip(): | |
| examples.append(InputExample( | |
| texts=[pair.query, pair.positive_text] | |
| )) | |
| # Add keyword-enriched variant | |
| enriched_query = f"{pair.section}: {pair.query}" | |
| examples.append(InputExample( | |
| texts=[enriched_query, pair.positive_text] | |
| )) | |
| if not examples: | |
| logger.warning(" No training examples generated β skipping fine-tuning") | |
| return Path(base_model) | |
| logger.info(f" [Fine-tune] Training on {len(examples)} examples") | |
| dataloader = DataLoader(examples, shuffle=True, batch_size=16) | |
| loss_fn = losses.MultipleNegativesRankingLoss(model) | |
| warmup_steps = max(1, len(dataloader) // 5) | |
| model.fit( | |
| train_objectives=[(dataloader, loss_fn)], | |
| epochs=3, | |
| warmup_steps=warmup_steps, | |
| show_progress_bar=True, | |
| output_path=str(model_out), | |
| save_best_model=True, | |
| ) | |
| logger.info(f" β Fine-tuned model saved β {model_out}") | |
| return model_out | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # C. ANTI-HALLUCINATION β KPI GROUND TRUTH EXTRACTION | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| PERCENTAGE_KPIS = { | |
| "nim_percent", "npl_ratio", "coverage_ratio", "cet1_ratio", "lcr", "adr", | |
| "cost_income_ratio", "casa_ratio", "gross_loans_yoy", "gross_loans_qoq", | |
| "total_deposits_yoy", "total_deposits_qoq", "nfi_yoy" | |
| } | |
| # Regex patterns to extract specific numeric values from table text / labels | |
| KPI_TABLE_LABEL_PATTERNS = { | |
| "net_profit_current": [ | |
| r"^(?:group net profit|net profit|profit|profit after tax)$", | |
| ], | |
| "profit_before_tax": [ | |
| r"^(?:profit before tax)$", | |
| ], | |
| "nim_percent": [ | |
| r"^(?:nim|net interest margin|net interest margin \(%\))$", | |
| ], | |
| "npl_ratio": [ | |
| r"^(?:npl ratio|npl ratio \(%\))$", | |
| ], | |
| "coverage_ratio": [ | |
| r"^(?:coverage ratio|npl coverage|npl coverage ratio)$", | |
| ], | |
| "cet1_ratio": [ | |
| r"^(?:cet[- ]?1|cet[- ]?1 ratio|common equity tier 1 ratio)$", | |
| ], | |
| "lcr": [ | |
| r"^(?:lcr|liquidity coverage ratio)$", | |
| ], | |
| "adr": [ | |
| r"^(?:adr|advances to deposit ratio)$", | |
| ], | |
| "cost_income_ratio": [ | |
| r"^(?:cost to income ratio|cost-to-income ratio)$", | |
| ], | |
| "total_income": [ | |
| r"^(?:total income)$", | |
| ], | |
| "total_assets": [ | |
| r"^(?:total assets)$", | |
| ], | |
| # New KPIs | |
| "gross_loans_total": [ | |
| r"^(?:total gross loans|gross loans)$", | |
| ], | |
| "gross_loans_yoy": [ | |
| r"^(?:total gross loans|gross loans)$", | |
| ], | |
| "gross_loans_qoq": [ | |
| r"^(?:total gross loans|gross loans)$", | |
| ], | |
| "total_deposits": [ | |
| r"^(?:deposits|customer deposits|total deposits)$", | |
| ], | |
| "total_deposits_yoy": [ | |
| r"^(?:deposits|customer deposits|total deposits)$", | |
| ], | |
| "total_deposits_qoq": [ | |
| r"^(?:deposits|customer deposits|total deposits)$", | |
| ], | |
| "nfi_total": [ | |
| r"^(?:total non-funded income|total non funded income|nfi)$", | |
| ], | |
| "nfi_yoy": [ | |
| r"^(?:total non-funded income|total non funded income|nfi)$", | |
| ], | |
| } | |
| KPI_TEXT_PATTERNS = { | |
| "net_profit_current": [ | |
| r"\b(?:net profit|group net profit|profit after tax)\b[^\n]{0,30}?(?:aed\s*)?([\d,.]+)\s*\b(bn|mn|b|m)\b", | |
| ], | |
| "profit_before_tax": [ | |
| r"\bprofit before tax\b[^\n]{0,30}?(?:aed\s*)?([\d,.]+)\s*\b(bn|mn|b|m)\b", | |
| ], | |
| "nim_percent": [ | |
| r"\b(?:nim|net interest margin)\b[^\n]{0,30}?([\d.]+)\s*%", | |
| ], | |
| "npl_ratio": [ | |
| r"\bnpl ratio\b[^\n]{0,30}?([\d.]+)\s*%", | |
| r"\bnon[- ]performing loan ratio\b[^\n]{0,30}?([\d.]+)\s*%", | |
| ], | |
| "coverage_ratio": [ | |
| r"\b(?:provision\s+)?coverage ratio\b[^\n]{0,30}?([\d.]+)\s*%", | |
| ], | |
| "cet1_ratio": [ | |
| r"\b(?:cet[- ]?1|common equity tier 1)\s*(?:ratio)?\b[^\n]{0,30}?([\d.]+)\s*%", | |
| ], | |
| "lcr": [ | |
| r"\b(?:lcr|liquidity coverage ratio)\b[^\n]{0,30}?([\d.]+)\s*%", | |
| ], | |
| "adr": [ | |
| r"\b(?:adr|advances[- ]to[- ]deposit ratio)\b[^\n]{0,30}?([\d.]+)\s*%", | |
| ], | |
| "cost_income_ratio": [ | |
| r"\bcost[- ]to[- ]income\s*(?:ratio)?\b[^\n]{0,30}?([\d.]+)\s*%", | |
| r"\bcost\s+income\s+ratio\b[^\n]{0,30}?([\d.]+)\s*%", | |
| ], | |
| "total_income": [ | |
| r"\btotal income\b[^\n]{0,30}?(?:aed\s*)?([\d,.]+)\s*\b(bn|mn|b|m)\b", | |
| ], | |
| "total_assets": [ | |
| r"\btotal assets\b[^\n]{0,30}?(?:aed\s*)?([\d,.]+)\s*\b(bn|mn|b|m|trillion)\b", | |
| ], | |
| # New KPIs | |
| "gross_loans_total": [ | |
| r"\b(?:gross loans|total gross loans)\b[^\n]{0,30}?(?:aed\s*)?([\d,.]+)\s*\b(bn|mn|b|m)\b", | |
| ], | |
| "total_deposits": [ | |
| r"\b(?:deposits|customer deposits|total deposits)\b[^\n]{0,30}?(?:aed\s*)?([\d,.]+)\s*\b(bn|mn|b|m)\b", | |
| ], | |
| "casa_ratio": [ | |
| r"\bcasa\s*(?:mix|ratio|stability ratio)?\b[^\n]{0,30}?([\d.]+)\s*%", | |
| ], | |
| "retail_pbt": [ | |
| r"Retail Banking.*?PBT\s+([\d,.]+)", | |
| ], | |
| "cib_pbt": [ | |
| r"Corporate and.*?PBT\s+([\d,.]+)", | |
| ], | |
| "gmt_pbt": [ | |
| r"Global Markets.*?PBT\s+([\d,.]+)", | |
| ], | |
| "denizbank_pbt": [ | |
| r"DenizBank.*?PBT\s+([\d,.]+)", | |
| ], | |
| "nfi_total": [ | |
| r"\b(?:total non-funded income|total non funded income|nfi)\b[^\n]{0,30}?(?:aed\s*)?([\d,.]+)\s*\b(bn|mn|b|m)\b", | |
| ], | |
| "nfi_yoy": [ | |
| r"non[- ]funded income,?\s+up\s+([\d.]+)\s*%\s*yoy", | |
| r"non[- ]funded income up\s+([\d.]+)\s*%\s*yoy", | |
| ], | |
| } | |
| KPI_ALLOWED_SECTIONS = { | |
| "net_profit_current": ["Income Statement", "Group Overview", "Financial Appendix"], | |
| "profit_before_tax": ["Income Statement", "Group Overview", "Financial Appendix"], | |
| "nim_percent": ["Net Interest Margin", "Income Statement", "Group Overview", "Financial Appendix"], | |
| "npl_ratio": ["Asset Quality", "Loans & Deposits", "Group Overview", "Financial Appendix"], | |
| "coverage_ratio": ["Asset Quality", "Group Overview", "Financial Appendix"], | |
| "cet1_ratio": ["Capital Adequacy", "Group Overview", "Financial Appendix"], | |
| "lcr": ["Liquidity", "Group Overview", "Financial Appendix"], | |
| "adr": ["Liquidity", "Group Overview", "Financial Appendix"], | |
| "cost_income_ratio": ["Cost to Income", "Income Statement", "Group Overview", "Financial Appendix"], | |
| "total_income": ["Income Statement", "Group Overview", "Financial Appendix"], | |
| "total_assets": ["Loans & Deposits", "Liquidity", "Group Overview", "Financial Appendix", "Income Statement"], | |
| # New KPIs | |
| "gross_loans_total": ["Loans & Deposits", "Income Statement", "Group Overview", "Financial Appendix"], | |
| "gross_loans_yoy": ["Loans & Deposits", "Income Statement", "Group Overview", "Financial Appendix"], | |
| "gross_loans_qoq": ["Loans & Deposits", "Income Statement", "Group Overview", "Financial Appendix"], | |
| "total_deposits": ["Loans & Deposits", "Income Statement", "Group Overview", "Financial Appendix", "Liquidity"], | |
| "total_deposits_yoy": ["Loans & Deposits", "Income Statement", "Group Overview", "Financial Appendix", "Liquidity"], | |
| "total_deposits_qoq": ["Loans & Deposits", "Income Statement", "Group Overview", "Financial Appendix", "Liquidity"], | |
| "casa_ratio": ["Loans & Deposits", "Liquidity", "Group Overview", "Financial Appendix"], | |
| "retail_pbt": ["Income Statement", "Group Overview", "Financial Appendix", "Asset Quality", "Divisional performance"], | |
| "cib_pbt": ["Income Statement", "Group Overview", "Financial Appendix", "Divisional performance"], | |
| "gmt_pbt": ["Income Statement", "Group Overview", "Financial Appendix", "Divisional performance"], | |
| "denizbank_pbt": ["Income Statement", "Group Overview", "Financial Appendix", "Hyperinflation", "Divisional performance"], | |
| "nfi_total": ["Income Statement", "Group Overview", "Financial Appendix", "Non-Funded Income"], | |
| "nfi_yoy": ["Income Statement", "Group Overview", "Financial Appendix", "Non-Funded Income"], | |
| } | |
| def detect_table_scale(tbl, page_text: str) -> str: | |
| """Detect whether table cells represent Millions (mn) or Billions (bn).""" | |
| headers_str = " ".join(str(h) for h in tbl.get("headers", []) if h).lower() | |
| caption_str = str(tbl.get("caption", "")).lower() | |
| combined = headers_str + " " + caption_str | |
| if re.search(r"\b(?:bn|billion)\b", combined): | |
| return "bn" | |
| if re.search(r"\b(?:mn|million)\b", combined): | |
| return "mn" | |
| # Check page text | |
| page_text_lower = page_text.lower() | |
| has_bn = bool(re.search(r"\b(?:bn|billion)\b", page_text_lower)) | |
| has_mn = bool(re.search(r"\b(?:mn|million)\b", page_text_lower)) | |
| if has_bn and not has_mn: | |
| return "bn" | |
| if has_mn and not has_bn: | |
| return "mn" | |
| return "bn" | |
| def extract_kpi_ground_truth(extracted, page_map: dict) -> dict: | |
| """ | |
| Extract exact numeric KPI values from verified table and text pages. | |
| These values act as ground truth β any generated response that contradicts | |
| them is flagged or overridden. | |
| """ | |
| ground_truth = { | |
| "doc_id": extracted.doc_id, | |
| "doc_name": extracted.doc_name, | |
| "extracted": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), | |
| "kpis": {}, | |
| "page_values": {}, | |
| } | |
| # Ensure page_values has section structures ready | |
| for page in extracted.pages: | |
| pg_info = page_map.get(page.page_number, {}) | |
| section = pg_info.get("section", "") | |
| if section: | |
| ground_truth["page_values"][str(page.page_number)] = { | |
| "section": section, | |
| "values": {} | |
| } | |
| # 1. EXTRACT FROM TABLES FIRST (High priority structured data) | |
| for page in extracted.pages: | |
| pg_info = page_map.get(page.page_number, {}) | |
| section = pg_info.get("section", "") | |
| if not section: | |
| continue | |
| for tbl in page.tables: | |
| for row in tbl.get("rows", []): | |
| if len(row) < 2: | |
| continue | |
| label = str(row[0]).lower().strip() | |
| for kpi_key, patterns in KPI_TABLE_LABEL_PATTERNS.items(): | |
| # Check allowed sections for this KPI | |
| allowed = KPI_ALLOWED_SECTIONS.get(kpi_key, []) | |
| if section not in allowed: | |
| continue | |
| for pat in patterns: | |
| if re.search(pat, label, re.IGNORECASE): | |
| # Resolve column index for yoy/qoq/normal | |
| col_idx = 1 | |
| if kpi_key.endswith("_yoy"): | |
| if len(row) > 3 and "%" in str(row[3]): | |
| col_idx = 3 | |
| elif len(row) > 2 and "%" in str(row[2]): | |
| col_idx = 2 | |
| else: | |
| continue | |
| elif kpi_key.endswith("_qoq"): | |
| if len(row) > 5 and "%" in str(row[5]): | |
| col_idx = 5 | |
| elif len(row) > 4 and "%" in str(row[4]): | |
| col_idx = 4 | |
| else: | |
| continue | |
| cell_str = str(row[col_idx]).strip() | |
| # Strip brackets/symbols for clean numeric extraction | |
| cell_str_clean = cell_str.replace("(", "").replace(")", "").replace("%", "").strip() | |
| num_match = re.search(r"([\d,.]+)", cell_str_clean) | |
| if num_match: | |
| try: | |
| val = float(num_match.group(1).replace(",", "")) | |
| unit = "%" if kpi_key in PERCENTAGE_KPIS else "bn" | |
| # Normalize millions to billions | |
| if unit != "%": | |
| scale = detect_table_scale(tbl, page.text) | |
| if scale == "mn": | |
| val = val / 1000 | |
| unit = "bn" | |
| elif "mn" in cell_str.lower() or "m" in cell_str.lower(): | |
| val = val / 1000 | |
| unit = "bn" | |
| kpi_data = { | |
| "value": round(val, 3), | |
| "unit": unit, | |
| "raw": cell_str, | |
| "page": page.page_number, | |
| "section": section, | |
| } | |
| # Update global and page values | |
| if kpi_key not in ground_truth["kpis"]: | |
| ground_truth["kpis"][kpi_key] = kpi_data | |
| page_key = str(page.page_number) | |
| if kpi_key not in ground_truth["page_values"][page_key]["values"]: | |
| ground_truth["page_values"][page_key]["values"][kpi_key] = kpi_data | |
| except ValueError: | |
| pass | |
| break | |
| break | |
| # 2. EXTRACT FROM PAGE TEXT (Low priority fallback for missing values) | |
| for page in extracted.pages: | |
| pg_info = page_map.get(page.page_number, {}) | |
| section = pg_info.get("section", "") | |
| if not section: | |
| continue | |
| full_text = page.text | |
| for kpi_key, patterns in KPI_TEXT_PATTERNS.items(): | |
| if kpi_key in ground_truth["kpis"]: | |
| continue | |
| allowed = KPI_ALLOWED_SECTIONS.get(kpi_key, []) | |
| if section not in allowed: | |
| continue | |
| for pattern in patterns: | |
| flags = re.DOTALL if "PBT" in pattern or kpi_key in ["retail_pbt", "cib_pbt", "gmt_pbt", "denizbank_pbt"] else 0 | |
| match = re.search(pattern, full_text, re.IGNORECASE | flags) | |
| if match: | |
| try: | |
| raw_val = match.group(1).replace(",", "") | |
| unit = "%" if kpi_key in PERCENTAGE_KPIS else "bn" | |
| value = float(raw_val) | |
| # Normalize text PBT values or other millions values to billions | |
| if unit != "%" and (kpi_key in ["retail_pbt", "cib_pbt", "gmt_pbt", "denizbank_pbt"] or "mn" in match.group(0).lower() or re.search(r"\b(?:mn|million)\b", match.group(0), re.I)): | |
| if value > 10.0: | |
| value = value / 1000 | |
| unit = "bn" | |
| kpi_data = { | |
| "value": round(value, 3), | |
| "unit": unit, | |
| "raw": match.group(0)[:80].strip(), | |
| "page": page.page_number, | |
| "section": section, | |
| } | |
| # Update global and page values | |
| if kpi_key not in ground_truth["kpis"]: | |
| ground_truth["kpis"][kpi_key] = kpi_data | |
| page_key = str(page.page_number) | |
| if kpi_key not in ground_truth["page_values"][page_key]["values"]: | |
| ground_truth["page_values"][page_key]["values"][kpi_key] = kpi_data | |
| except ValueError: | |
| pass | |
| break | |
| # Clean up empty page_values keys | |
| empty_pages = [k for k, v in ground_truth["page_values"].items() if not v["values"]] | |
| for k in empty_pages: | |
| del ground_truth["page_values"][k] | |
| logger.info(f" Extracted {len(ground_truth['kpis'])} ground truth KPI values") | |
| return ground_truth | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # UPDATE GENERATION AGENT WITH CALIBRATED PARAMS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def update_generation_agent(config: dict, kpi_gt: dict): | |
| """ | |
| Update financial_analyst_agent.py with calibrated parameters: | |
| 1. Temperature from calibrated config | |
| 2. KPI constraints written into response_rules.json | |
| 3. Confidence threshold update | |
| """ | |
| agent_path = BASE_DIR / "services" / "generation" / "financial_analyst_agent.py" | |
| if not agent_path.exists(): | |
| logger.warning(f" Agent file not found: {agent_path}") | |
| return | |
| with open(agent_path, "r") as f: | |
| content = f.read() | |
| # 1. Update default temperature | |
| new_temp = config.get("ollama_temperature", 0.05) | |
| content = re.sub( | |
| r"(def __init__.*?temperature:\s*float\s*=\s*)[\d.]+", | |
| lambda m: m.group(0).rsplit("=", 1)[0] + f"= {new_temp}", | |
| content, | |
| count=1, | |
| ) | |
| # 2. Update confidence threshold | |
| threshold = config.get("min_confidence_threshold", 0.20) | |
| content = re.sub( | |
| r"(min_confidence_threshold\s*=\s*)[\d.]+", | |
| f"\\g<1>{threshold}", | |
| content, | |
| ) | |
| with open(agent_path, "w") as f: | |
| f.write(content) | |
| # 3. Write verified KPI constraints to response_rules.json | |
| rules_file = BASE_DIR / "data" / "response_rules.json" | |
| kpi_label_map = { | |
| "net_profit_current": "Net Profit", | |
| "profit_before_tax": "Profit Before Tax", | |
| "nim_percent": "Net Interest Margin (NIM)", | |
| "npl_ratio": "NPL Ratio", | |
| "coverage_ratio": "Coverage Ratio", | |
| "cet1_ratio": "CET-1 Ratio", | |
| "lcr": "Liquidity Coverage Ratio (LCR)", | |
| "adr": "Advances-to-Deposit Ratio (ADR)", | |
| "cost_income_ratio": "Cost-to-Income Ratio", | |
| "total_income": "Total Income", | |
| "total_assets": "Total Assets", | |
| "gross_loans_total": "Total Gross Loans", | |
| "gross_loans_yoy": "Gross Loans YoY Change", | |
| "total_deposits": "Total Deposits", | |
| "total_deposits_yoy": "Total Deposits YoY Change", | |
| "casa_ratio": "CASA Ratio", | |
| "retail_pbt": "Retail Segment PBT", | |
| "cib_pbt": "CIB Segment PBT", | |
| "gmt_pbt": "GM&T Segment PBT", | |
| "denizbank_pbt": "DenizBank Segment PBT", | |
| "nfi_total": "Total Non-Funded Income", | |
| "nfi_yoy": "Non-Funded Income YoY Change", | |
| } | |
| kpi_lines = [] | |
| for key, label in kpi_label_map.items(): | |
| gt = kpi_gt.get("kpis", {}).get(key) | |
| if gt: | |
| unit = gt.get("unit", "") | |
| val = gt.get("value", "") | |
| pg = gt.get("page", "") | |
| unit_str = "%" if "%" in str(unit) else f" {unit}" | |
| kpi_lines.append(f" - {label}: {val}{unit_str} (Page {pg})") | |
| if kpi_lines and rules_file.exists(): | |
| try: | |
| with open(rules_file) as f: | |
| rules = json.load(f) | |
| rules["verified_kpis"] = kpi_lines | |
| rules["verified_kpis_updated"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) | |
| with open(rules_file, "w") as f: | |
| json.dump(rules, f, indent=2, ensure_ascii=False) | |
| logger.info(f" β {len(kpi_lines)} KPI constraints written to response_rules.json") | |
| except Exception as e: | |
| logger.warning(f" Could not update response_rules.json: {e}") | |
| logger.info( | |
| f" β Agent updated: temperature={new_temp}, " | |
| f"threshold={threshold}, " | |
| f"{len(kpi_lines)} KPI constraints injected" | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TABLE INDEXING WITH ENRICHED METADATA | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def enrich_and_index_tables(extracted, page_map: dict, doc_id: str, doc_name: str) -> list: | |
| from services.retrieval.table_retriever import FINANCIAL_METRIC_ALIASES | |
| def detect_metrics(text: str, section: str) -> set: | |
| found = set() | |
| tl = (text + " " + section).lower() | |
| for key, aliases in FINANCIAL_METRIC_ALIASES.items(): | |
| if any(a in tl for a in aliases): | |
| found.add(key) | |
| return found | |
| def table_to_text(tbl: dict) -> str: | |
| headers = tbl.get("headers", []) | |
| rows = tbl.get("rows", []) | |
| caption = tbl.get("caption", "") | |
| parts = [] | |
| if caption: | |
| parts.append(f"Table: {caption}") | |
| if headers: | |
| parts.append(" | ".join(h for h in headers if h)) | |
| for row in rows[:30]: | |
| parts.append(" | ".join(str(c) for c in row)) | |
| return "\n".join(parts) | |
| records = [] | |
| for page in extracted.pages: | |
| pmap = page_map.get(page.page_number, {}) | |
| section = pmap.get("section", "") | |
| slide = pmap.get("slide", page.page_number) | |
| sec_kpis = set(SECTION_KPI_TAGS.get(section, [])) | |
| for tbl in page.tables: | |
| text_rep = table_to_text(tbl) | |
| text_metrics = detect_metrics(text_rep, section) | |
| records.append({ | |
| "doc_id": doc_id, | |
| "doc_name": doc_name, | |
| "page_number": page.page_number, | |
| "slide_number": slide, | |
| "section_heading": section, | |
| "headers": tbl.get("headers", []), | |
| "rows": tbl.get("rows", []), | |
| "caption": tbl.get("caption", ""), | |
| "text_representation": text_rep, | |
| "metrics_found": sorted(text_metrics | sec_kpis), | |
| }) | |
| # Merge β keep records for other docs | |
| existing = [] | |
| if TABLES_FILE.exists(): | |
| try: | |
| with open(TABLES_FILE) as f: | |
| existing = json.load(f) | |
| existing = [r for r in existing if r.get("doc_id") != doc_id] | |
| except Exception: | |
| existing = [] | |
| all_records = existing + records | |
| TABLES_FILE.parent.mkdir(parents=True, exist_ok=True) | |
| with open(TABLES_FILE, "w", encoding="utf-8") as f: | |
| json.dump(all_records, f, indent=2, ensure_ascii=False) | |
| return records | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MAIN PIPELINE | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run(pdf_path: Path, doc_id: str, skip_colpali: bool, do_fine_tune: bool): | |
| doc_name = " ".join(w.capitalize() for w in doc_id.replace("_", " ").split()) | |
| total_start = time.time() | |
| logger.info("=" * 72) | |
| logger.info(" IRIS β IR Document Training Pipeline") | |
| logger.info(f" PDF : {pdf_path.name}") | |
| logger.info(f" Doc ID : {doc_id}") | |
| logger.info(f" Options : fine-tune={do_fine_tune}, skip-colpali={skip_colpali}") | |
| logger.info("=" * 72) | |
| if not pdf_path.exists(): | |
| logger.error(f"PDF not found: {pdf_path}") | |
| sys.exit(1) | |
| # Load existing config or use defaults | |
| config = {**DEFAULT_CONFIG} | |
| if CONFIG_FILE.exists(): | |
| with open(CONFIG_FILE) as f: | |
| config.update(json.load(f)) | |
| # ββ 1. Parse PDF βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info("\n[1/9] Parsing PDFβ¦") | |
| from services.ingestion.pdf_parser import parse_pdf, save_extraction | |
| t0 = time.time() | |
| extracted = parse_pdf(pdf_path, doc_id, doc_name) | |
| save_extraction(extracted, PROC_DIR) | |
| logger.info(f" β {extracted.total_pages} pages | {time.time()-t0:.1f}s") | |
| # ββ 2. Build page map βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info("\n[2/9] Building page β section β KPI mappingβ¦") | |
| page_map = build_page_map(extracted) | |
| PAGEMAP_DIR.mkdir(parents=True, exist_ok=True) | |
| pagemap_file = PAGEMAP_DIR / f"{doc_id}_pagemap.json" | |
| with open(pagemap_file, "w", encoding="utf-8") as f: | |
| json.dump({ | |
| "doc_id": doc_id, "doc_name": doc_name, | |
| "generated": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), | |
| "note": "Auto-generated. Edit section/kpis and re-run to apply corrections.", | |
| "pages": {str(k): v for k, v in page_map.items()}, | |
| }, f, indent=2, ensure_ascii=False) | |
| # Print mapping table | |
| logger.info(f" {'Page':>4} | {'Slide':>5} | {'Section':<28} | Title") | |
| logger.info(f" {'----':>4}-+-{'-----':>5}-+-{'-'*28}-+-{'-'*30}") | |
| for pg in sorted(page_map): | |
| info = page_map[pg] | |
| logger.info( | |
| f" P{pg:02d} | S{info['slide']:02d} | {info['section']:<28} | {info['title'][:40]}" | |
| ) | |
| # ββ 3. Extract KPI ground truth βββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info("\n[3/9] Extracting KPI ground truth values for hallucination preventionβ¦") | |
| kpi_gt = extract_kpi_ground_truth(extracted, page_map) | |
| KPI_GT_FILE.parent.mkdir(parents=True, exist_ok=True) | |
| # Merge with existing ground truth | |
| existing_gt = {} | |
| if KPI_GT_FILE.exists(): | |
| try: | |
| with open(KPI_GT_FILE) as f: | |
| existing_gt = json.load(f) | |
| except Exception: | |
| existing_gt = {} | |
| existing_gt[doc_id] = kpi_gt | |
| with open(KPI_GT_FILE, "w", encoding="utf-8") as f: | |
| json.dump(existing_gt, f, indent=2, ensure_ascii=False) | |
| logger.info(" Extracted KPI values:") | |
| for key, val in kpi_gt.get("kpis", {}).items(): | |
| logger.info(f" {key:<25} = {val['value']} {val['unit']} (Page {val['page']}, {val['section']})") | |
| # ββ 4. Chunk text βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info("\n[4/9] Chunking text with section metadataβ¦") | |
| import chromadb | |
| from chromadb.config import Settings | |
| try: | |
| client = chromadb.PersistentClient( | |
| path=str(CHROMA_DIR), settings=Settings(anonymized_telemetry=False) | |
| ) | |
| client.delete_collection("finbot_ir_chunks") | |
| logger.info(" β Old ChromaDB collection cleared") | |
| except Exception: | |
| pass | |
| from services.ingestion.text_chunker import TextChunker | |
| t0 = time.time() | |
| chunker = TextChunker( | |
| chunk_size=config["chunk_size"], | |
| chunk_overlap=config["chunk_overlap"], | |
| ) | |
| page_texts = [] | |
| for p in extracted.pages: | |
| if not p.text.strip(): | |
| continue | |
| pmap = page_map.get(p.page_number, {}) | |
| page_texts.append({ | |
| "page_number": p.page_number, | |
| "text": p.text, | |
| "section_heading": pmap.get("section", getattr(p, "section_heading", "") or ""), | |
| "slide_number": pmap.get("slide", getattr(p, "slide_number", None) or p.page_number), | |
| }) | |
| chunks = chunker.chunk_document(page_texts, doc_id) | |
| for chunk in chunks: | |
| pmap = page_map.get(chunk.page_number, {}) or page_map.get(str(chunk.page_number), {}) | |
| setattr(chunk, "section_heading", pmap.get("section", "")) | |
| logger.info(f" β {len(chunks)} chunks | {time.time()-t0:.1f}s") | |
| # ββ 5. Generate training pairs & calibrate weights βββββββββββββββββββββββββ | |
| logger.info("\n[5/9] Generating training pairs and calibrating retrieval weightsβ¦") | |
| pairs = generate_training_pairs(page_map, chunks) | |
| calibrated = calibrate_weights(pairs, chunks) | |
| # Merge calibrated weights into full config | |
| config.update(calibrated) | |
| config["last_calibrated"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) | |
| config["calibrated_on_doc"] = doc_id | |
| # ββ 6. Optional: Fine-tune embedding model ββββββββββββββββββββββββββββββββ | |
| embed_model_path = config["embed_model"] | |
| if do_fine_tune and pairs: | |
| logger.info("\n[6/9] Fine-tuning embedding model on domain dataβ¦") | |
| fine_tuned_path = fine_tune_embeddings(pairs, doc_id, config["embed_model"]) | |
| embed_model_path = str(fine_tuned_path) | |
| config["embed_model"] = embed_model_path | |
| config["fine_tuned"] = True | |
| else: | |
| logger.info("\n[6/9] Skipping fine-tuning (use --fine-tune to enable)") | |
| # Save config BEFORE embedding so embedding uses correct model | |
| CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True) | |
| with open(CONFIG_FILE, "w") as f: | |
| json.dump(config, f, indent=2) | |
| logger.info(f" β Retrieval config saved β {CONFIG_FILE.name}") | |
| # ββ 7. Embed + index chunks ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info("\n[7/9] Embedding chunks and indexing in ChromaDBβ¦") | |
| from services.retrieval.text_retriever import TextRetriever | |
| t0 = time.time() | |
| chunks = chunker.embed_chunks(chunks) | |
| text_retriever = TextRetriever(persist_dir=CHROMA_DIR) | |
| n_chunks = text_retriever.index_chunks(chunks) | |
| logger.info(f" β {n_chunks} chunks embedded and indexed | {time.time()-t0:.1f}s") | |
| # ββ 8. Index tables ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info("\n[8/9] Indexing tables with enriched metadataβ¦") | |
| table_records = enrich_and_index_tables(extracted, page_map, doc_id, doc_name) | |
| logger.info(f" β {len(table_records)} tables indexed") | |
| # ββ Update generation agent ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info("\n Updating generation agent with calibrated configβ¦") | |
| update_generation_agent(config, kpi_gt) | |
| # ββ 8b. Auto-generate Smart Response Cache ββββββββββββββββββββββββββββββββ | |
| # Pre-fills document-agnostic templates with real KPI values so the Smart | |
| # Response Engine can serve instant (<10ms) answers for any trained PDF. | |
| logger.info("\n[8b] Auto-generating smart response cache for new engineβ¦") | |
| try: | |
| from services.classification.kpi_context_builder import KPIContextBuilder, INTENT_KPI_KEYS | |
| from services.generation.smart_response_engine import ( | |
| SmartResponseEngine, save_cached_response, _INTENT_BUILDERS | |
| ) | |
| _cache_builder = KPIContextBuilder(data_dir=DATA_DIR) | |
| cache_results: dict[str, str] = {} | |
| for intent in INTENT_KPI_KEYS.keys(): | |
| ctx = _cache_builder.build(intent=intent, doc_ids=[doc_id]) | |
| if ctx.has_data: | |
| builder_fn = _INTENT_BUILDERS.get(intent) | |
| if builder_fn: | |
| try: | |
| resp = builder_fn(ctx) | |
| if resp: | |
| resp["question"] = f"[auto-generated for intent: {intent}]" | |
| resp["latency_ms"] = 0 | |
| saved = save_cached_response(doc_id, intent, resp) | |
| cache_results[intent] = "β" if saved else "β" | |
| else: | |
| cache_results[intent] = "β (no KPI match in template)" | |
| except Exception as be: | |
| cache_results[intent] = f"β ({be})" | |
| else: | |
| cache_results[intent] = "β (no template builder)" | |
| else: | |
| cache_results[intent] = "β (no KPI data extracted)" | |
| for intent, status in cache_results.items(): | |
| logger.info(f" {status} {intent}") | |
| n_cached = sum(1 for s in cache_results.values() if s == "β") | |
| logger.info(f" β {n_cached}/{len(cache_results)} smart response templates generated") | |
| logger.info(f" Cached to: data/response_cache/{doc_id}/") | |
| except Exception as e: | |
| logger.warning(f" β Smart cache generation failed (non-fatal): {e}") | |
| # ββ 9. Render pages + ColPali βββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info("\n[9/9] Rendering PDF pagesβ¦") | |
| from services.ingestion.page_renderer import render_pdf_pages | |
| t0 = time.time() | |
| PAGES_DIR.mkdir(parents=True, exist_ok=True) | |
| page_records = render_pdf_pages( | |
| pdf_path=pdf_path, | |
| output_dir=PAGES_DIR, | |
| doc_id=doc_id, | |
| ) | |
| logger.info(f" β {len(page_records)} pages rendered | {time.time()-t0:.1f}s") | |
| colpali_pages = len(page_records) | |
| if not skip_colpali: | |
| logger.info(" Generating ColPali visual embeddings (batch_size=1)β¦") | |
| import torch | |
| if torch.backends.mps.is_available(): | |
| os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0" | |
| from services.ingestion.colpali_indexer import ColPaliIndexer | |
| t0 = time.time() | |
| indexer = ColPaliIndexer(store_dir=COLPALI_DIR) | |
| page_records = indexer.index_pages(page_records, doc_id=doc_id, batch_size=1) | |
| colpali_pages = len(page_records) | |
| logger.info(f" β {colpali_pages} ColPali pages | {time.time()-t0:.1f}s") | |
| else: | |
| logger.info(" Skipping ColPali (--skip-colpali)") | |
| # ββ Update documents registry ββββββββββββββββββββββββββββββββββββββββββββββ | |
| docs = [] | |
| if DOCS_FILE.exists(): | |
| with open(DOCS_FILE) as f: | |
| docs = json.load(f) | |
| docs = [d for d in docs if d["doc_id"] != doc_id] | |
| docs.append({ | |
| "doc_id": doc_id, | |
| "name": doc_name, | |
| "doc_type": extracted.metadata.get("doc_type", "Financial Document"), | |
| "period": extracted.metadata.get("period", "Unknown"), | |
| "institution": extracted.metadata.get("institution", "Unknown"), | |
| "total_pages": extracted.total_pages, | |
| "status": "indexed", | |
| "filename": pdf_path.name, | |
| "chunks_indexed": n_chunks, | |
| "tables_indexed": len(table_records), | |
| "colpali_pages": colpali_pages, | |
| "pagemap_file": pagemap_file.name, | |
| "page_section_map": {str(k): v["section"] for k, v in page_map.items()}, | |
| "retrieval_config": str(CONFIG_FILE.name), | |
| }) | |
| DATA_DIR.mkdir(parents=True, exist_ok=True) | |
| with open(DOCS_FILE, "w") as f: | |
| json.dump(docs, f, indent=2) | |
| # ββ Final report βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| total_elapsed = time.time() - total_start | |
| mrr = calibrated.get("mrr", "β") | |
| logger.info("") | |
| logger.info("=" * 72) | |
| logger.info(f" β Training complete in {total_elapsed:.1f}s") | |
| logger.info(f" {'β'*60}") | |
| logger.info(f" Text chunks indexed : {n_chunks}") | |
| logger.info(f" Tables indexed : {len(table_records)}") | |
| logger.info(f" KPI ground truths : {len(kpi_gt.get('kpis', {}))}") | |
| logger.info(f" Training pairs : {len(pairs)}") | |
| logger.info(f" Embedding model : {embed_model_path}") | |
| logger.info(f" Fine-tuned : {do_fine_tune}") | |
| logger.info(f" Calibrated weights:") | |
| logger.info(f" metric_match : {config.get('metric_match_weight')}") | |
| logger.info(f" section_boost : {config.get('section_heading_boost')}") | |
| logger.info(f" confidence floor : {config.get('min_confidence_threshold')}") | |
| logger.info(f" LLM temperature : {config.get('ollama_temperature')}") | |
| logger.info(f" {'β'*60}") | |
| logger.info(f" Config saved to : data/{CONFIG_FILE.name}") | |
| logger.info(f" KPI ground truth : data/{KPI_GT_FILE.name}") | |
| logger.info(f" Page mapping : data/page_maps/{pagemap_file.name}") | |
| logger.info("=" * 72) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="IRIS β Train on IR PDF, calibrate weights, reduce hallucination", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| # Standard training (parse + calibrate + index) | |
| python train_document.py --pdf ../documents/enbd_q1_2026.pdf | |
| # Full training including embedding fine-tuning (~20 min extra) | |
| python train_document.py --pdf ../documents/enbd_q1_2026.pdf --fine-tune | |
| # Fast mode β skip ColPali visual embeddings | |
| python train_document.py --pdf ../documents/report.pdf --skip-colpali | |
| # Custom document ID | |
| python train_document.py --pdf ../documents/report.pdf --doc-id enbd_fy2025 | |
| """, | |
| ) | |
| parser.add_argument("--pdf", required=True, help="Path to IR PDF") | |
| parser.add_argument("--doc-id", default=None, help="Document ID (derived from filename if omitted)") | |
| parser.add_argument("--skip-colpali", action="store_true", help="Skip ColPali visual embedding") | |
| parser.add_argument("--fine-tune", action="store_true", help="Fine-tune the embedding model on domain data") | |
| args = parser.parse_args() | |
| pdf = Path(args.pdf).resolve() | |
| d_id = args.doc_id or pdf.stem.lower().replace(" ", "_").replace("-", "_") | |
| run(pdf_path=pdf, doc_id=d_id, skip_colpali=args.skip_colpali, do_fine_tune=args.fine_tune) | |