#!/usr/bin/env python # coding=utf-8 # Copyright 2025 The Footscray Coding Collective. All rights reserved. """ General Document Processing Tool for Smolagents This tool processes various types of documents with domain-specific models, optimizing for intelligent document parsing, entity extraction, and customized retrieval tasks. Author: Zhou Wang """ import os import re import tempfile import time from typing import Any, Dict, List, Optional, Union import numpy as np # Import Smolagents Tool class from smolagents import Tool # Import NLP components try: from langchain.text_splitter import RecursiveCharacterTextSplitter from llama_index.core import Document, SimpleDirectoryReader, VectorStoreIndex from llama_index.core.ingestion import IngestionPipeline from llama_index.core.node_parser import MarkdownNodeParser from llama_index.embeddings.huggingface import HuggingFaceEmbedding from sklearn.metrics.pairwise import cosine_similarity except ImportError: raise ImportError( "Required dependencies not found. Please install with: " "pip install llama-index langchain scikit-learn tqdm" ) # Model configurations based on domain specialization DOMAIN_MODELS = { "legal": { "name": "joelito/legal-xlm-roberta-base", "description": "Specialized for legal documents with citation preservation", "max_length": 512, "requires_gpu": True, }, "financial": { "name": "thenlper/finetuned-finbert-slot-filling", "description": "Financial document analysis with entity extraction", "max_length": 512, "requires_gpu": False, }, "medical": { "name": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", "description": "Medical text processing optimized for clinical terms", "max_length": 512, "requires_gpu": True, }, "technical": { "name": "allenai/scibert_scivocab_uncased", "description": "Scientific and technical document processing", "max_length": 512, "requires_gpu": True, }, "general": { "name": "sentence-transformers/all-mpnet-base-v2", "description": "General purpose embedding model for all document types", "max_length": 512, "requires_gpu": False, }, } class DocumentProcessor: """ Processor for documents with domain-specific models, entity preservation, and customizable processing capabilities. """ def __init__( self, domain: str = "general", model_key: Optional[str] = None, use_gpu: bool = False, chunk_size: int = 512, chunk_overlap: int = 100, custom_patterns: Optional[List[str]] = None, ): """ Initialize the document processor. Args: domain: Domain specialization ('legal', 'financial', 'medical', 'technical', 'general') model_key: Specific model to use (overrides domain selection) use_gpu: Whether to use GPU for embeddings (if available) chunk_size: Size of text chunks for processing chunk_overlap: Overlap between chunks to preserve context custom_patterns: Additional regex patterns for text cleaning """ # Store domain self.domain = domain # If model_key provided, use it directly if model_key: model_name = model_key device = "cuda" if use_gpu else "cpu" else: # Otherwise select model based on domain if domain not in DOMAIN_MODELS: print( f"Warning: Domain '{domain}' not found. Using 'general' as default." ) domain = "general" model_config = DOMAIN_MODELS[domain] model_name = model_config["name"] device = "cuda" if use_gpu and model_config["requires_gpu"] else "cpu" # Initialize embedding model try: self.embed_model = HuggingFaceEmbedding( model_name=model_name, device=device, tokenizer_kwargs={ "trust_remote_code": True, "max_length": 512, "truncation": True, }, ) # Store model information for reference self.model_name = model_name self.device = device except Exception as e: raise RuntimeError(f"Failed to initialize embedding model: {str(e)}") # Domain-optimized text splitter self.splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap, separators=[ "\n## ", "\n### ", "\n#### ", # Headers "\n\n", "\n", # Paragraphs ". ", "! ", "? ", # Sentences ";", ":", # Clause boundaries " ", # Last resort ], ) # Base cleaning patterns self.cleaning_patterns = [ r"^Page\s\d+(\s+of\s+\d+)?$", # Page numbers r"^©.*\b(Company|Inc|Ltd)\b.*$", # Copyright lines r"^All rights reserved.*?$", # Legal boilerplate r"^-+$", # Separator lines r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}(:\d{2})?$", # Timestamps r"(?i)^(confidential|proprietary|internal use only)", # Security tags ] # Add custom patterns if provided if custom_patterns: self.cleaning_patterns.extend(custom_patterns) # Join all patterns with the OR operator combined_pattern = "|".join( f"({pattern})" for pattern in self.cleaning_patterns ) # Compile the combined pattern self.cleaning_pattern = re.compile( combined_pattern, flags=re.MULTILINE | re.IGNORECASE ) # Initialize domain-specific processors self._init_domain_processors() def _init_domain_processors(self): """Initialize domain-specific processors based on selected domain.""" # Domain-specific entity patterns self.entity_patterns = {} # Set up domain-specific patterns and processors if self.domain == "legal": self.entity_patterns = { "case_citation": r"\[\d{4}\]\s+[A-Z]+\s+\d+", # [2019] UKSC 20 "statute": r"\b(?:Art\.|Section)\s+\d+(\.\d+)?", # Art. 5, Section 3.1 "legal_ref": r"\b[A-Za-z]+\s+v\.?\s+[A-Za-z]+", # Smith v. Jones } self.process_entities = self._process_legal_entities if self.domain == "financial": self.entity_patterns = { "monetary": r"\$\s*\d+(?:\.\d+)?(?:\s*(?:million|billion|trillion))?", # $5.2 million "percentage": r"\d+(?:\.\d+)?\s*%", # 10.5% "date_range": r"(?:Q[1-4]|FY)\s+\d{4}", # Q2 2023, FY 2022 } self.process_entities = self._process_financial_entities if self.domain == "medical": self.entity_patterns = { "dosage": r"\d+(?:\.\d+)?\s*(?:mg|mcg|g|ml|oz)", # 10mg, 5.5ml "medical_code": r"[A-Z]\d{2}(?:\.\d+)?", # ICD codes like E11.9 "vital_sign": r"\d+(?:\.\d+)?\s*(?:bpm|mmHg|°[CF])", # 120 bpm, 98.6°F } self.process_entities = self._process_medical_entities if self.domain == "technical": self.entity_patterns = { "version": r"v\d+(?:\.\d+){1,3}", # v1.2.3 "code_ref": r"(?:\w+\.)+\w+\(\)", # function calls like math.sqrt() "tech_standard": r"(?:RFC|ISO|IEEE)\s*\d+", # RFC 1918, ISO 9001 } self.process_entities = self._process_technical_entities else: # General domain or fallback self.entity_patterns = { "url": r"https?://\S+", # URLs "email": r"\S+@\S+\.\S+", # Email addresses "date": r"\d{1,2}[/-]\d{1,2}[/-]\d{2,4}", # Dates } self.process_entities = self._process_general_entities def _process_legal_entities(self, text: str) -> str: """Process legal document entities.""" # Preserve citation patterns # Pattern 1: Case citations [2019] UKSC 20 # Already well-structured, so no changes needed # Pattern 2: Standardize section references (§3.1, §123) processed = re.sub(r"§(\d+(\.\d+)?)", r"Section \1", text) # Pattern 3: Handle legal abbreviations (e.g., Art. -> Article) processed = re.sub(r"\bArt\.\s+(\d+)", r"Article \1", processed) # Pattern 4: Standardize case names with v. and vs. processed = re.sub(r"\bv\s+", r"v. ", processed) processed = re.sub(r"\bvs\s+", r"v. ", processed) return processed def _process_financial_entities(self, text: str) -> str: """Process financial document entities.""" # Pattern 1: Standardize monetary values processed = re.sub( r"\$\s*(\d+)(?:,\d{3})*(?:\.\d+)?", lambda m: f"${float(m.group(1).replace(',', ''))}", text, ) # Pattern 2: Standardize percentage representations processed = re.sub(r"(\d+(?:\.\d+)?)\s*(?:percent|pct)", r"\1%", processed) # Pattern 3: Standardize fiscal periods processed = re.sub(r"(?:fiscal year|FY)\s+(\d{4})", r"FY \1", processed) # Pattern 4: Standardize quarterly references processed = re.sub(r"(?:quarter|Q)(\d)\s+(\d{4})", r"Q\1 \2", processed) return processed def _process_medical_entities(self, text: str) -> str: """Process medical document entities.""" # Pattern 1: Standardize dosage format processed = re.sub( r"(\d+(?:\.\d+)?)\s*(milligrams?|mcgs?|grams?|milliliters?)", lambda m: f"{m.group(1)} {m.group(2)[0:2]}", text, ) # Pattern 2: Standardize temperature format processed = re.sub(r"(\d+(?:\.\d+)?)\s*degrees?\s*([CF])", r"\1°\2", processed) # Pattern 3: Standardize vital signs processed = re.sub( r"(\d+(?:\.\d+)?)\s*(?:beats per minute|BPM)", r"\1 bpm", processed ) return processed def _process_technical_entities(self, text: str) -> str: """Process technical document entities.""" # Pattern 1: Standardize version numbers processed = re.sub(r"version\s+(\d+(?:\.\d+){1,3})", r"v\1", text) # Pattern 2: RFC/ISO pattern standardization processed = re.sub(r"\b(RFC|ISO|IEEE)\s*[:#]?\s*(\d+)", r"\1 \2", processed) # Pattern 3: Standardize code references # This is a simplified example processed = re.sub(r"function\s+(\w+)\s*\(", r"\1(", processed) return processed def _process_general_entities(self, text: str) -> str: """Process general document entities.""" # General cleaning and standardization processed = text # URLs preserved as-is # Simple date standardization processed = re.sub( r"(\d{1,2})/(\d{1,2})/(\d{2})(?!\d)", r"\1/\2/20\3", # Assume 2-digit years are 2000s processed, ) return processed def remove_boilerplate(self, text: str) -> str: """ Remove common document boilerplate patterns from text. Args: text: The input text to process Returns: Text with boilerplate patterns removed """ return self.cleaning_pattern.sub("", text) def clean_text(self, text: str) -> str: """ Clean text while preserving domain-specific entities. Args: text: The input text to clean Returns: Cleaned text with domain entities preserved """ # First remove boilerplate cleaned = self.remove_boilerplate(text) # Then process domain-specific entities cleaned = self.process_entities(cleaned) return cleaned def create_pipeline(self) -> IngestionPipeline: """ Create a document processing pipeline. Returns: Configured IngestionPipeline object """ return IngestionPipeline( transformations=[ self.clean_text, MarkdownNodeParser(), self.splitter, self.embed_model, ] ) def validate_entity_retention(self, documents: List[Document]) -> Dict[str, float]: """ Measure semantic similarity of entities before/after text cleaning. Args: documents: List of Document objects to validate Returns: Dictionary with validation metrics """ if not documents: return {"entity_retention": 0.0, "processing_time": 0.0} start_time = time.time() # Extract original texts original_texts = [doc.text for doc in documents[:5]] # Sample for performance # Apply cleaning processed_texts = [self.clean_text(text) for text in original_texts] # Calculate embeddings try: # Direct access to the underlying HuggingFace model orig_embeds = self.embed_model._model.encode(original_texts) proc_embeds = self.embed_model._model.encode(processed_texts) # Calculate similarity similarities = cosine_similarity(orig_embeds, proc_embeds).diagonal() avg_similarity = float(np.mean(similarities)) processing_time = time.time() - start_time return { "entity_retention": avg_similarity * 100, # As percentage "processing_time": processing_time, "sample_size": len(original_texts), } except Exception as e: return {"entity_retention": 0.0, "processing_time": 0.0, "error": str(e)} def process_documents(self, documents: List[Document]) -> Dict[str, Any]: """ Process a list of documents. Args: documents: List of Document objects to process Returns: Dictionary with processing results and stats """ if not documents: return {"status": "error", "message": "No documents provided"} try: # Create pipeline and process documents pipeline = self.create_pipeline() nodes = pipeline.run(documents=documents) # Create vector index index = VectorStoreIndex(nodes) query_engine = index.as_query_engine() # Return success with stats return { "status": "success", "nodes_count": len(nodes), "documents_count": len(documents), "domain": self.domain, "model_name": self.model_name, "query_engine": query_engine, # This will be used for querying } except Exception as e: return {"status": "error", "message": str(e)} class DocumentProcessorTool(Tool): """ General-purpose document processing tool with domain specialization. """ name = "document_processor" description = ( "Processes documents with domain-specific models optimized for " "entity preservation and retrieval performance. Supports legal, " "financial, medical, technical and general document types." ) inputs = { "text": { "type": "string", "description": "Document text to process. Provide either text or file_paths.", "optional": True, }, "file_paths": { "type": "string", "description": "Comma-separated list of file paths or a directory path containing documents. Provide either text or file_paths.", "optional": True, }, "domain": { "type": "string", "description": "Document domain for specialized processing: legal, financial, medical, technical, or general.", "default": "general", }, "model_name": { "type": "string", "description": "Specific embedding model name to use (optional, overrides domain selection).", "optional": True, }, "query": { "type": "string", "description": "Optional query to run against the processed documents.", "optional": True, }, "validate_entities": { "type": "boolean", "description": "Whether to validate entity retention in the processed documents.", "default": False, }, "use_gpu": { "type": "boolean", "description": "Whether to use GPU for embedding calculations if available.", "default": False, }, } output_type = "string" def _load_documents(self, input_path: str) -> List[Document]: """ Load documents from a file path or directory. Args: input_path: Path to a file or directory Returns: List of Document objects """ if os.path.isfile(input_path): # Create a SimpleDirectoryReader for the file's directory # and filter to only include this file directory = os.path.dirname(input_path) filename = os.path.basename(input_path) return SimpleDirectoryReader( input_dir=directory, required_exts=[ os.path.splitext(filename)[1][1:] ], # Extension without dot filename_as_id=True, ).load_data() elif os.path.isdir(input_path): return SimpleDirectoryReader( input_dir=input_path, filename_as_id=True, ).load_data() else: raise ValueError(f"Path not found: {input_path}") def _create_document_from_text(self, text: str) -> List[Document]: """ Create a Document object from text. Args: text: Text content Returns: List containing a single Document object """ # Create a temporary file to store the text with tempfile.NamedTemporaryFile( mode="w", suffix=".md", delete=False ) as temp_file: temp_file.write(text) temp_path = temp_file.name try: # Load the document from the temporary file documents = self._load_documents(temp_path) return documents finally: # Clean up the temporary file os.remove(temp_path) def forward( self, text: Optional[str] = None, file_paths: Optional[str] = None, domain: str = "general", model_name: Optional[str] = None, query: Optional[str] = None, validate_entities: bool = False, use_gpu: bool = False, ) -> str: """ Process documents and optionally run a query. Args: text: Document text to process file_paths: Comma-separated list of file paths or a directory path domain: Document domain specialization model_name: Specific embedding model to use query: Optional query to run against the processed documents validate_entities: Whether to validate entity retention use_gpu: Whether to use GPU for embeddings Returns: Processing results or query response as a string """ # Validate inputs if not text and not file_paths: return "Error: Either text or file_paths must be provided." try: # Initialize processor processor = DocumentProcessor( domain=domain, model_key=model_name, use_gpu=use_gpu, ) # Load documents documents = [] if text: documents.extend(self._create_document_from_text(text)) if file_paths: # Handle comma-separated paths paths = [path.strip() for path in file_paths.split(",")] for path in paths: try: docs = self._load_documents(path) documents.extend(docs) except Exception as e: return f"Error loading documents from {path}: {str(e)}" # Check if we have documents to process if not documents: return "Error: No valid documents found." # Validate entity retention if requested validation_results = {} if validate_entities: validation_results = processor.validate_entity_retention(documents) # Process documents result = processor.process_documents(documents) if result["status"] != "success": return f"Processing error: {result['message']}" # Run query if provided if query and "query_engine" in result: query_engine = result["query_engine"] response = query_engine.query(query) # Format the response output = f"Query: {query}\n\nResponse: {response}\n\n" output += f"Documents processed: {result['documents_count']}\n" output += f"Text chunks: {result['nodes_count']}\n" output += f"Domain: {result['domain']}\n" output += f"Model: {result['model_name']}\n" # Add validation results if available if validation_results: output += "\n=== Entity Retention Validation ===\n" output += f"Entity retention: {validation_results.get('entity_retention', 0):.2f}%\n" output += f"Processing time: {validation_results.get('processing_time', 0):.2f} seconds\n" return output # If no query, return processing stats output = "Document processing complete.\n\n" output += f"Documents processed: {result['documents_count']}\n" output += f"Text chunks: {result['nodes_count']}\n" output += f"Domain: {result['domain']}\n" output += f"Model: {result['model_name']}\n" # Add validation results if available if validation_results: output += "\n=== Entity Retention Validation ===\n" output += f"Entity retention: {validation_results.get('entity_retention', 0):.2f}%\n" output += f"Processing time: {validation_results.get('processing_time', 0):.2f} seconds\n" output += "\nThe documents are now ready for querying. Use the 'query' parameter to run a query." return output except Exception as e: return f"Error: {str(e)}"