OpenDeepResearch / scripts /document_tool.py
Leonardo
Update scripts/document_tool.py
fdb59f7 verified
raw
history blame
23.3 kB
#!/usr/bin/env python
# coding=utf-8
# Copyright 2025 The Footscray Coding Collective. All rights reserved.
"""
General Document Processing Tool for Smolagents
This tool processes various types of documents with domain-specific models,
optimizing for intelligent document parsing, entity extraction, and
customized retrieval tasks.
Author: Zhou Wang
"""
import os
import re
import tempfile
import time
from typing import Any, Dict, List, Optional, Union
import numpy as np
# Import Smolagents Tool class
from smolagents import Tool
# Import NLP components
try:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from llama_index.core import Document, SimpleDirectoryReader, VectorStoreIndex
from llama_index.core.ingestion import IngestionPipeline
from llama_index.core.node_parser import MarkdownNodeParser
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from sklearn.metrics.pairwise import cosine_similarity
except ImportError:
raise ImportError(
"Required dependencies not found. Please install with: "
"pip install llama-index langchain scikit-learn tqdm"
)
# Model configurations based on domain specialization
DOMAIN_MODELS = {
"legal": {
"name": "joelito/legal-xlm-roberta-base",
"description": "Specialized for legal documents with citation preservation",
"max_length": 512,
"requires_gpu": True,
},
"financial": {
"name": "thenlper/finetuned-finbert-slot-filling",
"description": "Financial document analysis with entity extraction",
"max_length": 512,
"requires_gpu": False,
},
"medical": {
"name": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
"description": "Medical text processing optimized for clinical terms",
"max_length": 512,
"requires_gpu": True,
},
"technical": {
"name": "allenai/scibert_scivocab_uncased",
"description": "Scientific and technical document processing",
"max_length": 512,
"requires_gpu": True,
},
"general": {
"name": "sentence-transformers/all-mpnet-base-v2",
"description": "General purpose embedding model for all document types",
"max_length": 512,
"requires_gpu": False,
},
}
class DocumentProcessor:
"""
Processor for documents with domain-specific models,
entity preservation, and customizable processing capabilities.
"""
def __init__(
self,
domain: str = "general",
model_key: Optional[str] = None,
use_gpu: bool = False,
chunk_size: int = 512,
chunk_overlap: int = 100,
custom_patterns: Optional[List[str]] = None,
):
"""
Initialize the document processor.
Args:
domain: Domain specialization ('legal', 'financial', 'medical', 'technical', 'general')
model_key: Specific model to use (overrides domain selection)
use_gpu: Whether to use GPU for embeddings (if available)
chunk_size: Size of text chunks for processing
chunk_overlap: Overlap between chunks to preserve context
custom_patterns: Additional regex patterns for text cleaning
"""
# Store domain
self.domain = domain
# If model_key provided, use it directly
if model_key:
model_name = model_key
device = "cuda" if use_gpu else "cpu"
else:
# Otherwise select model based on domain
if domain not in DOMAIN_MODELS:
print(
f"Warning: Domain '{domain}' not found. Using 'general' as default."
)
domain = "general"
model_config = DOMAIN_MODELS[domain]
model_name = model_config["name"]
device = "cuda" if use_gpu and model_config["requires_gpu"] else "cpu"
# Initialize embedding model
try:
self.embed_model = HuggingFaceEmbedding(
model_name=model_name,
device=device,
tokenizer_kwargs={
"trust_remote_code": True,
"max_length": 512,
"truncation": True,
},
)
# Store model information for reference
self.model_name = model_name
self.device = device
except Exception as e:
raise RuntimeError(f"Failed to initialize embedding model: {str(e)}")
# Domain-optimized text splitter
self.splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=[
"\n## ",
"\n### ",
"\n#### ", # Headers
"\n\n",
"\n", # Paragraphs
". ",
"! ",
"? ", # Sentences
";",
":", # Clause boundaries
" ", # Last resort
],
)
# Base cleaning patterns
self.cleaning_patterns = [
r"^Page\s\d+(\s+of\s+\d+)?$", # Page numbers
r"^©.*\b(Company|Inc|Ltd)\b.*$", # Copyright lines
r"^All rights reserved.*?$", # Legal boilerplate
r"^-+$", # Separator lines
r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}(:\d{2})?$", # Timestamps
r"(?i)^(confidential|proprietary|internal use only)", # Security tags
]
# Add custom patterns if provided
if custom_patterns:
self.cleaning_patterns.extend(custom_patterns)
# Join all patterns with the OR operator
combined_pattern = "|".join(
f"({pattern})" for pattern in self.cleaning_patterns
)
# Compile the combined pattern
self.cleaning_pattern = re.compile(
combined_pattern, flags=re.MULTILINE | re.IGNORECASE
)
# Initialize domain-specific processors
self._init_domain_processors()
def _init_domain_processors(self):
"""Initialize domain-specific processors based on selected domain."""
# Domain-specific entity patterns
self.entity_patterns = {}
# Set up domain-specific patterns and processors
if self.domain == "legal":
self.entity_patterns = {
"case_citation": r"\[\d{4}\]\s+[A-Z]+\s+\d+", # [2019] UKSC 20
"statute": r"\b(?:Art\.|Section)\s+\d+(\.\d+)?", # Art. 5, Section 3.1
"legal_ref": r"\b[A-Za-z]+\s+v\.?\s+[A-Za-z]+", # Smith v. Jones
}
self.process_entities = self._process_legal_entities
if self.domain == "financial":
self.entity_patterns = {
"monetary": r"\$\s*\d+(?:\.\d+)?(?:\s*(?:million|billion|trillion))?", # $5.2 million
"percentage": r"\d+(?:\.\d+)?\s*%", # 10.5%
"date_range": r"(?:Q[1-4]|FY)\s+\d{4}", # Q2 2023, FY 2022
}
self.process_entities = self._process_financial_entities
if self.domain == "medical":
self.entity_patterns = {
"dosage": r"\d+(?:\.\d+)?\s*(?:mg|mcg|g|ml|oz)", # 10mg, 5.5ml
"medical_code": r"[A-Z]\d{2}(?:\.\d+)?", # ICD codes like E11.9
"vital_sign": r"\d+(?:\.\d+)?\s*(?:bpm|mmHg|°[CF])", # 120 bpm, 98.6°F
}
self.process_entities = self._process_medical_entities
if self.domain == "technical":
self.entity_patterns = {
"version": r"v\d+(?:\.\d+){1,3}", # v1.2.3
"code_ref": r"(?:\w+\.)+\w+\(\)", # function calls like math.sqrt()
"tech_standard": r"(?:RFC|ISO|IEEE)\s*\d+", # RFC 1918, ISO 9001
}
self.process_entities = self._process_technical_entities
else: # General domain or fallback
self.entity_patterns = {
"url": r"https?://\S+", # URLs
"email": r"\S+@\S+\.\S+", # Email addresses
"date": r"\d{1,2}[/-]\d{1,2}[/-]\d{2,4}", # Dates
}
self.process_entities = self._process_general_entities
def _process_legal_entities(self, text: str) -> str:
"""Process legal document entities."""
# Preserve citation patterns
# Pattern 1: Case citations [2019] UKSC 20
# Already well-structured, so no changes needed
# Pattern 2: Standardize section references (§3.1, §123)
processed = re.sub(r"§(\d+(\.\d+)?)", r"Section \1", text)
# Pattern 3: Handle legal abbreviations (e.g., Art. -> Article)
processed = re.sub(r"\bArt\.\s+(\d+)", r"Article \1", processed)
# Pattern 4: Standardize case names with v. and vs.
processed = re.sub(r"\bv\s+", r"v. ", processed)
processed = re.sub(r"\bvs\s+", r"v. ", processed)
return processed
def _process_financial_entities(self, text: str) -> str:
"""Process financial document entities."""
# Pattern 1: Standardize monetary values
processed = re.sub(
r"\$\s*(\d+)(?:,\d{3})*(?:\.\d+)?",
lambda m: f"${float(m.group(1).replace(',', ''))}",
text,
)
# Pattern 2: Standardize percentage representations
processed = re.sub(r"(\d+(?:\.\d+)?)\s*(?:percent|pct)", r"\1%", processed)
# Pattern 3: Standardize fiscal periods
processed = re.sub(r"(?:fiscal year|FY)\s+(\d{4})", r"FY \1", processed)
# Pattern 4: Standardize quarterly references
processed = re.sub(r"(?:quarter|Q)(\d)\s+(\d{4})", r"Q\1 \2", processed)
return processed
def _process_medical_entities(self, text: str) -> str:
"""Process medical document entities."""
# Pattern 1: Standardize dosage format
processed = re.sub(
r"(\d+(?:\.\d+)?)\s*(milligrams?|mcgs?|grams?|milliliters?)",
lambda m: f"{m.group(1)} {m.group(2)[0:2]}",
text,
)
# Pattern 2: Standardize temperature format
processed = re.sub(r"(\d+(?:\.\d+)?)\s*degrees?\s*([CF])", r"\1°\2", processed)
# Pattern 3: Standardize vital signs
processed = re.sub(
r"(\d+(?:\.\d+)?)\s*(?:beats per minute|BPM)", r"\1 bpm", processed
)
return processed
def _process_technical_entities(self, text: str) -> str:
"""Process technical document entities."""
# Pattern 1: Standardize version numbers
processed = re.sub(r"version\s+(\d+(?:\.\d+){1,3})", r"v\1", text)
# Pattern 2: RFC/ISO pattern standardization
processed = re.sub(r"\b(RFC|ISO|IEEE)\s*[:#]?\s*(\d+)", r"\1 \2", processed)
# Pattern 3: Standardize code references
# This is a simplified example
processed = re.sub(r"function\s+(\w+)\s*\(", r"\1(", processed)
return processed
def _process_general_entities(self, text: str) -> str:
"""Process general document entities."""
# General cleaning and standardization
processed = text
# URLs preserved as-is
# Simple date standardization
processed = re.sub(
r"(\d{1,2})/(\d{1,2})/(\d{2})(?!\d)",
r"\1/\2/20\3", # Assume 2-digit years are 2000s
processed,
)
return processed
def remove_boilerplate(self, text: str) -> str:
"""
Remove common document boilerplate patterns from text.
Args:
text: The input text to process
Returns:
Text with boilerplate patterns removed
"""
return self.cleaning_pattern.sub("", text)
def clean_text(self, text: str) -> str:
"""
Clean text while preserving domain-specific entities.
Args:
text: The input text to clean
Returns:
Cleaned text with domain entities preserved
"""
# First remove boilerplate
cleaned = self.remove_boilerplate(text)
# Then process domain-specific entities
cleaned = self.process_entities(cleaned)
return cleaned
def create_pipeline(self) -> IngestionPipeline:
"""
Create a document processing pipeline.
Returns:
Configured IngestionPipeline object
"""
return IngestionPipeline(
transformations=[
self.clean_text,
MarkdownNodeParser(),
self.splitter,
self.embed_model,
]
)
def validate_entity_retention(self, documents: List[Document]) -> Dict[str, float]:
"""
Measure semantic similarity of entities before/after text cleaning.
Args:
documents: List of Document objects to validate
Returns:
Dictionary with validation metrics
"""
if not documents:
return {"entity_retention": 0.0, "processing_time": 0.0}
start_time = time.time()
# Extract original texts
original_texts = [doc.text for doc in documents[:5]] # Sample for performance
# Apply cleaning
processed_texts = [self.clean_text(text) for text in original_texts]
# Calculate embeddings
try:
# Direct access to the underlying HuggingFace model
orig_embeds = self.embed_model._model.encode(original_texts)
proc_embeds = self.embed_model._model.encode(processed_texts)
# Calculate similarity
similarities = cosine_similarity(orig_embeds, proc_embeds).diagonal()
avg_similarity = float(np.mean(similarities))
processing_time = time.time() - start_time
return {
"entity_retention": avg_similarity * 100, # As percentage
"processing_time": processing_time,
"sample_size": len(original_texts),
}
except Exception as e:
return {"entity_retention": 0.0, "processing_time": 0.0, "error": str(e)}
def process_documents(self, documents: List[Document]) -> Dict[str, Any]:
"""
Process a list of documents.
Args:
documents: List of Document objects to process
Returns:
Dictionary with processing results and stats
"""
if not documents:
return {"status": "error", "message": "No documents provided"}
try:
# Create pipeline and process documents
pipeline = self.create_pipeline()
nodes = pipeline.run(documents=documents)
# Create vector index
index = VectorStoreIndex(nodes)
query_engine = index.as_query_engine()
# Return success with stats
return {
"status": "success",
"nodes_count": len(nodes),
"documents_count": len(documents),
"domain": self.domain,
"model_name": self.model_name,
"query_engine": query_engine, # This will be used for querying
}
except Exception as e:
return {"status": "error", "message": str(e)}
class DocumentProcessorTool(Tool):
"""
General-purpose document processing tool with domain specialization.
"""
name = "document_processor"
description = (
"Processes documents with domain-specific models optimized for "
"entity preservation and retrieval performance. Supports legal, "
"financial, medical, technical and general document types."
)
inputs = {
"text": {
"type": "string",
"description": "Document text to process. Provide either text or file_paths.",
"optional": True,
},
"file_paths": {
"type": "string",
"description": "Comma-separated list of file paths or a directory path containing documents. Provide either text or file_paths.",
"optional": True,
},
"domain": {
"type": "string",
"description": "Document domain for specialized processing: legal, financial, medical, technical, or general.",
"default": "general",
},
"model_name": {
"type": "string",
"description": "Specific embedding model name to use (optional, overrides domain selection).",
"optional": True,
},
"query": {
"type": "string",
"description": "Optional query to run against the processed documents.",
"optional": True,
},
"validate_entities": {
"type": "boolean",
"description": "Whether to validate entity retention in the processed documents.",
"default": False,
},
"use_gpu": {
"type": "boolean",
"description": "Whether to use GPU for embedding calculations if available.",
"default": False,
},
}
output_type = "string"
def _load_documents(self, input_path: str) -> List[Document]:
"""
Load documents from a file path or directory.
Args:
input_path: Path to a file or directory
Returns:
List of Document objects
"""
if os.path.isfile(input_path):
# Create a SimpleDirectoryReader for the file's directory
# and filter to only include this file
directory = os.path.dirname(input_path)
filename = os.path.basename(input_path)
return SimpleDirectoryReader(
input_dir=directory,
required_exts=[
os.path.splitext(filename)[1][1:]
], # Extension without dot
filename_as_id=True,
).load_data()
elif os.path.isdir(input_path):
return SimpleDirectoryReader(
input_dir=input_path,
filename_as_id=True,
).load_data()
else:
raise ValueError(f"Path not found: {input_path}")
def _create_document_from_text(self, text: str) -> List[Document]:
"""
Create a Document object from text.
Args:
text: Text content
Returns:
List containing a single Document object
"""
# Create a temporary file to store the text
with tempfile.NamedTemporaryFile(
mode="w", suffix=".md", delete=False
) as temp_file:
temp_file.write(text)
temp_path = temp_file.name
try:
# Load the document from the temporary file
documents = self._load_documents(temp_path)
return documents
finally:
# Clean up the temporary file
os.remove(temp_path)
def forward(
self,
text: Optional[str] = None,
file_paths: Optional[str] = None,
domain: str = "general",
model_name: Optional[str] = None,
query: Optional[str] = None,
validate_entities: bool = False,
use_gpu: bool = False,
) -> str:
"""
Process documents and optionally run a query.
Args:
text: Document text to process
file_paths: Comma-separated list of file paths or a directory path
domain: Document domain specialization
model_name: Specific embedding model to use
query: Optional query to run against the processed documents
validate_entities: Whether to validate entity retention
use_gpu: Whether to use GPU for embeddings
Returns:
Processing results or query response as a string
"""
# Validate inputs
if not text and not file_paths:
return "Error: Either text or file_paths must be provided."
try:
# Initialize processor
processor = DocumentProcessor(
domain=domain,
model_key=model_name,
use_gpu=use_gpu,
)
# Load documents
documents = []
if text:
documents.extend(self._create_document_from_text(text))
if file_paths:
# Handle comma-separated paths
paths = [path.strip() for path in file_paths.split(",")]
for path in paths:
try:
docs = self._load_documents(path)
documents.extend(docs)
except Exception as e:
return f"Error loading documents from {path}: {str(e)}"
# Check if we have documents to process
if not documents:
return "Error: No valid documents found."
# Validate entity retention if requested
validation_results = {}
if validate_entities:
validation_results = processor.validate_entity_retention(documents)
# Process documents
result = processor.process_documents(documents)
if result["status"] != "success":
return f"Processing error: {result['message']}"
# Run query if provided
if query and "query_engine" in result:
query_engine = result["query_engine"]
response = query_engine.query(query)
# Format the response
output = f"Query: {query}\n\nResponse: {response}\n\n"
output += f"Documents processed: {result['documents_count']}\n"
output += f"Text chunks: {result['nodes_count']}\n"
output += f"Domain: {result['domain']}\n"
output += f"Model: {result['model_name']}\n"
# Add validation results if available
if validation_results:
output += "\n=== Entity Retention Validation ===\n"
output += f"Entity retention: {validation_results.get('entity_retention', 0):.2f}%\n"
output += f"Processing time: {validation_results.get('processing_time', 0):.2f} seconds\n"
return output
# If no query, return processing stats
output = "Document processing complete.\n\n"
output += f"Documents processed: {result['documents_count']}\n"
output += f"Text chunks: {result['nodes_count']}\n"
output += f"Domain: {result['domain']}\n"
output += f"Model: {result['model_name']}\n"
# Add validation results if available
if validation_results:
output += "\n=== Entity Retention Validation ===\n"
output += f"Entity retention: {validation_results.get('entity_retention', 0):.2f}%\n"
output += f"Processing time: {validation_results.get('processing_time', 0):.2f} seconds\n"
output += "\nThe documents are now ready for querying. Use the 'query' parameter to run a query."
return output
except Exception as e:
return f"Error: {str(e)}"