Spaces:
Sleeping
Sleeping
| """ | |
| Adaptive Multimodal Fusion for DocVQA β Gradio Demo | |
| ==================================================== | |
| Run locally: | |
| python app.py | |
| Run with public URL (72hr): | |
| python app.py --share | |
| Deploy to HuggingFace Spaces: | |
| - Push this file + requirements.txt + checkpoints/ folder to a Space repo | |
| - HF Spaces auto-launches on port 7860 | |
| """ | |
| import argparse, os, sys, copy, json, warnings | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| import editdistance | |
| from PIL import Image as PILImage, ImageDraw as PILDraw, ImageFont as PILFont | |
| warnings.filterwarnings("ignore") | |
| # ββ CLI args ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--share", action="store_true", help="Create public Gradio URL") | |
| parser.add_argument("--port", type=int, default=7860) | |
| parser.add_argument("--ckpt_dir", type=str, default="./checkpoints", | |
| help="Folder containing all saved files") | |
| args, _ = parser.parse_known_args() | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CONFIGURATION β edit paths here if needed | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CKPT_DIR = args.ckpt_dir | |
| ORACLE_CACHE = os.path.join(CKPT_DIR, "oracle_cache.json") | |
| FEAT_PATH = os.path.join(CKPT_DIR, "feature_tensors.pt") | |
| RESULTS_PATH = os.path.join(CKPT_DIR, "final_results.json") | |
| # Try final checkpoint first, fall back to intermediate | |
| CKPT_PATH = os.path.join(CKPT_DIR, "cafp_rl_checkpoint_final.pt") | |
| if not os.path.exists(CKPT_PATH): | |
| CKPT_PATH = os.path.join(CKPT_DIR, "cafp_rl_checkpoint.pt") | |
| # Dataset / model IDs | |
| DATASET_NAME = "nielsr/docvqa_1200_examples" | |
| FEAT_MODEL_ID = "microsoft/layoutlmv3-base" | |
| VQA_MODEL_ID = "rubentito/layoutlmv3-base-mpdocvqa" | |
| SBERT_ID = "all-MiniLM-L6-v2" | |
| # Field names | |
| WORD_FIELD = "words" | |
| BOX_FIELD = "bounding_boxes" | |
| QUERY_FIELD = "query" | |
| ANSWER_FIELD = "answers" | |
| # Architecture | |
| MAX_WORDS = 64 | |
| N_PATCHES = 49 | |
| N_VAL = 100 | |
| N_TRAIN = 100 | |
| FEAT_DIM = 2701 | |
| PROJ_DIM = 128 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Device: {device}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MODEL CLASSES | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class CrossAttentionFusionPredictor(nn.Module): | |
| def __init__(self, feat_dim=FEAT_DIM, proj_dim=PROJ_DIM, | |
| n_heads=4, dropout=0.15): | |
| super().__init__() | |
| self.text_proj = nn.Linear(768, proj_dim) | |
| self.visual_proj = nn.Linear(768, proj_dim) | |
| self.spatial_proj_lyr = nn.Linear(768, proj_dim) | |
| self.q_proj = nn.Sequential( | |
| nn.Linear(384, proj_dim), nn.LayerNorm(proj_dim), nn.GELU() | |
| ) | |
| self.cross_attn = nn.MultiheadAttention( | |
| proj_dim, n_heads, dropout=dropout, batch_first=True | |
| ) | |
| self.attn_norm = nn.LayerNorm(proj_dim) | |
| self.head = nn.Sequential( | |
| nn.Linear(proj_dim + 3, proj_dim), nn.GELU(), | |
| nn.Dropout(dropout), nn.Linear(proj_dim, 3) | |
| ) | |
| def _logits(self, x): | |
| h_t = self.text_proj(x[:, 0:768]) | |
| h_v = self.visual_proj(x[:, 768:1536]) | |
| h_s = self.spatial_proj_lyr(x[:, 1536:2304]) | |
| q = self.q_proj(x[:, 2314:2698]).unsqueeze(1) | |
| kv = torch.stack([h_t, h_v, h_s], dim=1) | |
| ctx, _ = self.cross_attn(q, kv, kv) | |
| ctx = self.attn_norm(ctx.squeeze(1)) | |
| return self.head(torch.cat([ctx, x[:, 2698:2701]], dim=-1)) | |
| def forward(self, x): | |
| return F.softmax(self._logits(x), dim=-1) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # LOAD BASE MODELS | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("Loading base models (takes ~5 min on first run, cached after)...") | |
| from transformers import AutoProcessor, AutoModel, AutoModelForQuestionAnswering | |
| from sentence_transformers import SentenceTransformer | |
| feat_processor = AutoProcessor.from_pretrained(FEAT_MODEL_ID, apply_ocr=False) | |
| feat_model = AutoModel.from_pretrained(FEAT_MODEL_ID).to(device).eval() | |
| for p in feat_model.parameters(): p.requires_grad_(False) | |
| print(" β LayoutLMv3 feature model") | |
| vqa_processor = AutoProcessor.from_pretrained(VQA_MODEL_ID, apply_ocr=False) | |
| vqa_model = AutoModelForQuestionAnswering.from_pretrained( | |
| VQA_MODEL_ID).to(device).eval() | |
| for p in vqa_model.parameters(): p.requires_grad_(False) | |
| print(" β VQA model") | |
| sbert = SentenceTransformer(SBERT_ID) | |
| sbert.to(device) | |
| print(" β SBERT") | |
| spatial_proj = nn.Sequential( | |
| nn.Linear(10, 256), nn.ReLU(), nn.Linear(256, 768) | |
| ).to(device) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # HELPER FUNCTIONS | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_question(item): | |
| q = item.get(QUERY_FIELD, item.get("question", "")) | |
| if isinstance(q, dict): | |
| q = q.get("en", next(iter(q.values()), "")) | |
| return str(q).strip() | |
| def normalize_boxes(boxes, w, h): | |
| return [ | |
| [ | |
| int(max(0, min(b[0] / max(w, 1), 1)) * 1000), | |
| int(max(0, min(b[1] / max(h, 1), 1)) * 1000), | |
| int(max(0, min(b[2] / max(w, 1), 1)) * 1000), | |
| int(max(0, min(b[3] / max(h, 1), 1)) * 1000), | |
| ] | |
| for b in boxes | |
| ] | |
| def extract_rich_features(item): | |
| try: | |
| img = item["image"].convert("RGB") | |
| W, H = img.size | |
| words = list(item.get(WORD_FIELD, []))[:MAX_WORDS] or ["[PAD]"] | |
| boxes = list(item.get(BOX_FIELD, []))[:MAX_WORDS] or [[0, 0, 1, 1]] | |
| question = get_question(item) | |
| bn = normalize_boxes(boxes, W, H) | |
| enc = feat_processor(img, text=words, boxes=bn, | |
| return_tensors="pt", truncation=True, | |
| max_length=512, padding="max_length") | |
| enc = {k: v.to(device) for k, v in enc.items()} | |
| with torch.no_grad(): | |
| hidden = feat_model(**enc).last_hidden_state[0] | |
| n_txt = max(2, hidden.shape[0] - N_PATCHES) | |
| H_text = hidden[1:n_txt-1].mean(0) if n_txt > 2 else hidden[0] | |
| H_visual = hidden[-N_PATCHES:].mean(0) | |
| bx = np.array(bn, dtype=np.float32) | |
| cx = ((bx[:, 0] + bx[:, 2]) / 2) / 1000.0 | |
| cy = ((bx[:, 1] + bx[:, 3]) / 2) / 1000.0 | |
| sp = np.array([ | |
| W / 1000.0, H / 1000.0, min(W, H) / max(W, H), | |
| len(words) / MAX_WORDS, | |
| cx.mean(), cy.mean(), cx.std() + 1e-6, cy.std() + 1e-6, | |
| H_text.norm().item() / 10.0, | |
| H_visual.norm().item() / 10.0, | |
| ], dtype=np.float32) | |
| sp10 = torch.tensor(sp).to(device) | |
| H_spat = spatial_proj(sp10.unsqueeze(0)).squeeze(0) | |
| q_emb = torch.tensor(sbert.encode(question), | |
| dtype=torch.float32).to(device) | |
| return { | |
| "H_text": H_text, "H_visual": H_visual, "H_spatial": H_spat, | |
| "spatial_10": sp10, "question_emb": q_emb, | |
| "text_score": float(np.clip(sp[8], 0, 1)), | |
| "visual_score": float(np.clip(sp[9], 0, 1)), | |
| "spatial_score": float(np.clip(sp[6], 0, 1)), | |
| "n_tokens": len(words), | |
| } | |
| except Exception as e: | |
| print(f" extract_rich_features error: {e}") | |
| dummy = torch.zeros(768, device=device) | |
| return { | |
| "H_text": dummy, "H_visual": dummy, "H_spatial": dummy, | |
| "spatial_10": torch.zeros(10, device=device), | |
| "question_emb": torch.zeros(384, device=device), | |
| "text_score": 0.5, "visual_score": 0.3, "spatial_score": 0.2, | |
| "n_tokens": 0, | |
| } | |
| def build_feature_vector(feat): | |
| return torch.cat([ | |
| feat["H_text"], feat["H_visual"], feat["H_spatial"], | |
| feat["spatial_10"], feat["question_emb"], | |
| torch.tensor( | |
| [feat["text_score"], feat["visual_score"], feat["spatial_score"]], | |
| dtype=torch.float32, device=device | |
| ), | |
| ]) | |
| def vqa_infer(item, alpha, beta, gamma): | |
| try: | |
| img = item["image"].convert("RGB") | |
| words = list(item.get(WORD_FIELD, [])) | |
| boxes = list(item.get(BOX_FIELD, [])) | |
| question = get_question(item) | |
| if not words: | |
| return "" | |
| W, H = img.size | |
| n = len(words) | |
| n_keep = max(int(n * max(float(alpha), 0.30)), min(5, n)) | |
| if float(gamma) > max(float(alpha), float(beta)) and boxes: | |
| order = sorted(range(n), key=lambda i: (boxes[i][1], boxes[i][0])) | |
| sel_idx = sorted(order[:n_keep]) | |
| else: | |
| sel_idx = list(range(n_keep)) | |
| sw = [words[i] for i in sel_idx] | |
| sb = ([boxes[i] for i in sel_idx] | |
| if boxes else [[0, 0, W, H]] * len(sw)) | |
| enc = vqa_processor( | |
| img, text=question, text_pair=sw, | |
| boxes=normalize_boxes(sb, W, H), | |
| return_tensors="pt", truncation=True, | |
| max_length=512, padding=True | |
| ) | |
| enc = {k: v.to(device) for k, v in enc.items()} | |
| with torch.no_grad(): | |
| out = vqa_model(**enc) | |
| s = int(out.start_logits.argmax()) | |
| e = int(out.end_logits.argmax()) | |
| if e < s: e = s | |
| return vqa_processor.tokenizer.decode( | |
| enc["input_ids"][0][s:e+1], skip_special_tokens=True | |
| ).strip() | |
| except Exception: | |
| return "" | |
| def compute_anls(pred, gts, threshold=0.5): | |
| if isinstance(gts, str): gts = [gts] | |
| if not gts or not pred: return 0.0 | |
| p, best = str(pred).lower().strip(), 0.0 | |
| for gt in gts: | |
| g = str(gt).lower().strip() | |
| ml = max(len(p), len(g)) | |
| if ml == 0: | |
| best = max(best, 1.0); continue | |
| nls = 1.0 - editdistance.eval(p, g) / ml | |
| if nls < threshold: nls = 0.0 | |
| best = max(best, nls) | |
| return best | |
| def compute_f1(pred, gts): | |
| if isinstance(gts, str): gts = [gts] | |
| if not pred or not gts: return 0.0 | |
| pt = set(str(pred).lower().split()) | |
| if not pt: return 0.0 | |
| best = 0.0 | |
| for gt in gts: | |
| gt_t = set(str(gt).lower().split()) | |
| if not gt_t: continue | |
| common = pt & gt_t | |
| if not common: continue | |
| p = len(common) / len(pt) | |
| r = len(common) / len(gt_t) | |
| best = max(best, 2 * p * r / (p + r)) | |
| return best | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # WORD SELECTION VISUALIZER HELPERS | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_sel_idx(item, alpha, beta, gamma): | |
| """Return the SET of word indices kept by this (alpha, beta, gamma) config. | |
| Mirrors the exact selection logic in vqa_infer so the boxes always | |
| match what the model actually sees. | |
| """ | |
| words = list(item.get(WORD_FIELD, [])) | |
| boxes = list(item.get(BOX_FIELD, [])) | |
| n = len(words) | |
| if n == 0: | |
| return set() | |
| n_keep = max(int(n * max(float(alpha), 0.30)), min(5, n)) | |
| n_keep = min(n_keep, n) | |
| if float(gamma) > max(float(alpha), float(beta)) and boxes: | |
| order = sorted(range(n), key=lambda i: (boxes[i][1], boxes[i][0])) | |
| sel_idx = set(order[:n_keep]) | |
| else: | |
| sel_idx = set(range(n_keep)) | |
| return sel_idx | |
| def draw_selection(item, alpha, beta, gamma, title=""): | |
| """Return a PIL Image with coloured bounding boxes overlaid. | |
| π’ Green fill + outline β word KEPT (used for VQA) | |
| π΄ Red fill + outline β word DROPPED (compressed out) | |
| An info strip (dark) and a colour legend strip are appended below the | |
| document image so the panel is self-explanatory at a glance. | |
| """ | |
| try: | |
| img = item["image"].convert("RGB").copy() | |
| W, H = img.size | |
| words = list(item.get(WORD_FIELD, [])) | |
| boxes = list(item.get(BOX_FIELD, [])) | |
| n = min(len(words), len(boxes)) | |
| if n == 0: | |
| return img | |
| sel_idx = get_sel_idx(item, alpha, beta, gamma) | |
| n_keep = len(sel_idx) | |
| pct = 100 * n_keep / max(n, 1) | |
| # ββ Draw semi-transparent coloured overlays βββββββββββββββββββ | |
| overlay = PILImage.new("RGBA", img.size, (0, 0, 0, 0)) | |
| od = PILDraw.Draw(overlay) | |
| for i in range(n): | |
| try: | |
| x0, y0, x1, y1 = (int(boxes[i][0]), int(boxes[i][1]), | |
| int(boxes[i][2]), int(boxes[i][3])) | |
| # Clamp to image bounds | |
| x0, x1 = max(0, x0), min(W - 1, x1) | |
| y0, y1 = max(0, y0), min(H - 1, y1) | |
| if x1 <= x0 or y1 <= y0: | |
| continue | |
| if i in sel_idx: | |
| od.rectangle([x0, y0, x1, y1], | |
| fill=(0, 210, 0, 55), | |
| outline=(0, 160, 0, 230), width=2) | |
| else: | |
| od.rectangle([x0, y0, x1, y1], | |
| fill=(220, 30, 30, 40), | |
| outline=(200, 0, 0, 170), width=1) | |
| except Exception: | |
| continue | |
| img = PILImage.alpha_composite(img.convert("RGBA"), overlay).convert("RGB") | |
| # ββ Load font (graceful fallback) βββββββββββββββββββββββββββββ | |
| font_sm = PILFont.load_default() | |
| for _fp in [ | |
| "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", | |
| "/System/Library/Fonts/Supplemental/Arial.ttf", | |
| "/Windows/Fonts/arial.ttf", | |
| ]: | |
| try: | |
| font_sm = PILFont.truetype(_fp, 13) | |
| break | |
| except Exception: | |
| continue | |
| # ββ Info strip (dark bar showing title + stats) βββββββββββββββ | |
| strip_h = 36 | |
| strip = PILImage.new("RGB", (W, strip_h), (22, 22, 32)) | |
| sd = PILDraw.Draw(strip) | |
| info_text = (f"{title} | β Kept: {n_keep}/{n} ({pct:.0f}%)" | |
| f" | Ξ±={alpha:.2f} Ξ²={beta:.2f} Ξ³={gamma:.2f}") | |
| sd.text((8, 11), info_text, fill=(220, 220, 220), font=font_sm) | |
| # ββ Legend strip (light bar explaining colours) βββββββββββββββ | |
| leg_h = 28 | |
| leg = PILImage.new("RGB", (W, leg_h), (246, 246, 246)) | |
| ld = PILDraw.Draw(leg) | |
| ld.rectangle([8, 7, 24, 21], fill=(0, 180, 0), outline=(0, 130, 0, 255)) | |
| ld.text( [30, 8], "= Kept (used for VQA)", fill=(0, 110, 0), font=font_sm) | |
| ld.rectangle([210, 7, 226, 21], fill=(220, 30, 30), outline=(170, 0, 0, 255)) | |
| ld.text( [232, 8], "= Dropped (compressed out)", fill=(140, 0, 0), font=font_sm) | |
| # ββ Stack: image β dark strip β legend ββββββββββββββββββββββββ | |
| final = PILImage.new("RGB", (W, H + strip_h + leg_h), (255, 255, 255)) | |
| final.paste(img, (0, 0)) | |
| final.paste(strip, (0, H)) | |
| final.paste(leg, (0, H + strip_h)) | |
| return final | |
| except Exception as e: | |
| print(f" draw_selection error: {e}") | |
| return item.get("image", None) | |
| def make_compression_md(item, cfgs): | |
| """Build a markdown table showing kept / dropped word statistics and | |
| a sample of the words that each method discards. | |
| cfgs β OrderedDict/dict {method_name: (alpha, beta, gamma)} | |
| """ | |
| words = list(item.get(WORD_FIELD, [])) | |
| n = len(words) | |
| if n == 0: | |
| return "*No OCR words available for this document.*" | |
| md = "### π What Gets Compressed?\n\n" | |
| md += f"**Total OCR words in document:** {n}\n\n" | |
| md += ("| Method | Ξ± | Ξ² | Ξ³ | Words Kept | % Context |" | |
| " Sample Dropped Words |\n") | |
| md += ("|--------|---|---|---|:----------:|:---------:|" | |
| "----------------------|\n") | |
| for name, (a, b, g) in cfgs.items(): | |
| sel = get_sel_idx(item, a, b, g) | |
| n_keep = len(sel) | |
| pct = 100 * n_keep / max(n, 1) | |
| dropped = [words[i] for i in range(n) if i not in sel] | |
| d_preview = " Β· ".join(dropped[:8]) | |
| if len(dropped) > 8: | |
| d_preview += f" β¦ (+{len(dropped) - 8} more)" | |
| md += (f"| **{name}** | {a:.2f} | {b:.2f} | {g:.2f}" | |
| f" | {n_keep} / {n} | {pct:.0f}% | `{d_preview}` |\n") | |
| # Show the actual kept words for the CAFP+REINFORCE method | |
| if "CAFP+REINFORCE" in cfgs: | |
| a, b, g = cfgs["CAFP+REINFORCE"] | |
| sel = get_sel_idx(item, a, b, g) | |
| kept_w = [words[i] for i in sorted(sel)[:25]] | |
| md += (f"\n**CAFP+REINFORCE β kept words (first 25 shown):** \n" | |
| f"`{' Β· '.join(kept_w)}`\n") | |
| return md | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # LOAD CHECKPOINTS & DATA | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\nLoading checkpoints and data...") | |
| # ββ RL checkpoint βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if not os.path.exists(CKPT_PATH): | |
| sys.exit(f"β Checkpoint not found: {CKPT_PATH}\n" | |
| f" Copy cafp_rl_checkpoint_final.pt into {CKPT_DIR}/") | |
| ck = torch.load(CKPT_PATH, map_location=device, weights_only=False) | |
| spatial_proj.load_state_dict(ck["spatial_proj_state"]) | |
| cafp_soft = CrossAttentionFusionPredictor().to(device) | |
| cafp_soft.load_state_dict(ck["cafp_soft_state"]); cafp_soft.eval() | |
| cafp_rl = copy.deepcopy(cafp_soft) | |
| cafp_rl.load_state_dict(ck["cafp_rl_state"]); cafp_rl.eval() | |
| rl_train_anls = ck["rl_train_anls"] | |
| rl_val_anls = ck.get("rl_val_anls", | |
| max(rl_train_anls) if rl_train_anls else 0.0) | |
| print(f" β CAFP+REINFORCE: {len(rl_train_anls)} epochs | " | |
| f"best_train={max(rl_train_anls):.4f} | val={rl_val_anls:.4f}") | |
| # ββ Dataset βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(" Loading dataset (~30s)...") | |
| from datasets import load_dataset | |
| _ds = load_dataset(DATASET_NAME, split="train") | |
| _split = _ds.train_test_split(test_size=0.2, seed=42) | |
| rng = np.random.RandomState(42) | |
| val_idx = rng.permutation(len(_split["test"])).tolist()[:N_VAL] | |
| train_idx = rng.permutation(len(_split["train"])).tolist()[:N_TRAIN] | |
| val_items = [_split["test"][i] for i in val_idx] | |
| train_items = [_split["train"][i] for i in train_idx] | |
| val_gts = [item[ANSWER_FIELD] for item in val_items] | |
| train_gts = [item[ANSWER_FIELD] for item in train_items] | |
| print(f" β Dataset: {len(val_items)} val, {len(train_items)} train") | |
| # ββ Feature tensors βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if os.path.exists(FEAT_PATH): | |
| t = torch.load(FEAT_PATH, map_location=device, weights_only=False) | |
| val_feats = t["val_feats"] | |
| train_feats = t["train_feats"] | |
| print(f" β Features: {tuple(val_feats.shape)}") | |
| else: | |
| print(" β οΈ feature_tensors.pt not found β recomputing (~2 min)...") | |
| def _feats(items, tag): | |
| out = [] | |
| for i, item in enumerate(items): | |
| out.append(build_feature_vector( | |
| extract_rich_features(item)).unsqueeze(0)) | |
| if (i + 1) % 10 == 0: | |
| print(f" {tag}: {i+1}/{len(items)}", end="\r") | |
| print() | |
| return torch.cat(out).to(device) | |
| val_feats = _feats(val_items, "val") | |
| train_feats = _feats(train_items, "train") | |
| torch.save({"val_feats": val_feats, "train_feats": train_feats, | |
| "val_gts": val_gts, "train_gts": train_gts}, FEAT_PATH) | |
| print(f" β Features computed and saved to {FEAT_PATH}") | |
| # ββ Oracle cache ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| val_oracle = train_oracle = [] | |
| if os.path.exists(ORACLE_CACHE): | |
| _oc = json.load(open(ORACLE_CACHE)) | |
| train_oracle = _oc.get("train", []) | |
| val_oracle = _oc.get("val", []) | |
| print(f" β Oracle cache: {len(train_oracle)} train, {len(val_oracle)} val") | |
| else: | |
| print(" β οΈ oracle_cache.json not found β demo works without it") | |
| # ββ Results from JSON βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| results = {} | |
| _RKEYS = [ | |
| "Equal Fusion", "Proposed Fixed", "Text-Only", | |
| "LLMLingua-style", "Selective Context-style", | |
| "CAFP (paper checkpoint)", "CAFP-Hard Oracle", "CAFP-Soft Oracle", | |
| ] | |
| for _rpath in [RESULTS_PATH, "./final_results.json", "./results_condensed.json"]: | |
| try: | |
| _raw = json.load(open(_rpath)) | |
| for k in _RKEYS: | |
| if k in _raw and isinstance(_raw[k], dict): | |
| r = _raw[k] | |
| results[k] = { | |
| "mean_anls": float(r.get("mean_anls", r.get("anls", 0.0))), | |
| "mean_f1": float(r.get("mean_f1", r.get("f1", 0.0))), | |
| } | |
| if results: | |
| print(f" β Results: {len(results)} methods from {_rpath}") | |
| break | |
| except Exception: | |
| continue | |
| if not results: | |
| print(" β οΈ Results JSON not found β dashboard will show partial data") | |
| # ββ Find best demo documents ββββββββββββββββββββββββββββββββββββββββββ | |
| print("\nPre-scoring documents for demo (this takes ~2 min)...") | |
| demo_scores = [] | |
| cafp_rl.eval() | |
| with torch.no_grad(): | |
| for i in range(len(val_items)): | |
| fv = val_feats[i].unsqueeze(0) | |
| conc = F.softplus(cafp_rl._logits(fv)) + 0.1 | |
| w = (conc / conc.sum()).squeeze(0).cpu().tolist() | |
| rl_s = compute_anls(vqa_infer(val_items[i], w[0], w[1], w[2]), | |
| val_gts[i]) | |
| fx_s = compute_anls(vqa_infer(val_items[i], 0.5, 0.3, 0.2), | |
| val_gts[i]) | |
| demo_scores.append((i, round(rl_s - fx_s, 4), | |
| round(rl_s, 4), round(fx_s, 4))) | |
| if (i + 1) % 20 == 0: | |
| print(f" {i+1}/100", end="\r") | |
| demo_scores.sort(key=lambda x: -x[1]) | |
| best_idx = demo_scores[0][0] | |
| top5_str = ", ".join([f"#{x[0]}(+{x[1]:.2f})" for x in demo_scores[:5]]) | |
| print(f"\n β Best docs: {top5_str}") | |
| print(f"\n{'='*55}") | |
| print("ALL MODELS LOADED β ready to demo") | |
| print(f"{'='*55}\n") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GRADIO FUNCTIONS | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_rl_weights(idx, custom_q=None): | |
| if custom_q and custom_q.strip(): | |
| _item = dict(val_items[idx]) | |
| _item[QUERY_FIELD] = custom_q.strip() | |
| fv = build_feature_vector(extract_rich_features(_item)).unsqueeze(0) | |
| else: | |
| fv = val_feats[idx].unsqueeze(0) | |
| with torch.no_grad(): | |
| conc = F.softplus(cafp_rl._logits(fv)) + 0.1 | |
| w = (conc / conc.sum()).squeeze(0).cpu().tolist() | |
| return w | |
| def make_weight_chart(mw): | |
| fig, ax = plt.subplots(figsize=(9, 3.5)) | |
| labels = list(mw.keys()) | |
| x, bw = np.arange(len(labels)), 0.25 | |
| for j, (lbl, col) in enumerate([ | |
| ("\u03b1 Text", "#2196F3"), | |
| ("\u03b2 Visual", "#4CAF50"), | |
| ("\u03b3 Spatial", "#FF9800"), | |
| ]): | |
| vals = [list(mw.values())[i][j] for i in range(len(labels))] | |
| bars = ax.bar(x + (j - 1) * bw, vals, bw, | |
| label=lbl, color=col, alpha=0.85) | |
| for bar in bars: | |
| h = bar.get_height() | |
| ax.text(bar.get_x() + bar.get_width() / 2, h + 0.01, | |
| f"{h:.2f}", ha="center", va="bottom", fontsize=9) | |
| ax.set_xticks(x); ax.set_xticklabels(labels, fontsize=10) | |
| ax.set_ylabel("Weight"); ax.set_ylim(0, 1.2) | |
| ax.set_title("Fusion Weights (\u03b1, \u03b2, \u03b3) per Method", | |
| fontsize=12, fontweight="bold") | |
| ax.legend(fontsize=9); ax.grid(axis="y", alpha=0.3) | |
| plt.tight_layout() | |
| return fig | |
| def run_demo(doc_idx, custom_q): | |
| doc_idx = int(doc_idx) | |
| item = val_items[doc_idx] | |
| gt = val_gts[doc_idx] | |
| q = (custom_q.strip() | |
| if custom_q and custom_q.strip() | |
| else get_question(item)) | |
| gt_str = (", ".join(str(g) for g in gt[:2]) | |
| if isinstance(gt, list) else str(gt)) | |
| n_words = len(list(item.get(WORD_FIELD, []))) | |
| doc_type = "Text-dominant" if n_words > 40 else "Visual-dominant" | |
| alpha, beta, gamma = get_rl_weights(doc_idx, custom_q) | |
| dom = ("Text" if alpha > 0.65 else | |
| "Visual" if beta > 0.40 else "Balanced") | |
| cfgs = { | |
| "Equal Fusion": (1/3, 1/3, 1/3), | |
| "Fixed (0.5,0.3,0.2)": (0.5, 0.3, 0.2), | |
| "Text-Only": (1.0, 0.0, 0.0), | |
| "CAFP+REINFORCE": (alpha, beta, gamma), | |
| } | |
| res = {} | |
| for name, (a, b, g) in cfgs.items(): | |
| demo_item = dict(item); demo_item[QUERY_FIELD] = q | |
| pred = vqa_infer(demo_item, a, b, g) | |
| res[name] = { | |
| "pred": pred, | |
| "anls": compute_anls(pred, gt), | |
| "f1": compute_f1(pred, gt), | |
| "w": (a, b, g), | |
| } | |
| best = max(res, key=lambda k: res[k]["anls"]) | |
| rl_vs_fixed = res["CAFP+REINFORCE"]["anls"] - res["Fixed (0.5,0.3,0.2)"]["anls"] | |
| md = f"## Document #{doc_idx} \u2014 {doc_type} ({n_words} words)\n\n" | |
| md += f"**Question:** {q}\n\n" | |
| md += f"**Ground Truth:** `{gt_str}`\n\n---\n" | |
| md += "### Step 1 \u2014 Text Extraction\n" | |
| md += f"`{n_words}` OCR words extracted via LayoutLMv3\n\n" | |
| md += "### Step 2 \u2014 Multimodal Feature Extraction\n" | |
| md += ("- **Text** \u2192 LayoutLMv3 token embeddings [768-D]\n" | |
| "- **Visual** \u2192 LayoutLMv3 patch features [768-D]\n" | |
| "- **Spatial** \u2192 Bounding box layout encoding [768-D]\n\n") | |
| md += "### Step 3 \u2014 CAFP+REINFORCE Weight Prediction\n" | |
| md += "| Modality | Weight |\n|----------|--------|\n" | |
| md += f"| \u03b1 Text | **{alpha:.3f}** |\n" | |
| md += f"| \u03b2 Visual | **{beta:.3f}** |\n" | |
| md += f"| \u03b3 Spatial | **{gamma:.3f}** |\n\n" | |
| md += f"\u2192 **Dominant: {dom}**\n\n" | |
| md += "### Step 4 \u2014 Adaptive Fusion \u2192 Answer\n" | |
| md += "| Method | \u03b1 | \u03b2 | \u03b3 | Answer | ANLS | F1 |\n" | |
| md += "|--------|---|---|---|--------|------|----|\n" | |
| for name, d in res.items(): | |
| a, b, g = d["w"] | |
| star = " \u2b50" if name == best else "" | |
| md += (f"| {name}{star} | {a:.2f} | {b:.2f} | {g:.2f}" | |
| f" | `{d['pred']}` | **{d['anls']:.4f}** | {d['f1']:.4f} |\n") | |
| sign = "+" if rl_vs_fixed >= 0 else "" | |
| md += (f"\n---\n**Best:** {best} (ANLS: {res[best]['anls']:.4f})\n\n" | |
| f"**CAFP+REINFORCE answer:** `{res['CAFP+REINFORCE']['pred']}`\n\n" | |
| f"**\u0394 over Fixed:** {sign}{rl_vs_fixed:.4f}\n") | |
| chart = make_weight_chart({k: v["w"] for k, v in res.items()}) | |
| # ββ Word Selection Visualizations βββββββββββββββββββββββββββββββββ | |
| _item_q = dict(item); _item_q[QUERY_FIELD] = q | |
| fixed_vis = draw_selection( | |
| _item_q, 0.5, 0.3, 0.2, | |
| "Fixed Weights (0.5, 0.3, 0.2)" | |
| ) | |
| rl_vis = draw_selection( | |
| _item_q, alpha, beta, gamma, | |
| f"CAFP+REINFORCE (Ξ±={alpha:.2f} Ξ²={beta:.2f} Ξ³={gamma:.2f})" | |
| ) | |
| comp_md = make_compression_md(item, cfgs) | |
| return item.get("image", None), md, chart, fixed_vis, rl_vis, comp_md | |
| def show_dashboard(): | |
| def sg(k): return results.get(k, {}).get("mean_anls", 0.0) | |
| def sf(k): return results.get(k, {}).get("mean_f1", 0.0) | |
| fixed = sg("Proposed Fixed"); oracle = 0.8377 | |
| rv = rl_val_anls | |
| rows = [ | |
| ("Equal Fusion", sg("Equal Fusion"), sf("Equal Fusion")), | |
| ("Proposed Fixed (paper)", sg("Proposed Fixed"), sf("Proposed Fixed")), | |
| ("Text-Only", sg("Text-Only"), sf("Text-Only")), | |
| ("LLMLingua-style [NEW]", sg("LLMLingua-style"), sf("LLMLingua-style")), | |
| ("Selective Context [NEW]", sg("Selective Context-style"), sf("Selective Context-style")), | |
| ("CAFP paper checkpoint", sg("CAFP (paper checkpoint)"), sf("CAFP (paper checkpoint)")), | |
| ("CAFP Hard Oracle [NEW]", sg("CAFP-Hard Oracle"), sf("CAFP-Hard Oracle")), | |
| ("CAFP Soft Oracle [NEW]", sg("CAFP-Soft Oracle"), sf("CAFP-Soft Oracle")), | |
| ("CAFP+REINFORCE [NEW][BEST]", rv, 0.0), | |
| ("Oracle Upper Bound", oracle, 0.0), | |
| ] | |
| md = "## Full Experiment Results\n\n" | |
| md += "| Method | ANLS | F1 | \u0394 Fixed | % Oracle |\n" | |
| md += "|--------|------|----|----------|----------|\n" | |
| for name, anls, f1 in rows: | |
| is_oracle = "Oracle Upper" in name | |
| d = f"{anls - fixed:+.4f}" if not is_oracle else "\u2014" | |
| pct = f"{anls / oracle * 100:.1f}%" if anls > 0 else "\u2014" | |
| md += f"| {name} | {anls:.4f} | {f1:.4f} | {d} | {pct} |\n" | |
| md += (f"\n**CAFP+REINFORCE: {rv/oracle*100:.1f}% of Oracle ANLS**\n" | |
| f"**Improvement over Fixed: {rv - fixed:+.4f} ANLS**\n") | |
| # Bar chart | |
| fig1, ax1 = plt.subplots(figsize=(11, 5)) | |
| bv = [r[1] for r in rows] | |
| bc = ["#bbb","#999","#bbb","#2196F3","#2196F3", | |
| "#777","#4CAF50","#4CAF50","#FF5722","#d32f2f"] | |
| bars = ax1.barh([r[0] for r in rows], bv, | |
| color=bc, edgecolor="white", height=0.65) | |
| ax1.axvline(oracle, color="red", linestyle="--", lw=1.5, | |
| label=f"Oracle {oracle:.4f}") | |
| ax1.axvline(fixed, color="gray", linestyle=":", lw=1.2, | |
| label=f"Fixed {fixed:.4f}") | |
| for bar, val in zip(bars, bv): | |
| if val > 0: | |
| ax1.text(val + 0.003, bar.get_y() + bar.get_height() / 2, | |
| f"{val:.4f}", va="center", fontsize=8) | |
| ax1.set_xlabel("Val ANLS", fontsize=11) | |
| ax1.set_title("All Methods \u2014 Val ANLS", fontsize=13, fontweight="bold") | |
| ax1.legend(fontsize=9); ax1.invert_yaxis() | |
| ax1.set_xlim(0, oracle * 1.1); ax1.grid(axis="x", alpha=0.3) | |
| plt.tight_layout() | |
| # Training curve | |
| fig2, ax2 = plt.subplots(figsize=(10, 3.5)) | |
| eps = list(range(1, len(rl_train_anls) + 1)) | |
| ax2.plot(eps, rl_train_anls, "o-", color="#FF5722", | |
| lw=2.5, ms=7, label="Train ANLS") | |
| ax2.axhline(rv, color="#FF5722", linestyle=":", lw=2, | |
| label=f"Val ANLS = {rv:.4f}") | |
| ax2.axhline(oracle, color="red", linestyle="--", lw=1.5, | |
| label=f"Oracle = {oracle:.4f}") | |
| ax2.axhline(fixed, color="gray", linestyle=":", lw=1.2, | |
| label=f"Fixed = {fixed:.4f}") | |
| ax2.fill_between(eps, rl_train_anls, fixed, alpha=0.15, color="#FF5722") | |
| ax2.set_xlabel("Epoch"); ax2.set_ylabel("ANLS") | |
| ax2.set_title("REINFORCE Fine-tuning Progress", | |
| fontsize=12, fontweight="bold") | |
| ax2.legend(fontsize=9); ax2.grid(True, alpha=0.3) | |
| ax2.set_xticks(eps); plt.tight_layout() | |
| return md, fig1, fig2 | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GRADIO UI | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _fixed_anls = results.get("Proposed Fixed", {}).get("mean_anls", 0.0) | |
| _best_label = ("Best docs (REINFORCE wins most): " | |
| + ", ".join([f"#{x[0]}" for x in demo_scores[:5]])) | |
| CSS = ".tab-nav button { font-size: 15px !important; font-weight: 600 !important; }" | |
| with gr.Blocks( | |
| title="Adaptive Multimodal Fusion β DocVQA Demo", | |
| theme=gr.themes.Soft(primary_hue="blue"), | |
| css=CSS, | |
| ) as demo_app: | |
| gr.Markdown(""" | |
| # Adaptive Multimodal Fusion for Document VQA | |
| ### Cross-Attention Fusion Predictor (CAFP) + REINFORCE Fine-tuning | |
| """) | |
| with gr.Tabs(): | |
| # ββ Tab 1: Live Demo ββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem("\U0001f3af Live Demo"): | |
| gr.Markdown(f"**{_best_label}**") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| doc_slider = gr.Slider( | |
| 0, len(val_items) - 1, | |
| value=best_idx, step=1, | |
| label=f"Document Index (0\u2013{len(val_items)-1})" | |
| ) | |
| custom_q = gr.Textbox( | |
| label="Custom Question (optional)", | |
| placeholder="Leave blank to use original question" | |
| ) | |
| run_btn = gr.Button( | |
| "\u25b6 Run Adaptive Fusion", | |
| variant="primary", size="lg" | |
| ) | |
| gr.Markdown( | |
| "*Compares: Equal \u00b7 Fixed \u00b7 " | |
| "Text-Only \u00b7 CAFP+REINFORCE*" | |
| ) | |
| with gr.Column(scale=2): | |
| doc_image = gr.Image(label="Document Image", height=400) | |
| step_md = gr.Markdown() | |
| weight_chart = gr.Plot(label="Fusion Weights Comparison") | |
| # ββ Word Selection Visualizer βββββββββββββββββββββββββββββ | |
| gr.Markdown(""" | |
| --- | |
| ### π¨ Word Selection Visualization | |
| *See **exactly** which OCR words each method keeps vs discards.* | |
| π’ **Green** = kept and fed to the VQA model Β· π΄ **Red** = compressed out | |
| """) | |
| with gr.Row(): | |
| fixed_vis_img = gr.Image( | |
| label="π Fixed Weights (Ξ±=0.5 Ξ²=0.3 Ξ³=0.2)", | |
| height=520 | |
| ) | |
| rl_vis_img = gr.Image( | |
| label="π€ CAFP+REINFORCE (Adaptive Weights)", | |
| height=520 | |
| ) | |
| comp_md_out = gr.Markdown() | |
| run_btn.click( | |
| fn=run_demo, | |
| inputs=[doc_slider, custom_q], | |
| outputs=[doc_image, step_md, weight_chart, | |
| fixed_vis_img, rl_vis_img, comp_md_out], | |
| ) | |
| # ββ Tab 2: Results Dashboard ββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem("\U0001f4ca Results Dashboard"): | |
| gr.Markdown("### All methods compared + REINFORCE training curve") | |
| load_btn = gr.Button("Load Results", variant="secondary") | |
| res_md = gr.Markdown() | |
| with gr.Row(): | |
| bar_chart = gr.Plot(label="ANLS \u2014 All Methods") | |
| rl_curve = gr.Plot(label="REINFORCE Training Curve") | |
| load_btn.click( | |
| fn=show_dashboard, | |
| inputs=[], | |
| outputs=[res_md, bar_chart, rl_curve], | |
| ) | |
| # ββ Tab 3: About ββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem("\u2139\ufe0f About"): | |
| gr.Markdown(f""" | |
| ## Adaptive Multimodal Fusion for DocVQA | |
| ### Problem | |
| DocVQA requires reasoning over three modalities simultaneously: | |
| - **Text** β OCR words and their semantics | |
| - **Visual** β Document appearance and image patches | |
| - **Spatial** β Bounding box positions and layout structure | |
| Fixed weights (Ξ±=0.5, Ξ²=0.3, Ξ³=0.2) cannot adapt to different document types. | |
| ### Architecture: CAFP (428K params) | |
| 1. Projects each modality embedding to 128-D | |
| 2. Cross-attention: question attends to all modality representations | |
| 3. Predicts per-document (Ξ±, Ξ², Ξ³) fusion weights | |
| ### Training Pipeline | |
| 1. **Hard Oracle** (MSE) β argmax weights from 20-combo grid search | |
| 2. **Soft Oracle** (KL-div) β temperature-smoothed ANLS-weighted targets | |
| 3. **REINFORCE** β Policy gradient on direct ANLS reward (K=3 samples/step) | |
| ### Novel Contributions | |
| 1. Soft Oracle training eliminates hard-oracle label noise | |
| 2. REINFORCE fine-tuning directly maximises DocVQA metric | |
| 3. LLMLingua-style and Selective Context baselines for fair comparison | |
| ### Key Result | |
| **CAFP+REINFORCE achieves {rl_val_anls/0.8377*100:.1f}% of Oracle ANLS** | |
| Improvement over fixed-weight baseline: {rl_val_anls - _fixed_anls:+.4f} ANLS | |
| """) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # LAUNCH | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| demo_app.launch( | |
| server_name="0.0.0.0", | |
| server_port=args.port, | |
| share=args.share, | |
| show_error=True, | |
| ) | |