ravindranv commited on
Commit
5dabb85
Β·
verified Β·
1 Parent(s): 0660539

Upload folder using huggingface_hub

Browse files
.DS_Store ADDED
Binary file (10.2 kB). View file
 
aoo2.poy ADDED
@@ -0,0 +1,890 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adaptive Multimodal Fusion for DocVQA β€” Gradio Demo
3
+ ====================================================
4
+ Run locally:
5
+ python app.py
6
+
7
+ Run with public URL (72hr):
8
+ python app.py --share
9
+
10
+ Deploy to HuggingFace Spaces:
11
+ - Push this file + requirements.txt + checkpoints/ folder to a Space repo
12
+ - HF Spaces auto-launches on port 7860
13
+ """
14
+
15
+ import argparse, os, sys, copy, json, warnings
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import matplotlib
21
+ matplotlib.use("Agg")
22
+ import matplotlib.pyplot as plt
23
+ import gradio as gr
24
+ import editdistance
25
+ from PIL import Image as PILImage, ImageDraw as PILDraw, ImageFont as PILFont
26
+ warnings.filterwarnings("ignore")
27
+
28
+ # ── CLI args ──────────────────────────────────────────────────────────
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument("--share", action="store_true", help="Create public Gradio URL")
31
+ parser.add_argument("--port", type=int, default=7860)
32
+ parser.add_argument("--ckpt_dir", type=str, default="./checkpoints",
33
+ help="Folder containing all saved files")
34
+ args, _ = parser.parse_known_args()
35
+
36
+ # ══════════════════════════════════════════════════════════════════════
37
+ # CONFIGURATION β€” edit paths here if needed
38
+ # ══════════════════════════════════════════════════════════════════════
39
+ CKPT_DIR = args.ckpt_dir
40
+ ORACLE_CACHE = os.path.join(CKPT_DIR, "oracle_cache.json")
41
+ FEAT_PATH = os.path.join(CKPT_DIR, "feature_tensors.pt")
42
+ RESULTS_PATH = os.path.join(CKPT_DIR, "final_results.json")
43
+
44
+ # Try final checkpoint first, fall back to intermediate
45
+ CKPT_PATH = os.path.join(CKPT_DIR, "cafp_rl_checkpoint_final.pt")
46
+ if not os.path.exists(CKPT_PATH):
47
+ CKPT_PATH = os.path.join(CKPT_DIR, "cafp_rl_checkpoint.pt")
48
+
49
+ # Dataset / model IDs
50
+ DATASET_NAME = "nielsr/docvqa_1200_examples"
51
+ FEAT_MODEL_ID = "microsoft/layoutlmv3-base"
52
+ VQA_MODEL_ID = "rubentito/layoutlmv3-base-mpdocvqa"
53
+ SBERT_ID = "all-MiniLM-L6-v2"
54
+
55
+ # Field names
56
+ WORD_FIELD = "words"
57
+ BOX_FIELD = "bounding_boxes"
58
+ QUERY_FIELD = "query"
59
+ ANSWER_FIELD = "answers"
60
+
61
+ # Architecture
62
+ MAX_WORDS = 64
63
+ N_PATCHES = 49
64
+ N_VAL = 100
65
+ N_TRAIN = 100
66
+ FEAT_DIM = 2701
67
+ PROJ_DIM = 128
68
+
69
+ device = "cuda" if torch.cuda.is_available() else "cpu"
70
+ print(f"Device: {device}")
71
+
72
+ # ══════════════════════════════════════════════════════════════════════
73
+ # MODEL CLASSES
74
+ # ══════════════════════════════════════════════════════════════════════
75
+ class CrossAttentionFusionPredictor(nn.Module):
76
+ def __init__(self, feat_dim=FEAT_DIM, proj_dim=PROJ_DIM,
77
+ n_heads=4, dropout=0.15):
78
+ super().__init__()
79
+ self.text_proj = nn.Linear(768, proj_dim)
80
+ self.visual_proj = nn.Linear(768, proj_dim)
81
+ self.spatial_proj_lyr = nn.Linear(768, proj_dim)
82
+ self.q_proj = nn.Sequential(
83
+ nn.Linear(384, proj_dim), nn.LayerNorm(proj_dim), nn.GELU()
84
+ )
85
+ self.cross_attn = nn.MultiheadAttention(
86
+ proj_dim, n_heads, dropout=dropout, batch_first=True
87
+ )
88
+ self.attn_norm = nn.LayerNorm(proj_dim)
89
+ self.head = nn.Sequential(
90
+ nn.Linear(proj_dim + 3, proj_dim), nn.GELU(),
91
+ nn.Dropout(dropout), nn.Linear(proj_dim, 3)
92
+ )
93
+
94
+ def _logits(self, x):
95
+ h_t = self.text_proj(x[:, 0:768])
96
+ h_v = self.visual_proj(x[:, 768:1536])
97
+ h_s = self.spatial_proj_lyr(x[:, 1536:2304])
98
+ q = self.q_proj(x[:, 2314:2698]).unsqueeze(1)
99
+ kv = torch.stack([h_t, h_v, h_s], dim=1)
100
+ ctx, _ = self.cross_attn(q, kv, kv)
101
+ ctx = self.attn_norm(ctx.squeeze(1))
102
+ return self.head(torch.cat([ctx, x[:, 2698:2701]], dim=-1))
103
+
104
+ def forward(self, x):
105
+ return F.softmax(self._logits(x), dim=-1)
106
+
107
+
108
+ # ══════════════════════════════════════════════════════════════════════
109
+ # LOAD BASE MODELS
110
+ # ══════════════════════════════════════════════════════════════════════
111
+ print("Loading base models (takes ~5 min on first run, cached after)...")
112
+
113
+ from transformers import AutoProcessor, AutoModel, AutoModelForQuestionAnswering
114
+ from sentence_transformers import SentenceTransformer
115
+
116
+ feat_processor = AutoProcessor.from_pretrained(FEAT_MODEL_ID, apply_ocr=False)
117
+ feat_model = AutoModel.from_pretrained(FEAT_MODEL_ID).to(device).eval()
118
+ for p in feat_model.parameters(): p.requires_grad_(False)
119
+ print(" βœ… LayoutLMv3 feature model")
120
+
121
+ vqa_processor = AutoProcessor.from_pretrained(VQA_MODEL_ID, apply_ocr=False)
122
+ vqa_model = AutoModelForQuestionAnswering.from_pretrained(
123
+ VQA_MODEL_ID).to(device).eval()
124
+ for p in vqa_model.parameters(): p.requires_grad_(False)
125
+ print(" βœ… VQA model")
126
+
127
+ sbert = SentenceTransformer(SBERT_ID)
128
+ sbert.to(device)
129
+ print(" βœ… SBERT")
130
+
131
+ spatial_proj = nn.Sequential(
132
+ nn.Linear(10, 256), nn.ReLU(), nn.Linear(256, 768)
133
+ ).to(device)
134
+
135
+ # ══════════════════════════════════════════════════════════════════════
136
+ # HELPER FUNCTIONS
137
+ # ══════════════════════════════════════════════════════════════════════
138
+ def get_question(item):
139
+ q = item.get(QUERY_FIELD, item.get("question", ""))
140
+ if isinstance(q, dict):
141
+ q = q.get("en", next(iter(q.values()), ""))
142
+ return str(q).strip()
143
+
144
+
145
+ def normalize_boxes(boxes, w, h):
146
+ return [
147
+ [
148
+ int(max(0, min(b[0] / max(w, 1), 1)) * 1000),
149
+ int(max(0, min(b[1] / max(h, 1), 1)) * 1000),
150
+ int(max(0, min(b[2] / max(w, 1), 1)) * 1000),
151
+ int(max(0, min(b[3] / max(h, 1), 1)) * 1000),
152
+ ]
153
+ for b in boxes
154
+ ]
155
+
156
+
157
+ def extract_rich_features(item):
158
+ try:
159
+ img = item["image"].convert("RGB")
160
+ W, H = img.size
161
+ words = list(item.get(WORD_FIELD, []))[:MAX_WORDS] or ["[PAD]"]
162
+ boxes = list(item.get(BOX_FIELD, []))[:MAX_WORDS] or [[0, 0, 1, 1]]
163
+ question = get_question(item)
164
+ bn = normalize_boxes(boxes, W, H)
165
+ enc = feat_processor(img, text=words, boxes=bn,
166
+ return_tensors="pt", truncation=True,
167
+ max_length=512, padding="max_length")
168
+ enc = {k: v.to(device) for k, v in enc.items()}
169
+ with torch.no_grad():
170
+ hidden = feat_model(**enc).last_hidden_state[0]
171
+ n_txt = max(2, hidden.shape[0] - N_PATCHES)
172
+ H_text = hidden[1:n_txt-1].mean(0) if n_txt > 2 else hidden[0]
173
+ H_visual = hidden[-N_PATCHES:].mean(0)
174
+ bx = np.array(bn, dtype=np.float32)
175
+ cx = ((bx[:, 0] + bx[:, 2]) / 2) / 1000.0
176
+ cy = ((bx[:, 1] + bx[:, 3]) / 2) / 1000.0
177
+ sp = np.array([
178
+ W / 1000.0, H / 1000.0, min(W, H) / max(W, H),
179
+ len(words) / MAX_WORDS,
180
+ cx.mean(), cy.mean(), cx.std() + 1e-6, cy.std() + 1e-6,
181
+ H_text.norm().item() / 10.0,
182
+ H_visual.norm().item() / 10.0,
183
+ ], dtype=np.float32)
184
+ sp10 = torch.tensor(sp).to(device)
185
+ H_spat = spatial_proj(sp10.unsqueeze(0)).squeeze(0)
186
+ q_emb = torch.tensor(sbert.encode(question),
187
+ dtype=torch.float32).to(device)
188
+ return {
189
+ "H_text": H_text, "H_visual": H_visual, "H_spatial": H_spat,
190
+ "spatial_10": sp10, "question_emb": q_emb,
191
+ "text_score": float(np.clip(sp[8], 0, 1)),
192
+ "visual_score": float(np.clip(sp[9], 0, 1)),
193
+ "spatial_score": float(np.clip(sp[6], 0, 1)),
194
+ "n_tokens": len(words),
195
+ }
196
+ except Exception as e:
197
+ print(f" extract_rich_features error: {e}")
198
+ dummy = torch.zeros(768, device=device)
199
+ return {
200
+ "H_text": dummy, "H_visual": dummy, "H_spatial": dummy,
201
+ "spatial_10": torch.zeros(10, device=device),
202
+ "question_emb": torch.zeros(384, device=device),
203
+ "text_score": 0.5, "visual_score": 0.3, "spatial_score": 0.2,
204
+ "n_tokens": 0,
205
+ }
206
+
207
+
208
+ def build_feature_vector(feat):
209
+ return torch.cat([
210
+ feat["H_text"], feat["H_visual"], feat["H_spatial"],
211
+ feat["spatial_10"], feat["question_emb"],
212
+ torch.tensor(
213
+ [feat["text_score"], feat["visual_score"], feat["spatial_score"]],
214
+ dtype=torch.float32, device=device
215
+ ),
216
+ ])
217
+
218
+
219
+ def vqa_infer(item, alpha, beta, gamma):
220
+ try:
221
+ img = item["image"].convert("RGB")
222
+ words = list(item.get(WORD_FIELD, []))
223
+ boxes = list(item.get(BOX_FIELD, []))
224
+ question = get_question(item)
225
+ if not words:
226
+ return ""
227
+ W, H = img.size
228
+ n = len(words)
229
+ n_keep = max(int(n * max(float(alpha), 0.30)), min(5, n))
230
+ if float(gamma) > max(float(alpha), float(beta)) and boxes:
231
+ order = sorted(range(n), key=lambda i: (boxes[i][1], boxes[i][0]))
232
+ sel_idx = sorted(order[:n_keep])
233
+ else:
234
+ sel_idx = list(range(n_keep))
235
+ sw = [words[i] for i in sel_idx]
236
+ sb = ([boxes[i] for i in sel_idx]
237
+ if boxes else [[0, 0, W, H]] * len(sw))
238
+ enc = vqa_processor(
239
+ img, text=question, text_pair=sw,
240
+ boxes=normalize_boxes(sb, W, H),
241
+ return_tensors="pt", truncation=True,
242
+ max_length=512, padding=True
243
+ )
244
+ enc = {k: v.to(device) for k, v in enc.items()}
245
+ with torch.no_grad():
246
+ out = vqa_model(**enc)
247
+ s = int(out.start_logits.argmax())
248
+ e = int(out.end_logits.argmax())
249
+ if e < s: e = s
250
+ return vqa_processor.tokenizer.decode(
251
+ enc["input_ids"][0][s:e+1], skip_special_tokens=True
252
+ ).strip()
253
+ except Exception:
254
+ return ""
255
+
256
+
257
+ def compute_anls(pred, gts, threshold=0.5):
258
+ if isinstance(gts, str): gts = [gts]
259
+ if not gts or not pred: return 0.0
260
+ p, best = str(pred).lower().strip(), 0.0
261
+ for gt in gts:
262
+ g = str(gt).lower().strip()
263
+ ml = max(len(p), len(g))
264
+ if ml == 0:
265
+ best = max(best, 1.0); continue
266
+ nls = 1.0 - editdistance.eval(p, g) / ml
267
+ if nls < threshold: nls = 0.0
268
+ best = max(best, nls)
269
+ return best
270
+
271
+
272
+ def compute_f1(pred, gts):
273
+ if isinstance(gts, str): gts = [gts]
274
+ if not pred or not gts: return 0.0
275
+ pt = set(str(pred).lower().split())
276
+ if not pt: return 0.0
277
+ best = 0.0
278
+ for gt in gts:
279
+ gt_t = set(str(gt).lower().split())
280
+ if not gt_t: continue
281
+ common = pt & gt_t
282
+ if not common: continue
283
+ p = len(common) / len(pt)
284
+ r = len(common) / len(gt_t)
285
+ best = max(best, 2 * p * r / (p + r))
286
+ return best
287
+
288
+
289
+ # ══════════════════════════════════════════════════════════════════════
290
+ # WORD SELECTION VISUALIZER HELPERS
291
+ # ══════════════════════════════════════════════════════════════════════
292
+
293
+ def get_sel_idx(item, alpha, beta, gamma):
294
+ """Return the SET of word indices kept by this (alpha, beta, gamma) config.
295
+
296
+ Mirrors the exact selection logic in vqa_infer so the boxes always
297
+ match what the model actually sees.
298
+ """
299
+ words = list(item.get(WORD_FIELD, []))
300
+ boxes = list(item.get(BOX_FIELD, []))
301
+ n = len(words)
302
+ if n == 0:
303
+ return set()
304
+ n_keep = max(int(n * max(float(alpha), 0.30)), min(5, n))
305
+ n_keep = min(n_keep, n)
306
+ if float(gamma) > max(float(alpha), float(beta)) and boxes:
307
+ order = sorted(range(n), key=lambda i: (boxes[i][1], boxes[i][0]))
308
+ sel_idx = set(order[:n_keep])
309
+ else:
310
+ sel_idx = set(range(n_keep))
311
+ return sel_idx
312
+
313
+
314
+ def draw_selection(item, alpha, beta, gamma, title=""):
315
+ """Return a PIL Image with coloured bounding boxes overlaid.
316
+
317
+ 🟒 Green fill + outline β†’ word KEPT (used for VQA)
318
+ πŸ”΄ Red fill + outline β†’ word DROPPED (compressed out)
319
+
320
+ An info strip (dark) and a colour legend strip are appended below the
321
+ document image so the panel is self-explanatory at a glance.
322
+ """
323
+ try:
324
+ img = item["image"].convert("RGB").copy()
325
+ W, H = img.size
326
+ words = list(item.get(WORD_FIELD, []))
327
+ boxes = list(item.get(BOX_FIELD, []))
328
+ n = min(len(words), len(boxes))
329
+ if n == 0:
330
+ return img
331
+
332
+ sel_idx = get_sel_idx(item, alpha, beta, gamma)
333
+ n_keep = len(sel_idx)
334
+ pct = 100 * n_keep / max(n, 1)
335
+
336
+ # ── Draw semi-transparent coloured overlays ───────────────────
337
+ overlay = PILImage.new("RGBA", img.size, (0, 0, 0, 0))
338
+ od = PILDraw.Draw(overlay)
339
+ for i in range(n):
340
+ try:
341
+ x0, y0, x1, y1 = (int(boxes[i][0]), int(boxes[i][1]),
342
+ int(boxes[i][2]), int(boxes[i][3]))
343
+ # Clamp to image bounds
344
+ x0, x1 = max(0, x0), min(W - 1, x1)
345
+ y0, y1 = max(0, y0), min(H - 1, y1)
346
+ if x1 <= x0 or y1 <= y0:
347
+ continue
348
+ if i in sel_idx:
349
+ od.rectangle([x0, y0, x1, y1],
350
+ fill=(0, 210, 0, 55),
351
+ outline=(0, 160, 0, 230), width=2)
352
+ else:
353
+ od.rectangle([x0, y0, x1, y1],
354
+ fill=(220, 30, 30, 40),
355
+ outline=(200, 0, 0, 170), width=1)
356
+ except Exception:
357
+ continue
358
+ img = PILImage.alpha_composite(img.convert("RGBA"), overlay).convert("RGB")
359
+
360
+ # ── Load font (graceful fallback) ─��───────────────────────────
361
+ font_sm = PILFont.load_default()
362
+ for _fp in [
363
+ "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
364
+ "/System/Library/Fonts/Supplemental/Arial.ttf",
365
+ "/Windows/Fonts/arial.ttf",
366
+ ]:
367
+ try:
368
+ font_sm = PILFont.truetype(_fp, 13)
369
+ break
370
+ except Exception:
371
+ continue
372
+
373
+ # ── Info strip (dark bar showing title + stats) ───────────────
374
+ strip_h = 36
375
+ strip = PILImage.new("RGB", (W, strip_h), (22, 22, 32))
376
+ sd = PILDraw.Draw(strip)
377
+ info_text = (f"{title} | βœ“ Kept: {n_keep}/{n} ({pct:.0f}%)"
378
+ f" | Ξ±={alpha:.2f} Ξ²={beta:.2f} Ξ³={gamma:.2f}")
379
+ sd.text((8, 11), info_text, fill=(220, 220, 220), font=font_sm)
380
+
381
+ # ── Legend strip (light bar explaining colours) ───────────────
382
+ leg_h = 28
383
+ leg = PILImage.new("RGB", (W, leg_h), (246, 246, 246))
384
+ ld = PILDraw.Draw(leg)
385
+ ld.rectangle([8, 7, 24, 21], fill=(0, 180, 0), outline=(0, 130, 0, 255))
386
+ ld.text( [30, 8], "= Kept (used for VQA)", fill=(0, 110, 0), font=font_sm)
387
+ ld.rectangle([210, 7, 226, 21], fill=(220, 30, 30), outline=(170, 0, 0, 255))
388
+ ld.text( [232, 8], "= Dropped (compressed out)", fill=(140, 0, 0), font=font_sm)
389
+
390
+ # ── Stack: image β†’ dark strip β†’ legend ────────────────────────
391
+ final = PILImage.new("RGB", (W, H + strip_h + leg_h), (255, 255, 255))
392
+ final.paste(img, (0, 0))
393
+ final.paste(strip, (0, H))
394
+ final.paste(leg, (0, H + strip_h))
395
+ return final
396
+
397
+ except Exception as e:
398
+ print(f" draw_selection error: {e}")
399
+ return item.get("image", None)
400
+
401
+
402
+ def make_compression_md(item, cfgs):
403
+ """Build a markdown table showing kept / dropped word statistics and
404
+ a sample of the words that each method discards.
405
+
406
+ cfgs – OrderedDict/dict {method_name: (alpha, beta, gamma)}
407
+ """
408
+ words = list(item.get(WORD_FIELD, []))
409
+ n = len(words)
410
+ if n == 0:
411
+ return "*No OCR words available for this document.*"
412
+
413
+ md = "### πŸ” What Gets Compressed?\n\n"
414
+ md += f"**Total OCR words in document:** {n}\n\n"
415
+ md += ("| Method | Ξ± | Ξ² | Ξ³ | Words Kept | % Context |"
416
+ " Sample Dropped Words |\n")
417
+ md += ("|--------|---|---|---|:----------:|:---------:|"
418
+ "----------------------|\n")
419
+
420
+ for name, (a, b, g) in cfgs.items():
421
+ sel = get_sel_idx(item, a, b, g)
422
+ n_keep = len(sel)
423
+ pct = 100 * n_keep / max(n, 1)
424
+ dropped = [words[i] for i in range(n) if i not in sel]
425
+ d_preview = " Β· ".join(dropped[:8])
426
+ if len(dropped) > 8:
427
+ d_preview += f" … (+{len(dropped) - 8} more)"
428
+ md += (f"| **{name}** | {a:.2f} | {b:.2f} | {g:.2f}"
429
+ f" | {n_keep} / {n} | {pct:.0f}% | `{d_preview}` |\n")
430
+
431
+ # Show the actual kept words for the CAFP+REINFORCE method
432
+ if "CAFP+REINFORCE" in cfgs:
433
+ a, b, g = cfgs["CAFP+REINFORCE"]
434
+ sel = get_sel_idx(item, a, b, g)
435
+ kept_w = [words[i] for i in sorted(sel)[:25]]
436
+ md += (f"\n**CAFP+REINFORCE β€” kept words (first 25 shown):** \n"
437
+ f"`{' Β· '.join(kept_w)}`\n")
438
+
439
+ return md
440
+
441
+
442
+ # ══════════════════════════════════════════════════════════════════════
443
+ # LOAD CHECKPOINTS & DATA
444
+ # ══════════════════════════════════════════════════════════════════════
445
+ print("\nLoading checkpoints and data...")
446
+
447
+ # ── RL checkpoint ─────────────────────────────────────────────────────
448
+ if not os.path.exists(CKPT_PATH):
449
+ sys.exit(f"❌ Checkpoint not found: {CKPT_PATH}\n"
450
+ f" Copy cafp_rl_checkpoint_final.pt into {CKPT_DIR}/")
451
+
452
+ ck = torch.load(CKPT_PATH, map_location=device, weights_only=False)
453
+ spatial_proj.load_state_dict(ck["spatial_proj_state"])
454
+
455
+ cafp_soft = CrossAttentionFusionPredictor().to(device)
456
+ cafp_soft.load_state_dict(ck["cafp_soft_state"]); cafp_soft.eval()
457
+
458
+ cafp_rl = copy.deepcopy(cafp_soft)
459
+ cafp_rl.load_state_dict(ck["cafp_rl_state"]); cafp_rl.eval()
460
+
461
+ rl_train_anls = ck["rl_train_anls"]
462
+ rl_val_anls = ck.get("rl_val_anls",
463
+ max(rl_train_anls) if rl_train_anls else 0.0)
464
+ print(f" βœ… CAFP+REINFORCE: {len(rl_train_anls)} epochs | "
465
+ f"best_train={max(rl_train_anls):.4f} | val={rl_val_anls:.4f}")
466
+
467
+ # ── Dataset ───────────────────────────────────────────────────────────
468
+ print(" Loading dataset (~30s)...")
469
+ from datasets import load_dataset
470
+ _ds = load_dataset(DATASET_NAME, split="train")
471
+ _split = _ds.train_test_split(test_size=0.2, seed=42)
472
+ rng = np.random.RandomState(42)
473
+ val_idx = rng.permutation(len(_split["test"])).tolist()[:N_VAL]
474
+ train_idx = rng.permutation(len(_split["train"])).tolist()[:N_TRAIN]
475
+ val_items = [_split["test"][i] for i in val_idx]
476
+ train_items = [_split["train"][i] for i in train_idx]
477
+ val_gts = [item[ANSWER_FIELD] for item in val_items]
478
+ train_gts = [item[ANSWER_FIELD] for item in train_items]
479
+ print(f" βœ… Dataset: {len(val_items)} val, {len(train_items)} train")
480
+
481
+ # ── Feature tensors ───────────────────────────────────────────────────
482
+ if os.path.exists(FEAT_PATH):
483
+ t = torch.load(FEAT_PATH, map_location=device, weights_only=False)
484
+ val_feats = t["val_feats"]
485
+ train_feats = t["train_feats"]
486
+ print(f" βœ… Features: {tuple(val_feats.shape)}")
487
+ else:
488
+ print(" ⚠️ feature_tensors.pt not found β€” recomputing (~2 min)...")
489
+ def _feats(items, tag):
490
+ out = []
491
+ for i, item in enumerate(items):
492
+ out.append(build_feature_vector(
493
+ extract_rich_features(item)).unsqueeze(0))
494
+ if (i + 1) % 10 == 0:
495
+ print(f" {tag}: {i+1}/{len(items)}", end="\r")
496
+ print()
497
+ return torch.cat(out).to(device)
498
+ val_feats = _feats(val_items, "val")
499
+ train_feats = _feats(train_items, "train")
500
+ torch.save({"val_feats": val_feats, "train_feats": train_feats,
501
+ "val_gts": val_gts, "train_gts": train_gts}, FEAT_PATH)
502
+ print(f" βœ… Features computed and saved to {FEAT_PATH}")
503
+
504
+ # ── Oracle cache ──────────────────────────────────────────────────────
505
+ val_oracle = train_oracle = []
506
+ if os.path.exists(ORACLE_CACHE):
507
+ _oc = json.load(open(ORACLE_CACHE))
508
+ train_oracle = _oc.get("train", [])
509
+ val_oracle = _oc.get("val", [])
510
+ print(f" βœ… Oracle cache: {len(train_oracle)} train, {len(val_oracle)} val")
511
+ else:
512
+ print(" ⚠️ oracle_cache.json not found β€” demo works without it")
513
+
514
+ # ── Results from JSON ─────────────────────────────────────────────────
515
+ results = {}
516
+ _RKEYS = [
517
+ "Equal Fusion", "Proposed Fixed", "Text-Only",
518
+ "LLMLingua-style", "Selective Context-style",
519
+ "CAFP (paper checkpoint)", "CAFP-Hard Oracle", "CAFP-Soft Oracle",
520
+ ]
521
+ for _rpath in [RESULTS_PATH, "./final_results.json", "./results_condensed.json"]:
522
+ try:
523
+ _raw = json.load(open(_rpath))
524
+ for k in _RKEYS:
525
+ if k in _raw and isinstance(_raw[k], dict):
526
+ r = _raw[k]
527
+ results[k] = {
528
+ "mean_anls": float(r.get("mean_anls", r.get("anls", 0.0))),
529
+ "mean_f1": float(r.get("mean_f1", r.get("f1", 0.0))),
530
+ }
531
+ if results:
532
+ print(f" βœ… Results: {len(results)} methods from {_rpath}")
533
+ break
534
+ except Exception:
535
+ continue
536
+ if not results:
537
+ print(" ⚠️ Results JSON not found β€” dashboard will show partial data")
538
+
539
+ # ── Find best demo documents ──────────────────────────────────────────
540
+ print("\nPre-scoring documents for demo (this takes ~2 min)...")
541
+ demo_scores = []
542
+ cafp_rl.eval()
543
+ with torch.no_grad():
544
+ for i in range(len(val_items)):
545
+ fv = val_feats[i].unsqueeze(0)
546
+ conc = F.softplus(cafp_rl._logits(fv)) + 0.1
547
+ w = (conc / conc.sum()).squeeze(0).cpu().tolist()
548
+ rl_s = compute_anls(vqa_infer(val_items[i], w[0], w[1], w[2]),
549
+ val_gts[i])
550
+ fx_s = compute_anls(vqa_infer(val_items[i], 0.5, 0.3, 0.2),
551
+ val_gts[i])
552
+ demo_scores.append((i, round(rl_s - fx_s, 4),
553
+ round(rl_s, 4), round(fx_s, 4)))
554
+ if (i + 1) % 20 == 0:
555
+ print(f" {i+1}/100", end="\r")
556
+ demo_scores.sort(key=lambda x: -x[1])
557
+ best_idx = demo_scores[0][0]
558
+ top5_str = ", ".join([f"#{x[0]}(+{x[1]:.2f})" for x in demo_scores[:5]])
559
+ print(f"\n βœ… Best docs: {top5_str}")
560
+ print(f"\n{'='*55}")
561
+ print("ALL MODELS LOADED β€” ready to demo")
562
+ print(f"{'='*55}\n")
563
+
564
+
565
+ # ══════════════════════════════════════════════════════════════════════
566
+ # GRADIO FUNCTIONS
567
+ # ══════════════════════════════════════════════════════════════════════
568
+ def get_rl_weights(idx, custom_q=None):
569
+ if custom_q and custom_q.strip():
570
+ _item = dict(val_items[idx])
571
+ _item[QUERY_FIELD] = custom_q.strip()
572
+ fv = build_feature_vector(extract_rich_features(_item)).unsqueeze(0)
573
+ else:
574
+ fv = val_feats[idx].unsqueeze(0)
575
+ with torch.no_grad():
576
+ conc = F.softplus(cafp_rl._logits(fv)) + 0.1
577
+ w = (conc / conc.sum()).squeeze(0).cpu().tolist()
578
+ return w
579
+
580
+
581
+ def make_weight_chart(mw):
582
+ fig, ax = plt.subplots(figsize=(9, 3.5))
583
+ labels = list(mw.keys())
584
+ x, bw = np.arange(len(labels)), 0.25
585
+ for j, (lbl, col) in enumerate([
586
+ ("\u03b1 Text", "#2196F3"),
587
+ ("\u03b2 Visual", "#4CAF50"),
588
+ ("\u03b3 Spatial", "#FF9800"),
589
+ ]):
590
+ vals = [list(mw.values())[i][j] for i in range(len(labels))]
591
+ bars = ax.bar(x + (j - 1) * bw, vals, bw,
592
+ label=lbl, color=col, alpha=0.85)
593
+ for bar in bars:
594
+ h = bar.get_height()
595
+ ax.text(bar.get_x() + bar.get_width() / 2, h + 0.01,
596
+ f"{h:.2f}", ha="center", va="bottom", fontsize=9)
597
+ ax.set_xticks(x); ax.set_xticklabels(labels, fontsize=10)
598
+ ax.set_ylabel("Weight"); ax.set_ylim(0, 1.2)
599
+ ax.set_title("Fusion Weights (\u03b1, \u03b2, \u03b3) per Method",
600
+ fontsize=12, fontweight="bold")
601
+ ax.legend(fontsize=9); ax.grid(axis="y", alpha=0.3)
602
+ plt.tight_layout()
603
+ return fig
604
+
605
+
606
+ def run_demo(doc_idx, custom_q):
607
+ doc_idx = int(doc_idx)
608
+ item = val_items[doc_idx]
609
+ gt = val_gts[doc_idx]
610
+ q = (custom_q.strip()
611
+ if custom_q and custom_q.strip()
612
+ else get_question(item))
613
+ gt_str = (", ".join(str(g) for g in gt[:2])
614
+ if isinstance(gt, list) else str(gt))
615
+ n_words = len(list(item.get(WORD_FIELD, [])))
616
+ doc_type = "Text-dominant" if n_words > 40 else "Visual-dominant"
617
+
618
+ alpha, beta, gamma = get_rl_weights(doc_idx, custom_q)
619
+ dom = ("Text" if alpha > 0.65 else
620
+ "Visual" if beta > 0.40 else "Balanced")
621
+
622
+ cfgs = {
623
+ "Equal Fusion": (1/3, 1/3, 1/3),
624
+ "Fixed (0.5,0.3,0.2)": (0.5, 0.3, 0.2),
625
+ "Text-Only": (1.0, 0.0, 0.0),
626
+ "CAFP+REINFORCE": (alpha, beta, gamma),
627
+ }
628
+ res = {}
629
+ for name, (a, b, g) in cfgs.items():
630
+ demo_item = dict(item); demo_item[QUERY_FIELD] = q
631
+ pred = vqa_infer(demo_item, a, b, g)
632
+ res[name] = {
633
+ "pred": pred,
634
+ "anls": compute_anls(pred, gt),
635
+ "f1": compute_f1(pred, gt),
636
+ "w": (a, b, g),
637
+ }
638
+
639
+ best = max(res, key=lambda k: res[k]["anls"])
640
+ rl_vs_fixed = res["CAFP+REINFORCE"]["anls"] - res["Fixed (0.5,0.3,0.2)"]["anls"]
641
+
642
+ md = f"## Document #{doc_idx} \u2014 {doc_type} ({n_words} words)\n\n"
643
+ md += f"**Question:** {q}\n\n"
644
+ md += f"**Ground Truth:** `{gt_str}`\n\n---\n"
645
+ md += "### Step 1 \u2014 Text Extraction\n"
646
+ md += f"`{n_words}` OCR words extracted via LayoutLMv3\n\n"
647
+ md += "### Step 2 \u2014 Multimodal Feature Extraction\n"
648
+ md += ("- **Text** \u2192 LayoutLMv3 token embeddings [768-D]\n"
649
+ "- **Visual** \u2192 LayoutLMv3 patch features [768-D]\n"
650
+ "- **Spatial** \u2192 Bounding box layout encoding [768-D]\n\n")
651
+ md += "### Step 3 \u2014 CAFP+REINFORCE Weight Prediction\n"
652
+ md += "| Modality | Weight |\n|----------|--------|\n"
653
+ md += f"| \u03b1 Text | **{alpha:.3f}** |\n"
654
+ md += f"| \u03b2 Visual | **{beta:.3f}** |\n"
655
+ md += f"| \u03b3 Spatial | **{gamma:.3f}** |\n\n"
656
+ md += f"\u2192 **Dominant: {dom}**\n\n"
657
+ md += "### Step 4 \u2014 Adaptive Fusion \u2192 Answer\n"
658
+ md += "| Method | \u03b1 | \u03b2 | \u03b3 | Answer | ANLS | F1 |\n"
659
+ md += "|--------|---|---|---|--------|------|----|\n"
660
+ for name, d in res.items():
661
+ a, b, g = d["w"]
662
+ star = " \u2b50" if name == best else ""
663
+ md += (f"| {name}{star} | {a:.2f} | {b:.2f} | {g:.2f}"
664
+ f" | `{d['pred']}` | **{d['anls']:.4f}** | {d['f1']:.4f} |\n")
665
+ sign = "+" if rl_vs_fixed >= 0 else ""
666
+ md += (f"\n---\n**Best:** {best} (ANLS: {res[best]['anls']:.4f})\n\n"
667
+ f"**CAFP+REINFORCE answer:** `{res['CAFP+REINFORCE']['pred']}`\n\n"
668
+ f"**\u0394 over Fixed:** {sign}{rl_vs_fixed:.4f}\n")
669
+
670
+ chart = make_weight_chart({k: v["w"] for k, v in res.items()})
671
+
672
+ # ── Word Selection Visualizations ─────────────────────────────────
673
+ _item_q = dict(item); _item_q[QUERY_FIELD] = q
674
+ fixed_vis = draw_selection(
675
+ _item_q, 0.5, 0.3, 0.2,
676
+ "Fixed Weights (0.5, 0.3, 0.2)"
677
+ )
678
+ rl_vis = draw_selection(
679
+ _item_q, alpha, beta, gamma,
680
+ f"CAFP+REINFORCE (Ξ±={alpha:.2f} Ξ²={beta:.2f} Ξ³={gamma:.2f})"
681
+ )
682
+ comp_md = make_compression_md(item, cfgs)
683
+
684
+ return item.get("image", None), md, chart, fixed_vis, rl_vis, comp_md
685
+
686
+
687
+ def show_dashboard():
688
+ def sg(k): return results.get(k, {}).get("mean_anls", 0.0)
689
+ def sf(k): return results.get(k, {}).get("mean_f1", 0.0)
690
+ fixed = sg("Proposed Fixed"); oracle = 0.8377
691
+ rv = rl_val_anls
692
+
693
+ rows = [
694
+ ("Equal Fusion", sg("Equal Fusion"), sf("Equal Fusion")),
695
+ ("Proposed Fixed (paper)", sg("Proposed Fixed"), sf("Proposed Fixed")),
696
+ ("Text-Only", sg("Text-Only"), sf("Text-Only")),
697
+ ("LLMLingua-style [NEW]", sg("LLMLingua-style"), sf("LLMLingua-style")),
698
+ ("Selective Context [NEW]", sg("Selective Context-style"), sf("Selective Context-style")),
699
+ ("CAFP paper checkpoint", sg("CAFP (paper checkpoint)"), sf("CAFP (paper checkpoint)")),
700
+ ("CAFP Hard Oracle [NEW]", sg("CAFP-Hard Oracle"), sf("CAFP-Hard Oracle")),
701
+ ("CAFP Soft Oracle [NEW]", sg("CAFP-Soft Oracle"), sf("CAFP-Soft Oracle")),
702
+ ("CAFP+REINFORCE [NEW][BEST]", rv, 0.0),
703
+ ("Oracle Upper Bound", oracle, 0.0),
704
+ ]
705
+
706
+ md = "## Full Experiment Results\n\n"
707
+ md += "| Method | ANLS | F1 | \u0394 Fixed | % Oracle |\n"
708
+ md += "|--------|------|----|----------|----------|\n"
709
+ for name, anls, f1 in rows:
710
+ is_oracle = "Oracle Upper" in name
711
+ d = f"{anls - fixed:+.4f}" if not is_oracle else "\u2014"
712
+ pct = f"{anls / oracle * 100:.1f}%" if anls > 0 else "\u2014"
713
+ md += f"| {name} | {anls:.4f} | {f1:.4f} | {d} | {pct} |\n"
714
+ md += (f"\n**CAFP+REINFORCE: {rv/oracle*100:.1f}% of Oracle ANLS**\n"
715
+ f"**Improvement over Fixed: {rv - fixed:+.4f} ANLS**\n")
716
+
717
+ # Bar chart
718
+ fig1, ax1 = plt.subplots(figsize=(11, 5))
719
+ bv = [r[1] for r in rows]
720
+ bc = ["#bbb","#999","#bbb","#2196F3","#2196F3",
721
+ "#777","#4CAF50","#4CAF50","#FF5722","#d32f2f"]
722
+ bars = ax1.barh([r[0] for r in rows], bv,
723
+ color=bc, edgecolor="white", height=0.65)
724
+ ax1.axvline(oracle, color="red", linestyle="--", lw=1.5,
725
+ label=f"Oracle {oracle:.4f}")
726
+ ax1.axvline(fixed, color="gray", linestyle=":", lw=1.2,
727
+ label=f"Fixed {fixed:.4f}")
728
+ for bar, val in zip(bars, bv):
729
+ if val > 0:
730
+ ax1.text(val + 0.003, bar.get_y() + bar.get_height() / 2,
731
+ f"{val:.4f}", va="center", fontsize=8)
732
+ ax1.set_xlabel("Val ANLS", fontsize=11)
733
+ ax1.set_title("All Methods \u2014 Val ANLS", fontsize=13, fontweight="bold")
734
+ ax1.legend(fontsize=9); ax1.invert_yaxis()
735
+ ax1.set_xlim(0, oracle * 1.1); ax1.grid(axis="x", alpha=0.3)
736
+ plt.tight_layout()
737
+
738
+ # Training curve
739
+ fig2, ax2 = plt.subplots(figsize=(10, 3.5))
740
+ eps = list(range(1, len(rl_train_anls) + 1))
741
+ ax2.plot(eps, rl_train_anls, "o-", color="#FF5722",
742
+ lw=2.5, ms=7, label="Train ANLS")
743
+ ax2.axhline(rv, color="#FF5722", linestyle=":", lw=2,
744
+ label=f"Val ANLS = {rv:.4f}")
745
+ ax2.axhline(oracle, color="red", linestyle="--", lw=1.5,
746
+ label=f"Oracle = {oracle:.4f}")
747
+ ax2.axhline(fixed, color="gray", linestyle=":", lw=1.2,
748
+ label=f"Fixed = {fixed:.4f}")
749
+ ax2.fill_between(eps, rl_train_anls, fixed, alpha=0.15, color="#FF5722")
750
+ ax2.set_xlabel("Epoch"); ax2.set_ylabel("ANLS")
751
+ ax2.set_title("REINFORCE Fine-tuning Progress",
752
+ fontsize=12, fontweight="bold")
753
+ ax2.legend(fontsize=9); ax2.grid(True, alpha=0.3)
754
+ ax2.set_xticks(eps); plt.tight_layout()
755
+
756
+ return md, fig1, fig2
757
+
758
+
759
+ # ══════════════════════════════════════════════════════════════════════
760
+ # GRADIO UI
761
+ # ══════════════════════════════════════════════════════════════════════
762
+ _fixed_anls = results.get("Proposed Fixed", {}).get("mean_anls", 0.0)
763
+ _best_label = ("Best docs (REINFORCE wins most): "
764
+ + ", ".join([f"#{x[0]}" for x in demo_scores[:5]]))
765
+
766
+ CSS = ".tab-nav button { font-size: 15px !important; font-weight: 600 !important; }"
767
+
768
+ with gr.Blocks(
769
+ title="Adaptive Multimodal Fusion β€” DocVQA Demo",
770
+ theme=gr.themes.Soft(primary_hue="blue"),
771
+ css=CSS,
772
+ ) as demo_app:
773
+
774
+ gr.Markdown("""
775
+ # Adaptive Multimodal Fusion for Document VQA
776
+ ### Cross-Attention Fusion Predictor (CAFP) + REINFORCE Fine-tuning
777
+ """)
778
+
779
+ with gr.Tabs():
780
+
781
+ # ── Tab 1: Live Demo ──────────────────────────────────────────
782
+ with gr.TabItem("\U0001f3af Live Demo"):
783
+ gr.Markdown(f"**{_best_label}**")
784
+ with gr.Row():
785
+ with gr.Column(scale=1):
786
+ doc_slider = gr.Slider(
787
+ 0, len(val_items) - 1,
788
+ value=best_idx, step=1,
789
+ label=f"Document Index (0\u2013{len(val_items)-1})"
790
+ )
791
+ custom_q = gr.Textbox(
792
+ label="Custom Question (optional)",
793
+ placeholder="Leave blank to use original question"
794
+ )
795
+ run_btn = gr.Button(
796
+ "\u25b6 Run Adaptive Fusion",
797
+ variant="primary", size="lg"
798
+ )
799
+ gr.Markdown(
800
+ "*Compares: Equal \u00b7 Fixed \u00b7 "
801
+ "Text-Only \u00b7 CAFP+REINFORCE*"
802
+ )
803
+ with gr.Column(scale=2):
804
+ doc_image = gr.Image(label="Document Image", height=400)
805
+ step_md = gr.Markdown()
806
+ weight_chart = gr.Plot(label="Fusion Weights Comparison")
807
+
808
+ # ── Word Selection Visualizer ─────────────────────────────
809
+ gr.Markdown("""
810
+ ---
811
+ ### 🎨 Word Selection Visualization
812
+ *See **exactly** which OCR words each method keeps vs discards.*
813
+ 🟒 **Green** = kept and fed to the VQA model Β· πŸ”΄ **Red** = compressed out
814
+ """)
815
+ with gr.Row():
816
+ fixed_vis_img = gr.Image(
817
+ label="πŸ“Œ Fixed Weights (Ξ±=0.5 Ξ²=0.3 Ξ³=0.2)",
818
+ height=520, show_download_button=True
819
+ )
820
+ rl_vis_img = gr.Image(
821
+ label="πŸ€– CAFP+REINFORCE (Adaptive Weights)",
822
+ height=520, show_download_button=True
823
+ )
824
+ comp_md_out = gr.Markdown()
825
+
826
+ run_btn.click(
827
+ fn=run_demo,
828
+ inputs=[doc_slider, custom_q],
829
+ outputs=[doc_image, step_md, weight_chart,
830
+ fixed_vis_img, rl_vis_img, comp_md_out],
831
+ )
832
+
833
+ # ── Tab 2: Results Dashboard ──────────────────────────────────
834
+ with gr.TabItem("\U0001f4ca Results Dashboard"):
835
+ gr.Markdown("### All methods compared + REINFORCE training curve")
836
+ load_btn = gr.Button("Load Results", variant="secondary")
837
+ res_md = gr.Markdown()
838
+ with gr.Row():
839
+ bar_chart = gr.Plot(label="ANLS \u2014 All Methods")
840
+ rl_curve = gr.Plot(label="REINFORCE Training Curve")
841
+ load_btn.click(
842
+ fn=show_dashboard,
843
+ inputs=[],
844
+ outputs=[res_md, bar_chart, rl_curve],
845
+ )
846
+
847
+ # ── Tab 3: About ──────────────────────────────────────────────
848
+ with gr.TabItem("\u2139\ufe0f About"):
849
+ gr.Markdown(f"""
850
+ ## Adaptive Multimodal Fusion for DocVQA
851
+
852
+ ### Problem
853
+ DocVQA requires reasoning over three modalities simultaneously:
854
+ - **Text** β€” OCR words and their semantics
855
+ - **Visual** β€” Document appearance and image patches
856
+ - **Spatial** β€” Bounding box positions and layout structure
857
+
858
+ Fixed weights (Ξ±=0.5, Ξ²=0.3, Ξ³=0.2) cannot adapt to different document types.
859
+
860
+ ### Architecture: CAFP (428K params)
861
+ 1. Projects each modality embedding to 128-D
862
+ 2. Cross-attention: question attends to all modality representations
863
+ 3. Predicts per-document (Ξ±, Ξ², Ξ³) fusion weights
864
+
865
+ ### Training Pipeline
866
+ 1. **Hard Oracle** (MSE) β†’ argmax weights from 20-combo grid search
867
+ 2. **Soft Oracle** (KL-div) β†’ temperature-smoothed ANLS-weighted targets
868
+ 3. **REINFORCE** β†’ Policy gradient on direct ANLS reward (K=3 samples/step)
869
+
870
+ ### Novel Contributions
871
+ 1. Soft Oracle training eliminates hard-oracle label noise
872
+ 2. REINFORCE fine-tuning directly maximises DocVQA metric
873
+ 3. LLMLingua-style and Selective Context baselines for fair comparison
874
+
875
+ ### Key Result
876
+ **CAFP+REINFORCE achieves {rl_val_anls/0.8377*100:.1f}% of Oracle ANLS**
877
+ Improvement over fixed-weight baseline: {rl_val_anls - _fixed_anls:+.4f} ANLS
878
+ """)
879
+
880
+
881
+ # ══════════════════════════════════════════════════════════════════════
882
+ # LAUNCH
883
+ # ══════════════════════════════════════════════════════════════════════
884
+ if __name__ == "__main__":
885
+ demo_app.launch(
886
+ server_name="0.0.0.0",
887
+ server_port=args.port,
888
+ share=args.share,
889
+ show_error=True,
890
+ )
app.py ADDED
@@ -0,0 +1,890 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adaptive Multimodal Fusion for DocVQA β€” Gradio Demo
3
+ ====================================================
4
+ Run locally:
5
+ python app.py
6
+
7
+ Run with public URL (72hr):
8
+ python app.py --share
9
+
10
+ Deploy to HuggingFace Spaces:
11
+ - Push this file + requirements.txt + checkpoints/ folder to a Space repo
12
+ - HF Spaces auto-launches on port 7860
13
+ """
14
+
15
+ import argparse, os, sys, copy, json, warnings
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import matplotlib
21
+ matplotlib.use("Agg")
22
+ import matplotlib.pyplot as plt
23
+ import gradio as gr
24
+ import editdistance
25
+ from PIL import Image as PILImage, ImageDraw as PILDraw, ImageFont as PILFont
26
+ warnings.filterwarnings("ignore")
27
+
28
+ # ── CLI args ──────────────────────────────────────────────────────────
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument("--share", action="store_true", help="Create public Gradio URL")
31
+ parser.add_argument("--port", type=int, default=7860)
32
+ parser.add_argument("--ckpt_dir", type=str, default="./checkpoints",
33
+ help="Folder containing all saved files")
34
+ args, _ = parser.parse_known_args()
35
+
36
+ # ══════════════════════════════════════════════════════════════════════
37
+ # CONFIGURATION β€” edit paths here if needed
38
+ # ══════════════════════════════════════════════════════════════════════
39
+ CKPT_DIR = args.ckpt_dir
40
+ ORACLE_CACHE = os.path.join(CKPT_DIR, "oracle_cache.json")
41
+ FEAT_PATH = os.path.join(CKPT_DIR, "feature_tensors.pt")
42
+ RESULTS_PATH = os.path.join(CKPT_DIR, "final_results.json")
43
+
44
+ # Try final checkpoint first, fall back to intermediate
45
+ CKPT_PATH = os.path.join(CKPT_DIR, "cafp_rl_checkpoint_final.pt")
46
+ if not os.path.exists(CKPT_PATH):
47
+ CKPT_PATH = os.path.join(CKPT_DIR, "cafp_rl_checkpoint.pt")
48
+
49
+ # Dataset / model IDs
50
+ DATASET_NAME = "nielsr/docvqa_1200_examples"
51
+ FEAT_MODEL_ID = "microsoft/layoutlmv3-base"
52
+ VQA_MODEL_ID = "rubentito/layoutlmv3-base-mpdocvqa"
53
+ SBERT_ID = "all-MiniLM-L6-v2"
54
+
55
+ # Field names
56
+ WORD_FIELD = "words"
57
+ BOX_FIELD = "bounding_boxes"
58
+ QUERY_FIELD = "query"
59
+ ANSWER_FIELD = "answers"
60
+
61
+ # Architecture
62
+ MAX_WORDS = 64
63
+ N_PATCHES = 49
64
+ N_VAL = 100
65
+ N_TRAIN = 100
66
+ FEAT_DIM = 2701
67
+ PROJ_DIM = 128
68
+
69
+ device = "cuda" if torch.cuda.is_available() else "cpu"
70
+ print(f"Device: {device}")
71
+
72
+ # ══════════════════════════════════════════════════════════════════════
73
+ # MODEL CLASSES
74
+ # ══════════════════════════════════════════════════════════════════════
75
+ class CrossAttentionFusionPredictor(nn.Module):
76
+ def __init__(self, feat_dim=FEAT_DIM, proj_dim=PROJ_DIM,
77
+ n_heads=4, dropout=0.15):
78
+ super().__init__()
79
+ self.text_proj = nn.Linear(768, proj_dim)
80
+ self.visual_proj = nn.Linear(768, proj_dim)
81
+ self.spatial_proj_lyr = nn.Linear(768, proj_dim)
82
+ self.q_proj = nn.Sequential(
83
+ nn.Linear(384, proj_dim), nn.LayerNorm(proj_dim), nn.GELU()
84
+ )
85
+ self.cross_attn = nn.MultiheadAttention(
86
+ proj_dim, n_heads, dropout=dropout, batch_first=True
87
+ )
88
+ self.attn_norm = nn.LayerNorm(proj_dim)
89
+ self.head = nn.Sequential(
90
+ nn.Linear(proj_dim + 3, proj_dim), nn.GELU(),
91
+ nn.Dropout(dropout), nn.Linear(proj_dim, 3)
92
+ )
93
+
94
+ def _logits(self, x):
95
+ h_t = self.text_proj(x[:, 0:768])
96
+ h_v = self.visual_proj(x[:, 768:1536])
97
+ h_s = self.spatial_proj_lyr(x[:, 1536:2304])
98
+ q = self.q_proj(x[:, 2314:2698]).unsqueeze(1)
99
+ kv = torch.stack([h_t, h_v, h_s], dim=1)
100
+ ctx, _ = self.cross_attn(q, kv, kv)
101
+ ctx = self.attn_norm(ctx.squeeze(1))
102
+ return self.head(torch.cat([ctx, x[:, 2698:2701]], dim=-1))
103
+
104
+ def forward(self, x):
105
+ return F.softmax(self._logits(x), dim=-1)
106
+
107
+
108
+ # ══════════════════════════════════════════════════════════════════════
109
+ # LOAD BASE MODELS
110
+ # ══════════════════════════════════════════════════════════════════════
111
+ print("Loading base models (takes ~5 min on first run, cached after)...")
112
+
113
+ from transformers import AutoProcessor, AutoModel, AutoModelForQuestionAnswering
114
+ from sentence_transformers import SentenceTransformer
115
+
116
+ feat_processor = AutoProcessor.from_pretrained(FEAT_MODEL_ID, apply_ocr=False)
117
+ feat_model = AutoModel.from_pretrained(FEAT_MODEL_ID).to(device).eval()
118
+ for p in feat_model.parameters(): p.requires_grad_(False)
119
+ print(" βœ… LayoutLMv3 feature model")
120
+
121
+ vqa_processor = AutoProcessor.from_pretrained(VQA_MODEL_ID, apply_ocr=False)
122
+ vqa_model = AutoModelForQuestionAnswering.from_pretrained(
123
+ VQA_MODEL_ID).to(device).eval()
124
+ for p in vqa_model.parameters(): p.requires_grad_(False)
125
+ print(" βœ… VQA model")
126
+
127
+ sbert = SentenceTransformer(SBERT_ID)
128
+ sbert.to(device)
129
+ print(" βœ… SBERT")
130
+
131
+ spatial_proj = nn.Sequential(
132
+ nn.Linear(10, 256), nn.ReLU(), nn.Linear(256, 768)
133
+ ).to(device)
134
+
135
+ # ══════════════════════════════════════════════════════════════════════
136
+ # HELPER FUNCTIONS
137
+ # ══════════════════════════════════════════════════════════════════════
138
+ def get_question(item):
139
+ q = item.get(QUERY_FIELD, item.get("question", ""))
140
+ if isinstance(q, dict):
141
+ q = q.get("en", next(iter(q.values()), ""))
142
+ return str(q).strip()
143
+
144
+
145
+ def normalize_boxes(boxes, w, h):
146
+ return [
147
+ [
148
+ int(max(0, min(b[0] / max(w, 1), 1)) * 1000),
149
+ int(max(0, min(b[1] / max(h, 1), 1)) * 1000),
150
+ int(max(0, min(b[2] / max(w, 1), 1)) * 1000),
151
+ int(max(0, min(b[3] / max(h, 1), 1)) * 1000),
152
+ ]
153
+ for b in boxes
154
+ ]
155
+
156
+
157
+ def extract_rich_features(item):
158
+ try:
159
+ img = item["image"].convert("RGB")
160
+ W, H = img.size
161
+ words = list(item.get(WORD_FIELD, []))[:MAX_WORDS] or ["[PAD]"]
162
+ boxes = list(item.get(BOX_FIELD, []))[:MAX_WORDS] or [[0, 0, 1, 1]]
163
+ question = get_question(item)
164
+ bn = normalize_boxes(boxes, W, H)
165
+ enc = feat_processor(img, text=words, boxes=bn,
166
+ return_tensors="pt", truncation=True,
167
+ max_length=512, padding="max_length")
168
+ enc = {k: v.to(device) for k, v in enc.items()}
169
+ with torch.no_grad():
170
+ hidden = feat_model(**enc).last_hidden_state[0]
171
+ n_txt = max(2, hidden.shape[0] - N_PATCHES)
172
+ H_text = hidden[1:n_txt-1].mean(0) if n_txt > 2 else hidden[0]
173
+ H_visual = hidden[-N_PATCHES:].mean(0)
174
+ bx = np.array(bn, dtype=np.float32)
175
+ cx = ((bx[:, 0] + bx[:, 2]) / 2) / 1000.0
176
+ cy = ((bx[:, 1] + bx[:, 3]) / 2) / 1000.0
177
+ sp = np.array([
178
+ W / 1000.0, H / 1000.0, min(W, H) / max(W, H),
179
+ len(words) / MAX_WORDS,
180
+ cx.mean(), cy.mean(), cx.std() + 1e-6, cy.std() + 1e-6,
181
+ H_text.norm().item() / 10.0,
182
+ H_visual.norm().item() / 10.0,
183
+ ], dtype=np.float32)
184
+ sp10 = torch.tensor(sp).to(device)
185
+ H_spat = spatial_proj(sp10.unsqueeze(0)).squeeze(0)
186
+ q_emb = torch.tensor(sbert.encode(question),
187
+ dtype=torch.float32).to(device)
188
+ return {
189
+ "H_text": H_text, "H_visual": H_visual, "H_spatial": H_spat,
190
+ "spatial_10": sp10, "question_emb": q_emb,
191
+ "text_score": float(np.clip(sp[8], 0, 1)),
192
+ "visual_score": float(np.clip(sp[9], 0, 1)),
193
+ "spatial_score": float(np.clip(sp[6], 0, 1)),
194
+ "n_tokens": len(words),
195
+ }
196
+ except Exception as e:
197
+ print(f" extract_rich_features error: {e}")
198
+ dummy = torch.zeros(768, device=device)
199
+ return {
200
+ "H_text": dummy, "H_visual": dummy, "H_spatial": dummy,
201
+ "spatial_10": torch.zeros(10, device=device),
202
+ "question_emb": torch.zeros(384, device=device),
203
+ "text_score": 0.5, "visual_score": 0.3, "spatial_score": 0.2,
204
+ "n_tokens": 0,
205
+ }
206
+
207
+
208
+ def build_feature_vector(feat):
209
+ return torch.cat([
210
+ feat["H_text"], feat["H_visual"], feat["H_spatial"],
211
+ feat["spatial_10"], feat["question_emb"],
212
+ torch.tensor(
213
+ [feat["text_score"], feat["visual_score"], feat["spatial_score"]],
214
+ dtype=torch.float32, device=device
215
+ ),
216
+ ])
217
+
218
+
219
+ def vqa_infer(item, alpha, beta, gamma):
220
+ try:
221
+ img = item["image"].convert("RGB")
222
+ words = list(item.get(WORD_FIELD, []))
223
+ boxes = list(item.get(BOX_FIELD, []))
224
+ question = get_question(item)
225
+ if not words:
226
+ return ""
227
+ W, H = img.size
228
+ n = len(words)
229
+ n_keep = max(int(n * max(float(alpha), 0.30)), min(5, n))
230
+ if float(gamma) > max(float(alpha), float(beta)) and boxes:
231
+ order = sorted(range(n), key=lambda i: (boxes[i][1], boxes[i][0]))
232
+ sel_idx = sorted(order[:n_keep])
233
+ else:
234
+ sel_idx = list(range(n_keep))
235
+ sw = [words[i] for i in sel_idx]
236
+ sb = ([boxes[i] for i in sel_idx]
237
+ if boxes else [[0, 0, W, H]] * len(sw))
238
+ enc = vqa_processor(
239
+ img, text=question, text_pair=sw,
240
+ boxes=normalize_boxes(sb, W, H),
241
+ return_tensors="pt", truncation=True,
242
+ max_length=512, padding=True
243
+ )
244
+ enc = {k: v.to(device) for k, v in enc.items()}
245
+ with torch.no_grad():
246
+ out = vqa_model(**enc)
247
+ s = int(out.start_logits.argmax())
248
+ e = int(out.end_logits.argmax())
249
+ if e < s: e = s
250
+ return vqa_processor.tokenizer.decode(
251
+ enc["input_ids"][0][s:e+1], skip_special_tokens=True
252
+ ).strip()
253
+ except Exception:
254
+ return ""
255
+
256
+
257
+ def compute_anls(pred, gts, threshold=0.5):
258
+ if isinstance(gts, str): gts = [gts]
259
+ if not gts or not pred: return 0.0
260
+ p, best = str(pred).lower().strip(), 0.0
261
+ for gt in gts:
262
+ g = str(gt).lower().strip()
263
+ ml = max(len(p), len(g))
264
+ if ml == 0:
265
+ best = max(best, 1.0); continue
266
+ nls = 1.0 - editdistance.eval(p, g) / ml
267
+ if nls < threshold: nls = 0.0
268
+ best = max(best, nls)
269
+ return best
270
+
271
+
272
+ def compute_f1(pred, gts):
273
+ if isinstance(gts, str): gts = [gts]
274
+ if not pred or not gts: return 0.0
275
+ pt = set(str(pred).lower().split())
276
+ if not pt: return 0.0
277
+ best = 0.0
278
+ for gt in gts:
279
+ gt_t = set(str(gt).lower().split())
280
+ if not gt_t: continue
281
+ common = pt & gt_t
282
+ if not common: continue
283
+ p = len(common) / len(pt)
284
+ r = len(common) / len(gt_t)
285
+ best = max(best, 2 * p * r / (p + r))
286
+ return best
287
+
288
+
289
+ # ══════════════════════════════════════════════════════════════════════
290
+ # WORD SELECTION VISUALIZER HELPERS
291
+ # ══════════════════════════════════════════════════════════════════════
292
+
293
+ def get_sel_idx(item, alpha, beta, gamma):
294
+ """Return the SET of word indices kept by this (alpha, beta, gamma) config.
295
+
296
+ Mirrors the exact selection logic in vqa_infer so the boxes always
297
+ match what the model actually sees.
298
+ """
299
+ words = list(item.get(WORD_FIELD, []))
300
+ boxes = list(item.get(BOX_FIELD, []))
301
+ n = len(words)
302
+ if n == 0:
303
+ return set()
304
+ n_keep = max(int(n * max(float(alpha), 0.30)), min(5, n))
305
+ n_keep = min(n_keep, n)
306
+ if float(gamma) > max(float(alpha), float(beta)) and boxes:
307
+ order = sorted(range(n), key=lambda i: (boxes[i][1], boxes[i][0]))
308
+ sel_idx = set(order[:n_keep])
309
+ else:
310
+ sel_idx = set(range(n_keep))
311
+ return sel_idx
312
+
313
+
314
+ def draw_selection(item, alpha, beta, gamma, title=""):
315
+ """Return a PIL Image with coloured bounding boxes overlaid.
316
+
317
+ 🟒 Green fill + outline β†’ word KEPT (used for VQA)
318
+ πŸ”΄ Red fill + outline β†’ word DROPPED (compressed out)
319
+
320
+ An info strip (dark) and a colour legend strip are appended below the
321
+ document image so the panel is self-explanatory at a glance.
322
+ """
323
+ try:
324
+ img = item["image"].convert("RGB").copy()
325
+ W, H = img.size
326
+ words = list(item.get(WORD_FIELD, []))
327
+ boxes = list(item.get(BOX_FIELD, []))
328
+ n = min(len(words), len(boxes))
329
+ if n == 0:
330
+ return img
331
+
332
+ sel_idx = get_sel_idx(item, alpha, beta, gamma)
333
+ n_keep = len(sel_idx)
334
+ pct = 100 * n_keep / max(n, 1)
335
+
336
+ # ── Draw semi-transparent coloured overlays ───────────────────
337
+ overlay = PILImage.new("RGBA", img.size, (0, 0, 0, 0))
338
+ od = PILDraw.Draw(overlay)
339
+ for i in range(n):
340
+ try:
341
+ x0, y0, x1, y1 = (int(boxes[i][0]), int(boxes[i][1]),
342
+ int(boxes[i][2]), int(boxes[i][3]))
343
+ # Clamp to image bounds
344
+ x0, x1 = max(0, x0), min(W - 1, x1)
345
+ y0, y1 = max(0, y0), min(H - 1, y1)
346
+ if x1 <= x0 or y1 <= y0:
347
+ continue
348
+ if i in sel_idx:
349
+ od.rectangle([x0, y0, x1, y1],
350
+ fill=(0, 210, 0, 55),
351
+ outline=(0, 160, 0, 230), width=2)
352
+ else:
353
+ od.rectangle([x0, y0, x1, y1],
354
+ fill=(220, 30, 30, 40),
355
+ outline=(200, 0, 0, 170), width=1)
356
+ except Exception:
357
+ continue
358
+ img = PILImage.alpha_composite(img.convert("RGBA"), overlay).convert("RGB")
359
+
360
+ # ── Load font (graceful fallback) ─��───────────────────────────
361
+ font_sm = PILFont.load_default()
362
+ for _fp in [
363
+ "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
364
+ "/System/Library/Fonts/Supplemental/Arial.ttf",
365
+ "/Windows/Fonts/arial.ttf",
366
+ ]:
367
+ try:
368
+ font_sm = PILFont.truetype(_fp, 13)
369
+ break
370
+ except Exception:
371
+ continue
372
+
373
+ # ── Info strip (dark bar showing title + stats) ───────────────
374
+ strip_h = 36
375
+ strip = PILImage.new("RGB", (W, strip_h), (22, 22, 32))
376
+ sd = PILDraw.Draw(strip)
377
+ info_text = (f"{title} | βœ“ Kept: {n_keep}/{n} ({pct:.0f}%)"
378
+ f" | Ξ±={alpha:.2f} Ξ²={beta:.2f} Ξ³={gamma:.2f}")
379
+ sd.text((8, 11), info_text, fill=(220, 220, 220), font=font_sm)
380
+
381
+ # ── Legend strip (light bar explaining colours) ───────────────
382
+ leg_h = 28
383
+ leg = PILImage.new("RGB", (W, leg_h), (246, 246, 246))
384
+ ld = PILDraw.Draw(leg)
385
+ ld.rectangle([8, 7, 24, 21], fill=(0, 180, 0), outline=(0, 130, 0, 255))
386
+ ld.text( [30, 8], "= Kept (used for VQA)", fill=(0, 110, 0), font=font_sm)
387
+ ld.rectangle([210, 7, 226, 21], fill=(220, 30, 30), outline=(170, 0, 0, 255))
388
+ ld.text( [232, 8], "= Dropped (compressed out)", fill=(140, 0, 0), font=font_sm)
389
+
390
+ # ── Stack: image β†’ dark strip β†’ legend ────────────────────────
391
+ final = PILImage.new("RGB", (W, H + strip_h + leg_h), (255, 255, 255))
392
+ final.paste(img, (0, 0))
393
+ final.paste(strip, (0, H))
394
+ final.paste(leg, (0, H + strip_h))
395
+ return final
396
+
397
+ except Exception as e:
398
+ print(f" draw_selection error: {e}")
399
+ return item.get("image", None)
400
+
401
+
402
+ def make_compression_md(item, cfgs):
403
+ """Build a markdown table showing kept / dropped word statistics and
404
+ a sample of the words that each method discards.
405
+
406
+ cfgs – OrderedDict/dict {method_name: (alpha, beta, gamma)}
407
+ """
408
+ words = list(item.get(WORD_FIELD, []))
409
+ n = len(words)
410
+ if n == 0:
411
+ return "*No OCR words available for this document.*"
412
+
413
+ md = "### πŸ” What Gets Compressed?\n\n"
414
+ md += f"**Total OCR words in document:** {n}\n\n"
415
+ md += ("| Method | Ξ± | Ξ² | Ξ³ | Words Kept | % Context |"
416
+ " Sample Dropped Words |\n")
417
+ md += ("|--------|---|---|---|:----------:|:---------:|"
418
+ "----------------------|\n")
419
+
420
+ for name, (a, b, g) in cfgs.items():
421
+ sel = get_sel_idx(item, a, b, g)
422
+ n_keep = len(sel)
423
+ pct = 100 * n_keep / max(n, 1)
424
+ dropped = [words[i] for i in range(n) if i not in sel]
425
+ d_preview = " Β· ".join(dropped[:8])
426
+ if len(dropped) > 8:
427
+ d_preview += f" … (+{len(dropped) - 8} more)"
428
+ md += (f"| **{name}** | {a:.2f} | {b:.2f} | {g:.2f}"
429
+ f" | {n_keep} / {n} | {pct:.0f}% | `{d_preview}` |\n")
430
+
431
+ # Show the actual kept words for the CAFP+REINFORCE method
432
+ if "CAFP+REINFORCE" in cfgs:
433
+ a, b, g = cfgs["CAFP+REINFORCE"]
434
+ sel = get_sel_idx(item, a, b, g)
435
+ kept_w = [words[i] for i in sorted(sel)[:25]]
436
+ md += (f"\n**CAFP+REINFORCE β€” kept words (first 25 shown):** \n"
437
+ f"`{' Β· '.join(kept_w)}`\n")
438
+
439
+ return md
440
+
441
+
442
+ # ══════════════════════════════════════════════════════════════════════
443
+ # LOAD CHECKPOINTS & DATA
444
+ # ══════════════════════════════════════════════════════════════════════
445
+ print("\nLoading checkpoints and data...")
446
+
447
+ # ── RL checkpoint ─────────────────────────────────────────────────────
448
+ if not os.path.exists(CKPT_PATH):
449
+ sys.exit(f"❌ Checkpoint not found: {CKPT_PATH}\n"
450
+ f" Copy cafp_rl_checkpoint_final.pt into {CKPT_DIR}/")
451
+
452
+ ck = torch.load(CKPT_PATH, map_location=device, weights_only=False)
453
+ spatial_proj.load_state_dict(ck["spatial_proj_state"])
454
+
455
+ cafp_soft = CrossAttentionFusionPredictor().to(device)
456
+ cafp_soft.load_state_dict(ck["cafp_soft_state"]); cafp_soft.eval()
457
+
458
+ cafp_rl = copy.deepcopy(cafp_soft)
459
+ cafp_rl.load_state_dict(ck["cafp_rl_state"]); cafp_rl.eval()
460
+
461
+ rl_train_anls = ck["rl_train_anls"]
462
+ rl_val_anls = ck.get("rl_val_anls",
463
+ max(rl_train_anls) if rl_train_anls else 0.0)
464
+ print(f" βœ… CAFP+REINFORCE: {len(rl_train_anls)} epochs | "
465
+ f"best_train={max(rl_train_anls):.4f} | val={rl_val_anls:.4f}")
466
+
467
+ # ── Dataset ───────────────────────────────────────────────────────────
468
+ print(" Loading dataset (~30s)...")
469
+ from datasets import load_dataset
470
+ _ds = load_dataset(DATASET_NAME, split="train")
471
+ _split = _ds.train_test_split(test_size=0.2, seed=42)
472
+ rng = np.random.RandomState(42)
473
+ val_idx = rng.permutation(len(_split["test"])).tolist()[:N_VAL]
474
+ train_idx = rng.permutation(len(_split["train"])).tolist()[:N_TRAIN]
475
+ val_items = [_split["test"][i] for i in val_idx]
476
+ train_items = [_split["train"][i] for i in train_idx]
477
+ val_gts = [item[ANSWER_FIELD] for item in val_items]
478
+ train_gts = [item[ANSWER_FIELD] for item in train_items]
479
+ print(f" βœ… Dataset: {len(val_items)} val, {len(train_items)} train")
480
+
481
+ # ── Feature tensors ───────────────────────────────────────────────────
482
+ if os.path.exists(FEAT_PATH):
483
+ t = torch.load(FEAT_PATH, map_location=device, weights_only=False)
484
+ val_feats = t["val_feats"]
485
+ train_feats = t["train_feats"]
486
+ print(f" βœ… Features: {tuple(val_feats.shape)}")
487
+ else:
488
+ print(" ⚠️ feature_tensors.pt not found β€” recomputing (~2 min)...")
489
+ def _feats(items, tag):
490
+ out = []
491
+ for i, item in enumerate(items):
492
+ out.append(build_feature_vector(
493
+ extract_rich_features(item)).unsqueeze(0))
494
+ if (i + 1) % 10 == 0:
495
+ print(f" {tag}: {i+1}/{len(items)}", end="\r")
496
+ print()
497
+ return torch.cat(out).to(device)
498
+ val_feats = _feats(val_items, "val")
499
+ train_feats = _feats(train_items, "train")
500
+ torch.save({"val_feats": val_feats, "train_feats": train_feats,
501
+ "val_gts": val_gts, "train_gts": train_gts}, FEAT_PATH)
502
+ print(f" βœ… Features computed and saved to {FEAT_PATH}")
503
+
504
+ # ── Oracle cache ──────────────────────────────────────────────────────
505
+ val_oracle = train_oracle = []
506
+ if os.path.exists(ORACLE_CACHE):
507
+ _oc = json.load(open(ORACLE_CACHE))
508
+ train_oracle = _oc.get("train", [])
509
+ val_oracle = _oc.get("val", [])
510
+ print(f" βœ… Oracle cache: {len(train_oracle)} train, {len(val_oracle)} val")
511
+ else:
512
+ print(" ⚠️ oracle_cache.json not found β€” demo works without it")
513
+
514
+ # ── Results from JSON ─────────────────────────────────────────────────
515
+ results = {}
516
+ _RKEYS = [
517
+ "Equal Fusion", "Proposed Fixed", "Text-Only",
518
+ "LLMLingua-style", "Selective Context-style",
519
+ "CAFP (paper checkpoint)", "CAFP-Hard Oracle", "CAFP-Soft Oracle",
520
+ ]
521
+ for _rpath in [RESULTS_PATH, "./final_results.json", "./results_condensed.json"]:
522
+ try:
523
+ _raw = json.load(open(_rpath))
524
+ for k in _RKEYS:
525
+ if k in _raw and isinstance(_raw[k], dict):
526
+ r = _raw[k]
527
+ results[k] = {
528
+ "mean_anls": float(r.get("mean_anls", r.get("anls", 0.0))),
529
+ "mean_f1": float(r.get("mean_f1", r.get("f1", 0.0))),
530
+ }
531
+ if results:
532
+ print(f" βœ… Results: {len(results)} methods from {_rpath}")
533
+ break
534
+ except Exception:
535
+ continue
536
+ if not results:
537
+ print(" ⚠️ Results JSON not found β€” dashboard will show partial data")
538
+
539
+ # ── Find best demo documents ──────────────────────────────────────────
540
+ print("\nPre-scoring documents for demo (this takes ~2 min)...")
541
+ demo_scores = []
542
+ cafp_rl.eval()
543
+ with torch.no_grad():
544
+ for i in range(len(val_items)):
545
+ fv = val_feats[i].unsqueeze(0)
546
+ conc = F.softplus(cafp_rl._logits(fv)) + 0.1
547
+ w = (conc / conc.sum()).squeeze(0).cpu().tolist()
548
+ rl_s = compute_anls(vqa_infer(val_items[i], w[0], w[1], w[2]),
549
+ val_gts[i])
550
+ fx_s = compute_anls(vqa_infer(val_items[i], 0.5, 0.3, 0.2),
551
+ val_gts[i])
552
+ demo_scores.append((i, round(rl_s - fx_s, 4),
553
+ round(rl_s, 4), round(fx_s, 4)))
554
+ if (i + 1) % 20 == 0:
555
+ print(f" {i+1}/100", end="\r")
556
+ demo_scores.sort(key=lambda x: -x[1])
557
+ best_idx = demo_scores[0][0]
558
+ top5_str = ", ".join([f"#{x[0]}(+{x[1]:.2f})" for x in demo_scores[:5]])
559
+ print(f"\n βœ… Best docs: {top5_str}")
560
+ print(f"\n{'='*55}")
561
+ print("ALL MODELS LOADED β€” ready to demo")
562
+ print(f"{'='*55}\n")
563
+
564
+
565
+ # ══════════════════════════════════════════════════════════════════════
566
+ # GRADIO FUNCTIONS
567
+ # ══════════════════════════════════════════════════════════════════════
568
+ def get_rl_weights(idx, custom_q=None):
569
+ if custom_q and custom_q.strip():
570
+ _item = dict(val_items[idx])
571
+ _item[QUERY_FIELD] = custom_q.strip()
572
+ fv = build_feature_vector(extract_rich_features(_item)).unsqueeze(0)
573
+ else:
574
+ fv = val_feats[idx].unsqueeze(0)
575
+ with torch.no_grad():
576
+ conc = F.softplus(cafp_rl._logits(fv)) + 0.1
577
+ w = (conc / conc.sum()).squeeze(0).cpu().tolist()
578
+ return w
579
+
580
+
581
+ def make_weight_chart(mw):
582
+ fig, ax = plt.subplots(figsize=(9, 3.5))
583
+ labels = list(mw.keys())
584
+ x, bw = np.arange(len(labels)), 0.25
585
+ for j, (lbl, col) in enumerate([
586
+ ("\u03b1 Text", "#2196F3"),
587
+ ("\u03b2 Visual", "#4CAF50"),
588
+ ("\u03b3 Spatial", "#FF9800"),
589
+ ]):
590
+ vals = [list(mw.values())[i][j] for i in range(len(labels))]
591
+ bars = ax.bar(x + (j - 1) * bw, vals, bw,
592
+ label=lbl, color=col, alpha=0.85)
593
+ for bar in bars:
594
+ h = bar.get_height()
595
+ ax.text(bar.get_x() + bar.get_width() / 2, h + 0.01,
596
+ f"{h:.2f}", ha="center", va="bottom", fontsize=9)
597
+ ax.set_xticks(x); ax.set_xticklabels(labels, fontsize=10)
598
+ ax.set_ylabel("Weight"); ax.set_ylim(0, 1.2)
599
+ ax.set_title("Fusion Weights (\u03b1, \u03b2, \u03b3) per Method",
600
+ fontsize=12, fontweight="bold")
601
+ ax.legend(fontsize=9); ax.grid(axis="y", alpha=0.3)
602
+ plt.tight_layout()
603
+ return fig
604
+
605
+
606
+ def run_demo(doc_idx, custom_q):
607
+ doc_idx = int(doc_idx)
608
+ item = val_items[doc_idx]
609
+ gt = val_gts[doc_idx]
610
+ q = (custom_q.strip()
611
+ if custom_q and custom_q.strip()
612
+ else get_question(item))
613
+ gt_str = (", ".join(str(g) for g in gt[:2])
614
+ if isinstance(gt, list) else str(gt))
615
+ n_words = len(list(item.get(WORD_FIELD, [])))
616
+ doc_type = "Text-dominant" if n_words > 40 else "Visual-dominant"
617
+
618
+ alpha, beta, gamma = get_rl_weights(doc_idx, custom_q)
619
+ dom = ("Text" if alpha > 0.65 else
620
+ "Visual" if beta > 0.40 else "Balanced")
621
+
622
+ cfgs = {
623
+ "Equal Fusion": (1/3, 1/3, 1/3),
624
+ "Fixed (0.5,0.3,0.2)": (0.5, 0.3, 0.2),
625
+ "Text-Only": (1.0, 0.0, 0.0),
626
+ "CAFP+REINFORCE": (alpha, beta, gamma),
627
+ }
628
+ res = {}
629
+ for name, (a, b, g) in cfgs.items():
630
+ demo_item = dict(item); demo_item[QUERY_FIELD] = q
631
+ pred = vqa_infer(demo_item, a, b, g)
632
+ res[name] = {
633
+ "pred": pred,
634
+ "anls": compute_anls(pred, gt),
635
+ "f1": compute_f1(pred, gt),
636
+ "w": (a, b, g),
637
+ }
638
+
639
+ best = max(res, key=lambda k: res[k]["anls"])
640
+ rl_vs_fixed = res["CAFP+REINFORCE"]["anls"] - res["Fixed (0.5,0.3,0.2)"]["anls"]
641
+
642
+ md = f"## Document #{doc_idx} \u2014 {doc_type} ({n_words} words)\n\n"
643
+ md += f"**Question:** {q}\n\n"
644
+ md += f"**Ground Truth:** `{gt_str}`\n\n---\n"
645
+ md += "### Step 1 \u2014 Text Extraction\n"
646
+ md += f"`{n_words}` OCR words extracted via LayoutLMv3\n\n"
647
+ md += "### Step 2 \u2014 Multimodal Feature Extraction\n"
648
+ md += ("- **Text** \u2192 LayoutLMv3 token embeddings [768-D]\n"
649
+ "- **Visual** \u2192 LayoutLMv3 patch features [768-D]\n"
650
+ "- **Spatial** \u2192 Bounding box layout encoding [768-D]\n\n")
651
+ md += "### Step 3 \u2014 CAFP+REINFORCE Weight Prediction\n"
652
+ md += "| Modality | Weight |\n|----------|--------|\n"
653
+ md += f"| \u03b1 Text | **{alpha:.3f}** |\n"
654
+ md += f"| \u03b2 Visual | **{beta:.3f}** |\n"
655
+ md += f"| \u03b3 Spatial | **{gamma:.3f}** |\n\n"
656
+ md += f"\u2192 **Dominant: {dom}**\n\n"
657
+ md += "### Step 4 \u2014 Adaptive Fusion \u2192 Answer\n"
658
+ md += "| Method | \u03b1 | \u03b2 | \u03b3 | Answer | ANLS | F1 |\n"
659
+ md += "|--------|---|---|---|--------|------|----|\n"
660
+ for name, d in res.items():
661
+ a, b, g = d["w"]
662
+ star = " \u2b50" if name == best else ""
663
+ md += (f"| {name}{star} | {a:.2f} | {b:.2f} | {g:.2f}"
664
+ f" | `{d['pred']}` | **{d['anls']:.4f}** | {d['f1']:.4f} |\n")
665
+ sign = "+" if rl_vs_fixed >= 0 else ""
666
+ md += (f"\n---\n**Best:** {best} (ANLS: {res[best]['anls']:.4f})\n\n"
667
+ f"**CAFP+REINFORCE answer:** `{res['CAFP+REINFORCE']['pred']}`\n\n"
668
+ f"**\u0394 over Fixed:** {sign}{rl_vs_fixed:.4f}\n")
669
+
670
+ chart = make_weight_chart({k: v["w"] for k, v in res.items()})
671
+
672
+ # ── Word Selection Visualizations ─────────────────────────────────
673
+ _item_q = dict(item); _item_q[QUERY_FIELD] = q
674
+ fixed_vis = draw_selection(
675
+ _item_q, 0.5, 0.3, 0.2,
676
+ "Fixed Weights (0.5, 0.3, 0.2)"
677
+ )
678
+ rl_vis = draw_selection(
679
+ _item_q, alpha, beta, gamma,
680
+ f"CAFP+REINFORCE (Ξ±={alpha:.2f} Ξ²={beta:.2f} Ξ³={gamma:.2f})"
681
+ )
682
+ comp_md = make_compression_md(item, cfgs)
683
+
684
+ return item.get("image", None), md, chart, fixed_vis, rl_vis, comp_md
685
+
686
+
687
+ def show_dashboard():
688
+ def sg(k): return results.get(k, {}).get("mean_anls", 0.0)
689
+ def sf(k): return results.get(k, {}).get("mean_f1", 0.0)
690
+ fixed = sg("Proposed Fixed"); oracle = 0.8377
691
+ rv = rl_val_anls
692
+
693
+ rows = [
694
+ ("Equal Fusion", sg("Equal Fusion"), sf("Equal Fusion")),
695
+ ("Proposed Fixed (paper)", sg("Proposed Fixed"), sf("Proposed Fixed")),
696
+ ("Text-Only", sg("Text-Only"), sf("Text-Only")),
697
+ ("LLMLingua-style [NEW]", sg("LLMLingua-style"), sf("LLMLingua-style")),
698
+ ("Selective Context [NEW]", sg("Selective Context-style"), sf("Selective Context-style")),
699
+ ("CAFP paper checkpoint", sg("CAFP (paper checkpoint)"), sf("CAFP (paper checkpoint)")),
700
+ ("CAFP Hard Oracle [NEW]", sg("CAFP-Hard Oracle"), sf("CAFP-Hard Oracle")),
701
+ ("CAFP Soft Oracle [NEW]", sg("CAFP-Soft Oracle"), sf("CAFP-Soft Oracle")),
702
+ ("CAFP+REINFORCE [NEW][BEST]", rv, 0.0),
703
+ ("Oracle Upper Bound", oracle, 0.0),
704
+ ]
705
+
706
+ md = "## Full Experiment Results\n\n"
707
+ md += "| Method | ANLS | F1 | \u0394 Fixed | % Oracle |\n"
708
+ md += "|--------|------|----|----------|----------|\n"
709
+ for name, anls, f1 in rows:
710
+ is_oracle = "Oracle Upper" in name
711
+ d = f"{anls - fixed:+.4f}" if not is_oracle else "\u2014"
712
+ pct = f"{anls / oracle * 100:.1f}%" if anls > 0 else "\u2014"
713
+ md += f"| {name} | {anls:.4f} | {f1:.4f} | {d} | {pct} |\n"
714
+ md += (f"\n**CAFP+REINFORCE: {rv/oracle*100:.1f}% of Oracle ANLS**\n"
715
+ f"**Improvement over Fixed: {rv - fixed:+.4f} ANLS**\n")
716
+
717
+ # Bar chart
718
+ fig1, ax1 = plt.subplots(figsize=(11, 5))
719
+ bv = [r[1] for r in rows]
720
+ bc = ["#bbb","#999","#bbb","#2196F3","#2196F3",
721
+ "#777","#4CAF50","#4CAF50","#FF5722","#d32f2f"]
722
+ bars = ax1.barh([r[0] for r in rows], bv,
723
+ color=bc, edgecolor="white", height=0.65)
724
+ ax1.axvline(oracle, color="red", linestyle="--", lw=1.5,
725
+ label=f"Oracle {oracle:.4f}")
726
+ ax1.axvline(fixed, color="gray", linestyle=":", lw=1.2,
727
+ label=f"Fixed {fixed:.4f}")
728
+ for bar, val in zip(bars, bv):
729
+ if val > 0:
730
+ ax1.text(val + 0.003, bar.get_y() + bar.get_height() / 2,
731
+ f"{val:.4f}", va="center", fontsize=8)
732
+ ax1.set_xlabel("Val ANLS", fontsize=11)
733
+ ax1.set_title("All Methods \u2014 Val ANLS", fontsize=13, fontweight="bold")
734
+ ax1.legend(fontsize=9); ax1.invert_yaxis()
735
+ ax1.set_xlim(0, oracle * 1.1); ax1.grid(axis="x", alpha=0.3)
736
+ plt.tight_layout()
737
+
738
+ # Training curve
739
+ fig2, ax2 = plt.subplots(figsize=(10, 3.5))
740
+ eps = list(range(1, len(rl_train_anls) + 1))
741
+ ax2.plot(eps, rl_train_anls, "o-", color="#FF5722",
742
+ lw=2.5, ms=7, label="Train ANLS")
743
+ ax2.axhline(rv, color="#FF5722", linestyle=":", lw=2,
744
+ label=f"Val ANLS = {rv:.4f}")
745
+ ax2.axhline(oracle, color="red", linestyle="--", lw=1.5,
746
+ label=f"Oracle = {oracle:.4f}")
747
+ ax2.axhline(fixed, color="gray", linestyle=":", lw=1.2,
748
+ label=f"Fixed = {fixed:.4f}")
749
+ ax2.fill_between(eps, rl_train_anls, fixed, alpha=0.15, color="#FF5722")
750
+ ax2.set_xlabel("Epoch"); ax2.set_ylabel("ANLS")
751
+ ax2.set_title("REINFORCE Fine-tuning Progress",
752
+ fontsize=12, fontweight="bold")
753
+ ax2.legend(fontsize=9); ax2.grid(True, alpha=0.3)
754
+ ax2.set_xticks(eps); plt.tight_layout()
755
+
756
+ return md, fig1, fig2
757
+
758
+
759
+ # ══════════════════════════════════════════════════════════════════════
760
+ # GRADIO UI
761
+ # ══════════════════════════════════════════════════════════════════════
762
+ _fixed_anls = results.get("Proposed Fixed", {}).get("mean_anls", 0.0)
763
+ _best_label = ("Best docs (REINFORCE wins most): "
764
+ + ", ".join([f"#{x[0]}" for x in demo_scores[:5]]))
765
+
766
+ CSS = ".tab-nav button { font-size: 15px !important; font-weight: 600 !important; }"
767
+
768
+ with gr.Blocks(
769
+ title="Adaptive Multimodal Fusion β€” DocVQA Demo",
770
+ theme=gr.themes.Soft(primary_hue="blue"),
771
+ css=CSS,
772
+ ) as demo_app:
773
+
774
+ gr.Markdown("""
775
+ # Adaptive Multimodal Fusion for Document VQA
776
+ ### Cross-Attention Fusion Predictor (CAFP) + REINFORCE Fine-tuning
777
+ """)
778
+
779
+ with gr.Tabs():
780
+
781
+ # ── Tab 1: Live Demo ──────────────────────────────────────────
782
+ with gr.TabItem("\U0001f3af Live Demo"):
783
+ gr.Markdown(f"**{_best_label}**")
784
+ with gr.Row():
785
+ with gr.Column(scale=1):
786
+ doc_slider = gr.Slider(
787
+ 0, len(val_items) - 1,
788
+ value=best_idx, step=1,
789
+ label=f"Document Index (0\u2013{len(val_items)-1})"
790
+ )
791
+ custom_q = gr.Textbox(
792
+ label="Custom Question (optional)",
793
+ placeholder="Leave blank to use original question"
794
+ )
795
+ run_btn = gr.Button(
796
+ "\u25b6 Run Adaptive Fusion",
797
+ variant="primary", size="lg"
798
+ )
799
+ gr.Markdown(
800
+ "*Compares: Equal \u00b7 Fixed \u00b7 "
801
+ "Text-Only \u00b7 CAFP+REINFORCE*"
802
+ )
803
+ with gr.Column(scale=2):
804
+ doc_image = gr.Image(label="Document Image", height=400)
805
+ step_md = gr.Markdown()
806
+ weight_chart = gr.Plot(label="Fusion Weights Comparison")
807
+
808
+ # ── Word Selection Visualizer ─────────────────────────────
809
+ gr.Markdown("""
810
+ ---
811
+ ### 🎨 Word Selection Visualization
812
+ *See **exactly** which OCR words each method keeps vs discards.*
813
+ 🟒 **Green** = kept and fed to the VQA model Β· πŸ”΄ **Red** = compressed out
814
+ """)
815
+ with gr.Row():
816
+ fixed_vis_img = gr.Image(
817
+ label="πŸ“Œ Fixed Weights (Ξ±=0.5 Ξ²=0.3 Ξ³=0.2)",
818
+ height=520, show_download_button=True
819
+ )
820
+ rl_vis_img = gr.Image(
821
+ label="πŸ€– CAFP+REINFORCE (Adaptive Weights)",
822
+ height=520, show_download_button=True
823
+ )
824
+ comp_md_out = gr.Markdown()
825
+
826
+ run_btn.click(
827
+ fn=run_demo,
828
+ inputs=[doc_slider, custom_q],
829
+ outputs=[doc_image, step_md, weight_chart,
830
+ fixed_vis_img, rl_vis_img, comp_md_out],
831
+ )
832
+
833
+ # ── Tab 2: Results Dashboard ──────────────────────────────────
834
+ with gr.TabItem("\U0001f4ca Results Dashboard"):
835
+ gr.Markdown("### All methods compared + REINFORCE training curve")
836
+ load_btn = gr.Button("Load Results", variant="secondary")
837
+ res_md = gr.Markdown()
838
+ with gr.Row():
839
+ bar_chart = gr.Plot(label="ANLS \u2014 All Methods")
840
+ rl_curve = gr.Plot(label="REINFORCE Training Curve")
841
+ load_btn.click(
842
+ fn=show_dashboard,
843
+ inputs=[],
844
+ outputs=[res_md, bar_chart, rl_curve],
845
+ )
846
+
847
+ # ── Tab 3: About ──────────────────────────────────────────────
848
+ with gr.TabItem("\u2139\ufe0f About"):
849
+ gr.Markdown(f"""
850
+ ## Adaptive Multimodal Fusion for DocVQA
851
+
852
+ ### Problem
853
+ DocVQA requires reasoning over three modalities simultaneously:
854
+ - **Text** β€” OCR words and their semantics
855
+ - **Visual** β€” Document appearance and image patches
856
+ - **Spatial** β€” Bounding box positions and layout structure
857
+
858
+ Fixed weights (Ξ±=0.5, Ξ²=0.3, Ξ³=0.2) cannot adapt to different document types.
859
+
860
+ ### Architecture: CAFP (428K params)
861
+ 1. Projects each modality embedding to 128-D
862
+ 2. Cross-attention: question attends to all modality representations
863
+ 3. Predicts per-document (Ξ±, Ξ², Ξ³) fusion weights
864
+
865
+ ### Training Pipeline
866
+ 1. **Hard Oracle** (MSE) β†’ argmax weights from 20-combo grid search
867
+ 2. **Soft Oracle** (KL-div) β†’ temperature-smoothed ANLS-weighted targets
868
+ 3. **REINFORCE** β†’ Policy gradient on direct ANLS reward (K=3 samples/step)
869
+
870
+ ### Novel Contributions
871
+ 1. Soft Oracle training eliminates hard-oracle label noise
872
+ 2. REINFORCE fine-tuning directly maximises DocVQA metric
873
+ 3. LLMLingua-style and Selective Context baselines for fair comparison
874
+
875
+ ### Key Result
876
+ **CAFP+REINFORCE achieves {rl_val_anls/0.8377*100:.1f}% of Oracle ANLS**
877
+ Improvement over fixed-weight baseline: {rl_val_anls - _fixed_anls:+.4f} ANLS
878
+ """)
879
+
880
+
881
+ # ══════════════════════════════════════════════════════════════════════
882
+ # LAUNCH
883
+ # ══════════════════════════════════════════════════════════════════════
884
+ if __name__ == "__main__":
885
+ demo_app.launch(
886
+ server_name="0.0.0.0",
887
+ server_port=args.port,
888
+ share=args.share,
889
+ show_error=True,
890
+ )
checkpoints/.DS_Store ADDED
Binary file (6.15 kB). View file
 
checkpoints/cafp_rl_checkpoint_final.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a736aac8aa4ed8cb5ac34d53aa2ef6a896aad4f54ec241ab787d37147d62cce
3
+ size 7688445
checkpoints/feature_tensors.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:487d11bbbab8ee1e26ddd6791dbef4f14d4f649b58ef6623d2da3edb6a3e7cbe
3
+ size 2172325
checkpoints/final_results.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Equal Fusion": {
3
+ "mean_anls": 0.5136354004166895,
4
+ "mean_f1": 0.5237649141912695,
5
+ "mean_em": 0.45
6
+ },
7
+ "Proposed Fixed": {
8
+ "mean_anls": 0.5451455242797546,
9
+ "mean_f1": 0.5509710563188824,
10
+ "mean_em": 0.48
11
+ },
12
+ "Text-Only": {
13
+ "mean_anls": 0.7642224473566777,
14
+ "mean_f1": 0.7935144736858293,
15
+ "mean_em": 0.7
16
+ },
17
+ "LLMLingua-style": {
18
+ "mean_anls": 0.1892322383498854,
19
+ "mean_f1": 0.210276221599751,
20
+ "mean_em": 0.17
21
+ },
22
+ "Selective Context-style": {
23
+ "mean_anls": 0.3046072383498854,
24
+ "mean_f1": 0.3247206660441955,
25
+ "mean_em": 0.27
26
+ },
27
+ "CAFP (paper checkpoint)": {
28
+ "mean_anls": 0.7531033997376301,
29
+ "mean_f1": 0.7755144736858292,
30
+ "mean_em": 0.68
31
+ },
32
+ "CAFP-Hard Oracle": {
33
+ "mean_anls": 0.7542224473566776,
34
+ "mean_f1": 0.7721811403524959,
35
+ "mean_em": 0.69
36
+ },
37
+ "CAFP-Soft Oracle": {
38
+ "mean_anls": 0.5610757568378942,
39
+ "mean_f1": 0.5647372900851162,
40
+ "mean_em": 0.5
41
+ },
42
+ "CAFP+REINFORCE": {
43
+ "mean_anls": 0.7084701316595168,
44
+ "rl_curve": [
45
+ 0.5759529617062865,
46
+ 0.5927631990609283,
47
+ 0.6183613555884967,
48
+ 0.655312951977503,
49
+ 0.6726479882862236,
50
+ 0.6637648504900422,
51
+ 0.6849515365259237,
52
+ 0.6891216978125647,
53
+ 0.7084701316595168
54
+ ]
55
+ }
56
+ }
checkpoints/oracle_cache.json ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets
2
+ transformers
3
+ torch
4
+ torchvision
5
+ sentencepiece
6
+ editdistance
7
+ sentence-transformers
8
+ accelerate
9
+ gradio
10
+ pillow
11
+ numpy
12
+ matplotlib