""" DEXTR with LiLT - 3-Step Document Extraction. Architecture: - Document encoder: LiLT (XLM-RoBERTa + Layout Transformer), extended to 1024 tokens - Query encoder: Sentence Transformer (frozen, for semantic similarity) - Step 1: Token Classification (Q/A/T/H/O) - TRAINED - Step 2: Query-Question Matching (ZERO-SHOT) - NO TRAINING - Step 3: Table Head (hierarchical with attention-based column assignment) - TRAINED Key insight: Query→Question matching is semantic similarity. Sentence transformers provide proper semantic embeddings for matching queries to question text. """ import torch import torch.nn as nn import torch.nn.functional as F from transformers import LiltModel, LiltConfig from sentence_transformers import SentenceTransformer from typing import Optional, Dict, List, Tuple # Label mappings for Step 1 LABEL2ID = { "O": 0, "B-Q": 1, "I-Q": 2, # Question "B-A": 3, "I-A": 4, # Answer "B-H": 5, "I-H": 6, # Header "B-TABLE": 7, "I-TABLE": 8, # Table } ID2LABEL = {v: k for k, v in LABEL2ID.items()} NUM_LABELS = len(LABEL2ID) class TokenClassificationHead(nn.Module): """ Step 1: Token Classification Head. Predicts Q/A/H/T/O labels for each token (FUNSD-style). """ def __init__(self, hidden_size: int = 768, num_labels: int = NUM_LABELS, dropout: float = 0.1): super().__init__() self.classifier = nn.Sequential( nn.Dropout(dropout), nn.Linear(hidden_size, 256), nn.GELU(), nn.Dropout(dropout), nn.Linear(256, num_labels), ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.classifier(hidden_states) class QALinker(nn.Module): """ Step 1.5: Q-A Linker. Predicts which Answer span links to which Question span. Uses both semantic similarity and spatial features. """ def __init__(self, hidden_size: int = 768, dropout: float = 0.1): super().__init__() self.hidden_size = hidden_size # Project Q and A to same space for semantic matching self.q_proj = nn.Sequential( nn.Linear(hidden_size, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(dropout), ) self.a_proj = nn.Sequential( nn.Linear(hidden_size, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(dropout), ) # Spatial feature scorer # Features: x_dist, y_dist, is_right, is_below, is_same_line, width_ratio, height_ratio self.spatial_scorer = nn.Sequential( nn.Linear(7, 32), nn.GELU(), nn.Linear(32, 1), ) # Learnable temperature self.log_temp = nn.Parameter(torch.tensor(1.0)) def compute_spatial_features( self, q_bboxes: torch.Tensor, # (num_q, 4) - [x1, y1, x2, y2] a_bboxes: torch.Tensor, # (num_a, 4) ) -> torch.Tensor: """ Compute spatial features for all Q-A pairs. Returns: (num_q, num_a, 7) feature tensor """ num_q = q_bboxes.shape[0] num_a = a_bboxes.shape[0] device = q_bboxes.device # Compute centers and sizes q_cx = (q_bboxes[:, 0] + q_bboxes[:, 2]) / 2 # (num_q,) q_cy = (q_bboxes[:, 1] + q_bboxes[:, 3]) / 2 q_w = q_bboxes[:, 2] - q_bboxes[:, 0] q_h = q_bboxes[:, 3] - q_bboxes[:, 1] a_cx = (a_bboxes[:, 0] + a_bboxes[:, 2]) / 2 # (num_a,) a_cy = (a_bboxes[:, 1] + a_bboxes[:, 3]) / 2 a_w = a_bboxes[:, 2] - a_bboxes[:, 0] a_h = a_bboxes[:, 3] - a_bboxes[:, 1] # Expand for pairwise computation q_cx = q_cx.unsqueeze(1).expand(num_q, num_a) # (num_q, num_a) q_cy = q_cy.unsqueeze(1).expand(num_q, num_a) q_x2 = q_bboxes[:, 2].unsqueeze(1).expand(num_q, num_a) q_y2 = q_bboxes[:, 3].unsqueeze(1).expand(num_q, num_a) q_w = q_w.unsqueeze(1).expand(num_q, num_a) q_h = q_h.unsqueeze(1).expand(num_q, num_a) a_cx = a_cx.unsqueeze(0).expand(num_q, num_a) a_cy = a_cy.unsqueeze(0).expand(num_q, num_a) a_x1 = a_bboxes[:, 0].unsqueeze(0).expand(num_q, num_a) a_y1 = a_bboxes[:, 1].unsqueeze(0).expand(num_q, num_a) a_w = a_w.unsqueeze(0).expand(num_q, num_a) a_h = a_h.unsqueeze(0).expand(num_q, num_a) # Compute features (all normalized to [0, 1] range roughly) x_dist = (a_x1 - q_x2) / 1000.0 # Horizontal distance (positive = A right of Q) y_dist = (a_cy - q_cy).abs() / 1000.0 # Vertical distance is_right = (a_x1 > q_x2).float() # A is to the right of Q is_below = (a_y1 > q_y2).float() # A is below Q is_same_line = (y_dist < 0.03).float() # Same line (small y distance) width_ratio = (a_w / (q_w + 1e-6)).clamp(0, 5) / 5.0 # Relative width height_ratio = (a_h / (q_h + 1e-6)).clamp(0, 5) / 5.0 # Relative height # Stack features: (num_q, num_a, 7) features = torch.stack([ x_dist, y_dist, is_right, is_below, is_same_line, width_ratio, height_ratio ], dim=-1) return features def forward( self, q_embeds: torch.Tensor, # (num_q, hidden) a_embeds: torch.Tensor, # (num_a, hidden) q_bboxes: torch.Tensor, # (num_q, 4) a_bboxes: torch.Tensor, # (num_a, 4) ) -> torch.Tensor: """ Compute Q-A link scores. Returns: (num_q, num_a) score matrix """ # Semantic similarity q_proj = self.q_proj(q_embeds) # (num_q, 256) a_proj = self.a_proj(a_embeds) # (num_a, 256) q_proj = F.normalize(q_proj, dim=-1) a_proj = F.normalize(a_proj, dim=-1) semantic_scores = torch.matmul(q_proj, a_proj.t()) # (num_q, num_a) # Spatial scores spatial_feats = self.compute_spatial_features(q_bboxes, a_bboxes) # (num_q, num_a, 7) spatial_scores = self.spatial_scorer(spatial_feats).squeeze(-1) # (num_q, num_a) # Combine with learnable temperature temperature = self.log_temp.exp().clamp(min=0.1, max=10.0) combined_scores = (semantic_scores + spatial_scores) * temperature return combined_scores class QuestionPredictor(nn.Module): """ Predicts Question embedding from Answer embedding. Used when explicit Question tokens are missing in the document. Input: LiLT A embedding (768 dim) Output: Predicted Q embedding in sentence transformer space (384 dim) Training: - For Q-A pairs with Q: randomly mask Q and train to predict it - Learns what Q "should look like" given the answer """ def __init__(self, input_dim: int = 768, output_dim: int = 384, dropout: float = 0.1): super().__init__() # Input: LiLT answer embedding (768) # Output: predicted question embedding in sentence transformer space (384) self.predictor = nn.Sequential( nn.Linear(input_dim, input_dim), nn.LayerNorm(input_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(input_dim, output_dim), nn.LayerNorm(output_dim), ) def forward(self, answer_embed: torch.Tensor) -> torch.Tensor: """ Predict question embedding from answer embedding. Args: answer_embed: (batch, hidden) or (hidden,) - LiLT answer span embedding Returns: predicted_q_embed: (batch, output_dim) or (output_dim,) - in sentence transformer space """ return self.predictor(answer_embed) class QueryAnswerMatcher(nn.Module): """ Step 2: Query-Answer Matching. Given answer span embeddings and query embedding, scores which span matches. Features: - Start+End pooling: spans use [h_start; h_end] instead of mean (captures boundaries) - Bbox features: spatial position helps distinguish similar text - Learnable temperature for score scaling - Cosine similarity with projection to lower dim """ def __init__( self, hidden_size: int = 768, proj_dim: int = 256, dropout: float = 0.1, use_bbox_features: bool = True, num_bbox_features: int = 6, # x1, y1, x2, y2, width, height ): super().__init__() self.proj_dim = proj_dim self.use_bbox_features = use_bbox_features self.hidden_size = hidden_size # Span input: [start; end] = 2*hidden, optionally + bbox span_input_dim = 2 * hidden_size if use_bbox_features: span_input_dim += num_bbox_features self.span_proj = nn.Sequential( nn.Linear(span_input_dim, proj_dim), nn.LayerNorm(proj_dim), nn.GELU(), nn.Dropout(dropout), ) self.query_proj = nn.Sequential( nn.Linear(hidden_size, proj_dim), nn.LayerNorm(proj_dim), nn.GELU(), nn.Dropout(dropout), ) # Learnable temperature (initialized to sqrt(proj_dim) like standard attention) self.log_temp = nn.Parameter(torch.tensor(proj_dim ** 0.5).log()) def forward( self, span_embeddings: torch.Tensor, # (batch, num_spans, span_input_dim) query_embedding: torch.Tensor, # (batch, hidden) span_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: span_proj = self.span_proj(span_embeddings) query_proj = self.query_proj(query_embedding).unsqueeze(1) # Normalize for cosine similarity-like behavior span_proj = F.normalize(span_proj, dim=-1) query_proj = F.normalize(query_proj, dim=-1) # Scaled dot product with learnable temperature temperature = self.log_temp.exp().clamp(min=0.1, max=100.0) scores = torch.bmm(query_proj, span_proj.transpose(1, 2)).squeeze(1) * temperature if span_mask is not None: scores = scores.masked_fill(~span_mask.bool(), float('-inf')) return scores class TableHead(nn.Module): """ Step 3: Hierarchical Table Extraction Head. Sub-steps: 1. Table detection (TABLE/O per token) 2. Header detection (HEADER/CELL per token) 3. Row segmentation (B-ROW/I-ROW/O) 4. Column assignment (cell → header via attention) """ def __init__(self, hidden_size: int = 768, dropout: float = 0.1): super().__init__() self.table_detector = nn.Sequential( nn.Dropout(dropout), nn.Linear(hidden_size, 256), nn.GELU(), nn.Linear(256, 2), ) self.header_detector = nn.Sequential( nn.Dropout(dropout), nn.Linear(hidden_size, 256), nn.GELU(), nn.Linear(256, 2), ) self.row_tagger = nn.Sequential( nn.Dropout(dropout), nn.Linear(hidden_size, 256), nn.GELU(), nn.Linear(256, 3), ) self.col_key_proj = nn.Linear(hidden_size, 128) self.col_query_proj = nn.Linear(hidden_size, 128) def forward(self, hidden_states: torch.Tensor) -> Dict[str, torch.Tensor]: table_logits = self.table_detector(hidden_states) header_logits = self.header_detector(hidden_states) row_logits = self.row_tagger(hidden_states) col_keys = self.col_key_proj(hidden_states) col_queries = self.col_query_proj(hidden_states) col_scores = torch.bmm(col_queries, col_keys.transpose(1, 2)) / (128 ** 0.5) return { "table_logits": table_logits, "header_logits": header_logits, "row_logits": row_logits, "col_scores": col_scores, } class DEXTRLiLT(nn.Module): """ DEXTR with LiLT: Multi-Step Document Extraction with Zero-Shot Query Matching. Architecture: - Step 1: Token Classification (Q/A/H/T/O) - TRAINED - Step 1.5: Q-A Linker (links Question spans to Answer spans) - TRAINED - Step 2: Query → Question Matching (ZERO-SHOT with Sentence Transformer) - Step 3: Table Head - TRAINED Uses SEPARATE encoders: - Document encoder: LiLT (XLM-RoBERTa + Layout Transformer) for layout-aware encoding - Query encoder: Sentence Transformer (frozen) for semantic similarity Zero-shot matching: - Query text encoded with frozen Sentence Transformer - Question text (from predicted Q regions) encoded with frozen Sentence Transformer - Direct cosine similarity (no trainable projections) - Fallback: if no Q regions, use QuestionPredictor on A text """ def __init__( self, model_name: str = "nielsr/lilt-xlm-roberta-base", query_model_name: str = "paraphrase-multilingual-MiniLM-L12-v2", max_seq_len: int = 1024, hidden_size: int = 768, dropout: float = 0.1, q_mask_prob: float = 0.3, # Probability of masking Q during training (for Q-A linker) ): super().__init__() self.max_seq_len = max_seq_len self.hidden_size = hidden_size self.q_mask_prob = q_mask_prob # Document encoder: LiLT (layout-aware) if max_seq_len > 512: self.encoder = self._load_lilt_extended(model_name, max_seq_len) else: self.encoder = LiltModel.from_pretrained(model_name) # Query encoder: Sentence Transformer (frozen, zero-shot) # Provides proper semantic similarity for query→question matching self.query_encoder = SentenceTransformer(query_model_name) self.query_encoder.requires_grad_(False) # Freeze for zero-shot self.query_embed_dim = self.query_encoder.get_sentence_embedding_dimension() print(f"Loaded frozen query encoder: {query_model_name} (dim={self.query_embed_dim})") # Step 1: Token Classification Head self.token_classifier = TokenClassificationHead(hidden_size, NUM_LABELS, dropout) # Step 1.5: Q-A Linker (links Q spans to A spans) self.qa_linker = QALinker(hidden_size, dropout) # Step 2: QuestionPredictor for documents without explicit Q labels (e.g., receipts) # Input: LiLT A embedding (768 dim), Output: Q embedding in sentence transformer space (384 dim) self.question_predictor = QuestionPredictor(hidden_size, self.query_embed_dim, dropout) # Step 3: Table Head self.table_head = TableHead(hidden_size, dropout) def _load_lilt_extended(self, model_name: str, max_seq_len: int) -> LiltModel: """Load LiLT with extended position embeddings.""" model = LiltModel.from_pretrained(model_name) original_max_pos = model.config.max_position_embeddings required_positions = max_seq_len + 2 if required_positions <= original_max_pos: return model # 1. Extend text position embeddings old_pos_emb = model.embeddings.position_embeddings.weight.data hidden_size = old_pos_emb.shape[1] padding_idx = model.embeddings.position_embeddings.padding_idx new_pos_emb = nn.Embedding(required_positions, hidden_size, padding_idx=padding_idx) new_pos_emb.weight.data[:original_max_pos] = old_pos_emb new_pos_emb.weight.data[original_max_pos:].normal_(mean=0.0, std=0.02) model.embeddings.position_embeddings = new_pos_emb # 2. Extend layout box_position_embeddings (same sequence length limit) old_box_pos = model.layout_embeddings.box_position_embeddings.weight.data box_hidden = old_box_pos.shape[1] # 192 new_box_pos = nn.Embedding(required_positions, box_hidden) new_box_pos.weight.data[:original_max_pos] = old_box_pos new_box_pos.weight.data[original_max_pos:].normal_(mean=0.0, std=0.02) model.layout_embeddings.box_position_embeddings = new_box_pos # 3. Update config model.config.max_position_embeddings = required_positions # 4. Extend position_ids buffer new_position_ids = torch.arange(required_positions).unsqueeze(0) model.embeddings.register_buffer("position_ids", new_position_ids, persistent=False) print(f"Extended LiLT positions: {original_max_pos} → {required_positions}") return model def encode( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, bbox: torch.Tensor, ) -> torch.Tensor: """Encode tokens with layout using shared LiLT encoder.""" outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, bbox=bbox, ) return outputs.last_hidden_state def encode_texts( self, texts: List[str], device: torch.device, ) -> torch.Tensor: """ Encode text strings using frozen Sentence Transformer (zero-shot). Used for encoding Q text extracted from predicted regions. Args: texts: List of text strings to encode device: Device to put tensors on Returns: embeddings: (num_texts, embed_dim) tensor """ if not texts: return torch.zeros(0, self.query_embed_dim, device=device) with torch.no_grad(): embeddings = self.query_encoder.encode( texts, convert_to_tensor=True, device=device, ) # Clone to exit inference mode (needed for autograd compatibility) embeddings = embeddings.clone() return embeddings # (num_texts, embed_dim) def pool_spans( self, hidden_states: torch.Tensor, span_indices: List[List[Tuple[int, int]]], bbox: Optional[torch.Tensor] = None, use_bbox_features: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Pool hidden states for answer spans using start+end tokens. Returns [h_start; h_end; bbox_features] for each span instead of mean pooling. This preserves boundary information which is crucial for extraction tasks. Args: hidden_states: (batch, seq, hidden) - document encodings span_indices: List of (start, end) tuples per batch item bbox: (batch, seq, 4) - bounding boxes [x1, y1, x2, y2] normalized to [0, 1000] use_bbox_features: whether to include bbox in span representation Returns: span_embeddings: (batch, max_spans, span_dim) where span_dim = 2*hidden + 6 (if bbox) span_mask: (batch, max_spans) bool mask """ batch_size = hidden_states.shape[0] hidden_size = hidden_states.shape[2] device = hidden_states.device max_spans = max(len(spans) for spans in span_indices) if span_indices else 1 max_spans = max(max_spans, 1) # Output dimension: [h_start; h_end] + optionally [x1, y1, x2, y2, width, height] span_dim = 2 * hidden_size if use_bbox_features and bbox is not None: span_dim += 6 # normalized bbox features span_embeddings = torch.zeros(batch_size, max_spans, span_dim, device=device) span_mask = torch.zeros(batch_size, max_spans, dtype=torch.bool, device=device) for b, spans in enumerate(span_indices): for s, (start, end) in enumerate(spans): if s >= max_spans: break if start >= hidden_states.shape[1] or end > hidden_states.shape[1]: continue # Start+End pooling: [h_start; h_end] h_start = hidden_states[b, start, :] h_end = hidden_states[b, end - 1, :] # end is exclusive span_repr = torch.cat([h_start, h_end], dim=0) # Add bbox features if available if use_bbox_features and bbox is not None: # Get bbox of span (union of start and end token bboxes) bbox_start = bbox[b, start, :] # [x1, y1, x2, y2] bbox_end = bbox[b, end - 1, :] # Compute span bounding box (min x1/y1, max x2/y2) span_x1 = torch.min(bbox_start[0], bbox_end[0]) span_y1 = torch.min(bbox_start[1], bbox_end[1]) span_x2 = torch.max(bbox_start[2], bbox_end[2]) span_y2 = torch.max(bbox_start[3], bbox_end[3]) # Normalize to [0, 1] and add width/height bbox_feat = torch.tensor([ span_x1 / 1000.0, span_y1 / 1000.0, span_x2 / 1000.0, span_y2 / 1000.0, (span_x2 - span_x1) / 1000.0, # width (span_y2 - span_y1) / 1000.0, # height ], device=device) span_repr = torch.cat([span_repr, bbox_feat], dim=0) span_embeddings[b, s] = span_repr span_mask[b, s] = True return span_embeddings, span_mask def pool_single_span( self, hidden_states: torch.Tensor, # (seq, hidden) - single sample span: Tuple[int, int], bbox: Optional[torch.Tensor] = None, # (seq, 4) ) -> torch.Tensor: """ Pool a single span to get its embedding. Returns mean of start and end tokens (hidden_size,). """ start, end = span if start >= hidden_states.shape[0] or end > hidden_states.shape[0]: return torch.zeros(self.hidden_size, device=hidden_states.device) h_start = hidden_states[start, :] h_end = hidden_states[end - 1, :] # Mean of start and end for Q embedding (simpler than concat for matching) return (h_start + h_end) / 2 def get_span_bbox( self, bbox: torch.Tensor, # (seq, 4) span: Tuple[int, int], ) -> torch.Tensor: """Get bounding box for a span (union of start and end token bboxes).""" start, end = span bbox_start = bbox[start, :] bbox_end = bbox[end - 1, :] span_bbox = torch.stack([ torch.min(bbox_start[0], bbox_end[0]), # x1 torch.min(bbox_start[1], bbox_end[1]), # y1 torch.max(bbox_start[2], bbox_end[2]), # x2 torch.max(bbox_start[3], bbox_end[3]), # y2 ]) return span_bbox def extract_span_text( self, span: Tuple[int, int], tokens: List[str], subword_to_word: Dict[int, int], ) -> str: """ Extract text from a span using tokens and subword-to-word mapping. Args: span: (start, end) subword indices (exclusive end) tokens: List of word tokens subword_to_word: Dict mapping subword idx -> word idx Returns: Text string for the span """ start, end = span # Get word indices for this span word_indices = set() for subword_idx in range(start, end): if subword_idx in subword_to_word: word_indices.add(subword_to_word[subword_idx]) if not word_indices: return "" # Get contiguous word range min_word = min(word_indices) max_word = max(word_indices) # Extract and join tokens span_tokens = tokens[min_word:max_word + 1] return " ".join(span_tokens) def extract_qa_regions( self, token_logits: torch.Tensor, # (batch, seq, num_labels) attention_mask: torch.Tensor, # (batch, seq) ) -> Tuple[List[List[Tuple[int, int]]], List[List[Tuple[int, int]]]]: """ Extract Q and A regions from Step 1 token predictions. This enables the cascading architecture where Step 2 uses Step 1's predictions instead of ground truth spans. Returns: q_regions: List of Q spans per batch item [(start, end), ...] a_regions: List of A spans per batch item [(start, end), ...] """ batch_size = token_logits.shape[0] preds = token_logits.argmax(dim=-1) # (batch, seq) q_regions = [] a_regions = [] for b in range(batch_size): sample_q_regions = [] sample_a_regions = [] seq_len = attention_mask[b].sum().item() pred_seq = preds[b, :int(seq_len)].cpu().tolist() # Extract Q spans (B-Q=1, I-Q=2) current_span = None for i, label in enumerate(pred_seq): if label == 1: # B-Q if current_span is not None: sample_q_regions.append(current_span) current_span = (i, i + 1) elif label == 2: # I-Q if current_span is not None: current_span = (current_span[0], i + 1) else: if current_span is not None: sample_q_regions.append(current_span) current_span = None if current_span is not None: sample_q_regions.append(current_span) # Extract A spans (B-A=3, I-A=4) current_span = None for i, label in enumerate(pred_seq): if label == 3: # B-A if current_span is not None: sample_a_regions.append(current_span) current_span = (i, i + 1) elif label == 4: # I-A if current_span is not None: current_span = (current_span[0], i + 1) else: if current_span is not None: sample_a_regions.append(current_span) current_span = None if current_span is not None: sample_a_regions.append(current_span) q_regions.append(sample_q_regions) a_regions.append(sample_a_regions) return q_regions, a_regions def match_regions_to_gt( self, pred_regions: List[Tuple[int, int]], gt_regions: List[Tuple[int, int]], ) -> Tuple[List[int], List[int]]: """ Match predicted regions to GT regions by overlap. Returns: gt_to_pred: For each GT region, index of best matching pred region (-1 if none) pred_to_gt: For each pred region, index of best matching GT region (-1 if none) """ gt_to_pred = [] for gt_start, gt_end in gt_regions: best_pred_idx = -1 best_overlap = 0 for pred_idx, (pred_start, pred_end) in enumerate(pred_regions): overlap_start = max(gt_start, pred_start) overlap_end = min(gt_end, pred_end) overlap = max(0, overlap_end - overlap_start) if overlap > best_overlap: best_overlap = overlap best_pred_idx = pred_idx gt_to_pred.append(best_pred_idx) pred_to_gt = [] for pred_start, pred_end in pred_regions: best_gt_idx = -1 best_overlap = 0 for gt_idx, (gt_start, gt_end) in enumerate(gt_regions): overlap_start = max(gt_start, pred_start) overlap_end = min(gt_end, pred_end) overlap = max(0, overlap_end - overlap_start) if overlap > best_overlap: best_overlap = overlap best_gt_idx = gt_idx pred_to_gt.append(best_gt_idx) return gt_to_pred, pred_to_gt def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, bbox: torch.Tensor, tokens: Optional[List[List[str]]] = None, subword_to_word: Optional[List[Dict[int, int]]] = None, query_texts: Optional[List[str]] = None, gt_answer_spans: Optional[List[List[Tuple[int, int]]]] = None, gt_question_spans: Optional[List[List[Optional[Tuple[int, int]]]]] = None, target_field_idx: Optional[List[int]] = None, training: bool = True, ) -> Dict[str, torch.Tensor]: """ Forward pass with CASCADING architecture and ZERO-SHOT query matching. Key features: - Step 1: Token classification to predict Q/A regions - Step 1.5: Q-A linking using LiLT embeddings + spatial features - Step 2: ZERO-SHOT query→question matching using Sentence Transformer - GT spans are only used to compute labels (which predicted region is correct) Args: input_ids: (batch, seq) token IDs attention_mask: (batch, seq) attention mask bbox: (batch, seq, 4) bounding boxes tokens: List of word tokens per batch item (for zero-shot Q text encoding) subword_to_word: List of dicts mapping subword idx to word idx query_texts: List of raw query strings (for sentence transformer) gt_answer_spans: GT answer spans per batch item (for loss labels only) gt_question_spans: GT question spans per batch item (for loss labels only) target_field_idx: Index of GT field that query should match (for loss labels) training: whether in training mode (affects Q masking for Q-A linker) """ batch_size = input_ids.shape[0] device = input_ids.device # Encode document hidden_states = self.encode(input_ids, attention_mask, bbox) # Step 1: Token Classification token_logits = self.token_classifier(hidden_states) # Step 3: Table Head table_outputs = self.table_head(hidden_states) outputs = { "hidden_states": hidden_states, "token_logits": token_logits, **table_outputs, } # Step 1.5 + Step 2: Q-A Linking and Query-Question Matching # CASCADING: Extract Q/A regions from Step 1 PREDICTIONS if query_texts is not None: # Extract predicted Q and A regions from Step 1 output pred_q_regions, pred_a_regions = self.extract_qa_regions(token_logits, attention_mask) outputs["pred_q_regions"] = pred_q_regions outputs["pred_a_regions"] = pred_a_regions # Encode query using sentence transformer query_emb = self.encode_texts(query_texts, device) outputs["query_embedding"] = query_emb # Process each sample in batch using PREDICTED regions all_q_embeds_lilt = [] # Q embeddings from LiLT (768 dim) for QA Linker all_q_embeds_st = [] # Q embeddings from sentence transformer (384 dim) for query matching all_a_embeds = [] # A embeddings from predicted regions (768 dim) all_q_bboxes = [] # Q bboxes for QA linker all_a_bboxes = [] # A bboxes for QA linker all_pred_q_from_a = [] # Predicted Q from A (for QuestionPredictor loss) all_real_q_from_pred = [] # Real Q from predicted regions (for Q masking loss) all_real_q_embeds = [] # Real Q embeds from GT (for aux loss) all_gt_to_pred_a_idx = [] # Maps GT field idx to predicted A region idx for b in range(batch_size): sample_pred_q = pred_q_regions[b] # Predicted Q spans from Step 1 sample_pred_a = pred_a_regions[b] # Predicted A spans from Step 1 # Skip if no predicted regions if len(sample_pred_a) == 0: all_q_embeds_lilt.append(None) all_q_embeds_st.append(None) all_a_embeds.append(None) all_q_bboxes.append(None) all_a_bboxes.append(None) all_pred_q_from_a.append(None) all_real_q_from_pred.append(None) all_real_q_embeds.append(None) all_gt_to_pred_a_idx.append(None) continue # Pool A embeddings from predicted regions sample_a_embeds = [] sample_a_bboxes = [] for a_span in sample_pred_a: a_emb = self.pool_single_span(hidden_states[b], a_span, bbox[b]) sample_a_embeds.append(a_emb) sample_a_bboxes.append(self.get_span_bbox(bbox[b], a_span)) # Two Q representations: # 1. q_embeds_lilt (768) - LiLT pooled, for QA Linker # 2. q_embeds_st (384) - Sentence transformer, for query matching sample_q_embeds_lilt = [] # For QA Linker (768 dim) sample_q_embeds_st = [] # For query matching (384 dim) sample_q_bboxes = [] sample_real_q_from_pred = [] # Real Q embeds (ST) for QuestionPredictor loss sample_pred_q_from_a = [] # Predicted Q embeds from A (for loss) if len(sample_pred_q) > 0: # We have predicted Q regions # Get Q texts for sentence transformer encoding q_texts = [] if tokens is not None and subword_to_word is not None: for q_span in sample_pred_q: q_text = self.extract_span_text(q_span, tokens[b], subword_to_word[b]) q_texts.append(q_text if q_text else "unknown") # Batch encode all Q texts with sentence transformer real_q_embeds_st = self.encode_texts(q_texts, device) # (num_q, 384) else: real_q_embeds_st = None for i, q_span in enumerate(sample_pred_q): # LiLT embedding for QA Linker q_emb_lilt = self.pool_single_span(hidden_states[b], q_span, bbox[b]) sample_q_embeds_lilt.append(q_emb_lilt) q_bbox = self.get_span_bbox(bbox[b], q_span) sample_q_bboxes.append(q_bbox) # Sentence transformer embedding for query matching if real_q_embeds_st is not None: real_q_emb_st = real_q_embeds_st[i] # 384 dim else: real_q_emb_st = None # Q MASKING: during training, randomly mask Q and use predictor should_mask = training and (torch.rand(1).item() < self.q_mask_prob) if should_mask and i < len(sample_a_embeds): # Mask: use QuestionPredictor instead of real Q for query matching pred_q_emb = self.question_predictor(sample_a_embeds[i]) # 384 dim sample_q_embeds_st.append(pred_q_emb) sample_real_q_from_pred.append(real_q_emb_st) # Real Q for loss sample_pred_q_from_a.append(pred_q_emb) # Predicted Q for loss else: # No mask: use real Q (sentence transformer encoded) sample_q_embeds_st.append(real_q_emb_st if real_q_emb_st is not None else torch.zeros(self.query_embed_dim, device=device)) sample_real_q_from_pred.append(real_q_emb_st) sample_pred_q_from_a.append(None) # No prediction, no loss else: # No Q regions predicted - use QuestionPredictor on all A for a_emb in sample_a_embeds: pred_q = self.question_predictor(a_emb) # 384 dim sample_q_embeds_st.append(pred_q) sample_pred_q_from_a.append(pred_q) sample_real_q_from_pred.append(None) # No real Q from prediction # No LiLT Q embeds when no Q predicted sample_q_embeds_lilt = [] # Use A bboxes as proxy for Q bboxes sample_q_bboxes = sample_a_bboxes.copy() # Store real Q embeds from GT (for aux loss when Q was masked) # Encode GT Q text with sentence transformer (384 dim) sample_real_q = [] if gt_question_spans is not None and b < len(gt_question_spans) and tokens is not None and subword_to_word is not None: gt_q_texts = [] gt_q_valid_indices = [] for idx, gt_q_span in enumerate(gt_question_spans[b]): if gt_q_span is not None: q_text = self.extract_span_text(gt_q_span, tokens[b], subword_to_word[b]) gt_q_texts.append(q_text if q_text else "unknown") gt_q_valid_indices.append(idx) # Batch encode GT Q texts if gt_q_texts: gt_q_embeds = self.encode_texts(gt_q_texts, device) # (num_valid, 384) embed_idx = 0 for idx, gt_q_span in enumerate(gt_question_spans[b]): if gt_q_span is not None: sample_real_q.append(gt_q_embeds[embed_idx]) embed_idx += 1 else: sample_real_q.append(None) else: sample_real_q = [None] * len(gt_question_spans[b]) # Match GT A spans to predicted A regions (for label computation) gt_to_pred_a = None if gt_answer_spans is not None and b < len(gt_answer_spans): gt_a_spans = gt_answer_spans[b] gt_to_pred_a, _ = self.match_regions_to_gt(sample_pred_a, gt_a_spans) all_q_embeds_lilt.append(torch.stack(sample_q_embeds_lilt) if sample_q_embeds_lilt else None) all_q_embeds_st.append(torch.stack(sample_q_embeds_st) if sample_q_embeds_st else None) all_a_embeds.append(torch.stack(sample_a_embeds) if sample_a_embeds else None) all_q_bboxes.append(torch.stack(sample_q_bboxes) if sample_q_bboxes else None) all_a_bboxes.append(torch.stack(sample_a_bboxes) if sample_a_bboxes else None) # For Q predictor loss: parallel lists with None for non-masked positions all_pred_q_from_a.append(sample_pred_q_from_a if sample_pred_q_from_a else None) all_real_q_from_pred.append(sample_real_q_from_pred if sample_real_q_from_pred else None) all_real_q_embeds.append(sample_real_q if sample_real_q else None) all_gt_to_pred_a_idx.append(gt_to_pred_a) # Filter out None entries and compute outputs # Use LiLT Q embeds for valid check (QA Linker needs them) valid_indices_lilt = [i for i, q in enumerate(all_q_embeds_lilt) if q is not None] # Also track valid ST Q embeds for query matching valid_indices_st = [i for i, q in enumerate(all_q_embeds_st) if q is not None] if valid_indices_lilt: # Get max spans for padding (use LiLT Q for QA Linker) max_q_spans_lilt = max(all_q_embeds_lilt[i].shape[0] for i in valid_indices_lilt) max_a_spans = max(all_a_embeds[i].shape[0] for i in valid_indices_lilt) # Create padded tensors for Q (LiLT, 768 dim) - for QA Linker q_embeds_lilt_padded = torch.zeros(batch_size, max_q_spans_lilt, self.hidden_size, device=device) q_bboxes_padded = torch.zeros(batch_size, max_q_spans_lilt, 4, device=device) q_span_mask = torch.zeros(batch_size, max_q_spans_lilt, dtype=torch.bool, device=device) # Create padded tensors for A (LiLT, 768 dim) a_embeds_padded = torch.zeros(batch_size, max_a_spans, self.hidden_size, device=device) a_bboxes_padded = torch.zeros(batch_size, max_a_spans, 4, device=device) a_span_mask = torch.zeros(batch_size, max_a_spans, dtype=torch.bool, device=device) for b in valid_indices_lilt: nq = all_q_embeds_lilt[b].shape[0] q_embeds_lilt_padded[b, :nq] = all_q_embeds_lilt[b] q_bboxes_padded[b, :nq] = all_q_bboxes[b] q_span_mask[b, :nq] = True na = all_a_embeds[b].shape[0] a_embeds_padded[b, :na] = all_a_embeds[b] a_bboxes_padded[b, :na] = all_a_bboxes[b] a_span_mask[b, :na] = True # Step 1.5: Q-A Linker - predict which Q links to which A (uses LiLT embeds) qa_link_scores_list = [] for b in valid_indices_lilt: nq = all_q_embeds_lilt[b].shape[0] na = all_a_embeds[b].shape[0] if nq > 0 and na > 0: link_scores = self.qa_linker( all_q_embeds_lilt[b], # (num_q, 768) LiLT all_a_embeds[b], # (num_a, 768) LiLT all_q_bboxes[b], # (num_q, 4) all_a_bboxes[b], # (num_a, 4) ) qa_link_scores_list.append(link_scores) else: qa_link_scores_list.append(None) outputs["qa_link_scores"] = qa_link_scores_list # Create padded tensors for Q (sentence transformer, 384 dim) - for query matching if valid_indices_st: max_q_spans_st = max(all_q_embeds_st[i].shape[0] for i in valid_indices_st) q_embeds_st_padded = torch.zeros(batch_size, max_q_spans_st, self.query_embed_dim, device=device) q_span_mask_st = torch.zeros(batch_size, max_q_spans_st, dtype=torch.bool, device=device) for b in valid_indices_st: nq = all_q_embeds_st[b].shape[0] q_embeds_st_padded[b, :nq] = all_q_embeds_st[b] q_span_mask_st[b, :nq] = True outputs["q_embeds_st"] = q_embeds_st_padded # For query matching outputs["q_span_mask_st"] = q_span_mask_st # GT-based QA link scores (for training - uses GT spans directly, LiLT embeddings) if training and gt_question_spans is not None and gt_answer_spans is not None: gt_qa_link_scores_list = [] gt_valid_q_indices_list = [] # Track which Q indices are valid (not None) for b in range(batch_size): gt_q_spans = gt_question_spans[b] if b < len(gt_question_spans) else [] gt_a_spans = gt_answer_spans[b] if b < len(gt_answer_spans) else [] # Filter valid Q spans (not None) and track indices valid_q_indices = [i for i, q in enumerate(gt_q_spans) if q is not None] valid_q_spans = [gt_q_spans[i] for i in valid_q_indices] if len(valid_q_spans) > 0 and len(gt_a_spans) > 0: # Pool embeddings from GT spans (LiLT, 768 dim) gt_q_embeds = torch.stack([ self.pool_single_span(hidden_states[b], q_span, bbox[b]) for q_span in valid_q_spans ]) gt_a_embeds = torch.stack([ self.pool_single_span(hidden_states[b], a_span, bbox[b]) for a_span in gt_a_spans ]) gt_q_bboxes = torch.stack([ self.get_span_bbox(bbox[b], q_span) for q_span in valid_q_spans ]) gt_a_bboxes = torch.stack([ self.get_span_bbox(bbox[b], a_span) for a_span in gt_a_spans ]) # Compute QA link scores on GT gt_link_scores = self.qa_linker( gt_q_embeds, gt_a_embeds, gt_q_bboxes, gt_a_bboxes ) gt_qa_link_scores_list.append(gt_link_scores) gt_valid_q_indices_list.append(valid_q_indices) else: gt_qa_link_scores_list.append(None) gt_valid_q_indices_list.append(None) outputs["gt_qa_link_scores"] = gt_qa_link_scores_list outputs["gt_valid_q_indices"] = gt_valid_q_indices_list # Step 2: Query-Question Matching (ZERO-SHOT with Sentence Transformer) # Use pre-computed q_embeds_st for query matching if valid_indices_st and "q_embeds_st" in outputs: q_embeds_st_padded = outputs["q_embeds_st"] q_span_mask_st = outputs["q_span_mask_st"] # Compute cosine similarity between query and Q embeddings query_norm = F.normalize(query_emb, dim=-1) q_norm = F.normalize(q_embeds_st_padded, dim=-1) match_scores = torch.bmm(q_norm, query_norm.unsqueeze(-1)).squeeze(-1) match_scores = match_scores.masked_fill(~q_span_mask_st, float('-inf')) outputs["match_scores"] = match_scores outputs["match_span_mask"] = q_span_mask_st # Store LiLT embeddings for QA Linker if valid_indices_lilt: outputs["q_span_mask"] = q_span_mask outputs["a_span_mask"] = a_span_mask outputs["q_embeds"] = q_embeds_lilt_padded outputs["a_embeds"] = a_embeds_padded outputs["q_bboxes"] = q_bboxes_padded outputs["a_bboxes"] = a_bboxes_padded # Store for loss computation (Q masking / QuestionPredictor) outputs["pred_q_from_a"] = all_pred_q_from_a outputs["real_q_from_pred"] = all_real_q_from_pred # Real Q (ST) when masked outputs["real_q_embeds"] = all_real_q_embeds # GT Q embeds (ST) outputs["gt_to_pred_a_idx"] = all_gt_to_pred_a_idx outputs["valid_batch_indices_lilt"] = valid_indices_lilt outputs["valid_batch_indices_st"] = valid_indices_st return outputs def compute_loss( outputs: Dict[str, torch.Tensor], token_labels: torch.Tensor, attention_mask: torch.Tensor, table_labels: Optional[torch.Tensor] = None, row_labels: Optional[torch.Tensor] = None, header_labels: Optional[torch.Tensor] = None, match_labels: Optional[torch.Tensor] = None, col_labels: Optional[torch.Tensor] = None, class_weights: Optional[torch.Tensor] = None, qa_link_labels: Optional[List[torch.Tensor]] = None, # NEW: Q-A link labels ) -> Dict[str, torch.Tensor]: """Compute joint loss for all steps including Q-A linking.""" device = token_labels.device losses = {} # Step 1: Token Classification token_logits = outputs["token_logits"] token_logits_flat = token_logits.view(-1, token_logits.shape[-1]) token_labels_flat = token_labels.view(-1) attn_flat = attention_mask.view(-1).bool() if class_weights is not None: step1_loss = F.cross_entropy( token_logits_flat[attn_flat], token_labels_flat[attn_flat], weight=class_weights, ) else: step1_loss = F.cross_entropy( token_logits_flat[attn_flat], token_labels_flat[attn_flat], ) losses["step1_loss"] = step1_loss # Step 1.5: Q-A Linker Loss (uses GT spans directly - no dependency on Step 1 predictions) if "gt_qa_link_scores" in outputs and qa_link_labels is not None: qa_link_losses = [] gt_scores_list = outputs["gt_qa_link_scores"] gt_valid_q_indices = outputs.get("gt_valid_q_indices", [None] * len(gt_scores_list)) for scores, labels, valid_indices in zip(gt_scores_list, qa_link_labels, gt_valid_q_indices): if scores is None or labels is None or valid_indices is None: continue # Filter labels to only include valid Q indices (matching scores shape) filtered_labels = labels[valid_indices] if len(valid_indices) > 0 else labels if scores.numel() > 0 and filtered_labels.numel() > 0: # scores: (num_valid_q, num_gt_a), filtered_labels: (num_valid_q,) num_a = scores.shape[1] valid_mask = (filtered_labels >= 0) & (filtered_labels < num_a) if valid_mask.any(): link_loss = F.cross_entropy(scores[valid_mask], filtered_labels[valid_mask]) qa_link_losses.append(link_loss) if qa_link_losses: losses["qa_link_loss"] = torch.stack(qa_link_losses).mean() else: losses["qa_link_loss"] = torch.tensor(0.0, device=device) else: losses["qa_link_loss"] = torch.tensor(0.0, device=device) # Question Predictor Loss (MSE between predicted Q and real Q in Sentence Transformer space) # QuestionPredictor learns: A_lilt_embed (768) → Q_st_embed (384) # Training: when Q is masked, compare predicted Q with real Q (sentence transformer) if "pred_q_from_a" in outputs and "real_q_from_pred" in outputs: pred_q_from_a = outputs["pred_q_from_a"] # List of lists real_q_from_pred = outputs["real_q_from_pred"] # List of lists q_predict_losses = [] for batch_pred, batch_real in zip(pred_q_from_a, real_q_from_pred): if batch_pred is None or batch_real is None: continue for pred_q, real_q in zip(batch_pred, batch_real): if pred_q is not None and real_q is not None: # MSE loss to make predicted Q similar to real Q mse = F.mse_loss(pred_q, real_q.detach()) q_predict_losses.append(mse) if q_predict_losses: losses["q_predict_loss"] = torch.stack(q_predict_losses).mean() else: losses["q_predict_loss"] = torch.tensor(0.0, device=device) else: losses["q_predict_loss"] = torch.tensor(0.0, device=device) # Step 2: Query-Question Matching - ZERO-SHOT (no loss) # Zero-shot Q text matching - no training needed # QuestionPredictor IS trained (for A→Q prediction when no Q regions) losses["step2_loss"] = torch.tensor(0.0, device=device) # Step 3: Table Losses if table_labels is not None: table_logits = outputs["table_logits"] table_logits_flat = table_logits.view(-1, 2) table_labels_flat = table_labels.view(-1) table_det_loss = F.cross_entropy( table_logits_flat[attn_flat], table_labels_flat[attn_flat], ) losses["table_det_loss"] = table_det_loss else: losses["table_det_loss"] = torch.tensor(0.0, device=device) if header_labels is not None and table_labels is not None: header_logits = outputs["header_logits"] table_mask = (table_labels == 1) & attention_mask.bool() if table_mask.any(): header_logits_flat = header_logits.view(-1, 2) header_labels_flat = header_labels.view(-1) table_mask_flat = table_mask.view(-1) header_loss = F.cross_entropy( header_logits_flat[table_mask_flat], header_labels_flat[table_mask_flat], ) losses["header_loss"] = header_loss else: losses["header_loss"] = torch.tensor(0.0, device=device) else: losses["header_loss"] = torch.tensor(0.0, device=device) if row_labels is not None and table_labels is not None: row_logits = outputs["row_logits"] table_mask = (table_labels == 1) & attention_mask.bool() if table_mask.any(): row_logits_flat = row_logits.view(-1, 3) row_labels_flat = row_labels.view(-1) table_mask_flat = table_mask.view(-1) row_loss = F.cross_entropy( row_logits_flat[table_mask_flat], row_labels_flat[table_mask_flat], ) losses["row_loss"] = row_loss else: losses["row_loss"] = torch.tensor(0.0, device=device) else: losses["row_loss"] = torch.tensor(0.0, device=device) if col_labels is not None and table_labels is not None and header_labels is not None: col_scores = outputs["col_scores"] cell_mask = (table_labels == 1) & (header_labels == 0) & attention_mask.bool() if cell_mask.any(): col_scores_flat = col_scores.view(-1, col_scores.shape[-1]) col_labels_flat = col_labels.view(-1) cell_mask_flat = cell_mask.view(-1) col_loss = F.cross_entropy( col_scores_flat[cell_mask_flat], col_labels_flat[cell_mask_flat], ) losses["col_loss"] = col_loss else: losses["col_loss"] = torch.tensor(0.0, device=device) else: losses["col_loss"] = torch.tensor(0.0, device=device) # Total loss with weighted components total_loss = ( losses["step1_loss"] + # Token classification (main task) losses["qa_link_loss"] * 0.3 + # Q-A linker (reduced to not compete with step1) losses["q_predict_loss"] * 0.2 + # Q prediction aux (reduced) losses["step2_loss"] * 0.5 + # Query-Q matching (reduced) losses["table_det_loss"] * 0.5 + # Table detection losses["header_loss"] * 0.3 + # Header detection losses["row_loss"] * 0.3 + # Row segmentation losses["col_loss"] * 0.3 # Column assignment ) losses["loss"] = total_loss return losses def compute_contrastive_loss( query_embedding: torch.Tensor, # (batch, hidden) span_embeddings: torch.Tensor, # (batch, num_spans, span_dim) span_mask: torch.Tensor, # (batch, num_spans) match_labels: torch.Tensor, # (batch,) - index of correct span temperature: float = 0.07, hard_negative_weight: float = 2.0, ) -> torch.Tensor: """ Contrastive loss for query-answer matching. InfoNCE-style loss that: - Pulls query towards correct answer span - Pushes query away from all other spans (in-batch negatives) - Applies extra weight to hard negatives (high-scoring wrong answers) Args: query_embedding: Query [CLS] embeddings span_embeddings: Pooled span embeddings [h_start; h_end; bbox] span_mask: Valid span mask match_labels: Index of correct span for each query temperature: Softmax temperature (lower = harder) hard_negative_weight: Extra weight for hard negatives Returns: Contrastive loss scalar """ batch_size = query_embedding.shape[0] device = query_embedding.device if batch_size == 0: return torch.tensor(0.0, device=device) # Project spans to same dimension as query if needed # (This is already done by QueryAnswerMatcher, so we use the projected scores) # Here we compute a simpler version using the match_scores from forward # For now, use a margin-based contrastive approach # We want: sim(q, correct_span) > sim(q, wrong_span) + margin losses = [] for b in range(batch_size): valid_spans = span_mask[b].sum().item() if valid_spans <= 1: continue correct_idx = match_labels[b].item() if correct_idx >= valid_spans: continue # Get span embeddings for this sample spans = span_embeddings[b, :int(valid_spans), :] # (num_valid, dim) query = query_embedding[b] # (hidden,) # Simple approach: normalize and compute similarities # The projection happens in QueryAnswerMatcher, so we compute raw similarities here # This loss is auxiliary to the CE loss # Normalize for cosine similarity spans_norm = F.normalize(spans[:, :query.shape[0]], dim=-1) # Use first hidden_size dims query_norm = F.normalize(query, dim=-1) # Compute similarities sims = torch.matmul(spans_norm, query_norm) / temperature # (num_valid,) # Create target (one-hot) target = torch.zeros(int(valid_spans), device=device) target[correct_idx] = 1.0 # InfoNCE: -log(exp(sim_pos) / sum(exp(sim_all))) # With hard negative weighting weights = torch.ones(int(valid_spans), device=device) weights[correct_idx] = 1.0 # Hard negatives: spans with high similarity but wrong with torch.no_grad(): hard_neg_mask = (sims > sims[correct_idx] - 0.5) & (torch.arange(int(valid_spans), device=device) != correct_idx) weights[hard_neg_mask] = hard_negative_weight # Weighted softmax cross-entropy log_probs = F.log_softmax(sims, dim=0) loss = -log_probs[correct_idx] # Add margin loss for hard negatives for i in range(int(valid_spans)): if i != correct_idx and hard_neg_mask[i] if hard_neg_mask.any() else False: margin_loss = F.relu(sims[i] - sims[correct_idx] + 0.3) loss = loss + 0.1 * margin_loss losses.append(loss) if not losses: return torch.tensor(0.0, device=device) return torch.stack(losses).mean() def compute_loss_with_contrastive( outputs: Dict[str, torch.Tensor], token_labels: torch.Tensor, attention_mask: torch.Tensor, table_labels: Optional[torch.Tensor] = None, row_labels: Optional[torch.Tensor] = None, header_labels: Optional[torch.Tensor] = None, match_labels: Optional[torch.Tensor] = None, col_labels: Optional[torch.Tensor] = None, class_weights: Optional[torch.Tensor] = None, qa_link_labels: Optional[List[torch.Tensor]] = None, contrastive_weight: float = 0.5, ) -> Dict[str, torch.Tensor]: """ Compute joint loss with contrastive learning for Step 2. Adds InfoNCE-style contrastive loss to help learn better span representations. """ # Get base losses losses = compute_loss( outputs=outputs, token_labels=token_labels, attention_mask=attention_mask, table_labels=table_labels, row_labels=row_labels, header_labels=header_labels, match_labels=match_labels, col_labels=col_labels, class_weights=class_weights, qa_link_labels=qa_link_labels, ) # Contrastive loss for Step 2 is DISABLED for zero-shot matching # Zero-shot matching - no contrastive loss needed losses["contrastive_loss"] = torch.tensor(0.0, device=token_labels.device) return losses if __name__ == "__main__": print("Testing DEXTR LiLT model with CASCADING architecture...") print("=" * 60) model = DEXTRLiLT( model_name="nielsr/lilt-xlm-roberta-base", max_seq_len=1024, ) total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Total parameters: {total_params:,}") print(f"Trainable parameters: {trainable_params:,}") # Test forward batch_size = 2 seq_len = 128 query_len = 16 input_ids = torch.randint(0, 1000, (batch_size, seq_len)) attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long) # Create valid bboxes: [x1, y1, x2, y2] where x1 0: match_labels = torch.zeros(batch_size, dtype=torch.long) losses = compute_loss( outputs, token_labels=token_labels, attention_mask=attention_mask, table_labels=table_labels, row_labels=row_labels, header_labels=header_labels, match_labels=match_labels, ) print(f"\nLosses:") for name, value in losses.items(): print(f" {name}: {value.item():.4f}") print("\n" + "=" * 60) print("DEXTR LiLT CASCADING architecture test passed!")