Spaces:
Runtime error
Runtime error
| """ | |
| Legal Document Processing Tool for Smolagents | |
| This tool processes legal documents with specialized models for legal text, | |
| optimizing for citation retention, multilingual support, and performance on | |
| legal-specific retrieval tasks. | |
| Author: Dr. Zhou Wang | |
| """ | |
| from typing import Dict, List, Any, Optional, Union | |
| import os | |
| import re | |
| import time | |
| import tempfile | |
| import numpy as np | |
| from tqdm import tqdm | |
| # Import Smolagents Tool class | |
| from smolagents import Tool | |
| # Import NLP components | |
| try: | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, Document | |
| from llama_index.core.node_parser import MarkdownNodeParser | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from llama_index.core.ingestion import IngestionPipeline | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| except ImportError: | |
| raise ImportError( | |
| "Required dependencies not found. Please install with: " | |
| "pip install llama-index langchain scikit-learn tqdm" | |
| ) | |
| # Model configurations based on research findings | |
| LEGAL_MODELS = { | |
| "legal-bert": { | |
| "name": "nlp-jurisprudence/legal-bert-base-uncased", | |
| "description": "Trained on ECtHR legal documents, specialized in human rights law", | |
| "max_length": 512, | |
| "requires_gpu": True, | |
| }, | |
| "multi-qa-mpnet": { | |
| "name": "sentence-transformers/multi-qa-mpnet-base-dot-v1", | |
| "description": "Optimized for legal Q&A retrieval with cross-lingual support", | |
| "max_length": 512, | |
| "requires_gpu": False, | |
| }, | |
| "legal-xlm-roberta": { | |
| "name": "joelito/legal-xlm-roberta-base", | |
| "description": "Multilingual legal model with EU legislation and RFC/ISO pattern awareness", | |
| "max_length": 512, | |
| "requires_gpu": True, | |
| }, | |
| "multilingual-e5": { | |
| "name": "intfloat/multilingual-e5-base", | |
| "description": "Dense retrieval optimized with citation context preservation", | |
| "max_length": 512, | |
| "requires_gpu": True, | |
| }, | |
| "all-mpnet": { | |
| "name": "sentence-transformers/all-mpnet-base-v2", | |
| "description": "General purpose embedding model, good baseline for legal text", | |
| "max_length": 512, | |
| "requires_gpu": False, | |
| }, | |
| } | |
| class LegalDocumentProcessor: | |
| """ | |
| Processor for legal documents with specialized models, | |
| citation preservation, and benchmarking capabilities. | |
| """ | |
| def __init__( | |
| self, | |
| model_key: str = "legal-xlm-roberta", | |
| use_gpu: bool = False, | |
| chunk_size: int = 512, | |
| chunk_overlap: int = 100, | |
| ): | |
| """ | |
| Initialize the legal document processor. | |
| Args: | |
| model_key: Key for the model to use from LEGAL_MODELS dictionary | |
| use_gpu: Whether to use GPU for embeddings (if available) | |
| chunk_size: Size of text chunks for processing | |
| chunk_overlap: Overlap between chunks to preserve context | |
| """ | |
| # Validate and set up model | |
| if model_key not in LEGAL_MODELS: | |
| print( | |
| f"Warning: Model '{model_key}' not found. Using legal-xlm-roberta as default." | |
| ) | |
| model_key = "legal-xlm-roberta" | |
| model_config = LEGAL_MODELS[model_key] | |
| device = "cuda" if use_gpu and model_config["requires_gpu"] else "cpu" | |
| # Initialize embedding model | |
| self.embed_model = HuggingFaceEmbedding( | |
| model_name=model_config["name"], | |
| device=device, | |
| tokenizer_kwargs={ | |
| "trust_remote_code": True, | |
| "max_length": model_config["max_length"], | |
| "truncation": True, | |
| }, | |
| ) | |
| # Store model information for reference | |
| self.model_info = model_config | |
| self.model_key = model_key | |
| # Legal document-optimized text splitter with improved chunk size | |
| self.splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=chunk_size, | |
| chunk_overlap=chunk_overlap, | |
| separators=[ | |
| "\n## ", | |
| "\n### ", | |
| "\n#### ", # Headers | |
| "\n\n", | |
| "\n", # Paragraphs | |
| ". ", | |
| "! ", | |
| "? ", # Sentences | |
| ";", | |
| ":", # Clause boundaries | |
| " ", # Last resort | |
| ], | |
| ) | |
| # Pattern for removing footers from legal documents | |
| # Separated into individual patterns for better maintainability | |
| self.footer_patterns = [ | |
| r"^Page\s\d+(\s+of\s+\d+)?$", # Page numbers | |
| r"^©.*\b(Company|Inc|Ltd)\b.*$", # Copyright lines | |
| r"^All rights reserved.*?$", # Legal boilerplate | |
| r"^-+$", # Separator lines | |
| r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}(:\d{2})?$", # Timestamps | |
| r"(?i)^(confidential|proprietary|internal use only)", # Security tags | |
| ] | |
| # Join all patterns with the OR operator | |
| combined_pattern = "|".join(f"({pattern})" for pattern in self.footer_patterns) | |
| # Compile the combined pattern | |
| self.footer_pattern = re.compile( | |
| combined_pattern, flags=re.MULTILINE | re.IGNORECASE | |
| ) | |
| def remove_footers(self, text: str) -> str: | |
| """ | |
| Remove common document footer patterns from text. | |
| Args: | |
| text: The input text to process | |
| Returns: | |
| Text with footer patterns removed | |
| """ | |
| return self.footer_pattern.sub("", text) | |
| def clean_text(self, text: str) -> str: | |
| """ | |
| Preserve legal citations while cleaning artifacts. | |
| Args: | |
| text: The input text to clean | |
| Returns: | |
| Cleaned text with citations preserved | |
| """ | |
| # First remove footers | |
| text = self.remove_footers(text) | |
| # Preserve citation patterns | |
| # Pattern 1: Footnote numbers (e.g., 98, 99, 100) | |
| cleaned = re.sub(r"(?<=\D)(\d{2,3})(?=\D)", r"[\1]", text) | |
| # Pattern 2: Case citations [2019] UKSC 20 | |
| # Already well-structured, so no changes needed | |
| # Pattern 3: Standardize quotation marks | |
| cleaned = cleaned.replace("''", '"').replace("``", '"') | |
| # Pattern 4: Handle section references (§3.1, §123) | |
| cleaned = re.sub(r"§(\d+(\.\d+)?)", r"Section \1", cleaned) | |
| # Pattern 5: Handle legal abbreviations (e.g., Art. -> Article) | |
| cleaned = re.sub(r"\bArt\.\s+(\d+)", r"Article \1", cleaned) | |
| # Pattern 6: Standardize case names with v. and vs. | |
| cleaned = re.sub(r"\bv\s+", r"v. ", cleaned) | |
| cleaned = re.sub(r"\bvs\s+", r"v. ", cleaned) | |
| # Pattern 7: RFC/ISO pattern standardization (RFC 1234, ISO 9001) | |
| cleaned = re.sub(r"\b(RFC|ISO)\s*[:#]?\s*(\d+)", r"\1 \2", cleaned) | |
| return cleaned | |
| def create_pipeline(self) -> IngestionPipeline: | |
| """ | |
| Create a document processing pipeline. | |
| Returns: | |
| Configured IngestionPipeline object | |
| """ | |
| return IngestionPipeline( | |
| transformations=[ | |
| self.clean_text, | |
| MarkdownNodeParser(), | |
| self.splitter, | |
| self.embed_model, | |
| ] | |
| ) | |
| def validate_citation_retention( | |
| self, documents: List[Document] | |
| ) -> Dict[str, float]: | |
| """ | |
| Measure semantic similarity of citations before/after text cleaning. | |
| Args: | |
| documents: List of Document objects to validate | |
| Returns: | |
| Dictionary with validation metrics | |
| """ | |
| if not documents: | |
| return {"citation_retention": 0.0, "processing_time": 0.0} | |
| start_time = time.time() | |
| # Extract original texts | |
| original_texts = [doc.text for doc in documents[:5]] # Sample for performance | |
| # Apply cleaning | |
| processed_texts = [self.clean_text(text) for text in original_texts] | |
| # Calculate embeddings | |
| try: | |
| # Direct access to the underlying HuggingFace model | |
| orig_embeds = self.embed_model._model.encode(original_texts) | |
| proc_embeds = self.embed_model._model.encode(processed_texts) | |
| # Calculate similarity | |
| similarities = cosine_similarity(orig_embeds, proc_embeds).diagonal() | |
| avg_similarity = float(np.mean(similarities)) | |
| processing_time = time.time() - start_time | |
| return { | |
| "citation_retention": avg_similarity * 100, # As percentage | |
| "processing_time": processing_time, | |
| "sample_size": len(original_texts), | |
| } | |
| except Exception as e: | |
| return {"citation_retention": 0.0, "processing_time": 0.0, "error": str(e)} | |
| def process_documents(self, documents: List[Document]) -> Dict[str, Any]: | |
| """ | |
| Process a list of legal documents. | |
| Args: | |
| documents: List of Document objects to process | |
| Returns: | |
| Dictionary with processing results and stats | |
| """ | |
| if not documents: | |
| return {"status": "error", "message": "No documents provided"} | |
| try: | |
| # Create pipeline and process documents | |
| pipeline = self.create_pipeline() | |
| nodes = pipeline.run(documents=documents) | |
| # Create vector index | |
| index = VectorStoreIndex(nodes) | |
| query_engine = index.as_query_engine() | |
| # Return success with stats | |
| return { | |
| "status": "success", | |
| "nodes_count": len(nodes), | |
| "documents_count": len(documents), | |
| "model_used": self.model_key, | |
| "query_engine": query_engine, # This will be used for querying | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| class LegalDocumentTool(Tool): | |
| """ | |
| Tool for processing legal documents with specialized models and querying capabilities. | |
| """ | |
| name = "legal_document_processor" | |
| description = ( | |
| "Processes legal documents with specialized models for legal text, optimizing for " | |
| "citation retention, multilingual support, and performance on legal-specific retrieval tasks. " | |
| "Can process text or file inputs and provide enhanced query capabilities." | |
| ) | |
| inputs = { | |
| "text": { | |
| "type": "string", | |
| "description": "Legal document text to process. Provide either text or file_paths.", | |
| "optional": True, | |
| }, | |
| "file_paths": { | |
| "type": "string", | |
| "description": "Comma-separated list of file paths or a directory path containing legal documents. Provide either text or file_paths.", | |
| "optional": True, | |
| }, | |
| "model_key": { | |
| "type": "string", | |
| "description": "Legal embedding model to use. Options: legal-bert, multi-qa-mpnet, legal-xlm-roberta, multilingual-e5, all-mpnet", | |
| "default": "legal-xlm-roberta", | |
| }, | |
| "query": { | |
| "type": "string", | |
| "description": "Optional query to run against the processed documents.", | |
| "optional": True, | |
| }, | |
| "validate_citations": { | |
| "type": "boolean", | |
| "description": "Whether to validate citation retention in the processed documents.", | |
| "default": False, | |
| }, | |
| "use_gpu": { | |
| "type": "boolean", | |
| "description": "Whether to use GPU for embedding calculations if available.", | |
| "default": False, | |
| }, | |
| } | |
| output_type = "string" | |
| def _load_documents(self, input_path: str) -> List[Document]: | |
| """ | |
| Load documents from a file path or directory. | |
| Args: | |
| input_path: Path to a file or directory | |
| Returns: | |
| List of Document objects | |
| """ | |
| if os.path.isfile(input_path): | |
| # Create a SimpleDirectoryReader for the file's directory | |
| # and filter to only include this file | |
| directory = os.path.dirname(input_path) | |
| filename = os.path.basename(input_path) | |
| return SimpleDirectoryReader( | |
| input_dir=directory, | |
| required_exts=[ | |
| os.path.splitext(filename)[1][1:] | |
| ], # Extension without dot | |
| filename_as_id=True, | |
| ).load_data() | |
| elif os.path.isdir(input_path): | |
| return SimpleDirectoryReader( | |
| input_dir=input_path, | |
| filename_as_id=True, | |
| ).load_data() | |
| else: | |
| raise ValueError(f"Path not found: {input_path}") | |
| def _create_document_from_text(self, text: str) -> List[Document]: | |
| """ | |
| Create a Document object from text. | |
| Args: | |
| text: Text content | |
| Returns: | |
| List containing a single Document object | |
| """ | |
| # Create a temporary file to store the text | |
| with tempfile.NamedTemporaryFile( | |
| mode="w", suffix=".md", delete=False | |
| ) as temp_file: | |
| temp_file.write(text) | |
| temp_path = temp_file.name | |
| try: | |
| # Load the document from the temporary file | |
| documents = self._load_documents(temp_path) | |
| return documents | |
| finally: | |
| # Clean up the temporary file | |
| os.remove(temp_path) | |
| def forward( | |
| self, | |
| text: Optional[str] = None, | |
| file_paths: Optional[str] = None, | |
| model_key: str = "legal-xlm-roberta", | |
| query: Optional[str] = None, | |
| validate_citations: bool = False, | |
| use_gpu: bool = False, | |
| ) -> str: | |
| """ | |
| Process legal documents and optionally run a query. | |
| Args: | |
| text: Legal document text to process | |
| file_paths: Comma-separated list of file paths or a directory path | |
| model_key: Legal embedding model to use | |
| query: Optional query to run against the processed documents | |
| validate_citations: Whether to validate citation retention | |
| use_gpu: Whether to use GPU for embeddings | |
| Returns: | |
| Processing results or query response as a string | |
| """ | |
| # Validate inputs | |
| if not text and not file_paths: | |
| return "Error: Either text or file_paths must be provided." | |
| try: | |
| # Initialize processor | |
| processor = LegalDocumentProcessor( | |
| model_key=model_key, | |
| use_gpu=use_gpu, | |
| ) | |
| # Load documents | |
| documents = [] | |
| if text: | |
| documents.extend(self._create_document_from_text(text)) | |
| if file_paths: | |
| # Handle comma-separated paths | |
| paths = [path.strip() for path in file_paths.split(",")] | |
| for path in paths: | |
| try: | |
| docs = self._load_documents(path) | |
| documents.extend(docs) | |
| except Exception as e: | |
| return f"Error loading documents from {path}: {str(e)}" | |
| # Check if we have documents to process | |
| if not documents: | |
| return "Error: No valid documents found." | |
| # Validate citations if requested | |
| validation_results = {} | |
| if validate_citations: | |
| validation_results = processor.validate_citation_retention(documents) | |
| # Process documents | |
| result = processor.process_documents(documents) | |
| if result["status"] != "success": | |
| return f"Processing error: {result['message']}" | |
| # Run query if provided | |
| if query and "query_engine" in result: | |
| query_engine = result["query_engine"] | |
| response = query_engine.query(query) | |
| # Format the response | |
| output = f"Query: {query}\n\nResponse: {response}\n\n" | |
| output += f"Documents processed: {result['documents_count']}\n" | |
| output += f"Text chunks: {result['nodes_count']}\n" | |
| output += f"Model used: {result['model_used']}\n" | |
| # Add validation results if available | |
| if validation_results: | |
| output += "\n=== Citation Retention Validation ===\n" | |
| output += f"Citation retention: {validation_results.get('citation_retention', 0):.2f}%\n" | |
| output += f"Processing time: {validation_results.get('processing_time', 0):.2f} seconds\n" | |
| return output | |
| # If no query, return processing stats | |
| output = "Document processing complete.\n\n" | |
| output += f"Documents processed: {result['documents_count']}\n" | |
| output += f"Text chunks: {result['nodes_count']}\n" | |
| output += f"Model used: {result['model_used']}\n" | |
| # Add validation results if available | |
| if validation_results: | |
| output += "\n=== Citation Retention Validation ===\n" | |
| output += f"Citation retention: {validation_results.get('citation_retention', 0):.2f}%\n" | |
| output += f"Processing time: {validation_results.get('processing_time', 0):.2f} seconds\n" | |
| output += "\nThe documents are now ready for querying. Use the 'query' parameter to run a query." | |
| return output | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |