State-of-the-Art Document Understanding Methods
π― You're Right - This IS Document Understanding!
Contract analysis = Document Understanding + Legal Domain Knowledge
π Current SOTA Methods (2024-2025)
1. Long-Context Transformers β BEST FOR YOUR USE CASE
Longformer (Allen AI, 2020)
- Max Length: 4,096+ tokens (vs BERT's 512)
- Innovation: Sliding window + global attention
- Use Case: Full contract documents without chunking
from transformers import LongformerModel, LongformerTokenizer
model = LongformerModel.from_pretrained('allenai/longformer-base-4096')
tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
# Can process entire contract at once!
inputs = tokenizer(full_contract, return_tensors='pt', max_length=4096)
outputs = model(**inputs)
Legal Variant: Longformer-Legal
# Pre-trained on legal documents!
model = LongformerModel.from_pretrained('lexlms/legal-longformer-base')
LED (Longformer Encoder-Decoder) (Allen AI, 2020)
- Max Length: 16,384 tokens
- Use Case: Document summarization + Q&A
from transformers import LEDForConditionalGeneration, LEDTokenizer
model = LEDForConditionalGeneration.from_pretrained('allenai/led-large-16384')
# Summarize entire contract
summary = model.generate(contract_tokens, max_length=512)
BigBird (Google, 2020)
- Max Length: 4,096 tokens
- Innovation: Sparse attention (random + global + sliding)
- Efficiency: O(n) instead of O(nΒ²)
from transformers import BigBirdModel
model = BigBirdModel.from_pretrained('google/bigbird-roberta-base')
2. Hierarchical Document Models π₯ RECOMMENDED
Hierarchical Attention Networks (HAN) (Yang et al., 2016)
- Structure: Word β Sentence β Document
- Perfect for: Legal contracts with clause hierarchy
Document
βββ Section 1: SERVICES
β βββ Sentence 1 (attention weights)
β βββ Sentence 2 (attention weights)
β βββ Sentence 3 (attention weights)
βββ Section 2: PAYMENT
β βββ ...
βββ Section 3: TERMINATION
Implementation:
class HierarchicalContractModel(nn.Module):
def __init__(self):
super().__init__()
# Word-level encoder
self.word_encoder = nn.GRU(embedding_dim, hidden_dim, bidirectional=True)
self.word_attention = nn.Linear(hidden_dim*2, 1)
# Sentence-level encoder
self.sentence_encoder = nn.GRU(hidden_dim*2, hidden_dim, bidirectional=True)
self.sentence_attention = nn.Linear(hidden_dim*2, 1)
# Document-level classifier
self.classifier = nn.Linear(hidden_dim*2, num_classes)
def forward(self, document):
# document shape: [batch, num_sentences, num_words, embedding_dim]
# 1. Encode words in each sentence
sentence_vectors = []
for sentence in document:
word_hidden = self.word_encoder(sentence)
word_attn = F.softmax(self.word_attention(word_hidden), dim=0)
sentence_vec = (word_hidden * word_attn).sum(dim=0)
sentence_vectors.append(sentence_vec)
# 2. Encode sentences in document
doc_hidden = self.sentence_encoder(sentence_vectors)
sent_attn = F.softmax(self.sentence_attention(doc_hidden), dim=0)
doc_vec = (doc_hidden * sent_attn).sum(dim=0)
# 3. Classify
return self.classifier(doc_vec)
BERT-HAN (Modern variant)
Combine BERT with hierarchical structure:
class BERTHierarchical(nn.Module):
def __init__(self):
super().__init__()
# Use BERT for sentence encoding
self.bert = AutoModel.from_pretrained('nlpaueb/legal-bert-base-uncased')
# Hierarchical aggregation
self.clause_encoder = nn.LSTM(768, 256, bidirectional=True)
self.section_encoder = nn.LSTM(512, 256, bidirectional=True)
# Attention mechanisms
self.clause_attention = nn.Linear(512, 1)
self.section_attention = nn.Linear(512, 1)
def forward(self, sections):
# sections: List[List[clause_text]]
section_vectors = []
for section in sections:
# Encode each clause with BERT
clause_embeddings = []
for clause in section:
bert_output = self.bert(**tokenizer(clause, return_tensors='pt'))
clause_embeddings.append(bert_output.last_hidden_state[:, 0, :])
# Aggregate clauses -> section
clause_hidden, _ = self.clause_encoder(torch.stack(clause_embeddings))
clause_attn = F.softmax(self.clause_attention(clause_hidden), dim=0)
section_vec = (clause_hidden * clause_attn).sum(dim=0)
section_vectors.append(section_vec)
# Aggregate sections -> document
section_hidden, _ = self.section_encoder(torch.stack(section_vectors))
section_attn = F.softmax(self.section_attention(section_hidden), dim=0)
document_vec = (section_hidden * section_attn).sum(dim=0)
return document_vec
3. Document Graph Neural Networks π
Graph Transformer (Microsoft, 2022)
Model document as graph: clauses = nodes, references = edges
[Clause 1: "Services in Exhibit A"] ββreferencesββ> [Exhibit A]
β
mentions
β
β
[Clause 2: "Such Services..."] ββreferencesββ> [Section 5]
import torch_geometric as pyg
class DocumentGraphNN(nn.Module):
def __init__(self):
super().__init__()
# Node encoder (BERT for each clause)
self.node_encoder = AutoModel.from_pretrained('legal-bert')
# Graph convolution layers
self.conv1 = pyg.nn.GCNConv(768, 256)
self.conv2 = pyg.nn.GCNConv(256, 128)
# Classifier
self.classifier = nn.Linear(128, num_classes)
def forward(self, clauses, edges):
# 1. Encode nodes (clauses)
node_features = []
for clause in clauses:
bert_out = self.node_encoder(**tokenizer(clause, return_tensors='pt'))
node_features.append(bert_out.last_hidden_state[:, 0, :])
x = torch.stack(node_features)
# 2. Graph convolution (propagate context)
x = F.relu(self.conv1(x, edges))
x = self.conv2(x, edges)
# 3. Classify
return self.classifier(x)
Edge Types:
- Sequential (Clause N β Clause N+1)
- Reference ("Section 5" β actual Section 5)
- Semantic similarity (cosine > threshold)
4. Retrieval-Augmented Models π
RAG (Retrieval-Augmented Generation) (Facebook, 2020)
Retrieve relevant clauses before classification
from transformers import RagTokenizer, RagRetriever, RagModel
# Index all contract clauses
retriever = RagRetriever.from_pretrained('facebook/rag-token-base')
model = RagModel.from_pretrained('facebook/rag-token-base')
# For each clause, retrieve similar clauses
def predict_with_retrieval(clause):
# Retrieve top-k similar clauses
retrieved = retriever.retrieve(clause, top_k=5)
# Generate prediction with context
output = model.generate(
context_input_ids=retrieved['input_ids'],
context_attention_mask=retrieved['attention_mask'],
decoder_input_ids=tokenizer(clause, return_tensors='pt')['input_ids']
)
return output
5. Large Language Models (2023-2025) π€
GPT-4 / Claude (OpenAI / Anthropic)
- Context: 128k tokens (GPT-4 Turbo), 200k (Claude 3)
- Approach: Few-shot learning + prompting
import openai
def analyze_contract_llm(contract):
prompt = f"""
Analyze this contract for risk clauses. For each clause, provide:
1. Risk severity (0-10)
2. Risk category
3. Explanation
Contract:
{contract}
Format as JSON.
"""
response = openai.ChatCompletion.create(
model="gpt-4-turbo",
messages=[{"role": "user", "content": prompt}],
temperature=0.1 # Low for consistency
)
return json.loads(response.choices[0].message.content)
Pros: No training needed, incredible understanding Cons: Expensive, slower, privacy concerns
Legal-Specific LLMs
- LegalBench (Stanford, 2023)
- ChatLaw (Peking University, 2023)
- LawGPT (2024)
6. Multi-Modal Document Understanding π
LayoutLM (Microsoft, 2020-2023)
Understands document layout + text
from transformers import LayoutLMv3Model
# Processes:
# 1. Text content
# 2. Bounding boxes (where text appears)
# 3. Images (if PDF has visual elements)
model = LayoutLMv3Model.from_pretrained('microsoft/layoutlmv3-base')
# Input includes position information
inputs = {
'input_ids': text_tokens,
'bbox': bounding_boxes, # [x0, y0, x1, y1] for each token
'pixel_values': document_image
}
outputs = model(**inputs)
Why this matters: Contracts have structure (headers, indentation, tables)
π Comparison for Your Use Case
| Method | Context Length | Structure Aware | Training Cost | Inference Speed | Best For |
|---|---|---|---|---|---|
| BERT (current) | 512 tokens | β | Low | Fast | Clause-level |
| Longformer | 4,096 tokens | β | Medium | Medium | Full documents |
| Hierarchical BERT | Unlimited* | β β β | Medium | Medium | RECOMMENDED |
| Graph NN | Unlimited* | β β | High | Slow | Complex references |
| LLM (GPT-4) | 128k tokens | β | Zero! | Slow | No training data |
| LayoutLM | 512 tokens | β (visual) | High | Medium | Scanned PDFs |
*Processes document in chunks with aggregation
π― RECOMMENDATION: Hierarchical BERT
Why?
- β Respects document structure (clauses β sections β document)
- β Handles any document length (processes hierarchically)
- β Better context modeling than your current sliding window
- β Interpretable (attention weights show important sections)
- β Moderate complexity (not too hard to implement)
Implementation Plan
# hierarchical_bert.py
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
class HierarchicalContractBERT(nn.Module):
"""
Hierarchical document understanding for legal contracts
Structure:
Word β Clause β Section β Document
"""
def __init__(self, model_name='nlpaueb/legal-bert-base-uncased', num_labels=3):
super().__init__()
# Clause encoder (BERT)
self.bert = AutoModel.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# Hierarchical aggregation
self.clause_to_section = nn.LSTM(
input_size=768, # BERT hidden size
hidden_size=256,
num_layers=2,
bidirectional=True,
dropout=0.1,
batch_first=True
)
self.section_to_document = nn.LSTM(
input_size=512, # bidirectional * 256
hidden_size=256,
num_layers=2,
bidirectional=True,
dropout=0.1,
batch_first=True
)
# Attention mechanisms
self.clause_attention = nn.Sequential(
nn.Linear(512, 128),
nn.Tanh(),
nn.Linear(128, 1)
)
self.section_attention = nn.Sequential(
nn.Linear(512, 128),
nn.Tanh(),
nn.Linear(128, 1)
)
# Task heads (same as your current model)
self.severity_head = nn.Linear(512, 1)
self.category_head = nn.Linear(512, num_labels)
def encode_clause(self, clause_text):
"""Encode single clause with BERT"""
inputs = self.tokenizer(
clause_text,
return_tensors='pt',
padding='max_length',
truncation=True,
max_length=128 # Shorter for clauses
)
outputs = self.bert(**inputs)
return outputs.last_hidden_state[:, 0, :] # [CLS] token
def aggregate_with_attention(self, hidden_states, attention_module):
"""Apply attention-based aggregation"""
# hidden_states: [batch, seq_len, hidden_dim]
# Compute attention weights
attention_logits = attention_module(hidden_states) # [batch, seq_len, 1]
attention_weights = torch.softmax(attention_logits, dim=1)
# Weighted sum
context_vector = torch.sum(hidden_states * attention_weights, dim=1)
return context_vector, attention_weights
def forward(self, document_structure):
"""
Args:
document_structure: List of sections
Each section is a list of clause texts
Example: [
['clause 1', 'clause 2'], # Section 1
['clause 3', 'clause 4', 'clause 5'], # Section 2
]
Returns:
document_embedding, clause_predictions, attention_weights
"""
section_vectors = []
all_clause_predictions = []
attention_weights = {'clause': [], 'section': None}
# Process each section
for section_clauses in document_structure:
clause_vectors = []
# 1. Encode each clause
for clause in section_clauses:
clause_vec = self.encode_clause(clause)
clause_vectors.append(clause_vec)
# Stack: [num_clauses, 768]
clause_hidden = torch.stack(clause_vectors).squeeze(1)
# 2. LSTM over clauses (captures sequential context)
clause_lstm_out, _ = self.clause_to_section(clause_hidden.unsqueeze(0))
# clause_lstm_out: [1, num_clauses, 512]
# 3. Attention over clauses
section_vec, clause_attn = self.aggregate_with_attention(
clause_lstm_out,
self.clause_attention
)
section_vectors.append(section_vec)
attention_weights['clause'].append(clause_attn)
# 4. Predict for each clause (using context-aware representation)
for i in range(len(section_clauses)):
clause_repr = clause_lstm_out[0, i, :] # Context-aware!
severity = self.severity_head(clause_repr)
category = self.category_head(clause_repr)
all_clause_predictions.append({
'severity': severity,
'category': category,
'text': section_clauses[i]
})
# 5. Stack sections: [num_sections, 512]
section_hidden = torch.stack(section_vectors)
# 6. LSTM over sections
section_lstm_out, _ = self.section_to_document(section_hidden.unsqueeze(0))
# 7. Attention over sections
document_vec, section_attn = self.aggregate_with_attention(
section_lstm_out,
self.section_attention
)
attention_weights['section'] = section_attn
return {
'document_embedding': document_vec,
'clause_predictions': all_clause_predictions,
'attention_weights': attention_weights
}
def predict_document(self, document_structure):
"""Convenience method for inference"""
outputs = self.forward(document_structure)
# Extract predictions
predictions = []
for clause_pred in outputs['clause_predictions']:
predictions.append({
'text': clause_pred['text'],
'severity': torch.sigmoid(clause_pred['severity']).item() * 10,
'category': torch.softmax(clause_pred['category'], dim=-1).tolist()
})
return {
'clauses': predictions,
'attention_weights': outputs['attention_weights']
}
# Usage example
if __name__ == '__main__':
model = HierarchicalContractBERT()
# Parse document into hierarchical structure
document = [
# Section 1: Services
[
"Provider shall deliver software services as described in Exhibit A.",
"Such Services shall be performed in a professional manner.",
"Services include maintenance and support."
],
# Section 2: Payment
[
"Client shall pay within 30 days of invoice.",
"Late payments incur 1.5% monthly interest.",
],
# Section 3: Termination
[
"Either party may terminate with 30 days written notice.",
"Upon termination, all obligations under Section 2 remain in effect."
]
]
# Predict
results = model.predict_document(document)
print("Clause Predictions:")
for i, clause in enumerate(results['clauses']):
print(f"\n{i+1}. {clause['text']}")
print(f" Severity: {clause['severity']:.2f}/10")
print(f" Category: {clause['category']}")
print("\n\nAttention Weights:")
print("Clause-level attention shows which clauses are most important in each section")
print("Section-level attention shows which sections are most critical overall")
π Migration Path
Phase 1: Quick Win (Current + Context)
# Use your current model + sliding window context
# Already implemented! β
analyze_full_document(contract, model, use_context=True, context_window=2)
Phase 2: Upgrade to Longformer (Easy)
# Just swap BERT for Longformer
# Can process 8x longer context
from transformers import LongformerModel
model = LongformerModel.from_pretrained('lexlms/legal-longformer-base')
Phase 3: Hierarchical BERT (Recommended)
# Implement hierarchical model (code above)
# Better document understanding + interpretability
model = HierarchicalContractBERT()
Phase 4: Graph NN (Advanced)
# Add graph connections for references
# Build clause dependency graph
model = DocumentGraphNN()
π Key Papers
- Longformer: "Longformer: The Long-Document Transformer" (Beltagy et al., 2020)
- BigBird: "Big Bird: Transformers for Longer Sequences" (Zaheer et al., 2020)
- HAN: "Hierarchical Attention Networks for Document Classification" (Yang et al., 2016)
- LayoutLM: "LayoutLMv3: Pre-training for Document AI" (Huang et al., 2022)
- Legal-BERT: "LEGAL-BERT: The Muppets straight out of Law School" (Chalkidis et al., 2020)
- Document Understanding: "A Survey on Document-level Neural Machine Translation" (2023)
π‘ My Suggestion
Short term (Today):
- β Keep using context-aware analysis (already done!)
- Test with
use_context=True, context_window=2
Medium term (This week):
- π Implement Hierarchical BERT (code provided above)
- Train on your CUAD dataset with section structure
- Compare performance: BERT vs Hierarchical BERT
Long term (If needed):
- Consider Longformer for very long contracts (>2000 words)
- Experiment with Graph NN if many cross-references
- Try GPT-4 for zero-shot (if budget allows)
π― Bottom Line
You're correct: This is document understanding, not just text classification!
Current approach: Clause-by-clause is limiting SOTA approach: Hierarchical models that understand document structure
Best ROI: Implement Hierarchical BERT (code above)
- Moderate complexity
- Big performance gain
- Interpretable (attention weights)
- Handles full documents
Would you like me to integrate the Hierarchical BERT into your pipeline? π