demo-deploy / app.py
ravindranv's picture
Upload folder using huggingface_hub
763cd3b verified
"""
Adaptive Multimodal Fusion for DocVQA β€” Gradio Demo
====================================================
Run locally:
python app.py
Run with public URL (72hr):
python app.py --share
Deploy to HuggingFace Spaces:
- Push this file + requirements.txt + checkpoints/ folder to a Space repo
- HF Spaces auto-launches on port 7860
"""
import argparse, os, sys, copy, json, warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import gradio as gr
import editdistance
from PIL import Image as PILImage, ImageDraw as PILDraw, ImageFont as PILFont
warnings.filterwarnings("ignore")
# ── CLI args ──────────────────────────────────────────────────────────
parser = argparse.ArgumentParser()
parser.add_argument("--share", action="store_true", help="Create public Gradio URL")
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--ckpt_dir", type=str, default="./checkpoints",
help="Folder containing all saved files")
args, _ = parser.parse_known_args()
# ══════════════════════════════════════════════════════════════════════
# CONFIGURATION β€” edit paths here if needed
# ══════════════════════════════════════════════════════════════════════
CKPT_DIR = args.ckpt_dir
ORACLE_CACHE = os.path.join(CKPT_DIR, "oracle_cache.json")
FEAT_PATH = os.path.join(CKPT_DIR, "feature_tensors.pt")
RESULTS_PATH = os.path.join(CKPT_DIR, "final_results.json")
# Try final checkpoint first, fall back to intermediate
CKPT_PATH = os.path.join(CKPT_DIR, "cafp_rl_checkpoint_final.pt")
if not os.path.exists(CKPT_PATH):
CKPT_PATH = os.path.join(CKPT_DIR, "cafp_rl_checkpoint.pt")
# Dataset / model IDs
DATASET_NAME = "nielsr/docvqa_1200_examples"
FEAT_MODEL_ID = "microsoft/layoutlmv3-base"
VQA_MODEL_ID = "rubentito/layoutlmv3-base-mpdocvqa"
SBERT_ID = "all-MiniLM-L6-v2"
# Field names
WORD_FIELD = "words"
BOX_FIELD = "bounding_boxes"
QUERY_FIELD = "query"
ANSWER_FIELD = "answers"
# Architecture
MAX_WORDS = 64
N_PATCHES = 49
N_VAL = 100
N_TRAIN = 100
FEAT_DIM = 2701
PROJ_DIM = 128
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
# ══════════════════════════════════════════════════════════════════════
# MODEL CLASSES
# ══════════════════════════════════════════════════════════════════════
class CrossAttentionFusionPredictor(nn.Module):
def __init__(self, feat_dim=FEAT_DIM, proj_dim=PROJ_DIM,
n_heads=4, dropout=0.15):
super().__init__()
self.text_proj = nn.Linear(768, proj_dim)
self.visual_proj = nn.Linear(768, proj_dim)
self.spatial_proj_lyr = nn.Linear(768, proj_dim)
self.q_proj = nn.Sequential(
nn.Linear(384, proj_dim), nn.LayerNorm(proj_dim), nn.GELU()
)
self.cross_attn = nn.MultiheadAttention(
proj_dim, n_heads, dropout=dropout, batch_first=True
)
self.attn_norm = nn.LayerNorm(proj_dim)
self.head = nn.Sequential(
nn.Linear(proj_dim + 3, proj_dim), nn.GELU(),
nn.Dropout(dropout), nn.Linear(proj_dim, 3)
)
def _logits(self, x):
h_t = self.text_proj(x[:, 0:768])
h_v = self.visual_proj(x[:, 768:1536])
h_s = self.spatial_proj_lyr(x[:, 1536:2304])
q = self.q_proj(x[:, 2314:2698]).unsqueeze(1)
kv = torch.stack([h_t, h_v, h_s], dim=1)
ctx, _ = self.cross_attn(q, kv, kv)
ctx = self.attn_norm(ctx.squeeze(1))
return self.head(torch.cat([ctx, x[:, 2698:2701]], dim=-1))
def forward(self, x):
return F.softmax(self._logits(x), dim=-1)
# ══════════════════════════════════════════════════════════════════════
# LOAD BASE MODELS
# ══════════════════════════════════════════════════════════════════════
print("Loading base models (takes ~5 min on first run, cached after)...")
from transformers import AutoProcessor, AutoModel, AutoModelForQuestionAnswering
from sentence_transformers import SentenceTransformer
feat_processor = AutoProcessor.from_pretrained(FEAT_MODEL_ID, apply_ocr=False)
feat_model = AutoModel.from_pretrained(FEAT_MODEL_ID).to(device).eval()
for p in feat_model.parameters(): p.requires_grad_(False)
print(" βœ… LayoutLMv3 feature model")
vqa_processor = AutoProcessor.from_pretrained(VQA_MODEL_ID, apply_ocr=False)
vqa_model = AutoModelForQuestionAnswering.from_pretrained(
VQA_MODEL_ID).to(device).eval()
for p in vqa_model.parameters(): p.requires_grad_(False)
print(" βœ… VQA model")
sbert = SentenceTransformer(SBERT_ID)
sbert.to(device)
print(" βœ… SBERT")
spatial_proj = nn.Sequential(
nn.Linear(10, 256), nn.ReLU(), nn.Linear(256, 768)
).to(device)
# ══════════════════════════════════════════════════════════════════════
# HELPER FUNCTIONS
# ══════════════════════════════════════════════════════════════════════
def get_question(item):
q = item.get(QUERY_FIELD, item.get("question", ""))
if isinstance(q, dict):
q = q.get("en", next(iter(q.values()), ""))
return str(q).strip()
def normalize_boxes(boxes, w, h):
return [
[
int(max(0, min(b[0] / max(w, 1), 1)) * 1000),
int(max(0, min(b[1] / max(h, 1), 1)) * 1000),
int(max(0, min(b[2] / max(w, 1), 1)) * 1000),
int(max(0, min(b[3] / max(h, 1), 1)) * 1000),
]
for b in boxes
]
def extract_rich_features(item):
try:
img = item["image"].convert("RGB")
W, H = img.size
words = list(item.get(WORD_FIELD, []))[:MAX_WORDS] or ["[PAD]"]
boxes = list(item.get(BOX_FIELD, []))[:MAX_WORDS] or [[0, 0, 1, 1]]
question = get_question(item)
bn = normalize_boxes(boxes, W, H)
enc = feat_processor(img, text=words, boxes=bn,
return_tensors="pt", truncation=True,
max_length=512, padding="max_length")
enc = {k: v.to(device) for k, v in enc.items()}
with torch.no_grad():
hidden = feat_model(**enc).last_hidden_state[0]
n_txt = max(2, hidden.shape[0] - N_PATCHES)
H_text = hidden[1:n_txt-1].mean(0) if n_txt > 2 else hidden[0]
H_visual = hidden[-N_PATCHES:].mean(0)
bx = np.array(bn, dtype=np.float32)
cx = ((bx[:, 0] + bx[:, 2]) / 2) / 1000.0
cy = ((bx[:, 1] + bx[:, 3]) / 2) / 1000.0
sp = np.array([
W / 1000.0, H / 1000.0, min(W, H) / max(W, H),
len(words) / MAX_WORDS,
cx.mean(), cy.mean(), cx.std() + 1e-6, cy.std() + 1e-6,
H_text.norm().item() / 10.0,
H_visual.norm().item() / 10.0,
], dtype=np.float32)
sp10 = torch.tensor(sp).to(device)
H_spat = spatial_proj(sp10.unsqueeze(0)).squeeze(0)
q_emb = torch.tensor(sbert.encode(question),
dtype=torch.float32).to(device)
return {
"H_text": H_text, "H_visual": H_visual, "H_spatial": H_spat,
"spatial_10": sp10, "question_emb": q_emb,
"text_score": float(np.clip(sp[8], 0, 1)),
"visual_score": float(np.clip(sp[9], 0, 1)),
"spatial_score": float(np.clip(sp[6], 0, 1)),
"n_tokens": len(words),
}
except Exception as e:
print(f" extract_rich_features error: {e}")
dummy = torch.zeros(768, device=device)
return {
"H_text": dummy, "H_visual": dummy, "H_spatial": dummy,
"spatial_10": torch.zeros(10, device=device),
"question_emb": torch.zeros(384, device=device),
"text_score": 0.5, "visual_score": 0.3, "spatial_score": 0.2,
"n_tokens": 0,
}
def build_feature_vector(feat):
return torch.cat([
feat["H_text"], feat["H_visual"], feat["H_spatial"],
feat["spatial_10"], feat["question_emb"],
torch.tensor(
[feat["text_score"], feat["visual_score"], feat["spatial_score"]],
dtype=torch.float32, device=device
),
])
def vqa_infer(item, alpha, beta, gamma):
try:
img = item["image"].convert("RGB")
words = list(item.get(WORD_FIELD, []))
boxes = list(item.get(BOX_FIELD, []))
question = get_question(item)
if not words:
return ""
W, H = img.size
n = len(words)
n_keep = max(int(n * max(float(alpha), 0.30)), min(5, n))
if float(gamma) > max(float(alpha), float(beta)) and boxes:
order = sorted(range(n), key=lambda i: (boxes[i][1], boxes[i][0]))
sel_idx = sorted(order[:n_keep])
else:
sel_idx = list(range(n_keep))
sw = [words[i] for i in sel_idx]
sb = ([boxes[i] for i in sel_idx]
if boxes else [[0, 0, W, H]] * len(sw))
enc = vqa_processor(
img, text=question, text_pair=sw,
boxes=normalize_boxes(sb, W, H),
return_tensors="pt", truncation=True,
max_length=512, padding=True
)
enc = {k: v.to(device) for k, v in enc.items()}
with torch.no_grad():
out = vqa_model(**enc)
s = int(out.start_logits.argmax())
e = int(out.end_logits.argmax())
if e < s: e = s
return vqa_processor.tokenizer.decode(
enc["input_ids"][0][s:e+1], skip_special_tokens=True
).strip()
except Exception:
return ""
def compute_anls(pred, gts, threshold=0.5):
if isinstance(gts, str): gts = [gts]
if not gts or not pred: return 0.0
p, best = str(pred).lower().strip(), 0.0
for gt in gts:
g = str(gt).lower().strip()
ml = max(len(p), len(g))
if ml == 0:
best = max(best, 1.0); continue
nls = 1.0 - editdistance.eval(p, g) / ml
if nls < threshold: nls = 0.0
best = max(best, nls)
return best
def compute_f1(pred, gts):
if isinstance(gts, str): gts = [gts]
if not pred or not gts: return 0.0
pt = set(str(pred).lower().split())
if not pt: return 0.0
best = 0.0
for gt in gts:
gt_t = set(str(gt).lower().split())
if not gt_t: continue
common = pt & gt_t
if not common: continue
p = len(common) / len(pt)
r = len(common) / len(gt_t)
best = max(best, 2 * p * r / (p + r))
return best
# ══════════════════════════════════════════════════════════════════════
# WORD SELECTION VISUALIZER HELPERS
# ══════════════════════════════════════════════════════════════════════
def get_sel_idx(item, alpha, beta, gamma):
"""Return the SET of word indices kept by this (alpha, beta, gamma) config.
Mirrors the exact selection logic in vqa_infer so the boxes always
match what the model actually sees.
"""
words = list(item.get(WORD_FIELD, []))
boxes = list(item.get(BOX_FIELD, []))
n = len(words)
if n == 0:
return set()
n_keep = max(int(n * max(float(alpha), 0.30)), min(5, n))
n_keep = min(n_keep, n)
if float(gamma) > max(float(alpha), float(beta)) and boxes:
order = sorted(range(n), key=lambda i: (boxes[i][1], boxes[i][0]))
sel_idx = set(order[:n_keep])
else:
sel_idx = set(range(n_keep))
return sel_idx
def draw_selection(item, alpha, beta, gamma, title=""):
"""Return a PIL Image with coloured bounding boxes overlaid.
🟒 Green fill + outline β†’ word KEPT (used for VQA)
πŸ”΄ Red fill + outline β†’ word DROPPED (compressed out)
An info strip (dark) and a colour legend strip are appended below the
document image so the panel is self-explanatory at a glance.
"""
try:
img = item["image"].convert("RGB").copy()
W, H = img.size
words = list(item.get(WORD_FIELD, []))
boxes = list(item.get(BOX_FIELD, []))
n = min(len(words), len(boxes))
if n == 0:
return img
sel_idx = get_sel_idx(item, alpha, beta, gamma)
n_keep = len(sel_idx)
pct = 100 * n_keep / max(n, 1)
# ── Draw semi-transparent coloured overlays ───────────────────
overlay = PILImage.new("RGBA", img.size, (0, 0, 0, 0))
od = PILDraw.Draw(overlay)
for i in range(n):
try:
x0, y0, x1, y1 = (int(boxes[i][0]), int(boxes[i][1]),
int(boxes[i][2]), int(boxes[i][3]))
# Clamp to image bounds
x0, x1 = max(0, x0), min(W - 1, x1)
y0, y1 = max(0, y0), min(H - 1, y1)
if x1 <= x0 or y1 <= y0:
continue
if i in sel_idx:
od.rectangle([x0, y0, x1, y1],
fill=(0, 210, 0, 55),
outline=(0, 160, 0, 230), width=2)
else:
od.rectangle([x0, y0, x1, y1],
fill=(220, 30, 30, 40),
outline=(200, 0, 0, 170), width=1)
except Exception:
continue
img = PILImage.alpha_composite(img.convert("RGBA"), overlay).convert("RGB")
# ── Load font (graceful fallback) ─────────────────────────────
font_sm = PILFont.load_default()
for _fp in [
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
"/System/Library/Fonts/Supplemental/Arial.ttf",
"/Windows/Fonts/arial.ttf",
]:
try:
font_sm = PILFont.truetype(_fp, 13)
break
except Exception:
continue
# ── Info strip (dark bar showing title + stats) ───────────────
strip_h = 36
strip = PILImage.new("RGB", (W, strip_h), (22, 22, 32))
sd = PILDraw.Draw(strip)
info_text = (f"{title} | βœ“ Kept: {n_keep}/{n} ({pct:.0f}%)"
f" | Ξ±={alpha:.2f} Ξ²={beta:.2f} Ξ³={gamma:.2f}")
sd.text((8, 11), info_text, fill=(220, 220, 220), font=font_sm)
# ── Legend strip (light bar explaining colours) ───────────────
leg_h = 28
leg = PILImage.new("RGB", (W, leg_h), (246, 246, 246))
ld = PILDraw.Draw(leg)
ld.rectangle([8, 7, 24, 21], fill=(0, 180, 0), outline=(0, 130, 0, 255))
ld.text( [30, 8], "= Kept (used for VQA)", fill=(0, 110, 0), font=font_sm)
ld.rectangle([210, 7, 226, 21], fill=(220, 30, 30), outline=(170, 0, 0, 255))
ld.text( [232, 8], "= Dropped (compressed out)", fill=(140, 0, 0), font=font_sm)
# ── Stack: image β†’ dark strip β†’ legend ────────────────────────
final = PILImage.new("RGB", (W, H + strip_h + leg_h), (255, 255, 255))
final.paste(img, (0, 0))
final.paste(strip, (0, H))
final.paste(leg, (0, H + strip_h))
return final
except Exception as e:
print(f" draw_selection error: {e}")
return item.get("image", None)
def make_compression_md(item, cfgs):
"""Build a markdown table showing kept / dropped word statistics and
a sample of the words that each method discards.
cfgs – OrderedDict/dict {method_name: (alpha, beta, gamma)}
"""
words = list(item.get(WORD_FIELD, []))
n = len(words)
if n == 0:
return "*No OCR words available for this document.*"
md = "### πŸ” What Gets Compressed?\n\n"
md += f"**Total OCR words in document:** {n}\n\n"
md += ("| Method | Ξ± | Ξ² | Ξ³ | Words Kept | % Context |"
" Sample Dropped Words |\n")
md += ("|--------|---|---|---|:----------:|:---------:|"
"----------------------|\n")
for name, (a, b, g) in cfgs.items():
sel = get_sel_idx(item, a, b, g)
n_keep = len(sel)
pct = 100 * n_keep / max(n, 1)
dropped = [words[i] for i in range(n) if i not in sel]
d_preview = " Β· ".join(dropped[:8])
if len(dropped) > 8:
d_preview += f" … (+{len(dropped) - 8} more)"
md += (f"| **{name}** | {a:.2f} | {b:.2f} | {g:.2f}"
f" | {n_keep} / {n} | {pct:.0f}% | `{d_preview}` |\n")
# Show the actual kept words for the CAFP+REINFORCE method
if "CAFP+REINFORCE" in cfgs:
a, b, g = cfgs["CAFP+REINFORCE"]
sel = get_sel_idx(item, a, b, g)
kept_w = [words[i] for i in sorted(sel)[:25]]
md += (f"\n**CAFP+REINFORCE β€” kept words (first 25 shown):** \n"
f"`{' Β· '.join(kept_w)}`\n")
return md
# ══════════════════════════════════════════════════════════════════════
# LOAD CHECKPOINTS & DATA
# ══════════════════════════════════════════════════════════════════════
print("\nLoading checkpoints and data...")
# ── RL checkpoint ─────────────────────────────────────────────────────
if not os.path.exists(CKPT_PATH):
sys.exit(f"❌ Checkpoint not found: {CKPT_PATH}\n"
f" Copy cafp_rl_checkpoint_final.pt into {CKPT_DIR}/")
ck = torch.load(CKPT_PATH, map_location=device, weights_only=False)
spatial_proj.load_state_dict(ck["spatial_proj_state"])
cafp_soft = CrossAttentionFusionPredictor().to(device)
cafp_soft.load_state_dict(ck["cafp_soft_state"]); cafp_soft.eval()
cafp_rl = copy.deepcopy(cafp_soft)
cafp_rl.load_state_dict(ck["cafp_rl_state"]); cafp_rl.eval()
rl_train_anls = ck["rl_train_anls"]
rl_val_anls = ck.get("rl_val_anls",
max(rl_train_anls) if rl_train_anls else 0.0)
print(f" βœ… CAFP+REINFORCE: {len(rl_train_anls)} epochs | "
f"best_train={max(rl_train_anls):.4f} | val={rl_val_anls:.4f}")
# ── Dataset ───────────────────────────────────────────────────────────
print(" Loading dataset (~30s)...")
from datasets import load_dataset
_ds = load_dataset(DATASET_NAME, split="train")
_split = _ds.train_test_split(test_size=0.2, seed=42)
rng = np.random.RandomState(42)
val_idx = rng.permutation(len(_split["test"])).tolist()[:N_VAL]
train_idx = rng.permutation(len(_split["train"])).tolist()[:N_TRAIN]
val_items = [_split["test"][i] for i in val_idx]
train_items = [_split["train"][i] for i in train_idx]
val_gts = [item[ANSWER_FIELD] for item in val_items]
train_gts = [item[ANSWER_FIELD] for item in train_items]
print(f" βœ… Dataset: {len(val_items)} val, {len(train_items)} train")
# ── Feature tensors ───────────────────────────────────────────────────
if os.path.exists(FEAT_PATH):
t = torch.load(FEAT_PATH, map_location=device, weights_only=False)
val_feats = t["val_feats"]
train_feats = t["train_feats"]
print(f" βœ… Features: {tuple(val_feats.shape)}")
else:
print(" ⚠️ feature_tensors.pt not found β€” recomputing (~2 min)...")
def _feats(items, tag):
out = []
for i, item in enumerate(items):
out.append(build_feature_vector(
extract_rich_features(item)).unsqueeze(0))
if (i + 1) % 10 == 0:
print(f" {tag}: {i+1}/{len(items)}", end="\r")
print()
return torch.cat(out).to(device)
val_feats = _feats(val_items, "val")
train_feats = _feats(train_items, "train")
torch.save({"val_feats": val_feats, "train_feats": train_feats,
"val_gts": val_gts, "train_gts": train_gts}, FEAT_PATH)
print(f" βœ… Features computed and saved to {FEAT_PATH}")
# ── Oracle cache ──────────────────────────────────────────────────────
val_oracle = train_oracle = []
if os.path.exists(ORACLE_CACHE):
_oc = json.load(open(ORACLE_CACHE))
train_oracle = _oc.get("train", [])
val_oracle = _oc.get("val", [])
print(f" βœ… Oracle cache: {len(train_oracle)} train, {len(val_oracle)} val")
else:
print(" ⚠️ oracle_cache.json not found β€” demo works without it")
# ── Results from JSON ─────────────────────────────────────────────────
results = {}
_RKEYS = [
"Equal Fusion", "Proposed Fixed", "Text-Only",
"LLMLingua-style", "Selective Context-style",
"CAFP (paper checkpoint)", "CAFP-Hard Oracle", "CAFP-Soft Oracle",
]
for _rpath in [RESULTS_PATH, "./final_results.json", "./results_condensed.json"]:
try:
_raw = json.load(open(_rpath))
for k in _RKEYS:
if k in _raw and isinstance(_raw[k], dict):
r = _raw[k]
results[k] = {
"mean_anls": float(r.get("mean_anls", r.get("anls", 0.0))),
"mean_f1": float(r.get("mean_f1", r.get("f1", 0.0))),
}
if results:
print(f" βœ… Results: {len(results)} methods from {_rpath}")
break
except Exception:
continue
if not results:
print(" ⚠️ Results JSON not found β€” dashboard will show partial data")
# ── Find best demo documents ──────────────────────────────────────────
print("\nPre-scoring documents for demo (this takes ~2 min)...")
demo_scores = []
cafp_rl.eval()
with torch.no_grad():
for i in range(len(val_items)):
fv = val_feats[i].unsqueeze(0)
conc = F.softplus(cafp_rl._logits(fv)) + 0.1
w = (conc / conc.sum()).squeeze(0).cpu().tolist()
rl_s = compute_anls(vqa_infer(val_items[i], w[0], w[1], w[2]),
val_gts[i])
fx_s = compute_anls(vqa_infer(val_items[i], 0.5, 0.3, 0.2),
val_gts[i])
demo_scores.append((i, round(rl_s - fx_s, 4),
round(rl_s, 4), round(fx_s, 4)))
if (i + 1) % 20 == 0:
print(f" {i+1}/100", end="\r")
demo_scores.sort(key=lambda x: -x[1])
best_idx = demo_scores[0][0]
top5_str = ", ".join([f"#{x[0]}(+{x[1]:.2f})" for x in demo_scores[:5]])
print(f"\n βœ… Best docs: {top5_str}")
print(f"\n{'='*55}")
print("ALL MODELS LOADED β€” ready to demo")
print(f"{'='*55}\n")
# ══════════════════════════════════════════════════════════════════════
# GRADIO FUNCTIONS
# ══════════════════════════════════════════════════════════════════════
def get_rl_weights(idx, custom_q=None):
if custom_q and custom_q.strip():
_item = dict(val_items[idx])
_item[QUERY_FIELD] = custom_q.strip()
fv = build_feature_vector(extract_rich_features(_item)).unsqueeze(0)
else:
fv = val_feats[idx].unsqueeze(0)
with torch.no_grad():
conc = F.softplus(cafp_rl._logits(fv)) + 0.1
w = (conc / conc.sum()).squeeze(0).cpu().tolist()
return w
def make_weight_chart(mw):
fig, ax = plt.subplots(figsize=(9, 3.5))
labels = list(mw.keys())
x, bw = np.arange(len(labels)), 0.25
for j, (lbl, col) in enumerate([
("\u03b1 Text", "#2196F3"),
("\u03b2 Visual", "#4CAF50"),
("\u03b3 Spatial", "#FF9800"),
]):
vals = [list(mw.values())[i][j] for i in range(len(labels))]
bars = ax.bar(x + (j - 1) * bw, vals, bw,
label=lbl, color=col, alpha=0.85)
for bar in bars:
h = bar.get_height()
ax.text(bar.get_x() + bar.get_width() / 2, h + 0.01,
f"{h:.2f}", ha="center", va="bottom", fontsize=9)
ax.set_xticks(x); ax.set_xticklabels(labels, fontsize=10)
ax.set_ylabel("Weight"); ax.set_ylim(0, 1.2)
ax.set_title("Fusion Weights (\u03b1, \u03b2, \u03b3) per Method",
fontsize=12, fontweight="bold")
ax.legend(fontsize=9); ax.grid(axis="y", alpha=0.3)
plt.tight_layout()
return fig
def run_demo(doc_idx, custom_q):
doc_idx = int(doc_idx)
item = val_items[doc_idx]
gt = val_gts[doc_idx]
q = (custom_q.strip()
if custom_q and custom_q.strip()
else get_question(item))
gt_str = (", ".join(str(g) for g in gt[:2])
if isinstance(gt, list) else str(gt))
n_words = len(list(item.get(WORD_FIELD, [])))
doc_type = "Text-dominant" if n_words > 40 else "Visual-dominant"
alpha, beta, gamma = get_rl_weights(doc_idx, custom_q)
dom = ("Text" if alpha > 0.65 else
"Visual" if beta > 0.40 else "Balanced")
cfgs = {
"Equal Fusion": (1/3, 1/3, 1/3),
"Fixed (0.5,0.3,0.2)": (0.5, 0.3, 0.2),
"Text-Only": (1.0, 0.0, 0.0),
"CAFP+REINFORCE": (alpha, beta, gamma),
}
res = {}
for name, (a, b, g) in cfgs.items():
demo_item = dict(item); demo_item[QUERY_FIELD] = q
pred = vqa_infer(demo_item, a, b, g)
res[name] = {
"pred": pred,
"anls": compute_anls(pred, gt),
"f1": compute_f1(pred, gt),
"w": (a, b, g),
}
best = max(res, key=lambda k: res[k]["anls"])
rl_vs_fixed = res["CAFP+REINFORCE"]["anls"] - res["Fixed (0.5,0.3,0.2)"]["anls"]
md = f"## Document #{doc_idx} \u2014 {doc_type} ({n_words} words)\n\n"
md += f"**Question:** {q}\n\n"
md += f"**Ground Truth:** `{gt_str}`\n\n---\n"
md += "### Step 1 \u2014 Text Extraction\n"
md += f"`{n_words}` OCR words extracted via LayoutLMv3\n\n"
md += "### Step 2 \u2014 Multimodal Feature Extraction\n"
md += ("- **Text** \u2192 LayoutLMv3 token embeddings [768-D]\n"
"- **Visual** \u2192 LayoutLMv3 patch features [768-D]\n"
"- **Spatial** \u2192 Bounding box layout encoding [768-D]\n\n")
md += "### Step 3 \u2014 CAFP+REINFORCE Weight Prediction\n"
md += "| Modality | Weight |\n|----------|--------|\n"
md += f"| \u03b1 Text | **{alpha:.3f}** |\n"
md += f"| \u03b2 Visual | **{beta:.3f}** |\n"
md += f"| \u03b3 Spatial | **{gamma:.3f}** |\n\n"
md += f"\u2192 **Dominant: {dom}**\n\n"
md += "### Step 4 \u2014 Adaptive Fusion \u2192 Answer\n"
md += "| Method | \u03b1 | \u03b2 | \u03b3 | Answer | ANLS | F1 |\n"
md += "|--------|---|---|---|--------|------|----|\n"
for name, d in res.items():
a, b, g = d["w"]
star = " \u2b50" if name == best else ""
md += (f"| {name}{star} | {a:.2f} | {b:.2f} | {g:.2f}"
f" | `{d['pred']}` | **{d['anls']:.4f}** | {d['f1']:.4f} |\n")
sign = "+" if rl_vs_fixed >= 0 else ""
md += (f"\n---\n**Best:** {best} (ANLS: {res[best]['anls']:.4f})\n\n"
f"**CAFP+REINFORCE answer:** `{res['CAFP+REINFORCE']['pred']}`\n\n"
f"**\u0394 over Fixed:** {sign}{rl_vs_fixed:.4f}\n")
chart = make_weight_chart({k: v["w"] for k, v in res.items()})
# ── Word Selection Visualizations ─────────────────────────────────
_item_q = dict(item); _item_q[QUERY_FIELD] = q
fixed_vis = draw_selection(
_item_q, 0.5, 0.3, 0.2,
"Fixed Weights (0.5, 0.3, 0.2)"
)
rl_vis = draw_selection(
_item_q, alpha, beta, gamma,
f"CAFP+REINFORCE (Ξ±={alpha:.2f} Ξ²={beta:.2f} Ξ³={gamma:.2f})"
)
comp_md = make_compression_md(item, cfgs)
return item.get("image", None), md, chart, fixed_vis, rl_vis, comp_md
def show_dashboard():
def sg(k): return results.get(k, {}).get("mean_anls", 0.0)
def sf(k): return results.get(k, {}).get("mean_f1", 0.0)
fixed = sg("Proposed Fixed"); oracle = 0.8377
rv = rl_val_anls
rows = [
("Equal Fusion", sg("Equal Fusion"), sf("Equal Fusion")),
("Proposed Fixed (paper)", sg("Proposed Fixed"), sf("Proposed Fixed")),
("Text-Only", sg("Text-Only"), sf("Text-Only")),
("LLMLingua-style [NEW]", sg("LLMLingua-style"), sf("LLMLingua-style")),
("Selective Context [NEW]", sg("Selective Context-style"), sf("Selective Context-style")),
("CAFP paper checkpoint", sg("CAFP (paper checkpoint)"), sf("CAFP (paper checkpoint)")),
("CAFP Hard Oracle [NEW]", sg("CAFP-Hard Oracle"), sf("CAFP-Hard Oracle")),
("CAFP Soft Oracle [NEW]", sg("CAFP-Soft Oracle"), sf("CAFP-Soft Oracle")),
("CAFP+REINFORCE [NEW][BEST]", rv, 0.0),
("Oracle Upper Bound", oracle, 0.0),
]
md = "## Full Experiment Results\n\n"
md += "| Method | ANLS | F1 | \u0394 Fixed | % Oracle |\n"
md += "|--------|------|----|----------|----------|\n"
for name, anls, f1 in rows:
is_oracle = "Oracle Upper" in name
d = f"{anls - fixed:+.4f}" if not is_oracle else "\u2014"
pct = f"{anls / oracle * 100:.1f}%" if anls > 0 else "\u2014"
md += f"| {name} | {anls:.4f} | {f1:.4f} | {d} | {pct} |\n"
md += (f"\n**CAFP+REINFORCE: {rv/oracle*100:.1f}% of Oracle ANLS**\n"
f"**Improvement over Fixed: {rv - fixed:+.4f} ANLS**\n")
# Bar chart
fig1, ax1 = plt.subplots(figsize=(11, 5))
bv = [r[1] for r in rows]
bc = ["#bbb","#999","#bbb","#2196F3","#2196F3",
"#777","#4CAF50","#4CAF50","#FF5722","#d32f2f"]
bars = ax1.barh([r[0] for r in rows], bv,
color=bc, edgecolor="white", height=0.65)
ax1.axvline(oracle, color="red", linestyle="--", lw=1.5,
label=f"Oracle {oracle:.4f}")
ax1.axvline(fixed, color="gray", linestyle=":", lw=1.2,
label=f"Fixed {fixed:.4f}")
for bar, val in zip(bars, bv):
if val > 0:
ax1.text(val + 0.003, bar.get_y() + bar.get_height() / 2,
f"{val:.4f}", va="center", fontsize=8)
ax1.set_xlabel("Val ANLS", fontsize=11)
ax1.set_title("All Methods \u2014 Val ANLS", fontsize=13, fontweight="bold")
ax1.legend(fontsize=9); ax1.invert_yaxis()
ax1.set_xlim(0, oracle * 1.1); ax1.grid(axis="x", alpha=0.3)
plt.tight_layout()
# Training curve
fig2, ax2 = plt.subplots(figsize=(10, 3.5))
eps = list(range(1, len(rl_train_anls) + 1))
ax2.plot(eps, rl_train_anls, "o-", color="#FF5722",
lw=2.5, ms=7, label="Train ANLS")
ax2.axhline(rv, color="#FF5722", linestyle=":", lw=2,
label=f"Val ANLS = {rv:.4f}")
ax2.axhline(oracle, color="red", linestyle="--", lw=1.5,
label=f"Oracle = {oracle:.4f}")
ax2.axhline(fixed, color="gray", linestyle=":", lw=1.2,
label=f"Fixed = {fixed:.4f}")
ax2.fill_between(eps, rl_train_anls, fixed, alpha=0.15, color="#FF5722")
ax2.set_xlabel("Epoch"); ax2.set_ylabel("ANLS")
ax2.set_title("REINFORCE Fine-tuning Progress",
fontsize=12, fontweight="bold")
ax2.legend(fontsize=9); ax2.grid(True, alpha=0.3)
ax2.set_xticks(eps); plt.tight_layout()
return md, fig1, fig2
# ══════════════════════════════════════════════════════════════════════
# GRADIO UI
# ══════════════════════════════════════════════════════════════════════
_fixed_anls = results.get("Proposed Fixed", {}).get("mean_anls", 0.0)
_best_label = ("Best docs (REINFORCE wins most): "
+ ", ".join([f"#{x[0]}" for x in demo_scores[:5]]))
CSS = ".tab-nav button { font-size: 15px !important; font-weight: 600 !important; }"
with gr.Blocks(
title="Adaptive Multimodal Fusion β€” DocVQA Demo",
theme=gr.themes.Soft(primary_hue="blue"),
css=CSS,
) as demo_app:
gr.Markdown("""
# Adaptive Multimodal Fusion for Document VQA
### Cross-Attention Fusion Predictor (CAFP) + REINFORCE Fine-tuning
""")
with gr.Tabs():
# ── Tab 1: Live Demo ──────────────────────────────────────────
with gr.TabItem("\U0001f3af Live Demo"):
gr.Markdown(f"**{_best_label}**")
with gr.Row():
with gr.Column(scale=1):
doc_slider = gr.Slider(
0, len(val_items) - 1,
value=best_idx, step=1,
label=f"Document Index (0\u2013{len(val_items)-1})"
)
custom_q = gr.Textbox(
label="Custom Question (optional)",
placeholder="Leave blank to use original question"
)
run_btn = gr.Button(
"\u25b6 Run Adaptive Fusion",
variant="primary", size="lg"
)
gr.Markdown(
"*Compares: Equal \u00b7 Fixed \u00b7 "
"Text-Only \u00b7 CAFP+REINFORCE*"
)
with gr.Column(scale=2):
doc_image = gr.Image(label="Document Image", height=400)
step_md = gr.Markdown()
weight_chart = gr.Plot(label="Fusion Weights Comparison")
# ── Word Selection Visualizer ─────────────────────────────
gr.Markdown("""
---
### 🎨 Word Selection Visualization
*See **exactly** which OCR words each method keeps vs discards.*
🟒 **Green** = kept and fed to the VQA model Β· πŸ”΄ **Red** = compressed out
""")
with gr.Row():
fixed_vis_img = gr.Image(
label="πŸ“Œ Fixed Weights (Ξ±=0.5 Ξ²=0.3 Ξ³=0.2)",
height=520
)
rl_vis_img = gr.Image(
label="πŸ€– CAFP+REINFORCE (Adaptive Weights)",
height=520
)
comp_md_out = gr.Markdown()
run_btn.click(
fn=run_demo,
inputs=[doc_slider, custom_q],
outputs=[doc_image, step_md, weight_chart,
fixed_vis_img, rl_vis_img, comp_md_out],
)
# ── Tab 2: Results Dashboard ──────────────────────────────────
with gr.TabItem("\U0001f4ca Results Dashboard"):
gr.Markdown("### All methods compared + REINFORCE training curve")
load_btn = gr.Button("Load Results", variant="secondary")
res_md = gr.Markdown()
with gr.Row():
bar_chart = gr.Plot(label="ANLS \u2014 All Methods")
rl_curve = gr.Plot(label="REINFORCE Training Curve")
load_btn.click(
fn=show_dashboard,
inputs=[],
outputs=[res_md, bar_chart, rl_curve],
)
# ── Tab 3: About ──────────────────────────────────────────────
with gr.TabItem("\u2139\ufe0f About"):
gr.Markdown(f"""
## Adaptive Multimodal Fusion for DocVQA
### Problem
DocVQA requires reasoning over three modalities simultaneously:
- **Text** β€” OCR words and their semantics
- **Visual** β€” Document appearance and image patches
- **Spatial** β€” Bounding box positions and layout structure
Fixed weights (Ξ±=0.5, Ξ²=0.3, Ξ³=0.2) cannot adapt to different document types.
### Architecture: CAFP (428K params)
1. Projects each modality embedding to 128-D
2. Cross-attention: question attends to all modality representations
3. Predicts per-document (Ξ±, Ξ², Ξ³) fusion weights
### Training Pipeline
1. **Hard Oracle** (MSE) β†’ argmax weights from 20-combo grid search
2. **Soft Oracle** (KL-div) β†’ temperature-smoothed ANLS-weighted targets
3. **REINFORCE** β†’ Policy gradient on direct ANLS reward (K=3 samples/step)
### Novel Contributions
1. Soft Oracle training eliminates hard-oracle label noise
2. REINFORCE fine-tuning directly maximises DocVQA metric
3. LLMLingua-style and Selective Context baselines for fair comparison
### Key Result
**CAFP+REINFORCE achieves {rl_val_anls/0.8377*100:.1f}% of Oracle ANLS**
Improvement over fixed-weight baseline: {rl_val_anls - _fixed_anls:+.4f} ANLS
""")
# ══════════════════════════════════════════════════════════════════════
# LAUNCH
# ══════════════════════════════════════════════════════════════════════
if __name__ == "__main__":
demo_app.launch(
server_name="0.0.0.0",
server_port=args.port,
share=args.share,
show_error=True,
)