Leonardo commited on
Commit
fdb59f7
·
verified ·
1 Parent(s): c1ce83b

Update scripts/document_tool.py

Browse files
Files changed (1) hide show
  1. scripts/document_tool.py +278 -131
scripts/document_tool.py CHANGED
@@ -1,33 +1,35 @@
 
 
 
1
  """
2
- Legal Document Processing Tool for Smolagents
3
 
4
- This tool processes legal documents with specialized models for legal text,
5
- optimizing for citation retention, multilingual support, and performance on
6
- legal-specific retrieval tasks.
7
 
8
- Author: Dr. Zhou Wang
9
  """
10
 
11
- from typing import Dict, List, Any, Optional, Union
12
  import os
13
  import re
14
- import time
15
  import tempfile
16
- import spaces
 
 
17
  import numpy as np
18
- from tqdm import tqdm
19
 
20
  # Import Smolagents Tool class
21
  from smolagents import Tool
22
 
23
  # Import NLP components
24
  try:
25
- from sklearn.metrics.pairwise import cosine_similarity
26
- from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, Document
 
27
  from llama_index.core.node_parser import MarkdownNodeParser
28
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
29
- from llama_index.core.ingestion import IngestionPipeline
30
- from langchain.text_splitter import RecursiveCharacterTextSplitter
31
  except ImportError:
32
  raise ImportError(
33
  "Required dependencies not found. Please install with: "
@@ -35,89 +37,105 @@ except ImportError:
35
  )
36
 
37
 
38
- # Model configurations based on research findings
39
- LEGAL_MODELS = {
40
- "legal-bert": {
41
- "name": "nlp-jurisprudence/legal-bert-base-uncased",
42
- "description": "Trained on ECtHR legal documents, specialized in human rights law",
43
  "max_length": 512,
44
  "requires_gpu": True,
45
  },
46
- "multi-qa-mpnet": {
47
- "name": "sentence-transformers/multi-qa-mpnet-base-dot-v1",
48
- "description": "Optimized for legal Q&A retrieval with cross-lingual support",
49
  "max_length": 512,
50
  "requires_gpu": False,
51
  },
52
- "legal-xlm-roberta": {
53
- "name": "joelito/legal-xlm-roberta-base",
54
- "description": "Multilingual legal model with EU legislation and RFC/ISO pattern awareness",
55
  "max_length": 512,
56
  "requires_gpu": True,
57
  },
58
- "multilingual-e5": {
59
- "name": "intfloat/multilingual-e5-base",
60
- "description": "Dense retrieval optimized with citation context preservation",
61
  "max_length": 512,
62
  "requires_gpu": True,
63
  },
64
- "all-mpnet": {
65
  "name": "sentence-transformers/all-mpnet-base-v2",
66
- "description": "General purpose embedding model, good baseline for legal text",
67
  "max_length": 512,
68
  "requires_gpu": False,
69
  },
70
  }
71
 
72
 
73
- class LegalDocumentProcessor:
74
  """
75
- Processor for legal documents with specialized models,
76
- citation preservation, and benchmarking capabilities.
77
  """
78
 
79
  def __init__(
80
  self,
81
- model_key: str = "legal-xlm-roberta",
 
82
  use_gpu: bool = False,
83
  chunk_size: int = 512,
84
  chunk_overlap: int = 100,
 
85
  ):
86
  """
87
- Initialize the legal document processor.
88
 
89
  Args:
90
- model_key: Key for the model to use from LEGAL_MODELS dictionary
 
91
  use_gpu: Whether to use GPU for embeddings (if available)
92
  chunk_size: Size of text chunks for processing
93
  chunk_overlap: Overlap between chunks to preserve context
 
94
  """
95
- # Validate and set up model
96
- if model_key not in LEGAL_MODELS:
97
- print(
98
- f"Warning: Model '{model_key}' not found. Using legal-xlm-roberta as default."
99
- )
100
- model_key = "legal-xlm-roberta"
 
 
 
 
 
 
 
 
101
 
102
- model_config = LEGAL_MODELS[model_key]
103
- device = "cuda" if use_gpu and model_config["requires_gpu"] else "cpu"
 
104
 
105
  # Initialize embedding model
106
- self.embed_model = HuggingFaceEmbedding(
107
- model_name=model_config["name"],
108
- device=device,
109
- tokenizer_kwargs={
110
- "trust_remote_code": True,
111
- "max_length": model_config["max_length"],
112
- "truncation": True,
113
- },
114
- )
 
115
 
116
- # Store model information for reference
117
- self.model_info = model_config
118
- self.model_key = model_key
 
 
119
 
120
- # Legal document-optimized text splitter with improved chunk size
121
  self.splitter = RecursiveCharacterTextSplitter(
122
  chunk_size=chunk_size,
123
  chunk_overlap=chunk_overlap,
@@ -136,9 +154,8 @@ class LegalDocumentProcessor:
136
  ],
137
  )
138
 
139
- # Pattern for removing footers from legal documents
140
- # Separated into individual patterns for better maintainability
141
- self.footer_patterns = [
142
  r"^Page\s\d+(\s+of\s+\d+)?$", # Page numbers
143
  r"^©.*\b(Company|Inc|Ltd)\b.*$", # Copyright lines
144
  r"^All rights reserved.*?$", # Legal boilerplate
@@ -147,61 +164,183 @@ class LegalDocumentProcessor:
147
  r"(?i)^(confidential|proprietary|internal use only)", # Security tags
148
  ]
149
 
 
 
 
 
150
  # Join all patterns with the OR operator
151
- combined_pattern = "|".join(f"({pattern})" for pattern in self.footer_patterns)
 
 
152
 
153
  # Compile the combined pattern
154
- self.footer_pattern = re.compile(
155
  combined_pattern, flags=re.MULTILINE | re.IGNORECASE
156
  )
157
 
158
- def remove_footers(self, text: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  """
160
- Remove common document footer patterns from text.
161
 
162
  Args:
163
  text: The input text to process
164
 
165
  Returns:
166
- Text with footer patterns removed
167
  """
168
- return self.footer_pattern.sub("", text)
169
 
170
  def clean_text(self, text: str) -> str:
171
  """
172
- Preserve legal citations while cleaning artifacts.
173
 
174
  Args:
175
  text: The input text to clean
176
 
177
  Returns:
178
- Cleaned text with citations preserved
179
  """
180
- # First remove footers
181
- text = self.remove_footers(text)
182
-
183
- # Preserve citation patterns
184
- # Pattern 1: Footnote numbers (e.g., 98, 99, 100)
185
- cleaned = re.sub(r"(?<=\D)(\d{2,3})(?=\D)", r"[\1]", text)
186
 
187
- # Pattern 2: Case citations [2019] UKSC 20
188
- # Already well-structured, so no changes needed
189
-
190
- # Pattern 3: Standardize quotation marks
191
- cleaned = cleaned.replace("''", '"').replace("``", '"')
192
-
193
- # Pattern 4: Handle section references (§3.1, §123)
194
- cleaned = re.sub(r"§(\d+(\.\d+)?)", r"Section \1", cleaned)
195
-
196
- # Pattern 5: Handle legal abbreviations (e.g., Art. -> Article)
197
- cleaned = re.sub(r"\bArt\.\s+(\d+)", r"Article \1", cleaned)
198
-
199
- # Pattern 6: Standardize case names with v. and vs.
200
- cleaned = re.sub(r"\bv\s+", r"v. ", cleaned)
201
- cleaned = re.sub(r"\bvs\s+", r"v. ", cleaned)
202
-
203
- # Pattern 7: RFC/ISO pattern standardization (RFC 1234, ISO 9001)
204
- cleaned = re.sub(r"\b(RFC|ISO)\s*[:#]?\s*(\d+)", r"\1 \2", cleaned)
205
 
206
  return cleaned
207
 
@@ -221,11 +360,9 @@ class LegalDocumentProcessor:
221
  ]
222
  )
223
 
224
- def validate_citation_retention(
225
- self, documents: List[Document]
226
- ) -> Dict[str, float]:
227
  """
228
- Measure semantic similarity of citations before/after text cleaning.
229
 
230
  Args:
231
  documents: List of Document objects to validate
@@ -234,7 +371,7 @@ class LegalDocumentProcessor:
234
  Dictionary with validation metrics
235
  """
236
  if not documents:
237
- return {"citation_retention": 0.0, "processing_time": 0.0}
238
 
239
  start_time = time.time()
240
 
@@ -257,16 +394,16 @@ class LegalDocumentProcessor:
257
  processing_time = time.time() - start_time
258
 
259
  return {
260
- "citation_retention": avg_similarity * 100, # As percentage
261
  "processing_time": processing_time,
262
  "sample_size": len(original_texts),
263
  }
264
  except Exception as e:
265
- return {"citation_retention": 0.0, "processing_time": 0.0, "error": str(e)}
266
 
267
  def process_documents(self, documents: List[Document]) -> Dict[str, Any]:
268
  """
269
- Process a list of legal documents.
270
 
271
  Args:
272
  documents: List of Document objects to process
@@ -291,48 +428,54 @@ class LegalDocumentProcessor:
291
  "status": "success",
292
  "nodes_count": len(nodes),
293
  "documents_count": len(documents),
294
- "model_used": self.model_key,
 
295
  "query_engine": query_engine, # This will be used for querying
296
  }
297
  except Exception as e:
298
  return {"status": "error", "message": str(e)}
299
 
300
 
301
- class LegalDocumentTool(Tool):
302
  """
303
- Tool for processing legal documents with specialized models and querying capabilities.
304
  """
305
 
306
- name = "legal_document_processor"
307
  description = (
308
- "Processes legal documents with specialized models for legal text, optimizing for "
309
- "citation retention, multilingual support, and performance on legal-specific retrieval tasks. "
310
- "Can process text or file inputs and provide enhanced query capabilities."
311
  )
312
  inputs = {
313
  "text": {
314
  "type": "string",
315
- "description": "Legal document text to process. Provide either text or file_paths.",
316
  "optional": True,
317
  },
318
  "file_paths": {
319
  "type": "string",
320
- "description": "Comma-separated list of file paths or a directory path containing legal documents. Provide either text or file_paths.",
321
  "optional": True,
322
  },
323
- "model_key": {
 
 
 
 
 
324
  "type": "string",
325
- "description": "Legal embedding model to use. Options: legal-bert, multi-qa-mpnet, legal-xlm-roberta, multilingual-e5, all-mpnet",
326
- "default": "legal-xlm-roberta",
327
  },
328
  "query": {
329
  "type": "string",
330
  "description": "Optional query to run against the processed documents.",
331
  "optional": True,
332
  },
333
- "validate_citations": {
334
  "type": "boolean",
335
- "description": "Whether to validate citation retention in the processed documents.",
336
  "default": False,
337
  },
338
  "use_gpu": {
@@ -401,25 +544,26 @@ class LegalDocumentTool(Tool):
401
  # Clean up the temporary file
402
  os.remove(temp_path)
403
 
404
- @spaces.GPU
405
  def forward(
406
  self,
407
  text: Optional[str] = None,
408
  file_paths: Optional[str] = None,
409
- model_key: str = "legal-xlm-roberta",
 
410
  query: Optional[str] = None,
411
- validate_citations: bool = False,
412
  use_gpu: bool = False,
413
  ) -> str:
414
  """
415
- Process legal documents and optionally run a query.
416
 
417
  Args:
418
- text: Legal document text to process
419
  file_paths: Comma-separated list of file paths or a directory path
420
- model_key: Legal embedding model to use
 
421
  query: Optional query to run against the processed documents
422
- validate_citations: Whether to validate citation retention
423
  use_gpu: Whether to use GPU for embeddings
424
 
425
  Returns:
@@ -431,8 +575,9 @@ class LegalDocumentTool(Tool):
431
 
432
  try:
433
  # Initialize processor
434
- processor = LegalDocumentProcessor(
435
- model_key=model_key,
 
436
  use_gpu=use_gpu,
437
  )
438
 
@@ -457,10 +602,10 @@ class LegalDocumentTool(Tool):
457
  if not documents:
458
  return "Error: No valid documents found."
459
 
460
- # Validate citations if requested
461
  validation_results = {}
462
- if validate_citations:
463
- validation_results = processor.validate_citation_retention(documents)
464
 
465
  # Process documents
466
  result = processor.process_documents(documents)
@@ -477,12 +622,13 @@ class LegalDocumentTool(Tool):
477
  output = f"Query: {query}\n\nResponse: {response}\n\n"
478
  output += f"Documents processed: {result['documents_count']}\n"
479
  output += f"Text chunks: {result['nodes_count']}\n"
480
- output += f"Model used: {result['model_used']}\n"
 
481
 
482
  # Add validation results if available
483
  if validation_results:
484
- output += "\n=== Citation Retention Validation ===\n"
485
- output += f"Citation retention: {validation_results.get('citation_retention', 0):.2f}%\n"
486
  output += f"Processing time: {validation_results.get('processing_time', 0):.2f} seconds\n"
487
 
488
  return output
@@ -491,12 +637,13 @@ class LegalDocumentTool(Tool):
491
  output = "Document processing complete.\n\n"
492
  output += f"Documents processed: {result['documents_count']}\n"
493
  output += f"Text chunks: {result['nodes_count']}\n"
494
- output += f"Model used: {result['model_used']}\n"
 
495
 
496
  # Add validation results if available
497
  if validation_results:
498
- output += "\n=== Citation Retention Validation ===\n"
499
- output += f"Citation retention: {validation_results.get('citation_retention', 0):.2f}%\n"
500
  output += f"Processing time: {validation_results.get('processing_time', 0):.2f} seconds\n"
501
 
502
  output += "\nThe documents are now ready for querying. Use the 'query' parameter to run a query."
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2025 The Footscray Coding Collective. All rights reserved.
4
  """
5
+ General Document Processing Tool for Smolagents
6
 
7
+ This tool processes various types of documents with domain-specific models,
8
+ optimizing for intelligent document parsing, entity extraction, and
9
+ customized retrieval tasks.
10
 
11
+ Author: Zhou Wang
12
  """
13
 
 
14
  import os
15
  import re
 
16
  import tempfile
17
+ import time
18
+ from typing import Any, Dict, List, Optional, Union
19
+
20
  import numpy as np
 
21
 
22
  # Import Smolagents Tool class
23
  from smolagents import Tool
24
 
25
  # Import NLP components
26
  try:
27
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
28
+ from llama_index.core import Document, SimpleDirectoryReader, VectorStoreIndex
29
+ from llama_index.core.ingestion import IngestionPipeline
30
  from llama_index.core.node_parser import MarkdownNodeParser
31
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
32
+ from sklearn.metrics.pairwise import cosine_similarity
 
33
  except ImportError:
34
  raise ImportError(
35
  "Required dependencies not found. Please install with: "
 
37
  )
38
 
39
 
40
+ # Model configurations based on domain specialization
41
+ DOMAIN_MODELS = {
42
+ "legal": {
43
+ "name": "joelito/legal-xlm-roberta-base",
44
+ "description": "Specialized for legal documents with citation preservation",
45
  "max_length": 512,
46
  "requires_gpu": True,
47
  },
48
+ "financial": {
49
+ "name": "thenlper/finetuned-finbert-slot-filling",
50
+ "description": "Financial document analysis with entity extraction",
51
  "max_length": 512,
52
  "requires_gpu": False,
53
  },
54
+ "medical": {
55
+ "name": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
56
+ "description": "Medical text processing optimized for clinical terms",
57
  "max_length": 512,
58
  "requires_gpu": True,
59
  },
60
+ "technical": {
61
+ "name": "allenai/scibert_scivocab_uncased",
62
+ "description": "Scientific and technical document processing",
63
  "max_length": 512,
64
  "requires_gpu": True,
65
  },
66
+ "general": {
67
  "name": "sentence-transformers/all-mpnet-base-v2",
68
+ "description": "General purpose embedding model for all document types",
69
  "max_length": 512,
70
  "requires_gpu": False,
71
  },
72
  }
73
 
74
 
75
+ class DocumentProcessor:
76
  """
77
+ Processor for documents with domain-specific models,
78
+ entity preservation, and customizable processing capabilities.
79
  """
80
 
81
  def __init__(
82
  self,
83
+ domain: str = "general",
84
+ model_key: Optional[str] = None,
85
  use_gpu: bool = False,
86
  chunk_size: int = 512,
87
  chunk_overlap: int = 100,
88
+ custom_patterns: Optional[List[str]] = None,
89
  ):
90
  """
91
+ Initialize the document processor.
92
 
93
  Args:
94
+ domain: Domain specialization ('legal', 'financial', 'medical', 'technical', 'general')
95
+ model_key: Specific model to use (overrides domain selection)
96
  use_gpu: Whether to use GPU for embeddings (if available)
97
  chunk_size: Size of text chunks for processing
98
  chunk_overlap: Overlap between chunks to preserve context
99
+ custom_patterns: Additional regex patterns for text cleaning
100
  """
101
+ # Store domain
102
+ self.domain = domain
103
+
104
+ # If model_key provided, use it directly
105
+ if model_key:
106
+ model_name = model_key
107
+ device = "cuda" if use_gpu else "cpu"
108
+ else:
109
+ # Otherwise select model based on domain
110
+ if domain not in DOMAIN_MODELS:
111
+ print(
112
+ f"Warning: Domain '{domain}' not found. Using 'general' as default."
113
+ )
114
+ domain = "general"
115
 
116
+ model_config = DOMAIN_MODELS[domain]
117
+ model_name = model_config["name"]
118
+ device = "cuda" if use_gpu and model_config["requires_gpu"] else "cpu"
119
 
120
  # Initialize embedding model
121
+ try:
122
+ self.embed_model = HuggingFaceEmbedding(
123
+ model_name=model_name,
124
+ device=device,
125
+ tokenizer_kwargs={
126
+ "trust_remote_code": True,
127
+ "max_length": 512,
128
+ "truncation": True,
129
+ },
130
+ )
131
 
132
+ # Store model information for reference
133
+ self.model_name = model_name
134
+ self.device = device
135
+ except Exception as e:
136
+ raise RuntimeError(f"Failed to initialize embedding model: {str(e)}")
137
 
138
+ # Domain-optimized text splitter
139
  self.splitter = RecursiveCharacterTextSplitter(
140
  chunk_size=chunk_size,
141
  chunk_overlap=chunk_overlap,
 
154
  ],
155
  )
156
 
157
+ # Base cleaning patterns
158
+ self.cleaning_patterns = [
 
159
  r"^Page\s\d+(\s+of\s+\d+)?$", # Page numbers
160
  r"^©.*\b(Company|Inc|Ltd)\b.*$", # Copyright lines
161
  r"^All rights reserved.*?$", # Legal boilerplate
 
164
  r"(?i)^(confidential|proprietary|internal use only)", # Security tags
165
  ]
166
 
167
+ # Add custom patterns if provided
168
+ if custom_patterns:
169
+ self.cleaning_patterns.extend(custom_patterns)
170
+
171
  # Join all patterns with the OR operator
172
+ combined_pattern = "|".join(
173
+ f"({pattern})" for pattern in self.cleaning_patterns
174
+ )
175
 
176
  # Compile the combined pattern
177
+ self.cleaning_pattern = re.compile(
178
  combined_pattern, flags=re.MULTILINE | re.IGNORECASE
179
  )
180
 
181
+ # Initialize domain-specific processors
182
+ self._init_domain_processors()
183
+
184
+ def _init_domain_processors(self):
185
+ """Initialize domain-specific processors based on selected domain."""
186
+ # Domain-specific entity patterns
187
+ self.entity_patterns = {}
188
+
189
+ # Set up domain-specific patterns and processors
190
+ if self.domain == "legal":
191
+ self.entity_patterns = {
192
+ "case_citation": r"\[\d{4}\]\s+[A-Z]+\s+\d+", # [2019] UKSC 20
193
+ "statute": r"\b(?:Art\.|Section)\s+\d+(\.\d+)?", # Art. 5, Section 3.1
194
+ "legal_ref": r"\b[A-Za-z]+\s+v\.?\s+[A-Za-z]+", # Smith v. Jones
195
+ }
196
+ self.process_entities = self._process_legal_entities
197
+
198
+ if self.domain == "financial":
199
+ self.entity_patterns = {
200
+ "monetary": r"\$\s*\d+(?:\.\d+)?(?:\s*(?:million|billion|trillion))?", # $5.2 million
201
+ "percentage": r"\d+(?:\.\d+)?\s*%", # 10.5%
202
+ "date_range": r"(?:Q[1-4]|FY)\s+\d{4}", # Q2 2023, FY 2022
203
+ }
204
+ self.process_entities = self._process_financial_entities
205
+
206
+ if self.domain == "medical":
207
+ self.entity_patterns = {
208
+ "dosage": r"\d+(?:\.\d+)?\s*(?:mg|mcg|g|ml|oz)", # 10mg, 5.5ml
209
+ "medical_code": r"[A-Z]\d{2}(?:\.\d+)?", # ICD codes like E11.9
210
+ "vital_sign": r"\d+(?:\.\d+)?\s*(?:bpm|mmHg|°[CF])", # 120 bpm, 98.6°F
211
+ }
212
+ self.process_entities = self._process_medical_entities
213
+
214
+ if self.domain == "technical":
215
+ self.entity_patterns = {
216
+ "version": r"v\d+(?:\.\d+){1,3}", # v1.2.3
217
+ "code_ref": r"(?:\w+\.)+\w+\(\)", # function calls like math.sqrt()
218
+ "tech_standard": r"(?:RFC|ISO|IEEE)\s*\d+", # RFC 1918, ISO 9001
219
+ }
220
+ self.process_entities = self._process_technical_entities
221
+
222
+ else: # General domain or fallback
223
+ self.entity_patterns = {
224
+ "url": r"https?://\S+", # URLs
225
+ "email": r"\S+@\S+\.\S+", # Email addresses
226
+ "date": r"\d{1,2}[/-]\d{1,2}[/-]\d{2,4}", # Dates
227
+ }
228
+ self.process_entities = self._process_general_entities
229
+
230
+ def _process_legal_entities(self, text: str) -> str:
231
+ """Process legal document entities."""
232
+ # Preserve citation patterns
233
+ # Pattern 1: Case citations [2019] UKSC 20
234
+ # Already well-structured, so no changes needed
235
+
236
+ # Pattern 2: Standardize section references (§3.1, §123)
237
+ processed = re.sub(r"§(\d+(\.\d+)?)", r"Section \1", text)
238
+
239
+ # Pattern 3: Handle legal abbreviations (e.g., Art. -> Article)
240
+ processed = re.sub(r"\bArt\.\s+(\d+)", r"Article \1", processed)
241
+
242
+ # Pattern 4: Standardize case names with v. and vs.
243
+ processed = re.sub(r"\bv\s+", r"v. ", processed)
244
+ processed = re.sub(r"\bvs\s+", r"v. ", processed)
245
+
246
+ return processed
247
+
248
+ def _process_financial_entities(self, text: str) -> str:
249
+ """Process financial document entities."""
250
+ # Pattern 1: Standardize monetary values
251
+ processed = re.sub(
252
+ r"\$\s*(\d+)(?:,\d{3})*(?:\.\d+)?",
253
+ lambda m: f"${float(m.group(1).replace(',', ''))}",
254
+ text,
255
+ )
256
+
257
+ # Pattern 2: Standardize percentage representations
258
+ processed = re.sub(r"(\d+(?:\.\d+)?)\s*(?:percent|pct)", r"\1%", processed)
259
+
260
+ # Pattern 3: Standardize fiscal periods
261
+ processed = re.sub(r"(?:fiscal year|FY)\s+(\d{4})", r"FY \1", processed)
262
+
263
+ # Pattern 4: Standardize quarterly references
264
+ processed = re.sub(r"(?:quarter|Q)(\d)\s+(\d{4})", r"Q\1 \2", processed)
265
+
266
+ return processed
267
+
268
+ def _process_medical_entities(self, text: str) -> str:
269
+ """Process medical document entities."""
270
+ # Pattern 1: Standardize dosage format
271
+ processed = re.sub(
272
+ r"(\d+(?:\.\d+)?)\s*(milligrams?|mcgs?|grams?|milliliters?)",
273
+ lambda m: f"{m.group(1)} {m.group(2)[0:2]}",
274
+ text,
275
+ )
276
+
277
+ # Pattern 2: Standardize temperature format
278
+ processed = re.sub(r"(\d+(?:\.\d+)?)\s*degrees?\s*([CF])", r"\1°\2", processed)
279
+
280
+ # Pattern 3: Standardize vital signs
281
+ processed = re.sub(
282
+ r"(\d+(?:\.\d+)?)\s*(?:beats per minute|BPM)", r"\1 bpm", processed
283
+ )
284
+
285
+ return processed
286
+
287
+ def _process_technical_entities(self, text: str) -> str:
288
+ """Process technical document entities."""
289
+ # Pattern 1: Standardize version numbers
290
+ processed = re.sub(r"version\s+(\d+(?:\.\d+){1,3})", r"v\1", text)
291
+
292
+ # Pattern 2: RFC/ISO pattern standardization
293
+ processed = re.sub(r"\b(RFC|ISO|IEEE)\s*[:#]?\s*(\d+)", r"\1 \2", processed)
294
+
295
+ # Pattern 3: Standardize code references
296
+ # This is a simplified example
297
+ processed = re.sub(r"function\s+(\w+)\s*\(", r"\1(", processed)
298
+
299
+ return processed
300
+
301
+ def _process_general_entities(self, text: str) -> str:
302
+ """Process general document entities."""
303
+ # General cleaning and standardization
304
+ processed = text
305
+
306
+ # URLs preserved as-is
307
+
308
+ # Simple date standardization
309
+ processed = re.sub(
310
+ r"(\d{1,2})/(\d{1,2})/(\d{2})(?!\d)",
311
+ r"\1/\2/20\3", # Assume 2-digit years are 2000s
312
+ processed,
313
+ )
314
+
315
+ return processed
316
+
317
+ def remove_boilerplate(self, text: str) -> str:
318
  """
319
+ Remove common document boilerplate patterns from text.
320
 
321
  Args:
322
  text: The input text to process
323
 
324
  Returns:
325
+ Text with boilerplate patterns removed
326
  """
327
+ return self.cleaning_pattern.sub("", text)
328
 
329
  def clean_text(self, text: str) -> str:
330
  """
331
+ Clean text while preserving domain-specific entities.
332
 
333
  Args:
334
  text: The input text to clean
335
 
336
  Returns:
337
+ Cleaned text with domain entities preserved
338
  """
339
+ # First remove boilerplate
340
+ cleaned = self.remove_boilerplate(text)
 
 
 
 
341
 
342
+ # Then process domain-specific entities
343
+ cleaned = self.process_entities(cleaned)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
  return cleaned
346
 
 
360
  ]
361
  )
362
 
363
+ def validate_entity_retention(self, documents: List[Document]) -> Dict[str, float]:
 
 
364
  """
365
+ Measure semantic similarity of entities before/after text cleaning.
366
 
367
  Args:
368
  documents: List of Document objects to validate
 
371
  Dictionary with validation metrics
372
  """
373
  if not documents:
374
+ return {"entity_retention": 0.0, "processing_time": 0.0}
375
 
376
  start_time = time.time()
377
 
 
394
  processing_time = time.time() - start_time
395
 
396
  return {
397
+ "entity_retention": avg_similarity * 100, # As percentage
398
  "processing_time": processing_time,
399
  "sample_size": len(original_texts),
400
  }
401
  except Exception as e:
402
+ return {"entity_retention": 0.0, "processing_time": 0.0, "error": str(e)}
403
 
404
  def process_documents(self, documents: List[Document]) -> Dict[str, Any]:
405
  """
406
+ Process a list of documents.
407
 
408
  Args:
409
  documents: List of Document objects to process
 
428
  "status": "success",
429
  "nodes_count": len(nodes),
430
  "documents_count": len(documents),
431
+ "domain": self.domain,
432
+ "model_name": self.model_name,
433
  "query_engine": query_engine, # This will be used for querying
434
  }
435
  except Exception as e:
436
  return {"status": "error", "message": str(e)}
437
 
438
 
439
+ class DocumentProcessorTool(Tool):
440
  """
441
+ General-purpose document processing tool with domain specialization.
442
  """
443
 
444
+ name = "document_processor"
445
  description = (
446
+ "Processes documents with domain-specific models optimized for "
447
+ "entity preservation and retrieval performance. Supports legal, "
448
+ "financial, medical, technical and general document types."
449
  )
450
  inputs = {
451
  "text": {
452
  "type": "string",
453
+ "description": "Document text to process. Provide either text or file_paths.",
454
  "optional": True,
455
  },
456
  "file_paths": {
457
  "type": "string",
458
+ "description": "Comma-separated list of file paths or a directory path containing documents. Provide either text or file_paths.",
459
  "optional": True,
460
  },
461
+ "domain": {
462
+ "type": "string",
463
+ "description": "Document domain for specialized processing: legal, financial, medical, technical, or general.",
464
+ "default": "general",
465
+ },
466
+ "model_name": {
467
  "type": "string",
468
+ "description": "Specific embedding model name to use (optional, overrides domain selection).",
469
+ "optional": True,
470
  },
471
  "query": {
472
  "type": "string",
473
  "description": "Optional query to run against the processed documents.",
474
  "optional": True,
475
  },
476
+ "validate_entities": {
477
  "type": "boolean",
478
+ "description": "Whether to validate entity retention in the processed documents.",
479
  "default": False,
480
  },
481
  "use_gpu": {
 
544
  # Clean up the temporary file
545
  os.remove(temp_path)
546
 
 
547
  def forward(
548
  self,
549
  text: Optional[str] = None,
550
  file_paths: Optional[str] = None,
551
+ domain: str = "general",
552
+ model_name: Optional[str] = None,
553
  query: Optional[str] = None,
554
+ validate_entities: bool = False,
555
  use_gpu: bool = False,
556
  ) -> str:
557
  """
558
+ Process documents and optionally run a query.
559
 
560
  Args:
561
+ text: Document text to process
562
  file_paths: Comma-separated list of file paths or a directory path
563
+ domain: Document domain specialization
564
+ model_name: Specific embedding model to use
565
  query: Optional query to run against the processed documents
566
+ validate_entities: Whether to validate entity retention
567
  use_gpu: Whether to use GPU for embeddings
568
 
569
  Returns:
 
575
 
576
  try:
577
  # Initialize processor
578
+ processor = DocumentProcessor(
579
+ domain=domain,
580
+ model_key=model_name,
581
  use_gpu=use_gpu,
582
  )
583
 
 
602
  if not documents:
603
  return "Error: No valid documents found."
604
 
605
+ # Validate entity retention if requested
606
  validation_results = {}
607
+ if validate_entities:
608
+ validation_results = processor.validate_entity_retention(documents)
609
 
610
  # Process documents
611
  result = processor.process_documents(documents)
 
622
  output = f"Query: {query}\n\nResponse: {response}\n\n"
623
  output += f"Documents processed: {result['documents_count']}\n"
624
  output += f"Text chunks: {result['nodes_count']}\n"
625
+ output += f"Domain: {result['domain']}\n"
626
+ output += f"Model: {result['model_name']}\n"
627
 
628
  # Add validation results if available
629
  if validation_results:
630
+ output += "\n=== Entity Retention Validation ===\n"
631
+ output += f"Entity retention: {validation_results.get('entity_retention', 0):.2f}%\n"
632
  output += f"Processing time: {validation_results.get('processing_time', 0):.2f} seconds\n"
633
 
634
  return output
 
637
  output = "Document processing complete.\n\n"
638
  output += f"Documents processed: {result['documents_count']}\n"
639
  output += f"Text chunks: {result['nodes_count']}\n"
640
+ output += f"Domain: {result['domain']}\n"
641
+ output += f"Model: {result['model_name']}\n"
642
 
643
  # Add validation results if available
644
  if validation_results:
645
+ output += "\n=== Entity Retention Validation ===\n"
646
+ output += f"Entity retention: {validation_results.get('entity_retention', 0):.2f}%\n"
647
  output += f"Processing time: {validation_results.get('processing_time', 0):.2f} seconds\n"
648
 
649
  output += "\nThe documents are now ready for querying. Use the 'query' parameter to run a query."