| # State-of-the-Art Document Understanding Methods | |
| ## π― You're Right - This IS Document Understanding! | |
| Contract analysis = **Document Understanding** + Legal Domain Knowledge | |
| --- | |
| ## π Current SOTA Methods (2024-2025) | |
| ### **1. Long-Context Transformers** β BEST FOR YOUR USE CASE | |
| #### **Longformer** (Allen AI, 2020) | |
| - **Max Length**: 4,096+ tokens (vs BERT's 512) | |
| - **Innovation**: Sliding window + global attention | |
| - **Use Case**: Full contract documents without chunking | |
| ```python | |
| from transformers import LongformerModel, LongformerTokenizer | |
| model = LongformerModel.from_pretrained('allenai/longformer-base-4096') | |
| tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096') | |
| # Can process entire contract at once! | |
| inputs = tokenizer(full_contract, return_tensors='pt', max_length=4096) | |
| outputs = model(**inputs) | |
| ``` | |
| **Legal Variant**: [**Longformer-Legal**](https://huggingface.co/lexlms/legal-longformer-base) | |
| ```python | |
| # Pre-trained on legal documents! | |
| model = LongformerModel.from_pretrained('lexlms/legal-longformer-base') | |
| ``` | |
| --- | |
| #### **LED (Longformer Encoder-Decoder)** (Allen AI, 2020) | |
| - **Max Length**: 16,384 tokens | |
| - **Use Case**: Document summarization + Q&A | |
| ```python | |
| from transformers import LEDForConditionalGeneration, LEDTokenizer | |
| model = LEDForConditionalGeneration.from_pretrained('allenai/led-large-16384') | |
| # Summarize entire contract | |
| summary = model.generate(contract_tokens, max_length=512) | |
| ``` | |
| --- | |
| #### **BigBird** (Google, 2020) | |
| - **Max Length**: 4,096 tokens | |
| - **Innovation**: Sparse attention (random + global + sliding) | |
| - **Efficiency**: O(n) instead of O(nΒ²) | |
| ```python | |
| from transformers import BigBirdModel | |
| model = BigBirdModel.from_pretrained('google/bigbird-roberta-base') | |
| ``` | |
| --- | |
| ### **2. Hierarchical Document Models** π₯ RECOMMENDED | |
| #### **Hierarchical Attention Networks (HAN)** (Yang et al., 2016) | |
| - **Structure**: Word β Sentence β Document | |
| - **Perfect for**: Legal contracts with clause hierarchy | |
| ``` | |
| Document | |
| βββ Section 1: SERVICES | |
| β βββ Sentence 1 (attention weights) | |
| β βββ Sentence 2 (attention weights) | |
| β βββ Sentence 3 (attention weights) | |
| βββ Section 2: PAYMENT | |
| β βββ ... | |
| βββ Section 3: TERMINATION | |
| ``` | |
| **Implementation**: | |
| ```python | |
| class HierarchicalContractModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| # Word-level encoder | |
| self.word_encoder = nn.GRU(embedding_dim, hidden_dim, bidirectional=True) | |
| self.word_attention = nn.Linear(hidden_dim*2, 1) | |
| # Sentence-level encoder | |
| self.sentence_encoder = nn.GRU(hidden_dim*2, hidden_dim, bidirectional=True) | |
| self.sentence_attention = nn.Linear(hidden_dim*2, 1) | |
| # Document-level classifier | |
| self.classifier = nn.Linear(hidden_dim*2, num_classes) | |
| def forward(self, document): | |
| # document shape: [batch, num_sentences, num_words, embedding_dim] | |
| # 1. Encode words in each sentence | |
| sentence_vectors = [] | |
| for sentence in document: | |
| word_hidden = self.word_encoder(sentence) | |
| word_attn = F.softmax(self.word_attention(word_hidden), dim=0) | |
| sentence_vec = (word_hidden * word_attn).sum(dim=0) | |
| sentence_vectors.append(sentence_vec) | |
| # 2. Encode sentences in document | |
| doc_hidden = self.sentence_encoder(sentence_vectors) | |
| sent_attn = F.softmax(self.sentence_attention(doc_hidden), dim=0) | |
| doc_vec = (doc_hidden * sent_attn).sum(dim=0) | |
| # 3. Classify | |
| return self.classifier(doc_vec) | |
| ``` | |
| --- | |
| #### **BERT-HAN** (Modern variant) | |
| Combine BERT with hierarchical structure: | |
| ```python | |
| class BERTHierarchical(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| # Use BERT for sentence encoding | |
| self.bert = AutoModel.from_pretrained('nlpaueb/legal-bert-base-uncased') | |
| # Hierarchical aggregation | |
| self.clause_encoder = nn.LSTM(768, 256, bidirectional=True) | |
| self.section_encoder = nn.LSTM(512, 256, bidirectional=True) | |
| # Attention mechanisms | |
| self.clause_attention = nn.Linear(512, 1) | |
| self.section_attention = nn.Linear(512, 1) | |
| def forward(self, sections): | |
| # sections: List[List[clause_text]] | |
| section_vectors = [] | |
| for section in sections: | |
| # Encode each clause with BERT | |
| clause_embeddings = [] | |
| for clause in section: | |
| bert_output = self.bert(**tokenizer(clause, return_tensors='pt')) | |
| clause_embeddings.append(bert_output.last_hidden_state[:, 0, :]) | |
| # Aggregate clauses -> section | |
| clause_hidden, _ = self.clause_encoder(torch.stack(clause_embeddings)) | |
| clause_attn = F.softmax(self.clause_attention(clause_hidden), dim=0) | |
| section_vec = (clause_hidden * clause_attn).sum(dim=0) | |
| section_vectors.append(section_vec) | |
| # Aggregate sections -> document | |
| section_hidden, _ = self.section_encoder(torch.stack(section_vectors)) | |
| section_attn = F.softmax(self.section_attention(section_hidden), dim=0) | |
| document_vec = (section_hidden * section_attn).sum(dim=0) | |
| return document_vec | |
| ``` | |
| --- | |
| ### **3. Document Graph Neural Networks** π | |
| #### **Graph Transformer** (Microsoft, 2022) | |
| Model document as graph: clauses = nodes, references = edges | |
| ``` | |
| [Clause 1: "Services in Exhibit A"] ββreferencesββ> [Exhibit A] | |
| β | |
| mentions | |
| β | |
| β | |
| [Clause 2: "Such Services..."] ββreferencesββ> [Section 5] | |
| ``` | |
| ```python | |
| import torch_geometric as pyg | |
| class DocumentGraphNN(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| # Node encoder (BERT for each clause) | |
| self.node_encoder = AutoModel.from_pretrained('legal-bert') | |
| # Graph convolution layers | |
| self.conv1 = pyg.nn.GCNConv(768, 256) | |
| self.conv2 = pyg.nn.GCNConv(256, 128) | |
| # Classifier | |
| self.classifier = nn.Linear(128, num_classes) | |
| def forward(self, clauses, edges): | |
| # 1. Encode nodes (clauses) | |
| node_features = [] | |
| for clause in clauses: | |
| bert_out = self.node_encoder(**tokenizer(clause, return_tensors='pt')) | |
| node_features.append(bert_out.last_hidden_state[:, 0, :]) | |
| x = torch.stack(node_features) | |
| # 2. Graph convolution (propagate context) | |
| x = F.relu(self.conv1(x, edges)) | |
| x = self.conv2(x, edges) | |
| # 3. Classify | |
| return self.classifier(x) | |
| ``` | |
| **Edge Types**: | |
| - Sequential (Clause N β Clause N+1) | |
| - Reference ("Section 5" β actual Section 5) | |
| - Semantic similarity (cosine > threshold) | |
| --- | |
| ### **4. Retrieval-Augmented Models** π | |
| #### **RAG (Retrieval-Augmented Generation)** (Facebook, 2020) | |
| Retrieve relevant clauses before classification | |
| ```python | |
| from transformers import RagTokenizer, RagRetriever, RagModel | |
| # Index all contract clauses | |
| retriever = RagRetriever.from_pretrained('facebook/rag-token-base') | |
| model = RagModel.from_pretrained('facebook/rag-token-base') | |
| # For each clause, retrieve similar clauses | |
| def predict_with_retrieval(clause): | |
| # Retrieve top-k similar clauses | |
| retrieved = retriever.retrieve(clause, top_k=5) | |
| # Generate prediction with context | |
| output = model.generate( | |
| context_input_ids=retrieved['input_ids'], | |
| context_attention_mask=retrieved['attention_mask'], | |
| decoder_input_ids=tokenizer(clause, return_tensors='pt')['input_ids'] | |
| ) | |
| return output | |
| ``` | |
| --- | |
| ### **5. Large Language Models (2023-2025)** π€ | |
| #### **GPT-4 / Claude** (OpenAI / Anthropic) | |
| - **Context**: 128k tokens (GPT-4 Turbo), 200k (Claude 3) | |
| - **Approach**: Few-shot learning + prompting | |
| ```python | |
| import openai | |
| def analyze_contract_llm(contract): | |
| prompt = f""" | |
| Analyze this contract for risk clauses. For each clause, provide: | |
| 1. Risk severity (0-10) | |
| 2. Risk category | |
| 3. Explanation | |
| Contract: | |
| {contract} | |
| Format as JSON. | |
| """ | |
| response = openai.ChatCompletion.create( | |
| model="gpt-4-turbo", | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.1 # Low for consistency | |
| ) | |
| return json.loads(response.choices[0].message.content) | |
| ``` | |
| **Pros**: No training needed, incredible understanding | |
| **Cons**: Expensive, slower, privacy concerns | |
| --- | |
| #### **Legal-Specific LLMs** | |
| - **LegalBench** (Stanford, 2023) | |
| - **ChatLaw** (Peking University, 2023) | |
| - **LawGPT** (2024) | |
| --- | |
| ### **6. Multi-Modal Document Understanding** π | |
| #### **LayoutLM** (Microsoft, 2020-2023) | |
| Understands document **layout** + text | |
| ```python | |
| from transformers import LayoutLMv3Model | |
| # Processes: | |
| # 1. Text content | |
| # 2. Bounding boxes (where text appears) | |
| # 3. Images (if PDF has visual elements) | |
| model = LayoutLMv3Model.from_pretrained('microsoft/layoutlmv3-base') | |
| # Input includes position information | |
| inputs = { | |
| 'input_ids': text_tokens, | |
| 'bbox': bounding_boxes, # [x0, y0, x1, y1] for each token | |
| 'pixel_values': document_image | |
| } | |
| outputs = model(**inputs) | |
| ``` | |
| **Why this matters**: Contracts have structure (headers, indentation, tables) | |
| --- | |
| ## π Comparison for Your Use Case | |
| | Method | Context Length | Structure Aware | Training Cost | Inference Speed | Best For | | |
| |--------|---------------|-----------------|---------------|-----------------|----------| | |
| | **BERT** (current) | 512 tokens | β | Low | Fast | Clause-level | | |
| | **Longformer** | 4,096 tokens | β | Medium | Medium | Full documents | | |
| | **Hierarchical BERT** | Unlimited* | β β β | Medium | Medium | **RECOMMENDED** | | |
| | **Graph NN** | Unlimited* | β β | High | Slow | Complex references | | |
| | **LLM (GPT-4)** | 128k tokens | β | Zero! | Slow | No training data | | |
| | **LayoutLM** | 512 tokens | β (visual) | High | Medium | Scanned PDFs | | |
| *Processes document in chunks with aggregation | |
| --- | |
| ## π― RECOMMENDATION: Hierarchical BERT | |
| ### Why? | |
| 1. β **Respects document structure** (clauses β sections β document) | |
| 2. β **Handles any document length** (processes hierarchically) | |
| 3. β **Better context modeling** than your current sliding window | |
| 4. β **Interpretable** (attention weights show important sections) | |
| 5. β **Moderate complexity** (not too hard to implement) | |
| ### Implementation Plan | |
| ```python | |
| # hierarchical_bert.py | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoModel, AutoTokenizer | |
| class HierarchicalContractBERT(nn.Module): | |
| """ | |
| Hierarchical document understanding for legal contracts | |
| Structure: | |
| Word β Clause β Section β Document | |
| """ | |
| def __init__(self, model_name='nlpaueb/legal-bert-base-uncased', num_labels=3): | |
| super().__init__() | |
| # Clause encoder (BERT) | |
| self.bert = AutoModel.from_pretrained(model_name) | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Hierarchical aggregation | |
| self.clause_to_section = nn.LSTM( | |
| input_size=768, # BERT hidden size | |
| hidden_size=256, | |
| num_layers=2, | |
| bidirectional=True, | |
| dropout=0.1, | |
| batch_first=True | |
| ) | |
| self.section_to_document = nn.LSTM( | |
| input_size=512, # bidirectional * 256 | |
| hidden_size=256, | |
| num_layers=2, | |
| bidirectional=True, | |
| dropout=0.1, | |
| batch_first=True | |
| ) | |
| # Attention mechanisms | |
| self.clause_attention = nn.Sequential( | |
| nn.Linear(512, 128), | |
| nn.Tanh(), | |
| nn.Linear(128, 1) | |
| ) | |
| self.section_attention = nn.Sequential( | |
| nn.Linear(512, 128), | |
| nn.Tanh(), | |
| nn.Linear(128, 1) | |
| ) | |
| # Task heads (same as your current model) | |
| self.severity_head = nn.Linear(512, 1) | |
| self.category_head = nn.Linear(512, num_labels) | |
| def encode_clause(self, clause_text): | |
| """Encode single clause with BERT""" | |
| inputs = self.tokenizer( | |
| clause_text, | |
| return_tensors='pt', | |
| padding='max_length', | |
| truncation=True, | |
| max_length=128 # Shorter for clauses | |
| ) | |
| outputs = self.bert(**inputs) | |
| return outputs.last_hidden_state[:, 0, :] # [CLS] token | |
| def aggregate_with_attention(self, hidden_states, attention_module): | |
| """Apply attention-based aggregation""" | |
| # hidden_states: [batch, seq_len, hidden_dim] | |
| # Compute attention weights | |
| attention_logits = attention_module(hidden_states) # [batch, seq_len, 1] | |
| attention_weights = torch.softmax(attention_logits, dim=1) | |
| # Weighted sum | |
| context_vector = torch.sum(hidden_states * attention_weights, dim=1) | |
| return context_vector, attention_weights | |
| def forward(self, document_structure): | |
| """ | |
| Args: | |
| document_structure: List of sections | |
| Each section is a list of clause texts | |
| Example: [ | |
| ['clause 1', 'clause 2'], # Section 1 | |
| ['clause 3', 'clause 4', 'clause 5'], # Section 2 | |
| ] | |
| Returns: | |
| document_embedding, clause_predictions, attention_weights | |
| """ | |
| section_vectors = [] | |
| all_clause_predictions = [] | |
| attention_weights = {'clause': [], 'section': None} | |
| # Process each section | |
| for section_clauses in document_structure: | |
| clause_vectors = [] | |
| # 1. Encode each clause | |
| for clause in section_clauses: | |
| clause_vec = self.encode_clause(clause) | |
| clause_vectors.append(clause_vec) | |
| # Stack: [num_clauses, 768] | |
| clause_hidden = torch.stack(clause_vectors).squeeze(1) | |
| # 2. LSTM over clauses (captures sequential context) | |
| clause_lstm_out, _ = self.clause_to_section(clause_hidden.unsqueeze(0)) | |
| # clause_lstm_out: [1, num_clauses, 512] | |
| # 3. Attention over clauses | |
| section_vec, clause_attn = self.aggregate_with_attention( | |
| clause_lstm_out, | |
| self.clause_attention | |
| ) | |
| section_vectors.append(section_vec) | |
| attention_weights['clause'].append(clause_attn) | |
| # 4. Predict for each clause (using context-aware representation) | |
| for i in range(len(section_clauses)): | |
| clause_repr = clause_lstm_out[0, i, :] # Context-aware! | |
| severity = self.severity_head(clause_repr) | |
| category = self.category_head(clause_repr) | |
| all_clause_predictions.append({ | |
| 'severity': severity, | |
| 'category': category, | |
| 'text': section_clauses[i] | |
| }) | |
| # 5. Stack sections: [num_sections, 512] | |
| section_hidden = torch.stack(section_vectors) | |
| # 6. LSTM over sections | |
| section_lstm_out, _ = self.section_to_document(section_hidden.unsqueeze(0)) | |
| # 7. Attention over sections | |
| document_vec, section_attn = self.aggregate_with_attention( | |
| section_lstm_out, | |
| self.section_attention | |
| ) | |
| attention_weights['section'] = section_attn | |
| return { | |
| 'document_embedding': document_vec, | |
| 'clause_predictions': all_clause_predictions, | |
| 'attention_weights': attention_weights | |
| } | |
| def predict_document(self, document_structure): | |
| """Convenience method for inference""" | |
| outputs = self.forward(document_structure) | |
| # Extract predictions | |
| predictions = [] | |
| for clause_pred in outputs['clause_predictions']: | |
| predictions.append({ | |
| 'text': clause_pred['text'], | |
| 'severity': torch.sigmoid(clause_pred['severity']).item() * 10, | |
| 'category': torch.softmax(clause_pred['category'], dim=-1).tolist() | |
| }) | |
| return { | |
| 'clauses': predictions, | |
| 'attention_weights': outputs['attention_weights'] | |
| } | |
| # Usage example | |
| if __name__ == '__main__': | |
| model = HierarchicalContractBERT() | |
| # Parse document into hierarchical structure | |
| document = [ | |
| # Section 1: Services | |
| [ | |
| "Provider shall deliver software services as described in Exhibit A.", | |
| "Such Services shall be performed in a professional manner.", | |
| "Services include maintenance and support." | |
| ], | |
| # Section 2: Payment | |
| [ | |
| "Client shall pay within 30 days of invoice.", | |
| "Late payments incur 1.5% monthly interest.", | |
| ], | |
| # Section 3: Termination | |
| [ | |
| "Either party may terminate with 30 days written notice.", | |
| "Upon termination, all obligations under Section 2 remain in effect." | |
| ] | |
| ] | |
| # Predict | |
| results = model.predict_document(document) | |
| print("Clause Predictions:") | |
| for i, clause in enumerate(results['clauses']): | |
| print(f"\n{i+1}. {clause['text']}") | |
| print(f" Severity: {clause['severity']:.2f}/10") | |
| print(f" Category: {clause['category']}") | |
| print("\n\nAttention Weights:") | |
| print("Clause-level attention shows which clauses are most important in each section") | |
| print("Section-level attention shows which sections are most critical overall") | |
| ``` | |
| --- | |
| ## π Migration Path | |
| ### Phase 1: Quick Win (Current + Context) | |
| ```python | |
| # Use your current model + sliding window context | |
| # Already implemented! β | |
| analyze_full_document(contract, model, use_context=True, context_window=2) | |
| ``` | |
| ### Phase 2: Upgrade to Longformer (Easy) | |
| ```python | |
| # Just swap BERT for Longformer | |
| # Can process 8x longer context | |
| from transformers import LongformerModel | |
| model = LongformerModel.from_pretrained('lexlms/legal-longformer-base') | |
| ``` | |
| ### Phase 3: Hierarchical BERT (Recommended) | |
| ```python | |
| # Implement hierarchical model (code above) | |
| # Better document understanding + interpretability | |
| model = HierarchicalContractBERT() | |
| ``` | |
| ### Phase 4: Graph NN (Advanced) | |
| ```python | |
| # Add graph connections for references | |
| # Build clause dependency graph | |
| model = DocumentGraphNN() | |
| ``` | |
| --- | |
| ## π Key Papers | |
| 1. **Longformer**: "Longformer: The Long-Document Transformer" (Beltagy et al., 2020) | |
| 2. **BigBird**: "Big Bird: Transformers for Longer Sequences" (Zaheer et al., 2020) | |
| 3. **HAN**: "Hierarchical Attention Networks for Document Classification" (Yang et al., 2016) | |
| 4. **LayoutLM**: "LayoutLMv3: Pre-training for Document AI" (Huang et al., 2022) | |
| 5. **Legal-BERT**: "LEGAL-BERT: The Muppets straight out of Law School" (Chalkidis et al., 2020) | |
| 6. **Document Understanding**: "A Survey on Document-level Neural Machine Translation" (2023) | |
| --- | |
| ## π‘ My Suggestion | |
| **Short term** (Today): | |
| - β Keep using context-aware analysis (already done!) | |
| - Test with `use_context=True, context_window=2` | |
| **Medium term** (This week): | |
| - π Implement Hierarchical BERT (code provided above) | |
| - Train on your CUAD dataset with section structure | |
| - Compare performance: BERT vs Hierarchical BERT | |
| **Long term** (If needed): | |
| - Consider Longformer for very long contracts (>2000 words) | |
| - Experiment with Graph NN if many cross-references | |
| - Try GPT-4 for zero-shot (if budget allows) | |
| --- | |
| ## π― Bottom Line | |
| **You're correct**: This is document understanding, not just text classification! | |
| **Current approach**: Clause-by-clause is limiting | |
| **SOTA approach**: Hierarchical models that understand document structure | |
| **Best ROI**: Implement Hierarchical BERT (code above) | |
| - Moderate complexity | |
| - Big performance gain | |
| - Interpretable (attention weights) | |
| - Handles full documents | |
| Would you like me to integrate the Hierarchical BERT into your pipeline? π | |