satya007 commited on
Commit
f34d690
·
verified ·
1 Parent(s): 073f8f9

Upload model_lilt.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model_lilt.py +1491 -0
model_lilt.py ADDED
@@ -0,0 +1,1491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DEXTR with LiLT - 3-Step Document Extraction.
3
+
4
+ Architecture:
5
+ - Document encoder: LiLT (XLM-RoBERTa + Layout Transformer), extended to 1024 tokens
6
+ - Query encoder: Sentence Transformer (frozen, for semantic similarity)
7
+ - Step 1: Token Classification (Q/A/T/H/O) - TRAINED
8
+ - Step 2: Query-Question Matching (ZERO-SHOT) - NO TRAINING
9
+ - Step 3: Table Head (hierarchical with attention-based column assignment) - TRAINED
10
+
11
+ Key insight: Query→Question matching is semantic similarity. Sentence transformers
12
+ provide proper semantic embeddings for matching queries to question text.
13
+ """
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from transformers import LiltModel, LiltConfig
19
+ from sentence_transformers import SentenceTransformer
20
+ from typing import Optional, Dict, List, Tuple
21
+
22
+
23
+ # Label mappings for Step 1
24
+ LABEL2ID = {
25
+ "O": 0,
26
+ "B-Q": 1, "I-Q": 2, # Question
27
+ "B-A": 3, "I-A": 4, # Answer
28
+ "B-H": 5, "I-H": 6, # Header
29
+ "B-TABLE": 7, "I-TABLE": 8, # Table
30
+ }
31
+ ID2LABEL = {v: k for k, v in LABEL2ID.items()}
32
+ NUM_LABELS = len(LABEL2ID)
33
+
34
+
35
+ class TokenClassificationHead(nn.Module):
36
+ """
37
+ Step 1: Token Classification Head.
38
+ Predicts Q/A/H/T/O labels for each token (FUNSD-style).
39
+ """
40
+
41
+ def __init__(self, hidden_size: int = 768, num_labels: int = NUM_LABELS, dropout: float = 0.1):
42
+ super().__init__()
43
+ self.classifier = nn.Sequential(
44
+ nn.Dropout(dropout),
45
+ nn.Linear(hidden_size, 256),
46
+ nn.GELU(),
47
+ nn.Dropout(dropout),
48
+ nn.Linear(256, num_labels),
49
+ )
50
+
51
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
52
+ return self.classifier(hidden_states)
53
+
54
+
55
+ class QALinker(nn.Module):
56
+ """
57
+ Step 1.5: Q-A Linker.
58
+ Predicts which Answer span links to which Question span.
59
+ Uses both semantic similarity and spatial features.
60
+ """
61
+
62
+ def __init__(self, hidden_size: int = 768, dropout: float = 0.1):
63
+ super().__init__()
64
+ self.hidden_size = hidden_size
65
+
66
+ # Project Q and A to same space for semantic matching
67
+ self.q_proj = nn.Sequential(
68
+ nn.Linear(hidden_size, 256),
69
+ nn.LayerNorm(256),
70
+ nn.GELU(),
71
+ nn.Dropout(dropout),
72
+ )
73
+ self.a_proj = nn.Sequential(
74
+ nn.Linear(hidden_size, 256),
75
+ nn.LayerNorm(256),
76
+ nn.GELU(),
77
+ nn.Dropout(dropout),
78
+ )
79
+
80
+ # Spatial feature scorer
81
+ # Features: x_dist, y_dist, is_right, is_below, is_same_line, width_ratio, height_ratio
82
+ self.spatial_scorer = nn.Sequential(
83
+ nn.Linear(7, 32),
84
+ nn.GELU(),
85
+ nn.Linear(32, 1),
86
+ )
87
+
88
+ # Learnable temperature
89
+ self.log_temp = nn.Parameter(torch.tensor(1.0))
90
+
91
+ def compute_spatial_features(
92
+ self,
93
+ q_bboxes: torch.Tensor, # (num_q, 4) - [x1, y1, x2, y2]
94
+ a_bboxes: torch.Tensor, # (num_a, 4)
95
+ ) -> torch.Tensor:
96
+ """
97
+ Compute spatial features for all Q-A pairs.
98
+ Returns: (num_q, num_a, 7) feature tensor
99
+ """
100
+ num_q = q_bboxes.shape[0]
101
+ num_a = a_bboxes.shape[0]
102
+ device = q_bboxes.device
103
+
104
+ # Compute centers and sizes
105
+ q_cx = (q_bboxes[:, 0] + q_bboxes[:, 2]) / 2 # (num_q,)
106
+ q_cy = (q_bboxes[:, 1] + q_bboxes[:, 3]) / 2
107
+ q_w = q_bboxes[:, 2] - q_bboxes[:, 0]
108
+ q_h = q_bboxes[:, 3] - q_bboxes[:, 1]
109
+
110
+ a_cx = (a_bboxes[:, 0] + a_bboxes[:, 2]) / 2 # (num_a,)
111
+ a_cy = (a_bboxes[:, 1] + a_bboxes[:, 3]) / 2
112
+ a_w = a_bboxes[:, 2] - a_bboxes[:, 0]
113
+ a_h = a_bboxes[:, 3] - a_bboxes[:, 1]
114
+
115
+ # Expand for pairwise computation
116
+ q_cx = q_cx.unsqueeze(1).expand(num_q, num_a) # (num_q, num_a)
117
+ q_cy = q_cy.unsqueeze(1).expand(num_q, num_a)
118
+ q_x2 = q_bboxes[:, 2].unsqueeze(1).expand(num_q, num_a)
119
+ q_y2 = q_bboxes[:, 3].unsqueeze(1).expand(num_q, num_a)
120
+ q_w = q_w.unsqueeze(1).expand(num_q, num_a)
121
+ q_h = q_h.unsqueeze(1).expand(num_q, num_a)
122
+
123
+ a_cx = a_cx.unsqueeze(0).expand(num_q, num_a)
124
+ a_cy = a_cy.unsqueeze(0).expand(num_q, num_a)
125
+ a_x1 = a_bboxes[:, 0].unsqueeze(0).expand(num_q, num_a)
126
+ a_y1 = a_bboxes[:, 1].unsqueeze(0).expand(num_q, num_a)
127
+ a_w = a_w.unsqueeze(0).expand(num_q, num_a)
128
+ a_h = a_h.unsqueeze(0).expand(num_q, num_a)
129
+
130
+ # Compute features (all normalized to [0, 1] range roughly)
131
+ x_dist = (a_x1 - q_x2) / 1000.0 # Horizontal distance (positive = A right of Q)
132
+ y_dist = (a_cy - q_cy).abs() / 1000.0 # Vertical distance
133
+ is_right = (a_x1 > q_x2).float() # A is to the right of Q
134
+ is_below = (a_y1 > q_y2).float() # A is below Q
135
+ is_same_line = (y_dist < 0.03).float() # Same line (small y distance)
136
+ width_ratio = (a_w / (q_w + 1e-6)).clamp(0, 5) / 5.0 # Relative width
137
+ height_ratio = (a_h / (q_h + 1e-6)).clamp(0, 5) / 5.0 # Relative height
138
+
139
+ # Stack features: (num_q, num_a, 7)
140
+ features = torch.stack([
141
+ x_dist, y_dist, is_right, is_below, is_same_line, width_ratio, height_ratio
142
+ ], dim=-1)
143
+
144
+ return features
145
+
146
+ def forward(
147
+ self,
148
+ q_embeds: torch.Tensor, # (num_q, hidden)
149
+ a_embeds: torch.Tensor, # (num_a, hidden)
150
+ q_bboxes: torch.Tensor, # (num_q, 4)
151
+ a_bboxes: torch.Tensor, # (num_a, 4)
152
+ ) -> torch.Tensor:
153
+ """
154
+ Compute Q-A link scores.
155
+ Returns: (num_q, num_a) score matrix
156
+ """
157
+ # Semantic similarity
158
+ q_proj = self.q_proj(q_embeds) # (num_q, 256)
159
+ a_proj = self.a_proj(a_embeds) # (num_a, 256)
160
+
161
+ q_proj = F.normalize(q_proj, dim=-1)
162
+ a_proj = F.normalize(a_proj, dim=-1)
163
+
164
+ semantic_scores = torch.matmul(q_proj, a_proj.t()) # (num_q, num_a)
165
+
166
+ # Spatial scores
167
+ spatial_feats = self.compute_spatial_features(q_bboxes, a_bboxes) # (num_q, num_a, 7)
168
+ spatial_scores = self.spatial_scorer(spatial_feats).squeeze(-1) # (num_q, num_a)
169
+
170
+ # Combine with learnable temperature
171
+ temperature = self.log_temp.exp().clamp(min=0.1, max=10.0)
172
+ combined_scores = (semantic_scores + spatial_scores) * temperature
173
+
174
+ return combined_scores
175
+
176
+
177
+ class QuestionPredictor(nn.Module):
178
+ """
179
+ Predicts Question embedding from Answer embedding.
180
+ Used when explicit Question tokens are missing in the document.
181
+
182
+ Input: LiLT A embedding (768 dim)
183
+ Output: Predicted Q embedding in sentence transformer space (384 dim)
184
+
185
+ Training:
186
+ - For Q-A pairs with Q: randomly mask Q and train to predict it
187
+ - Learns what Q "should look like" given the answer
188
+ """
189
+
190
+ def __init__(self, input_dim: int = 768, output_dim: int = 384, dropout: float = 0.1):
191
+ super().__init__()
192
+
193
+ # Input: LiLT answer embedding (768)
194
+ # Output: predicted question embedding in sentence transformer space (384)
195
+ self.predictor = nn.Sequential(
196
+ nn.Linear(input_dim, input_dim),
197
+ nn.LayerNorm(input_dim),
198
+ nn.GELU(),
199
+ nn.Dropout(dropout),
200
+ nn.Linear(input_dim, output_dim),
201
+ nn.LayerNorm(output_dim),
202
+ )
203
+
204
+ def forward(self, answer_embed: torch.Tensor) -> torch.Tensor:
205
+ """
206
+ Predict question embedding from answer embedding.
207
+ Args:
208
+ answer_embed: (batch, hidden) or (hidden,) - LiLT answer span embedding
209
+ Returns:
210
+ predicted_q_embed: (batch, output_dim) or (output_dim,) - in sentence transformer space
211
+ """
212
+ return self.predictor(answer_embed)
213
+
214
+
215
+ class QueryAnswerMatcher(nn.Module):
216
+ """
217
+ Step 2: Query-Answer Matching.
218
+ Given answer span embeddings and query embedding, scores which span matches.
219
+
220
+ Features:
221
+ - Start+End pooling: spans use [h_start; h_end] instead of mean (captures boundaries)
222
+ - Bbox features: spatial position helps distinguish similar text
223
+ - Learnable temperature for score scaling
224
+ - Cosine similarity with projection to lower dim
225
+ """
226
+
227
+ def __init__(
228
+ self,
229
+ hidden_size: int = 768,
230
+ proj_dim: int = 256,
231
+ dropout: float = 0.1,
232
+ use_bbox_features: bool = True,
233
+ num_bbox_features: int = 6, # x1, y1, x2, y2, width, height
234
+ ):
235
+ super().__init__()
236
+ self.proj_dim = proj_dim
237
+ self.use_bbox_features = use_bbox_features
238
+ self.hidden_size = hidden_size
239
+
240
+ # Span input: [start; end] = 2*hidden, optionally + bbox
241
+ span_input_dim = 2 * hidden_size
242
+ if use_bbox_features:
243
+ span_input_dim += num_bbox_features
244
+
245
+ self.span_proj = nn.Sequential(
246
+ nn.Linear(span_input_dim, proj_dim),
247
+ nn.LayerNorm(proj_dim),
248
+ nn.GELU(),
249
+ nn.Dropout(dropout),
250
+ )
251
+ self.query_proj = nn.Sequential(
252
+ nn.Linear(hidden_size, proj_dim),
253
+ nn.LayerNorm(proj_dim),
254
+ nn.GELU(),
255
+ nn.Dropout(dropout),
256
+ )
257
+ # Learnable temperature (initialized to sqrt(proj_dim) like standard attention)
258
+ self.log_temp = nn.Parameter(torch.tensor(proj_dim ** 0.5).log())
259
+
260
+ def forward(
261
+ self,
262
+ span_embeddings: torch.Tensor, # (batch, num_spans, span_input_dim)
263
+ query_embedding: torch.Tensor, # (batch, hidden)
264
+ span_mask: Optional[torch.Tensor] = None,
265
+ ) -> torch.Tensor:
266
+ span_proj = self.span_proj(span_embeddings)
267
+ query_proj = self.query_proj(query_embedding).unsqueeze(1)
268
+
269
+ # Normalize for cosine similarity-like behavior
270
+ span_proj = F.normalize(span_proj, dim=-1)
271
+ query_proj = F.normalize(query_proj, dim=-1)
272
+
273
+ # Scaled dot product with learnable temperature
274
+ temperature = self.log_temp.exp().clamp(min=0.1, max=100.0)
275
+ scores = torch.bmm(query_proj, span_proj.transpose(1, 2)).squeeze(1) * temperature
276
+
277
+ if span_mask is not None:
278
+ scores = scores.masked_fill(~span_mask.bool(), float('-inf'))
279
+ return scores
280
+
281
+
282
+ class TableHead(nn.Module):
283
+ """
284
+ Step 3: Hierarchical Table Extraction Head.
285
+
286
+ Sub-steps:
287
+ 1. Table detection (TABLE/O per token)
288
+ 2. Header detection (HEADER/CELL per token)
289
+ 3. Row segmentation (B-ROW/I-ROW/O)
290
+ 4. Column assignment (cell → header via attention)
291
+ """
292
+
293
+ def __init__(self, hidden_size: int = 768, dropout: float = 0.1):
294
+ super().__init__()
295
+
296
+ self.table_detector = nn.Sequential(
297
+ nn.Dropout(dropout),
298
+ nn.Linear(hidden_size, 256),
299
+ nn.GELU(),
300
+ nn.Linear(256, 2),
301
+ )
302
+
303
+ self.header_detector = nn.Sequential(
304
+ nn.Dropout(dropout),
305
+ nn.Linear(hidden_size, 256),
306
+ nn.GELU(),
307
+ nn.Linear(256, 2),
308
+ )
309
+
310
+ self.row_tagger = nn.Sequential(
311
+ nn.Dropout(dropout),
312
+ nn.Linear(hidden_size, 256),
313
+ nn.GELU(),
314
+ nn.Linear(256, 3),
315
+ )
316
+
317
+ self.col_key_proj = nn.Linear(hidden_size, 128)
318
+ self.col_query_proj = nn.Linear(hidden_size, 128)
319
+
320
+ def forward(self, hidden_states: torch.Tensor) -> Dict[str, torch.Tensor]:
321
+ table_logits = self.table_detector(hidden_states)
322
+ header_logits = self.header_detector(hidden_states)
323
+ row_logits = self.row_tagger(hidden_states)
324
+
325
+ col_keys = self.col_key_proj(hidden_states)
326
+ col_queries = self.col_query_proj(hidden_states)
327
+ col_scores = torch.bmm(col_queries, col_keys.transpose(1, 2)) / (128 ** 0.5)
328
+
329
+ return {
330
+ "table_logits": table_logits,
331
+ "header_logits": header_logits,
332
+ "row_logits": row_logits,
333
+ "col_scores": col_scores,
334
+ }
335
+
336
+
337
+ class DEXTRLiLT(nn.Module):
338
+ """
339
+ DEXTR with LiLT: Multi-Step Document Extraction with Zero-Shot Query Matching.
340
+
341
+ Architecture:
342
+ - Step 1: Token Classification (Q/A/H/T/O) - TRAINED
343
+ - Step 1.5: Q-A Linker (links Question spans to Answer spans) - TRAINED
344
+ - Step 2: Query → Question Matching (ZERO-SHOT with Sentence Transformer)
345
+ - Step 3: Table Head - TRAINED
346
+
347
+ Uses SEPARATE encoders:
348
+ - Document encoder: LiLT (XLM-RoBERTa + Layout Transformer) for layout-aware encoding
349
+ - Query encoder: Sentence Transformer (frozen) for semantic similarity
350
+
351
+ Zero-shot matching:
352
+ - Query text encoded with frozen Sentence Transformer
353
+ - Question text (from predicted Q regions) encoded with frozen Sentence Transformer
354
+ - Direct cosine similarity (no trainable projections)
355
+ - Fallback: if no Q regions, use QuestionPredictor on A text
356
+ """
357
+
358
+ def __init__(
359
+ self,
360
+ model_name: str = "nielsr/lilt-xlm-roberta-base",
361
+ query_model_name: str = "paraphrase-multilingual-MiniLM-L12-v2",
362
+ max_seq_len: int = 1024,
363
+ hidden_size: int = 768,
364
+ dropout: float = 0.1,
365
+ q_mask_prob: float = 0.3, # Probability of masking Q during training (for Q-A linker)
366
+ ):
367
+ super().__init__()
368
+
369
+ self.max_seq_len = max_seq_len
370
+ self.hidden_size = hidden_size
371
+ self.q_mask_prob = q_mask_prob
372
+
373
+ # Document encoder: LiLT (layout-aware)
374
+ if max_seq_len > 512:
375
+ self.encoder = self._load_lilt_extended(model_name, max_seq_len)
376
+ else:
377
+ self.encoder = LiltModel.from_pretrained(model_name)
378
+
379
+ # Query encoder: Sentence Transformer (frozen, zero-shot)
380
+ # Provides proper semantic similarity for query→question matching
381
+ self.query_encoder = SentenceTransformer(query_model_name)
382
+ self.query_encoder.requires_grad_(False) # Freeze for zero-shot
383
+ self.query_embed_dim = self.query_encoder.get_sentence_embedding_dimension()
384
+ print(f"Loaded frozen query encoder: {query_model_name} (dim={self.query_embed_dim})")
385
+
386
+ # Step 1: Token Classification Head
387
+ self.token_classifier = TokenClassificationHead(hidden_size, NUM_LABELS, dropout)
388
+
389
+ # Step 1.5: Q-A Linker (links Q spans to A spans)
390
+ self.qa_linker = QALinker(hidden_size, dropout)
391
+
392
+ # Step 2: QuestionPredictor for documents without explicit Q labels (e.g., receipts)
393
+ # Input: LiLT A embedding (768 dim), Output: Q embedding in sentence transformer space (384 dim)
394
+ self.question_predictor = QuestionPredictor(hidden_size, self.query_embed_dim, dropout)
395
+
396
+ # Step 3: Table Head
397
+ self.table_head = TableHead(hidden_size, dropout)
398
+
399
+ def _load_lilt_extended(self, model_name: str, max_seq_len: int) -> LiltModel:
400
+ """Load LiLT with extended position embeddings."""
401
+ model = LiltModel.from_pretrained(model_name)
402
+ original_max_pos = model.config.max_position_embeddings
403
+ required_positions = max_seq_len + 2
404
+
405
+ if required_positions <= original_max_pos:
406
+ return model
407
+
408
+ # 1. Extend text position embeddings
409
+ old_pos_emb = model.embeddings.position_embeddings.weight.data
410
+ hidden_size = old_pos_emb.shape[1]
411
+ padding_idx = model.embeddings.position_embeddings.padding_idx
412
+
413
+ new_pos_emb = nn.Embedding(required_positions, hidden_size, padding_idx=padding_idx)
414
+ new_pos_emb.weight.data[:original_max_pos] = old_pos_emb
415
+ new_pos_emb.weight.data[original_max_pos:].normal_(mean=0.0, std=0.02)
416
+
417
+ model.embeddings.position_embeddings = new_pos_emb
418
+
419
+ # 2. Extend layout box_position_embeddings (same sequence length limit)
420
+ old_box_pos = model.layout_embeddings.box_position_embeddings.weight.data
421
+ box_hidden = old_box_pos.shape[1] # 192
422
+
423
+ new_box_pos = nn.Embedding(required_positions, box_hidden)
424
+ new_box_pos.weight.data[:original_max_pos] = old_box_pos
425
+ new_box_pos.weight.data[original_max_pos:].normal_(mean=0.0, std=0.02)
426
+
427
+ model.layout_embeddings.box_position_embeddings = new_box_pos
428
+
429
+ # 3. Update config
430
+ model.config.max_position_embeddings = required_positions
431
+
432
+ # 4. Extend position_ids buffer
433
+ new_position_ids = torch.arange(required_positions).unsqueeze(0)
434
+ model.embeddings.register_buffer("position_ids", new_position_ids, persistent=False)
435
+
436
+ print(f"Extended LiLT positions: {original_max_pos} → {required_positions}")
437
+ return model
438
+
439
+ def encode(
440
+ self,
441
+ input_ids: torch.Tensor,
442
+ attention_mask: torch.Tensor,
443
+ bbox: torch.Tensor,
444
+ ) -> torch.Tensor:
445
+ """Encode tokens with layout using shared LiLT encoder."""
446
+ outputs = self.encoder(
447
+ input_ids=input_ids,
448
+ attention_mask=attention_mask,
449
+ bbox=bbox,
450
+ )
451
+ return outputs.last_hidden_state
452
+
453
+ def encode_texts(
454
+ self,
455
+ texts: List[str],
456
+ device: torch.device,
457
+ ) -> torch.Tensor:
458
+ """
459
+ Encode text strings using frozen Sentence Transformer (zero-shot).
460
+
461
+ Used for encoding Q text extracted from predicted regions.
462
+
463
+ Args:
464
+ texts: List of text strings to encode
465
+ device: Device to put tensors on
466
+
467
+ Returns:
468
+ embeddings: (num_texts, embed_dim) tensor
469
+ """
470
+ if not texts:
471
+ return torch.zeros(0, self.query_embed_dim, device=device)
472
+
473
+ with torch.no_grad():
474
+ embeddings = self.query_encoder.encode(
475
+ texts,
476
+ convert_to_tensor=True,
477
+ device=device,
478
+ )
479
+ # Clone to exit inference mode (needed for autograd compatibility)
480
+ embeddings = embeddings.clone()
481
+
482
+ return embeddings # (num_texts, embed_dim)
483
+
484
+ def pool_spans(
485
+ self,
486
+ hidden_states: torch.Tensor,
487
+ span_indices: List[List[Tuple[int, int]]],
488
+ bbox: Optional[torch.Tensor] = None,
489
+ use_bbox_features: bool = True,
490
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
491
+ """
492
+ Pool hidden states for answer spans using start+end tokens.
493
+
494
+ Returns [h_start; h_end; bbox_features] for each span instead of mean pooling.
495
+ This preserves boundary information which is crucial for extraction tasks.
496
+
497
+ Args:
498
+ hidden_states: (batch, seq, hidden) - document encodings
499
+ span_indices: List of (start, end) tuples per batch item
500
+ bbox: (batch, seq, 4) - bounding boxes [x1, y1, x2, y2] normalized to [0, 1000]
501
+ use_bbox_features: whether to include bbox in span representation
502
+
503
+ Returns:
504
+ span_embeddings: (batch, max_spans, span_dim) where span_dim = 2*hidden + 6 (if bbox)
505
+ span_mask: (batch, max_spans) bool mask
506
+ """
507
+ batch_size = hidden_states.shape[0]
508
+ hidden_size = hidden_states.shape[2]
509
+ device = hidden_states.device
510
+
511
+ max_spans = max(len(spans) for spans in span_indices) if span_indices else 1
512
+ max_spans = max(max_spans, 1)
513
+
514
+ # Output dimension: [h_start; h_end] + optionally [x1, y1, x2, y2, width, height]
515
+ span_dim = 2 * hidden_size
516
+ if use_bbox_features and bbox is not None:
517
+ span_dim += 6 # normalized bbox features
518
+
519
+ span_embeddings = torch.zeros(batch_size, max_spans, span_dim, device=device)
520
+ span_mask = torch.zeros(batch_size, max_spans, dtype=torch.bool, device=device)
521
+
522
+ for b, spans in enumerate(span_indices):
523
+ for s, (start, end) in enumerate(spans):
524
+ if s >= max_spans:
525
+ break
526
+ if start >= hidden_states.shape[1] or end > hidden_states.shape[1]:
527
+ continue
528
+
529
+ # Start+End pooling: [h_start; h_end]
530
+ h_start = hidden_states[b, start, :]
531
+ h_end = hidden_states[b, end - 1, :] # end is exclusive
532
+ span_repr = torch.cat([h_start, h_end], dim=0)
533
+
534
+ # Add bbox features if available
535
+ if use_bbox_features and bbox is not None:
536
+ # Get bbox of span (union of start and end token bboxes)
537
+ bbox_start = bbox[b, start, :] # [x1, y1, x2, y2]
538
+ bbox_end = bbox[b, end - 1, :]
539
+
540
+ # Compute span bounding box (min x1/y1, max x2/y2)
541
+ span_x1 = torch.min(bbox_start[0], bbox_end[0])
542
+ span_y1 = torch.min(bbox_start[1], bbox_end[1])
543
+ span_x2 = torch.max(bbox_start[2], bbox_end[2])
544
+ span_y2 = torch.max(bbox_start[3], bbox_end[3])
545
+
546
+ # Normalize to [0, 1] and add width/height
547
+ bbox_feat = torch.tensor([
548
+ span_x1 / 1000.0,
549
+ span_y1 / 1000.0,
550
+ span_x2 / 1000.0,
551
+ span_y2 / 1000.0,
552
+ (span_x2 - span_x1) / 1000.0, # width
553
+ (span_y2 - span_y1) / 1000.0, # height
554
+ ], device=device)
555
+
556
+ span_repr = torch.cat([span_repr, bbox_feat], dim=0)
557
+
558
+ span_embeddings[b, s] = span_repr
559
+ span_mask[b, s] = True
560
+
561
+ return span_embeddings, span_mask
562
+
563
+ def pool_single_span(
564
+ self,
565
+ hidden_states: torch.Tensor, # (seq, hidden) - single sample
566
+ span: Tuple[int, int],
567
+ bbox: Optional[torch.Tensor] = None, # (seq, 4)
568
+ ) -> torch.Tensor:
569
+ """
570
+ Pool a single span to get its embedding.
571
+ Returns mean of start and end tokens (hidden_size,).
572
+ """
573
+ start, end = span
574
+ if start >= hidden_states.shape[0] or end > hidden_states.shape[0]:
575
+ return torch.zeros(self.hidden_size, device=hidden_states.device)
576
+
577
+ h_start = hidden_states[start, :]
578
+ h_end = hidden_states[end - 1, :]
579
+
580
+ # Mean of start and end for Q embedding (simpler than concat for matching)
581
+ return (h_start + h_end) / 2
582
+
583
+ def get_span_bbox(
584
+ self,
585
+ bbox: torch.Tensor, # (seq, 4)
586
+ span: Tuple[int, int],
587
+ ) -> torch.Tensor:
588
+ """Get bounding box for a span (union of start and end token bboxes)."""
589
+ start, end = span
590
+ bbox_start = bbox[start, :]
591
+ bbox_end = bbox[end - 1, :]
592
+
593
+ span_bbox = torch.stack([
594
+ torch.min(bbox_start[0], bbox_end[0]), # x1
595
+ torch.min(bbox_start[1], bbox_end[1]), # y1
596
+ torch.max(bbox_start[2], bbox_end[2]), # x2
597
+ torch.max(bbox_start[3], bbox_end[3]), # y2
598
+ ])
599
+ return span_bbox
600
+
601
+ def extract_span_text(
602
+ self,
603
+ span: Tuple[int, int],
604
+ tokens: List[str],
605
+ subword_to_word: Dict[int, int],
606
+ ) -> str:
607
+ """
608
+ Extract text from a span using tokens and subword-to-word mapping.
609
+
610
+ Args:
611
+ span: (start, end) subword indices (exclusive end)
612
+ tokens: List of word tokens
613
+ subword_to_word: Dict mapping subword idx -> word idx
614
+
615
+ Returns:
616
+ Text string for the span
617
+ """
618
+ start, end = span
619
+ # Get word indices for this span
620
+ word_indices = set()
621
+ for subword_idx in range(start, end):
622
+ if subword_idx in subword_to_word:
623
+ word_indices.add(subword_to_word[subword_idx])
624
+
625
+ if not word_indices:
626
+ return ""
627
+
628
+ # Get contiguous word range
629
+ min_word = min(word_indices)
630
+ max_word = max(word_indices)
631
+
632
+ # Extract and join tokens
633
+ span_tokens = tokens[min_word:max_word + 1]
634
+ return " ".join(span_tokens)
635
+
636
+ def extract_qa_regions(
637
+ self,
638
+ token_logits: torch.Tensor, # (batch, seq, num_labels)
639
+ attention_mask: torch.Tensor, # (batch, seq)
640
+ ) -> Tuple[List[List[Tuple[int, int]]], List[List[Tuple[int, int]]]]:
641
+ """
642
+ Extract Q and A regions from Step 1 token predictions.
643
+
644
+ This enables the cascading architecture where Step 2 uses
645
+ Step 1's predictions instead of ground truth spans.
646
+
647
+ Returns:
648
+ q_regions: List of Q spans per batch item [(start, end), ...]
649
+ a_regions: List of A spans per batch item [(start, end), ...]
650
+ """
651
+ batch_size = token_logits.shape[0]
652
+ preds = token_logits.argmax(dim=-1) # (batch, seq)
653
+
654
+ q_regions = []
655
+ a_regions = []
656
+
657
+ for b in range(batch_size):
658
+ sample_q_regions = []
659
+ sample_a_regions = []
660
+
661
+ seq_len = attention_mask[b].sum().item()
662
+ pred_seq = preds[b, :int(seq_len)].cpu().tolist()
663
+
664
+ # Extract Q spans (B-Q=1, I-Q=2)
665
+ current_span = None
666
+ for i, label in enumerate(pred_seq):
667
+ if label == 1: # B-Q
668
+ if current_span is not None:
669
+ sample_q_regions.append(current_span)
670
+ current_span = (i, i + 1)
671
+ elif label == 2: # I-Q
672
+ if current_span is not None:
673
+ current_span = (current_span[0], i + 1)
674
+ else:
675
+ if current_span is not None:
676
+ sample_q_regions.append(current_span)
677
+ current_span = None
678
+ if current_span is not None:
679
+ sample_q_regions.append(current_span)
680
+
681
+ # Extract A spans (B-A=3, I-A=4)
682
+ current_span = None
683
+ for i, label in enumerate(pred_seq):
684
+ if label == 3: # B-A
685
+ if current_span is not None:
686
+ sample_a_regions.append(current_span)
687
+ current_span = (i, i + 1)
688
+ elif label == 4: # I-A
689
+ if current_span is not None:
690
+ current_span = (current_span[0], i + 1)
691
+ else:
692
+ if current_span is not None:
693
+ sample_a_regions.append(current_span)
694
+ current_span = None
695
+ if current_span is not None:
696
+ sample_a_regions.append(current_span)
697
+
698
+ q_regions.append(sample_q_regions)
699
+ a_regions.append(sample_a_regions)
700
+
701
+ return q_regions, a_regions
702
+
703
+ def match_regions_to_gt(
704
+ self,
705
+ pred_regions: List[Tuple[int, int]],
706
+ gt_regions: List[Tuple[int, int]],
707
+ ) -> Tuple[List[int], List[int]]:
708
+ """
709
+ Match predicted regions to GT regions by overlap.
710
+
711
+ Returns:
712
+ gt_to_pred: For each GT region, index of best matching pred region (-1 if none)
713
+ pred_to_gt: For each pred region, index of best matching GT region (-1 if none)
714
+ """
715
+ gt_to_pred = []
716
+ for gt_start, gt_end in gt_regions:
717
+ best_pred_idx = -1
718
+ best_overlap = 0
719
+ for pred_idx, (pred_start, pred_end) in enumerate(pred_regions):
720
+ overlap_start = max(gt_start, pred_start)
721
+ overlap_end = min(gt_end, pred_end)
722
+ overlap = max(0, overlap_end - overlap_start)
723
+ if overlap > best_overlap:
724
+ best_overlap = overlap
725
+ best_pred_idx = pred_idx
726
+ gt_to_pred.append(best_pred_idx)
727
+
728
+ pred_to_gt = []
729
+ for pred_start, pred_end in pred_regions:
730
+ best_gt_idx = -1
731
+ best_overlap = 0
732
+ for gt_idx, (gt_start, gt_end) in enumerate(gt_regions):
733
+ overlap_start = max(gt_start, pred_start)
734
+ overlap_end = min(gt_end, pred_end)
735
+ overlap = max(0, overlap_end - overlap_start)
736
+ if overlap > best_overlap:
737
+ best_overlap = overlap
738
+ best_gt_idx = gt_idx
739
+ pred_to_gt.append(best_gt_idx)
740
+
741
+ return gt_to_pred, pred_to_gt
742
+
743
+ def forward(
744
+ self,
745
+ input_ids: torch.Tensor,
746
+ attention_mask: torch.Tensor,
747
+ bbox: torch.Tensor,
748
+ tokens: Optional[List[List[str]]] = None,
749
+ subword_to_word: Optional[List[Dict[int, int]]] = None,
750
+ query_texts: Optional[List[str]] = None,
751
+ gt_answer_spans: Optional[List[List[Tuple[int, int]]]] = None,
752
+ gt_question_spans: Optional[List[List[Optional[Tuple[int, int]]]]] = None,
753
+ target_field_idx: Optional[List[int]] = None,
754
+ training: bool = True,
755
+ ) -> Dict[str, torch.Tensor]:
756
+ """
757
+ Forward pass with CASCADING architecture and ZERO-SHOT query matching.
758
+
759
+ Key features:
760
+ - Step 1: Token classification to predict Q/A regions
761
+ - Step 1.5: Q-A linking using LiLT embeddings + spatial features
762
+ - Step 2: ZERO-SHOT query→question matching using Sentence Transformer
763
+ - GT spans are only used to compute labels (which predicted region is correct)
764
+
765
+ Args:
766
+ input_ids: (batch, seq) token IDs
767
+ attention_mask: (batch, seq) attention mask
768
+ bbox: (batch, seq, 4) bounding boxes
769
+ tokens: List of word tokens per batch item (for zero-shot Q text encoding)
770
+ subword_to_word: List of dicts mapping subword idx to word idx
771
+ query_texts: List of raw query strings (for sentence transformer)
772
+ gt_answer_spans: GT answer spans per batch item (for loss labels only)
773
+ gt_question_spans: GT question spans per batch item (for loss labels only)
774
+ target_field_idx: Index of GT field that query should match (for loss labels)
775
+ training: whether in training mode (affects Q masking for Q-A linker)
776
+ """
777
+ batch_size = input_ids.shape[0]
778
+ device = input_ids.device
779
+
780
+ # Encode document
781
+ hidden_states = self.encode(input_ids, attention_mask, bbox)
782
+
783
+ # Step 1: Token Classification
784
+ token_logits = self.token_classifier(hidden_states)
785
+
786
+ # Step 3: Table Head
787
+ table_outputs = self.table_head(hidden_states)
788
+
789
+ outputs = {
790
+ "hidden_states": hidden_states,
791
+ "token_logits": token_logits,
792
+ **table_outputs,
793
+ }
794
+
795
+ # Step 1.5 + Step 2: Q-A Linking and Query-Question Matching
796
+ # CASCADING: Extract Q/A regions from Step 1 PREDICTIONS
797
+ if query_texts is not None:
798
+ # Extract predicted Q and A regions from Step 1 output
799
+ pred_q_regions, pred_a_regions = self.extract_qa_regions(token_logits, attention_mask)
800
+ outputs["pred_q_regions"] = pred_q_regions
801
+ outputs["pred_a_regions"] = pred_a_regions
802
+
803
+ # Encode query using sentence transformer
804
+ query_emb = self.encode_texts(query_texts, device)
805
+ outputs["query_embedding"] = query_emb
806
+
807
+ # Process each sample in batch using PREDICTED regions
808
+ all_q_embeds_lilt = [] # Q embeddings from LiLT (768 dim) for QA Linker
809
+ all_q_embeds_st = [] # Q embeddings from sentence transformer (384 dim) for query matching
810
+ all_a_embeds = [] # A embeddings from predicted regions (768 dim)
811
+ all_q_bboxes = [] # Q bboxes for QA linker
812
+ all_a_bboxes = [] # A bboxes for QA linker
813
+ all_pred_q_from_a = [] # Predicted Q from A (for QuestionPredictor loss)
814
+ all_real_q_from_pred = [] # Real Q from predicted regions (for Q masking loss)
815
+ all_real_q_embeds = [] # Real Q embeds from GT (for aux loss)
816
+ all_gt_to_pred_a_idx = [] # Maps GT field idx to predicted A region idx
817
+
818
+ for b in range(batch_size):
819
+ sample_pred_q = pred_q_regions[b] # Predicted Q spans from Step 1
820
+ sample_pred_a = pred_a_regions[b] # Predicted A spans from Step 1
821
+
822
+ # Skip if no predicted regions
823
+ if len(sample_pred_a) == 0:
824
+ all_q_embeds_lilt.append(None)
825
+ all_q_embeds_st.append(None)
826
+ all_a_embeds.append(None)
827
+ all_q_bboxes.append(None)
828
+ all_a_bboxes.append(None)
829
+ all_pred_q_from_a.append(None)
830
+ all_real_q_from_pred.append(None)
831
+ all_real_q_embeds.append(None)
832
+ all_gt_to_pred_a_idx.append(None)
833
+ continue
834
+
835
+ # Pool A embeddings from predicted regions
836
+ sample_a_embeds = []
837
+ sample_a_bboxes = []
838
+ for a_span in sample_pred_a:
839
+ a_emb = self.pool_single_span(hidden_states[b], a_span, bbox[b])
840
+ sample_a_embeds.append(a_emb)
841
+ sample_a_bboxes.append(self.get_span_bbox(bbox[b], a_span))
842
+
843
+ # Two Q representations:
844
+ # 1. q_embeds_lilt (768) - LiLT pooled, for QA Linker
845
+ # 2. q_embeds_st (384) - Sentence transformer, for query matching
846
+ sample_q_embeds_lilt = [] # For QA Linker (768 dim)
847
+ sample_q_embeds_st = [] # For query matching (384 dim)
848
+ sample_q_bboxes = []
849
+ sample_real_q_from_pred = [] # Real Q embeds (ST) for QuestionPredictor loss
850
+ sample_pred_q_from_a = [] # Predicted Q embeds from A (for loss)
851
+
852
+ if len(sample_pred_q) > 0:
853
+ # We have predicted Q regions
854
+ # Get Q texts for sentence transformer encoding
855
+ q_texts = []
856
+ if tokens is not None and subword_to_word is not None:
857
+ for q_span in sample_pred_q:
858
+ q_text = self.extract_span_text(q_span, tokens[b], subword_to_word[b])
859
+ q_texts.append(q_text if q_text else "unknown")
860
+ # Batch encode all Q texts with sentence transformer
861
+ real_q_embeds_st = self.encode_texts(q_texts, device) # (num_q, 384)
862
+ else:
863
+ real_q_embeds_st = None
864
+
865
+ for i, q_span in enumerate(sample_pred_q):
866
+ # LiLT embedding for QA Linker
867
+ q_emb_lilt = self.pool_single_span(hidden_states[b], q_span, bbox[b])
868
+ sample_q_embeds_lilt.append(q_emb_lilt)
869
+
870
+ q_bbox = self.get_span_bbox(bbox[b], q_span)
871
+ sample_q_bboxes.append(q_bbox)
872
+
873
+ # Sentence transformer embedding for query matching
874
+ if real_q_embeds_st is not None:
875
+ real_q_emb_st = real_q_embeds_st[i] # 384 dim
876
+ else:
877
+ real_q_emb_st = None
878
+
879
+ # Q MASKING: during training, randomly mask Q and use predictor
880
+ should_mask = training and (torch.rand(1).item() < self.q_mask_prob)
881
+
882
+ if should_mask and i < len(sample_a_embeds):
883
+ # Mask: use QuestionPredictor instead of real Q for query matching
884
+ pred_q_emb = self.question_predictor(sample_a_embeds[i]) # 384 dim
885
+ sample_q_embeds_st.append(pred_q_emb)
886
+ sample_real_q_from_pred.append(real_q_emb_st) # Real Q for loss
887
+ sample_pred_q_from_a.append(pred_q_emb) # Predicted Q for loss
888
+ else:
889
+ # No mask: use real Q (sentence transformer encoded)
890
+ sample_q_embeds_st.append(real_q_emb_st if real_q_emb_st is not None else torch.zeros(self.query_embed_dim, device=device))
891
+ sample_real_q_from_pred.append(real_q_emb_st)
892
+ sample_pred_q_from_a.append(None) # No prediction, no loss
893
+ else:
894
+ # No Q regions predicted - use QuestionPredictor on all A
895
+ for a_emb in sample_a_embeds:
896
+ pred_q = self.question_predictor(a_emb) # 384 dim
897
+ sample_q_embeds_st.append(pred_q)
898
+ sample_pred_q_from_a.append(pred_q)
899
+ sample_real_q_from_pred.append(None) # No real Q from prediction
900
+ # No LiLT Q embeds when no Q predicted
901
+ sample_q_embeds_lilt = []
902
+ # Use A bboxes as proxy for Q bboxes
903
+ sample_q_bboxes = sample_a_bboxes.copy()
904
+
905
+ # Store real Q embeds from GT (for aux loss when Q was masked)
906
+ # Encode GT Q text with sentence transformer (384 dim)
907
+ sample_real_q = []
908
+ if gt_question_spans is not None and b < len(gt_question_spans) and tokens is not None and subword_to_word is not None:
909
+ gt_q_texts = []
910
+ gt_q_valid_indices = []
911
+ for idx, gt_q_span in enumerate(gt_question_spans[b]):
912
+ if gt_q_span is not None:
913
+ q_text = self.extract_span_text(gt_q_span, tokens[b], subword_to_word[b])
914
+ gt_q_texts.append(q_text if q_text else "unknown")
915
+ gt_q_valid_indices.append(idx)
916
+
917
+ # Batch encode GT Q texts
918
+ if gt_q_texts:
919
+ gt_q_embeds = self.encode_texts(gt_q_texts, device) # (num_valid, 384)
920
+ embed_idx = 0
921
+ for idx, gt_q_span in enumerate(gt_question_spans[b]):
922
+ if gt_q_span is not None:
923
+ sample_real_q.append(gt_q_embeds[embed_idx])
924
+ embed_idx += 1
925
+ else:
926
+ sample_real_q.append(None)
927
+ else:
928
+ sample_real_q = [None] * len(gt_question_spans[b])
929
+
930
+ # Match GT A spans to predicted A regions (for label computation)
931
+ gt_to_pred_a = None
932
+ if gt_answer_spans is not None and b < len(gt_answer_spans):
933
+ gt_a_spans = gt_answer_spans[b]
934
+ gt_to_pred_a, _ = self.match_regions_to_gt(sample_pred_a, gt_a_spans)
935
+
936
+ all_q_embeds_lilt.append(torch.stack(sample_q_embeds_lilt) if sample_q_embeds_lilt else None)
937
+ all_q_embeds_st.append(torch.stack(sample_q_embeds_st) if sample_q_embeds_st else None)
938
+ all_a_embeds.append(torch.stack(sample_a_embeds) if sample_a_embeds else None)
939
+ all_q_bboxes.append(torch.stack(sample_q_bboxes) if sample_q_bboxes else None)
940
+ all_a_bboxes.append(torch.stack(sample_a_bboxes) if sample_a_bboxes else None)
941
+ # For Q predictor loss: parallel lists with None for non-masked positions
942
+ all_pred_q_from_a.append(sample_pred_q_from_a if sample_pred_q_from_a else None)
943
+ all_real_q_from_pred.append(sample_real_q_from_pred if sample_real_q_from_pred else None)
944
+ all_real_q_embeds.append(sample_real_q if sample_real_q else None)
945
+ all_gt_to_pred_a_idx.append(gt_to_pred_a)
946
+
947
+ # Filter out None entries and compute outputs
948
+ # Use LiLT Q embeds for valid check (QA Linker needs them)
949
+ valid_indices_lilt = [i for i, q in enumerate(all_q_embeds_lilt) if q is not None]
950
+ # Also track valid ST Q embeds for query matching
951
+ valid_indices_st = [i for i, q in enumerate(all_q_embeds_st) if q is not None]
952
+
953
+ if valid_indices_lilt:
954
+ # Get max spans for padding (use LiLT Q for QA Linker)
955
+ max_q_spans_lilt = max(all_q_embeds_lilt[i].shape[0] for i in valid_indices_lilt)
956
+ max_a_spans = max(all_a_embeds[i].shape[0] for i in valid_indices_lilt)
957
+
958
+ # Create padded tensors for Q (LiLT, 768 dim) - for QA Linker
959
+ q_embeds_lilt_padded = torch.zeros(batch_size, max_q_spans_lilt, self.hidden_size, device=device)
960
+ q_bboxes_padded = torch.zeros(batch_size, max_q_spans_lilt, 4, device=device)
961
+ q_span_mask = torch.zeros(batch_size, max_q_spans_lilt, dtype=torch.bool, device=device)
962
+
963
+ # Create padded tensors for A (LiLT, 768 dim)
964
+ a_embeds_padded = torch.zeros(batch_size, max_a_spans, self.hidden_size, device=device)
965
+ a_bboxes_padded = torch.zeros(batch_size, max_a_spans, 4, device=device)
966
+ a_span_mask = torch.zeros(batch_size, max_a_spans, dtype=torch.bool, device=device)
967
+
968
+ for b in valid_indices_lilt:
969
+ nq = all_q_embeds_lilt[b].shape[0]
970
+ q_embeds_lilt_padded[b, :nq] = all_q_embeds_lilt[b]
971
+ q_bboxes_padded[b, :nq] = all_q_bboxes[b]
972
+ q_span_mask[b, :nq] = True
973
+
974
+ na = all_a_embeds[b].shape[0]
975
+ a_embeds_padded[b, :na] = all_a_embeds[b]
976
+ a_bboxes_padded[b, :na] = all_a_bboxes[b]
977
+ a_span_mask[b, :na] = True
978
+
979
+ # Step 1.5: Q-A Linker - predict which Q links to which A (uses LiLT embeds)
980
+ qa_link_scores_list = []
981
+ for b in valid_indices_lilt:
982
+ nq = all_q_embeds_lilt[b].shape[0]
983
+ na = all_a_embeds[b].shape[0]
984
+ if nq > 0 and na > 0:
985
+ link_scores = self.qa_linker(
986
+ all_q_embeds_lilt[b], # (num_q, 768) LiLT
987
+ all_a_embeds[b], # (num_a, 768) LiLT
988
+ all_q_bboxes[b], # (num_q, 4)
989
+ all_a_bboxes[b], # (num_a, 4)
990
+ )
991
+ qa_link_scores_list.append(link_scores)
992
+ else:
993
+ qa_link_scores_list.append(None)
994
+
995
+ outputs["qa_link_scores"] = qa_link_scores_list
996
+
997
+ # Create padded tensors for Q (sentence transformer, 384 dim) - for query matching
998
+ if valid_indices_st:
999
+ max_q_spans_st = max(all_q_embeds_st[i].shape[0] for i in valid_indices_st)
1000
+ q_embeds_st_padded = torch.zeros(batch_size, max_q_spans_st, self.query_embed_dim, device=device)
1001
+ q_span_mask_st = torch.zeros(batch_size, max_q_spans_st, dtype=torch.bool, device=device)
1002
+
1003
+ for b in valid_indices_st:
1004
+ nq = all_q_embeds_st[b].shape[0]
1005
+ q_embeds_st_padded[b, :nq] = all_q_embeds_st[b]
1006
+ q_span_mask_st[b, :nq] = True
1007
+
1008
+ outputs["q_embeds_st"] = q_embeds_st_padded # For query matching
1009
+ outputs["q_span_mask_st"] = q_span_mask_st
1010
+
1011
+ # GT-based QA link scores (for training - uses GT spans directly, LiLT embeddings)
1012
+ if training and gt_question_spans is not None and gt_answer_spans is not None:
1013
+ gt_qa_link_scores_list = []
1014
+ gt_valid_q_indices_list = [] # Track which Q indices are valid (not None)
1015
+ for b in range(batch_size):
1016
+ gt_q_spans = gt_question_spans[b] if b < len(gt_question_spans) else []
1017
+ gt_a_spans = gt_answer_spans[b] if b < len(gt_answer_spans) else []
1018
+
1019
+ # Filter valid Q spans (not None) and track indices
1020
+ valid_q_indices = [i for i, q in enumerate(gt_q_spans) if q is not None]
1021
+ valid_q_spans = [gt_q_spans[i] for i in valid_q_indices]
1022
+
1023
+ if len(valid_q_spans) > 0 and len(gt_a_spans) > 0:
1024
+ # Pool embeddings from GT spans (LiLT, 768 dim)
1025
+ gt_q_embeds = torch.stack([
1026
+ self.pool_single_span(hidden_states[b], q_span, bbox[b])
1027
+ for q_span in valid_q_spans
1028
+ ])
1029
+ gt_a_embeds = torch.stack([
1030
+ self.pool_single_span(hidden_states[b], a_span, bbox[b])
1031
+ for a_span in gt_a_spans
1032
+ ])
1033
+ gt_q_bboxes = torch.stack([
1034
+ self.get_span_bbox(bbox[b], q_span)
1035
+ for q_span in valid_q_spans
1036
+ ])
1037
+ gt_a_bboxes = torch.stack([
1038
+ self.get_span_bbox(bbox[b], a_span)
1039
+ for a_span in gt_a_spans
1040
+ ])
1041
+
1042
+ # Compute QA link scores on GT
1043
+ gt_link_scores = self.qa_linker(
1044
+ gt_q_embeds, gt_a_embeds, gt_q_bboxes, gt_a_bboxes
1045
+ )
1046
+ gt_qa_link_scores_list.append(gt_link_scores)
1047
+ gt_valid_q_indices_list.append(valid_q_indices)
1048
+ else:
1049
+ gt_qa_link_scores_list.append(None)
1050
+ gt_valid_q_indices_list.append(None)
1051
+
1052
+ outputs["gt_qa_link_scores"] = gt_qa_link_scores_list
1053
+ outputs["gt_valid_q_indices"] = gt_valid_q_indices_list
1054
+
1055
+ # Step 2: Query-Question Matching (ZERO-SHOT with Sentence Transformer)
1056
+ # Use pre-computed q_embeds_st for query matching
1057
+ if valid_indices_st and "q_embeds_st" in outputs:
1058
+ q_embeds_st_padded = outputs["q_embeds_st"]
1059
+ q_span_mask_st = outputs["q_span_mask_st"]
1060
+
1061
+ # Compute cosine similarity between query and Q embeddings
1062
+ query_norm = F.normalize(query_emb, dim=-1)
1063
+ q_norm = F.normalize(q_embeds_st_padded, dim=-1)
1064
+ match_scores = torch.bmm(q_norm, query_norm.unsqueeze(-1)).squeeze(-1)
1065
+ match_scores = match_scores.masked_fill(~q_span_mask_st, float('-inf'))
1066
+ outputs["match_scores"] = match_scores
1067
+ outputs["match_span_mask"] = q_span_mask_st
1068
+
1069
+ # Store LiLT embeddings for QA Linker
1070
+ if valid_indices_lilt:
1071
+ outputs["q_span_mask"] = q_span_mask
1072
+ outputs["a_span_mask"] = a_span_mask
1073
+ outputs["q_embeds"] = q_embeds_lilt_padded
1074
+ outputs["a_embeds"] = a_embeds_padded
1075
+ outputs["q_bboxes"] = q_bboxes_padded
1076
+ outputs["a_bboxes"] = a_bboxes_padded
1077
+
1078
+ # Store for loss computation (Q masking / QuestionPredictor)
1079
+ outputs["pred_q_from_a"] = all_pred_q_from_a
1080
+ outputs["real_q_from_pred"] = all_real_q_from_pred # Real Q (ST) when masked
1081
+ outputs["real_q_embeds"] = all_real_q_embeds # GT Q embeds (ST)
1082
+ outputs["gt_to_pred_a_idx"] = all_gt_to_pred_a_idx
1083
+ outputs["valid_batch_indices_lilt"] = valid_indices_lilt
1084
+ outputs["valid_batch_indices_st"] = valid_indices_st
1085
+
1086
+ return outputs
1087
+
1088
+
1089
+ def compute_loss(
1090
+ outputs: Dict[str, torch.Tensor],
1091
+ token_labels: torch.Tensor,
1092
+ attention_mask: torch.Tensor,
1093
+ table_labels: Optional[torch.Tensor] = None,
1094
+ row_labels: Optional[torch.Tensor] = None,
1095
+ header_labels: Optional[torch.Tensor] = None,
1096
+ match_labels: Optional[torch.Tensor] = None,
1097
+ col_labels: Optional[torch.Tensor] = None,
1098
+ class_weights: Optional[torch.Tensor] = None,
1099
+ qa_link_labels: Optional[List[torch.Tensor]] = None, # NEW: Q-A link labels
1100
+ ) -> Dict[str, torch.Tensor]:
1101
+ """Compute joint loss for all steps including Q-A linking."""
1102
+ device = token_labels.device
1103
+ losses = {}
1104
+
1105
+ # Step 1: Token Classification
1106
+ token_logits = outputs["token_logits"]
1107
+ token_logits_flat = token_logits.view(-1, token_logits.shape[-1])
1108
+ token_labels_flat = token_labels.view(-1)
1109
+ attn_flat = attention_mask.view(-1).bool()
1110
+
1111
+ if class_weights is not None:
1112
+ step1_loss = F.cross_entropy(
1113
+ token_logits_flat[attn_flat],
1114
+ token_labels_flat[attn_flat],
1115
+ weight=class_weights,
1116
+ )
1117
+ else:
1118
+ step1_loss = F.cross_entropy(
1119
+ token_logits_flat[attn_flat],
1120
+ token_labels_flat[attn_flat],
1121
+ )
1122
+ losses["step1_loss"] = step1_loss
1123
+
1124
+ # Step 1.5: Q-A Linker Loss (uses GT spans directly - no dependency on Step 1 predictions)
1125
+ if "gt_qa_link_scores" in outputs and qa_link_labels is not None:
1126
+ qa_link_losses = []
1127
+ gt_scores_list = outputs["gt_qa_link_scores"]
1128
+ gt_valid_q_indices = outputs.get("gt_valid_q_indices", [None] * len(gt_scores_list))
1129
+ for scores, labels, valid_indices in zip(gt_scores_list, qa_link_labels, gt_valid_q_indices):
1130
+ if scores is None or labels is None or valid_indices is None:
1131
+ continue
1132
+ # Filter labels to only include valid Q indices (matching scores shape)
1133
+ filtered_labels = labels[valid_indices] if len(valid_indices) > 0 else labels
1134
+ if scores.numel() > 0 and filtered_labels.numel() > 0:
1135
+ # scores: (num_valid_q, num_gt_a), filtered_labels: (num_valid_q,)
1136
+ num_a = scores.shape[1]
1137
+ valid_mask = (filtered_labels >= 0) & (filtered_labels < num_a)
1138
+ if valid_mask.any():
1139
+ link_loss = F.cross_entropy(scores[valid_mask], filtered_labels[valid_mask])
1140
+ qa_link_losses.append(link_loss)
1141
+ if qa_link_losses:
1142
+ losses["qa_link_loss"] = torch.stack(qa_link_losses).mean()
1143
+ else:
1144
+ losses["qa_link_loss"] = torch.tensor(0.0, device=device)
1145
+ else:
1146
+ losses["qa_link_loss"] = torch.tensor(0.0, device=device)
1147
+
1148
+ # Question Predictor Loss (MSE between predicted Q and real Q in Sentence Transformer space)
1149
+ # QuestionPredictor learns: A_lilt_embed (768) → Q_st_embed (384)
1150
+ # Training: when Q is masked, compare predicted Q with real Q (sentence transformer)
1151
+ if "pred_q_from_a" in outputs and "real_q_from_pred" in outputs:
1152
+ pred_q_from_a = outputs["pred_q_from_a"] # List of lists
1153
+ real_q_from_pred = outputs["real_q_from_pred"] # List of lists
1154
+
1155
+ q_predict_losses = []
1156
+ for batch_pred, batch_real in zip(pred_q_from_a, real_q_from_pred):
1157
+ if batch_pred is None or batch_real is None:
1158
+ continue
1159
+ for pred_q, real_q in zip(batch_pred, batch_real):
1160
+ if pred_q is not None and real_q is not None:
1161
+ # MSE loss to make predicted Q similar to real Q
1162
+ mse = F.mse_loss(pred_q, real_q.detach())
1163
+ q_predict_losses.append(mse)
1164
+
1165
+ if q_predict_losses:
1166
+ losses["q_predict_loss"] = torch.stack(q_predict_losses).mean()
1167
+ else:
1168
+ losses["q_predict_loss"] = torch.tensor(0.0, device=device)
1169
+ else:
1170
+ losses["q_predict_loss"] = torch.tensor(0.0, device=device)
1171
+
1172
+ # Step 2: Query-Question Matching - ZERO-SHOT (no loss)
1173
+ # Zero-shot Q text matching - no training needed
1174
+ # QuestionPredictor IS trained (for A→Q prediction when no Q regions)
1175
+ losses["step2_loss"] = torch.tensor(0.0, device=device)
1176
+
1177
+ # Step 3: Table Losses
1178
+ if table_labels is not None:
1179
+ table_logits = outputs["table_logits"]
1180
+ table_logits_flat = table_logits.view(-1, 2)
1181
+ table_labels_flat = table_labels.view(-1)
1182
+ table_det_loss = F.cross_entropy(
1183
+ table_logits_flat[attn_flat],
1184
+ table_labels_flat[attn_flat],
1185
+ )
1186
+ losses["table_det_loss"] = table_det_loss
1187
+ else:
1188
+ losses["table_det_loss"] = torch.tensor(0.0, device=device)
1189
+
1190
+ if header_labels is not None and table_labels is not None:
1191
+ header_logits = outputs["header_logits"]
1192
+ table_mask = (table_labels == 1) & attention_mask.bool()
1193
+ if table_mask.any():
1194
+ header_logits_flat = header_logits.view(-1, 2)
1195
+ header_labels_flat = header_labels.view(-1)
1196
+ table_mask_flat = table_mask.view(-1)
1197
+ header_loss = F.cross_entropy(
1198
+ header_logits_flat[table_mask_flat],
1199
+ header_labels_flat[table_mask_flat],
1200
+ )
1201
+ losses["header_loss"] = header_loss
1202
+ else:
1203
+ losses["header_loss"] = torch.tensor(0.0, device=device)
1204
+ else:
1205
+ losses["header_loss"] = torch.tensor(0.0, device=device)
1206
+
1207
+ if row_labels is not None and table_labels is not None:
1208
+ row_logits = outputs["row_logits"]
1209
+ table_mask = (table_labels == 1) & attention_mask.bool()
1210
+ if table_mask.any():
1211
+ row_logits_flat = row_logits.view(-1, 3)
1212
+ row_labels_flat = row_labels.view(-1)
1213
+ table_mask_flat = table_mask.view(-1)
1214
+ row_loss = F.cross_entropy(
1215
+ row_logits_flat[table_mask_flat],
1216
+ row_labels_flat[table_mask_flat],
1217
+ )
1218
+ losses["row_loss"] = row_loss
1219
+ else:
1220
+ losses["row_loss"] = torch.tensor(0.0, device=device)
1221
+ else:
1222
+ losses["row_loss"] = torch.tensor(0.0, device=device)
1223
+
1224
+ if col_labels is not None and table_labels is not None and header_labels is not None:
1225
+ col_scores = outputs["col_scores"]
1226
+ cell_mask = (table_labels == 1) & (header_labels == 0) & attention_mask.bool()
1227
+ if cell_mask.any():
1228
+ col_scores_flat = col_scores.view(-1, col_scores.shape[-1])
1229
+ col_labels_flat = col_labels.view(-1)
1230
+ cell_mask_flat = cell_mask.view(-1)
1231
+ col_loss = F.cross_entropy(
1232
+ col_scores_flat[cell_mask_flat],
1233
+ col_labels_flat[cell_mask_flat],
1234
+ )
1235
+ losses["col_loss"] = col_loss
1236
+ else:
1237
+ losses["col_loss"] = torch.tensor(0.0, device=device)
1238
+ else:
1239
+ losses["col_loss"] = torch.tensor(0.0, device=device)
1240
+
1241
+ # Total loss with weighted components
1242
+ total_loss = (
1243
+ losses["step1_loss"] + # Token classification (main task)
1244
+ losses["qa_link_loss"] * 0.3 + # Q-A linker (reduced to not compete with step1)
1245
+ losses["q_predict_loss"] * 0.2 + # Q prediction aux (reduced)
1246
+ losses["step2_loss"] * 0.5 + # Query-Q matching (reduced)
1247
+ losses["table_det_loss"] * 0.5 + # Table detection
1248
+ losses["header_loss"] * 0.3 + # Header detection
1249
+ losses["row_loss"] * 0.3 + # Row segmentation
1250
+ losses["col_loss"] * 0.3 # Column assignment
1251
+ )
1252
+ losses["loss"] = total_loss
1253
+
1254
+ return losses
1255
+
1256
+
1257
+ def compute_contrastive_loss(
1258
+ query_embedding: torch.Tensor, # (batch, hidden)
1259
+ span_embeddings: torch.Tensor, # (batch, num_spans, span_dim)
1260
+ span_mask: torch.Tensor, # (batch, num_spans)
1261
+ match_labels: torch.Tensor, # (batch,) - index of correct span
1262
+ temperature: float = 0.07,
1263
+ hard_negative_weight: float = 2.0,
1264
+ ) -> torch.Tensor:
1265
+ """
1266
+ Contrastive loss for query-answer matching.
1267
+
1268
+ InfoNCE-style loss that:
1269
+ - Pulls query towards correct answer span
1270
+ - Pushes query away from all other spans (in-batch negatives)
1271
+ - Applies extra weight to hard negatives (high-scoring wrong answers)
1272
+
1273
+ Args:
1274
+ query_embedding: Query [CLS] embeddings
1275
+ span_embeddings: Pooled span embeddings [h_start; h_end; bbox]
1276
+ span_mask: Valid span mask
1277
+ match_labels: Index of correct span for each query
1278
+ temperature: Softmax temperature (lower = harder)
1279
+ hard_negative_weight: Extra weight for hard negatives
1280
+
1281
+ Returns:
1282
+ Contrastive loss scalar
1283
+ """
1284
+ batch_size = query_embedding.shape[0]
1285
+ device = query_embedding.device
1286
+
1287
+ if batch_size == 0:
1288
+ return torch.tensor(0.0, device=device)
1289
+
1290
+ # Project spans to same dimension as query if needed
1291
+ # (This is already done by QueryAnswerMatcher, so we use the projected scores)
1292
+ # Here we compute a simpler version using the match_scores from forward
1293
+
1294
+ # For now, use a margin-based contrastive approach
1295
+ # We want: sim(q, correct_span) > sim(q, wrong_span) + margin
1296
+
1297
+ losses = []
1298
+ for b in range(batch_size):
1299
+ valid_spans = span_mask[b].sum().item()
1300
+ if valid_spans <= 1:
1301
+ continue
1302
+
1303
+ correct_idx = match_labels[b].item()
1304
+ if correct_idx >= valid_spans:
1305
+ continue
1306
+
1307
+ # Get span embeddings for this sample
1308
+ spans = span_embeddings[b, :int(valid_spans), :] # (num_valid, dim)
1309
+ query = query_embedding[b] # (hidden,)
1310
+
1311
+ # Simple approach: normalize and compute similarities
1312
+ # The projection happens in QueryAnswerMatcher, so we compute raw similarities here
1313
+ # This loss is auxiliary to the CE loss
1314
+
1315
+ # Normalize for cosine similarity
1316
+ spans_norm = F.normalize(spans[:, :query.shape[0]], dim=-1) # Use first hidden_size dims
1317
+ query_norm = F.normalize(query, dim=-1)
1318
+
1319
+ # Compute similarities
1320
+ sims = torch.matmul(spans_norm, query_norm) / temperature # (num_valid,)
1321
+
1322
+ # Create target (one-hot)
1323
+ target = torch.zeros(int(valid_spans), device=device)
1324
+ target[correct_idx] = 1.0
1325
+
1326
+ # InfoNCE: -log(exp(sim_pos) / sum(exp(sim_all)))
1327
+ # With hard negative weighting
1328
+ weights = torch.ones(int(valid_spans), device=device)
1329
+ weights[correct_idx] = 1.0
1330
+
1331
+ # Hard negatives: spans with high similarity but wrong
1332
+ with torch.no_grad():
1333
+ hard_neg_mask = (sims > sims[correct_idx] - 0.5) & (torch.arange(int(valid_spans), device=device) != correct_idx)
1334
+ weights[hard_neg_mask] = hard_negative_weight
1335
+
1336
+ # Weighted softmax cross-entropy
1337
+ log_probs = F.log_softmax(sims, dim=0)
1338
+ loss = -log_probs[correct_idx]
1339
+
1340
+ # Add margin loss for hard negatives
1341
+ for i in range(int(valid_spans)):
1342
+ if i != correct_idx and hard_neg_mask[i] if hard_neg_mask.any() else False:
1343
+ margin_loss = F.relu(sims[i] - sims[correct_idx] + 0.3)
1344
+ loss = loss + 0.1 * margin_loss
1345
+
1346
+ losses.append(loss)
1347
+
1348
+ if not losses:
1349
+ return torch.tensor(0.0, device=device)
1350
+
1351
+ return torch.stack(losses).mean()
1352
+
1353
+
1354
+ def compute_loss_with_contrastive(
1355
+ outputs: Dict[str, torch.Tensor],
1356
+ token_labels: torch.Tensor,
1357
+ attention_mask: torch.Tensor,
1358
+ table_labels: Optional[torch.Tensor] = None,
1359
+ row_labels: Optional[torch.Tensor] = None,
1360
+ header_labels: Optional[torch.Tensor] = None,
1361
+ match_labels: Optional[torch.Tensor] = None,
1362
+ col_labels: Optional[torch.Tensor] = None,
1363
+ class_weights: Optional[torch.Tensor] = None,
1364
+ qa_link_labels: Optional[List[torch.Tensor]] = None,
1365
+ contrastive_weight: float = 0.5,
1366
+ ) -> Dict[str, torch.Tensor]:
1367
+ """
1368
+ Compute joint loss with contrastive learning for Step 2.
1369
+
1370
+ Adds InfoNCE-style contrastive loss to help learn better span representations.
1371
+ """
1372
+ # Get base losses
1373
+ losses = compute_loss(
1374
+ outputs=outputs,
1375
+ token_labels=token_labels,
1376
+ attention_mask=attention_mask,
1377
+ table_labels=table_labels,
1378
+ row_labels=row_labels,
1379
+ header_labels=header_labels,
1380
+ match_labels=match_labels,
1381
+ col_labels=col_labels,
1382
+ class_weights=class_weights,
1383
+ qa_link_labels=qa_link_labels,
1384
+ )
1385
+
1386
+ # Contrastive loss for Step 2 is DISABLED for zero-shot matching
1387
+ # Zero-shot matching - no contrastive loss needed
1388
+ losses["contrastive_loss"] = torch.tensor(0.0, device=token_labels.device)
1389
+
1390
+ return losses
1391
+
1392
+
1393
+ if __name__ == "__main__":
1394
+ print("Testing DEXTR LiLT model with CASCADING architecture...")
1395
+ print("=" * 60)
1396
+
1397
+ model = DEXTRLiLT(
1398
+ model_name="nielsr/lilt-xlm-roberta-base",
1399
+ max_seq_len=1024,
1400
+ )
1401
+
1402
+ total_params = sum(p.numel() for p in model.parameters())
1403
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
1404
+ print(f"Total parameters: {total_params:,}")
1405
+ print(f"Trainable parameters: {trainable_params:,}")
1406
+
1407
+ # Test forward
1408
+ batch_size = 2
1409
+ seq_len = 128
1410
+ query_len = 16
1411
+
1412
+ input_ids = torch.randint(0, 1000, (batch_size, seq_len))
1413
+ attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long)
1414
+ # Create valid bboxes: [x1, y1, x2, y2] where x1<x2 and y1<y2, values in [0, 1000]
1415
+ x1 = torch.randint(0, 500, (batch_size, seq_len))
1416
+ y1 = torch.randint(0, 500, (batch_size, seq_len))
1417
+ x2 = x1 + torch.randint(10, 500, (batch_size, seq_len))
1418
+ y2 = y1 + torch.randint(10, 500, (batch_size, seq_len))
1419
+ # Clamp to valid range [0, 1000]
1420
+ bbox = torch.stack([x1, y1, x2.clamp(max=1000), y2.clamp(max=1000)], dim=-1)
1421
+ query_input_ids = torch.randint(0, 1000, (batch_size, query_len))
1422
+ query_attention_mask = torch.ones(batch_size, query_len, dtype=torch.long)
1423
+
1424
+ # GT spans (only used for loss labels in cascading architecture)
1425
+ gt_answer_spans = [[(5, 10), (20, 25)], [(8, 12)]]
1426
+ gt_question_spans = [[(2, 5), (17, 20)], [(5, 8)]]
1427
+
1428
+ print("\nRunning forward pass (CASCADING architecture)...")
1429
+ print(" - Step 1: Token classification → predict Q/A/H/TABLE/O")
1430
+ print(" - Step 1.5: Extract predicted regions, QALinker pairs them")
1431
+ print(" - Step 2: Query matches to predicted Q embeddings")
1432
+ model.eval()
1433
+ with torch.no_grad():
1434
+ outputs = model(
1435
+ input_ids=input_ids,
1436
+ attention_mask=attention_mask,
1437
+ bbox=bbox,
1438
+ query_input_ids=query_input_ids,
1439
+ query_attention_mask=query_attention_mask,
1440
+ gt_answer_spans=gt_answer_spans,
1441
+ gt_question_spans=gt_question_spans,
1442
+ )
1443
+
1444
+ print(f"\nOutput shapes:")
1445
+ print(f" token_logits: {outputs['token_logits'].shape}")
1446
+ print(f" table_logits: {outputs['table_logits'].shape}")
1447
+ print(f" header_logits: {outputs['header_logits'].shape}")
1448
+ print(f" row_logits: {outputs['row_logits'].shape}")
1449
+ print(f" col_scores: {outputs['col_scores'].shape}")
1450
+ print(f" query_embedding: {outputs['query_embedding'].shape}")
1451
+
1452
+ # Predicted regions from Step 1
1453
+ print(f"\nPredicted regions from Step 1:")
1454
+ print(f" pred_q_regions: {outputs['pred_q_regions']}")
1455
+ print(f" pred_a_regions: {outputs['pred_a_regions']}")
1456
+
1457
+ # Match scores may not exist if no regions predicted
1458
+ if "match_scores" in outputs:
1459
+ print(f" match_scores: {outputs['match_scores'].shape}")
1460
+ else:
1461
+ print(" match_scores: None (no regions predicted)")
1462
+
1463
+ # Test loss
1464
+ print("\nTesting loss computation...")
1465
+ token_labels = torch.randint(0, NUM_LABELS, (batch_size, seq_len))
1466
+ table_labels = torch.randint(0, 2, (batch_size, seq_len))
1467
+ row_labels = torch.randint(0, 3, (batch_size, seq_len))
1468
+ header_labels = torch.randint(0, 2, (batch_size, seq_len))
1469
+
1470
+ # For cascading: match_labels should point to predicted Q region index
1471
+ # This would normally be computed by matching GT fields to predicted regions
1472
+ match_labels = None
1473
+ if "match_scores" in outputs and outputs["match_scores"].shape[1] > 0:
1474
+ match_labels = torch.zeros(batch_size, dtype=torch.long)
1475
+
1476
+ losses = compute_loss(
1477
+ outputs,
1478
+ token_labels=token_labels,
1479
+ attention_mask=attention_mask,
1480
+ table_labels=table_labels,
1481
+ row_labels=row_labels,
1482
+ header_labels=header_labels,
1483
+ match_labels=match_labels,
1484
+ )
1485
+
1486
+ print(f"\nLosses:")
1487
+ for name, value in losses.items():
1488
+ print(f" {name}: {value.item():.4f}")
1489
+
1490
+ print("\n" + "=" * 60)
1491
+ print("DEXTR LiLT CASCADING architecture test passed!")