Leonardo commited on
Commit
24730c9
·
verified ·
1 Parent(s): fdb59f7

Delete scripts/document_tool.py

Browse files
Files changed (1) hide show
  1. scripts/document_tool.py +0 -654
scripts/document_tool.py DELETED
@@ -1,654 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- # Copyright 2025 The Footscray Coding Collective. All rights reserved.
4
- """
5
- General Document Processing Tool for Smolagents
6
-
7
- This tool processes various types of documents with domain-specific models,
8
- optimizing for intelligent document parsing, entity extraction, and
9
- customized retrieval tasks.
10
-
11
- Author: Zhou Wang
12
- """
13
-
14
- import os
15
- import re
16
- import tempfile
17
- import time
18
- from typing import Any, Dict, List, Optional, Union
19
-
20
- import numpy as np
21
-
22
- # Import Smolagents Tool class
23
- from smolagents import Tool
24
-
25
- # Import NLP components
26
- try:
27
- from langchain.text_splitter import RecursiveCharacterTextSplitter
28
- from llama_index.core import Document, SimpleDirectoryReader, VectorStoreIndex
29
- from llama_index.core.ingestion import IngestionPipeline
30
- from llama_index.core.node_parser import MarkdownNodeParser
31
- from llama_index.embeddings.huggingface import HuggingFaceEmbedding
32
- from sklearn.metrics.pairwise import cosine_similarity
33
- except ImportError:
34
- raise ImportError(
35
- "Required dependencies not found. Please install with: "
36
- "pip install llama-index langchain scikit-learn tqdm"
37
- )
38
-
39
-
40
- # Model configurations based on domain specialization
41
- DOMAIN_MODELS = {
42
- "legal": {
43
- "name": "joelito/legal-xlm-roberta-base",
44
- "description": "Specialized for legal documents with citation preservation",
45
- "max_length": 512,
46
- "requires_gpu": True,
47
- },
48
- "financial": {
49
- "name": "thenlper/finetuned-finbert-slot-filling",
50
- "description": "Financial document analysis with entity extraction",
51
- "max_length": 512,
52
- "requires_gpu": False,
53
- },
54
- "medical": {
55
- "name": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
56
- "description": "Medical text processing optimized for clinical terms",
57
- "max_length": 512,
58
- "requires_gpu": True,
59
- },
60
- "technical": {
61
- "name": "allenai/scibert_scivocab_uncased",
62
- "description": "Scientific and technical document processing",
63
- "max_length": 512,
64
- "requires_gpu": True,
65
- },
66
- "general": {
67
- "name": "sentence-transformers/all-mpnet-base-v2",
68
- "description": "General purpose embedding model for all document types",
69
- "max_length": 512,
70
- "requires_gpu": False,
71
- },
72
- }
73
-
74
-
75
- class DocumentProcessor:
76
- """
77
- Processor for documents with domain-specific models,
78
- entity preservation, and customizable processing capabilities.
79
- """
80
-
81
- def __init__(
82
- self,
83
- domain: str = "general",
84
- model_key: Optional[str] = None,
85
- use_gpu: bool = False,
86
- chunk_size: int = 512,
87
- chunk_overlap: int = 100,
88
- custom_patterns: Optional[List[str]] = None,
89
- ):
90
- """
91
- Initialize the document processor.
92
-
93
- Args:
94
- domain: Domain specialization ('legal', 'financial', 'medical', 'technical', 'general')
95
- model_key: Specific model to use (overrides domain selection)
96
- use_gpu: Whether to use GPU for embeddings (if available)
97
- chunk_size: Size of text chunks for processing
98
- chunk_overlap: Overlap between chunks to preserve context
99
- custom_patterns: Additional regex patterns for text cleaning
100
- """
101
- # Store domain
102
- self.domain = domain
103
-
104
- # If model_key provided, use it directly
105
- if model_key:
106
- model_name = model_key
107
- device = "cuda" if use_gpu else "cpu"
108
- else:
109
- # Otherwise select model based on domain
110
- if domain not in DOMAIN_MODELS:
111
- print(
112
- f"Warning: Domain '{domain}' not found. Using 'general' as default."
113
- )
114
- domain = "general"
115
-
116
- model_config = DOMAIN_MODELS[domain]
117
- model_name = model_config["name"]
118
- device = "cuda" if use_gpu and model_config["requires_gpu"] else "cpu"
119
-
120
- # Initialize embedding model
121
- try:
122
- self.embed_model = HuggingFaceEmbedding(
123
- model_name=model_name,
124
- device=device,
125
- tokenizer_kwargs={
126
- "trust_remote_code": True,
127
- "max_length": 512,
128
- "truncation": True,
129
- },
130
- )
131
-
132
- # Store model information for reference
133
- self.model_name = model_name
134
- self.device = device
135
- except Exception as e:
136
- raise RuntimeError(f"Failed to initialize embedding model: {str(e)}")
137
-
138
- # Domain-optimized text splitter
139
- self.splitter = RecursiveCharacterTextSplitter(
140
- chunk_size=chunk_size,
141
- chunk_overlap=chunk_overlap,
142
- separators=[
143
- "\n## ",
144
- "\n### ",
145
- "\n#### ", # Headers
146
- "\n\n",
147
- "\n", # Paragraphs
148
- ". ",
149
- "! ",
150
- "? ", # Sentences
151
- ";",
152
- ":", # Clause boundaries
153
- " ", # Last resort
154
- ],
155
- )
156
-
157
- # Base cleaning patterns
158
- self.cleaning_patterns = [
159
- r"^Page\s\d+(\s+of\s+\d+)?$", # Page numbers
160
- r"^©.*\b(Company|Inc|Ltd)\b.*$", # Copyright lines
161
- r"^All rights reserved.*?$", # Legal boilerplate
162
- r"^-+$", # Separator lines
163
- r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}(:\d{2})?$", # Timestamps
164
- r"(?i)^(confidential|proprietary|internal use only)", # Security tags
165
- ]
166
-
167
- # Add custom patterns if provided
168
- if custom_patterns:
169
- self.cleaning_patterns.extend(custom_patterns)
170
-
171
- # Join all patterns with the OR operator
172
- combined_pattern = "|".join(
173
- f"({pattern})" for pattern in self.cleaning_patterns
174
- )
175
-
176
- # Compile the combined pattern
177
- self.cleaning_pattern = re.compile(
178
- combined_pattern, flags=re.MULTILINE | re.IGNORECASE
179
- )
180
-
181
- # Initialize domain-specific processors
182
- self._init_domain_processors()
183
-
184
- def _init_domain_processors(self):
185
- """Initialize domain-specific processors based on selected domain."""
186
- # Domain-specific entity patterns
187
- self.entity_patterns = {}
188
-
189
- # Set up domain-specific patterns and processors
190
- if self.domain == "legal":
191
- self.entity_patterns = {
192
- "case_citation": r"\[\d{4}\]\s+[A-Z]+\s+\d+", # [2019] UKSC 20
193
- "statute": r"\b(?:Art\.|Section)\s+\d+(\.\d+)?", # Art. 5, Section 3.1
194
- "legal_ref": r"\b[A-Za-z]+\s+v\.?\s+[A-Za-z]+", # Smith v. Jones
195
- }
196
- self.process_entities = self._process_legal_entities
197
-
198
- if self.domain == "financial":
199
- self.entity_patterns = {
200
- "monetary": r"\$\s*\d+(?:\.\d+)?(?:\s*(?:million|billion|trillion))?", # $5.2 million
201
- "percentage": r"\d+(?:\.\d+)?\s*%", # 10.5%
202
- "date_range": r"(?:Q[1-4]|FY)\s+\d{4}", # Q2 2023, FY 2022
203
- }
204
- self.process_entities = self._process_financial_entities
205
-
206
- if self.domain == "medical":
207
- self.entity_patterns = {
208
- "dosage": r"\d+(?:\.\d+)?\s*(?:mg|mcg|g|ml|oz)", # 10mg, 5.5ml
209
- "medical_code": r"[A-Z]\d{2}(?:\.\d+)?", # ICD codes like E11.9
210
- "vital_sign": r"\d+(?:\.\d+)?\s*(?:bpm|mmHg|°[CF])", # 120 bpm, 98.6°F
211
- }
212
- self.process_entities = self._process_medical_entities
213
-
214
- if self.domain == "technical":
215
- self.entity_patterns = {
216
- "version": r"v\d+(?:\.\d+){1,3}", # v1.2.3
217
- "code_ref": r"(?:\w+\.)+\w+\(\)", # function calls like math.sqrt()
218
- "tech_standard": r"(?:RFC|ISO|IEEE)\s*\d+", # RFC 1918, ISO 9001
219
- }
220
- self.process_entities = self._process_technical_entities
221
-
222
- else: # General domain or fallback
223
- self.entity_patterns = {
224
- "url": r"https?://\S+", # URLs
225
- "email": r"\S+@\S+\.\S+", # Email addresses
226
- "date": r"\d{1,2}[/-]\d{1,2}[/-]\d{2,4}", # Dates
227
- }
228
- self.process_entities = self._process_general_entities
229
-
230
- def _process_legal_entities(self, text: str) -> str:
231
- """Process legal document entities."""
232
- # Preserve citation patterns
233
- # Pattern 1: Case citations [2019] UKSC 20
234
- # Already well-structured, so no changes needed
235
-
236
- # Pattern 2: Standardize section references (§3.1, §123)
237
- processed = re.sub(r"§(\d+(\.\d+)?)", r"Section \1", text)
238
-
239
- # Pattern 3: Handle legal abbreviations (e.g., Art. -> Article)
240
- processed = re.sub(r"\bArt\.\s+(\d+)", r"Article \1", processed)
241
-
242
- # Pattern 4: Standardize case names with v. and vs.
243
- processed = re.sub(r"\bv\s+", r"v. ", processed)
244
- processed = re.sub(r"\bvs\s+", r"v. ", processed)
245
-
246
- return processed
247
-
248
- def _process_financial_entities(self, text: str) -> str:
249
- """Process financial document entities."""
250
- # Pattern 1: Standardize monetary values
251
- processed = re.sub(
252
- r"\$\s*(\d+)(?:,\d{3})*(?:\.\d+)?",
253
- lambda m: f"${float(m.group(1).replace(',', ''))}",
254
- text,
255
- )
256
-
257
- # Pattern 2: Standardize percentage representations
258
- processed = re.sub(r"(\d+(?:\.\d+)?)\s*(?:percent|pct)", r"\1%", processed)
259
-
260
- # Pattern 3: Standardize fiscal periods
261
- processed = re.sub(r"(?:fiscal year|FY)\s+(\d{4})", r"FY \1", processed)
262
-
263
- # Pattern 4: Standardize quarterly references
264
- processed = re.sub(r"(?:quarter|Q)(\d)\s+(\d{4})", r"Q\1 \2", processed)
265
-
266
- return processed
267
-
268
- def _process_medical_entities(self, text: str) -> str:
269
- """Process medical document entities."""
270
- # Pattern 1: Standardize dosage format
271
- processed = re.sub(
272
- r"(\d+(?:\.\d+)?)\s*(milligrams?|mcgs?|grams?|milliliters?)",
273
- lambda m: f"{m.group(1)} {m.group(2)[0:2]}",
274
- text,
275
- )
276
-
277
- # Pattern 2: Standardize temperature format
278
- processed = re.sub(r"(\d+(?:\.\d+)?)\s*degrees?\s*([CF])", r"\1°\2", processed)
279
-
280
- # Pattern 3: Standardize vital signs
281
- processed = re.sub(
282
- r"(\d+(?:\.\d+)?)\s*(?:beats per minute|BPM)", r"\1 bpm", processed
283
- )
284
-
285
- return processed
286
-
287
- def _process_technical_entities(self, text: str) -> str:
288
- """Process technical document entities."""
289
- # Pattern 1: Standardize version numbers
290
- processed = re.sub(r"version\s+(\d+(?:\.\d+){1,3})", r"v\1", text)
291
-
292
- # Pattern 2: RFC/ISO pattern standardization
293
- processed = re.sub(r"\b(RFC|ISO|IEEE)\s*[:#]?\s*(\d+)", r"\1 \2", processed)
294
-
295
- # Pattern 3: Standardize code references
296
- # This is a simplified example
297
- processed = re.sub(r"function\s+(\w+)\s*\(", r"\1(", processed)
298
-
299
- return processed
300
-
301
- def _process_general_entities(self, text: str) -> str:
302
- """Process general document entities."""
303
- # General cleaning and standardization
304
- processed = text
305
-
306
- # URLs preserved as-is
307
-
308
- # Simple date standardization
309
- processed = re.sub(
310
- r"(\d{1,2})/(\d{1,2})/(\d{2})(?!\d)",
311
- r"\1/\2/20\3", # Assume 2-digit years are 2000s
312
- processed,
313
- )
314
-
315
- return processed
316
-
317
- def remove_boilerplate(self, text: str) -> str:
318
- """
319
- Remove common document boilerplate patterns from text.
320
-
321
- Args:
322
- text: The input text to process
323
-
324
- Returns:
325
- Text with boilerplate patterns removed
326
- """
327
- return self.cleaning_pattern.sub("", text)
328
-
329
- def clean_text(self, text: str) -> str:
330
- """
331
- Clean text while preserving domain-specific entities.
332
-
333
- Args:
334
- text: The input text to clean
335
-
336
- Returns:
337
- Cleaned text with domain entities preserved
338
- """
339
- # First remove boilerplate
340
- cleaned = self.remove_boilerplate(text)
341
-
342
- # Then process domain-specific entities
343
- cleaned = self.process_entities(cleaned)
344
-
345
- return cleaned
346
-
347
- def create_pipeline(self) -> IngestionPipeline:
348
- """
349
- Create a document processing pipeline.
350
-
351
- Returns:
352
- Configured IngestionPipeline object
353
- """
354
- return IngestionPipeline(
355
- transformations=[
356
- self.clean_text,
357
- MarkdownNodeParser(),
358
- self.splitter,
359
- self.embed_model,
360
- ]
361
- )
362
-
363
- def validate_entity_retention(self, documents: List[Document]) -> Dict[str, float]:
364
- """
365
- Measure semantic similarity of entities before/after text cleaning.
366
-
367
- Args:
368
- documents: List of Document objects to validate
369
-
370
- Returns:
371
- Dictionary with validation metrics
372
- """
373
- if not documents:
374
- return {"entity_retention": 0.0, "processing_time": 0.0}
375
-
376
- start_time = time.time()
377
-
378
- # Extract original texts
379
- original_texts = [doc.text for doc in documents[:5]] # Sample for performance
380
-
381
- # Apply cleaning
382
- processed_texts = [self.clean_text(text) for text in original_texts]
383
-
384
- # Calculate embeddings
385
- try:
386
- # Direct access to the underlying HuggingFace model
387
- orig_embeds = self.embed_model._model.encode(original_texts)
388
- proc_embeds = self.embed_model._model.encode(processed_texts)
389
-
390
- # Calculate similarity
391
- similarities = cosine_similarity(orig_embeds, proc_embeds).diagonal()
392
- avg_similarity = float(np.mean(similarities))
393
-
394
- processing_time = time.time() - start_time
395
-
396
- return {
397
- "entity_retention": avg_similarity * 100, # As percentage
398
- "processing_time": processing_time,
399
- "sample_size": len(original_texts),
400
- }
401
- except Exception as e:
402
- return {"entity_retention": 0.0, "processing_time": 0.0, "error": str(e)}
403
-
404
- def process_documents(self, documents: List[Document]) -> Dict[str, Any]:
405
- """
406
- Process a list of documents.
407
-
408
- Args:
409
- documents: List of Document objects to process
410
-
411
- Returns:
412
- Dictionary with processing results and stats
413
- """
414
- if not documents:
415
- return {"status": "error", "message": "No documents provided"}
416
-
417
- try:
418
- # Create pipeline and process documents
419
- pipeline = self.create_pipeline()
420
- nodes = pipeline.run(documents=documents)
421
-
422
- # Create vector index
423
- index = VectorStoreIndex(nodes)
424
- query_engine = index.as_query_engine()
425
-
426
- # Return success with stats
427
- return {
428
- "status": "success",
429
- "nodes_count": len(nodes),
430
- "documents_count": len(documents),
431
- "domain": self.domain,
432
- "model_name": self.model_name,
433
- "query_engine": query_engine, # This will be used for querying
434
- }
435
- except Exception as e:
436
- return {"status": "error", "message": str(e)}
437
-
438
-
439
- class DocumentProcessorTool(Tool):
440
- """
441
- General-purpose document processing tool with domain specialization.
442
- """
443
-
444
- name = "document_processor"
445
- description = (
446
- "Processes documents with domain-specific models optimized for "
447
- "entity preservation and retrieval performance. Supports legal, "
448
- "financial, medical, technical and general document types."
449
- )
450
- inputs = {
451
- "text": {
452
- "type": "string",
453
- "description": "Document text to process. Provide either text or file_paths.",
454
- "optional": True,
455
- },
456
- "file_paths": {
457
- "type": "string",
458
- "description": "Comma-separated list of file paths or a directory path containing documents. Provide either text or file_paths.",
459
- "optional": True,
460
- },
461
- "domain": {
462
- "type": "string",
463
- "description": "Document domain for specialized processing: legal, financial, medical, technical, or general.",
464
- "default": "general",
465
- },
466
- "model_name": {
467
- "type": "string",
468
- "description": "Specific embedding model name to use (optional, overrides domain selection).",
469
- "optional": True,
470
- },
471
- "query": {
472
- "type": "string",
473
- "description": "Optional query to run against the processed documents.",
474
- "optional": True,
475
- },
476
- "validate_entities": {
477
- "type": "boolean",
478
- "description": "Whether to validate entity retention in the processed documents.",
479
- "default": False,
480
- },
481
- "use_gpu": {
482
- "type": "boolean",
483
- "description": "Whether to use GPU for embedding calculations if available.",
484
- "default": False,
485
- },
486
- }
487
- output_type = "string"
488
-
489
- def _load_documents(self, input_path: str) -> List[Document]:
490
- """
491
- Load documents from a file path or directory.
492
-
493
- Args:
494
- input_path: Path to a file or directory
495
-
496
- Returns:
497
- List of Document objects
498
- """
499
- if os.path.isfile(input_path):
500
- # Create a SimpleDirectoryReader for the file's directory
501
- # and filter to only include this file
502
- directory = os.path.dirname(input_path)
503
- filename = os.path.basename(input_path)
504
-
505
- return SimpleDirectoryReader(
506
- input_dir=directory,
507
- required_exts=[
508
- os.path.splitext(filename)[1][1:]
509
- ], # Extension without dot
510
- filename_as_id=True,
511
- ).load_data()
512
-
513
- elif os.path.isdir(input_path):
514
- return SimpleDirectoryReader(
515
- input_dir=input_path,
516
- filename_as_id=True,
517
- ).load_data()
518
-
519
- else:
520
- raise ValueError(f"Path not found: {input_path}")
521
-
522
- def _create_document_from_text(self, text: str) -> List[Document]:
523
- """
524
- Create a Document object from text.
525
-
526
- Args:
527
- text: Text content
528
-
529
- Returns:
530
- List containing a single Document object
531
- """
532
- # Create a temporary file to store the text
533
- with tempfile.NamedTemporaryFile(
534
- mode="w", suffix=".md", delete=False
535
- ) as temp_file:
536
- temp_file.write(text)
537
- temp_path = temp_file.name
538
-
539
- try:
540
- # Load the document from the temporary file
541
- documents = self._load_documents(temp_path)
542
- return documents
543
- finally:
544
- # Clean up the temporary file
545
- os.remove(temp_path)
546
-
547
- def forward(
548
- self,
549
- text: Optional[str] = None,
550
- file_paths: Optional[str] = None,
551
- domain: str = "general",
552
- model_name: Optional[str] = None,
553
- query: Optional[str] = None,
554
- validate_entities: bool = False,
555
- use_gpu: bool = False,
556
- ) -> str:
557
- """
558
- Process documents and optionally run a query.
559
-
560
- Args:
561
- text: Document text to process
562
- file_paths: Comma-separated list of file paths or a directory path
563
- domain: Document domain specialization
564
- model_name: Specific embedding model to use
565
- query: Optional query to run against the processed documents
566
- validate_entities: Whether to validate entity retention
567
- use_gpu: Whether to use GPU for embeddings
568
-
569
- Returns:
570
- Processing results or query response as a string
571
- """
572
- # Validate inputs
573
- if not text and not file_paths:
574
- return "Error: Either text or file_paths must be provided."
575
-
576
- try:
577
- # Initialize processor
578
- processor = DocumentProcessor(
579
- domain=domain,
580
- model_key=model_name,
581
- use_gpu=use_gpu,
582
- )
583
-
584
- # Load documents
585
- documents = []
586
-
587
- if text:
588
- documents.extend(self._create_document_from_text(text))
589
-
590
- if file_paths:
591
- # Handle comma-separated paths
592
- paths = [path.strip() for path in file_paths.split(",")]
593
-
594
- for path in paths:
595
- try:
596
- docs = self._load_documents(path)
597
- documents.extend(docs)
598
- except Exception as e:
599
- return f"Error loading documents from {path}: {str(e)}"
600
-
601
- # Check if we have documents to process
602
- if not documents:
603
- return "Error: No valid documents found."
604
-
605
- # Validate entity retention if requested
606
- validation_results = {}
607
- if validate_entities:
608
- validation_results = processor.validate_entity_retention(documents)
609
-
610
- # Process documents
611
- result = processor.process_documents(documents)
612
-
613
- if result["status"] != "success":
614
- return f"Processing error: {result['message']}"
615
-
616
- # Run query if provided
617
- if query and "query_engine" in result:
618
- query_engine = result["query_engine"]
619
- response = query_engine.query(query)
620
-
621
- # Format the response
622
- output = f"Query: {query}\n\nResponse: {response}\n\n"
623
- output += f"Documents processed: {result['documents_count']}\n"
624
- output += f"Text chunks: {result['nodes_count']}\n"
625
- output += f"Domain: {result['domain']}\n"
626
- output += f"Model: {result['model_name']}\n"
627
-
628
- # Add validation results if available
629
- if validation_results:
630
- output += "\n=== Entity Retention Validation ===\n"
631
- output += f"Entity retention: {validation_results.get('entity_retention', 0):.2f}%\n"
632
- output += f"Processing time: {validation_results.get('processing_time', 0):.2f} seconds\n"
633
-
634
- return output
635
-
636
- # If no query, return processing stats
637
- output = "Document processing complete.\n\n"
638
- output += f"Documents processed: {result['documents_count']}\n"
639
- output += f"Text chunks: {result['nodes_count']}\n"
640
- output += f"Domain: {result['domain']}\n"
641
- output += f"Model: {result['model_name']}\n"
642
-
643
- # Add validation results if available
644
- if validation_results:
645
- output += "\n=== Entity Retention Validation ===\n"
646
- output += f"Entity retention: {validation_results.get('entity_retention', 0):.2f}%\n"
647
- output += f"Processing time: {validation_results.get('processing_time', 0):.2f} seconds\n"
648
-
649
- output += "\nThe documents are now ready for querying. Use the 'query' parameter to run a query."
650
-
651
- return output
652
-
653
- except Exception as e:
654
- return f"Error: {str(e)}"