dextr-lilt / model_lilt.py
satya007's picture
Upload model_lilt.py with huggingface_hub
f34d690 verified
"""
DEXTR with LiLT - 3-Step Document Extraction.
Architecture:
- Document encoder: LiLT (XLM-RoBERTa + Layout Transformer), extended to 1024 tokens
- Query encoder: Sentence Transformer (frozen, for semantic similarity)
- Step 1: Token Classification (Q/A/T/H/O) - TRAINED
- Step 2: Query-Question Matching (ZERO-SHOT) - NO TRAINING
- Step 3: Table Head (hierarchical with attention-based column assignment) - TRAINED
Key insight: Query→Question matching is semantic similarity. Sentence transformers
provide proper semantic embeddings for matching queries to question text.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import LiltModel, LiltConfig
from sentence_transformers import SentenceTransformer
from typing import Optional, Dict, List, Tuple
# Label mappings for Step 1
LABEL2ID = {
"O": 0,
"B-Q": 1, "I-Q": 2, # Question
"B-A": 3, "I-A": 4, # Answer
"B-H": 5, "I-H": 6, # Header
"B-TABLE": 7, "I-TABLE": 8, # Table
}
ID2LABEL = {v: k for k, v in LABEL2ID.items()}
NUM_LABELS = len(LABEL2ID)
class TokenClassificationHead(nn.Module):
"""
Step 1: Token Classification Head.
Predicts Q/A/H/T/O labels for each token (FUNSD-style).
"""
def __init__(self, hidden_size: int = 768, num_labels: int = NUM_LABELS, dropout: float = 0.1):
super().__init__()
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(hidden_size, 256),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(256, num_labels),
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.classifier(hidden_states)
class QALinker(nn.Module):
"""
Step 1.5: Q-A Linker.
Predicts which Answer span links to which Question span.
Uses both semantic similarity and spatial features.
"""
def __init__(self, hidden_size: int = 768, dropout: float = 0.1):
super().__init__()
self.hidden_size = hidden_size
# Project Q and A to same space for semantic matching
self.q_proj = nn.Sequential(
nn.Linear(hidden_size, 256),
nn.LayerNorm(256),
nn.GELU(),
nn.Dropout(dropout),
)
self.a_proj = nn.Sequential(
nn.Linear(hidden_size, 256),
nn.LayerNorm(256),
nn.GELU(),
nn.Dropout(dropout),
)
# Spatial feature scorer
# Features: x_dist, y_dist, is_right, is_below, is_same_line, width_ratio, height_ratio
self.spatial_scorer = nn.Sequential(
nn.Linear(7, 32),
nn.GELU(),
nn.Linear(32, 1),
)
# Learnable temperature
self.log_temp = nn.Parameter(torch.tensor(1.0))
def compute_spatial_features(
self,
q_bboxes: torch.Tensor, # (num_q, 4) - [x1, y1, x2, y2]
a_bboxes: torch.Tensor, # (num_a, 4)
) -> torch.Tensor:
"""
Compute spatial features for all Q-A pairs.
Returns: (num_q, num_a, 7) feature tensor
"""
num_q = q_bboxes.shape[0]
num_a = a_bboxes.shape[0]
device = q_bboxes.device
# Compute centers and sizes
q_cx = (q_bboxes[:, 0] + q_bboxes[:, 2]) / 2 # (num_q,)
q_cy = (q_bboxes[:, 1] + q_bboxes[:, 3]) / 2
q_w = q_bboxes[:, 2] - q_bboxes[:, 0]
q_h = q_bboxes[:, 3] - q_bboxes[:, 1]
a_cx = (a_bboxes[:, 0] + a_bboxes[:, 2]) / 2 # (num_a,)
a_cy = (a_bboxes[:, 1] + a_bboxes[:, 3]) / 2
a_w = a_bboxes[:, 2] - a_bboxes[:, 0]
a_h = a_bboxes[:, 3] - a_bboxes[:, 1]
# Expand for pairwise computation
q_cx = q_cx.unsqueeze(1).expand(num_q, num_a) # (num_q, num_a)
q_cy = q_cy.unsqueeze(1).expand(num_q, num_a)
q_x2 = q_bboxes[:, 2].unsqueeze(1).expand(num_q, num_a)
q_y2 = q_bboxes[:, 3].unsqueeze(1).expand(num_q, num_a)
q_w = q_w.unsqueeze(1).expand(num_q, num_a)
q_h = q_h.unsqueeze(1).expand(num_q, num_a)
a_cx = a_cx.unsqueeze(0).expand(num_q, num_a)
a_cy = a_cy.unsqueeze(0).expand(num_q, num_a)
a_x1 = a_bboxes[:, 0].unsqueeze(0).expand(num_q, num_a)
a_y1 = a_bboxes[:, 1].unsqueeze(0).expand(num_q, num_a)
a_w = a_w.unsqueeze(0).expand(num_q, num_a)
a_h = a_h.unsqueeze(0).expand(num_q, num_a)
# Compute features (all normalized to [0, 1] range roughly)
x_dist = (a_x1 - q_x2) / 1000.0 # Horizontal distance (positive = A right of Q)
y_dist = (a_cy - q_cy).abs() / 1000.0 # Vertical distance
is_right = (a_x1 > q_x2).float() # A is to the right of Q
is_below = (a_y1 > q_y2).float() # A is below Q
is_same_line = (y_dist < 0.03).float() # Same line (small y distance)
width_ratio = (a_w / (q_w + 1e-6)).clamp(0, 5) / 5.0 # Relative width
height_ratio = (a_h / (q_h + 1e-6)).clamp(0, 5) / 5.0 # Relative height
# Stack features: (num_q, num_a, 7)
features = torch.stack([
x_dist, y_dist, is_right, is_below, is_same_line, width_ratio, height_ratio
], dim=-1)
return features
def forward(
self,
q_embeds: torch.Tensor, # (num_q, hidden)
a_embeds: torch.Tensor, # (num_a, hidden)
q_bboxes: torch.Tensor, # (num_q, 4)
a_bboxes: torch.Tensor, # (num_a, 4)
) -> torch.Tensor:
"""
Compute Q-A link scores.
Returns: (num_q, num_a) score matrix
"""
# Semantic similarity
q_proj = self.q_proj(q_embeds) # (num_q, 256)
a_proj = self.a_proj(a_embeds) # (num_a, 256)
q_proj = F.normalize(q_proj, dim=-1)
a_proj = F.normalize(a_proj, dim=-1)
semantic_scores = torch.matmul(q_proj, a_proj.t()) # (num_q, num_a)
# Spatial scores
spatial_feats = self.compute_spatial_features(q_bboxes, a_bboxes) # (num_q, num_a, 7)
spatial_scores = self.spatial_scorer(spatial_feats).squeeze(-1) # (num_q, num_a)
# Combine with learnable temperature
temperature = self.log_temp.exp().clamp(min=0.1, max=10.0)
combined_scores = (semantic_scores + spatial_scores) * temperature
return combined_scores
class QuestionPredictor(nn.Module):
"""
Predicts Question embedding from Answer embedding.
Used when explicit Question tokens are missing in the document.
Input: LiLT A embedding (768 dim)
Output: Predicted Q embedding in sentence transformer space (384 dim)
Training:
- For Q-A pairs with Q: randomly mask Q and train to predict it
- Learns what Q "should look like" given the answer
"""
def __init__(self, input_dim: int = 768, output_dim: int = 384, dropout: float = 0.1):
super().__init__()
# Input: LiLT answer embedding (768)
# Output: predicted question embedding in sentence transformer space (384)
self.predictor = nn.Sequential(
nn.Linear(input_dim, input_dim),
nn.LayerNorm(input_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(input_dim, output_dim),
nn.LayerNorm(output_dim),
)
def forward(self, answer_embed: torch.Tensor) -> torch.Tensor:
"""
Predict question embedding from answer embedding.
Args:
answer_embed: (batch, hidden) or (hidden,) - LiLT answer span embedding
Returns:
predicted_q_embed: (batch, output_dim) or (output_dim,) - in sentence transformer space
"""
return self.predictor(answer_embed)
class QueryAnswerMatcher(nn.Module):
"""
Step 2: Query-Answer Matching.
Given answer span embeddings and query embedding, scores which span matches.
Features:
- Start+End pooling: spans use [h_start; h_end] instead of mean (captures boundaries)
- Bbox features: spatial position helps distinguish similar text
- Learnable temperature for score scaling
- Cosine similarity with projection to lower dim
"""
def __init__(
self,
hidden_size: int = 768,
proj_dim: int = 256,
dropout: float = 0.1,
use_bbox_features: bool = True,
num_bbox_features: int = 6, # x1, y1, x2, y2, width, height
):
super().__init__()
self.proj_dim = proj_dim
self.use_bbox_features = use_bbox_features
self.hidden_size = hidden_size
# Span input: [start; end] = 2*hidden, optionally + bbox
span_input_dim = 2 * hidden_size
if use_bbox_features:
span_input_dim += num_bbox_features
self.span_proj = nn.Sequential(
nn.Linear(span_input_dim, proj_dim),
nn.LayerNorm(proj_dim),
nn.GELU(),
nn.Dropout(dropout),
)
self.query_proj = nn.Sequential(
nn.Linear(hidden_size, proj_dim),
nn.LayerNorm(proj_dim),
nn.GELU(),
nn.Dropout(dropout),
)
# Learnable temperature (initialized to sqrt(proj_dim) like standard attention)
self.log_temp = nn.Parameter(torch.tensor(proj_dim ** 0.5).log())
def forward(
self,
span_embeddings: torch.Tensor, # (batch, num_spans, span_input_dim)
query_embedding: torch.Tensor, # (batch, hidden)
span_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
span_proj = self.span_proj(span_embeddings)
query_proj = self.query_proj(query_embedding).unsqueeze(1)
# Normalize for cosine similarity-like behavior
span_proj = F.normalize(span_proj, dim=-1)
query_proj = F.normalize(query_proj, dim=-1)
# Scaled dot product with learnable temperature
temperature = self.log_temp.exp().clamp(min=0.1, max=100.0)
scores = torch.bmm(query_proj, span_proj.transpose(1, 2)).squeeze(1) * temperature
if span_mask is not None:
scores = scores.masked_fill(~span_mask.bool(), float('-inf'))
return scores
class TableHead(nn.Module):
"""
Step 3: Hierarchical Table Extraction Head.
Sub-steps:
1. Table detection (TABLE/O per token)
2. Header detection (HEADER/CELL per token)
3. Row segmentation (B-ROW/I-ROW/O)
4. Column assignment (cell → header via attention)
"""
def __init__(self, hidden_size: int = 768, dropout: float = 0.1):
super().__init__()
self.table_detector = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(hidden_size, 256),
nn.GELU(),
nn.Linear(256, 2),
)
self.header_detector = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(hidden_size, 256),
nn.GELU(),
nn.Linear(256, 2),
)
self.row_tagger = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(hidden_size, 256),
nn.GELU(),
nn.Linear(256, 3),
)
self.col_key_proj = nn.Linear(hidden_size, 128)
self.col_query_proj = nn.Linear(hidden_size, 128)
def forward(self, hidden_states: torch.Tensor) -> Dict[str, torch.Tensor]:
table_logits = self.table_detector(hidden_states)
header_logits = self.header_detector(hidden_states)
row_logits = self.row_tagger(hidden_states)
col_keys = self.col_key_proj(hidden_states)
col_queries = self.col_query_proj(hidden_states)
col_scores = torch.bmm(col_queries, col_keys.transpose(1, 2)) / (128 ** 0.5)
return {
"table_logits": table_logits,
"header_logits": header_logits,
"row_logits": row_logits,
"col_scores": col_scores,
}
class DEXTRLiLT(nn.Module):
"""
DEXTR with LiLT: Multi-Step Document Extraction with Zero-Shot Query Matching.
Architecture:
- Step 1: Token Classification (Q/A/H/T/O) - TRAINED
- Step 1.5: Q-A Linker (links Question spans to Answer spans) - TRAINED
- Step 2: Query → Question Matching (ZERO-SHOT with Sentence Transformer)
- Step 3: Table Head - TRAINED
Uses SEPARATE encoders:
- Document encoder: LiLT (XLM-RoBERTa + Layout Transformer) for layout-aware encoding
- Query encoder: Sentence Transformer (frozen) for semantic similarity
Zero-shot matching:
- Query text encoded with frozen Sentence Transformer
- Question text (from predicted Q regions) encoded with frozen Sentence Transformer
- Direct cosine similarity (no trainable projections)
- Fallback: if no Q regions, use QuestionPredictor on A text
"""
def __init__(
self,
model_name: str = "nielsr/lilt-xlm-roberta-base",
query_model_name: str = "paraphrase-multilingual-MiniLM-L12-v2",
max_seq_len: int = 1024,
hidden_size: int = 768,
dropout: float = 0.1,
q_mask_prob: float = 0.3, # Probability of masking Q during training (for Q-A linker)
):
super().__init__()
self.max_seq_len = max_seq_len
self.hidden_size = hidden_size
self.q_mask_prob = q_mask_prob
# Document encoder: LiLT (layout-aware)
if max_seq_len > 512:
self.encoder = self._load_lilt_extended(model_name, max_seq_len)
else:
self.encoder = LiltModel.from_pretrained(model_name)
# Query encoder: Sentence Transformer (frozen, zero-shot)
# Provides proper semantic similarity for query→question matching
self.query_encoder = SentenceTransformer(query_model_name)
self.query_encoder.requires_grad_(False) # Freeze for zero-shot
self.query_embed_dim = self.query_encoder.get_sentence_embedding_dimension()
print(f"Loaded frozen query encoder: {query_model_name} (dim={self.query_embed_dim})")
# Step 1: Token Classification Head
self.token_classifier = TokenClassificationHead(hidden_size, NUM_LABELS, dropout)
# Step 1.5: Q-A Linker (links Q spans to A spans)
self.qa_linker = QALinker(hidden_size, dropout)
# Step 2: QuestionPredictor for documents without explicit Q labels (e.g., receipts)
# Input: LiLT A embedding (768 dim), Output: Q embedding in sentence transformer space (384 dim)
self.question_predictor = QuestionPredictor(hidden_size, self.query_embed_dim, dropout)
# Step 3: Table Head
self.table_head = TableHead(hidden_size, dropout)
def _load_lilt_extended(self, model_name: str, max_seq_len: int) -> LiltModel:
"""Load LiLT with extended position embeddings."""
model = LiltModel.from_pretrained(model_name)
original_max_pos = model.config.max_position_embeddings
required_positions = max_seq_len + 2
if required_positions <= original_max_pos:
return model
# 1. Extend text position embeddings
old_pos_emb = model.embeddings.position_embeddings.weight.data
hidden_size = old_pos_emb.shape[1]
padding_idx = model.embeddings.position_embeddings.padding_idx
new_pos_emb = nn.Embedding(required_positions, hidden_size, padding_idx=padding_idx)
new_pos_emb.weight.data[:original_max_pos] = old_pos_emb
new_pos_emb.weight.data[original_max_pos:].normal_(mean=0.0, std=0.02)
model.embeddings.position_embeddings = new_pos_emb
# 2. Extend layout box_position_embeddings (same sequence length limit)
old_box_pos = model.layout_embeddings.box_position_embeddings.weight.data
box_hidden = old_box_pos.shape[1] # 192
new_box_pos = nn.Embedding(required_positions, box_hidden)
new_box_pos.weight.data[:original_max_pos] = old_box_pos
new_box_pos.weight.data[original_max_pos:].normal_(mean=0.0, std=0.02)
model.layout_embeddings.box_position_embeddings = new_box_pos
# 3. Update config
model.config.max_position_embeddings = required_positions
# 4. Extend position_ids buffer
new_position_ids = torch.arange(required_positions).unsqueeze(0)
model.embeddings.register_buffer("position_ids", new_position_ids, persistent=False)
print(f"Extended LiLT positions: {original_max_pos}{required_positions}")
return model
def encode(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
bbox: torch.Tensor,
) -> torch.Tensor:
"""Encode tokens with layout using shared LiLT encoder."""
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
bbox=bbox,
)
return outputs.last_hidden_state
def encode_texts(
self,
texts: List[str],
device: torch.device,
) -> torch.Tensor:
"""
Encode text strings using frozen Sentence Transformer (zero-shot).
Used for encoding Q text extracted from predicted regions.
Args:
texts: List of text strings to encode
device: Device to put tensors on
Returns:
embeddings: (num_texts, embed_dim) tensor
"""
if not texts:
return torch.zeros(0, self.query_embed_dim, device=device)
with torch.no_grad():
embeddings = self.query_encoder.encode(
texts,
convert_to_tensor=True,
device=device,
)
# Clone to exit inference mode (needed for autograd compatibility)
embeddings = embeddings.clone()
return embeddings # (num_texts, embed_dim)
def pool_spans(
self,
hidden_states: torch.Tensor,
span_indices: List[List[Tuple[int, int]]],
bbox: Optional[torch.Tensor] = None,
use_bbox_features: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Pool hidden states for answer spans using start+end tokens.
Returns [h_start; h_end; bbox_features] for each span instead of mean pooling.
This preserves boundary information which is crucial for extraction tasks.
Args:
hidden_states: (batch, seq, hidden) - document encodings
span_indices: List of (start, end) tuples per batch item
bbox: (batch, seq, 4) - bounding boxes [x1, y1, x2, y2] normalized to [0, 1000]
use_bbox_features: whether to include bbox in span representation
Returns:
span_embeddings: (batch, max_spans, span_dim) where span_dim = 2*hidden + 6 (if bbox)
span_mask: (batch, max_spans) bool mask
"""
batch_size = hidden_states.shape[0]
hidden_size = hidden_states.shape[2]
device = hidden_states.device
max_spans = max(len(spans) for spans in span_indices) if span_indices else 1
max_spans = max(max_spans, 1)
# Output dimension: [h_start; h_end] + optionally [x1, y1, x2, y2, width, height]
span_dim = 2 * hidden_size
if use_bbox_features and bbox is not None:
span_dim += 6 # normalized bbox features
span_embeddings = torch.zeros(batch_size, max_spans, span_dim, device=device)
span_mask = torch.zeros(batch_size, max_spans, dtype=torch.bool, device=device)
for b, spans in enumerate(span_indices):
for s, (start, end) in enumerate(spans):
if s >= max_spans:
break
if start >= hidden_states.shape[1] or end > hidden_states.shape[1]:
continue
# Start+End pooling: [h_start; h_end]
h_start = hidden_states[b, start, :]
h_end = hidden_states[b, end - 1, :] # end is exclusive
span_repr = torch.cat([h_start, h_end], dim=0)
# Add bbox features if available
if use_bbox_features and bbox is not None:
# Get bbox of span (union of start and end token bboxes)
bbox_start = bbox[b, start, :] # [x1, y1, x2, y2]
bbox_end = bbox[b, end - 1, :]
# Compute span bounding box (min x1/y1, max x2/y2)
span_x1 = torch.min(bbox_start[0], bbox_end[0])
span_y1 = torch.min(bbox_start[1], bbox_end[1])
span_x2 = torch.max(bbox_start[2], bbox_end[2])
span_y2 = torch.max(bbox_start[3], bbox_end[3])
# Normalize to [0, 1] and add width/height
bbox_feat = torch.tensor([
span_x1 / 1000.0,
span_y1 / 1000.0,
span_x2 / 1000.0,
span_y2 / 1000.0,
(span_x2 - span_x1) / 1000.0, # width
(span_y2 - span_y1) / 1000.0, # height
], device=device)
span_repr = torch.cat([span_repr, bbox_feat], dim=0)
span_embeddings[b, s] = span_repr
span_mask[b, s] = True
return span_embeddings, span_mask
def pool_single_span(
self,
hidden_states: torch.Tensor, # (seq, hidden) - single sample
span: Tuple[int, int],
bbox: Optional[torch.Tensor] = None, # (seq, 4)
) -> torch.Tensor:
"""
Pool a single span to get its embedding.
Returns mean of start and end tokens (hidden_size,).
"""
start, end = span
if start >= hidden_states.shape[0] or end > hidden_states.shape[0]:
return torch.zeros(self.hidden_size, device=hidden_states.device)
h_start = hidden_states[start, :]
h_end = hidden_states[end - 1, :]
# Mean of start and end for Q embedding (simpler than concat for matching)
return (h_start + h_end) / 2
def get_span_bbox(
self,
bbox: torch.Tensor, # (seq, 4)
span: Tuple[int, int],
) -> torch.Tensor:
"""Get bounding box for a span (union of start and end token bboxes)."""
start, end = span
bbox_start = bbox[start, :]
bbox_end = bbox[end - 1, :]
span_bbox = torch.stack([
torch.min(bbox_start[0], bbox_end[0]), # x1
torch.min(bbox_start[1], bbox_end[1]), # y1
torch.max(bbox_start[2], bbox_end[2]), # x2
torch.max(bbox_start[3], bbox_end[3]), # y2
])
return span_bbox
def extract_span_text(
self,
span: Tuple[int, int],
tokens: List[str],
subword_to_word: Dict[int, int],
) -> str:
"""
Extract text from a span using tokens and subword-to-word mapping.
Args:
span: (start, end) subword indices (exclusive end)
tokens: List of word tokens
subword_to_word: Dict mapping subword idx -> word idx
Returns:
Text string for the span
"""
start, end = span
# Get word indices for this span
word_indices = set()
for subword_idx in range(start, end):
if subword_idx in subword_to_word:
word_indices.add(subword_to_word[subword_idx])
if not word_indices:
return ""
# Get contiguous word range
min_word = min(word_indices)
max_word = max(word_indices)
# Extract and join tokens
span_tokens = tokens[min_word:max_word + 1]
return " ".join(span_tokens)
def extract_qa_regions(
self,
token_logits: torch.Tensor, # (batch, seq, num_labels)
attention_mask: torch.Tensor, # (batch, seq)
) -> Tuple[List[List[Tuple[int, int]]], List[List[Tuple[int, int]]]]:
"""
Extract Q and A regions from Step 1 token predictions.
This enables the cascading architecture where Step 2 uses
Step 1's predictions instead of ground truth spans.
Returns:
q_regions: List of Q spans per batch item [(start, end), ...]
a_regions: List of A spans per batch item [(start, end), ...]
"""
batch_size = token_logits.shape[0]
preds = token_logits.argmax(dim=-1) # (batch, seq)
q_regions = []
a_regions = []
for b in range(batch_size):
sample_q_regions = []
sample_a_regions = []
seq_len = attention_mask[b].sum().item()
pred_seq = preds[b, :int(seq_len)].cpu().tolist()
# Extract Q spans (B-Q=1, I-Q=2)
current_span = None
for i, label in enumerate(pred_seq):
if label == 1: # B-Q
if current_span is not None:
sample_q_regions.append(current_span)
current_span = (i, i + 1)
elif label == 2: # I-Q
if current_span is not None:
current_span = (current_span[0], i + 1)
else:
if current_span is not None:
sample_q_regions.append(current_span)
current_span = None
if current_span is not None:
sample_q_regions.append(current_span)
# Extract A spans (B-A=3, I-A=4)
current_span = None
for i, label in enumerate(pred_seq):
if label == 3: # B-A
if current_span is not None:
sample_a_regions.append(current_span)
current_span = (i, i + 1)
elif label == 4: # I-A
if current_span is not None:
current_span = (current_span[0], i + 1)
else:
if current_span is not None:
sample_a_regions.append(current_span)
current_span = None
if current_span is not None:
sample_a_regions.append(current_span)
q_regions.append(sample_q_regions)
a_regions.append(sample_a_regions)
return q_regions, a_regions
def match_regions_to_gt(
self,
pred_regions: List[Tuple[int, int]],
gt_regions: List[Tuple[int, int]],
) -> Tuple[List[int], List[int]]:
"""
Match predicted regions to GT regions by overlap.
Returns:
gt_to_pred: For each GT region, index of best matching pred region (-1 if none)
pred_to_gt: For each pred region, index of best matching GT region (-1 if none)
"""
gt_to_pred = []
for gt_start, gt_end in gt_regions:
best_pred_idx = -1
best_overlap = 0
for pred_idx, (pred_start, pred_end) in enumerate(pred_regions):
overlap_start = max(gt_start, pred_start)
overlap_end = min(gt_end, pred_end)
overlap = max(0, overlap_end - overlap_start)
if overlap > best_overlap:
best_overlap = overlap
best_pred_idx = pred_idx
gt_to_pred.append(best_pred_idx)
pred_to_gt = []
for pred_start, pred_end in pred_regions:
best_gt_idx = -1
best_overlap = 0
for gt_idx, (gt_start, gt_end) in enumerate(gt_regions):
overlap_start = max(gt_start, pred_start)
overlap_end = min(gt_end, pred_end)
overlap = max(0, overlap_end - overlap_start)
if overlap > best_overlap:
best_overlap = overlap
best_gt_idx = gt_idx
pred_to_gt.append(best_gt_idx)
return gt_to_pred, pred_to_gt
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
bbox: torch.Tensor,
tokens: Optional[List[List[str]]] = None,
subword_to_word: Optional[List[Dict[int, int]]] = None,
query_texts: Optional[List[str]] = None,
gt_answer_spans: Optional[List[List[Tuple[int, int]]]] = None,
gt_question_spans: Optional[List[List[Optional[Tuple[int, int]]]]] = None,
target_field_idx: Optional[List[int]] = None,
training: bool = True,
) -> Dict[str, torch.Tensor]:
"""
Forward pass with CASCADING architecture and ZERO-SHOT query matching.
Key features:
- Step 1: Token classification to predict Q/A regions
- Step 1.5: Q-A linking using LiLT embeddings + spatial features
- Step 2: ZERO-SHOT query→question matching using Sentence Transformer
- GT spans are only used to compute labels (which predicted region is correct)
Args:
input_ids: (batch, seq) token IDs
attention_mask: (batch, seq) attention mask
bbox: (batch, seq, 4) bounding boxes
tokens: List of word tokens per batch item (for zero-shot Q text encoding)
subword_to_word: List of dicts mapping subword idx to word idx
query_texts: List of raw query strings (for sentence transformer)
gt_answer_spans: GT answer spans per batch item (for loss labels only)
gt_question_spans: GT question spans per batch item (for loss labels only)
target_field_idx: Index of GT field that query should match (for loss labels)
training: whether in training mode (affects Q masking for Q-A linker)
"""
batch_size = input_ids.shape[0]
device = input_ids.device
# Encode document
hidden_states = self.encode(input_ids, attention_mask, bbox)
# Step 1: Token Classification
token_logits = self.token_classifier(hidden_states)
# Step 3: Table Head
table_outputs = self.table_head(hidden_states)
outputs = {
"hidden_states": hidden_states,
"token_logits": token_logits,
**table_outputs,
}
# Step 1.5 + Step 2: Q-A Linking and Query-Question Matching
# CASCADING: Extract Q/A regions from Step 1 PREDICTIONS
if query_texts is not None:
# Extract predicted Q and A regions from Step 1 output
pred_q_regions, pred_a_regions = self.extract_qa_regions(token_logits, attention_mask)
outputs["pred_q_regions"] = pred_q_regions
outputs["pred_a_regions"] = pred_a_regions
# Encode query using sentence transformer
query_emb = self.encode_texts(query_texts, device)
outputs["query_embedding"] = query_emb
# Process each sample in batch using PREDICTED regions
all_q_embeds_lilt = [] # Q embeddings from LiLT (768 dim) for QA Linker
all_q_embeds_st = [] # Q embeddings from sentence transformer (384 dim) for query matching
all_a_embeds = [] # A embeddings from predicted regions (768 dim)
all_q_bboxes = [] # Q bboxes for QA linker
all_a_bboxes = [] # A bboxes for QA linker
all_pred_q_from_a = [] # Predicted Q from A (for QuestionPredictor loss)
all_real_q_from_pred = [] # Real Q from predicted regions (for Q masking loss)
all_real_q_embeds = [] # Real Q embeds from GT (for aux loss)
all_gt_to_pred_a_idx = [] # Maps GT field idx to predicted A region idx
for b in range(batch_size):
sample_pred_q = pred_q_regions[b] # Predicted Q spans from Step 1
sample_pred_a = pred_a_regions[b] # Predicted A spans from Step 1
# Skip if no predicted regions
if len(sample_pred_a) == 0:
all_q_embeds_lilt.append(None)
all_q_embeds_st.append(None)
all_a_embeds.append(None)
all_q_bboxes.append(None)
all_a_bboxes.append(None)
all_pred_q_from_a.append(None)
all_real_q_from_pred.append(None)
all_real_q_embeds.append(None)
all_gt_to_pred_a_idx.append(None)
continue
# Pool A embeddings from predicted regions
sample_a_embeds = []
sample_a_bboxes = []
for a_span in sample_pred_a:
a_emb = self.pool_single_span(hidden_states[b], a_span, bbox[b])
sample_a_embeds.append(a_emb)
sample_a_bboxes.append(self.get_span_bbox(bbox[b], a_span))
# Two Q representations:
# 1. q_embeds_lilt (768) - LiLT pooled, for QA Linker
# 2. q_embeds_st (384) - Sentence transformer, for query matching
sample_q_embeds_lilt = [] # For QA Linker (768 dim)
sample_q_embeds_st = [] # For query matching (384 dim)
sample_q_bboxes = []
sample_real_q_from_pred = [] # Real Q embeds (ST) for QuestionPredictor loss
sample_pred_q_from_a = [] # Predicted Q embeds from A (for loss)
if len(sample_pred_q) > 0:
# We have predicted Q regions
# Get Q texts for sentence transformer encoding
q_texts = []
if tokens is not None and subword_to_word is not None:
for q_span in sample_pred_q:
q_text = self.extract_span_text(q_span, tokens[b], subword_to_word[b])
q_texts.append(q_text if q_text else "unknown")
# Batch encode all Q texts with sentence transformer
real_q_embeds_st = self.encode_texts(q_texts, device) # (num_q, 384)
else:
real_q_embeds_st = None
for i, q_span in enumerate(sample_pred_q):
# LiLT embedding for QA Linker
q_emb_lilt = self.pool_single_span(hidden_states[b], q_span, bbox[b])
sample_q_embeds_lilt.append(q_emb_lilt)
q_bbox = self.get_span_bbox(bbox[b], q_span)
sample_q_bboxes.append(q_bbox)
# Sentence transformer embedding for query matching
if real_q_embeds_st is not None:
real_q_emb_st = real_q_embeds_st[i] # 384 dim
else:
real_q_emb_st = None
# Q MASKING: during training, randomly mask Q and use predictor
should_mask = training and (torch.rand(1).item() < self.q_mask_prob)
if should_mask and i < len(sample_a_embeds):
# Mask: use QuestionPredictor instead of real Q for query matching
pred_q_emb = self.question_predictor(sample_a_embeds[i]) # 384 dim
sample_q_embeds_st.append(pred_q_emb)
sample_real_q_from_pred.append(real_q_emb_st) # Real Q for loss
sample_pred_q_from_a.append(pred_q_emb) # Predicted Q for loss
else:
# No mask: use real Q (sentence transformer encoded)
sample_q_embeds_st.append(real_q_emb_st if real_q_emb_st is not None else torch.zeros(self.query_embed_dim, device=device))
sample_real_q_from_pred.append(real_q_emb_st)
sample_pred_q_from_a.append(None) # No prediction, no loss
else:
# No Q regions predicted - use QuestionPredictor on all A
for a_emb in sample_a_embeds:
pred_q = self.question_predictor(a_emb) # 384 dim
sample_q_embeds_st.append(pred_q)
sample_pred_q_from_a.append(pred_q)
sample_real_q_from_pred.append(None) # No real Q from prediction
# No LiLT Q embeds when no Q predicted
sample_q_embeds_lilt = []
# Use A bboxes as proxy for Q bboxes
sample_q_bboxes = sample_a_bboxes.copy()
# Store real Q embeds from GT (for aux loss when Q was masked)
# Encode GT Q text with sentence transformer (384 dim)
sample_real_q = []
if gt_question_spans is not None and b < len(gt_question_spans) and tokens is not None and subword_to_word is not None:
gt_q_texts = []
gt_q_valid_indices = []
for idx, gt_q_span in enumerate(gt_question_spans[b]):
if gt_q_span is not None:
q_text = self.extract_span_text(gt_q_span, tokens[b], subword_to_word[b])
gt_q_texts.append(q_text if q_text else "unknown")
gt_q_valid_indices.append(idx)
# Batch encode GT Q texts
if gt_q_texts:
gt_q_embeds = self.encode_texts(gt_q_texts, device) # (num_valid, 384)
embed_idx = 0
for idx, gt_q_span in enumerate(gt_question_spans[b]):
if gt_q_span is not None:
sample_real_q.append(gt_q_embeds[embed_idx])
embed_idx += 1
else:
sample_real_q.append(None)
else:
sample_real_q = [None] * len(gt_question_spans[b])
# Match GT A spans to predicted A regions (for label computation)
gt_to_pred_a = None
if gt_answer_spans is not None and b < len(gt_answer_spans):
gt_a_spans = gt_answer_spans[b]
gt_to_pred_a, _ = self.match_regions_to_gt(sample_pred_a, gt_a_spans)
all_q_embeds_lilt.append(torch.stack(sample_q_embeds_lilt) if sample_q_embeds_lilt else None)
all_q_embeds_st.append(torch.stack(sample_q_embeds_st) if sample_q_embeds_st else None)
all_a_embeds.append(torch.stack(sample_a_embeds) if sample_a_embeds else None)
all_q_bboxes.append(torch.stack(sample_q_bboxes) if sample_q_bboxes else None)
all_a_bboxes.append(torch.stack(sample_a_bboxes) if sample_a_bboxes else None)
# For Q predictor loss: parallel lists with None for non-masked positions
all_pred_q_from_a.append(sample_pred_q_from_a if sample_pred_q_from_a else None)
all_real_q_from_pred.append(sample_real_q_from_pred if sample_real_q_from_pred else None)
all_real_q_embeds.append(sample_real_q if sample_real_q else None)
all_gt_to_pred_a_idx.append(gt_to_pred_a)
# Filter out None entries and compute outputs
# Use LiLT Q embeds for valid check (QA Linker needs them)
valid_indices_lilt = [i for i, q in enumerate(all_q_embeds_lilt) if q is not None]
# Also track valid ST Q embeds for query matching
valid_indices_st = [i for i, q in enumerate(all_q_embeds_st) if q is not None]
if valid_indices_lilt:
# Get max spans for padding (use LiLT Q for QA Linker)
max_q_spans_lilt = max(all_q_embeds_lilt[i].shape[0] for i in valid_indices_lilt)
max_a_spans = max(all_a_embeds[i].shape[0] for i in valid_indices_lilt)
# Create padded tensors for Q (LiLT, 768 dim) - for QA Linker
q_embeds_lilt_padded = torch.zeros(batch_size, max_q_spans_lilt, self.hidden_size, device=device)
q_bboxes_padded = torch.zeros(batch_size, max_q_spans_lilt, 4, device=device)
q_span_mask = torch.zeros(batch_size, max_q_spans_lilt, dtype=torch.bool, device=device)
# Create padded tensors for A (LiLT, 768 dim)
a_embeds_padded = torch.zeros(batch_size, max_a_spans, self.hidden_size, device=device)
a_bboxes_padded = torch.zeros(batch_size, max_a_spans, 4, device=device)
a_span_mask = torch.zeros(batch_size, max_a_spans, dtype=torch.bool, device=device)
for b in valid_indices_lilt:
nq = all_q_embeds_lilt[b].shape[0]
q_embeds_lilt_padded[b, :nq] = all_q_embeds_lilt[b]
q_bboxes_padded[b, :nq] = all_q_bboxes[b]
q_span_mask[b, :nq] = True
na = all_a_embeds[b].shape[0]
a_embeds_padded[b, :na] = all_a_embeds[b]
a_bboxes_padded[b, :na] = all_a_bboxes[b]
a_span_mask[b, :na] = True
# Step 1.5: Q-A Linker - predict which Q links to which A (uses LiLT embeds)
qa_link_scores_list = []
for b in valid_indices_lilt:
nq = all_q_embeds_lilt[b].shape[0]
na = all_a_embeds[b].shape[0]
if nq > 0 and na > 0:
link_scores = self.qa_linker(
all_q_embeds_lilt[b], # (num_q, 768) LiLT
all_a_embeds[b], # (num_a, 768) LiLT
all_q_bboxes[b], # (num_q, 4)
all_a_bboxes[b], # (num_a, 4)
)
qa_link_scores_list.append(link_scores)
else:
qa_link_scores_list.append(None)
outputs["qa_link_scores"] = qa_link_scores_list
# Create padded tensors for Q (sentence transformer, 384 dim) - for query matching
if valid_indices_st:
max_q_spans_st = max(all_q_embeds_st[i].shape[0] for i in valid_indices_st)
q_embeds_st_padded = torch.zeros(batch_size, max_q_spans_st, self.query_embed_dim, device=device)
q_span_mask_st = torch.zeros(batch_size, max_q_spans_st, dtype=torch.bool, device=device)
for b in valid_indices_st:
nq = all_q_embeds_st[b].shape[0]
q_embeds_st_padded[b, :nq] = all_q_embeds_st[b]
q_span_mask_st[b, :nq] = True
outputs["q_embeds_st"] = q_embeds_st_padded # For query matching
outputs["q_span_mask_st"] = q_span_mask_st
# GT-based QA link scores (for training - uses GT spans directly, LiLT embeddings)
if training and gt_question_spans is not None and gt_answer_spans is not None:
gt_qa_link_scores_list = []
gt_valid_q_indices_list = [] # Track which Q indices are valid (not None)
for b in range(batch_size):
gt_q_spans = gt_question_spans[b] if b < len(gt_question_spans) else []
gt_a_spans = gt_answer_spans[b] if b < len(gt_answer_spans) else []
# Filter valid Q spans (not None) and track indices
valid_q_indices = [i for i, q in enumerate(gt_q_spans) if q is not None]
valid_q_spans = [gt_q_spans[i] for i in valid_q_indices]
if len(valid_q_spans) > 0 and len(gt_a_spans) > 0:
# Pool embeddings from GT spans (LiLT, 768 dim)
gt_q_embeds = torch.stack([
self.pool_single_span(hidden_states[b], q_span, bbox[b])
for q_span in valid_q_spans
])
gt_a_embeds = torch.stack([
self.pool_single_span(hidden_states[b], a_span, bbox[b])
for a_span in gt_a_spans
])
gt_q_bboxes = torch.stack([
self.get_span_bbox(bbox[b], q_span)
for q_span in valid_q_spans
])
gt_a_bboxes = torch.stack([
self.get_span_bbox(bbox[b], a_span)
for a_span in gt_a_spans
])
# Compute QA link scores on GT
gt_link_scores = self.qa_linker(
gt_q_embeds, gt_a_embeds, gt_q_bboxes, gt_a_bboxes
)
gt_qa_link_scores_list.append(gt_link_scores)
gt_valid_q_indices_list.append(valid_q_indices)
else:
gt_qa_link_scores_list.append(None)
gt_valid_q_indices_list.append(None)
outputs["gt_qa_link_scores"] = gt_qa_link_scores_list
outputs["gt_valid_q_indices"] = gt_valid_q_indices_list
# Step 2: Query-Question Matching (ZERO-SHOT with Sentence Transformer)
# Use pre-computed q_embeds_st for query matching
if valid_indices_st and "q_embeds_st" in outputs:
q_embeds_st_padded = outputs["q_embeds_st"]
q_span_mask_st = outputs["q_span_mask_st"]
# Compute cosine similarity between query and Q embeddings
query_norm = F.normalize(query_emb, dim=-1)
q_norm = F.normalize(q_embeds_st_padded, dim=-1)
match_scores = torch.bmm(q_norm, query_norm.unsqueeze(-1)).squeeze(-1)
match_scores = match_scores.masked_fill(~q_span_mask_st, float('-inf'))
outputs["match_scores"] = match_scores
outputs["match_span_mask"] = q_span_mask_st
# Store LiLT embeddings for QA Linker
if valid_indices_lilt:
outputs["q_span_mask"] = q_span_mask
outputs["a_span_mask"] = a_span_mask
outputs["q_embeds"] = q_embeds_lilt_padded
outputs["a_embeds"] = a_embeds_padded
outputs["q_bboxes"] = q_bboxes_padded
outputs["a_bboxes"] = a_bboxes_padded
# Store for loss computation (Q masking / QuestionPredictor)
outputs["pred_q_from_a"] = all_pred_q_from_a
outputs["real_q_from_pred"] = all_real_q_from_pred # Real Q (ST) when masked
outputs["real_q_embeds"] = all_real_q_embeds # GT Q embeds (ST)
outputs["gt_to_pred_a_idx"] = all_gt_to_pred_a_idx
outputs["valid_batch_indices_lilt"] = valid_indices_lilt
outputs["valid_batch_indices_st"] = valid_indices_st
return outputs
def compute_loss(
outputs: Dict[str, torch.Tensor],
token_labels: torch.Tensor,
attention_mask: torch.Tensor,
table_labels: Optional[torch.Tensor] = None,
row_labels: Optional[torch.Tensor] = None,
header_labels: Optional[torch.Tensor] = None,
match_labels: Optional[torch.Tensor] = None,
col_labels: Optional[torch.Tensor] = None,
class_weights: Optional[torch.Tensor] = None,
qa_link_labels: Optional[List[torch.Tensor]] = None, # NEW: Q-A link labels
) -> Dict[str, torch.Tensor]:
"""Compute joint loss for all steps including Q-A linking."""
device = token_labels.device
losses = {}
# Step 1: Token Classification
token_logits = outputs["token_logits"]
token_logits_flat = token_logits.view(-1, token_logits.shape[-1])
token_labels_flat = token_labels.view(-1)
attn_flat = attention_mask.view(-1).bool()
if class_weights is not None:
step1_loss = F.cross_entropy(
token_logits_flat[attn_flat],
token_labels_flat[attn_flat],
weight=class_weights,
)
else:
step1_loss = F.cross_entropy(
token_logits_flat[attn_flat],
token_labels_flat[attn_flat],
)
losses["step1_loss"] = step1_loss
# Step 1.5: Q-A Linker Loss (uses GT spans directly - no dependency on Step 1 predictions)
if "gt_qa_link_scores" in outputs and qa_link_labels is not None:
qa_link_losses = []
gt_scores_list = outputs["gt_qa_link_scores"]
gt_valid_q_indices = outputs.get("gt_valid_q_indices", [None] * len(gt_scores_list))
for scores, labels, valid_indices in zip(gt_scores_list, qa_link_labels, gt_valid_q_indices):
if scores is None or labels is None or valid_indices is None:
continue
# Filter labels to only include valid Q indices (matching scores shape)
filtered_labels = labels[valid_indices] if len(valid_indices) > 0 else labels
if scores.numel() > 0 and filtered_labels.numel() > 0:
# scores: (num_valid_q, num_gt_a), filtered_labels: (num_valid_q,)
num_a = scores.shape[1]
valid_mask = (filtered_labels >= 0) & (filtered_labels < num_a)
if valid_mask.any():
link_loss = F.cross_entropy(scores[valid_mask], filtered_labels[valid_mask])
qa_link_losses.append(link_loss)
if qa_link_losses:
losses["qa_link_loss"] = torch.stack(qa_link_losses).mean()
else:
losses["qa_link_loss"] = torch.tensor(0.0, device=device)
else:
losses["qa_link_loss"] = torch.tensor(0.0, device=device)
# Question Predictor Loss (MSE between predicted Q and real Q in Sentence Transformer space)
# QuestionPredictor learns: A_lilt_embed (768) → Q_st_embed (384)
# Training: when Q is masked, compare predicted Q with real Q (sentence transformer)
if "pred_q_from_a" in outputs and "real_q_from_pred" in outputs:
pred_q_from_a = outputs["pred_q_from_a"] # List of lists
real_q_from_pred = outputs["real_q_from_pred"] # List of lists
q_predict_losses = []
for batch_pred, batch_real in zip(pred_q_from_a, real_q_from_pred):
if batch_pred is None or batch_real is None:
continue
for pred_q, real_q in zip(batch_pred, batch_real):
if pred_q is not None and real_q is not None:
# MSE loss to make predicted Q similar to real Q
mse = F.mse_loss(pred_q, real_q.detach())
q_predict_losses.append(mse)
if q_predict_losses:
losses["q_predict_loss"] = torch.stack(q_predict_losses).mean()
else:
losses["q_predict_loss"] = torch.tensor(0.0, device=device)
else:
losses["q_predict_loss"] = torch.tensor(0.0, device=device)
# Step 2: Query-Question Matching - ZERO-SHOT (no loss)
# Zero-shot Q text matching - no training needed
# QuestionPredictor IS trained (for A→Q prediction when no Q regions)
losses["step2_loss"] = torch.tensor(0.0, device=device)
# Step 3: Table Losses
if table_labels is not None:
table_logits = outputs["table_logits"]
table_logits_flat = table_logits.view(-1, 2)
table_labels_flat = table_labels.view(-1)
table_det_loss = F.cross_entropy(
table_logits_flat[attn_flat],
table_labels_flat[attn_flat],
)
losses["table_det_loss"] = table_det_loss
else:
losses["table_det_loss"] = torch.tensor(0.0, device=device)
if header_labels is not None and table_labels is not None:
header_logits = outputs["header_logits"]
table_mask = (table_labels == 1) & attention_mask.bool()
if table_mask.any():
header_logits_flat = header_logits.view(-1, 2)
header_labels_flat = header_labels.view(-1)
table_mask_flat = table_mask.view(-1)
header_loss = F.cross_entropy(
header_logits_flat[table_mask_flat],
header_labels_flat[table_mask_flat],
)
losses["header_loss"] = header_loss
else:
losses["header_loss"] = torch.tensor(0.0, device=device)
else:
losses["header_loss"] = torch.tensor(0.0, device=device)
if row_labels is not None and table_labels is not None:
row_logits = outputs["row_logits"]
table_mask = (table_labels == 1) & attention_mask.bool()
if table_mask.any():
row_logits_flat = row_logits.view(-1, 3)
row_labels_flat = row_labels.view(-1)
table_mask_flat = table_mask.view(-1)
row_loss = F.cross_entropy(
row_logits_flat[table_mask_flat],
row_labels_flat[table_mask_flat],
)
losses["row_loss"] = row_loss
else:
losses["row_loss"] = torch.tensor(0.0, device=device)
else:
losses["row_loss"] = torch.tensor(0.0, device=device)
if col_labels is not None and table_labels is not None and header_labels is not None:
col_scores = outputs["col_scores"]
cell_mask = (table_labels == 1) & (header_labels == 0) & attention_mask.bool()
if cell_mask.any():
col_scores_flat = col_scores.view(-1, col_scores.shape[-1])
col_labels_flat = col_labels.view(-1)
cell_mask_flat = cell_mask.view(-1)
col_loss = F.cross_entropy(
col_scores_flat[cell_mask_flat],
col_labels_flat[cell_mask_flat],
)
losses["col_loss"] = col_loss
else:
losses["col_loss"] = torch.tensor(0.0, device=device)
else:
losses["col_loss"] = torch.tensor(0.0, device=device)
# Total loss with weighted components
total_loss = (
losses["step1_loss"] + # Token classification (main task)
losses["qa_link_loss"] * 0.3 + # Q-A linker (reduced to not compete with step1)
losses["q_predict_loss"] * 0.2 + # Q prediction aux (reduced)
losses["step2_loss"] * 0.5 + # Query-Q matching (reduced)
losses["table_det_loss"] * 0.5 + # Table detection
losses["header_loss"] * 0.3 + # Header detection
losses["row_loss"] * 0.3 + # Row segmentation
losses["col_loss"] * 0.3 # Column assignment
)
losses["loss"] = total_loss
return losses
def compute_contrastive_loss(
query_embedding: torch.Tensor, # (batch, hidden)
span_embeddings: torch.Tensor, # (batch, num_spans, span_dim)
span_mask: torch.Tensor, # (batch, num_spans)
match_labels: torch.Tensor, # (batch,) - index of correct span
temperature: float = 0.07,
hard_negative_weight: float = 2.0,
) -> torch.Tensor:
"""
Contrastive loss for query-answer matching.
InfoNCE-style loss that:
- Pulls query towards correct answer span
- Pushes query away from all other spans (in-batch negatives)
- Applies extra weight to hard negatives (high-scoring wrong answers)
Args:
query_embedding: Query [CLS] embeddings
span_embeddings: Pooled span embeddings [h_start; h_end; bbox]
span_mask: Valid span mask
match_labels: Index of correct span for each query
temperature: Softmax temperature (lower = harder)
hard_negative_weight: Extra weight for hard negatives
Returns:
Contrastive loss scalar
"""
batch_size = query_embedding.shape[0]
device = query_embedding.device
if batch_size == 0:
return torch.tensor(0.0, device=device)
# Project spans to same dimension as query if needed
# (This is already done by QueryAnswerMatcher, so we use the projected scores)
# Here we compute a simpler version using the match_scores from forward
# For now, use a margin-based contrastive approach
# We want: sim(q, correct_span) > sim(q, wrong_span) + margin
losses = []
for b in range(batch_size):
valid_spans = span_mask[b].sum().item()
if valid_spans <= 1:
continue
correct_idx = match_labels[b].item()
if correct_idx >= valid_spans:
continue
# Get span embeddings for this sample
spans = span_embeddings[b, :int(valid_spans), :] # (num_valid, dim)
query = query_embedding[b] # (hidden,)
# Simple approach: normalize and compute similarities
# The projection happens in QueryAnswerMatcher, so we compute raw similarities here
# This loss is auxiliary to the CE loss
# Normalize for cosine similarity
spans_norm = F.normalize(spans[:, :query.shape[0]], dim=-1) # Use first hidden_size dims
query_norm = F.normalize(query, dim=-1)
# Compute similarities
sims = torch.matmul(spans_norm, query_norm) / temperature # (num_valid,)
# Create target (one-hot)
target = torch.zeros(int(valid_spans), device=device)
target[correct_idx] = 1.0
# InfoNCE: -log(exp(sim_pos) / sum(exp(sim_all)))
# With hard negative weighting
weights = torch.ones(int(valid_spans), device=device)
weights[correct_idx] = 1.0
# Hard negatives: spans with high similarity but wrong
with torch.no_grad():
hard_neg_mask = (sims > sims[correct_idx] - 0.5) & (torch.arange(int(valid_spans), device=device) != correct_idx)
weights[hard_neg_mask] = hard_negative_weight
# Weighted softmax cross-entropy
log_probs = F.log_softmax(sims, dim=0)
loss = -log_probs[correct_idx]
# Add margin loss for hard negatives
for i in range(int(valid_spans)):
if i != correct_idx and hard_neg_mask[i] if hard_neg_mask.any() else False:
margin_loss = F.relu(sims[i] - sims[correct_idx] + 0.3)
loss = loss + 0.1 * margin_loss
losses.append(loss)
if not losses:
return torch.tensor(0.0, device=device)
return torch.stack(losses).mean()
def compute_loss_with_contrastive(
outputs: Dict[str, torch.Tensor],
token_labels: torch.Tensor,
attention_mask: torch.Tensor,
table_labels: Optional[torch.Tensor] = None,
row_labels: Optional[torch.Tensor] = None,
header_labels: Optional[torch.Tensor] = None,
match_labels: Optional[torch.Tensor] = None,
col_labels: Optional[torch.Tensor] = None,
class_weights: Optional[torch.Tensor] = None,
qa_link_labels: Optional[List[torch.Tensor]] = None,
contrastive_weight: float = 0.5,
) -> Dict[str, torch.Tensor]:
"""
Compute joint loss with contrastive learning for Step 2.
Adds InfoNCE-style contrastive loss to help learn better span representations.
"""
# Get base losses
losses = compute_loss(
outputs=outputs,
token_labels=token_labels,
attention_mask=attention_mask,
table_labels=table_labels,
row_labels=row_labels,
header_labels=header_labels,
match_labels=match_labels,
col_labels=col_labels,
class_weights=class_weights,
qa_link_labels=qa_link_labels,
)
# Contrastive loss for Step 2 is DISABLED for zero-shot matching
# Zero-shot matching - no contrastive loss needed
losses["contrastive_loss"] = torch.tensor(0.0, device=token_labels.device)
return losses
if __name__ == "__main__":
print("Testing DEXTR LiLT model with CASCADING architecture...")
print("=" * 60)
model = DEXTRLiLT(
model_name="nielsr/lilt-xlm-roberta-base",
max_seq_len=1024,
)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
# Test forward
batch_size = 2
seq_len = 128
query_len = 16
input_ids = torch.randint(0, 1000, (batch_size, seq_len))
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long)
# Create valid bboxes: [x1, y1, x2, y2] where x1<x2 and y1<y2, values in [0, 1000]
x1 = torch.randint(0, 500, (batch_size, seq_len))
y1 = torch.randint(0, 500, (batch_size, seq_len))
x2 = x1 + torch.randint(10, 500, (batch_size, seq_len))
y2 = y1 + torch.randint(10, 500, (batch_size, seq_len))
# Clamp to valid range [0, 1000]
bbox = torch.stack([x1, y1, x2.clamp(max=1000), y2.clamp(max=1000)], dim=-1)
query_input_ids = torch.randint(0, 1000, (batch_size, query_len))
query_attention_mask = torch.ones(batch_size, query_len, dtype=torch.long)
# GT spans (only used for loss labels in cascading architecture)
gt_answer_spans = [[(5, 10), (20, 25)], [(8, 12)]]
gt_question_spans = [[(2, 5), (17, 20)], [(5, 8)]]
print("\nRunning forward pass (CASCADING architecture)...")
print(" - Step 1: Token classification → predict Q/A/H/TABLE/O")
print(" - Step 1.5: Extract predicted regions, QALinker pairs them")
print(" - Step 2: Query matches to predicted Q embeddings")
model.eval()
with torch.no_grad():
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
bbox=bbox,
query_input_ids=query_input_ids,
query_attention_mask=query_attention_mask,
gt_answer_spans=gt_answer_spans,
gt_question_spans=gt_question_spans,
)
print(f"\nOutput shapes:")
print(f" token_logits: {outputs['token_logits'].shape}")
print(f" table_logits: {outputs['table_logits'].shape}")
print(f" header_logits: {outputs['header_logits'].shape}")
print(f" row_logits: {outputs['row_logits'].shape}")
print(f" col_scores: {outputs['col_scores'].shape}")
print(f" query_embedding: {outputs['query_embedding'].shape}")
# Predicted regions from Step 1
print(f"\nPredicted regions from Step 1:")
print(f" pred_q_regions: {outputs['pred_q_regions']}")
print(f" pred_a_regions: {outputs['pred_a_regions']}")
# Match scores may not exist if no regions predicted
if "match_scores" in outputs:
print(f" match_scores: {outputs['match_scores'].shape}")
else:
print(" match_scores: None (no regions predicted)")
# Test loss
print("\nTesting loss computation...")
token_labels = torch.randint(0, NUM_LABELS, (batch_size, seq_len))
table_labels = torch.randint(0, 2, (batch_size, seq_len))
row_labels = torch.randint(0, 3, (batch_size, seq_len))
header_labels = torch.randint(0, 2, (batch_size, seq_len))
# For cascading: match_labels should point to predicted Q region index
# This would normally be computed by matching GT fields to predicted regions
match_labels = None
if "match_scores" in outputs and outputs["match_scores"].shape[1] > 0:
match_labels = torch.zeros(batch_size, dtype=torch.long)
losses = compute_loss(
outputs,
token_labels=token_labels,
attention_mask=attention_mask,
table_labels=table_labels,
row_labels=row_labels,
header_labels=header_labels,
match_labels=match_labels,
)
print(f"\nLosses:")
for name, value in losses.items():
print(f" {name}: {value.item():.4f}")
print("\n" + "=" * 60)
print("DEXTR LiLT CASCADING architecture test passed!")