Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # coding=utf-8 | |
| # Copyright 2025 The Footscray Coding Collective. All rights reserved. | |
| """ | |
| General Document Processing Tool for Smolagents | |
| This tool processes various types of documents with domain-specific models, | |
| optimizing for intelligent document parsing, entity extraction, and | |
| customized retrieval tasks. | |
| Author: Zhou Wang | |
| """ | |
| import os | |
| import re | |
| import tempfile | |
| import time | |
| from typing import Any, Dict, List, Optional, Union | |
| import numpy as np | |
| # Import Smolagents Tool class | |
| from smolagents import Tool | |
| # Import NLP components | |
| try: | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from llama_index.core import Document, SimpleDirectoryReader, VectorStoreIndex | |
| from llama_index.core.ingestion import IngestionPipeline | |
| from llama_index.core.node_parser import MarkdownNodeParser | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| except ImportError: | |
| raise ImportError( | |
| "Required dependencies not found. Please install with: " | |
| "pip install llama-index langchain scikit-learn tqdm" | |
| ) | |
| # Model configurations based on domain specialization | |
| DOMAIN_MODELS = { | |
| "legal": { | |
| "name": "joelito/legal-xlm-roberta-base", | |
| "description": "Specialized for legal documents with citation preservation", | |
| "max_length": 512, | |
| "requires_gpu": True, | |
| }, | |
| "financial": { | |
| "name": "thenlper/finetuned-finbert-slot-filling", | |
| "description": "Financial document analysis with entity extraction", | |
| "max_length": 512, | |
| "requires_gpu": False, | |
| }, | |
| "medical": { | |
| "name": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", | |
| "description": "Medical text processing optimized for clinical terms", | |
| "max_length": 512, | |
| "requires_gpu": True, | |
| }, | |
| "technical": { | |
| "name": "allenai/scibert_scivocab_uncased", | |
| "description": "Scientific and technical document processing", | |
| "max_length": 512, | |
| "requires_gpu": True, | |
| }, | |
| "general": { | |
| "name": "sentence-transformers/all-mpnet-base-v2", | |
| "description": "General purpose embedding model for all document types", | |
| "max_length": 512, | |
| "requires_gpu": False, | |
| }, | |
| } | |
| class DocumentProcessor: | |
| """ | |
| Processor for documents with domain-specific models, | |
| entity preservation, and customizable processing capabilities. | |
| """ | |
| def __init__( | |
| self, | |
| domain: str = "general", | |
| model_key: Optional[str] = None, | |
| use_gpu: bool = False, | |
| chunk_size: int = 512, | |
| chunk_overlap: int = 100, | |
| custom_patterns: Optional[List[str]] = None, | |
| ): | |
| """ | |
| Initialize the document processor. | |
| Args: | |
| domain: Domain specialization ('legal', 'financial', 'medical', 'technical', 'general') | |
| model_key: Specific model to use (overrides domain selection) | |
| use_gpu: Whether to use GPU for embeddings (if available) | |
| chunk_size: Size of text chunks for processing | |
| chunk_overlap: Overlap between chunks to preserve context | |
| custom_patterns: Additional regex patterns for text cleaning | |
| """ | |
| # Store domain | |
| self.domain = domain | |
| # If model_key provided, use it directly | |
| if model_key: | |
| model_name = model_key | |
| device = "cuda" if use_gpu else "cpu" | |
| else: | |
| # Otherwise select model based on domain | |
| if domain not in DOMAIN_MODELS: | |
| print( | |
| f"Warning: Domain '{domain}' not found. Using 'general' as default." | |
| ) | |
| domain = "general" | |
| model_config = DOMAIN_MODELS[domain] | |
| model_name = model_config["name"] | |
| device = "cuda" if use_gpu and model_config["requires_gpu"] else "cpu" | |
| # Initialize embedding model | |
| try: | |
| self.embed_model = HuggingFaceEmbedding( | |
| model_name=model_name, | |
| device=device, | |
| tokenizer_kwargs={ | |
| "trust_remote_code": True, | |
| "max_length": 512, | |
| "truncation": True, | |
| }, | |
| ) | |
| # Store model information for reference | |
| self.model_name = model_name | |
| self.device = device | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to initialize embedding model: {str(e)}") | |
| # Domain-optimized text splitter | |
| self.splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=chunk_size, | |
| chunk_overlap=chunk_overlap, | |
| separators=[ | |
| "\n## ", | |
| "\n### ", | |
| "\n#### ", # Headers | |
| "\n\n", | |
| "\n", # Paragraphs | |
| ". ", | |
| "! ", | |
| "? ", # Sentences | |
| ";", | |
| ":", # Clause boundaries | |
| " ", # Last resort | |
| ], | |
| ) | |
| # Base cleaning patterns | |
| self.cleaning_patterns = [ | |
| r"^Page\s\d+(\s+of\s+\d+)?$", # Page numbers | |
| r"^©.*\b(Company|Inc|Ltd)\b.*$", # Copyright lines | |
| r"^All rights reserved.*?$", # Legal boilerplate | |
| r"^-+$", # Separator lines | |
| r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}(:\d{2})?$", # Timestamps | |
| r"(?i)^(confidential|proprietary|internal use only)", # Security tags | |
| ] | |
| # Add custom patterns if provided | |
| if custom_patterns: | |
| self.cleaning_patterns.extend(custom_patterns) | |
| # Join all patterns with the OR operator | |
| combined_pattern = "|".join( | |
| f"({pattern})" for pattern in self.cleaning_patterns | |
| ) | |
| # Compile the combined pattern | |
| self.cleaning_pattern = re.compile( | |
| combined_pattern, flags=re.MULTILINE | re.IGNORECASE | |
| ) | |
| # Initialize domain-specific processors | |
| self._init_domain_processors() | |
| def _init_domain_processors(self): | |
| """Initialize domain-specific processors based on selected domain.""" | |
| # Domain-specific entity patterns | |
| self.entity_patterns = {} | |
| # Set up domain-specific patterns and processors | |
| if self.domain == "legal": | |
| self.entity_patterns = { | |
| "case_citation": r"\[\d{4}\]\s+[A-Z]+\s+\d+", # [2019] UKSC 20 | |
| "statute": r"\b(?:Art\.|Section)\s+\d+(\.\d+)?", # Art. 5, Section 3.1 | |
| "legal_ref": r"\b[A-Za-z]+\s+v\.?\s+[A-Za-z]+", # Smith v. Jones | |
| } | |
| self.process_entities = self._process_legal_entities | |
| if self.domain == "financial": | |
| self.entity_patterns = { | |
| "monetary": r"\$\s*\d+(?:\.\d+)?(?:\s*(?:million|billion|trillion))?", # $5.2 million | |
| "percentage": r"\d+(?:\.\d+)?\s*%", # 10.5% | |
| "date_range": r"(?:Q[1-4]|FY)\s+\d{4}", # Q2 2023, FY 2022 | |
| } | |
| self.process_entities = self._process_financial_entities | |
| if self.domain == "medical": | |
| self.entity_patterns = { | |
| "dosage": r"\d+(?:\.\d+)?\s*(?:mg|mcg|g|ml|oz)", # 10mg, 5.5ml | |
| "medical_code": r"[A-Z]\d{2}(?:\.\d+)?", # ICD codes like E11.9 | |
| "vital_sign": r"\d+(?:\.\d+)?\s*(?:bpm|mmHg|°[CF])", # 120 bpm, 98.6°F | |
| } | |
| self.process_entities = self._process_medical_entities | |
| if self.domain == "technical": | |
| self.entity_patterns = { | |
| "version": r"v\d+(?:\.\d+){1,3}", # v1.2.3 | |
| "code_ref": r"(?:\w+\.)+\w+\(\)", # function calls like math.sqrt() | |
| "tech_standard": r"(?:RFC|ISO|IEEE)\s*\d+", # RFC 1918, ISO 9001 | |
| } | |
| self.process_entities = self._process_technical_entities | |
| else: # General domain or fallback | |
| self.entity_patterns = { | |
| "url": r"https?://\S+", # URLs | |
| "email": r"\S+@\S+\.\S+", # Email addresses | |
| "date": r"\d{1,2}[/-]\d{1,2}[/-]\d{2,4}", # Dates | |
| } | |
| self.process_entities = self._process_general_entities | |
| def _process_legal_entities(self, text: str) -> str: | |
| """Process legal document entities.""" | |
| # Preserve citation patterns | |
| # Pattern 1: Case citations [2019] UKSC 20 | |
| # Already well-structured, so no changes needed | |
| # Pattern 2: Standardize section references (§3.1, §123) | |
| processed = re.sub(r"§(\d+(\.\d+)?)", r"Section \1", text) | |
| # Pattern 3: Handle legal abbreviations (e.g., Art. -> Article) | |
| processed = re.sub(r"\bArt\.\s+(\d+)", r"Article \1", processed) | |
| # Pattern 4: Standardize case names with v. and vs. | |
| processed = re.sub(r"\bv\s+", r"v. ", processed) | |
| processed = re.sub(r"\bvs\s+", r"v. ", processed) | |
| return processed | |
| def _process_financial_entities(self, text: str) -> str: | |
| """Process financial document entities.""" | |
| # Pattern 1: Standardize monetary values | |
| processed = re.sub( | |
| r"\$\s*(\d+)(?:,\d{3})*(?:\.\d+)?", | |
| lambda m: f"${float(m.group(1).replace(',', ''))}", | |
| text, | |
| ) | |
| # Pattern 2: Standardize percentage representations | |
| processed = re.sub(r"(\d+(?:\.\d+)?)\s*(?:percent|pct)", r"\1%", processed) | |
| # Pattern 3: Standardize fiscal periods | |
| processed = re.sub(r"(?:fiscal year|FY)\s+(\d{4})", r"FY \1", processed) | |
| # Pattern 4: Standardize quarterly references | |
| processed = re.sub(r"(?:quarter|Q)(\d)\s+(\d{4})", r"Q\1 \2", processed) | |
| return processed | |
| def _process_medical_entities(self, text: str) -> str: | |
| """Process medical document entities.""" | |
| # Pattern 1: Standardize dosage format | |
| processed = re.sub( | |
| r"(\d+(?:\.\d+)?)\s*(milligrams?|mcgs?|grams?|milliliters?)", | |
| lambda m: f"{m.group(1)} {m.group(2)[0:2]}", | |
| text, | |
| ) | |
| # Pattern 2: Standardize temperature format | |
| processed = re.sub(r"(\d+(?:\.\d+)?)\s*degrees?\s*([CF])", r"\1°\2", processed) | |
| # Pattern 3: Standardize vital signs | |
| processed = re.sub( | |
| r"(\d+(?:\.\d+)?)\s*(?:beats per minute|BPM)", r"\1 bpm", processed | |
| ) | |
| return processed | |
| def _process_technical_entities(self, text: str) -> str: | |
| """Process technical document entities.""" | |
| # Pattern 1: Standardize version numbers | |
| processed = re.sub(r"version\s+(\d+(?:\.\d+){1,3})", r"v\1", text) | |
| # Pattern 2: RFC/ISO pattern standardization | |
| processed = re.sub(r"\b(RFC|ISO|IEEE)\s*[:#]?\s*(\d+)", r"\1 \2", processed) | |
| # Pattern 3: Standardize code references | |
| # This is a simplified example | |
| processed = re.sub(r"function\s+(\w+)\s*\(", r"\1(", processed) | |
| return processed | |
| def _process_general_entities(self, text: str) -> str: | |
| """Process general document entities.""" | |
| # General cleaning and standardization | |
| processed = text | |
| # URLs preserved as-is | |
| # Simple date standardization | |
| processed = re.sub( | |
| r"(\d{1,2})/(\d{1,2})/(\d{2})(?!\d)", | |
| r"\1/\2/20\3", # Assume 2-digit years are 2000s | |
| processed, | |
| ) | |
| return processed | |
| def remove_boilerplate(self, text: str) -> str: | |
| """ | |
| Remove common document boilerplate patterns from text. | |
| Args: | |
| text: The input text to process | |
| Returns: | |
| Text with boilerplate patterns removed | |
| """ | |
| return self.cleaning_pattern.sub("", text) | |
| def clean_text(self, text: str) -> str: | |
| """ | |
| Clean text while preserving domain-specific entities. | |
| Args: | |
| text: The input text to clean | |
| Returns: | |
| Cleaned text with domain entities preserved | |
| """ | |
| # First remove boilerplate | |
| cleaned = self.remove_boilerplate(text) | |
| # Then process domain-specific entities | |
| cleaned = self.process_entities(cleaned) | |
| return cleaned | |
| def create_pipeline(self) -> IngestionPipeline: | |
| """ | |
| Create a document processing pipeline. | |
| Returns: | |
| Configured IngestionPipeline object | |
| """ | |
| return IngestionPipeline( | |
| transformations=[ | |
| self.clean_text, | |
| MarkdownNodeParser(), | |
| self.splitter, | |
| self.embed_model, | |
| ] | |
| ) | |
| def validate_entity_retention(self, documents: List[Document]) -> Dict[str, float]: | |
| """ | |
| Measure semantic similarity of entities before/after text cleaning. | |
| Args: | |
| documents: List of Document objects to validate | |
| Returns: | |
| Dictionary with validation metrics | |
| """ | |
| if not documents: | |
| return {"entity_retention": 0.0, "processing_time": 0.0} | |
| start_time = time.time() | |
| # Extract original texts | |
| original_texts = [doc.text for doc in documents[:5]] # Sample for performance | |
| # Apply cleaning | |
| processed_texts = [self.clean_text(text) for text in original_texts] | |
| # Calculate embeddings | |
| try: | |
| # Direct access to the underlying HuggingFace model | |
| orig_embeds = self.embed_model._model.encode(original_texts) | |
| proc_embeds = self.embed_model._model.encode(processed_texts) | |
| # Calculate similarity | |
| similarities = cosine_similarity(orig_embeds, proc_embeds).diagonal() | |
| avg_similarity = float(np.mean(similarities)) | |
| processing_time = time.time() - start_time | |
| return { | |
| "entity_retention": avg_similarity * 100, # As percentage | |
| "processing_time": processing_time, | |
| "sample_size": len(original_texts), | |
| } | |
| except Exception as e: | |
| return {"entity_retention": 0.0, "processing_time": 0.0, "error": str(e)} | |
| def process_documents(self, documents: List[Document]) -> Dict[str, Any]: | |
| """ | |
| Process a list of documents. | |
| Args: | |
| documents: List of Document objects to process | |
| Returns: | |
| Dictionary with processing results and stats | |
| """ | |
| if not documents: | |
| return {"status": "error", "message": "No documents provided"} | |
| try: | |
| # Create pipeline and process documents | |
| pipeline = self.create_pipeline() | |
| nodes = pipeline.run(documents=documents) | |
| # Create vector index | |
| index = VectorStoreIndex(nodes) | |
| query_engine = index.as_query_engine() | |
| # Return success with stats | |
| return { | |
| "status": "success", | |
| "nodes_count": len(nodes), | |
| "documents_count": len(documents), | |
| "domain": self.domain, | |
| "model_name": self.model_name, | |
| "query_engine": query_engine, # This will be used for querying | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| class DocumentProcessorTool(Tool): | |
| """ | |
| General-purpose document processing tool with domain specialization. | |
| """ | |
| name = "document_processor" | |
| description = ( | |
| "Processes documents with domain-specific models optimized for " | |
| "entity preservation and retrieval performance. Supports legal, " | |
| "financial, medical, technical and general document types." | |
| ) | |
| inputs = { | |
| "text": { | |
| "type": "string", | |
| "description": "Document text to process. Provide either text or file_paths.", | |
| "optional": True, | |
| }, | |
| "file_paths": { | |
| "type": "string", | |
| "description": "Comma-separated list of file paths or a directory path containing documents. Provide either text or file_paths.", | |
| "optional": True, | |
| }, | |
| "domain": { | |
| "type": "string", | |
| "description": "Document domain for specialized processing: legal, financial, medical, technical, or general.", | |
| "default": "general", | |
| }, | |
| "model_name": { | |
| "type": "string", | |
| "description": "Specific embedding model name to use (optional, overrides domain selection).", | |
| "optional": True, | |
| }, | |
| "query": { | |
| "type": "string", | |
| "description": "Optional query to run against the processed documents.", | |
| "optional": True, | |
| }, | |
| "validate_entities": { | |
| "type": "boolean", | |
| "description": "Whether to validate entity retention in the processed documents.", | |
| "default": False, | |
| }, | |
| "use_gpu": { | |
| "type": "boolean", | |
| "description": "Whether to use GPU for embedding calculations if available.", | |
| "default": False, | |
| }, | |
| } | |
| output_type = "string" | |
| def _load_documents(self, input_path: str) -> List[Document]: | |
| """ | |
| Load documents from a file path or directory. | |
| Args: | |
| input_path: Path to a file or directory | |
| Returns: | |
| List of Document objects | |
| """ | |
| if os.path.isfile(input_path): | |
| # Create a SimpleDirectoryReader for the file's directory | |
| # and filter to only include this file | |
| directory = os.path.dirname(input_path) | |
| filename = os.path.basename(input_path) | |
| return SimpleDirectoryReader( | |
| input_dir=directory, | |
| required_exts=[ | |
| os.path.splitext(filename)[1][1:] | |
| ], # Extension without dot | |
| filename_as_id=True, | |
| ).load_data() | |
| elif os.path.isdir(input_path): | |
| return SimpleDirectoryReader( | |
| input_dir=input_path, | |
| filename_as_id=True, | |
| ).load_data() | |
| else: | |
| raise ValueError(f"Path not found: {input_path}") | |
| def _create_document_from_text(self, text: str) -> List[Document]: | |
| """ | |
| Create a Document object from text. | |
| Args: | |
| text: Text content | |
| Returns: | |
| List containing a single Document object | |
| """ | |
| # Create a temporary file to store the text | |
| with tempfile.NamedTemporaryFile( | |
| mode="w", suffix=".md", delete=False | |
| ) as temp_file: | |
| temp_file.write(text) | |
| temp_path = temp_file.name | |
| try: | |
| # Load the document from the temporary file | |
| documents = self._load_documents(temp_path) | |
| return documents | |
| finally: | |
| # Clean up the temporary file | |
| os.remove(temp_path) | |
| def forward( | |
| self, | |
| text: Optional[str] = None, | |
| file_paths: Optional[str] = None, | |
| domain: str = "general", | |
| model_name: Optional[str] = None, | |
| query: Optional[str] = None, | |
| validate_entities: bool = False, | |
| use_gpu: bool = False, | |
| ) -> str: | |
| """ | |
| Process documents and optionally run a query. | |
| Args: | |
| text: Document text to process | |
| file_paths: Comma-separated list of file paths or a directory path | |
| domain: Document domain specialization | |
| model_name: Specific embedding model to use | |
| query: Optional query to run against the processed documents | |
| validate_entities: Whether to validate entity retention | |
| use_gpu: Whether to use GPU for embeddings | |
| Returns: | |
| Processing results or query response as a string | |
| """ | |
| # Validate inputs | |
| if not text and not file_paths: | |
| return "Error: Either text or file_paths must be provided." | |
| try: | |
| # Initialize processor | |
| processor = DocumentProcessor( | |
| domain=domain, | |
| model_key=model_name, | |
| use_gpu=use_gpu, | |
| ) | |
| # Load documents | |
| documents = [] | |
| if text: | |
| documents.extend(self._create_document_from_text(text)) | |
| if file_paths: | |
| # Handle comma-separated paths | |
| paths = [path.strip() for path in file_paths.split(",")] | |
| for path in paths: | |
| try: | |
| docs = self._load_documents(path) | |
| documents.extend(docs) | |
| except Exception as e: | |
| return f"Error loading documents from {path}: {str(e)}" | |
| # Check if we have documents to process | |
| if not documents: | |
| return "Error: No valid documents found." | |
| # Validate entity retention if requested | |
| validation_results = {} | |
| if validate_entities: | |
| validation_results = processor.validate_entity_retention(documents) | |
| # Process documents | |
| result = processor.process_documents(documents) | |
| if result["status"] != "success": | |
| return f"Processing error: {result['message']}" | |
| # Run query if provided | |
| if query and "query_engine" in result: | |
| query_engine = result["query_engine"] | |
| response = query_engine.query(query) | |
| # Format the response | |
| output = f"Query: {query}\n\nResponse: {response}\n\n" | |
| output += f"Documents processed: {result['documents_count']}\n" | |
| output += f"Text chunks: {result['nodes_count']}\n" | |
| output += f"Domain: {result['domain']}\n" | |
| output += f"Model: {result['model_name']}\n" | |
| # Add validation results if available | |
| if validation_results: | |
| output += "\n=== Entity Retention Validation ===\n" | |
| output += f"Entity retention: {validation_results.get('entity_retention', 0):.2f}%\n" | |
| output += f"Processing time: {validation_results.get('processing_time', 0):.2f} seconds\n" | |
| return output | |
| # If no query, return processing stats | |
| output = "Document processing complete.\n\n" | |
| output += f"Documents processed: {result['documents_count']}\n" | |
| output += f"Text chunks: {result['nodes_count']}\n" | |
| output += f"Domain: {result['domain']}\n" | |
| output += f"Model: {result['model_name']}\n" | |
| # Add validation results if available | |
| if validation_results: | |
| output += "\n=== Entity Retention Validation ===\n" | |
| output += f"Entity retention: {validation_results.get('entity_retention', 0):.2f}%\n" | |
| output += f"Processing time: {validation_results.get('processing_time', 0):.2f} seconds\n" | |
| output += "\nThe documents are now ready for querying. Use the 'query' parameter to run a query." | |
| return output | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |