Leonardo commited on
Commit
9a5117b
·
verified ·
1 Parent(s): 66af00d

Create scripts/legal_document_tool.py

Browse files
Files changed (1) hide show
  1. scripts/legal_document_tool.py +505 -0
scripts/legal_document_tool.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Legal Document Processing Tool for Smolagents
3
+
4
+ This tool processes legal documents with specialized models for legal text,
5
+ optimizing for citation retention, multilingual support, and performance on
6
+ legal-specific retrieval tasks.
7
+
8
+ Author: Dr. Zhou Wang
9
+ """
10
+
11
+ from typing import Dict, List, Any, Optional, Union
12
+ import os
13
+ import re
14
+ import time
15
+ import tempfile
16
+ import numpy as np
17
+ from tqdm import tqdm
18
+
19
+ # Import Smolagents Tool class
20
+ from smolagents import Tool
21
+
22
+ # Import NLP components
23
+ try:
24
+ from sklearn.metrics.pairwise import cosine_similarity
25
+ from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, Document
26
+ from llama_index.core.node_parser import MarkdownNodeParser
27
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
28
+ from llama_index.core.ingestion import IngestionPipeline
29
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
30
+ except ImportError:
31
+ raise ImportError(
32
+ "Required dependencies not found. Please install with: "
33
+ "pip install llama-index langchain scikit-learn tqdm"
34
+ )
35
+
36
+
37
+ # Model configurations based on research findings
38
+ LEGAL_MODELS = {
39
+ "legal-bert": {
40
+ "name": "nlp-jurisprudence/legal-bert-base-uncased",
41
+ "description": "Trained on ECtHR legal documents, specialized in human rights law",
42
+ "max_length": 512,
43
+ "requires_gpu": True,
44
+ },
45
+ "multi-qa-mpnet": {
46
+ "name": "sentence-transformers/multi-qa-mpnet-base-dot-v1",
47
+ "description": "Optimized for legal Q&A retrieval with cross-lingual support",
48
+ "max_length": 512,
49
+ "requires_gpu": False,
50
+ },
51
+ "legal-xlm-roberta": {
52
+ "name": "joelito/legal-xlm-roberta-base",
53
+ "description": "Multilingual legal model with EU legislation and RFC/ISO pattern awareness",
54
+ "max_length": 512,
55
+ "requires_gpu": True,
56
+ },
57
+ "multilingual-e5": {
58
+ "name": "intfloat/multilingual-e5-base",
59
+ "description": "Dense retrieval optimized with citation context preservation",
60
+ "max_length": 512,
61
+ "requires_gpu": True,
62
+ },
63
+ "all-mpnet": {
64
+ "name": "sentence-transformers/all-mpnet-base-v2",
65
+ "description": "General purpose embedding model, good baseline for legal text",
66
+ "max_length": 512,
67
+ "requires_gpu": False,
68
+ },
69
+ }
70
+
71
+
72
+ class LegalDocumentProcessor:
73
+ """
74
+ Processor for legal documents with specialized models,
75
+ citation preservation, and benchmarking capabilities.
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ model_key: str = "legal-xlm-roberta",
81
+ use_gpu: bool = False,
82
+ chunk_size: int = 512,
83
+ chunk_overlap: int = 100,
84
+ ):
85
+ """
86
+ Initialize the legal document processor.
87
+
88
+ Args:
89
+ model_key: Key for the model to use from LEGAL_MODELS dictionary
90
+ use_gpu: Whether to use GPU for embeddings (if available)
91
+ chunk_size: Size of text chunks for processing
92
+ chunk_overlap: Overlap between chunks to preserve context
93
+ """
94
+ # Validate and set up model
95
+ if model_key not in LEGAL_MODELS:
96
+ print(
97
+ f"Warning: Model '{model_key}' not found. Using legal-xlm-roberta as default."
98
+ )
99
+ model_key = "legal-xlm-roberta"
100
+
101
+ model_config = LEGAL_MODELS[model_key]
102
+ device = "cuda" if use_gpu and model_config["requires_gpu"] else "cpu"
103
+
104
+ # Initialize embedding model
105
+ self.embed_model = HuggingFaceEmbedding(
106
+ model_name=model_config["name"],
107
+ device=device,
108
+ tokenizer_kwargs={
109
+ "trust_remote_code": True,
110
+ "max_length": model_config["max_length"],
111
+ "truncation": True,
112
+ },
113
+ )
114
+
115
+ # Store model information for reference
116
+ self.model_info = model_config
117
+ self.model_key = model_key
118
+
119
+ # Legal document-optimized text splitter with improved chunk size
120
+ self.splitter = RecursiveCharacterTextSplitter(
121
+ chunk_size=chunk_size,
122
+ chunk_overlap=chunk_overlap,
123
+ separators=[
124
+ "\n## ",
125
+ "\n### ",
126
+ "\n#### ", # Headers
127
+ "\n\n",
128
+ "\n", # Paragraphs
129
+ ". ",
130
+ "! ",
131
+ "? ", # Sentences
132
+ ";",
133
+ ":", # Clause boundaries
134
+ " ", # Last resort
135
+ ],
136
+ )
137
+
138
+ # Pattern for removing footers from legal documents
139
+ # Separated into individual patterns for better maintainability
140
+ self.footer_patterns = [
141
+ r"^Page\s\d+(\s+of\s+\d+)?$", # Page numbers
142
+ r"^©.*\b(Company|Inc|Ltd)\b.*$", # Copyright lines
143
+ r"^All rights reserved.*?$", # Legal boilerplate
144
+ r"^-+$", # Separator lines
145
+ r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}(:\d{2})?$", # Timestamps
146
+ r"(?i)^(confidential|proprietary|internal use only)", # Security tags
147
+ ]
148
+
149
+ # Join all patterns with the OR operator
150
+ combined_pattern = "|".join(f"({pattern})" for pattern in self.footer_patterns)
151
+
152
+ # Compile the combined pattern
153
+ self.footer_pattern = re.compile(
154
+ combined_pattern, flags=re.MULTILINE | re.IGNORECASE
155
+ )
156
+
157
+ def remove_footers(self, text: str) -> str:
158
+ """
159
+ Remove common document footer patterns from text.
160
+
161
+ Args:
162
+ text: The input text to process
163
+
164
+ Returns:
165
+ Text with footer patterns removed
166
+ """
167
+ return self.footer_pattern.sub("", text)
168
+
169
+ def clean_text(self, text: str) -> str:
170
+ """
171
+ Preserve legal citations while cleaning artifacts.
172
+
173
+ Args:
174
+ text: The input text to clean
175
+
176
+ Returns:
177
+ Cleaned text with citations preserved
178
+ """
179
+ # First remove footers
180
+ text = self.remove_footers(text)
181
+
182
+ # Preserve citation patterns
183
+ # Pattern 1: Footnote numbers (e.g., 98, 99, 100)
184
+ cleaned = re.sub(r"(?<=\D)(\d{2,3})(?=\D)", r"[\1]", text)
185
+
186
+ # Pattern 2: Case citations [2019] UKSC 20
187
+ # Already well-structured, so no changes needed
188
+
189
+ # Pattern 3: Standardize quotation marks
190
+ cleaned = cleaned.replace("''", '"').replace("``", '"')
191
+
192
+ # Pattern 4: Handle section references (§3.1, §123)
193
+ cleaned = re.sub(r"§(\d+(\.\d+)?)", r"Section \1", cleaned)
194
+
195
+ # Pattern 5: Handle legal abbreviations (e.g., Art. -> Article)
196
+ cleaned = re.sub(r"\bArt\.\s+(\d+)", r"Article \1", cleaned)
197
+
198
+ # Pattern 6: Standardize case names with v. and vs.
199
+ cleaned = re.sub(r"\bv\s+", r"v. ", cleaned)
200
+ cleaned = re.sub(r"\bvs\s+", r"v. ", cleaned)
201
+
202
+ # Pattern 7: RFC/ISO pattern standardization (RFC 1234, ISO 9001)
203
+ cleaned = re.sub(r"\b(RFC|ISO)\s*[:#]?\s*(\d+)", r"\1 \2", cleaned)
204
+
205
+ return cleaned
206
+
207
+ def create_pipeline(self) -> IngestionPipeline:
208
+ """
209
+ Create a document processing pipeline.
210
+
211
+ Returns:
212
+ Configured IngestionPipeline object
213
+ """
214
+ return IngestionPipeline(
215
+ transformations=[
216
+ self.clean_text,
217
+ MarkdownNodeParser(),
218
+ self.splitter,
219
+ self.embed_model,
220
+ ]
221
+ )
222
+
223
+ def validate_citation_retention(
224
+ self, documents: List[Document]
225
+ ) -> Dict[str, float]:
226
+ """
227
+ Measure semantic similarity of citations before/after text cleaning.
228
+
229
+ Args:
230
+ documents: List of Document objects to validate
231
+
232
+ Returns:
233
+ Dictionary with validation metrics
234
+ """
235
+ if not documents:
236
+ return {"citation_retention": 0.0, "processing_time": 0.0}
237
+
238
+ start_time = time.time()
239
+
240
+ # Extract original texts
241
+ original_texts = [doc.text for doc in documents[:5]] # Sample for performance
242
+
243
+ # Apply cleaning
244
+ processed_texts = [self.clean_text(text) for text in original_texts]
245
+
246
+ # Calculate embeddings
247
+ try:
248
+ # Direct access to the underlying HuggingFace model
249
+ orig_embeds = self.embed_model._model.encode(original_texts)
250
+ proc_embeds = self.embed_model._model.encode(processed_texts)
251
+
252
+ # Calculate similarity
253
+ similarities = cosine_similarity(orig_embeds, proc_embeds).diagonal()
254
+ avg_similarity = float(np.mean(similarities))
255
+
256
+ processing_time = time.time() - start_time
257
+
258
+ return {
259
+ "citation_retention": avg_similarity * 100, # As percentage
260
+ "processing_time": processing_time,
261
+ "sample_size": len(original_texts),
262
+ }
263
+ except Exception as e:
264
+ return {"citation_retention": 0.0, "processing_time": 0.0, "error": str(e)}
265
+
266
+ def process_documents(self, documents: List[Document]) -> Dict[str, Any]:
267
+ """
268
+ Process a list of legal documents.
269
+
270
+ Args:
271
+ documents: List of Document objects to process
272
+
273
+ Returns:
274
+ Dictionary with processing results and stats
275
+ """
276
+ if not documents:
277
+ return {"status": "error", "message": "No documents provided"}
278
+
279
+ try:
280
+ # Create pipeline and process documents
281
+ pipeline = self.create_pipeline()
282
+ nodes = pipeline.run(documents=documents)
283
+
284
+ # Create vector index
285
+ index = VectorStoreIndex(nodes)
286
+ query_engine = index.as_query_engine()
287
+
288
+ # Return success with stats
289
+ return {
290
+ "status": "success",
291
+ "nodes_count": len(nodes),
292
+ "documents_count": len(documents),
293
+ "model_used": self.model_key,
294
+ "query_engine": query_engine, # This will be used for querying
295
+ }
296
+ except Exception as e:
297
+ return {"status": "error", "message": str(e)}
298
+
299
+
300
+ class LegalDocumentTool(Tool):
301
+ """
302
+ Tool for processing legal documents with specialized models and querying capabilities.
303
+ """
304
+
305
+ name = "legal_document_processor"
306
+ description = (
307
+ "Processes legal documents with specialized models for legal text, optimizing for "
308
+ "citation retention, multilingual support, and performance on legal-specific retrieval tasks. "
309
+ "Can process text or file inputs and provide enhanced query capabilities."
310
+ )
311
+ inputs = {
312
+ "text": {
313
+ "type": "string",
314
+ "description": "Legal document text to process. Provide either text or file_paths.",
315
+ "optional": True,
316
+ },
317
+ "file_paths": {
318
+ "type": "string",
319
+ "description": "Comma-separated list of file paths or a directory path containing legal documents. Provide either text or file_paths.",
320
+ "optional": True,
321
+ },
322
+ "model_key": {
323
+ "type": "string",
324
+ "description": "Legal embedding model to use. Options: legal-bert, multi-qa-mpnet, legal-xlm-roberta, multilingual-e5, all-mpnet",
325
+ "default": "legal-xlm-roberta",
326
+ },
327
+ "query": {
328
+ "type": "string",
329
+ "description": "Optional query to run against the processed documents.",
330
+ "optional": True,
331
+ },
332
+ "validate_citations": {
333
+ "type": "boolean",
334
+ "description": "Whether to validate citation retention in the processed documents.",
335
+ "default": False,
336
+ },
337
+ "use_gpu": {
338
+ "type": "boolean",
339
+ "description": "Whether to use GPU for embedding calculations if available.",
340
+ "default": False,
341
+ },
342
+ }
343
+ output_type = "string"
344
+
345
+ def _load_documents(self, input_path: str) -> List[Document]:
346
+ """
347
+ Load documents from a file path or directory.
348
+
349
+ Args:
350
+ input_path: Path to a file or directory
351
+
352
+ Returns:
353
+ List of Document objects
354
+ """
355
+ if os.path.isfile(input_path):
356
+ # Create a SimpleDirectoryReader for the file's directory
357
+ # and filter to only include this file
358
+ directory = os.path.dirname(input_path)
359
+ filename = os.path.basename(input_path)
360
+
361
+ return SimpleDirectoryReader(
362
+ input_dir=directory,
363
+ required_exts=[
364
+ os.path.splitext(filename)[1][1:]
365
+ ], # Extension without dot
366
+ filename_as_id=True,
367
+ ).load_data()
368
+
369
+ elif os.path.isdir(input_path):
370
+ return SimpleDirectoryReader(
371
+ input_dir=input_path,
372
+ filename_as_id=True,
373
+ ).load_data()
374
+
375
+ else:
376
+ raise ValueError(f"Path not found: {input_path}")
377
+
378
+ def _create_document_from_text(self, text: str) -> List[Document]:
379
+ """
380
+ Create a Document object from text.
381
+
382
+ Args:
383
+ text: Text content
384
+
385
+ Returns:
386
+ List containing a single Document object
387
+ """
388
+ # Create a temporary file to store the text
389
+ with tempfile.NamedTemporaryFile(
390
+ mode="w", suffix=".md", delete=False
391
+ ) as temp_file:
392
+ temp_file.write(text)
393
+ temp_path = temp_file.name
394
+
395
+ try:
396
+ # Load the document from the temporary file
397
+ documents = self._load_documents(temp_path)
398
+ return documents
399
+ finally:
400
+ # Clean up the temporary file
401
+ os.remove(temp_path)
402
+
403
+ def forward(
404
+ self,
405
+ text: Optional[str] = None,
406
+ file_paths: Optional[str] = None,
407
+ model_key: str = "legal-xlm-roberta",
408
+ query: Optional[str] = None,
409
+ validate_citations: bool = False,
410
+ use_gpu: bool = False,
411
+ ) -> str:
412
+ """
413
+ Process legal documents and optionally run a query.
414
+
415
+ Args:
416
+ text: Legal document text to process
417
+ file_paths: Comma-separated list of file paths or a directory path
418
+ model_key: Legal embedding model to use
419
+ query: Optional query to run against the processed documents
420
+ validate_citations: Whether to validate citation retention
421
+ use_gpu: Whether to use GPU for embeddings
422
+
423
+ Returns:
424
+ Processing results or query response as a string
425
+ """
426
+ # Validate inputs
427
+ if not text and not file_paths:
428
+ return "Error: Either text or file_paths must be provided."
429
+
430
+ try:
431
+ # Initialize processor
432
+ processor = LegalDocumentProcessor(
433
+ model_key=model_key,
434
+ use_gpu=use_gpu,
435
+ )
436
+
437
+ # Load documents
438
+ documents = []
439
+
440
+ if text:
441
+ documents.extend(self._create_document_from_text(text))
442
+
443
+ if file_paths:
444
+ # Handle comma-separated paths
445
+ paths = [path.strip() for path in file_paths.split(",")]
446
+
447
+ for path in paths:
448
+ try:
449
+ docs = self._load_documents(path)
450
+ documents.extend(docs)
451
+ except Exception as e:
452
+ return f"Error loading documents from {path}: {str(e)}"
453
+
454
+ # Check if we have documents to process
455
+ if not documents:
456
+ return "Error: No valid documents found."
457
+
458
+ # Validate citations if requested
459
+ validation_results = {}
460
+ if validate_citations:
461
+ validation_results = processor.validate_citation_retention(documents)
462
+
463
+ # Process documents
464
+ result = processor.process_documents(documents)
465
+
466
+ if result["status"] != "success":
467
+ return f"Processing error: {result['message']}"
468
+
469
+ # Run query if provided
470
+ if query and "query_engine" in result:
471
+ query_engine = result["query_engine"]
472
+ response = query_engine.query(query)
473
+
474
+ # Format the response
475
+ output = f"Query: {query}\n\nResponse: {response}\n\n"
476
+ output += f"Documents processed: {result['documents_count']}\n"
477
+ output += f"Text chunks: {result['nodes_count']}\n"
478
+ output += f"Model used: {result['model_used']}\n"
479
+
480
+ # Add validation results if available
481
+ if validation_results:
482
+ output += "\n=== Citation Retention Validation ===\n"
483
+ output += f"Citation retention: {validation_results.get('citation_retention', 0):.2f}%\n"
484
+ output += f"Processing time: {validation_results.get('processing_time', 0):.2f} seconds\n"
485
+
486
+ return output
487
+
488
+ # If no query, return processing stats
489
+ output = "Document processing complete.\n\n"
490
+ output += f"Documents processed: {result['documents_count']}\n"
491
+ output += f"Text chunks: {result['nodes_count']}\n"
492
+ output += f"Model used: {result['model_used']}\n"
493
+
494
+ # Add validation results if available
495
+ if validation_results:
496
+ output += "\n=== Citation Retention Validation ===\n"
497
+ output += f"Citation retention: {validation_results.get('citation_retention', 0):.2f}%\n"
498
+ output += f"Processing time: {validation_results.get('processing_time', 0):.2f} seconds\n"
499
+
500
+ output += "\nThe documents are now ready for querying. Use the 'query' parameter to run a query."
501
+
502
+ return output
503
+
504
+ except Exception as e:
505
+ return f"Error: {str(e)}"