code2-repo / doc /SOTA_DOCUMENT_UNDERSTANDING.md
Deepu1965's picture
Upload folder using huggingface_hub
9b1c753 verified
# State-of-the-Art Document Understanding Methods
## 🎯 You're Right - This IS Document Understanding!
Contract analysis = **Document Understanding** + Legal Domain Knowledge
---
## πŸ† Current SOTA Methods (2024-2025)
### **1. Long-Context Transformers** ⭐ BEST FOR YOUR USE CASE
#### **Longformer** (Allen AI, 2020)
- **Max Length**: 4,096+ tokens (vs BERT's 512)
- **Innovation**: Sliding window + global attention
- **Use Case**: Full contract documents without chunking
```python
from transformers import LongformerModel, LongformerTokenizer
model = LongformerModel.from_pretrained('allenai/longformer-base-4096')
tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
# Can process entire contract at once!
inputs = tokenizer(full_contract, return_tensors='pt', max_length=4096)
outputs = model(**inputs)
```
**Legal Variant**: [**Longformer-Legal**](https://huggingface.co/lexlms/legal-longformer-base)
```python
# Pre-trained on legal documents!
model = LongformerModel.from_pretrained('lexlms/legal-longformer-base')
```
---
#### **LED (Longformer Encoder-Decoder)** (Allen AI, 2020)
- **Max Length**: 16,384 tokens
- **Use Case**: Document summarization + Q&A
```python
from transformers import LEDForConditionalGeneration, LEDTokenizer
model = LEDForConditionalGeneration.from_pretrained('allenai/led-large-16384')
# Summarize entire contract
summary = model.generate(contract_tokens, max_length=512)
```
---
#### **BigBird** (Google, 2020)
- **Max Length**: 4,096 tokens
- **Innovation**: Sparse attention (random + global + sliding)
- **Efficiency**: O(n) instead of O(nΒ²)
```python
from transformers import BigBirdModel
model = BigBirdModel.from_pretrained('google/bigbird-roberta-base')
```
---
### **2. Hierarchical Document Models** πŸ”₯ RECOMMENDED
#### **Hierarchical Attention Networks (HAN)** (Yang et al., 2016)
- **Structure**: Word β†’ Sentence β†’ Document
- **Perfect for**: Legal contracts with clause hierarchy
```
Document
β”œβ”€β”€ Section 1: SERVICES
β”‚ β”œβ”€β”€ Sentence 1 (attention weights)
β”‚ β”œβ”€β”€ Sentence 2 (attention weights)
β”‚ └── Sentence 3 (attention weights)
β”œβ”€β”€ Section 2: PAYMENT
β”‚ └── ...
└── Section 3: TERMINATION
```
**Implementation**:
```python
class HierarchicalContractModel(nn.Module):
def __init__(self):
super().__init__()
# Word-level encoder
self.word_encoder = nn.GRU(embedding_dim, hidden_dim, bidirectional=True)
self.word_attention = nn.Linear(hidden_dim*2, 1)
# Sentence-level encoder
self.sentence_encoder = nn.GRU(hidden_dim*2, hidden_dim, bidirectional=True)
self.sentence_attention = nn.Linear(hidden_dim*2, 1)
# Document-level classifier
self.classifier = nn.Linear(hidden_dim*2, num_classes)
def forward(self, document):
# document shape: [batch, num_sentences, num_words, embedding_dim]
# 1. Encode words in each sentence
sentence_vectors = []
for sentence in document:
word_hidden = self.word_encoder(sentence)
word_attn = F.softmax(self.word_attention(word_hidden), dim=0)
sentence_vec = (word_hidden * word_attn).sum(dim=0)
sentence_vectors.append(sentence_vec)
# 2. Encode sentences in document
doc_hidden = self.sentence_encoder(sentence_vectors)
sent_attn = F.softmax(self.sentence_attention(doc_hidden), dim=0)
doc_vec = (doc_hidden * sent_attn).sum(dim=0)
# 3. Classify
return self.classifier(doc_vec)
```
---
#### **BERT-HAN** (Modern variant)
Combine BERT with hierarchical structure:
```python
class BERTHierarchical(nn.Module):
def __init__(self):
super().__init__()
# Use BERT for sentence encoding
self.bert = AutoModel.from_pretrained('nlpaueb/legal-bert-base-uncased')
# Hierarchical aggregation
self.clause_encoder = nn.LSTM(768, 256, bidirectional=True)
self.section_encoder = nn.LSTM(512, 256, bidirectional=True)
# Attention mechanisms
self.clause_attention = nn.Linear(512, 1)
self.section_attention = nn.Linear(512, 1)
def forward(self, sections):
# sections: List[List[clause_text]]
section_vectors = []
for section in sections:
# Encode each clause with BERT
clause_embeddings = []
for clause in section:
bert_output = self.bert(**tokenizer(clause, return_tensors='pt'))
clause_embeddings.append(bert_output.last_hidden_state[:, 0, :])
# Aggregate clauses -> section
clause_hidden, _ = self.clause_encoder(torch.stack(clause_embeddings))
clause_attn = F.softmax(self.clause_attention(clause_hidden), dim=0)
section_vec = (clause_hidden * clause_attn).sum(dim=0)
section_vectors.append(section_vec)
# Aggregate sections -> document
section_hidden, _ = self.section_encoder(torch.stack(section_vectors))
section_attn = F.softmax(self.section_attention(section_hidden), dim=0)
document_vec = (section_hidden * section_attn).sum(dim=0)
return document_vec
```
---
### **3. Document Graph Neural Networks** 🌐
#### **Graph Transformer** (Microsoft, 2022)
Model document as graph: clauses = nodes, references = edges
```
[Clause 1: "Services in Exhibit A"] ──references──> [Exhibit A]
β”‚
mentions
β”‚
↓
[Clause 2: "Such Services..."] ──references──> [Section 5]
```
```python
import torch_geometric as pyg
class DocumentGraphNN(nn.Module):
def __init__(self):
super().__init__()
# Node encoder (BERT for each clause)
self.node_encoder = AutoModel.from_pretrained('legal-bert')
# Graph convolution layers
self.conv1 = pyg.nn.GCNConv(768, 256)
self.conv2 = pyg.nn.GCNConv(256, 128)
# Classifier
self.classifier = nn.Linear(128, num_classes)
def forward(self, clauses, edges):
# 1. Encode nodes (clauses)
node_features = []
for clause in clauses:
bert_out = self.node_encoder(**tokenizer(clause, return_tensors='pt'))
node_features.append(bert_out.last_hidden_state[:, 0, :])
x = torch.stack(node_features)
# 2. Graph convolution (propagate context)
x = F.relu(self.conv1(x, edges))
x = self.conv2(x, edges)
# 3. Classify
return self.classifier(x)
```
**Edge Types**:
- Sequential (Clause N β†’ Clause N+1)
- Reference ("Section 5" β†’ actual Section 5)
- Semantic similarity (cosine > threshold)
---
### **4. Retrieval-Augmented Models** πŸ”
#### **RAG (Retrieval-Augmented Generation)** (Facebook, 2020)
Retrieve relevant clauses before classification
```python
from transformers import RagTokenizer, RagRetriever, RagModel
# Index all contract clauses
retriever = RagRetriever.from_pretrained('facebook/rag-token-base')
model = RagModel.from_pretrained('facebook/rag-token-base')
# For each clause, retrieve similar clauses
def predict_with_retrieval(clause):
# Retrieve top-k similar clauses
retrieved = retriever.retrieve(clause, top_k=5)
# Generate prediction with context
output = model.generate(
context_input_ids=retrieved['input_ids'],
context_attention_mask=retrieved['attention_mask'],
decoder_input_ids=tokenizer(clause, return_tensors='pt')['input_ids']
)
return output
```
---
### **5. Large Language Models (2023-2025)** πŸ€–
#### **GPT-4 / Claude** (OpenAI / Anthropic)
- **Context**: 128k tokens (GPT-4 Turbo), 200k (Claude 3)
- **Approach**: Few-shot learning + prompting
```python
import openai
def analyze_contract_llm(contract):
prompt = f"""
Analyze this contract for risk clauses. For each clause, provide:
1. Risk severity (0-10)
2. Risk category
3. Explanation
Contract:
{contract}
Format as JSON.
"""
response = openai.ChatCompletion.create(
model="gpt-4-turbo",
messages=[{"role": "user", "content": prompt}],
temperature=0.1 # Low for consistency
)
return json.loads(response.choices[0].message.content)
```
**Pros**: No training needed, incredible understanding
**Cons**: Expensive, slower, privacy concerns
---
#### **Legal-Specific LLMs**
- **LegalBench** (Stanford, 2023)
- **ChatLaw** (Peking University, 2023)
- **LawGPT** (2024)
---
### **6. Multi-Modal Document Understanding** πŸ“„
#### **LayoutLM** (Microsoft, 2020-2023)
Understands document **layout** + text
```python
from transformers import LayoutLMv3Model
# Processes:
# 1. Text content
# 2. Bounding boxes (where text appears)
# 3. Images (if PDF has visual elements)
model = LayoutLMv3Model.from_pretrained('microsoft/layoutlmv3-base')
# Input includes position information
inputs = {
'input_ids': text_tokens,
'bbox': bounding_boxes, # [x0, y0, x1, y1] for each token
'pixel_values': document_image
}
outputs = model(**inputs)
```
**Why this matters**: Contracts have structure (headers, indentation, tables)
---
## πŸ“Š Comparison for Your Use Case
| Method | Context Length | Structure Aware | Training Cost | Inference Speed | Best For |
|--------|---------------|-----------------|---------------|-----------------|----------|
| **BERT** (current) | 512 tokens | ❌ | Low | Fast | Clause-level |
| **Longformer** | 4,096 tokens | ❌ | Medium | Medium | Full documents |
| **Hierarchical BERT** | Unlimited* | βœ…βœ…βœ… | Medium | Medium | **RECOMMENDED** |
| **Graph NN** | Unlimited* | βœ…βœ… | High | Slow | Complex references |
| **LLM (GPT-4)** | 128k tokens | βœ… | Zero! | Slow | No training data |
| **LayoutLM** | 512 tokens | βœ… (visual) | High | Medium | Scanned PDFs |
*Processes document in chunks with aggregation
---
## 🎯 RECOMMENDATION: Hierarchical BERT
### Why?
1. βœ… **Respects document structure** (clauses β†’ sections β†’ document)
2. βœ… **Handles any document length** (processes hierarchically)
3. βœ… **Better context modeling** than your current sliding window
4. βœ… **Interpretable** (attention weights show important sections)
5. βœ… **Moderate complexity** (not too hard to implement)
### Implementation Plan
```python
# hierarchical_bert.py
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
class HierarchicalContractBERT(nn.Module):
"""
Hierarchical document understanding for legal contracts
Structure:
Word β†’ Clause β†’ Section β†’ Document
"""
def __init__(self, model_name='nlpaueb/legal-bert-base-uncased', num_labels=3):
super().__init__()
# Clause encoder (BERT)
self.bert = AutoModel.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# Hierarchical aggregation
self.clause_to_section = nn.LSTM(
input_size=768, # BERT hidden size
hidden_size=256,
num_layers=2,
bidirectional=True,
dropout=0.1,
batch_first=True
)
self.section_to_document = nn.LSTM(
input_size=512, # bidirectional * 256
hidden_size=256,
num_layers=2,
bidirectional=True,
dropout=0.1,
batch_first=True
)
# Attention mechanisms
self.clause_attention = nn.Sequential(
nn.Linear(512, 128),
nn.Tanh(),
nn.Linear(128, 1)
)
self.section_attention = nn.Sequential(
nn.Linear(512, 128),
nn.Tanh(),
nn.Linear(128, 1)
)
# Task heads (same as your current model)
self.severity_head = nn.Linear(512, 1)
self.category_head = nn.Linear(512, num_labels)
def encode_clause(self, clause_text):
"""Encode single clause with BERT"""
inputs = self.tokenizer(
clause_text,
return_tensors='pt',
padding='max_length',
truncation=True,
max_length=128 # Shorter for clauses
)
outputs = self.bert(**inputs)
return outputs.last_hidden_state[:, 0, :] # [CLS] token
def aggregate_with_attention(self, hidden_states, attention_module):
"""Apply attention-based aggregation"""
# hidden_states: [batch, seq_len, hidden_dim]
# Compute attention weights
attention_logits = attention_module(hidden_states) # [batch, seq_len, 1]
attention_weights = torch.softmax(attention_logits, dim=1)
# Weighted sum
context_vector = torch.sum(hidden_states * attention_weights, dim=1)
return context_vector, attention_weights
def forward(self, document_structure):
"""
Args:
document_structure: List of sections
Each section is a list of clause texts
Example: [
['clause 1', 'clause 2'], # Section 1
['clause 3', 'clause 4', 'clause 5'], # Section 2
]
Returns:
document_embedding, clause_predictions, attention_weights
"""
section_vectors = []
all_clause_predictions = []
attention_weights = {'clause': [], 'section': None}
# Process each section
for section_clauses in document_structure:
clause_vectors = []
# 1. Encode each clause
for clause in section_clauses:
clause_vec = self.encode_clause(clause)
clause_vectors.append(clause_vec)
# Stack: [num_clauses, 768]
clause_hidden = torch.stack(clause_vectors).squeeze(1)
# 2. LSTM over clauses (captures sequential context)
clause_lstm_out, _ = self.clause_to_section(clause_hidden.unsqueeze(0))
# clause_lstm_out: [1, num_clauses, 512]
# 3. Attention over clauses
section_vec, clause_attn = self.aggregate_with_attention(
clause_lstm_out,
self.clause_attention
)
section_vectors.append(section_vec)
attention_weights['clause'].append(clause_attn)
# 4. Predict for each clause (using context-aware representation)
for i in range(len(section_clauses)):
clause_repr = clause_lstm_out[0, i, :] # Context-aware!
severity = self.severity_head(clause_repr)
category = self.category_head(clause_repr)
all_clause_predictions.append({
'severity': severity,
'category': category,
'text': section_clauses[i]
})
# 5. Stack sections: [num_sections, 512]
section_hidden = torch.stack(section_vectors)
# 6. LSTM over sections
section_lstm_out, _ = self.section_to_document(section_hidden.unsqueeze(0))
# 7. Attention over sections
document_vec, section_attn = self.aggregate_with_attention(
section_lstm_out,
self.section_attention
)
attention_weights['section'] = section_attn
return {
'document_embedding': document_vec,
'clause_predictions': all_clause_predictions,
'attention_weights': attention_weights
}
def predict_document(self, document_structure):
"""Convenience method for inference"""
outputs = self.forward(document_structure)
# Extract predictions
predictions = []
for clause_pred in outputs['clause_predictions']:
predictions.append({
'text': clause_pred['text'],
'severity': torch.sigmoid(clause_pred['severity']).item() * 10,
'category': torch.softmax(clause_pred['category'], dim=-1).tolist()
})
return {
'clauses': predictions,
'attention_weights': outputs['attention_weights']
}
# Usage example
if __name__ == '__main__':
model = HierarchicalContractBERT()
# Parse document into hierarchical structure
document = [
# Section 1: Services
[
"Provider shall deliver software services as described in Exhibit A.",
"Such Services shall be performed in a professional manner.",
"Services include maintenance and support."
],
# Section 2: Payment
[
"Client shall pay within 30 days of invoice.",
"Late payments incur 1.5% monthly interest.",
],
# Section 3: Termination
[
"Either party may terminate with 30 days written notice.",
"Upon termination, all obligations under Section 2 remain in effect."
]
]
# Predict
results = model.predict_document(document)
print("Clause Predictions:")
for i, clause in enumerate(results['clauses']):
print(f"\n{i+1}. {clause['text']}")
print(f" Severity: {clause['severity']:.2f}/10")
print(f" Category: {clause['category']}")
print("\n\nAttention Weights:")
print("Clause-level attention shows which clauses are most important in each section")
print("Section-level attention shows which sections are most critical overall")
```
---
## πŸš€ Migration Path
### Phase 1: Quick Win (Current + Context)
```python
# Use your current model + sliding window context
# Already implemented! βœ…
analyze_full_document(contract, model, use_context=True, context_window=2)
```
### Phase 2: Upgrade to Longformer (Easy)
```python
# Just swap BERT for Longformer
# Can process 8x longer context
from transformers import LongformerModel
model = LongformerModel.from_pretrained('lexlms/legal-longformer-base')
```
### Phase 3: Hierarchical BERT (Recommended)
```python
# Implement hierarchical model (code above)
# Better document understanding + interpretability
model = HierarchicalContractBERT()
```
### Phase 4: Graph NN (Advanced)
```python
# Add graph connections for references
# Build clause dependency graph
model = DocumentGraphNN()
```
---
## πŸ“š Key Papers
1. **Longformer**: "Longformer: The Long-Document Transformer" (Beltagy et al., 2020)
2. **BigBird**: "Big Bird: Transformers for Longer Sequences" (Zaheer et al., 2020)
3. **HAN**: "Hierarchical Attention Networks for Document Classification" (Yang et al., 2016)
4. **LayoutLM**: "LayoutLMv3: Pre-training for Document AI" (Huang et al., 2022)
5. **Legal-BERT**: "LEGAL-BERT: The Muppets straight out of Law School" (Chalkidis et al., 2020)
6. **Document Understanding**: "A Survey on Document-level Neural Machine Translation" (2023)
---
## πŸ’‘ My Suggestion
**Short term** (Today):
- βœ… Keep using context-aware analysis (already done!)
- Test with `use_context=True, context_window=2`
**Medium term** (This week):
- πŸ”„ Implement Hierarchical BERT (code provided above)
- Train on your CUAD dataset with section structure
- Compare performance: BERT vs Hierarchical BERT
**Long term** (If needed):
- Consider Longformer for very long contracts (>2000 words)
- Experiment with Graph NN if many cross-references
- Try GPT-4 for zero-shot (if budget allows)
---
## 🎯 Bottom Line
**You're correct**: This is document understanding, not just text classification!
**Current approach**: Clause-by-clause is limiting
**SOTA approach**: Hierarchical models that understand document structure
**Best ROI**: Implement Hierarchical BERT (code above)
- Moderate complexity
- Big performance gain
- Interpretable (attention weights)
- Handles full documents
Would you like me to integrate the Hierarchical BERT into your pipeline? πŸš€