|
|
""" |
|
|
DEXTR with LiLT - 3-Step Document Extraction. |
|
|
|
|
|
Architecture: |
|
|
- Document encoder: LiLT (XLM-RoBERTa + Layout Transformer), extended to 1024 tokens |
|
|
- Query encoder: Sentence Transformer (frozen, for semantic similarity) |
|
|
- Step 1: Token Classification (Q/A/T/H/O) - TRAINED |
|
|
- Step 2: Query-Question Matching (ZERO-SHOT) - NO TRAINING |
|
|
- Step 3: Table Head (hierarchical with attention-based column assignment) - TRAINED |
|
|
|
|
|
Key insight: Query→Question matching is semantic similarity. Sentence transformers |
|
|
provide proper semantic embeddings for matching queries to question text. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import LiltModel, LiltConfig |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from typing import Optional, Dict, List, Tuple |
|
|
|
|
|
|
|
|
|
|
|
LABEL2ID = { |
|
|
"O": 0, |
|
|
"B-Q": 1, "I-Q": 2, |
|
|
"B-A": 3, "I-A": 4, |
|
|
"B-H": 5, "I-H": 6, |
|
|
"B-TABLE": 7, "I-TABLE": 8, |
|
|
} |
|
|
ID2LABEL = {v: k for k, v in LABEL2ID.items()} |
|
|
NUM_LABELS = len(LABEL2ID) |
|
|
|
|
|
|
|
|
class TokenClassificationHead(nn.Module): |
|
|
""" |
|
|
Step 1: Token Classification Head. |
|
|
Predicts Q/A/H/T/O labels for each token (FUNSD-style). |
|
|
""" |
|
|
|
|
|
def __init__(self, hidden_size: int = 768, num_labels: int = NUM_LABELS, dropout: float = 0.1): |
|
|
super().__init__() |
|
|
self.classifier = nn.Sequential( |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_size, 256), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(256, num_labels), |
|
|
) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
return self.classifier(hidden_states) |
|
|
|
|
|
|
|
|
class QALinker(nn.Module): |
|
|
""" |
|
|
Step 1.5: Q-A Linker. |
|
|
Predicts which Answer span links to which Question span. |
|
|
Uses both semantic similarity and spatial features. |
|
|
""" |
|
|
|
|
|
def __init__(self, hidden_size: int = 768, dropout: float = 0.1): |
|
|
super().__init__() |
|
|
self.hidden_size = hidden_size |
|
|
|
|
|
|
|
|
self.q_proj = nn.Sequential( |
|
|
nn.Linear(hidden_size, 256), |
|
|
nn.LayerNorm(256), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
) |
|
|
self.a_proj = nn.Sequential( |
|
|
nn.Linear(hidden_size, 256), |
|
|
nn.LayerNorm(256), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.spatial_scorer = nn.Sequential( |
|
|
nn.Linear(7, 32), |
|
|
nn.GELU(), |
|
|
nn.Linear(32, 1), |
|
|
) |
|
|
|
|
|
|
|
|
self.log_temp = nn.Parameter(torch.tensor(1.0)) |
|
|
|
|
|
def compute_spatial_features( |
|
|
self, |
|
|
q_bboxes: torch.Tensor, |
|
|
a_bboxes: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Compute spatial features for all Q-A pairs. |
|
|
Returns: (num_q, num_a, 7) feature tensor |
|
|
""" |
|
|
num_q = q_bboxes.shape[0] |
|
|
num_a = a_bboxes.shape[0] |
|
|
device = q_bboxes.device |
|
|
|
|
|
|
|
|
q_cx = (q_bboxes[:, 0] + q_bboxes[:, 2]) / 2 |
|
|
q_cy = (q_bboxes[:, 1] + q_bboxes[:, 3]) / 2 |
|
|
q_w = q_bboxes[:, 2] - q_bboxes[:, 0] |
|
|
q_h = q_bboxes[:, 3] - q_bboxes[:, 1] |
|
|
|
|
|
a_cx = (a_bboxes[:, 0] + a_bboxes[:, 2]) / 2 |
|
|
a_cy = (a_bboxes[:, 1] + a_bboxes[:, 3]) / 2 |
|
|
a_w = a_bboxes[:, 2] - a_bboxes[:, 0] |
|
|
a_h = a_bboxes[:, 3] - a_bboxes[:, 1] |
|
|
|
|
|
|
|
|
q_cx = q_cx.unsqueeze(1).expand(num_q, num_a) |
|
|
q_cy = q_cy.unsqueeze(1).expand(num_q, num_a) |
|
|
q_x2 = q_bboxes[:, 2].unsqueeze(1).expand(num_q, num_a) |
|
|
q_y2 = q_bboxes[:, 3].unsqueeze(1).expand(num_q, num_a) |
|
|
q_w = q_w.unsqueeze(1).expand(num_q, num_a) |
|
|
q_h = q_h.unsqueeze(1).expand(num_q, num_a) |
|
|
|
|
|
a_cx = a_cx.unsqueeze(0).expand(num_q, num_a) |
|
|
a_cy = a_cy.unsqueeze(0).expand(num_q, num_a) |
|
|
a_x1 = a_bboxes[:, 0].unsqueeze(0).expand(num_q, num_a) |
|
|
a_y1 = a_bboxes[:, 1].unsqueeze(0).expand(num_q, num_a) |
|
|
a_w = a_w.unsqueeze(0).expand(num_q, num_a) |
|
|
a_h = a_h.unsqueeze(0).expand(num_q, num_a) |
|
|
|
|
|
|
|
|
x_dist = (a_x1 - q_x2) / 1000.0 |
|
|
y_dist = (a_cy - q_cy).abs() / 1000.0 |
|
|
is_right = (a_x1 > q_x2).float() |
|
|
is_below = (a_y1 > q_y2).float() |
|
|
is_same_line = (y_dist < 0.03).float() |
|
|
width_ratio = (a_w / (q_w + 1e-6)).clamp(0, 5) / 5.0 |
|
|
height_ratio = (a_h / (q_h + 1e-6)).clamp(0, 5) / 5.0 |
|
|
|
|
|
|
|
|
features = torch.stack([ |
|
|
x_dist, y_dist, is_right, is_below, is_same_line, width_ratio, height_ratio |
|
|
], dim=-1) |
|
|
|
|
|
return features |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
q_embeds: torch.Tensor, |
|
|
a_embeds: torch.Tensor, |
|
|
q_bboxes: torch.Tensor, |
|
|
a_bboxes: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Compute Q-A link scores. |
|
|
Returns: (num_q, num_a) score matrix |
|
|
""" |
|
|
|
|
|
q_proj = self.q_proj(q_embeds) |
|
|
a_proj = self.a_proj(a_embeds) |
|
|
|
|
|
q_proj = F.normalize(q_proj, dim=-1) |
|
|
a_proj = F.normalize(a_proj, dim=-1) |
|
|
|
|
|
semantic_scores = torch.matmul(q_proj, a_proj.t()) |
|
|
|
|
|
|
|
|
spatial_feats = self.compute_spatial_features(q_bboxes, a_bboxes) |
|
|
spatial_scores = self.spatial_scorer(spatial_feats).squeeze(-1) |
|
|
|
|
|
|
|
|
temperature = self.log_temp.exp().clamp(min=0.1, max=10.0) |
|
|
combined_scores = (semantic_scores + spatial_scores) * temperature |
|
|
|
|
|
return combined_scores |
|
|
|
|
|
|
|
|
class QuestionPredictor(nn.Module): |
|
|
""" |
|
|
Predicts Question embedding from Answer embedding. |
|
|
Used when explicit Question tokens are missing in the document. |
|
|
|
|
|
Input: LiLT A embedding (768 dim) |
|
|
Output: Predicted Q embedding in sentence transformer space (384 dim) |
|
|
|
|
|
Training: |
|
|
- For Q-A pairs with Q: randomly mask Q and train to predict it |
|
|
- Learns what Q "should look like" given the answer |
|
|
""" |
|
|
|
|
|
def __init__(self, input_dim: int = 768, output_dim: int = 384, dropout: float = 0.1): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
self.predictor = nn.Sequential( |
|
|
nn.Linear(input_dim, input_dim), |
|
|
nn.LayerNorm(input_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(input_dim, output_dim), |
|
|
nn.LayerNorm(output_dim), |
|
|
) |
|
|
|
|
|
def forward(self, answer_embed: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Predict question embedding from answer embedding. |
|
|
Args: |
|
|
answer_embed: (batch, hidden) or (hidden,) - LiLT answer span embedding |
|
|
Returns: |
|
|
predicted_q_embed: (batch, output_dim) or (output_dim,) - in sentence transformer space |
|
|
""" |
|
|
return self.predictor(answer_embed) |
|
|
|
|
|
|
|
|
class QueryAnswerMatcher(nn.Module): |
|
|
""" |
|
|
Step 2: Query-Answer Matching. |
|
|
Given answer span embeddings and query embedding, scores which span matches. |
|
|
|
|
|
Features: |
|
|
- Start+End pooling: spans use [h_start; h_end] instead of mean (captures boundaries) |
|
|
- Bbox features: spatial position helps distinguish similar text |
|
|
- Learnable temperature for score scaling |
|
|
- Cosine similarity with projection to lower dim |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
hidden_size: int = 768, |
|
|
proj_dim: int = 256, |
|
|
dropout: float = 0.1, |
|
|
use_bbox_features: bool = True, |
|
|
num_bbox_features: int = 6, |
|
|
): |
|
|
super().__init__() |
|
|
self.proj_dim = proj_dim |
|
|
self.use_bbox_features = use_bbox_features |
|
|
self.hidden_size = hidden_size |
|
|
|
|
|
|
|
|
span_input_dim = 2 * hidden_size |
|
|
if use_bbox_features: |
|
|
span_input_dim += num_bbox_features |
|
|
|
|
|
self.span_proj = nn.Sequential( |
|
|
nn.Linear(span_input_dim, proj_dim), |
|
|
nn.LayerNorm(proj_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
) |
|
|
self.query_proj = nn.Sequential( |
|
|
nn.Linear(hidden_size, proj_dim), |
|
|
nn.LayerNorm(proj_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
) |
|
|
|
|
|
self.log_temp = nn.Parameter(torch.tensor(proj_dim ** 0.5).log()) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
span_embeddings: torch.Tensor, |
|
|
query_embedding: torch.Tensor, |
|
|
span_mask: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
span_proj = self.span_proj(span_embeddings) |
|
|
query_proj = self.query_proj(query_embedding).unsqueeze(1) |
|
|
|
|
|
|
|
|
span_proj = F.normalize(span_proj, dim=-1) |
|
|
query_proj = F.normalize(query_proj, dim=-1) |
|
|
|
|
|
|
|
|
temperature = self.log_temp.exp().clamp(min=0.1, max=100.0) |
|
|
scores = torch.bmm(query_proj, span_proj.transpose(1, 2)).squeeze(1) * temperature |
|
|
|
|
|
if span_mask is not None: |
|
|
scores = scores.masked_fill(~span_mask.bool(), float('-inf')) |
|
|
return scores |
|
|
|
|
|
|
|
|
class TableHead(nn.Module): |
|
|
""" |
|
|
Step 3: Hierarchical Table Extraction Head. |
|
|
|
|
|
Sub-steps: |
|
|
1. Table detection (TABLE/O per token) |
|
|
2. Header detection (HEADER/CELL per token) |
|
|
3. Row segmentation (B-ROW/I-ROW/O) |
|
|
4. Column assignment (cell → header via attention) |
|
|
""" |
|
|
|
|
|
def __init__(self, hidden_size: int = 768, dropout: float = 0.1): |
|
|
super().__init__() |
|
|
|
|
|
self.table_detector = nn.Sequential( |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_size, 256), |
|
|
nn.GELU(), |
|
|
nn.Linear(256, 2), |
|
|
) |
|
|
|
|
|
self.header_detector = nn.Sequential( |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_size, 256), |
|
|
nn.GELU(), |
|
|
nn.Linear(256, 2), |
|
|
) |
|
|
|
|
|
self.row_tagger = nn.Sequential( |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_size, 256), |
|
|
nn.GELU(), |
|
|
nn.Linear(256, 3), |
|
|
) |
|
|
|
|
|
self.col_key_proj = nn.Linear(hidden_size, 128) |
|
|
self.col_query_proj = nn.Linear(hidden_size, 128) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> Dict[str, torch.Tensor]: |
|
|
table_logits = self.table_detector(hidden_states) |
|
|
header_logits = self.header_detector(hidden_states) |
|
|
row_logits = self.row_tagger(hidden_states) |
|
|
|
|
|
col_keys = self.col_key_proj(hidden_states) |
|
|
col_queries = self.col_query_proj(hidden_states) |
|
|
col_scores = torch.bmm(col_queries, col_keys.transpose(1, 2)) / (128 ** 0.5) |
|
|
|
|
|
return { |
|
|
"table_logits": table_logits, |
|
|
"header_logits": header_logits, |
|
|
"row_logits": row_logits, |
|
|
"col_scores": col_scores, |
|
|
} |
|
|
|
|
|
|
|
|
class DEXTRLiLT(nn.Module): |
|
|
""" |
|
|
DEXTR with LiLT: Multi-Step Document Extraction with Zero-Shot Query Matching. |
|
|
|
|
|
Architecture: |
|
|
- Step 1: Token Classification (Q/A/H/T/O) - TRAINED |
|
|
- Step 1.5: Q-A Linker (links Question spans to Answer spans) - TRAINED |
|
|
- Step 2: Query → Question Matching (ZERO-SHOT with Sentence Transformer) |
|
|
- Step 3: Table Head - TRAINED |
|
|
|
|
|
Uses SEPARATE encoders: |
|
|
- Document encoder: LiLT (XLM-RoBERTa + Layout Transformer) for layout-aware encoding |
|
|
- Query encoder: Sentence Transformer (frozen) for semantic similarity |
|
|
|
|
|
Zero-shot matching: |
|
|
- Query text encoded with frozen Sentence Transformer |
|
|
- Question text (from predicted Q regions) encoded with frozen Sentence Transformer |
|
|
- Direct cosine similarity (no trainable projections) |
|
|
- Fallback: if no Q regions, use QuestionPredictor on A text |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "nielsr/lilt-xlm-roberta-base", |
|
|
query_model_name: str = "paraphrase-multilingual-MiniLM-L12-v2", |
|
|
max_seq_len: int = 1024, |
|
|
hidden_size: int = 768, |
|
|
dropout: float = 0.1, |
|
|
q_mask_prob: float = 0.3, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.max_seq_len = max_seq_len |
|
|
self.hidden_size = hidden_size |
|
|
self.q_mask_prob = q_mask_prob |
|
|
|
|
|
|
|
|
if max_seq_len > 512: |
|
|
self.encoder = self._load_lilt_extended(model_name, max_seq_len) |
|
|
else: |
|
|
self.encoder = LiltModel.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
|
|
|
self.query_encoder = SentenceTransformer(query_model_name) |
|
|
self.query_encoder.requires_grad_(False) |
|
|
self.query_embed_dim = self.query_encoder.get_sentence_embedding_dimension() |
|
|
print(f"Loaded frozen query encoder: {query_model_name} (dim={self.query_embed_dim})") |
|
|
|
|
|
|
|
|
self.token_classifier = TokenClassificationHead(hidden_size, NUM_LABELS, dropout) |
|
|
|
|
|
|
|
|
self.qa_linker = QALinker(hidden_size, dropout) |
|
|
|
|
|
|
|
|
|
|
|
self.question_predictor = QuestionPredictor(hidden_size, self.query_embed_dim, dropout) |
|
|
|
|
|
|
|
|
self.table_head = TableHead(hidden_size, dropout) |
|
|
|
|
|
def _load_lilt_extended(self, model_name: str, max_seq_len: int) -> LiltModel: |
|
|
"""Load LiLT with extended position embeddings.""" |
|
|
model = LiltModel.from_pretrained(model_name) |
|
|
original_max_pos = model.config.max_position_embeddings |
|
|
required_positions = max_seq_len + 2 |
|
|
|
|
|
if required_positions <= original_max_pos: |
|
|
return model |
|
|
|
|
|
|
|
|
old_pos_emb = model.embeddings.position_embeddings.weight.data |
|
|
hidden_size = old_pos_emb.shape[1] |
|
|
padding_idx = model.embeddings.position_embeddings.padding_idx |
|
|
|
|
|
new_pos_emb = nn.Embedding(required_positions, hidden_size, padding_idx=padding_idx) |
|
|
new_pos_emb.weight.data[:original_max_pos] = old_pos_emb |
|
|
new_pos_emb.weight.data[original_max_pos:].normal_(mean=0.0, std=0.02) |
|
|
|
|
|
model.embeddings.position_embeddings = new_pos_emb |
|
|
|
|
|
|
|
|
old_box_pos = model.layout_embeddings.box_position_embeddings.weight.data |
|
|
box_hidden = old_box_pos.shape[1] |
|
|
|
|
|
new_box_pos = nn.Embedding(required_positions, box_hidden) |
|
|
new_box_pos.weight.data[:original_max_pos] = old_box_pos |
|
|
new_box_pos.weight.data[original_max_pos:].normal_(mean=0.0, std=0.02) |
|
|
|
|
|
model.layout_embeddings.box_position_embeddings = new_box_pos |
|
|
|
|
|
|
|
|
model.config.max_position_embeddings = required_positions |
|
|
|
|
|
|
|
|
new_position_ids = torch.arange(required_positions).unsqueeze(0) |
|
|
model.embeddings.register_buffer("position_ids", new_position_ids, persistent=False) |
|
|
|
|
|
print(f"Extended LiLT positions: {original_max_pos} → {required_positions}") |
|
|
return model |
|
|
|
|
|
def encode( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
bbox: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
"""Encode tokens with layout using shared LiLT encoder.""" |
|
|
outputs = self.encoder( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
bbox=bbox, |
|
|
) |
|
|
return outputs.last_hidden_state |
|
|
|
|
|
def encode_texts( |
|
|
self, |
|
|
texts: List[str], |
|
|
device: torch.device, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Encode text strings using frozen Sentence Transformer (zero-shot). |
|
|
|
|
|
Used for encoding Q text extracted from predicted regions. |
|
|
|
|
|
Args: |
|
|
texts: List of text strings to encode |
|
|
device: Device to put tensors on |
|
|
|
|
|
Returns: |
|
|
embeddings: (num_texts, embed_dim) tensor |
|
|
""" |
|
|
if not texts: |
|
|
return torch.zeros(0, self.query_embed_dim, device=device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
embeddings = self.query_encoder.encode( |
|
|
texts, |
|
|
convert_to_tensor=True, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
embeddings = embeddings.clone() |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def pool_spans( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
span_indices: List[List[Tuple[int, int]]], |
|
|
bbox: Optional[torch.Tensor] = None, |
|
|
use_bbox_features: bool = True, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Pool hidden states for answer spans using start+end tokens. |
|
|
|
|
|
Returns [h_start; h_end; bbox_features] for each span instead of mean pooling. |
|
|
This preserves boundary information which is crucial for extraction tasks. |
|
|
|
|
|
Args: |
|
|
hidden_states: (batch, seq, hidden) - document encodings |
|
|
span_indices: List of (start, end) tuples per batch item |
|
|
bbox: (batch, seq, 4) - bounding boxes [x1, y1, x2, y2] normalized to [0, 1000] |
|
|
use_bbox_features: whether to include bbox in span representation |
|
|
|
|
|
Returns: |
|
|
span_embeddings: (batch, max_spans, span_dim) where span_dim = 2*hidden + 6 (if bbox) |
|
|
span_mask: (batch, max_spans) bool mask |
|
|
""" |
|
|
batch_size = hidden_states.shape[0] |
|
|
hidden_size = hidden_states.shape[2] |
|
|
device = hidden_states.device |
|
|
|
|
|
max_spans = max(len(spans) for spans in span_indices) if span_indices else 1 |
|
|
max_spans = max(max_spans, 1) |
|
|
|
|
|
|
|
|
span_dim = 2 * hidden_size |
|
|
if use_bbox_features and bbox is not None: |
|
|
span_dim += 6 |
|
|
|
|
|
span_embeddings = torch.zeros(batch_size, max_spans, span_dim, device=device) |
|
|
span_mask = torch.zeros(batch_size, max_spans, dtype=torch.bool, device=device) |
|
|
|
|
|
for b, spans in enumerate(span_indices): |
|
|
for s, (start, end) in enumerate(spans): |
|
|
if s >= max_spans: |
|
|
break |
|
|
if start >= hidden_states.shape[1] or end > hidden_states.shape[1]: |
|
|
continue |
|
|
|
|
|
|
|
|
h_start = hidden_states[b, start, :] |
|
|
h_end = hidden_states[b, end - 1, :] |
|
|
span_repr = torch.cat([h_start, h_end], dim=0) |
|
|
|
|
|
|
|
|
if use_bbox_features and bbox is not None: |
|
|
|
|
|
bbox_start = bbox[b, start, :] |
|
|
bbox_end = bbox[b, end - 1, :] |
|
|
|
|
|
|
|
|
span_x1 = torch.min(bbox_start[0], bbox_end[0]) |
|
|
span_y1 = torch.min(bbox_start[1], bbox_end[1]) |
|
|
span_x2 = torch.max(bbox_start[2], bbox_end[2]) |
|
|
span_y2 = torch.max(bbox_start[3], bbox_end[3]) |
|
|
|
|
|
|
|
|
bbox_feat = torch.tensor([ |
|
|
span_x1 / 1000.0, |
|
|
span_y1 / 1000.0, |
|
|
span_x2 / 1000.0, |
|
|
span_y2 / 1000.0, |
|
|
(span_x2 - span_x1) / 1000.0, |
|
|
(span_y2 - span_y1) / 1000.0, |
|
|
], device=device) |
|
|
|
|
|
span_repr = torch.cat([span_repr, bbox_feat], dim=0) |
|
|
|
|
|
span_embeddings[b, s] = span_repr |
|
|
span_mask[b, s] = True |
|
|
|
|
|
return span_embeddings, span_mask |
|
|
|
|
|
def pool_single_span( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
span: Tuple[int, int], |
|
|
bbox: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Pool a single span to get its embedding. |
|
|
Returns mean of start and end tokens (hidden_size,). |
|
|
""" |
|
|
start, end = span |
|
|
if start >= hidden_states.shape[0] or end > hidden_states.shape[0]: |
|
|
return torch.zeros(self.hidden_size, device=hidden_states.device) |
|
|
|
|
|
h_start = hidden_states[start, :] |
|
|
h_end = hidden_states[end - 1, :] |
|
|
|
|
|
|
|
|
return (h_start + h_end) / 2 |
|
|
|
|
|
def get_span_bbox( |
|
|
self, |
|
|
bbox: torch.Tensor, |
|
|
span: Tuple[int, int], |
|
|
) -> torch.Tensor: |
|
|
"""Get bounding box for a span (union of start and end token bboxes).""" |
|
|
start, end = span |
|
|
bbox_start = bbox[start, :] |
|
|
bbox_end = bbox[end - 1, :] |
|
|
|
|
|
span_bbox = torch.stack([ |
|
|
torch.min(bbox_start[0], bbox_end[0]), |
|
|
torch.min(bbox_start[1], bbox_end[1]), |
|
|
torch.max(bbox_start[2], bbox_end[2]), |
|
|
torch.max(bbox_start[3], bbox_end[3]), |
|
|
]) |
|
|
return span_bbox |
|
|
|
|
|
def extract_span_text( |
|
|
self, |
|
|
span: Tuple[int, int], |
|
|
tokens: List[str], |
|
|
subword_to_word: Dict[int, int], |
|
|
) -> str: |
|
|
""" |
|
|
Extract text from a span using tokens and subword-to-word mapping. |
|
|
|
|
|
Args: |
|
|
span: (start, end) subword indices (exclusive end) |
|
|
tokens: List of word tokens |
|
|
subword_to_word: Dict mapping subword idx -> word idx |
|
|
|
|
|
Returns: |
|
|
Text string for the span |
|
|
""" |
|
|
start, end = span |
|
|
|
|
|
word_indices = set() |
|
|
for subword_idx in range(start, end): |
|
|
if subword_idx in subword_to_word: |
|
|
word_indices.add(subword_to_word[subword_idx]) |
|
|
|
|
|
if not word_indices: |
|
|
return "" |
|
|
|
|
|
|
|
|
min_word = min(word_indices) |
|
|
max_word = max(word_indices) |
|
|
|
|
|
|
|
|
span_tokens = tokens[min_word:max_word + 1] |
|
|
return " ".join(span_tokens) |
|
|
|
|
|
def extract_qa_regions( |
|
|
self, |
|
|
token_logits: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
) -> Tuple[List[List[Tuple[int, int]]], List[List[Tuple[int, int]]]]: |
|
|
""" |
|
|
Extract Q and A regions from Step 1 token predictions. |
|
|
|
|
|
This enables the cascading architecture where Step 2 uses |
|
|
Step 1's predictions instead of ground truth spans. |
|
|
|
|
|
Returns: |
|
|
q_regions: List of Q spans per batch item [(start, end), ...] |
|
|
a_regions: List of A spans per batch item [(start, end), ...] |
|
|
""" |
|
|
batch_size = token_logits.shape[0] |
|
|
preds = token_logits.argmax(dim=-1) |
|
|
|
|
|
q_regions = [] |
|
|
a_regions = [] |
|
|
|
|
|
for b in range(batch_size): |
|
|
sample_q_regions = [] |
|
|
sample_a_regions = [] |
|
|
|
|
|
seq_len = attention_mask[b].sum().item() |
|
|
pred_seq = preds[b, :int(seq_len)].cpu().tolist() |
|
|
|
|
|
|
|
|
current_span = None |
|
|
for i, label in enumerate(pred_seq): |
|
|
if label == 1: |
|
|
if current_span is not None: |
|
|
sample_q_regions.append(current_span) |
|
|
current_span = (i, i + 1) |
|
|
elif label == 2: |
|
|
if current_span is not None: |
|
|
current_span = (current_span[0], i + 1) |
|
|
else: |
|
|
if current_span is not None: |
|
|
sample_q_regions.append(current_span) |
|
|
current_span = None |
|
|
if current_span is not None: |
|
|
sample_q_regions.append(current_span) |
|
|
|
|
|
|
|
|
current_span = None |
|
|
for i, label in enumerate(pred_seq): |
|
|
if label == 3: |
|
|
if current_span is not None: |
|
|
sample_a_regions.append(current_span) |
|
|
current_span = (i, i + 1) |
|
|
elif label == 4: |
|
|
if current_span is not None: |
|
|
current_span = (current_span[0], i + 1) |
|
|
else: |
|
|
if current_span is not None: |
|
|
sample_a_regions.append(current_span) |
|
|
current_span = None |
|
|
if current_span is not None: |
|
|
sample_a_regions.append(current_span) |
|
|
|
|
|
q_regions.append(sample_q_regions) |
|
|
a_regions.append(sample_a_regions) |
|
|
|
|
|
return q_regions, a_regions |
|
|
|
|
|
def match_regions_to_gt( |
|
|
self, |
|
|
pred_regions: List[Tuple[int, int]], |
|
|
gt_regions: List[Tuple[int, int]], |
|
|
) -> Tuple[List[int], List[int]]: |
|
|
""" |
|
|
Match predicted regions to GT regions by overlap. |
|
|
|
|
|
Returns: |
|
|
gt_to_pred: For each GT region, index of best matching pred region (-1 if none) |
|
|
pred_to_gt: For each pred region, index of best matching GT region (-1 if none) |
|
|
""" |
|
|
gt_to_pred = [] |
|
|
for gt_start, gt_end in gt_regions: |
|
|
best_pred_idx = -1 |
|
|
best_overlap = 0 |
|
|
for pred_idx, (pred_start, pred_end) in enumerate(pred_regions): |
|
|
overlap_start = max(gt_start, pred_start) |
|
|
overlap_end = min(gt_end, pred_end) |
|
|
overlap = max(0, overlap_end - overlap_start) |
|
|
if overlap > best_overlap: |
|
|
best_overlap = overlap |
|
|
best_pred_idx = pred_idx |
|
|
gt_to_pred.append(best_pred_idx) |
|
|
|
|
|
pred_to_gt = [] |
|
|
for pred_start, pred_end in pred_regions: |
|
|
best_gt_idx = -1 |
|
|
best_overlap = 0 |
|
|
for gt_idx, (gt_start, gt_end) in enumerate(gt_regions): |
|
|
overlap_start = max(gt_start, pred_start) |
|
|
overlap_end = min(gt_end, pred_end) |
|
|
overlap = max(0, overlap_end - overlap_start) |
|
|
if overlap > best_overlap: |
|
|
best_overlap = overlap |
|
|
best_gt_idx = gt_idx |
|
|
pred_to_gt.append(best_gt_idx) |
|
|
|
|
|
return gt_to_pred, pred_to_gt |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
bbox: torch.Tensor, |
|
|
tokens: Optional[List[List[str]]] = None, |
|
|
subword_to_word: Optional[List[Dict[int, int]]] = None, |
|
|
query_texts: Optional[List[str]] = None, |
|
|
gt_answer_spans: Optional[List[List[Tuple[int, int]]]] = None, |
|
|
gt_question_spans: Optional[List[List[Optional[Tuple[int, int]]]]] = None, |
|
|
target_field_idx: Optional[List[int]] = None, |
|
|
training: bool = True, |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Forward pass with CASCADING architecture and ZERO-SHOT query matching. |
|
|
|
|
|
Key features: |
|
|
- Step 1: Token classification to predict Q/A regions |
|
|
- Step 1.5: Q-A linking using LiLT embeddings + spatial features |
|
|
- Step 2: ZERO-SHOT query→question matching using Sentence Transformer |
|
|
- GT spans are only used to compute labels (which predicted region is correct) |
|
|
|
|
|
Args: |
|
|
input_ids: (batch, seq) token IDs |
|
|
attention_mask: (batch, seq) attention mask |
|
|
bbox: (batch, seq, 4) bounding boxes |
|
|
tokens: List of word tokens per batch item (for zero-shot Q text encoding) |
|
|
subword_to_word: List of dicts mapping subword idx to word idx |
|
|
query_texts: List of raw query strings (for sentence transformer) |
|
|
gt_answer_spans: GT answer spans per batch item (for loss labels only) |
|
|
gt_question_spans: GT question spans per batch item (for loss labels only) |
|
|
target_field_idx: Index of GT field that query should match (for loss labels) |
|
|
training: whether in training mode (affects Q masking for Q-A linker) |
|
|
""" |
|
|
batch_size = input_ids.shape[0] |
|
|
device = input_ids.device |
|
|
|
|
|
|
|
|
hidden_states = self.encode(input_ids, attention_mask, bbox) |
|
|
|
|
|
|
|
|
token_logits = self.token_classifier(hidden_states) |
|
|
|
|
|
|
|
|
table_outputs = self.table_head(hidden_states) |
|
|
|
|
|
outputs = { |
|
|
"hidden_states": hidden_states, |
|
|
"token_logits": token_logits, |
|
|
**table_outputs, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if query_texts is not None: |
|
|
|
|
|
pred_q_regions, pred_a_regions = self.extract_qa_regions(token_logits, attention_mask) |
|
|
outputs["pred_q_regions"] = pred_q_regions |
|
|
outputs["pred_a_regions"] = pred_a_regions |
|
|
|
|
|
|
|
|
query_emb = self.encode_texts(query_texts, device) |
|
|
outputs["query_embedding"] = query_emb |
|
|
|
|
|
|
|
|
all_q_embeds_lilt = [] |
|
|
all_q_embeds_st = [] |
|
|
all_a_embeds = [] |
|
|
all_q_bboxes = [] |
|
|
all_a_bboxes = [] |
|
|
all_pred_q_from_a = [] |
|
|
all_real_q_from_pred = [] |
|
|
all_real_q_embeds = [] |
|
|
all_gt_to_pred_a_idx = [] |
|
|
|
|
|
for b in range(batch_size): |
|
|
sample_pred_q = pred_q_regions[b] |
|
|
sample_pred_a = pred_a_regions[b] |
|
|
|
|
|
|
|
|
if len(sample_pred_a) == 0: |
|
|
all_q_embeds_lilt.append(None) |
|
|
all_q_embeds_st.append(None) |
|
|
all_a_embeds.append(None) |
|
|
all_q_bboxes.append(None) |
|
|
all_a_bboxes.append(None) |
|
|
all_pred_q_from_a.append(None) |
|
|
all_real_q_from_pred.append(None) |
|
|
all_real_q_embeds.append(None) |
|
|
all_gt_to_pred_a_idx.append(None) |
|
|
continue |
|
|
|
|
|
|
|
|
sample_a_embeds = [] |
|
|
sample_a_bboxes = [] |
|
|
for a_span in sample_pred_a: |
|
|
a_emb = self.pool_single_span(hidden_states[b], a_span, bbox[b]) |
|
|
sample_a_embeds.append(a_emb) |
|
|
sample_a_bboxes.append(self.get_span_bbox(bbox[b], a_span)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sample_q_embeds_lilt = [] |
|
|
sample_q_embeds_st = [] |
|
|
sample_q_bboxes = [] |
|
|
sample_real_q_from_pred = [] |
|
|
sample_pred_q_from_a = [] |
|
|
|
|
|
if len(sample_pred_q) > 0: |
|
|
|
|
|
|
|
|
q_texts = [] |
|
|
if tokens is not None and subword_to_word is not None: |
|
|
for q_span in sample_pred_q: |
|
|
q_text = self.extract_span_text(q_span, tokens[b], subword_to_word[b]) |
|
|
q_texts.append(q_text if q_text else "unknown") |
|
|
|
|
|
real_q_embeds_st = self.encode_texts(q_texts, device) |
|
|
else: |
|
|
real_q_embeds_st = None |
|
|
|
|
|
for i, q_span in enumerate(sample_pred_q): |
|
|
|
|
|
q_emb_lilt = self.pool_single_span(hidden_states[b], q_span, bbox[b]) |
|
|
sample_q_embeds_lilt.append(q_emb_lilt) |
|
|
|
|
|
q_bbox = self.get_span_bbox(bbox[b], q_span) |
|
|
sample_q_bboxes.append(q_bbox) |
|
|
|
|
|
|
|
|
if real_q_embeds_st is not None: |
|
|
real_q_emb_st = real_q_embeds_st[i] |
|
|
else: |
|
|
real_q_emb_st = None |
|
|
|
|
|
|
|
|
should_mask = training and (torch.rand(1).item() < self.q_mask_prob) |
|
|
|
|
|
if should_mask and i < len(sample_a_embeds): |
|
|
|
|
|
pred_q_emb = self.question_predictor(sample_a_embeds[i]) |
|
|
sample_q_embeds_st.append(pred_q_emb) |
|
|
sample_real_q_from_pred.append(real_q_emb_st) |
|
|
sample_pred_q_from_a.append(pred_q_emb) |
|
|
else: |
|
|
|
|
|
sample_q_embeds_st.append(real_q_emb_st if real_q_emb_st is not None else torch.zeros(self.query_embed_dim, device=device)) |
|
|
sample_real_q_from_pred.append(real_q_emb_st) |
|
|
sample_pred_q_from_a.append(None) |
|
|
else: |
|
|
|
|
|
for a_emb in sample_a_embeds: |
|
|
pred_q = self.question_predictor(a_emb) |
|
|
sample_q_embeds_st.append(pred_q) |
|
|
sample_pred_q_from_a.append(pred_q) |
|
|
sample_real_q_from_pred.append(None) |
|
|
|
|
|
sample_q_embeds_lilt = [] |
|
|
|
|
|
sample_q_bboxes = sample_a_bboxes.copy() |
|
|
|
|
|
|
|
|
|
|
|
sample_real_q = [] |
|
|
if gt_question_spans is not None and b < len(gt_question_spans) and tokens is not None and subword_to_word is not None: |
|
|
gt_q_texts = [] |
|
|
gt_q_valid_indices = [] |
|
|
for idx, gt_q_span in enumerate(gt_question_spans[b]): |
|
|
if gt_q_span is not None: |
|
|
q_text = self.extract_span_text(gt_q_span, tokens[b], subword_to_word[b]) |
|
|
gt_q_texts.append(q_text if q_text else "unknown") |
|
|
gt_q_valid_indices.append(idx) |
|
|
|
|
|
|
|
|
if gt_q_texts: |
|
|
gt_q_embeds = self.encode_texts(gt_q_texts, device) |
|
|
embed_idx = 0 |
|
|
for idx, gt_q_span in enumerate(gt_question_spans[b]): |
|
|
if gt_q_span is not None: |
|
|
sample_real_q.append(gt_q_embeds[embed_idx]) |
|
|
embed_idx += 1 |
|
|
else: |
|
|
sample_real_q.append(None) |
|
|
else: |
|
|
sample_real_q = [None] * len(gt_question_spans[b]) |
|
|
|
|
|
|
|
|
gt_to_pred_a = None |
|
|
if gt_answer_spans is not None and b < len(gt_answer_spans): |
|
|
gt_a_spans = gt_answer_spans[b] |
|
|
gt_to_pred_a, _ = self.match_regions_to_gt(sample_pred_a, gt_a_spans) |
|
|
|
|
|
all_q_embeds_lilt.append(torch.stack(sample_q_embeds_lilt) if sample_q_embeds_lilt else None) |
|
|
all_q_embeds_st.append(torch.stack(sample_q_embeds_st) if sample_q_embeds_st else None) |
|
|
all_a_embeds.append(torch.stack(sample_a_embeds) if sample_a_embeds else None) |
|
|
all_q_bboxes.append(torch.stack(sample_q_bboxes) if sample_q_bboxes else None) |
|
|
all_a_bboxes.append(torch.stack(sample_a_bboxes) if sample_a_bboxes else None) |
|
|
|
|
|
all_pred_q_from_a.append(sample_pred_q_from_a if sample_pred_q_from_a else None) |
|
|
all_real_q_from_pred.append(sample_real_q_from_pred if sample_real_q_from_pred else None) |
|
|
all_real_q_embeds.append(sample_real_q if sample_real_q else None) |
|
|
all_gt_to_pred_a_idx.append(gt_to_pred_a) |
|
|
|
|
|
|
|
|
|
|
|
valid_indices_lilt = [i for i, q in enumerate(all_q_embeds_lilt) if q is not None] |
|
|
|
|
|
valid_indices_st = [i for i, q in enumerate(all_q_embeds_st) if q is not None] |
|
|
|
|
|
if valid_indices_lilt: |
|
|
|
|
|
max_q_spans_lilt = max(all_q_embeds_lilt[i].shape[0] for i in valid_indices_lilt) |
|
|
max_a_spans = max(all_a_embeds[i].shape[0] for i in valid_indices_lilt) |
|
|
|
|
|
|
|
|
q_embeds_lilt_padded = torch.zeros(batch_size, max_q_spans_lilt, self.hidden_size, device=device) |
|
|
q_bboxes_padded = torch.zeros(batch_size, max_q_spans_lilt, 4, device=device) |
|
|
q_span_mask = torch.zeros(batch_size, max_q_spans_lilt, dtype=torch.bool, device=device) |
|
|
|
|
|
|
|
|
a_embeds_padded = torch.zeros(batch_size, max_a_spans, self.hidden_size, device=device) |
|
|
a_bboxes_padded = torch.zeros(batch_size, max_a_spans, 4, device=device) |
|
|
a_span_mask = torch.zeros(batch_size, max_a_spans, dtype=torch.bool, device=device) |
|
|
|
|
|
for b in valid_indices_lilt: |
|
|
nq = all_q_embeds_lilt[b].shape[0] |
|
|
q_embeds_lilt_padded[b, :nq] = all_q_embeds_lilt[b] |
|
|
q_bboxes_padded[b, :nq] = all_q_bboxes[b] |
|
|
q_span_mask[b, :nq] = True |
|
|
|
|
|
na = all_a_embeds[b].shape[0] |
|
|
a_embeds_padded[b, :na] = all_a_embeds[b] |
|
|
a_bboxes_padded[b, :na] = all_a_bboxes[b] |
|
|
a_span_mask[b, :na] = True |
|
|
|
|
|
|
|
|
qa_link_scores_list = [] |
|
|
for b in valid_indices_lilt: |
|
|
nq = all_q_embeds_lilt[b].shape[0] |
|
|
na = all_a_embeds[b].shape[0] |
|
|
if nq > 0 and na > 0: |
|
|
link_scores = self.qa_linker( |
|
|
all_q_embeds_lilt[b], |
|
|
all_a_embeds[b], |
|
|
all_q_bboxes[b], |
|
|
all_a_bboxes[b], |
|
|
) |
|
|
qa_link_scores_list.append(link_scores) |
|
|
else: |
|
|
qa_link_scores_list.append(None) |
|
|
|
|
|
outputs["qa_link_scores"] = qa_link_scores_list |
|
|
|
|
|
|
|
|
if valid_indices_st: |
|
|
max_q_spans_st = max(all_q_embeds_st[i].shape[0] for i in valid_indices_st) |
|
|
q_embeds_st_padded = torch.zeros(batch_size, max_q_spans_st, self.query_embed_dim, device=device) |
|
|
q_span_mask_st = torch.zeros(batch_size, max_q_spans_st, dtype=torch.bool, device=device) |
|
|
|
|
|
for b in valid_indices_st: |
|
|
nq = all_q_embeds_st[b].shape[0] |
|
|
q_embeds_st_padded[b, :nq] = all_q_embeds_st[b] |
|
|
q_span_mask_st[b, :nq] = True |
|
|
|
|
|
outputs["q_embeds_st"] = q_embeds_st_padded |
|
|
outputs["q_span_mask_st"] = q_span_mask_st |
|
|
|
|
|
|
|
|
if training and gt_question_spans is not None and gt_answer_spans is not None: |
|
|
gt_qa_link_scores_list = [] |
|
|
gt_valid_q_indices_list = [] |
|
|
for b in range(batch_size): |
|
|
gt_q_spans = gt_question_spans[b] if b < len(gt_question_spans) else [] |
|
|
gt_a_spans = gt_answer_spans[b] if b < len(gt_answer_spans) else [] |
|
|
|
|
|
|
|
|
valid_q_indices = [i for i, q in enumerate(gt_q_spans) if q is not None] |
|
|
valid_q_spans = [gt_q_spans[i] for i in valid_q_indices] |
|
|
|
|
|
if len(valid_q_spans) > 0 and len(gt_a_spans) > 0: |
|
|
|
|
|
gt_q_embeds = torch.stack([ |
|
|
self.pool_single_span(hidden_states[b], q_span, bbox[b]) |
|
|
for q_span in valid_q_spans |
|
|
]) |
|
|
gt_a_embeds = torch.stack([ |
|
|
self.pool_single_span(hidden_states[b], a_span, bbox[b]) |
|
|
for a_span in gt_a_spans |
|
|
]) |
|
|
gt_q_bboxes = torch.stack([ |
|
|
self.get_span_bbox(bbox[b], q_span) |
|
|
for q_span in valid_q_spans |
|
|
]) |
|
|
gt_a_bboxes = torch.stack([ |
|
|
self.get_span_bbox(bbox[b], a_span) |
|
|
for a_span in gt_a_spans |
|
|
]) |
|
|
|
|
|
|
|
|
gt_link_scores = self.qa_linker( |
|
|
gt_q_embeds, gt_a_embeds, gt_q_bboxes, gt_a_bboxes |
|
|
) |
|
|
gt_qa_link_scores_list.append(gt_link_scores) |
|
|
gt_valid_q_indices_list.append(valid_q_indices) |
|
|
else: |
|
|
gt_qa_link_scores_list.append(None) |
|
|
gt_valid_q_indices_list.append(None) |
|
|
|
|
|
outputs["gt_qa_link_scores"] = gt_qa_link_scores_list |
|
|
outputs["gt_valid_q_indices"] = gt_valid_q_indices_list |
|
|
|
|
|
|
|
|
|
|
|
if valid_indices_st and "q_embeds_st" in outputs: |
|
|
q_embeds_st_padded = outputs["q_embeds_st"] |
|
|
q_span_mask_st = outputs["q_span_mask_st"] |
|
|
|
|
|
|
|
|
query_norm = F.normalize(query_emb, dim=-1) |
|
|
q_norm = F.normalize(q_embeds_st_padded, dim=-1) |
|
|
match_scores = torch.bmm(q_norm, query_norm.unsqueeze(-1)).squeeze(-1) |
|
|
match_scores = match_scores.masked_fill(~q_span_mask_st, float('-inf')) |
|
|
outputs["match_scores"] = match_scores |
|
|
outputs["match_span_mask"] = q_span_mask_st |
|
|
|
|
|
|
|
|
if valid_indices_lilt: |
|
|
outputs["q_span_mask"] = q_span_mask |
|
|
outputs["a_span_mask"] = a_span_mask |
|
|
outputs["q_embeds"] = q_embeds_lilt_padded |
|
|
outputs["a_embeds"] = a_embeds_padded |
|
|
outputs["q_bboxes"] = q_bboxes_padded |
|
|
outputs["a_bboxes"] = a_bboxes_padded |
|
|
|
|
|
|
|
|
outputs["pred_q_from_a"] = all_pred_q_from_a |
|
|
outputs["real_q_from_pred"] = all_real_q_from_pred |
|
|
outputs["real_q_embeds"] = all_real_q_embeds |
|
|
outputs["gt_to_pred_a_idx"] = all_gt_to_pred_a_idx |
|
|
outputs["valid_batch_indices_lilt"] = valid_indices_lilt |
|
|
outputs["valid_batch_indices_st"] = valid_indices_st |
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
def compute_loss( |
|
|
outputs: Dict[str, torch.Tensor], |
|
|
token_labels: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
table_labels: Optional[torch.Tensor] = None, |
|
|
row_labels: Optional[torch.Tensor] = None, |
|
|
header_labels: Optional[torch.Tensor] = None, |
|
|
match_labels: Optional[torch.Tensor] = None, |
|
|
col_labels: Optional[torch.Tensor] = None, |
|
|
class_weights: Optional[torch.Tensor] = None, |
|
|
qa_link_labels: Optional[List[torch.Tensor]] = None, |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
"""Compute joint loss for all steps including Q-A linking.""" |
|
|
device = token_labels.device |
|
|
losses = {} |
|
|
|
|
|
|
|
|
token_logits = outputs["token_logits"] |
|
|
token_logits_flat = token_logits.view(-1, token_logits.shape[-1]) |
|
|
token_labels_flat = token_labels.view(-1) |
|
|
attn_flat = attention_mask.view(-1).bool() |
|
|
|
|
|
if class_weights is not None: |
|
|
step1_loss = F.cross_entropy( |
|
|
token_logits_flat[attn_flat], |
|
|
token_labels_flat[attn_flat], |
|
|
weight=class_weights, |
|
|
) |
|
|
else: |
|
|
step1_loss = F.cross_entropy( |
|
|
token_logits_flat[attn_flat], |
|
|
token_labels_flat[attn_flat], |
|
|
) |
|
|
losses["step1_loss"] = step1_loss |
|
|
|
|
|
|
|
|
if "gt_qa_link_scores" in outputs and qa_link_labels is not None: |
|
|
qa_link_losses = [] |
|
|
gt_scores_list = outputs["gt_qa_link_scores"] |
|
|
gt_valid_q_indices = outputs.get("gt_valid_q_indices", [None] * len(gt_scores_list)) |
|
|
for scores, labels, valid_indices in zip(gt_scores_list, qa_link_labels, gt_valid_q_indices): |
|
|
if scores is None or labels is None or valid_indices is None: |
|
|
continue |
|
|
|
|
|
filtered_labels = labels[valid_indices] if len(valid_indices) > 0 else labels |
|
|
if scores.numel() > 0 and filtered_labels.numel() > 0: |
|
|
|
|
|
num_a = scores.shape[1] |
|
|
valid_mask = (filtered_labels >= 0) & (filtered_labels < num_a) |
|
|
if valid_mask.any(): |
|
|
link_loss = F.cross_entropy(scores[valid_mask], filtered_labels[valid_mask]) |
|
|
qa_link_losses.append(link_loss) |
|
|
if qa_link_losses: |
|
|
losses["qa_link_loss"] = torch.stack(qa_link_losses).mean() |
|
|
else: |
|
|
losses["qa_link_loss"] = torch.tensor(0.0, device=device) |
|
|
else: |
|
|
losses["qa_link_loss"] = torch.tensor(0.0, device=device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "pred_q_from_a" in outputs and "real_q_from_pred" in outputs: |
|
|
pred_q_from_a = outputs["pred_q_from_a"] |
|
|
real_q_from_pred = outputs["real_q_from_pred"] |
|
|
|
|
|
q_predict_losses = [] |
|
|
for batch_pred, batch_real in zip(pred_q_from_a, real_q_from_pred): |
|
|
if batch_pred is None or batch_real is None: |
|
|
continue |
|
|
for pred_q, real_q in zip(batch_pred, batch_real): |
|
|
if pred_q is not None and real_q is not None: |
|
|
|
|
|
mse = F.mse_loss(pred_q, real_q.detach()) |
|
|
q_predict_losses.append(mse) |
|
|
|
|
|
if q_predict_losses: |
|
|
losses["q_predict_loss"] = torch.stack(q_predict_losses).mean() |
|
|
else: |
|
|
losses["q_predict_loss"] = torch.tensor(0.0, device=device) |
|
|
else: |
|
|
losses["q_predict_loss"] = torch.tensor(0.0, device=device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
losses["step2_loss"] = torch.tensor(0.0, device=device) |
|
|
|
|
|
|
|
|
if table_labels is not None: |
|
|
table_logits = outputs["table_logits"] |
|
|
table_logits_flat = table_logits.view(-1, 2) |
|
|
table_labels_flat = table_labels.view(-1) |
|
|
table_det_loss = F.cross_entropy( |
|
|
table_logits_flat[attn_flat], |
|
|
table_labels_flat[attn_flat], |
|
|
) |
|
|
losses["table_det_loss"] = table_det_loss |
|
|
else: |
|
|
losses["table_det_loss"] = torch.tensor(0.0, device=device) |
|
|
|
|
|
if header_labels is not None and table_labels is not None: |
|
|
header_logits = outputs["header_logits"] |
|
|
table_mask = (table_labels == 1) & attention_mask.bool() |
|
|
if table_mask.any(): |
|
|
header_logits_flat = header_logits.view(-1, 2) |
|
|
header_labels_flat = header_labels.view(-1) |
|
|
table_mask_flat = table_mask.view(-1) |
|
|
header_loss = F.cross_entropy( |
|
|
header_logits_flat[table_mask_flat], |
|
|
header_labels_flat[table_mask_flat], |
|
|
) |
|
|
losses["header_loss"] = header_loss |
|
|
else: |
|
|
losses["header_loss"] = torch.tensor(0.0, device=device) |
|
|
else: |
|
|
losses["header_loss"] = torch.tensor(0.0, device=device) |
|
|
|
|
|
if row_labels is not None and table_labels is not None: |
|
|
row_logits = outputs["row_logits"] |
|
|
table_mask = (table_labels == 1) & attention_mask.bool() |
|
|
if table_mask.any(): |
|
|
row_logits_flat = row_logits.view(-1, 3) |
|
|
row_labels_flat = row_labels.view(-1) |
|
|
table_mask_flat = table_mask.view(-1) |
|
|
row_loss = F.cross_entropy( |
|
|
row_logits_flat[table_mask_flat], |
|
|
row_labels_flat[table_mask_flat], |
|
|
) |
|
|
losses["row_loss"] = row_loss |
|
|
else: |
|
|
losses["row_loss"] = torch.tensor(0.0, device=device) |
|
|
else: |
|
|
losses["row_loss"] = torch.tensor(0.0, device=device) |
|
|
|
|
|
if col_labels is not None and table_labels is not None and header_labels is not None: |
|
|
col_scores = outputs["col_scores"] |
|
|
cell_mask = (table_labels == 1) & (header_labels == 0) & attention_mask.bool() |
|
|
if cell_mask.any(): |
|
|
col_scores_flat = col_scores.view(-1, col_scores.shape[-1]) |
|
|
col_labels_flat = col_labels.view(-1) |
|
|
cell_mask_flat = cell_mask.view(-1) |
|
|
col_loss = F.cross_entropy( |
|
|
col_scores_flat[cell_mask_flat], |
|
|
col_labels_flat[cell_mask_flat], |
|
|
) |
|
|
losses["col_loss"] = col_loss |
|
|
else: |
|
|
losses["col_loss"] = torch.tensor(0.0, device=device) |
|
|
else: |
|
|
losses["col_loss"] = torch.tensor(0.0, device=device) |
|
|
|
|
|
|
|
|
total_loss = ( |
|
|
losses["step1_loss"] + |
|
|
losses["qa_link_loss"] * 0.3 + |
|
|
losses["q_predict_loss"] * 0.2 + |
|
|
losses["step2_loss"] * 0.5 + |
|
|
losses["table_det_loss"] * 0.5 + |
|
|
losses["header_loss"] * 0.3 + |
|
|
losses["row_loss"] * 0.3 + |
|
|
losses["col_loss"] * 0.3 |
|
|
) |
|
|
losses["loss"] = total_loss |
|
|
|
|
|
return losses |
|
|
|
|
|
|
|
|
def compute_contrastive_loss( |
|
|
query_embedding: torch.Tensor, |
|
|
span_embeddings: torch.Tensor, |
|
|
span_mask: torch.Tensor, |
|
|
match_labels: torch.Tensor, |
|
|
temperature: float = 0.07, |
|
|
hard_negative_weight: float = 2.0, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Contrastive loss for query-answer matching. |
|
|
|
|
|
InfoNCE-style loss that: |
|
|
- Pulls query towards correct answer span |
|
|
- Pushes query away from all other spans (in-batch negatives) |
|
|
- Applies extra weight to hard negatives (high-scoring wrong answers) |
|
|
|
|
|
Args: |
|
|
query_embedding: Query [CLS] embeddings |
|
|
span_embeddings: Pooled span embeddings [h_start; h_end; bbox] |
|
|
span_mask: Valid span mask |
|
|
match_labels: Index of correct span for each query |
|
|
temperature: Softmax temperature (lower = harder) |
|
|
hard_negative_weight: Extra weight for hard negatives |
|
|
|
|
|
Returns: |
|
|
Contrastive loss scalar |
|
|
""" |
|
|
batch_size = query_embedding.shape[0] |
|
|
device = query_embedding.device |
|
|
|
|
|
if batch_size == 0: |
|
|
return torch.tensor(0.0, device=device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
losses = [] |
|
|
for b in range(batch_size): |
|
|
valid_spans = span_mask[b].sum().item() |
|
|
if valid_spans <= 1: |
|
|
continue |
|
|
|
|
|
correct_idx = match_labels[b].item() |
|
|
if correct_idx >= valid_spans: |
|
|
continue |
|
|
|
|
|
|
|
|
spans = span_embeddings[b, :int(valid_spans), :] |
|
|
query = query_embedding[b] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spans_norm = F.normalize(spans[:, :query.shape[0]], dim=-1) |
|
|
query_norm = F.normalize(query, dim=-1) |
|
|
|
|
|
|
|
|
sims = torch.matmul(spans_norm, query_norm) / temperature |
|
|
|
|
|
|
|
|
target = torch.zeros(int(valid_spans), device=device) |
|
|
target[correct_idx] = 1.0 |
|
|
|
|
|
|
|
|
|
|
|
weights = torch.ones(int(valid_spans), device=device) |
|
|
weights[correct_idx] = 1.0 |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
hard_neg_mask = (sims > sims[correct_idx] - 0.5) & (torch.arange(int(valid_spans), device=device) != correct_idx) |
|
|
weights[hard_neg_mask] = hard_negative_weight |
|
|
|
|
|
|
|
|
log_probs = F.log_softmax(sims, dim=0) |
|
|
loss = -log_probs[correct_idx] |
|
|
|
|
|
|
|
|
for i in range(int(valid_spans)): |
|
|
if i != correct_idx and hard_neg_mask[i] if hard_neg_mask.any() else False: |
|
|
margin_loss = F.relu(sims[i] - sims[correct_idx] + 0.3) |
|
|
loss = loss + 0.1 * margin_loss |
|
|
|
|
|
losses.append(loss) |
|
|
|
|
|
if not losses: |
|
|
return torch.tensor(0.0, device=device) |
|
|
|
|
|
return torch.stack(losses).mean() |
|
|
|
|
|
|
|
|
def compute_loss_with_contrastive( |
|
|
outputs: Dict[str, torch.Tensor], |
|
|
token_labels: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
table_labels: Optional[torch.Tensor] = None, |
|
|
row_labels: Optional[torch.Tensor] = None, |
|
|
header_labels: Optional[torch.Tensor] = None, |
|
|
match_labels: Optional[torch.Tensor] = None, |
|
|
col_labels: Optional[torch.Tensor] = None, |
|
|
class_weights: Optional[torch.Tensor] = None, |
|
|
qa_link_labels: Optional[List[torch.Tensor]] = None, |
|
|
contrastive_weight: float = 0.5, |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Compute joint loss with contrastive learning for Step 2. |
|
|
|
|
|
Adds InfoNCE-style contrastive loss to help learn better span representations. |
|
|
""" |
|
|
|
|
|
losses = compute_loss( |
|
|
outputs=outputs, |
|
|
token_labels=token_labels, |
|
|
attention_mask=attention_mask, |
|
|
table_labels=table_labels, |
|
|
row_labels=row_labels, |
|
|
header_labels=header_labels, |
|
|
match_labels=match_labels, |
|
|
col_labels=col_labels, |
|
|
class_weights=class_weights, |
|
|
qa_link_labels=qa_link_labels, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
losses["contrastive_loss"] = torch.tensor(0.0, device=token_labels.device) |
|
|
|
|
|
return losses |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("Testing DEXTR LiLT model with CASCADING architecture...") |
|
|
print("=" * 60) |
|
|
|
|
|
model = DEXTRLiLT( |
|
|
model_name="nielsr/lilt-xlm-roberta-base", |
|
|
max_seq_len=1024, |
|
|
) |
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
print(f"Total parameters: {total_params:,}") |
|
|
print(f"Trainable parameters: {trainable_params:,}") |
|
|
|
|
|
|
|
|
batch_size = 2 |
|
|
seq_len = 128 |
|
|
query_len = 16 |
|
|
|
|
|
input_ids = torch.randint(0, 1000, (batch_size, seq_len)) |
|
|
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long) |
|
|
|
|
|
x1 = torch.randint(0, 500, (batch_size, seq_len)) |
|
|
y1 = torch.randint(0, 500, (batch_size, seq_len)) |
|
|
x2 = x1 + torch.randint(10, 500, (batch_size, seq_len)) |
|
|
y2 = y1 + torch.randint(10, 500, (batch_size, seq_len)) |
|
|
|
|
|
bbox = torch.stack([x1, y1, x2.clamp(max=1000), y2.clamp(max=1000)], dim=-1) |
|
|
query_input_ids = torch.randint(0, 1000, (batch_size, query_len)) |
|
|
query_attention_mask = torch.ones(batch_size, query_len, dtype=torch.long) |
|
|
|
|
|
|
|
|
gt_answer_spans = [[(5, 10), (20, 25)], [(8, 12)]] |
|
|
gt_question_spans = [[(2, 5), (17, 20)], [(5, 8)]] |
|
|
|
|
|
print("\nRunning forward pass (CASCADING architecture)...") |
|
|
print(" - Step 1: Token classification → predict Q/A/H/TABLE/O") |
|
|
print(" - Step 1.5: Extract predicted regions, QALinker pairs them") |
|
|
print(" - Step 2: Query matches to predicted Q embeddings") |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
outputs = model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
bbox=bbox, |
|
|
query_input_ids=query_input_ids, |
|
|
query_attention_mask=query_attention_mask, |
|
|
gt_answer_spans=gt_answer_spans, |
|
|
gt_question_spans=gt_question_spans, |
|
|
) |
|
|
|
|
|
print(f"\nOutput shapes:") |
|
|
print(f" token_logits: {outputs['token_logits'].shape}") |
|
|
print(f" table_logits: {outputs['table_logits'].shape}") |
|
|
print(f" header_logits: {outputs['header_logits'].shape}") |
|
|
print(f" row_logits: {outputs['row_logits'].shape}") |
|
|
print(f" col_scores: {outputs['col_scores'].shape}") |
|
|
print(f" query_embedding: {outputs['query_embedding'].shape}") |
|
|
|
|
|
|
|
|
print(f"\nPredicted regions from Step 1:") |
|
|
print(f" pred_q_regions: {outputs['pred_q_regions']}") |
|
|
print(f" pred_a_regions: {outputs['pred_a_regions']}") |
|
|
|
|
|
|
|
|
if "match_scores" in outputs: |
|
|
print(f" match_scores: {outputs['match_scores'].shape}") |
|
|
else: |
|
|
print(" match_scores: None (no regions predicted)") |
|
|
|
|
|
|
|
|
print("\nTesting loss computation...") |
|
|
token_labels = torch.randint(0, NUM_LABELS, (batch_size, seq_len)) |
|
|
table_labels = torch.randint(0, 2, (batch_size, seq_len)) |
|
|
row_labels = torch.randint(0, 3, (batch_size, seq_len)) |
|
|
header_labels = torch.randint(0, 2, (batch_size, seq_len)) |
|
|
|
|
|
|
|
|
|
|
|
match_labels = None |
|
|
if "match_scores" in outputs and outputs["match_scores"].shape[1] > 0: |
|
|
match_labels = torch.zeros(batch_size, dtype=torch.long) |
|
|
|
|
|
losses = compute_loss( |
|
|
outputs, |
|
|
token_labels=token_labels, |
|
|
attention_mask=attention_mask, |
|
|
table_labels=table_labels, |
|
|
row_labels=row_labels, |
|
|
header_labels=header_labels, |
|
|
match_labels=match_labels, |
|
|
) |
|
|
|
|
|
print(f"\nLosses:") |
|
|
for name, value in losses.items(): |
|
|
print(f" {name}: {value.item():.4f}") |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("DEXTR LiLT CASCADING architecture test passed!") |
|
|
|