Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .DS_Store +0 -0
- aoo2.poy +890 -0
- app.py +890 -0
- checkpoints/.DS_Store +0 -0
- checkpoints/cafp_rl_checkpoint_final.pt +3 -0
- checkpoints/feature_tensors.pt +3 -0
- checkpoints/final_results.json +56 -0
- checkpoints/oracle_cache.json +0 -0
- requirements.txt +12 -0
.DS_Store
ADDED
|
Binary file (10.2 kB). View file
|
|
|
aoo2.poy
ADDED
|
@@ -0,0 +1,890 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adaptive Multimodal Fusion for DocVQA β Gradio Demo
|
| 3 |
+
====================================================
|
| 4 |
+
Run locally:
|
| 5 |
+
python app.py
|
| 6 |
+
|
| 7 |
+
Run with public URL (72hr):
|
| 8 |
+
python app.py --share
|
| 9 |
+
|
| 10 |
+
Deploy to HuggingFace Spaces:
|
| 11 |
+
- Push this file + requirements.txt + checkpoints/ folder to a Space repo
|
| 12 |
+
- HF Spaces auto-launches on port 7860
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import argparse, os, sys, copy, json, warnings
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
import matplotlib
|
| 21 |
+
matplotlib.use("Agg")
|
| 22 |
+
import matplotlib.pyplot as plt
|
| 23 |
+
import gradio as gr
|
| 24 |
+
import editdistance
|
| 25 |
+
from PIL import Image as PILImage, ImageDraw as PILDraw, ImageFont as PILFont
|
| 26 |
+
warnings.filterwarnings("ignore")
|
| 27 |
+
|
| 28 |
+
# ββ CLI args ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 29 |
+
parser = argparse.ArgumentParser()
|
| 30 |
+
parser.add_argument("--share", action="store_true", help="Create public Gradio URL")
|
| 31 |
+
parser.add_argument("--port", type=int, default=7860)
|
| 32 |
+
parser.add_argument("--ckpt_dir", type=str, default="./checkpoints",
|
| 33 |
+
help="Folder containing all saved files")
|
| 34 |
+
args, _ = parser.parse_known_args()
|
| 35 |
+
|
| 36 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
+
# CONFIGURATION β edit paths here if needed
|
| 38 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 39 |
+
CKPT_DIR = args.ckpt_dir
|
| 40 |
+
ORACLE_CACHE = os.path.join(CKPT_DIR, "oracle_cache.json")
|
| 41 |
+
FEAT_PATH = os.path.join(CKPT_DIR, "feature_tensors.pt")
|
| 42 |
+
RESULTS_PATH = os.path.join(CKPT_DIR, "final_results.json")
|
| 43 |
+
|
| 44 |
+
# Try final checkpoint first, fall back to intermediate
|
| 45 |
+
CKPT_PATH = os.path.join(CKPT_DIR, "cafp_rl_checkpoint_final.pt")
|
| 46 |
+
if not os.path.exists(CKPT_PATH):
|
| 47 |
+
CKPT_PATH = os.path.join(CKPT_DIR, "cafp_rl_checkpoint.pt")
|
| 48 |
+
|
| 49 |
+
# Dataset / model IDs
|
| 50 |
+
DATASET_NAME = "nielsr/docvqa_1200_examples"
|
| 51 |
+
FEAT_MODEL_ID = "microsoft/layoutlmv3-base"
|
| 52 |
+
VQA_MODEL_ID = "rubentito/layoutlmv3-base-mpdocvqa"
|
| 53 |
+
SBERT_ID = "all-MiniLM-L6-v2"
|
| 54 |
+
|
| 55 |
+
# Field names
|
| 56 |
+
WORD_FIELD = "words"
|
| 57 |
+
BOX_FIELD = "bounding_boxes"
|
| 58 |
+
QUERY_FIELD = "query"
|
| 59 |
+
ANSWER_FIELD = "answers"
|
| 60 |
+
|
| 61 |
+
# Architecture
|
| 62 |
+
MAX_WORDS = 64
|
| 63 |
+
N_PATCHES = 49
|
| 64 |
+
N_VAL = 100
|
| 65 |
+
N_TRAIN = 100
|
| 66 |
+
FEAT_DIM = 2701
|
| 67 |
+
PROJ_DIM = 128
|
| 68 |
+
|
| 69 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 70 |
+
print(f"Device: {device}")
|
| 71 |
+
|
| 72 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 73 |
+
# MODEL CLASSES
|
| 74 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 75 |
+
class CrossAttentionFusionPredictor(nn.Module):
|
| 76 |
+
def __init__(self, feat_dim=FEAT_DIM, proj_dim=PROJ_DIM,
|
| 77 |
+
n_heads=4, dropout=0.15):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.text_proj = nn.Linear(768, proj_dim)
|
| 80 |
+
self.visual_proj = nn.Linear(768, proj_dim)
|
| 81 |
+
self.spatial_proj_lyr = nn.Linear(768, proj_dim)
|
| 82 |
+
self.q_proj = nn.Sequential(
|
| 83 |
+
nn.Linear(384, proj_dim), nn.LayerNorm(proj_dim), nn.GELU()
|
| 84 |
+
)
|
| 85 |
+
self.cross_attn = nn.MultiheadAttention(
|
| 86 |
+
proj_dim, n_heads, dropout=dropout, batch_first=True
|
| 87 |
+
)
|
| 88 |
+
self.attn_norm = nn.LayerNorm(proj_dim)
|
| 89 |
+
self.head = nn.Sequential(
|
| 90 |
+
nn.Linear(proj_dim + 3, proj_dim), nn.GELU(),
|
| 91 |
+
nn.Dropout(dropout), nn.Linear(proj_dim, 3)
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def _logits(self, x):
|
| 95 |
+
h_t = self.text_proj(x[:, 0:768])
|
| 96 |
+
h_v = self.visual_proj(x[:, 768:1536])
|
| 97 |
+
h_s = self.spatial_proj_lyr(x[:, 1536:2304])
|
| 98 |
+
q = self.q_proj(x[:, 2314:2698]).unsqueeze(1)
|
| 99 |
+
kv = torch.stack([h_t, h_v, h_s], dim=1)
|
| 100 |
+
ctx, _ = self.cross_attn(q, kv, kv)
|
| 101 |
+
ctx = self.attn_norm(ctx.squeeze(1))
|
| 102 |
+
return self.head(torch.cat([ctx, x[:, 2698:2701]], dim=-1))
|
| 103 |
+
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
return F.softmax(self._logits(x), dim=-1)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 109 |
+
# LOAD BASE MODELS
|
| 110 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 111 |
+
print("Loading base models (takes ~5 min on first run, cached after)...")
|
| 112 |
+
|
| 113 |
+
from transformers import AutoProcessor, AutoModel, AutoModelForQuestionAnswering
|
| 114 |
+
from sentence_transformers import SentenceTransformer
|
| 115 |
+
|
| 116 |
+
feat_processor = AutoProcessor.from_pretrained(FEAT_MODEL_ID, apply_ocr=False)
|
| 117 |
+
feat_model = AutoModel.from_pretrained(FEAT_MODEL_ID).to(device).eval()
|
| 118 |
+
for p in feat_model.parameters(): p.requires_grad_(False)
|
| 119 |
+
print(" β
LayoutLMv3 feature model")
|
| 120 |
+
|
| 121 |
+
vqa_processor = AutoProcessor.from_pretrained(VQA_MODEL_ID, apply_ocr=False)
|
| 122 |
+
vqa_model = AutoModelForQuestionAnswering.from_pretrained(
|
| 123 |
+
VQA_MODEL_ID).to(device).eval()
|
| 124 |
+
for p in vqa_model.parameters(): p.requires_grad_(False)
|
| 125 |
+
print(" β
VQA model")
|
| 126 |
+
|
| 127 |
+
sbert = SentenceTransformer(SBERT_ID)
|
| 128 |
+
sbert.to(device)
|
| 129 |
+
print(" β
SBERT")
|
| 130 |
+
|
| 131 |
+
spatial_proj = nn.Sequential(
|
| 132 |
+
nn.Linear(10, 256), nn.ReLU(), nn.Linear(256, 768)
|
| 133 |
+
).to(device)
|
| 134 |
+
|
| 135 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 136 |
+
# HELPER FUNCTIONS
|
| 137 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 138 |
+
def get_question(item):
|
| 139 |
+
q = item.get(QUERY_FIELD, item.get("question", ""))
|
| 140 |
+
if isinstance(q, dict):
|
| 141 |
+
q = q.get("en", next(iter(q.values()), ""))
|
| 142 |
+
return str(q).strip()
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def normalize_boxes(boxes, w, h):
|
| 146 |
+
return [
|
| 147 |
+
[
|
| 148 |
+
int(max(0, min(b[0] / max(w, 1), 1)) * 1000),
|
| 149 |
+
int(max(0, min(b[1] / max(h, 1), 1)) * 1000),
|
| 150 |
+
int(max(0, min(b[2] / max(w, 1), 1)) * 1000),
|
| 151 |
+
int(max(0, min(b[3] / max(h, 1), 1)) * 1000),
|
| 152 |
+
]
|
| 153 |
+
for b in boxes
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def extract_rich_features(item):
|
| 158 |
+
try:
|
| 159 |
+
img = item["image"].convert("RGB")
|
| 160 |
+
W, H = img.size
|
| 161 |
+
words = list(item.get(WORD_FIELD, []))[:MAX_WORDS] or ["[PAD]"]
|
| 162 |
+
boxes = list(item.get(BOX_FIELD, []))[:MAX_WORDS] or [[0, 0, 1, 1]]
|
| 163 |
+
question = get_question(item)
|
| 164 |
+
bn = normalize_boxes(boxes, W, H)
|
| 165 |
+
enc = feat_processor(img, text=words, boxes=bn,
|
| 166 |
+
return_tensors="pt", truncation=True,
|
| 167 |
+
max_length=512, padding="max_length")
|
| 168 |
+
enc = {k: v.to(device) for k, v in enc.items()}
|
| 169 |
+
with torch.no_grad():
|
| 170 |
+
hidden = feat_model(**enc).last_hidden_state[0]
|
| 171 |
+
n_txt = max(2, hidden.shape[0] - N_PATCHES)
|
| 172 |
+
H_text = hidden[1:n_txt-1].mean(0) if n_txt > 2 else hidden[0]
|
| 173 |
+
H_visual = hidden[-N_PATCHES:].mean(0)
|
| 174 |
+
bx = np.array(bn, dtype=np.float32)
|
| 175 |
+
cx = ((bx[:, 0] + bx[:, 2]) / 2) / 1000.0
|
| 176 |
+
cy = ((bx[:, 1] + bx[:, 3]) / 2) / 1000.0
|
| 177 |
+
sp = np.array([
|
| 178 |
+
W / 1000.0, H / 1000.0, min(W, H) / max(W, H),
|
| 179 |
+
len(words) / MAX_WORDS,
|
| 180 |
+
cx.mean(), cy.mean(), cx.std() + 1e-6, cy.std() + 1e-6,
|
| 181 |
+
H_text.norm().item() / 10.0,
|
| 182 |
+
H_visual.norm().item() / 10.0,
|
| 183 |
+
], dtype=np.float32)
|
| 184 |
+
sp10 = torch.tensor(sp).to(device)
|
| 185 |
+
H_spat = spatial_proj(sp10.unsqueeze(0)).squeeze(0)
|
| 186 |
+
q_emb = torch.tensor(sbert.encode(question),
|
| 187 |
+
dtype=torch.float32).to(device)
|
| 188 |
+
return {
|
| 189 |
+
"H_text": H_text, "H_visual": H_visual, "H_spatial": H_spat,
|
| 190 |
+
"spatial_10": sp10, "question_emb": q_emb,
|
| 191 |
+
"text_score": float(np.clip(sp[8], 0, 1)),
|
| 192 |
+
"visual_score": float(np.clip(sp[9], 0, 1)),
|
| 193 |
+
"spatial_score": float(np.clip(sp[6], 0, 1)),
|
| 194 |
+
"n_tokens": len(words),
|
| 195 |
+
}
|
| 196 |
+
except Exception as e:
|
| 197 |
+
print(f" extract_rich_features error: {e}")
|
| 198 |
+
dummy = torch.zeros(768, device=device)
|
| 199 |
+
return {
|
| 200 |
+
"H_text": dummy, "H_visual": dummy, "H_spatial": dummy,
|
| 201 |
+
"spatial_10": torch.zeros(10, device=device),
|
| 202 |
+
"question_emb": torch.zeros(384, device=device),
|
| 203 |
+
"text_score": 0.5, "visual_score": 0.3, "spatial_score": 0.2,
|
| 204 |
+
"n_tokens": 0,
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def build_feature_vector(feat):
|
| 209 |
+
return torch.cat([
|
| 210 |
+
feat["H_text"], feat["H_visual"], feat["H_spatial"],
|
| 211 |
+
feat["spatial_10"], feat["question_emb"],
|
| 212 |
+
torch.tensor(
|
| 213 |
+
[feat["text_score"], feat["visual_score"], feat["spatial_score"]],
|
| 214 |
+
dtype=torch.float32, device=device
|
| 215 |
+
),
|
| 216 |
+
])
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def vqa_infer(item, alpha, beta, gamma):
|
| 220 |
+
try:
|
| 221 |
+
img = item["image"].convert("RGB")
|
| 222 |
+
words = list(item.get(WORD_FIELD, []))
|
| 223 |
+
boxes = list(item.get(BOX_FIELD, []))
|
| 224 |
+
question = get_question(item)
|
| 225 |
+
if not words:
|
| 226 |
+
return ""
|
| 227 |
+
W, H = img.size
|
| 228 |
+
n = len(words)
|
| 229 |
+
n_keep = max(int(n * max(float(alpha), 0.30)), min(5, n))
|
| 230 |
+
if float(gamma) > max(float(alpha), float(beta)) and boxes:
|
| 231 |
+
order = sorted(range(n), key=lambda i: (boxes[i][1], boxes[i][0]))
|
| 232 |
+
sel_idx = sorted(order[:n_keep])
|
| 233 |
+
else:
|
| 234 |
+
sel_idx = list(range(n_keep))
|
| 235 |
+
sw = [words[i] for i in sel_idx]
|
| 236 |
+
sb = ([boxes[i] for i in sel_idx]
|
| 237 |
+
if boxes else [[0, 0, W, H]] * len(sw))
|
| 238 |
+
enc = vqa_processor(
|
| 239 |
+
img, text=question, text_pair=sw,
|
| 240 |
+
boxes=normalize_boxes(sb, W, H),
|
| 241 |
+
return_tensors="pt", truncation=True,
|
| 242 |
+
max_length=512, padding=True
|
| 243 |
+
)
|
| 244 |
+
enc = {k: v.to(device) for k, v in enc.items()}
|
| 245 |
+
with torch.no_grad():
|
| 246 |
+
out = vqa_model(**enc)
|
| 247 |
+
s = int(out.start_logits.argmax())
|
| 248 |
+
e = int(out.end_logits.argmax())
|
| 249 |
+
if e < s: e = s
|
| 250 |
+
return vqa_processor.tokenizer.decode(
|
| 251 |
+
enc["input_ids"][0][s:e+1], skip_special_tokens=True
|
| 252 |
+
).strip()
|
| 253 |
+
except Exception:
|
| 254 |
+
return ""
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def compute_anls(pred, gts, threshold=0.5):
|
| 258 |
+
if isinstance(gts, str): gts = [gts]
|
| 259 |
+
if not gts or not pred: return 0.0
|
| 260 |
+
p, best = str(pred).lower().strip(), 0.0
|
| 261 |
+
for gt in gts:
|
| 262 |
+
g = str(gt).lower().strip()
|
| 263 |
+
ml = max(len(p), len(g))
|
| 264 |
+
if ml == 0:
|
| 265 |
+
best = max(best, 1.0); continue
|
| 266 |
+
nls = 1.0 - editdistance.eval(p, g) / ml
|
| 267 |
+
if nls < threshold: nls = 0.0
|
| 268 |
+
best = max(best, nls)
|
| 269 |
+
return best
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def compute_f1(pred, gts):
|
| 273 |
+
if isinstance(gts, str): gts = [gts]
|
| 274 |
+
if not pred or not gts: return 0.0
|
| 275 |
+
pt = set(str(pred).lower().split())
|
| 276 |
+
if not pt: return 0.0
|
| 277 |
+
best = 0.0
|
| 278 |
+
for gt in gts:
|
| 279 |
+
gt_t = set(str(gt).lower().split())
|
| 280 |
+
if not gt_t: continue
|
| 281 |
+
common = pt & gt_t
|
| 282 |
+
if not common: continue
|
| 283 |
+
p = len(common) / len(pt)
|
| 284 |
+
r = len(common) / len(gt_t)
|
| 285 |
+
best = max(best, 2 * p * r / (p + r))
|
| 286 |
+
return best
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 290 |
+
# WORD SELECTION VISUALIZER HELPERS
|
| 291 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 292 |
+
|
| 293 |
+
def get_sel_idx(item, alpha, beta, gamma):
|
| 294 |
+
"""Return the SET of word indices kept by this (alpha, beta, gamma) config.
|
| 295 |
+
|
| 296 |
+
Mirrors the exact selection logic in vqa_infer so the boxes always
|
| 297 |
+
match what the model actually sees.
|
| 298 |
+
"""
|
| 299 |
+
words = list(item.get(WORD_FIELD, []))
|
| 300 |
+
boxes = list(item.get(BOX_FIELD, []))
|
| 301 |
+
n = len(words)
|
| 302 |
+
if n == 0:
|
| 303 |
+
return set()
|
| 304 |
+
n_keep = max(int(n * max(float(alpha), 0.30)), min(5, n))
|
| 305 |
+
n_keep = min(n_keep, n)
|
| 306 |
+
if float(gamma) > max(float(alpha), float(beta)) and boxes:
|
| 307 |
+
order = sorted(range(n), key=lambda i: (boxes[i][1], boxes[i][0]))
|
| 308 |
+
sel_idx = set(order[:n_keep])
|
| 309 |
+
else:
|
| 310 |
+
sel_idx = set(range(n_keep))
|
| 311 |
+
return sel_idx
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def draw_selection(item, alpha, beta, gamma, title=""):
|
| 315 |
+
"""Return a PIL Image with coloured bounding boxes overlaid.
|
| 316 |
+
|
| 317 |
+
π’ Green fill + outline β word KEPT (used for VQA)
|
| 318 |
+
π΄ Red fill + outline β word DROPPED (compressed out)
|
| 319 |
+
|
| 320 |
+
An info strip (dark) and a colour legend strip are appended below the
|
| 321 |
+
document image so the panel is self-explanatory at a glance.
|
| 322 |
+
"""
|
| 323 |
+
try:
|
| 324 |
+
img = item["image"].convert("RGB").copy()
|
| 325 |
+
W, H = img.size
|
| 326 |
+
words = list(item.get(WORD_FIELD, []))
|
| 327 |
+
boxes = list(item.get(BOX_FIELD, []))
|
| 328 |
+
n = min(len(words), len(boxes))
|
| 329 |
+
if n == 0:
|
| 330 |
+
return img
|
| 331 |
+
|
| 332 |
+
sel_idx = get_sel_idx(item, alpha, beta, gamma)
|
| 333 |
+
n_keep = len(sel_idx)
|
| 334 |
+
pct = 100 * n_keep / max(n, 1)
|
| 335 |
+
|
| 336 |
+
# ββ Draw semi-transparent coloured overlays βββββββββββββββββββ
|
| 337 |
+
overlay = PILImage.new("RGBA", img.size, (0, 0, 0, 0))
|
| 338 |
+
od = PILDraw.Draw(overlay)
|
| 339 |
+
for i in range(n):
|
| 340 |
+
try:
|
| 341 |
+
x0, y0, x1, y1 = (int(boxes[i][0]), int(boxes[i][1]),
|
| 342 |
+
int(boxes[i][2]), int(boxes[i][3]))
|
| 343 |
+
# Clamp to image bounds
|
| 344 |
+
x0, x1 = max(0, x0), min(W - 1, x1)
|
| 345 |
+
y0, y1 = max(0, y0), min(H - 1, y1)
|
| 346 |
+
if x1 <= x0 or y1 <= y0:
|
| 347 |
+
continue
|
| 348 |
+
if i in sel_idx:
|
| 349 |
+
od.rectangle([x0, y0, x1, y1],
|
| 350 |
+
fill=(0, 210, 0, 55),
|
| 351 |
+
outline=(0, 160, 0, 230), width=2)
|
| 352 |
+
else:
|
| 353 |
+
od.rectangle([x0, y0, x1, y1],
|
| 354 |
+
fill=(220, 30, 30, 40),
|
| 355 |
+
outline=(200, 0, 0, 170), width=1)
|
| 356 |
+
except Exception:
|
| 357 |
+
continue
|
| 358 |
+
img = PILImage.alpha_composite(img.convert("RGBA"), overlay).convert("RGB")
|
| 359 |
+
|
| 360 |
+
# ββ Load font (graceful fallback) βοΏ½οΏ½βββββββββββββββββββββββββββ
|
| 361 |
+
font_sm = PILFont.load_default()
|
| 362 |
+
for _fp in [
|
| 363 |
+
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
|
| 364 |
+
"/System/Library/Fonts/Supplemental/Arial.ttf",
|
| 365 |
+
"/Windows/Fonts/arial.ttf",
|
| 366 |
+
]:
|
| 367 |
+
try:
|
| 368 |
+
font_sm = PILFont.truetype(_fp, 13)
|
| 369 |
+
break
|
| 370 |
+
except Exception:
|
| 371 |
+
continue
|
| 372 |
+
|
| 373 |
+
# ββ Info strip (dark bar showing title + stats) βββββββββββββββ
|
| 374 |
+
strip_h = 36
|
| 375 |
+
strip = PILImage.new("RGB", (W, strip_h), (22, 22, 32))
|
| 376 |
+
sd = PILDraw.Draw(strip)
|
| 377 |
+
info_text = (f"{title} | β Kept: {n_keep}/{n} ({pct:.0f}%)"
|
| 378 |
+
f" | Ξ±={alpha:.2f} Ξ²={beta:.2f} Ξ³={gamma:.2f}")
|
| 379 |
+
sd.text((8, 11), info_text, fill=(220, 220, 220), font=font_sm)
|
| 380 |
+
|
| 381 |
+
# ββ Legend strip (light bar explaining colours) βββββββββββββββ
|
| 382 |
+
leg_h = 28
|
| 383 |
+
leg = PILImage.new("RGB", (W, leg_h), (246, 246, 246))
|
| 384 |
+
ld = PILDraw.Draw(leg)
|
| 385 |
+
ld.rectangle([8, 7, 24, 21], fill=(0, 180, 0), outline=(0, 130, 0, 255))
|
| 386 |
+
ld.text( [30, 8], "= Kept (used for VQA)", fill=(0, 110, 0), font=font_sm)
|
| 387 |
+
ld.rectangle([210, 7, 226, 21], fill=(220, 30, 30), outline=(170, 0, 0, 255))
|
| 388 |
+
ld.text( [232, 8], "= Dropped (compressed out)", fill=(140, 0, 0), font=font_sm)
|
| 389 |
+
|
| 390 |
+
# ββ Stack: image β dark strip β legend ββββββββββββββββββββββββ
|
| 391 |
+
final = PILImage.new("RGB", (W, H + strip_h + leg_h), (255, 255, 255))
|
| 392 |
+
final.paste(img, (0, 0))
|
| 393 |
+
final.paste(strip, (0, H))
|
| 394 |
+
final.paste(leg, (0, H + strip_h))
|
| 395 |
+
return final
|
| 396 |
+
|
| 397 |
+
except Exception as e:
|
| 398 |
+
print(f" draw_selection error: {e}")
|
| 399 |
+
return item.get("image", None)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def make_compression_md(item, cfgs):
|
| 403 |
+
"""Build a markdown table showing kept / dropped word statistics and
|
| 404 |
+
a sample of the words that each method discards.
|
| 405 |
+
|
| 406 |
+
cfgs β OrderedDict/dict {method_name: (alpha, beta, gamma)}
|
| 407 |
+
"""
|
| 408 |
+
words = list(item.get(WORD_FIELD, []))
|
| 409 |
+
n = len(words)
|
| 410 |
+
if n == 0:
|
| 411 |
+
return "*No OCR words available for this document.*"
|
| 412 |
+
|
| 413 |
+
md = "### π What Gets Compressed?\n\n"
|
| 414 |
+
md += f"**Total OCR words in document:** {n}\n\n"
|
| 415 |
+
md += ("| Method | Ξ± | Ξ² | Ξ³ | Words Kept | % Context |"
|
| 416 |
+
" Sample Dropped Words |\n")
|
| 417 |
+
md += ("|--------|---|---|---|:----------:|:---------:|"
|
| 418 |
+
"----------------------|\n")
|
| 419 |
+
|
| 420 |
+
for name, (a, b, g) in cfgs.items():
|
| 421 |
+
sel = get_sel_idx(item, a, b, g)
|
| 422 |
+
n_keep = len(sel)
|
| 423 |
+
pct = 100 * n_keep / max(n, 1)
|
| 424 |
+
dropped = [words[i] for i in range(n) if i not in sel]
|
| 425 |
+
d_preview = " Β· ".join(dropped[:8])
|
| 426 |
+
if len(dropped) > 8:
|
| 427 |
+
d_preview += f" β¦ (+{len(dropped) - 8} more)"
|
| 428 |
+
md += (f"| **{name}** | {a:.2f} | {b:.2f} | {g:.2f}"
|
| 429 |
+
f" | {n_keep} / {n} | {pct:.0f}% | `{d_preview}` |\n")
|
| 430 |
+
|
| 431 |
+
# Show the actual kept words for the CAFP+REINFORCE method
|
| 432 |
+
if "CAFP+REINFORCE" in cfgs:
|
| 433 |
+
a, b, g = cfgs["CAFP+REINFORCE"]
|
| 434 |
+
sel = get_sel_idx(item, a, b, g)
|
| 435 |
+
kept_w = [words[i] for i in sorted(sel)[:25]]
|
| 436 |
+
md += (f"\n**CAFP+REINFORCE β kept words (first 25 shown):** \n"
|
| 437 |
+
f"`{' Β· '.join(kept_w)}`\n")
|
| 438 |
+
|
| 439 |
+
return md
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 443 |
+
# LOAD CHECKPOINTS & DATA
|
| 444 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 445 |
+
print("\nLoading checkpoints and data...")
|
| 446 |
+
|
| 447 |
+
# ββ RL checkpoint βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 448 |
+
if not os.path.exists(CKPT_PATH):
|
| 449 |
+
sys.exit(f"β Checkpoint not found: {CKPT_PATH}\n"
|
| 450 |
+
f" Copy cafp_rl_checkpoint_final.pt into {CKPT_DIR}/")
|
| 451 |
+
|
| 452 |
+
ck = torch.load(CKPT_PATH, map_location=device, weights_only=False)
|
| 453 |
+
spatial_proj.load_state_dict(ck["spatial_proj_state"])
|
| 454 |
+
|
| 455 |
+
cafp_soft = CrossAttentionFusionPredictor().to(device)
|
| 456 |
+
cafp_soft.load_state_dict(ck["cafp_soft_state"]); cafp_soft.eval()
|
| 457 |
+
|
| 458 |
+
cafp_rl = copy.deepcopy(cafp_soft)
|
| 459 |
+
cafp_rl.load_state_dict(ck["cafp_rl_state"]); cafp_rl.eval()
|
| 460 |
+
|
| 461 |
+
rl_train_anls = ck["rl_train_anls"]
|
| 462 |
+
rl_val_anls = ck.get("rl_val_anls",
|
| 463 |
+
max(rl_train_anls) if rl_train_anls else 0.0)
|
| 464 |
+
print(f" β
CAFP+REINFORCE: {len(rl_train_anls)} epochs | "
|
| 465 |
+
f"best_train={max(rl_train_anls):.4f} | val={rl_val_anls:.4f}")
|
| 466 |
+
|
| 467 |
+
# ββ Dataset βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 468 |
+
print(" Loading dataset (~30s)...")
|
| 469 |
+
from datasets import load_dataset
|
| 470 |
+
_ds = load_dataset(DATASET_NAME, split="train")
|
| 471 |
+
_split = _ds.train_test_split(test_size=0.2, seed=42)
|
| 472 |
+
rng = np.random.RandomState(42)
|
| 473 |
+
val_idx = rng.permutation(len(_split["test"])).tolist()[:N_VAL]
|
| 474 |
+
train_idx = rng.permutation(len(_split["train"])).tolist()[:N_TRAIN]
|
| 475 |
+
val_items = [_split["test"][i] for i in val_idx]
|
| 476 |
+
train_items = [_split["train"][i] for i in train_idx]
|
| 477 |
+
val_gts = [item[ANSWER_FIELD] for item in val_items]
|
| 478 |
+
train_gts = [item[ANSWER_FIELD] for item in train_items]
|
| 479 |
+
print(f" β
Dataset: {len(val_items)} val, {len(train_items)} train")
|
| 480 |
+
|
| 481 |
+
# ββ Feature tensors βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 482 |
+
if os.path.exists(FEAT_PATH):
|
| 483 |
+
t = torch.load(FEAT_PATH, map_location=device, weights_only=False)
|
| 484 |
+
val_feats = t["val_feats"]
|
| 485 |
+
train_feats = t["train_feats"]
|
| 486 |
+
print(f" β
Features: {tuple(val_feats.shape)}")
|
| 487 |
+
else:
|
| 488 |
+
print(" β οΈ feature_tensors.pt not found β recomputing (~2 min)...")
|
| 489 |
+
def _feats(items, tag):
|
| 490 |
+
out = []
|
| 491 |
+
for i, item in enumerate(items):
|
| 492 |
+
out.append(build_feature_vector(
|
| 493 |
+
extract_rich_features(item)).unsqueeze(0))
|
| 494 |
+
if (i + 1) % 10 == 0:
|
| 495 |
+
print(f" {tag}: {i+1}/{len(items)}", end="\r")
|
| 496 |
+
print()
|
| 497 |
+
return torch.cat(out).to(device)
|
| 498 |
+
val_feats = _feats(val_items, "val")
|
| 499 |
+
train_feats = _feats(train_items, "train")
|
| 500 |
+
torch.save({"val_feats": val_feats, "train_feats": train_feats,
|
| 501 |
+
"val_gts": val_gts, "train_gts": train_gts}, FEAT_PATH)
|
| 502 |
+
print(f" β
Features computed and saved to {FEAT_PATH}")
|
| 503 |
+
|
| 504 |
+
# ββ Oracle cache ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 505 |
+
val_oracle = train_oracle = []
|
| 506 |
+
if os.path.exists(ORACLE_CACHE):
|
| 507 |
+
_oc = json.load(open(ORACLE_CACHE))
|
| 508 |
+
train_oracle = _oc.get("train", [])
|
| 509 |
+
val_oracle = _oc.get("val", [])
|
| 510 |
+
print(f" β
Oracle cache: {len(train_oracle)} train, {len(val_oracle)} val")
|
| 511 |
+
else:
|
| 512 |
+
print(" β οΈ oracle_cache.json not found β demo works without it")
|
| 513 |
+
|
| 514 |
+
# ββ Results from JSON βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 515 |
+
results = {}
|
| 516 |
+
_RKEYS = [
|
| 517 |
+
"Equal Fusion", "Proposed Fixed", "Text-Only",
|
| 518 |
+
"LLMLingua-style", "Selective Context-style",
|
| 519 |
+
"CAFP (paper checkpoint)", "CAFP-Hard Oracle", "CAFP-Soft Oracle",
|
| 520 |
+
]
|
| 521 |
+
for _rpath in [RESULTS_PATH, "./final_results.json", "./results_condensed.json"]:
|
| 522 |
+
try:
|
| 523 |
+
_raw = json.load(open(_rpath))
|
| 524 |
+
for k in _RKEYS:
|
| 525 |
+
if k in _raw and isinstance(_raw[k], dict):
|
| 526 |
+
r = _raw[k]
|
| 527 |
+
results[k] = {
|
| 528 |
+
"mean_anls": float(r.get("mean_anls", r.get("anls", 0.0))),
|
| 529 |
+
"mean_f1": float(r.get("mean_f1", r.get("f1", 0.0))),
|
| 530 |
+
}
|
| 531 |
+
if results:
|
| 532 |
+
print(f" β
Results: {len(results)} methods from {_rpath}")
|
| 533 |
+
break
|
| 534 |
+
except Exception:
|
| 535 |
+
continue
|
| 536 |
+
if not results:
|
| 537 |
+
print(" β οΈ Results JSON not found β dashboard will show partial data")
|
| 538 |
+
|
| 539 |
+
# ββ Find best demo documents ββββββββββββββββββββββββββββββββββββββββββ
|
| 540 |
+
print("\nPre-scoring documents for demo (this takes ~2 min)...")
|
| 541 |
+
demo_scores = []
|
| 542 |
+
cafp_rl.eval()
|
| 543 |
+
with torch.no_grad():
|
| 544 |
+
for i in range(len(val_items)):
|
| 545 |
+
fv = val_feats[i].unsqueeze(0)
|
| 546 |
+
conc = F.softplus(cafp_rl._logits(fv)) + 0.1
|
| 547 |
+
w = (conc / conc.sum()).squeeze(0).cpu().tolist()
|
| 548 |
+
rl_s = compute_anls(vqa_infer(val_items[i], w[0], w[1], w[2]),
|
| 549 |
+
val_gts[i])
|
| 550 |
+
fx_s = compute_anls(vqa_infer(val_items[i], 0.5, 0.3, 0.2),
|
| 551 |
+
val_gts[i])
|
| 552 |
+
demo_scores.append((i, round(rl_s - fx_s, 4),
|
| 553 |
+
round(rl_s, 4), round(fx_s, 4)))
|
| 554 |
+
if (i + 1) % 20 == 0:
|
| 555 |
+
print(f" {i+1}/100", end="\r")
|
| 556 |
+
demo_scores.sort(key=lambda x: -x[1])
|
| 557 |
+
best_idx = demo_scores[0][0]
|
| 558 |
+
top5_str = ", ".join([f"#{x[0]}(+{x[1]:.2f})" for x in demo_scores[:5]])
|
| 559 |
+
print(f"\n β
Best docs: {top5_str}")
|
| 560 |
+
print(f"\n{'='*55}")
|
| 561 |
+
print("ALL MODELS LOADED β ready to demo")
|
| 562 |
+
print(f"{'='*55}\n")
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 566 |
+
# GRADIO FUNCTIONS
|
| 567 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 568 |
+
def get_rl_weights(idx, custom_q=None):
|
| 569 |
+
if custom_q and custom_q.strip():
|
| 570 |
+
_item = dict(val_items[idx])
|
| 571 |
+
_item[QUERY_FIELD] = custom_q.strip()
|
| 572 |
+
fv = build_feature_vector(extract_rich_features(_item)).unsqueeze(0)
|
| 573 |
+
else:
|
| 574 |
+
fv = val_feats[idx].unsqueeze(0)
|
| 575 |
+
with torch.no_grad():
|
| 576 |
+
conc = F.softplus(cafp_rl._logits(fv)) + 0.1
|
| 577 |
+
w = (conc / conc.sum()).squeeze(0).cpu().tolist()
|
| 578 |
+
return w
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
def make_weight_chart(mw):
|
| 582 |
+
fig, ax = plt.subplots(figsize=(9, 3.5))
|
| 583 |
+
labels = list(mw.keys())
|
| 584 |
+
x, bw = np.arange(len(labels)), 0.25
|
| 585 |
+
for j, (lbl, col) in enumerate([
|
| 586 |
+
("\u03b1 Text", "#2196F3"),
|
| 587 |
+
("\u03b2 Visual", "#4CAF50"),
|
| 588 |
+
("\u03b3 Spatial", "#FF9800"),
|
| 589 |
+
]):
|
| 590 |
+
vals = [list(mw.values())[i][j] for i in range(len(labels))]
|
| 591 |
+
bars = ax.bar(x + (j - 1) * bw, vals, bw,
|
| 592 |
+
label=lbl, color=col, alpha=0.85)
|
| 593 |
+
for bar in bars:
|
| 594 |
+
h = bar.get_height()
|
| 595 |
+
ax.text(bar.get_x() + bar.get_width() / 2, h + 0.01,
|
| 596 |
+
f"{h:.2f}", ha="center", va="bottom", fontsize=9)
|
| 597 |
+
ax.set_xticks(x); ax.set_xticklabels(labels, fontsize=10)
|
| 598 |
+
ax.set_ylabel("Weight"); ax.set_ylim(0, 1.2)
|
| 599 |
+
ax.set_title("Fusion Weights (\u03b1, \u03b2, \u03b3) per Method",
|
| 600 |
+
fontsize=12, fontweight="bold")
|
| 601 |
+
ax.legend(fontsize=9); ax.grid(axis="y", alpha=0.3)
|
| 602 |
+
plt.tight_layout()
|
| 603 |
+
return fig
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def run_demo(doc_idx, custom_q):
|
| 607 |
+
doc_idx = int(doc_idx)
|
| 608 |
+
item = val_items[doc_idx]
|
| 609 |
+
gt = val_gts[doc_idx]
|
| 610 |
+
q = (custom_q.strip()
|
| 611 |
+
if custom_q and custom_q.strip()
|
| 612 |
+
else get_question(item))
|
| 613 |
+
gt_str = (", ".join(str(g) for g in gt[:2])
|
| 614 |
+
if isinstance(gt, list) else str(gt))
|
| 615 |
+
n_words = len(list(item.get(WORD_FIELD, [])))
|
| 616 |
+
doc_type = "Text-dominant" if n_words > 40 else "Visual-dominant"
|
| 617 |
+
|
| 618 |
+
alpha, beta, gamma = get_rl_weights(doc_idx, custom_q)
|
| 619 |
+
dom = ("Text" if alpha > 0.65 else
|
| 620 |
+
"Visual" if beta > 0.40 else "Balanced")
|
| 621 |
+
|
| 622 |
+
cfgs = {
|
| 623 |
+
"Equal Fusion": (1/3, 1/3, 1/3),
|
| 624 |
+
"Fixed (0.5,0.3,0.2)": (0.5, 0.3, 0.2),
|
| 625 |
+
"Text-Only": (1.0, 0.0, 0.0),
|
| 626 |
+
"CAFP+REINFORCE": (alpha, beta, gamma),
|
| 627 |
+
}
|
| 628 |
+
res = {}
|
| 629 |
+
for name, (a, b, g) in cfgs.items():
|
| 630 |
+
demo_item = dict(item); demo_item[QUERY_FIELD] = q
|
| 631 |
+
pred = vqa_infer(demo_item, a, b, g)
|
| 632 |
+
res[name] = {
|
| 633 |
+
"pred": pred,
|
| 634 |
+
"anls": compute_anls(pred, gt),
|
| 635 |
+
"f1": compute_f1(pred, gt),
|
| 636 |
+
"w": (a, b, g),
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
best = max(res, key=lambda k: res[k]["anls"])
|
| 640 |
+
rl_vs_fixed = res["CAFP+REINFORCE"]["anls"] - res["Fixed (0.5,0.3,0.2)"]["anls"]
|
| 641 |
+
|
| 642 |
+
md = f"## Document #{doc_idx} \u2014 {doc_type} ({n_words} words)\n\n"
|
| 643 |
+
md += f"**Question:** {q}\n\n"
|
| 644 |
+
md += f"**Ground Truth:** `{gt_str}`\n\n---\n"
|
| 645 |
+
md += "### Step 1 \u2014 Text Extraction\n"
|
| 646 |
+
md += f"`{n_words}` OCR words extracted via LayoutLMv3\n\n"
|
| 647 |
+
md += "### Step 2 \u2014 Multimodal Feature Extraction\n"
|
| 648 |
+
md += ("- **Text** \u2192 LayoutLMv3 token embeddings [768-D]\n"
|
| 649 |
+
"- **Visual** \u2192 LayoutLMv3 patch features [768-D]\n"
|
| 650 |
+
"- **Spatial** \u2192 Bounding box layout encoding [768-D]\n\n")
|
| 651 |
+
md += "### Step 3 \u2014 CAFP+REINFORCE Weight Prediction\n"
|
| 652 |
+
md += "| Modality | Weight |\n|----------|--------|\n"
|
| 653 |
+
md += f"| \u03b1 Text | **{alpha:.3f}** |\n"
|
| 654 |
+
md += f"| \u03b2 Visual | **{beta:.3f}** |\n"
|
| 655 |
+
md += f"| \u03b3 Spatial | **{gamma:.3f}** |\n\n"
|
| 656 |
+
md += f"\u2192 **Dominant: {dom}**\n\n"
|
| 657 |
+
md += "### Step 4 \u2014 Adaptive Fusion \u2192 Answer\n"
|
| 658 |
+
md += "| Method | \u03b1 | \u03b2 | \u03b3 | Answer | ANLS | F1 |\n"
|
| 659 |
+
md += "|--------|---|---|---|--------|------|----|\n"
|
| 660 |
+
for name, d in res.items():
|
| 661 |
+
a, b, g = d["w"]
|
| 662 |
+
star = " \u2b50" if name == best else ""
|
| 663 |
+
md += (f"| {name}{star} | {a:.2f} | {b:.2f} | {g:.2f}"
|
| 664 |
+
f" | `{d['pred']}` | **{d['anls']:.4f}** | {d['f1']:.4f} |\n")
|
| 665 |
+
sign = "+" if rl_vs_fixed >= 0 else ""
|
| 666 |
+
md += (f"\n---\n**Best:** {best} (ANLS: {res[best]['anls']:.4f})\n\n"
|
| 667 |
+
f"**CAFP+REINFORCE answer:** `{res['CAFP+REINFORCE']['pred']}`\n\n"
|
| 668 |
+
f"**\u0394 over Fixed:** {sign}{rl_vs_fixed:.4f}\n")
|
| 669 |
+
|
| 670 |
+
chart = make_weight_chart({k: v["w"] for k, v in res.items()})
|
| 671 |
+
|
| 672 |
+
# ββ Word Selection Visualizations βββββββββββββββββββββββββββββββββ
|
| 673 |
+
_item_q = dict(item); _item_q[QUERY_FIELD] = q
|
| 674 |
+
fixed_vis = draw_selection(
|
| 675 |
+
_item_q, 0.5, 0.3, 0.2,
|
| 676 |
+
"Fixed Weights (0.5, 0.3, 0.2)"
|
| 677 |
+
)
|
| 678 |
+
rl_vis = draw_selection(
|
| 679 |
+
_item_q, alpha, beta, gamma,
|
| 680 |
+
f"CAFP+REINFORCE (Ξ±={alpha:.2f} Ξ²={beta:.2f} Ξ³={gamma:.2f})"
|
| 681 |
+
)
|
| 682 |
+
comp_md = make_compression_md(item, cfgs)
|
| 683 |
+
|
| 684 |
+
return item.get("image", None), md, chart, fixed_vis, rl_vis, comp_md
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
def show_dashboard():
|
| 688 |
+
def sg(k): return results.get(k, {}).get("mean_anls", 0.0)
|
| 689 |
+
def sf(k): return results.get(k, {}).get("mean_f1", 0.0)
|
| 690 |
+
fixed = sg("Proposed Fixed"); oracle = 0.8377
|
| 691 |
+
rv = rl_val_anls
|
| 692 |
+
|
| 693 |
+
rows = [
|
| 694 |
+
("Equal Fusion", sg("Equal Fusion"), sf("Equal Fusion")),
|
| 695 |
+
("Proposed Fixed (paper)", sg("Proposed Fixed"), sf("Proposed Fixed")),
|
| 696 |
+
("Text-Only", sg("Text-Only"), sf("Text-Only")),
|
| 697 |
+
("LLMLingua-style [NEW]", sg("LLMLingua-style"), sf("LLMLingua-style")),
|
| 698 |
+
("Selective Context [NEW]", sg("Selective Context-style"), sf("Selective Context-style")),
|
| 699 |
+
("CAFP paper checkpoint", sg("CAFP (paper checkpoint)"), sf("CAFP (paper checkpoint)")),
|
| 700 |
+
("CAFP Hard Oracle [NEW]", sg("CAFP-Hard Oracle"), sf("CAFP-Hard Oracle")),
|
| 701 |
+
("CAFP Soft Oracle [NEW]", sg("CAFP-Soft Oracle"), sf("CAFP-Soft Oracle")),
|
| 702 |
+
("CAFP+REINFORCE [NEW][BEST]", rv, 0.0),
|
| 703 |
+
("Oracle Upper Bound", oracle, 0.0),
|
| 704 |
+
]
|
| 705 |
+
|
| 706 |
+
md = "## Full Experiment Results\n\n"
|
| 707 |
+
md += "| Method | ANLS | F1 | \u0394 Fixed | % Oracle |\n"
|
| 708 |
+
md += "|--------|------|----|----------|----------|\n"
|
| 709 |
+
for name, anls, f1 in rows:
|
| 710 |
+
is_oracle = "Oracle Upper" in name
|
| 711 |
+
d = f"{anls - fixed:+.4f}" if not is_oracle else "\u2014"
|
| 712 |
+
pct = f"{anls / oracle * 100:.1f}%" if anls > 0 else "\u2014"
|
| 713 |
+
md += f"| {name} | {anls:.4f} | {f1:.4f} | {d} | {pct} |\n"
|
| 714 |
+
md += (f"\n**CAFP+REINFORCE: {rv/oracle*100:.1f}% of Oracle ANLS**\n"
|
| 715 |
+
f"**Improvement over Fixed: {rv - fixed:+.4f} ANLS**\n")
|
| 716 |
+
|
| 717 |
+
# Bar chart
|
| 718 |
+
fig1, ax1 = plt.subplots(figsize=(11, 5))
|
| 719 |
+
bv = [r[1] for r in rows]
|
| 720 |
+
bc = ["#bbb","#999","#bbb","#2196F3","#2196F3",
|
| 721 |
+
"#777","#4CAF50","#4CAF50","#FF5722","#d32f2f"]
|
| 722 |
+
bars = ax1.barh([r[0] for r in rows], bv,
|
| 723 |
+
color=bc, edgecolor="white", height=0.65)
|
| 724 |
+
ax1.axvline(oracle, color="red", linestyle="--", lw=1.5,
|
| 725 |
+
label=f"Oracle {oracle:.4f}")
|
| 726 |
+
ax1.axvline(fixed, color="gray", linestyle=":", lw=1.2,
|
| 727 |
+
label=f"Fixed {fixed:.4f}")
|
| 728 |
+
for bar, val in zip(bars, bv):
|
| 729 |
+
if val > 0:
|
| 730 |
+
ax1.text(val + 0.003, bar.get_y() + bar.get_height() / 2,
|
| 731 |
+
f"{val:.4f}", va="center", fontsize=8)
|
| 732 |
+
ax1.set_xlabel("Val ANLS", fontsize=11)
|
| 733 |
+
ax1.set_title("All Methods \u2014 Val ANLS", fontsize=13, fontweight="bold")
|
| 734 |
+
ax1.legend(fontsize=9); ax1.invert_yaxis()
|
| 735 |
+
ax1.set_xlim(0, oracle * 1.1); ax1.grid(axis="x", alpha=0.3)
|
| 736 |
+
plt.tight_layout()
|
| 737 |
+
|
| 738 |
+
# Training curve
|
| 739 |
+
fig2, ax2 = plt.subplots(figsize=(10, 3.5))
|
| 740 |
+
eps = list(range(1, len(rl_train_anls) + 1))
|
| 741 |
+
ax2.plot(eps, rl_train_anls, "o-", color="#FF5722",
|
| 742 |
+
lw=2.5, ms=7, label="Train ANLS")
|
| 743 |
+
ax2.axhline(rv, color="#FF5722", linestyle=":", lw=2,
|
| 744 |
+
label=f"Val ANLS = {rv:.4f}")
|
| 745 |
+
ax2.axhline(oracle, color="red", linestyle="--", lw=1.5,
|
| 746 |
+
label=f"Oracle = {oracle:.4f}")
|
| 747 |
+
ax2.axhline(fixed, color="gray", linestyle=":", lw=1.2,
|
| 748 |
+
label=f"Fixed = {fixed:.4f}")
|
| 749 |
+
ax2.fill_between(eps, rl_train_anls, fixed, alpha=0.15, color="#FF5722")
|
| 750 |
+
ax2.set_xlabel("Epoch"); ax2.set_ylabel("ANLS")
|
| 751 |
+
ax2.set_title("REINFORCE Fine-tuning Progress",
|
| 752 |
+
fontsize=12, fontweight="bold")
|
| 753 |
+
ax2.legend(fontsize=9); ax2.grid(True, alpha=0.3)
|
| 754 |
+
ax2.set_xticks(eps); plt.tight_layout()
|
| 755 |
+
|
| 756 |
+
return md, fig1, fig2
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 760 |
+
# GRADIO UI
|
| 761 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 762 |
+
_fixed_anls = results.get("Proposed Fixed", {}).get("mean_anls", 0.0)
|
| 763 |
+
_best_label = ("Best docs (REINFORCE wins most): "
|
| 764 |
+
+ ", ".join([f"#{x[0]}" for x in demo_scores[:5]]))
|
| 765 |
+
|
| 766 |
+
CSS = ".tab-nav button { font-size: 15px !important; font-weight: 600 !important; }"
|
| 767 |
+
|
| 768 |
+
with gr.Blocks(
|
| 769 |
+
title="Adaptive Multimodal Fusion β DocVQA Demo",
|
| 770 |
+
theme=gr.themes.Soft(primary_hue="blue"),
|
| 771 |
+
css=CSS,
|
| 772 |
+
) as demo_app:
|
| 773 |
+
|
| 774 |
+
gr.Markdown("""
|
| 775 |
+
# Adaptive Multimodal Fusion for Document VQA
|
| 776 |
+
### Cross-Attention Fusion Predictor (CAFP) + REINFORCE Fine-tuning
|
| 777 |
+
""")
|
| 778 |
+
|
| 779 |
+
with gr.Tabs():
|
| 780 |
+
|
| 781 |
+
# ββ Tab 1: Live Demo ββββββββββββββββββββββββββββββββββββββββββ
|
| 782 |
+
with gr.TabItem("\U0001f3af Live Demo"):
|
| 783 |
+
gr.Markdown(f"**{_best_label}**")
|
| 784 |
+
with gr.Row():
|
| 785 |
+
with gr.Column(scale=1):
|
| 786 |
+
doc_slider = gr.Slider(
|
| 787 |
+
0, len(val_items) - 1,
|
| 788 |
+
value=best_idx, step=1,
|
| 789 |
+
label=f"Document Index (0\u2013{len(val_items)-1})"
|
| 790 |
+
)
|
| 791 |
+
custom_q = gr.Textbox(
|
| 792 |
+
label="Custom Question (optional)",
|
| 793 |
+
placeholder="Leave blank to use original question"
|
| 794 |
+
)
|
| 795 |
+
run_btn = gr.Button(
|
| 796 |
+
"\u25b6 Run Adaptive Fusion",
|
| 797 |
+
variant="primary", size="lg"
|
| 798 |
+
)
|
| 799 |
+
gr.Markdown(
|
| 800 |
+
"*Compares: Equal \u00b7 Fixed \u00b7 "
|
| 801 |
+
"Text-Only \u00b7 CAFP+REINFORCE*"
|
| 802 |
+
)
|
| 803 |
+
with gr.Column(scale=2):
|
| 804 |
+
doc_image = gr.Image(label="Document Image", height=400)
|
| 805 |
+
step_md = gr.Markdown()
|
| 806 |
+
weight_chart = gr.Plot(label="Fusion Weights Comparison")
|
| 807 |
+
|
| 808 |
+
# ββ Word Selection Visualizer βββββββββββββββββββββββββββββ
|
| 809 |
+
gr.Markdown("""
|
| 810 |
+
---
|
| 811 |
+
### π¨ Word Selection Visualization
|
| 812 |
+
*See **exactly** which OCR words each method keeps vs discards.*
|
| 813 |
+
π’ **Green** = kept and fed to the VQA model Β· π΄ **Red** = compressed out
|
| 814 |
+
""")
|
| 815 |
+
with gr.Row():
|
| 816 |
+
fixed_vis_img = gr.Image(
|
| 817 |
+
label="π Fixed Weights (Ξ±=0.5 Ξ²=0.3 Ξ³=0.2)",
|
| 818 |
+
height=520, show_download_button=True
|
| 819 |
+
)
|
| 820 |
+
rl_vis_img = gr.Image(
|
| 821 |
+
label="π€ CAFP+REINFORCE (Adaptive Weights)",
|
| 822 |
+
height=520, show_download_button=True
|
| 823 |
+
)
|
| 824 |
+
comp_md_out = gr.Markdown()
|
| 825 |
+
|
| 826 |
+
run_btn.click(
|
| 827 |
+
fn=run_demo,
|
| 828 |
+
inputs=[doc_slider, custom_q],
|
| 829 |
+
outputs=[doc_image, step_md, weight_chart,
|
| 830 |
+
fixed_vis_img, rl_vis_img, comp_md_out],
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
# ββ Tab 2: Results Dashboard ββββββββββββββββββββββββββββββββββ
|
| 834 |
+
with gr.TabItem("\U0001f4ca Results Dashboard"):
|
| 835 |
+
gr.Markdown("### All methods compared + REINFORCE training curve")
|
| 836 |
+
load_btn = gr.Button("Load Results", variant="secondary")
|
| 837 |
+
res_md = gr.Markdown()
|
| 838 |
+
with gr.Row():
|
| 839 |
+
bar_chart = gr.Plot(label="ANLS \u2014 All Methods")
|
| 840 |
+
rl_curve = gr.Plot(label="REINFORCE Training Curve")
|
| 841 |
+
load_btn.click(
|
| 842 |
+
fn=show_dashboard,
|
| 843 |
+
inputs=[],
|
| 844 |
+
outputs=[res_md, bar_chart, rl_curve],
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
# ββ Tab 3: About ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 848 |
+
with gr.TabItem("\u2139\ufe0f About"):
|
| 849 |
+
gr.Markdown(f"""
|
| 850 |
+
## Adaptive Multimodal Fusion for DocVQA
|
| 851 |
+
|
| 852 |
+
### Problem
|
| 853 |
+
DocVQA requires reasoning over three modalities simultaneously:
|
| 854 |
+
- **Text** β OCR words and their semantics
|
| 855 |
+
- **Visual** β Document appearance and image patches
|
| 856 |
+
- **Spatial** β Bounding box positions and layout structure
|
| 857 |
+
|
| 858 |
+
Fixed weights (Ξ±=0.5, Ξ²=0.3, Ξ³=0.2) cannot adapt to different document types.
|
| 859 |
+
|
| 860 |
+
### Architecture: CAFP (428K params)
|
| 861 |
+
1. Projects each modality embedding to 128-D
|
| 862 |
+
2. Cross-attention: question attends to all modality representations
|
| 863 |
+
3. Predicts per-document (Ξ±, Ξ², Ξ³) fusion weights
|
| 864 |
+
|
| 865 |
+
### Training Pipeline
|
| 866 |
+
1. **Hard Oracle** (MSE) β argmax weights from 20-combo grid search
|
| 867 |
+
2. **Soft Oracle** (KL-div) β temperature-smoothed ANLS-weighted targets
|
| 868 |
+
3. **REINFORCE** β Policy gradient on direct ANLS reward (K=3 samples/step)
|
| 869 |
+
|
| 870 |
+
### Novel Contributions
|
| 871 |
+
1. Soft Oracle training eliminates hard-oracle label noise
|
| 872 |
+
2. REINFORCE fine-tuning directly maximises DocVQA metric
|
| 873 |
+
3. LLMLingua-style and Selective Context baselines for fair comparison
|
| 874 |
+
|
| 875 |
+
### Key Result
|
| 876 |
+
**CAFP+REINFORCE achieves {rl_val_anls/0.8377*100:.1f}% of Oracle ANLS**
|
| 877 |
+
Improvement over fixed-weight baseline: {rl_val_anls - _fixed_anls:+.4f} ANLS
|
| 878 |
+
""")
|
| 879 |
+
|
| 880 |
+
|
| 881 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 882 |
+
# LAUNCH
|
| 883 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 884 |
+
if __name__ == "__main__":
|
| 885 |
+
demo_app.launch(
|
| 886 |
+
server_name="0.0.0.0",
|
| 887 |
+
server_port=args.port,
|
| 888 |
+
share=args.share,
|
| 889 |
+
show_error=True,
|
| 890 |
+
)
|
app.py
ADDED
|
@@ -0,0 +1,890 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adaptive Multimodal Fusion for DocVQA β Gradio Demo
|
| 3 |
+
====================================================
|
| 4 |
+
Run locally:
|
| 5 |
+
python app.py
|
| 6 |
+
|
| 7 |
+
Run with public URL (72hr):
|
| 8 |
+
python app.py --share
|
| 9 |
+
|
| 10 |
+
Deploy to HuggingFace Spaces:
|
| 11 |
+
- Push this file + requirements.txt + checkpoints/ folder to a Space repo
|
| 12 |
+
- HF Spaces auto-launches on port 7860
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import argparse, os, sys, copy, json, warnings
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
import matplotlib
|
| 21 |
+
matplotlib.use("Agg")
|
| 22 |
+
import matplotlib.pyplot as plt
|
| 23 |
+
import gradio as gr
|
| 24 |
+
import editdistance
|
| 25 |
+
from PIL import Image as PILImage, ImageDraw as PILDraw, ImageFont as PILFont
|
| 26 |
+
warnings.filterwarnings("ignore")
|
| 27 |
+
|
| 28 |
+
# ββ CLI args ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 29 |
+
parser = argparse.ArgumentParser()
|
| 30 |
+
parser.add_argument("--share", action="store_true", help="Create public Gradio URL")
|
| 31 |
+
parser.add_argument("--port", type=int, default=7860)
|
| 32 |
+
parser.add_argument("--ckpt_dir", type=str, default="./checkpoints",
|
| 33 |
+
help="Folder containing all saved files")
|
| 34 |
+
args, _ = parser.parse_known_args()
|
| 35 |
+
|
| 36 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
+
# CONFIGURATION β edit paths here if needed
|
| 38 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 39 |
+
CKPT_DIR = args.ckpt_dir
|
| 40 |
+
ORACLE_CACHE = os.path.join(CKPT_DIR, "oracle_cache.json")
|
| 41 |
+
FEAT_PATH = os.path.join(CKPT_DIR, "feature_tensors.pt")
|
| 42 |
+
RESULTS_PATH = os.path.join(CKPT_DIR, "final_results.json")
|
| 43 |
+
|
| 44 |
+
# Try final checkpoint first, fall back to intermediate
|
| 45 |
+
CKPT_PATH = os.path.join(CKPT_DIR, "cafp_rl_checkpoint_final.pt")
|
| 46 |
+
if not os.path.exists(CKPT_PATH):
|
| 47 |
+
CKPT_PATH = os.path.join(CKPT_DIR, "cafp_rl_checkpoint.pt")
|
| 48 |
+
|
| 49 |
+
# Dataset / model IDs
|
| 50 |
+
DATASET_NAME = "nielsr/docvqa_1200_examples"
|
| 51 |
+
FEAT_MODEL_ID = "microsoft/layoutlmv3-base"
|
| 52 |
+
VQA_MODEL_ID = "rubentito/layoutlmv3-base-mpdocvqa"
|
| 53 |
+
SBERT_ID = "all-MiniLM-L6-v2"
|
| 54 |
+
|
| 55 |
+
# Field names
|
| 56 |
+
WORD_FIELD = "words"
|
| 57 |
+
BOX_FIELD = "bounding_boxes"
|
| 58 |
+
QUERY_FIELD = "query"
|
| 59 |
+
ANSWER_FIELD = "answers"
|
| 60 |
+
|
| 61 |
+
# Architecture
|
| 62 |
+
MAX_WORDS = 64
|
| 63 |
+
N_PATCHES = 49
|
| 64 |
+
N_VAL = 100
|
| 65 |
+
N_TRAIN = 100
|
| 66 |
+
FEAT_DIM = 2701
|
| 67 |
+
PROJ_DIM = 128
|
| 68 |
+
|
| 69 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 70 |
+
print(f"Device: {device}")
|
| 71 |
+
|
| 72 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 73 |
+
# MODEL CLASSES
|
| 74 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 75 |
+
class CrossAttentionFusionPredictor(nn.Module):
|
| 76 |
+
def __init__(self, feat_dim=FEAT_DIM, proj_dim=PROJ_DIM,
|
| 77 |
+
n_heads=4, dropout=0.15):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.text_proj = nn.Linear(768, proj_dim)
|
| 80 |
+
self.visual_proj = nn.Linear(768, proj_dim)
|
| 81 |
+
self.spatial_proj_lyr = nn.Linear(768, proj_dim)
|
| 82 |
+
self.q_proj = nn.Sequential(
|
| 83 |
+
nn.Linear(384, proj_dim), nn.LayerNorm(proj_dim), nn.GELU()
|
| 84 |
+
)
|
| 85 |
+
self.cross_attn = nn.MultiheadAttention(
|
| 86 |
+
proj_dim, n_heads, dropout=dropout, batch_first=True
|
| 87 |
+
)
|
| 88 |
+
self.attn_norm = nn.LayerNorm(proj_dim)
|
| 89 |
+
self.head = nn.Sequential(
|
| 90 |
+
nn.Linear(proj_dim + 3, proj_dim), nn.GELU(),
|
| 91 |
+
nn.Dropout(dropout), nn.Linear(proj_dim, 3)
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def _logits(self, x):
|
| 95 |
+
h_t = self.text_proj(x[:, 0:768])
|
| 96 |
+
h_v = self.visual_proj(x[:, 768:1536])
|
| 97 |
+
h_s = self.spatial_proj_lyr(x[:, 1536:2304])
|
| 98 |
+
q = self.q_proj(x[:, 2314:2698]).unsqueeze(1)
|
| 99 |
+
kv = torch.stack([h_t, h_v, h_s], dim=1)
|
| 100 |
+
ctx, _ = self.cross_attn(q, kv, kv)
|
| 101 |
+
ctx = self.attn_norm(ctx.squeeze(1))
|
| 102 |
+
return self.head(torch.cat([ctx, x[:, 2698:2701]], dim=-1))
|
| 103 |
+
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
return F.softmax(self._logits(x), dim=-1)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 109 |
+
# LOAD BASE MODELS
|
| 110 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 111 |
+
print("Loading base models (takes ~5 min on first run, cached after)...")
|
| 112 |
+
|
| 113 |
+
from transformers import AutoProcessor, AutoModel, AutoModelForQuestionAnswering
|
| 114 |
+
from sentence_transformers import SentenceTransformer
|
| 115 |
+
|
| 116 |
+
feat_processor = AutoProcessor.from_pretrained(FEAT_MODEL_ID, apply_ocr=False)
|
| 117 |
+
feat_model = AutoModel.from_pretrained(FEAT_MODEL_ID).to(device).eval()
|
| 118 |
+
for p in feat_model.parameters(): p.requires_grad_(False)
|
| 119 |
+
print(" β
LayoutLMv3 feature model")
|
| 120 |
+
|
| 121 |
+
vqa_processor = AutoProcessor.from_pretrained(VQA_MODEL_ID, apply_ocr=False)
|
| 122 |
+
vqa_model = AutoModelForQuestionAnswering.from_pretrained(
|
| 123 |
+
VQA_MODEL_ID).to(device).eval()
|
| 124 |
+
for p in vqa_model.parameters(): p.requires_grad_(False)
|
| 125 |
+
print(" β
VQA model")
|
| 126 |
+
|
| 127 |
+
sbert = SentenceTransformer(SBERT_ID)
|
| 128 |
+
sbert.to(device)
|
| 129 |
+
print(" β
SBERT")
|
| 130 |
+
|
| 131 |
+
spatial_proj = nn.Sequential(
|
| 132 |
+
nn.Linear(10, 256), nn.ReLU(), nn.Linear(256, 768)
|
| 133 |
+
).to(device)
|
| 134 |
+
|
| 135 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 136 |
+
# HELPER FUNCTIONS
|
| 137 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 138 |
+
def get_question(item):
|
| 139 |
+
q = item.get(QUERY_FIELD, item.get("question", ""))
|
| 140 |
+
if isinstance(q, dict):
|
| 141 |
+
q = q.get("en", next(iter(q.values()), ""))
|
| 142 |
+
return str(q).strip()
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def normalize_boxes(boxes, w, h):
|
| 146 |
+
return [
|
| 147 |
+
[
|
| 148 |
+
int(max(0, min(b[0] / max(w, 1), 1)) * 1000),
|
| 149 |
+
int(max(0, min(b[1] / max(h, 1), 1)) * 1000),
|
| 150 |
+
int(max(0, min(b[2] / max(w, 1), 1)) * 1000),
|
| 151 |
+
int(max(0, min(b[3] / max(h, 1), 1)) * 1000),
|
| 152 |
+
]
|
| 153 |
+
for b in boxes
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def extract_rich_features(item):
|
| 158 |
+
try:
|
| 159 |
+
img = item["image"].convert("RGB")
|
| 160 |
+
W, H = img.size
|
| 161 |
+
words = list(item.get(WORD_FIELD, []))[:MAX_WORDS] or ["[PAD]"]
|
| 162 |
+
boxes = list(item.get(BOX_FIELD, []))[:MAX_WORDS] or [[0, 0, 1, 1]]
|
| 163 |
+
question = get_question(item)
|
| 164 |
+
bn = normalize_boxes(boxes, W, H)
|
| 165 |
+
enc = feat_processor(img, text=words, boxes=bn,
|
| 166 |
+
return_tensors="pt", truncation=True,
|
| 167 |
+
max_length=512, padding="max_length")
|
| 168 |
+
enc = {k: v.to(device) for k, v in enc.items()}
|
| 169 |
+
with torch.no_grad():
|
| 170 |
+
hidden = feat_model(**enc).last_hidden_state[0]
|
| 171 |
+
n_txt = max(2, hidden.shape[0] - N_PATCHES)
|
| 172 |
+
H_text = hidden[1:n_txt-1].mean(0) if n_txt > 2 else hidden[0]
|
| 173 |
+
H_visual = hidden[-N_PATCHES:].mean(0)
|
| 174 |
+
bx = np.array(bn, dtype=np.float32)
|
| 175 |
+
cx = ((bx[:, 0] + bx[:, 2]) / 2) / 1000.0
|
| 176 |
+
cy = ((bx[:, 1] + bx[:, 3]) / 2) / 1000.0
|
| 177 |
+
sp = np.array([
|
| 178 |
+
W / 1000.0, H / 1000.0, min(W, H) / max(W, H),
|
| 179 |
+
len(words) / MAX_WORDS,
|
| 180 |
+
cx.mean(), cy.mean(), cx.std() + 1e-6, cy.std() + 1e-6,
|
| 181 |
+
H_text.norm().item() / 10.0,
|
| 182 |
+
H_visual.norm().item() / 10.0,
|
| 183 |
+
], dtype=np.float32)
|
| 184 |
+
sp10 = torch.tensor(sp).to(device)
|
| 185 |
+
H_spat = spatial_proj(sp10.unsqueeze(0)).squeeze(0)
|
| 186 |
+
q_emb = torch.tensor(sbert.encode(question),
|
| 187 |
+
dtype=torch.float32).to(device)
|
| 188 |
+
return {
|
| 189 |
+
"H_text": H_text, "H_visual": H_visual, "H_spatial": H_spat,
|
| 190 |
+
"spatial_10": sp10, "question_emb": q_emb,
|
| 191 |
+
"text_score": float(np.clip(sp[8], 0, 1)),
|
| 192 |
+
"visual_score": float(np.clip(sp[9], 0, 1)),
|
| 193 |
+
"spatial_score": float(np.clip(sp[6], 0, 1)),
|
| 194 |
+
"n_tokens": len(words),
|
| 195 |
+
}
|
| 196 |
+
except Exception as e:
|
| 197 |
+
print(f" extract_rich_features error: {e}")
|
| 198 |
+
dummy = torch.zeros(768, device=device)
|
| 199 |
+
return {
|
| 200 |
+
"H_text": dummy, "H_visual": dummy, "H_spatial": dummy,
|
| 201 |
+
"spatial_10": torch.zeros(10, device=device),
|
| 202 |
+
"question_emb": torch.zeros(384, device=device),
|
| 203 |
+
"text_score": 0.5, "visual_score": 0.3, "spatial_score": 0.2,
|
| 204 |
+
"n_tokens": 0,
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def build_feature_vector(feat):
|
| 209 |
+
return torch.cat([
|
| 210 |
+
feat["H_text"], feat["H_visual"], feat["H_spatial"],
|
| 211 |
+
feat["spatial_10"], feat["question_emb"],
|
| 212 |
+
torch.tensor(
|
| 213 |
+
[feat["text_score"], feat["visual_score"], feat["spatial_score"]],
|
| 214 |
+
dtype=torch.float32, device=device
|
| 215 |
+
),
|
| 216 |
+
])
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def vqa_infer(item, alpha, beta, gamma):
|
| 220 |
+
try:
|
| 221 |
+
img = item["image"].convert("RGB")
|
| 222 |
+
words = list(item.get(WORD_FIELD, []))
|
| 223 |
+
boxes = list(item.get(BOX_FIELD, []))
|
| 224 |
+
question = get_question(item)
|
| 225 |
+
if not words:
|
| 226 |
+
return ""
|
| 227 |
+
W, H = img.size
|
| 228 |
+
n = len(words)
|
| 229 |
+
n_keep = max(int(n * max(float(alpha), 0.30)), min(5, n))
|
| 230 |
+
if float(gamma) > max(float(alpha), float(beta)) and boxes:
|
| 231 |
+
order = sorted(range(n), key=lambda i: (boxes[i][1], boxes[i][0]))
|
| 232 |
+
sel_idx = sorted(order[:n_keep])
|
| 233 |
+
else:
|
| 234 |
+
sel_idx = list(range(n_keep))
|
| 235 |
+
sw = [words[i] for i in sel_idx]
|
| 236 |
+
sb = ([boxes[i] for i in sel_idx]
|
| 237 |
+
if boxes else [[0, 0, W, H]] * len(sw))
|
| 238 |
+
enc = vqa_processor(
|
| 239 |
+
img, text=question, text_pair=sw,
|
| 240 |
+
boxes=normalize_boxes(sb, W, H),
|
| 241 |
+
return_tensors="pt", truncation=True,
|
| 242 |
+
max_length=512, padding=True
|
| 243 |
+
)
|
| 244 |
+
enc = {k: v.to(device) for k, v in enc.items()}
|
| 245 |
+
with torch.no_grad():
|
| 246 |
+
out = vqa_model(**enc)
|
| 247 |
+
s = int(out.start_logits.argmax())
|
| 248 |
+
e = int(out.end_logits.argmax())
|
| 249 |
+
if e < s: e = s
|
| 250 |
+
return vqa_processor.tokenizer.decode(
|
| 251 |
+
enc["input_ids"][0][s:e+1], skip_special_tokens=True
|
| 252 |
+
).strip()
|
| 253 |
+
except Exception:
|
| 254 |
+
return ""
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def compute_anls(pred, gts, threshold=0.5):
|
| 258 |
+
if isinstance(gts, str): gts = [gts]
|
| 259 |
+
if not gts or not pred: return 0.0
|
| 260 |
+
p, best = str(pred).lower().strip(), 0.0
|
| 261 |
+
for gt in gts:
|
| 262 |
+
g = str(gt).lower().strip()
|
| 263 |
+
ml = max(len(p), len(g))
|
| 264 |
+
if ml == 0:
|
| 265 |
+
best = max(best, 1.0); continue
|
| 266 |
+
nls = 1.0 - editdistance.eval(p, g) / ml
|
| 267 |
+
if nls < threshold: nls = 0.0
|
| 268 |
+
best = max(best, nls)
|
| 269 |
+
return best
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def compute_f1(pred, gts):
|
| 273 |
+
if isinstance(gts, str): gts = [gts]
|
| 274 |
+
if not pred or not gts: return 0.0
|
| 275 |
+
pt = set(str(pred).lower().split())
|
| 276 |
+
if not pt: return 0.0
|
| 277 |
+
best = 0.0
|
| 278 |
+
for gt in gts:
|
| 279 |
+
gt_t = set(str(gt).lower().split())
|
| 280 |
+
if not gt_t: continue
|
| 281 |
+
common = pt & gt_t
|
| 282 |
+
if not common: continue
|
| 283 |
+
p = len(common) / len(pt)
|
| 284 |
+
r = len(common) / len(gt_t)
|
| 285 |
+
best = max(best, 2 * p * r / (p + r))
|
| 286 |
+
return best
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 290 |
+
# WORD SELECTION VISUALIZER HELPERS
|
| 291 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 292 |
+
|
| 293 |
+
def get_sel_idx(item, alpha, beta, gamma):
|
| 294 |
+
"""Return the SET of word indices kept by this (alpha, beta, gamma) config.
|
| 295 |
+
|
| 296 |
+
Mirrors the exact selection logic in vqa_infer so the boxes always
|
| 297 |
+
match what the model actually sees.
|
| 298 |
+
"""
|
| 299 |
+
words = list(item.get(WORD_FIELD, []))
|
| 300 |
+
boxes = list(item.get(BOX_FIELD, []))
|
| 301 |
+
n = len(words)
|
| 302 |
+
if n == 0:
|
| 303 |
+
return set()
|
| 304 |
+
n_keep = max(int(n * max(float(alpha), 0.30)), min(5, n))
|
| 305 |
+
n_keep = min(n_keep, n)
|
| 306 |
+
if float(gamma) > max(float(alpha), float(beta)) and boxes:
|
| 307 |
+
order = sorted(range(n), key=lambda i: (boxes[i][1], boxes[i][0]))
|
| 308 |
+
sel_idx = set(order[:n_keep])
|
| 309 |
+
else:
|
| 310 |
+
sel_idx = set(range(n_keep))
|
| 311 |
+
return sel_idx
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def draw_selection(item, alpha, beta, gamma, title=""):
|
| 315 |
+
"""Return a PIL Image with coloured bounding boxes overlaid.
|
| 316 |
+
|
| 317 |
+
π’ Green fill + outline β word KEPT (used for VQA)
|
| 318 |
+
π΄ Red fill + outline β word DROPPED (compressed out)
|
| 319 |
+
|
| 320 |
+
An info strip (dark) and a colour legend strip are appended below the
|
| 321 |
+
document image so the panel is self-explanatory at a glance.
|
| 322 |
+
"""
|
| 323 |
+
try:
|
| 324 |
+
img = item["image"].convert("RGB").copy()
|
| 325 |
+
W, H = img.size
|
| 326 |
+
words = list(item.get(WORD_FIELD, []))
|
| 327 |
+
boxes = list(item.get(BOX_FIELD, []))
|
| 328 |
+
n = min(len(words), len(boxes))
|
| 329 |
+
if n == 0:
|
| 330 |
+
return img
|
| 331 |
+
|
| 332 |
+
sel_idx = get_sel_idx(item, alpha, beta, gamma)
|
| 333 |
+
n_keep = len(sel_idx)
|
| 334 |
+
pct = 100 * n_keep / max(n, 1)
|
| 335 |
+
|
| 336 |
+
# ββ Draw semi-transparent coloured overlays βββββββββββββββββββ
|
| 337 |
+
overlay = PILImage.new("RGBA", img.size, (0, 0, 0, 0))
|
| 338 |
+
od = PILDraw.Draw(overlay)
|
| 339 |
+
for i in range(n):
|
| 340 |
+
try:
|
| 341 |
+
x0, y0, x1, y1 = (int(boxes[i][0]), int(boxes[i][1]),
|
| 342 |
+
int(boxes[i][2]), int(boxes[i][3]))
|
| 343 |
+
# Clamp to image bounds
|
| 344 |
+
x0, x1 = max(0, x0), min(W - 1, x1)
|
| 345 |
+
y0, y1 = max(0, y0), min(H - 1, y1)
|
| 346 |
+
if x1 <= x0 or y1 <= y0:
|
| 347 |
+
continue
|
| 348 |
+
if i in sel_idx:
|
| 349 |
+
od.rectangle([x0, y0, x1, y1],
|
| 350 |
+
fill=(0, 210, 0, 55),
|
| 351 |
+
outline=(0, 160, 0, 230), width=2)
|
| 352 |
+
else:
|
| 353 |
+
od.rectangle([x0, y0, x1, y1],
|
| 354 |
+
fill=(220, 30, 30, 40),
|
| 355 |
+
outline=(200, 0, 0, 170), width=1)
|
| 356 |
+
except Exception:
|
| 357 |
+
continue
|
| 358 |
+
img = PILImage.alpha_composite(img.convert("RGBA"), overlay).convert("RGB")
|
| 359 |
+
|
| 360 |
+
# ββ Load font (graceful fallback) βοΏ½οΏ½βββββββββββββββββββββββββββ
|
| 361 |
+
font_sm = PILFont.load_default()
|
| 362 |
+
for _fp in [
|
| 363 |
+
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
|
| 364 |
+
"/System/Library/Fonts/Supplemental/Arial.ttf",
|
| 365 |
+
"/Windows/Fonts/arial.ttf",
|
| 366 |
+
]:
|
| 367 |
+
try:
|
| 368 |
+
font_sm = PILFont.truetype(_fp, 13)
|
| 369 |
+
break
|
| 370 |
+
except Exception:
|
| 371 |
+
continue
|
| 372 |
+
|
| 373 |
+
# ββ Info strip (dark bar showing title + stats) βββββββββββββββ
|
| 374 |
+
strip_h = 36
|
| 375 |
+
strip = PILImage.new("RGB", (W, strip_h), (22, 22, 32))
|
| 376 |
+
sd = PILDraw.Draw(strip)
|
| 377 |
+
info_text = (f"{title} | β Kept: {n_keep}/{n} ({pct:.0f}%)"
|
| 378 |
+
f" | Ξ±={alpha:.2f} Ξ²={beta:.2f} Ξ³={gamma:.2f}")
|
| 379 |
+
sd.text((8, 11), info_text, fill=(220, 220, 220), font=font_sm)
|
| 380 |
+
|
| 381 |
+
# ββ Legend strip (light bar explaining colours) βββββββββββββββ
|
| 382 |
+
leg_h = 28
|
| 383 |
+
leg = PILImage.new("RGB", (W, leg_h), (246, 246, 246))
|
| 384 |
+
ld = PILDraw.Draw(leg)
|
| 385 |
+
ld.rectangle([8, 7, 24, 21], fill=(0, 180, 0), outline=(0, 130, 0, 255))
|
| 386 |
+
ld.text( [30, 8], "= Kept (used for VQA)", fill=(0, 110, 0), font=font_sm)
|
| 387 |
+
ld.rectangle([210, 7, 226, 21], fill=(220, 30, 30), outline=(170, 0, 0, 255))
|
| 388 |
+
ld.text( [232, 8], "= Dropped (compressed out)", fill=(140, 0, 0), font=font_sm)
|
| 389 |
+
|
| 390 |
+
# ββ Stack: image β dark strip β legend ββββββββββββββββββββββββ
|
| 391 |
+
final = PILImage.new("RGB", (W, H + strip_h + leg_h), (255, 255, 255))
|
| 392 |
+
final.paste(img, (0, 0))
|
| 393 |
+
final.paste(strip, (0, H))
|
| 394 |
+
final.paste(leg, (0, H + strip_h))
|
| 395 |
+
return final
|
| 396 |
+
|
| 397 |
+
except Exception as e:
|
| 398 |
+
print(f" draw_selection error: {e}")
|
| 399 |
+
return item.get("image", None)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def make_compression_md(item, cfgs):
|
| 403 |
+
"""Build a markdown table showing kept / dropped word statistics and
|
| 404 |
+
a sample of the words that each method discards.
|
| 405 |
+
|
| 406 |
+
cfgs β OrderedDict/dict {method_name: (alpha, beta, gamma)}
|
| 407 |
+
"""
|
| 408 |
+
words = list(item.get(WORD_FIELD, []))
|
| 409 |
+
n = len(words)
|
| 410 |
+
if n == 0:
|
| 411 |
+
return "*No OCR words available for this document.*"
|
| 412 |
+
|
| 413 |
+
md = "### π What Gets Compressed?\n\n"
|
| 414 |
+
md += f"**Total OCR words in document:** {n}\n\n"
|
| 415 |
+
md += ("| Method | Ξ± | Ξ² | Ξ³ | Words Kept | % Context |"
|
| 416 |
+
" Sample Dropped Words |\n")
|
| 417 |
+
md += ("|--------|---|---|---|:----------:|:---------:|"
|
| 418 |
+
"----------------------|\n")
|
| 419 |
+
|
| 420 |
+
for name, (a, b, g) in cfgs.items():
|
| 421 |
+
sel = get_sel_idx(item, a, b, g)
|
| 422 |
+
n_keep = len(sel)
|
| 423 |
+
pct = 100 * n_keep / max(n, 1)
|
| 424 |
+
dropped = [words[i] for i in range(n) if i not in sel]
|
| 425 |
+
d_preview = " Β· ".join(dropped[:8])
|
| 426 |
+
if len(dropped) > 8:
|
| 427 |
+
d_preview += f" β¦ (+{len(dropped) - 8} more)"
|
| 428 |
+
md += (f"| **{name}** | {a:.2f} | {b:.2f} | {g:.2f}"
|
| 429 |
+
f" | {n_keep} / {n} | {pct:.0f}% | `{d_preview}` |\n")
|
| 430 |
+
|
| 431 |
+
# Show the actual kept words for the CAFP+REINFORCE method
|
| 432 |
+
if "CAFP+REINFORCE" in cfgs:
|
| 433 |
+
a, b, g = cfgs["CAFP+REINFORCE"]
|
| 434 |
+
sel = get_sel_idx(item, a, b, g)
|
| 435 |
+
kept_w = [words[i] for i in sorted(sel)[:25]]
|
| 436 |
+
md += (f"\n**CAFP+REINFORCE β kept words (first 25 shown):** \n"
|
| 437 |
+
f"`{' Β· '.join(kept_w)}`\n")
|
| 438 |
+
|
| 439 |
+
return md
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 443 |
+
# LOAD CHECKPOINTS & DATA
|
| 444 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 445 |
+
print("\nLoading checkpoints and data...")
|
| 446 |
+
|
| 447 |
+
# ββ RL checkpoint βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 448 |
+
if not os.path.exists(CKPT_PATH):
|
| 449 |
+
sys.exit(f"β Checkpoint not found: {CKPT_PATH}\n"
|
| 450 |
+
f" Copy cafp_rl_checkpoint_final.pt into {CKPT_DIR}/")
|
| 451 |
+
|
| 452 |
+
ck = torch.load(CKPT_PATH, map_location=device, weights_only=False)
|
| 453 |
+
spatial_proj.load_state_dict(ck["spatial_proj_state"])
|
| 454 |
+
|
| 455 |
+
cafp_soft = CrossAttentionFusionPredictor().to(device)
|
| 456 |
+
cafp_soft.load_state_dict(ck["cafp_soft_state"]); cafp_soft.eval()
|
| 457 |
+
|
| 458 |
+
cafp_rl = copy.deepcopy(cafp_soft)
|
| 459 |
+
cafp_rl.load_state_dict(ck["cafp_rl_state"]); cafp_rl.eval()
|
| 460 |
+
|
| 461 |
+
rl_train_anls = ck["rl_train_anls"]
|
| 462 |
+
rl_val_anls = ck.get("rl_val_anls",
|
| 463 |
+
max(rl_train_anls) if rl_train_anls else 0.0)
|
| 464 |
+
print(f" β
CAFP+REINFORCE: {len(rl_train_anls)} epochs | "
|
| 465 |
+
f"best_train={max(rl_train_anls):.4f} | val={rl_val_anls:.4f}")
|
| 466 |
+
|
| 467 |
+
# ββ Dataset βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 468 |
+
print(" Loading dataset (~30s)...")
|
| 469 |
+
from datasets import load_dataset
|
| 470 |
+
_ds = load_dataset(DATASET_NAME, split="train")
|
| 471 |
+
_split = _ds.train_test_split(test_size=0.2, seed=42)
|
| 472 |
+
rng = np.random.RandomState(42)
|
| 473 |
+
val_idx = rng.permutation(len(_split["test"])).tolist()[:N_VAL]
|
| 474 |
+
train_idx = rng.permutation(len(_split["train"])).tolist()[:N_TRAIN]
|
| 475 |
+
val_items = [_split["test"][i] for i in val_idx]
|
| 476 |
+
train_items = [_split["train"][i] for i in train_idx]
|
| 477 |
+
val_gts = [item[ANSWER_FIELD] for item in val_items]
|
| 478 |
+
train_gts = [item[ANSWER_FIELD] for item in train_items]
|
| 479 |
+
print(f" β
Dataset: {len(val_items)} val, {len(train_items)} train")
|
| 480 |
+
|
| 481 |
+
# ββ Feature tensors βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 482 |
+
if os.path.exists(FEAT_PATH):
|
| 483 |
+
t = torch.load(FEAT_PATH, map_location=device, weights_only=False)
|
| 484 |
+
val_feats = t["val_feats"]
|
| 485 |
+
train_feats = t["train_feats"]
|
| 486 |
+
print(f" β
Features: {tuple(val_feats.shape)}")
|
| 487 |
+
else:
|
| 488 |
+
print(" β οΈ feature_tensors.pt not found β recomputing (~2 min)...")
|
| 489 |
+
def _feats(items, tag):
|
| 490 |
+
out = []
|
| 491 |
+
for i, item in enumerate(items):
|
| 492 |
+
out.append(build_feature_vector(
|
| 493 |
+
extract_rich_features(item)).unsqueeze(0))
|
| 494 |
+
if (i + 1) % 10 == 0:
|
| 495 |
+
print(f" {tag}: {i+1}/{len(items)}", end="\r")
|
| 496 |
+
print()
|
| 497 |
+
return torch.cat(out).to(device)
|
| 498 |
+
val_feats = _feats(val_items, "val")
|
| 499 |
+
train_feats = _feats(train_items, "train")
|
| 500 |
+
torch.save({"val_feats": val_feats, "train_feats": train_feats,
|
| 501 |
+
"val_gts": val_gts, "train_gts": train_gts}, FEAT_PATH)
|
| 502 |
+
print(f" β
Features computed and saved to {FEAT_PATH}")
|
| 503 |
+
|
| 504 |
+
# ββ Oracle cache ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 505 |
+
val_oracle = train_oracle = []
|
| 506 |
+
if os.path.exists(ORACLE_CACHE):
|
| 507 |
+
_oc = json.load(open(ORACLE_CACHE))
|
| 508 |
+
train_oracle = _oc.get("train", [])
|
| 509 |
+
val_oracle = _oc.get("val", [])
|
| 510 |
+
print(f" β
Oracle cache: {len(train_oracle)} train, {len(val_oracle)} val")
|
| 511 |
+
else:
|
| 512 |
+
print(" β οΈ oracle_cache.json not found β demo works without it")
|
| 513 |
+
|
| 514 |
+
# ββ Results from JSON βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 515 |
+
results = {}
|
| 516 |
+
_RKEYS = [
|
| 517 |
+
"Equal Fusion", "Proposed Fixed", "Text-Only",
|
| 518 |
+
"LLMLingua-style", "Selective Context-style",
|
| 519 |
+
"CAFP (paper checkpoint)", "CAFP-Hard Oracle", "CAFP-Soft Oracle",
|
| 520 |
+
]
|
| 521 |
+
for _rpath in [RESULTS_PATH, "./final_results.json", "./results_condensed.json"]:
|
| 522 |
+
try:
|
| 523 |
+
_raw = json.load(open(_rpath))
|
| 524 |
+
for k in _RKEYS:
|
| 525 |
+
if k in _raw and isinstance(_raw[k], dict):
|
| 526 |
+
r = _raw[k]
|
| 527 |
+
results[k] = {
|
| 528 |
+
"mean_anls": float(r.get("mean_anls", r.get("anls", 0.0))),
|
| 529 |
+
"mean_f1": float(r.get("mean_f1", r.get("f1", 0.0))),
|
| 530 |
+
}
|
| 531 |
+
if results:
|
| 532 |
+
print(f" β
Results: {len(results)} methods from {_rpath}")
|
| 533 |
+
break
|
| 534 |
+
except Exception:
|
| 535 |
+
continue
|
| 536 |
+
if not results:
|
| 537 |
+
print(" β οΈ Results JSON not found β dashboard will show partial data")
|
| 538 |
+
|
| 539 |
+
# ββ Find best demo documents ββββββββββββββββββββββββββββββββββββββββββ
|
| 540 |
+
print("\nPre-scoring documents for demo (this takes ~2 min)...")
|
| 541 |
+
demo_scores = []
|
| 542 |
+
cafp_rl.eval()
|
| 543 |
+
with torch.no_grad():
|
| 544 |
+
for i in range(len(val_items)):
|
| 545 |
+
fv = val_feats[i].unsqueeze(0)
|
| 546 |
+
conc = F.softplus(cafp_rl._logits(fv)) + 0.1
|
| 547 |
+
w = (conc / conc.sum()).squeeze(0).cpu().tolist()
|
| 548 |
+
rl_s = compute_anls(vqa_infer(val_items[i], w[0], w[1], w[2]),
|
| 549 |
+
val_gts[i])
|
| 550 |
+
fx_s = compute_anls(vqa_infer(val_items[i], 0.5, 0.3, 0.2),
|
| 551 |
+
val_gts[i])
|
| 552 |
+
demo_scores.append((i, round(rl_s - fx_s, 4),
|
| 553 |
+
round(rl_s, 4), round(fx_s, 4)))
|
| 554 |
+
if (i + 1) % 20 == 0:
|
| 555 |
+
print(f" {i+1}/100", end="\r")
|
| 556 |
+
demo_scores.sort(key=lambda x: -x[1])
|
| 557 |
+
best_idx = demo_scores[0][0]
|
| 558 |
+
top5_str = ", ".join([f"#{x[0]}(+{x[1]:.2f})" for x in demo_scores[:5]])
|
| 559 |
+
print(f"\n β
Best docs: {top5_str}")
|
| 560 |
+
print(f"\n{'='*55}")
|
| 561 |
+
print("ALL MODELS LOADED β ready to demo")
|
| 562 |
+
print(f"{'='*55}\n")
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 566 |
+
# GRADIO FUNCTIONS
|
| 567 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 568 |
+
def get_rl_weights(idx, custom_q=None):
|
| 569 |
+
if custom_q and custom_q.strip():
|
| 570 |
+
_item = dict(val_items[idx])
|
| 571 |
+
_item[QUERY_FIELD] = custom_q.strip()
|
| 572 |
+
fv = build_feature_vector(extract_rich_features(_item)).unsqueeze(0)
|
| 573 |
+
else:
|
| 574 |
+
fv = val_feats[idx].unsqueeze(0)
|
| 575 |
+
with torch.no_grad():
|
| 576 |
+
conc = F.softplus(cafp_rl._logits(fv)) + 0.1
|
| 577 |
+
w = (conc / conc.sum()).squeeze(0).cpu().tolist()
|
| 578 |
+
return w
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
def make_weight_chart(mw):
|
| 582 |
+
fig, ax = plt.subplots(figsize=(9, 3.5))
|
| 583 |
+
labels = list(mw.keys())
|
| 584 |
+
x, bw = np.arange(len(labels)), 0.25
|
| 585 |
+
for j, (lbl, col) in enumerate([
|
| 586 |
+
("\u03b1 Text", "#2196F3"),
|
| 587 |
+
("\u03b2 Visual", "#4CAF50"),
|
| 588 |
+
("\u03b3 Spatial", "#FF9800"),
|
| 589 |
+
]):
|
| 590 |
+
vals = [list(mw.values())[i][j] for i in range(len(labels))]
|
| 591 |
+
bars = ax.bar(x + (j - 1) * bw, vals, bw,
|
| 592 |
+
label=lbl, color=col, alpha=0.85)
|
| 593 |
+
for bar in bars:
|
| 594 |
+
h = bar.get_height()
|
| 595 |
+
ax.text(bar.get_x() + bar.get_width() / 2, h + 0.01,
|
| 596 |
+
f"{h:.2f}", ha="center", va="bottom", fontsize=9)
|
| 597 |
+
ax.set_xticks(x); ax.set_xticklabels(labels, fontsize=10)
|
| 598 |
+
ax.set_ylabel("Weight"); ax.set_ylim(0, 1.2)
|
| 599 |
+
ax.set_title("Fusion Weights (\u03b1, \u03b2, \u03b3) per Method",
|
| 600 |
+
fontsize=12, fontweight="bold")
|
| 601 |
+
ax.legend(fontsize=9); ax.grid(axis="y", alpha=0.3)
|
| 602 |
+
plt.tight_layout()
|
| 603 |
+
return fig
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def run_demo(doc_idx, custom_q):
|
| 607 |
+
doc_idx = int(doc_idx)
|
| 608 |
+
item = val_items[doc_idx]
|
| 609 |
+
gt = val_gts[doc_idx]
|
| 610 |
+
q = (custom_q.strip()
|
| 611 |
+
if custom_q and custom_q.strip()
|
| 612 |
+
else get_question(item))
|
| 613 |
+
gt_str = (", ".join(str(g) for g in gt[:2])
|
| 614 |
+
if isinstance(gt, list) else str(gt))
|
| 615 |
+
n_words = len(list(item.get(WORD_FIELD, [])))
|
| 616 |
+
doc_type = "Text-dominant" if n_words > 40 else "Visual-dominant"
|
| 617 |
+
|
| 618 |
+
alpha, beta, gamma = get_rl_weights(doc_idx, custom_q)
|
| 619 |
+
dom = ("Text" if alpha > 0.65 else
|
| 620 |
+
"Visual" if beta > 0.40 else "Balanced")
|
| 621 |
+
|
| 622 |
+
cfgs = {
|
| 623 |
+
"Equal Fusion": (1/3, 1/3, 1/3),
|
| 624 |
+
"Fixed (0.5,0.3,0.2)": (0.5, 0.3, 0.2),
|
| 625 |
+
"Text-Only": (1.0, 0.0, 0.0),
|
| 626 |
+
"CAFP+REINFORCE": (alpha, beta, gamma),
|
| 627 |
+
}
|
| 628 |
+
res = {}
|
| 629 |
+
for name, (a, b, g) in cfgs.items():
|
| 630 |
+
demo_item = dict(item); demo_item[QUERY_FIELD] = q
|
| 631 |
+
pred = vqa_infer(demo_item, a, b, g)
|
| 632 |
+
res[name] = {
|
| 633 |
+
"pred": pred,
|
| 634 |
+
"anls": compute_anls(pred, gt),
|
| 635 |
+
"f1": compute_f1(pred, gt),
|
| 636 |
+
"w": (a, b, g),
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
best = max(res, key=lambda k: res[k]["anls"])
|
| 640 |
+
rl_vs_fixed = res["CAFP+REINFORCE"]["anls"] - res["Fixed (0.5,0.3,0.2)"]["anls"]
|
| 641 |
+
|
| 642 |
+
md = f"## Document #{doc_idx} \u2014 {doc_type} ({n_words} words)\n\n"
|
| 643 |
+
md += f"**Question:** {q}\n\n"
|
| 644 |
+
md += f"**Ground Truth:** `{gt_str}`\n\n---\n"
|
| 645 |
+
md += "### Step 1 \u2014 Text Extraction\n"
|
| 646 |
+
md += f"`{n_words}` OCR words extracted via LayoutLMv3\n\n"
|
| 647 |
+
md += "### Step 2 \u2014 Multimodal Feature Extraction\n"
|
| 648 |
+
md += ("- **Text** \u2192 LayoutLMv3 token embeddings [768-D]\n"
|
| 649 |
+
"- **Visual** \u2192 LayoutLMv3 patch features [768-D]\n"
|
| 650 |
+
"- **Spatial** \u2192 Bounding box layout encoding [768-D]\n\n")
|
| 651 |
+
md += "### Step 3 \u2014 CAFP+REINFORCE Weight Prediction\n"
|
| 652 |
+
md += "| Modality | Weight |\n|----------|--------|\n"
|
| 653 |
+
md += f"| \u03b1 Text | **{alpha:.3f}** |\n"
|
| 654 |
+
md += f"| \u03b2 Visual | **{beta:.3f}** |\n"
|
| 655 |
+
md += f"| \u03b3 Spatial | **{gamma:.3f}** |\n\n"
|
| 656 |
+
md += f"\u2192 **Dominant: {dom}**\n\n"
|
| 657 |
+
md += "### Step 4 \u2014 Adaptive Fusion \u2192 Answer\n"
|
| 658 |
+
md += "| Method | \u03b1 | \u03b2 | \u03b3 | Answer | ANLS | F1 |\n"
|
| 659 |
+
md += "|--------|---|---|---|--------|------|----|\n"
|
| 660 |
+
for name, d in res.items():
|
| 661 |
+
a, b, g = d["w"]
|
| 662 |
+
star = " \u2b50" if name == best else ""
|
| 663 |
+
md += (f"| {name}{star} | {a:.2f} | {b:.2f} | {g:.2f}"
|
| 664 |
+
f" | `{d['pred']}` | **{d['anls']:.4f}** | {d['f1']:.4f} |\n")
|
| 665 |
+
sign = "+" if rl_vs_fixed >= 0 else ""
|
| 666 |
+
md += (f"\n---\n**Best:** {best} (ANLS: {res[best]['anls']:.4f})\n\n"
|
| 667 |
+
f"**CAFP+REINFORCE answer:** `{res['CAFP+REINFORCE']['pred']}`\n\n"
|
| 668 |
+
f"**\u0394 over Fixed:** {sign}{rl_vs_fixed:.4f}\n")
|
| 669 |
+
|
| 670 |
+
chart = make_weight_chart({k: v["w"] for k, v in res.items()})
|
| 671 |
+
|
| 672 |
+
# ββ Word Selection Visualizations βββββββββββββββββββββββββββββββββ
|
| 673 |
+
_item_q = dict(item); _item_q[QUERY_FIELD] = q
|
| 674 |
+
fixed_vis = draw_selection(
|
| 675 |
+
_item_q, 0.5, 0.3, 0.2,
|
| 676 |
+
"Fixed Weights (0.5, 0.3, 0.2)"
|
| 677 |
+
)
|
| 678 |
+
rl_vis = draw_selection(
|
| 679 |
+
_item_q, alpha, beta, gamma,
|
| 680 |
+
f"CAFP+REINFORCE (Ξ±={alpha:.2f} Ξ²={beta:.2f} Ξ³={gamma:.2f})"
|
| 681 |
+
)
|
| 682 |
+
comp_md = make_compression_md(item, cfgs)
|
| 683 |
+
|
| 684 |
+
return item.get("image", None), md, chart, fixed_vis, rl_vis, comp_md
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
def show_dashboard():
|
| 688 |
+
def sg(k): return results.get(k, {}).get("mean_anls", 0.0)
|
| 689 |
+
def sf(k): return results.get(k, {}).get("mean_f1", 0.0)
|
| 690 |
+
fixed = sg("Proposed Fixed"); oracle = 0.8377
|
| 691 |
+
rv = rl_val_anls
|
| 692 |
+
|
| 693 |
+
rows = [
|
| 694 |
+
("Equal Fusion", sg("Equal Fusion"), sf("Equal Fusion")),
|
| 695 |
+
("Proposed Fixed (paper)", sg("Proposed Fixed"), sf("Proposed Fixed")),
|
| 696 |
+
("Text-Only", sg("Text-Only"), sf("Text-Only")),
|
| 697 |
+
("LLMLingua-style [NEW]", sg("LLMLingua-style"), sf("LLMLingua-style")),
|
| 698 |
+
("Selective Context [NEW]", sg("Selective Context-style"), sf("Selective Context-style")),
|
| 699 |
+
("CAFP paper checkpoint", sg("CAFP (paper checkpoint)"), sf("CAFP (paper checkpoint)")),
|
| 700 |
+
("CAFP Hard Oracle [NEW]", sg("CAFP-Hard Oracle"), sf("CAFP-Hard Oracle")),
|
| 701 |
+
("CAFP Soft Oracle [NEW]", sg("CAFP-Soft Oracle"), sf("CAFP-Soft Oracle")),
|
| 702 |
+
("CAFP+REINFORCE [NEW][BEST]", rv, 0.0),
|
| 703 |
+
("Oracle Upper Bound", oracle, 0.0),
|
| 704 |
+
]
|
| 705 |
+
|
| 706 |
+
md = "## Full Experiment Results\n\n"
|
| 707 |
+
md += "| Method | ANLS | F1 | \u0394 Fixed | % Oracle |\n"
|
| 708 |
+
md += "|--------|------|----|----------|----------|\n"
|
| 709 |
+
for name, anls, f1 in rows:
|
| 710 |
+
is_oracle = "Oracle Upper" in name
|
| 711 |
+
d = f"{anls - fixed:+.4f}" if not is_oracle else "\u2014"
|
| 712 |
+
pct = f"{anls / oracle * 100:.1f}%" if anls > 0 else "\u2014"
|
| 713 |
+
md += f"| {name} | {anls:.4f} | {f1:.4f} | {d} | {pct} |\n"
|
| 714 |
+
md += (f"\n**CAFP+REINFORCE: {rv/oracle*100:.1f}% of Oracle ANLS**\n"
|
| 715 |
+
f"**Improvement over Fixed: {rv - fixed:+.4f} ANLS**\n")
|
| 716 |
+
|
| 717 |
+
# Bar chart
|
| 718 |
+
fig1, ax1 = plt.subplots(figsize=(11, 5))
|
| 719 |
+
bv = [r[1] for r in rows]
|
| 720 |
+
bc = ["#bbb","#999","#bbb","#2196F3","#2196F3",
|
| 721 |
+
"#777","#4CAF50","#4CAF50","#FF5722","#d32f2f"]
|
| 722 |
+
bars = ax1.barh([r[0] for r in rows], bv,
|
| 723 |
+
color=bc, edgecolor="white", height=0.65)
|
| 724 |
+
ax1.axvline(oracle, color="red", linestyle="--", lw=1.5,
|
| 725 |
+
label=f"Oracle {oracle:.4f}")
|
| 726 |
+
ax1.axvline(fixed, color="gray", linestyle=":", lw=1.2,
|
| 727 |
+
label=f"Fixed {fixed:.4f}")
|
| 728 |
+
for bar, val in zip(bars, bv):
|
| 729 |
+
if val > 0:
|
| 730 |
+
ax1.text(val + 0.003, bar.get_y() + bar.get_height() / 2,
|
| 731 |
+
f"{val:.4f}", va="center", fontsize=8)
|
| 732 |
+
ax1.set_xlabel("Val ANLS", fontsize=11)
|
| 733 |
+
ax1.set_title("All Methods \u2014 Val ANLS", fontsize=13, fontweight="bold")
|
| 734 |
+
ax1.legend(fontsize=9); ax1.invert_yaxis()
|
| 735 |
+
ax1.set_xlim(0, oracle * 1.1); ax1.grid(axis="x", alpha=0.3)
|
| 736 |
+
plt.tight_layout()
|
| 737 |
+
|
| 738 |
+
# Training curve
|
| 739 |
+
fig2, ax2 = plt.subplots(figsize=(10, 3.5))
|
| 740 |
+
eps = list(range(1, len(rl_train_anls) + 1))
|
| 741 |
+
ax2.plot(eps, rl_train_anls, "o-", color="#FF5722",
|
| 742 |
+
lw=2.5, ms=7, label="Train ANLS")
|
| 743 |
+
ax2.axhline(rv, color="#FF5722", linestyle=":", lw=2,
|
| 744 |
+
label=f"Val ANLS = {rv:.4f}")
|
| 745 |
+
ax2.axhline(oracle, color="red", linestyle="--", lw=1.5,
|
| 746 |
+
label=f"Oracle = {oracle:.4f}")
|
| 747 |
+
ax2.axhline(fixed, color="gray", linestyle=":", lw=1.2,
|
| 748 |
+
label=f"Fixed = {fixed:.4f}")
|
| 749 |
+
ax2.fill_between(eps, rl_train_anls, fixed, alpha=0.15, color="#FF5722")
|
| 750 |
+
ax2.set_xlabel("Epoch"); ax2.set_ylabel("ANLS")
|
| 751 |
+
ax2.set_title("REINFORCE Fine-tuning Progress",
|
| 752 |
+
fontsize=12, fontweight="bold")
|
| 753 |
+
ax2.legend(fontsize=9); ax2.grid(True, alpha=0.3)
|
| 754 |
+
ax2.set_xticks(eps); plt.tight_layout()
|
| 755 |
+
|
| 756 |
+
return md, fig1, fig2
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 760 |
+
# GRADIO UI
|
| 761 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 762 |
+
_fixed_anls = results.get("Proposed Fixed", {}).get("mean_anls", 0.0)
|
| 763 |
+
_best_label = ("Best docs (REINFORCE wins most): "
|
| 764 |
+
+ ", ".join([f"#{x[0]}" for x in demo_scores[:5]]))
|
| 765 |
+
|
| 766 |
+
CSS = ".tab-nav button { font-size: 15px !important; font-weight: 600 !important; }"
|
| 767 |
+
|
| 768 |
+
with gr.Blocks(
|
| 769 |
+
title="Adaptive Multimodal Fusion β DocVQA Demo",
|
| 770 |
+
theme=gr.themes.Soft(primary_hue="blue"),
|
| 771 |
+
css=CSS,
|
| 772 |
+
) as demo_app:
|
| 773 |
+
|
| 774 |
+
gr.Markdown("""
|
| 775 |
+
# Adaptive Multimodal Fusion for Document VQA
|
| 776 |
+
### Cross-Attention Fusion Predictor (CAFP) + REINFORCE Fine-tuning
|
| 777 |
+
""")
|
| 778 |
+
|
| 779 |
+
with gr.Tabs():
|
| 780 |
+
|
| 781 |
+
# ββ Tab 1: Live Demo ββββββββββββββββββββββββββββββββββββββββββ
|
| 782 |
+
with gr.TabItem("\U0001f3af Live Demo"):
|
| 783 |
+
gr.Markdown(f"**{_best_label}**")
|
| 784 |
+
with gr.Row():
|
| 785 |
+
with gr.Column(scale=1):
|
| 786 |
+
doc_slider = gr.Slider(
|
| 787 |
+
0, len(val_items) - 1,
|
| 788 |
+
value=best_idx, step=1,
|
| 789 |
+
label=f"Document Index (0\u2013{len(val_items)-1})"
|
| 790 |
+
)
|
| 791 |
+
custom_q = gr.Textbox(
|
| 792 |
+
label="Custom Question (optional)",
|
| 793 |
+
placeholder="Leave blank to use original question"
|
| 794 |
+
)
|
| 795 |
+
run_btn = gr.Button(
|
| 796 |
+
"\u25b6 Run Adaptive Fusion",
|
| 797 |
+
variant="primary", size="lg"
|
| 798 |
+
)
|
| 799 |
+
gr.Markdown(
|
| 800 |
+
"*Compares: Equal \u00b7 Fixed \u00b7 "
|
| 801 |
+
"Text-Only \u00b7 CAFP+REINFORCE*"
|
| 802 |
+
)
|
| 803 |
+
with gr.Column(scale=2):
|
| 804 |
+
doc_image = gr.Image(label="Document Image", height=400)
|
| 805 |
+
step_md = gr.Markdown()
|
| 806 |
+
weight_chart = gr.Plot(label="Fusion Weights Comparison")
|
| 807 |
+
|
| 808 |
+
# ββ Word Selection Visualizer βββββββββββββββββββββββββββββ
|
| 809 |
+
gr.Markdown("""
|
| 810 |
+
---
|
| 811 |
+
### π¨ Word Selection Visualization
|
| 812 |
+
*See **exactly** which OCR words each method keeps vs discards.*
|
| 813 |
+
π’ **Green** = kept and fed to the VQA model Β· π΄ **Red** = compressed out
|
| 814 |
+
""")
|
| 815 |
+
with gr.Row():
|
| 816 |
+
fixed_vis_img = gr.Image(
|
| 817 |
+
label="π Fixed Weights (Ξ±=0.5 Ξ²=0.3 Ξ³=0.2)",
|
| 818 |
+
height=520, show_download_button=True
|
| 819 |
+
)
|
| 820 |
+
rl_vis_img = gr.Image(
|
| 821 |
+
label="π€ CAFP+REINFORCE (Adaptive Weights)",
|
| 822 |
+
height=520, show_download_button=True
|
| 823 |
+
)
|
| 824 |
+
comp_md_out = gr.Markdown()
|
| 825 |
+
|
| 826 |
+
run_btn.click(
|
| 827 |
+
fn=run_demo,
|
| 828 |
+
inputs=[doc_slider, custom_q],
|
| 829 |
+
outputs=[doc_image, step_md, weight_chart,
|
| 830 |
+
fixed_vis_img, rl_vis_img, comp_md_out],
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
# ββ Tab 2: Results Dashboard ββββββββββββββββββββββββββββββββββ
|
| 834 |
+
with gr.TabItem("\U0001f4ca Results Dashboard"):
|
| 835 |
+
gr.Markdown("### All methods compared + REINFORCE training curve")
|
| 836 |
+
load_btn = gr.Button("Load Results", variant="secondary")
|
| 837 |
+
res_md = gr.Markdown()
|
| 838 |
+
with gr.Row():
|
| 839 |
+
bar_chart = gr.Plot(label="ANLS \u2014 All Methods")
|
| 840 |
+
rl_curve = gr.Plot(label="REINFORCE Training Curve")
|
| 841 |
+
load_btn.click(
|
| 842 |
+
fn=show_dashboard,
|
| 843 |
+
inputs=[],
|
| 844 |
+
outputs=[res_md, bar_chart, rl_curve],
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
# ββ Tab 3: About ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 848 |
+
with gr.TabItem("\u2139\ufe0f About"):
|
| 849 |
+
gr.Markdown(f"""
|
| 850 |
+
## Adaptive Multimodal Fusion for DocVQA
|
| 851 |
+
|
| 852 |
+
### Problem
|
| 853 |
+
DocVQA requires reasoning over three modalities simultaneously:
|
| 854 |
+
- **Text** β OCR words and their semantics
|
| 855 |
+
- **Visual** β Document appearance and image patches
|
| 856 |
+
- **Spatial** β Bounding box positions and layout structure
|
| 857 |
+
|
| 858 |
+
Fixed weights (Ξ±=0.5, Ξ²=0.3, Ξ³=0.2) cannot adapt to different document types.
|
| 859 |
+
|
| 860 |
+
### Architecture: CAFP (428K params)
|
| 861 |
+
1. Projects each modality embedding to 128-D
|
| 862 |
+
2. Cross-attention: question attends to all modality representations
|
| 863 |
+
3. Predicts per-document (Ξ±, Ξ², Ξ³) fusion weights
|
| 864 |
+
|
| 865 |
+
### Training Pipeline
|
| 866 |
+
1. **Hard Oracle** (MSE) β argmax weights from 20-combo grid search
|
| 867 |
+
2. **Soft Oracle** (KL-div) β temperature-smoothed ANLS-weighted targets
|
| 868 |
+
3. **REINFORCE** β Policy gradient on direct ANLS reward (K=3 samples/step)
|
| 869 |
+
|
| 870 |
+
### Novel Contributions
|
| 871 |
+
1. Soft Oracle training eliminates hard-oracle label noise
|
| 872 |
+
2. REINFORCE fine-tuning directly maximises DocVQA metric
|
| 873 |
+
3. LLMLingua-style and Selective Context baselines for fair comparison
|
| 874 |
+
|
| 875 |
+
### Key Result
|
| 876 |
+
**CAFP+REINFORCE achieves {rl_val_anls/0.8377*100:.1f}% of Oracle ANLS**
|
| 877 |
+
Improvement over fixed-weight baseline: {rl_val_anls - _fixed_anls:+.4f} ANLS
|
| 878 |
+
""")
|
| 879 |
+
|
| 880 |
+
|
| 881 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 882 |
+
# LAUNCH
|
| 883 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 884 |
+
if __name__ == "__main__":
|
| 885 |
+
demo_app.launch(
|
| 886 |
+
server_name="0.0.0.0",
|
| 887 |
+
server_port=args.port,
|
| 888 |
+
share=args.share,
|
| 889 |
+
show_error=True,
|
| 890 |
+
)
|
checkpoints/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
checkpoints/cafp_rl_checkpoint_final.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2a736aac8aa4ed8cb5ac34d53aa2ef6a896aad4f54ec241ab787d37147d62cce
|
| 3 |
+
size 7688445
|
checkpoints/feature_tensors.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:487d11bbbab8ee1e26ddd6791dbef4f14d4f649b58ef6623d2da3edb6a3e7cbe
|
| 3 |
+
size 2172325
|
checkpoints/final_results.json
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"Equal Fusion": {
|
| 3 |
+
"mean_anls": 0.5136354004166895,
|
| 4 |
+
"mean_f1": 0.5237649141912695,
|
| 5 |
+
"mean_em": 0.45
|
| 6 |
+
},
|
| 7 |
+
"Proposed Fixed": {
|
| 8 |
+
"mean_anls": 0.5451455242797546,
|
| 9 |
+
"mean_f1": 0.5509710563188824,
|
| 10 |
+
"mean_em": 0.48
|
| 11 |
+
},
|
| 12 |
+
"Text-Only": {
|
| 13 |
+
"mean_anls": 0.7642224473566777,
|
| 14 |
+
"mean_f1": 0.7935144736858293,
|
| 15 |
+
"mean_em": 0.7
|
| 16 |
+
},
|
| 17 |
+
"LLMLingua-style": {
|
| 18 |
+
"mean_anls": 0.1892322383498854,
|
| 19 |
+
"mean_f1": 0.210276221599751,
|
| 20 |
+
"mean_em": 0.17
|
| 21 |
+
},
|
| 22 |
+
"Selective Context-style": {
|
| 23 |
+
"mean_anls": 0.3046072383498854,
|
| 24 |
+
"mean_f1": 0.3247206660441955,
|
| 25 |
+
"mean_em": 0.27
|
| 26 |
+
},
|
| 27 |
+
"CAFP (paper checkpoint)": {
|
| 28 |
+
"mean_anls": 0.7531033997376301,
|
| 29 |
+
"mean_f1": 0.7755144736858292,
|
| 30 |
+
"mean_em": 0.68
|
| 31 |
+
},
|
| 32 |
+
"CAFP-Hard Oracle": {
|
| 33 |
+
"mean_anls": 0.7542224473566776,
|
| 34 |
+
"mean_f1": 0.7721811403524959,
|
| 35 |
+
"mean_em": 0.69
|
| 36 |
+
},
|
| 37 |
+
"CAFP-Soft Oracle": {
|
| 38 |
+
"mean_anls": 0.5610757568378942,
|
| 39 |
+
"mean_f1": 0.5647372900851162,
|
| 40 |
+
"mean_em": 0.5
|
| 41 |
+
},
|
| 42 |
+
"CAFP+REINFORCE": {
|
| 43 |
+
"mean_anls": 0.7084701316595168,
|
| 44 |
+
"rl_curve": [
|
| 45 |
+
0.5759529617062865,
|
| 46 |
+
0.5927631990609283,
|
| 47 |
+
0.6183613555884967,
|
| 48 |
+
0.655312951977503,
|
| 49 |
+
0.6726479882862236,
|
| 50 |
+
0.6637648504900422,
|
| 51 |
+
0.6849515365259237,
|
| 52 |
+
0.6891216978125647,
|
| 53 |
+
0.7084701316595168
|
| 54 |
+
]
|
| 55 |
+
}
|
| 56 |
+
}
|
checkpoints/oracle_cache.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets
|
| 2 |
+
transformers
|
| 3 |
+
torch
|
| 4 |
+
torchvision
|
| 5 |
+
sentencepiece
|
| 6 |
+
editdistance
|
| 7 |
+
sentence-transformers
|
| 8 |
+
accelerate
|
| 9 |
+
gradio
|
| 10 |
+
pillow
|
| 11 |
+
numpy
|
| 12 |
+
matplotlib
|