# State-of-the-Art Document Understanding Methods ## 🎯 You're Right - This IS Document Understanding! Contract analysis = **Document Understanding** + Legal Domain Knowledge --- ## πŸ† Current SOTA Methods (2024-2025) ### **1. Long-Context Transformers** ⭐ BEST FOR YOUR USE CASE #### **Longformer** (Allen AI, 2020) - **Max Length**: 4,096+ tokens (vs BERT's 512) - **Innovation**: Sliding window + global attention - **Use Case**: Full contract documents without chunking ```python from transformers import LongformerModel, LongformerTokenizer model = LongformerModel.from_pretrained('allenai/longformer-base-4096') tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096') # Can process entire contract at once! inputs = tokenizer(full_contract, return_tensors='pt', max_length=4096) outputs = model(**inputs) ``` **Legal Variant**: [**Longformer-Legal**](https://huggingface.co/lexlms/legal-longformer-base) ```python # Pre-trained on legal documents! model = LongformerModel.from_pretrained('lexlms/legal-longformer-base') ``` --- #### **LED (Longformer Encoder-Decoder)** (Allen AI, 2020) - **Max Length**: 16,384 tokens - **Use Case**: Document summarization + Q&A ```python from transformers import LEDForConditionalGeneration, LEDTokenizer model = LEDForConditionalGeneration.from_pretrained('allenai/led-large-16384') # Summarize entire contract summary = model.generate(contract_tokens, max_length=512) ``` --- #### **BigBird** (Google, 2020) - **Max Length**: 4,096 tokens - **Innovation**: Sparse attention (random + global + sliding) - **Efficiency**: O(n) instead of O(nΒ²) ```python from transformers import BigBirdModel model = BigBirdModel.from_pretrained('google/bigbird-roberta-base') ``` --- ### **2. Hierarchical Document Models** πŸ”₯ RECOMMENDED #### **Hierarchical Attention Networks (HAN)** (Yang et al., 2016) - **Structure**: Word β†’ Sentence β†’ Document - **Perfect for**: Legal contracts with clause hierarchy ``` Document β”œβ”€β”€ Section 1: SERVICES β”‚ β”œβ”€β”€ Sentence 1 (attention weights) β”‚ β”œβ”€β”€ Sentence 2 (attention weights) β”‚ └── Sentence 3 (attention weights) β”œβ”€β”€ Section 2: PAYMENT β”‚ └── ... └── Section 3: TERMINATION ``` **Implementation**: ```python class HierarchicalContractModel(nn.Module): def __init__(self): super().__init__() # Word-level encoder self.word_encoder = nn.GRU(embedding_dim, hidden_dim, bidirectional=True) self.word_attention = nn.Linear(hidden_dim*2, 1) # Sentence-level encoder self.sentence_encoder = nn.GRU(hidden_dim*2, hidden_dim, bidirectional=True) self.sentence_attention = nn.Linear(hidden_dim*2, 1) # Document-level classifier self.classifier = nn.Linear(hidden_dim*2, num_classes) def forward(self, document): # document shape: [batch, num_sentences, num_words, embedding_dim] # 1. Encode words in each sentence sentence_vectors = [] for sentence in document: word_hidden = self.word_encoder(sentence) word_attn = F.softmax(self.word_attention(word_hidden), dim=0) sentence_vec = (word_hidden * word_attn).sum(dim=0) sentence_vectors.append(sentence_vec) # 2. Encode sentences in document doc_hidden = self.sentence_encoder(sentence_vectors) sent_attn = F.softmax(self.sentence_attention(doc_hidden), dim=0) doc_vec = (doc_hidden * sent_attn).sum(dim=0) # 3. Classify return self.classifier(doc_vec) ``` --- #### **BERT-HAN** (Modern variant) Combine BERT with hierarchical structure: ```python class BERTHierarchical(nn.Module): def __init__(self): super().__init__() # Use BERT for sentence encoding self.bert = AutoModel.from_pretrained('nlpaueb/legal-bert-base-uncased') # Hierarchical aggregation self.clause_encoder = nn.LSTM(768, 256, bidirectional=True) self.section_encoder = nn.LSTM(512, 256, bidirectional=True) # Attention mechanisms self.clause_attention = nn.Linear(512, 1) self.section_attention = nn.Linear(512, 1) def forward(self, sections): # sections: List[List[clause_text]] section_vectors = [] for section in sections: # Encode each clause with BERT clause_embeddings = [] for clause in section: bert_output = self.bert(**tokenizer(clause, return_tensors='pt')) clause_embeddings.append(bert_output.last_hidden_state[:, 0, :]) # Aggregate clauses -> section clause_hidden, _ = self.clause_encoder(torch.stack(clause_embeddings)) clause_attn = F.softmax(self.clause_attention(clause_hidden), dim=0) section_vec = (clause_hidden * clause_attn).sum(dim=0) section_vectors.append(section_vec) # Aggregate sections -> document section_hidden, _ = self.section_encoder(torch.stack(section_vectors)) section_attn = F.softmax(self.section_attention(section_hidden), dim=0) document_vec = (section_hidden * section_attn).sum(dim=0) return document_vec ``` --- ### **3. Document Graph Neural Networks** 🌐 #### **Graph Transformer** (Microsoft, 2022) Model document as graph: clauses = nodes, references = edges ``` [Clause 1: "Services in Exhibit A"] ──references──> [Exhibit A] β”‚ mentions β”‚ ↓ [Clause 2: "Such Services..."] ──references──> [Section 5] ``` ```python import torch_geometric as pyg class DocumentGraphNN(nn.Module): def __init__(self): super().__init__() # Node encoder (BERT for each clause) self.node_encoder = AutoModel.from_pretrained('legal-bert') # Graph convolution layers self.conv1 = pyg.nn.GCNConv(768, 256) self.conv2 = pyg.nn.GCNConv(256, 128) # Classifier self.classifier = nn.Linear(128, num_classes) def forward(self, clauses, edges): # 1. Encode nodes (clauses) node_features = [] for clause in clauses: bert_out = self.node_encoder(**tokenizer(clause, return_tensors='pt')) node_features.append(bert_out.last_hidden_state[:, 0, :]) x = torch.stack(node_features) # 2. Graph convolution (propagate context) x = F.relu(self.conv1(x, edges)) x = self.conv2(x, edges) # 3. Classify return self.classifier(x) ``` **Edge Types**: - Sequential (Clause N β†’ Clause N+1) - Reference ("Section 5" β†’ actual Section 5) - Semantic similarity (cosine > threshold) --- ### **4. Retrieval-Augmented Models** πŸ” #### **RAG (Retrieval-Augmented Generation)** (Facebook, 2020) Retrieve relevant clauses before classification ```python from transformers import RagTokenizer, RagRetriever, RagModel # Index all contract clauses retriever = RagRetriever.from_pretrained('facebook/rag-token-base') model = RagModel.from_pretrained('facebook/rag-token-base') # For each clause, retrieve similar clauses def predict_with_retrieval(clause): # Retrieve top-k similar clauses retrieved = retriever.retrieve(clause, top_k=5) # Generate prediction with context output = model.generate( context_input_ids=retrieved['input_ids'], context_attention_mask=retrieved['attention_mask'], decoder_input_ids=tokenizer(clause, return_tensors='pt')['input_ids'] ) return output ``` --- ### **5. Large Language Models (2023-2025)** πŸ€– #### **GPT-4 / Claude** (OpenAI / Anthropic) - **Context**: 128k tokens (GPT-4 Turbo), 200k (Claude 3) - **Approach**: Few-shot learning + prompting ```python import openai def analyze_contract_llm(contract): prompt = f""" Analyze this contract for risk clauses. For each clause, provide: 1. Risk severity (0-10) 2. Risk category 3. Explanation Contract: {contract} Format as JSON. """ response = openai.ChatCompletion.create( model="gpt-4-turbo", messages=[{"role": "user", "content": prompt}], temperature=0.1 # Low for consistency ) return json.loads(response.choices[0].message.content) ``` **Pros**: No training needed, incredible understanding **Cons**: Expensive, slower, privacy concerns --- #### **Legal-Specific LLMs** - **LegalBench** (Stanford, 2023) - **ChatLaw** (Peking University, 2023) - **LawGPT** (2024) --- ### **6. Multi-Modal Document Understanding** πŸ“„ #### **LayoutLM** (Microsoft, 2020-2023) Understands document **layout** + text ```python from transformers import LayoutLMv3Model # Processes: # 1. Text content # 2. Bounding boxes (where text appears) # 3. Images (if PDF has visual elements) model = LayoutLMv3Model.from_pretrained('microsoft/layoutlmv3-base') # Input includes position information inputs = { 'input_ids': text_tokens, 'bbox': bounding_boxes, # [x0, y0, x1, y1] for each token 'pixel_values': document_image } outputs = model(**inputs) ``` **Why this matters**: Contracts have structure (headers, indentation, tables) --- ## πŸ“Š Comparison for Your Use Case | Method | Context Length | Structure Aware | Training Cost | Inference Speed | Best For | |--------|---------------|-----------------|---------------|-----------------|----------| | **BERT** (current) | 512 tokens | ❌ | Low | Fast | Clause-level | | **Longformer** | 4,096 tokens | ❌ | Medium | Medium | Full documents | | **Hierarchical BERT** | Unlimited* | βœ…βœ…βœ… | Medium | Medium | **RECOMMENDED** | | **Graph NN** | Unlimited* | βœ…βœ… | High | Slow | Complex references | | **LLM (GPT-4)** | 128k tokens | βœ… | Zero! | Slow | No training data | | **LayoutLM** | 512 tokens | βœ… (visual) | High | Medium | Scanned PDFs | *Processes document in chunks with aggregation --- ## 🎯 RECOMMENDATION: Hierarchical BERT ### Why? 1. βœ… **Respects document structure** (clauses β†’ sections β†’ document) 2. βœ… **Handles any document length** (processes hierarchically) 3. βœ… **Better context modeling** than your current sliding window 4. βœ… **Interpretable** (attention weights show important sections) 5. βœ… **Moderate complexity** (not too hard to implement) ### Implementation Plan ```python # hierarchical_bert.py import torch import torch.nn as nn from transformers import AutoModel, AutoTokenizer class HierarchicalContractBERT(nn.Module): """ Hierarchical document understanding for legal contracts Structure: Word β†’ Clause β†’ Section β†’ Document """ def __init__(self, model_name='nlpaueb/legal-bert-base-uncased', num_labels=3): super().__init__() # Clause encoder (BERT) self.bert = AutoModel.from_pretrained(model_name) self.tokenizer = AutoTokenizer.from_pretrained(model_name) # Hierarchical aggregation self.clause_to_section = nn.LSTM( input_size=768, # BERT hidden size hidden_size=256, num_layers=2, bidirectional=True, dropout=0.1, batch_first=True ) self.section_to_document = nn.LSTM( input_size=512, # bidirectional * 256 hidden_size=256, num_layers=2, bidirectional=True, dropout=0.1, batch_first=True ) # Attention mechanisms self.clause_attention = nn.Sequential( nn.Linear(512, 128), nn.Tanh(), nn.Linear(128, 1) ) self.section_attention = nn.Sequential( nn.Linear(512, 128), nn.Tanh(), nn.Linear(128, 1) ) # Task heads (same as your current model) self.severity_head = nn.Linear(512, 1) self.category_head = nn.Linear(512, num_labels) def encode_clause(self, clause_text): """Encode single clause with BERT""" inputs = self.tokenizer( clause_text, return_tensors='pt', padding='max_length', truncation=True, max_length=128 # Shorter for clauses ) outputs = self.bert(**inputs) return outputs.last_hidden_state[:, 0, :] # [CLS] token def aggregate_with_attention(self, hidden_states, attention_module): """Apply attention-based aggregation""" # hidden_states: [batch, seq_len, hidden_dim] # Compute attention weights attention_logits = attention_module(hidden_states) # [batch, seq_len, 1] attention_weights = torch.softmax(attention_logits, dim=1) # Weighted sum context_vector = torch.sum(hidden_states * attention_weights, dim=1) return context_vector, attention_weights def forward(self, document_structure): """ Args: document_structure: List of sections Each section is a list of clause texts Example: [ ['clause 1', 'clause 2'], # Section 1 ['clause 3', 'clause 4', 'clause 5'], # Section 2 ] Returns: document_embedding, clause_predictions, attention_weights """ section_vectors = [] all_clause_predictions = [] attention_weights = {'clause': [], 'section': None} # Process each section for section_clauses in document_structure: clause_vectors = [] # 1. Encode each clause for clause in section_clauses: clause_vec = self.encode_clause(clause) clause_vectors.append(clause_vec) # Stack: [num_clauses, 768] clause_hidden = torch.stack(clause_vectors).squeeze(1) # 2. LSTM over clauses (captures sequential context) clause_lstm_out, _ = self.clause_to_section(clause_hidden.unsqueeze(0)) # clause_lstm_out: [1, num_clauses, 512] # 3. Attention over clauses section_vec, clause_attn = self.aggregate_with_attention( clause_lstm_out, self.clause_attention ) section_vectors.append(section_vec) attention_weights['clause'].append(clause_attn) # 4. Predict for each clause (using context-aware representation) for i in range(len(section_clauses)): clause_repr = clause_lstm_out[0, i, :] # Context-aware! severity = self.severity_head(clause_repr) category = self.category_head(clause_repr) all_clause_predictions.append({ 'severity': severity, 'category': category, 'text': section_clauses[i] }) # 5. Stack sections: [num_sections, 512] section_hidden = torch.stack(section_vectors) # 6. LSTM over sections section_lstm_out, _ = self.section_to_document(section_hidden.unsqueeze(0)) # 7. Attention over sections document_vec, section_attn = self.aggregate_with_attention( section_lstm_out, self.section_attention ) attention_weights['section'] = section_attn return { 'document_embedding': document_vec, 'clause_predictions': all_clause_predictions, 'attention_weights': attention_weights } def predict_document(self, document_structure): """Convenience method for inference""" outputs = self.forward(document_structure) # Extract predictions predictions = [] for clause_pred in outputs['clause_predictions']: predictions.append({ 'text': clause_pred['text'], 'severity': torch.sigmoid(clause_pred['severity']).item() * 10, 'category': torch.softmax(clause_pred['category'], dim=-1).tolist() }) return { 'clauses': predictions, 'attention_weights': outputs['attention_weights'] } # Usage example if __name__ == '__main__': model = HierarchicalContractBERT() # Parse document into hierarchical structure document = [ # Section 1: Services [ "Provider shall deliver software services as described in Exhibit A.", "Such Services shall be performed in a professional manner.", "Services include maintenance and support." ], # Section 2: Payment [ "Client shall pay within 30 days of invoice.", "Late payments incur 1.5% monthly interest.", ], # Section 3: Termination [ "Either party may terminate with 30 days written notice.", "Upon termination, all obligations under Section 2 remain in effect." ] ] # Predict results = model.predict_document(document) print("Clause Predictions:") for i, clause in enumerate(results['clauses']): print(f"\n{i+1}. {clause['text']}") print(f" Severity: {clause['severity']:.2f}/10") print(f" Category: {clause['category']}") print("\n\nAttention Weights:") print("Clause-level attention shows which clauses are most important in each section") print("Section-level attention shows which sections are most critical overall") ``` --- ## πŸš€ Migration Path ### Phase 1: Quick Win (Current + Context) ```python # Use your current model + sliding window context # Already implemented! βœ… analyze_full_document(contract, model, use_context=True, context_window=2) ``` ### Phase 2: Upgrade to Longformer (Easy) ```python # Just swap BERT for Longformer # Can process 8x longer context from transformers import LongformerModel model = LongformerModel.from_pretrained('lexlms/legal-longformer-base') ``` ### Phase 3: Hierarchical BERT (Recommended) ```python # Implement hierarchical model (code above) # Better document understanding + interpretability model = HierarchicalContractBERT() ``` ### Phase 4: Graph NN (Advanced) ```python # Add graph connections for references # Build clause dependency graph model = DocumentGraphNN() ``` --- ## πŸ“š Key Papers 1. **Longformer**: "Longformer: The Long-Document Transformer" (Beltagy et al., 2020) 2. **BigBird**: "Big Bird: Transformers for Longer Sequences" (Zaheer et al., 2020) 3. **HAN**: "Hierarchical Attention Networks for Document Classification" (Yang et al., 2016) 4. **LayoutLM**: "LayoutLMv3: Pre-training for Document AI" (Huang et al., 2022) 5. **Legal-BERT**: "LEGAL-BERT: The Muppets straight out of Law School" (Chalkidis et al., 2020) 6. **Document Understanding**: "A Survey on Document-level Neural Machine Translation" (2023) --- ## πŸ’‘ My Suggestion **Short term** (Today): - βœ… Keep using context-aware analysis (already done!) - Test with `use_context=True, context_window=2` **Medium term** (This week): - πŸ”„ Implement Hierarchical BERT (code provided above) - Train on your CUAD dataset with section structure - Compare performance: BERT vs Hierarchical BERT **Long term** (If needed): - Consider Longformer for very long contracts (>2000 words) - Experiment with Graph NN if many cross-references - Try GPT-4 for zero-shot (if budget allows) --- ## 🎯 Bottom Line **You're correct**: This is document understanding, not just text classification! **Current approach**: Clause-by-clause is limiting **SOTA approach**: Hierarchical models that understand document structure **Best ROI**: Implement Hierarchical BERT (code above) - Moderate complexity - Big performance gain - Interpretable (attention weights) - Handles full documents Would you like me to integrate the Hierarchical BERT into your pipeline? πŸš€