""" Legal Document Processing Tool for Smolagents This tool processes legal documents with specialized models for legal text, optimizing for citation retention, multilingual support, and performance on legal-specific retrieval tasks. Author: Dr. Zhou Wang """ from typing import Dict, List, Any, Optional, Union import os import re import time import tempfile import numpy as np from tqdm import tqdm # Import Smolagents Tool class from smolagents import Tool # Import NLP components try: from sklearn.metrics.pairwise import cosine_similarity from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, Document from llama_index.core.node_parser import MarkdownNodeParser from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.core.ingestion import IngestionPipeline from langchain.text_splitter import RecursiveCharacterTextSplitter except ImportError: raise ImportError( "Required dependencies not found. Please install with: " "pip install llama-index langchain scikit-learn tqdm" ) # Model configurations based on research findings LEGAL_MODELS = { "legal-bert": { "name": "nlp-jurisprudence/legal-bert-base-uncased", "description": "Trained on ECtHR legal documents, specialized in human rights law", "max_length": 512, "requires_gpu": True, }, "multi-qa-mpnet": { "name": "sentence-transformers/multi-qa-mpnet-base-dot-v1", "description": "Optimized for legal Q&A retrieval with cross-lingual support", "max_length": 512, "requires_gpu": False, }, "legal-xlm-roberta": { "name": "joelito/legal-xlm-roberta-base", "description": "Multilingual legal model with EU legislation and RFC/ISO pattern awareness", "max_length": 512, "requires_gpu": True, }, "multilingual-e5": { "name": "intfloat/multilingual-e5-base", "description": "Dense retrieval optimized with citation context preservation", "max_length": 512, "requires_gpu": True, }, "all-mpnet": { "name": "sentence-transformers/all-mpnet-base-v2", "description": "General purpose embedding model, good baseline for legal text", "max_length": 512, "requires_gpu": False, }, } class LegalDocumentProcessor: """ Processor for legal documents with specialized models, citation preservation, and benchmarking capabilities. """ def __init__( self, model_key: str = "legal-xlm-roberta", use_gpu: bool = False, chunk_size: int = 512, chunk_overlap: int = 100, ): """ Initialize the legal document processor. Args: model_key: Key for the model to use from LEGAL_MODELS dictionary use_gpu: Whether to use GPU for embeddings (if available) chunk_size: Size of text chunks for processing chunk_overlap: Overlap between chunks to preserve context """ # Validate and set up model if model_key not in LEGAL_MODELS: print( f"Warning: Model '{model_key}' not found. Using legal-xlm-roberta as default." ) model_key = "legal-xlm-roberta" model_config = LEGAL_MODELS[model_key] device = "cuda" if use_gpu and model_config["requires_gpu"] else "cpu" # Initialize embedding model self.embed_model = HuggingFaceEmbedding( model_name=model_config["name"], device=device, tokenizer_kwargs={ "trust_remote_code": True, "max_length": model_config["max_length"], "truncation": True, }, ) # Store model information for reference self.model_info = model_config self.model_key = model_key # Legal document-optimized text splitter with improved chunk size self.splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap, separators=[ "\n## ", "\n### ", "\n#### ", # Headers "\n\n", "\n", # Paragraphs ". ", "! ", "? ", # Sentences ";", ":", # Clause boundaries " ", # Last resort ], ) # Pattern for removing footers from legal documents # Separated into individual patterns for better maintainability self.footer_patterns = [ r"^Page\s\d+(\s+of\s+\d+)?$", # Page numbers r"^©.*\b(Company|Inc|Ltd)\b.*$", # Copyright lines r"^All rights reserved.*?$", # Legal boilerplate r"^-+$", # Separator lines r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}(:\d{2})?$", # Timestamps r"(?i)^(confidential|proprietary|internal use only)", # Security tags ] # Join all patterns with the OR operator combined_pattern = "|".join(f"({pattern})" for pattern in self.footer_patterns) # Compile the combined pattern self.footer_pattern = re.compile( combined_pattern, flags=re.MULTILINE | re.IGNORECASE ) def remove_footers(self, text: str) -> str: """ Remove common document footer patterns from text. Args: text: The input text to process Returns: Text with footer patterns removed """ return self.footer_pattern.sub("", text) def clean_text(self, text: str) -> str: """ Preserve legal citations while cleaning artifacts. Args: text: The input text to clean Returns: Cleaned text with citations preserved """ # First remove footers text = self.remove_footers(text) # Preserve citation patterns # Pattern 1: Footnote numbers (e.g., 98, 99, 100) cleaned = re.sub(r"(?<=\D)(\d{2,3})(?=\D)", r"[\1]", text) # Pattern 2: Case citations [2019] UKSC 20 # Already well-structured, so no changes needed # Pattern 3: Standardize quotation marks cleaned = cleaned.replace("''", '"').replace("``", '"') # Pattern 4: Handle section references (§3.1, §123) cleaned = re.sub(r"§(\d+(\.\d+)?)", r"Section \1", cleaned) # Pattern 5: Handle legal abbreviations (e.g., Art. -> Article) cleaned = re.sub(r"\bArt\.\s+(\d+)", r"Article \1", cleaned) # Pattern 6: Standardize case names with v. and vs. cleaned = re.sub(r"\bv\s+", r"v. ", cleaned) cleaned = re.sub(r"\bvs\s+", r"v. ", cleaned) # Pattern 7: RFC/ISO pattern standardization (RFC 1234, ISO 9001) cleaned = re.sub(r"\b(RFC|ISO)\s*[:#]?\s*(\d+)", r"\1 \2", cleaned) return cleaned def create_pipeline(self) -> IngestionPipeline: """ Create a document processing pipeline. Returns: Configured IngestionPipeline object """ return IngestionPipeline( transformations=[ self.clean_text, MarkdownNodeParser(), self.splitter, self.embed_model, ] ) def validate_citation_retention( self, documents: List[Document] ) -> Dict[str, float]: """ Measure semantic similarity of citations before/after text cleaning. Args: documents: List of Document objects to validate Returns: Dictionary with validation metrics """ if not documents: return {"citation_retention": 0.0, "processing_time": 0.0} start_time = time.time() # Extract original texts original_texts = [doc.text for doc in documents[:5]] # Sample for performance # Apply cleaning processed_texts = [self.clean_text(text) for text in original_texts] # Calculate embeddings try: # Direct access to the underlying HuggingFace model orig_embeds = self.embed_model._model.encode(original_texts) proc_embeds = self.embed_model._model.encode(processed_texts) # Calculate similarity similarities = cosine_similarity(orig_embeds, proc_embeds).diagonal() avg_similarity = float(np.mean(similarities)) processing_time = time.time() - start_time return { "citation_retention": avg_similarity * 100, # As percentage "processing_time": processing_time, "sample_size": len(original_texts), } except Exception as e: return {"citation_retention": 0.0, "processing_time": 0.0, "error": str(e)} def process_documents(self, documents: List[Document]) -> Dict[str, Any]: """ Process a list of legal documents. Args: documents: List of Document objects to process Returns: Dictionary with processing results and stats """ if not documents: return {"status": "error", "message": "No documents provided"} try: # Create pipeline and process documents pipeline = self.create_pipeline() nodes = pipeline.run(documents=documents) # Create vector index index = VectorStoreIndex(nodes) query_engine = index.as_query_engine() # Return success with stats return { "status": "success", "nodes_count": len(nodes), "documents_count": len(documents), "model_used": self.model_key, "query_engine": query_engine, # This will be used for querying } except Exception as e: return {"status": "error", "message": str(e)} class LegalDocumentTool(Tool): """ Tool for processing legal documents with specialized models and querying capabilities. """ name = "legal_document_processor" description = ( "Processes legal documents with specialized models for legal text, optimizing for " "citation retention, multilingual support, and performance on legal-specific retrieval tasks. " "Can process text or file inputs and provide enhanced query capabilities." ) inputs = { "text": { "type": "string", "description": "Legal document text to process. Provide either text or file_paths.", "optional": True, }, "file_paths": { "type": "string", "description": "Comma-separated list of file paths or a directory path containing legal documents. Provide either text or file_paths.", "optional": True, }, "model_key": { "type": "string", "description": "Legal embedding model to use. Options: legal-bert, multi-qa-mpnet, legal-xlm-roberta, multilingual-e5, all-mpnet", "default": "legal-xlm-roberta", }, "query": { "type": "string", "description": "Optional query to run against the processed documents.", "optional": True, }, "validate_citations": { "type": "boolean", "description": "Whether to validate citation retention in the processed documents.", "default": False, }, "use_gpu": { "type": "boolean", "description": "Whether to use GPU for embedding calculations if available.", "default": False, }, } output_type = "string" def _load_documents(self, input_path: str) -> List[Document]: """ Load documents from a file path or directory. Args: input_path: Path to a file or directory Returns: List of Document objects """ if os.path.isfile(input_path): # Create a SimpleDirectoryReader for the file's directory # and filter to only include this file directory = os.path.dirname(input_path) filename = os.path.basename(input_path) return SimpleDirectoryReader( input_dir=directory, required_exts=[ os.path.splitext(filename)[1][1:] ], # Extension without dot filename_as_id=True, ).load_data() elif os.path.isdir(input_path): return SimpleDirectoryReader( input_dir=input_path, filename_as_id=True, ).load_data() else: raise ValueError(f"Path not found: {input_path}") def _create_document_from_text(self, text: str) -> List[Document]: """ Create a Document object from text. Args: text: Text content Returns: List containing a single Document object """ # Create a temporary file to store the text with tempfile.NamedTemporaryFile( mode="w", suffix=".md", delete=False ) as temp_file: temp_file.write(text) temp_path = temp_file.name try: # Load the document from the temporary file documents = self._load_documents(temp_path) return documents finally: # Clean up the temporary file os.remove(temp_path) def forward( self, text: Optional[str] = None, file_paths: Optional[str] = None, model_key: str = "legal-xlm-roberta", query: Optional[str] = None, validate_citations: bool = False, use_gpu: bool = False, ) -> str: """ Process legal documents and optionally run a query. Args: text: Legal document text to process file_paths: Comma-separated list of file paths or a directory path model_key: Legal embedding model to use query: Optional query to run against the processed documents validate_citations: Whether to validate citation retention use_gpu: Whether to use GPU for embeddings Returns: Processing results or query response as a string """ # Validate inputs if not text and not file_paths: return "Error: Either text or file_paths must be provided." try: # Initialize processor processor = LegalDocumentProcessor( model_key=model_key, use_gpu=use_gpu, ) # Load documents documents = [] if text: documents.extend(self._create_document_from_text(text)) if file_paths: # Handle comma-separated paths paths = [path.strip() for path in file_paths.split(",")] for path in paths: try: docs = self._load_documents(path) documents.extend(docs) except Exception as e: return f"Error loading documents from {path}: {str(e)}" # Check if we have documents to process if not documents: return "Error: No valid documents found." # Validate citations if requested validation_results = {} if validate_citations: validation_results = processor.validate_citation_retention(documents) # Process documents result = processor.process_documents(documents) if result["status"] != "success": return f"Processing error: {result['message']}" # Run query if provided if query and "query_engine" in result: query_engine = result["query_engine"] response = query_engine.query(query) # Format the response output = f"Query: {query}\n\nResponse: {response}\n\n" output += f"Documents processed: {result['documents_count']}\n" output += f"Text chunks: {result['nodes_count']}\n" output += f"Model used: {result['model_used']}\n" # Add validation results if available if validation_results: output += "\n=== Citation Retention Validation ===\n" output += f"Citation retention: {validation_results.get('citation_retention', 0):.2f}%\n" output += f"Processing time: {validation_results.get('processing_time', 0):.2f} seconds\n" return output # If no query, return processing stats output = "Document processing complete.\n\n" output += f"Documents processed: {result['documents_count']}\n" output += f"Text chunks: {result['nodes_count']}\n" output += f"Model used: {result['model_used']}\n" # Add validation results if available if validation_results: output += "\n=== Citation Retention Validation ===\n" output += f"Citation retention: {validation_results.get('citation_retention', 0):.2f}%\n" output += f"Processing time: {validation_results.get('processing_time', 0):.2f} seconds\n" output += "\nThe documents are now ready for querying. Use the 'query' parameter to run a query." return output except Exception as e: return f"Error: {str(e)}"