securedocai / document_processor_hf.py
navid72m's picture
Update document_processor_hf.py
0591f7d verified
import os
import re
import logging
import tempfile
from typing import List, Dict, Any, Tuple, Optional
import numpy as np
from sentence_transformers import SentenceTransformer
from dataclasses import dataclass
from collections import defaultdict
import functools
logger = logging.getLogger(__name__)
@dataclass
class DocumentChunk:
"""Represents a document chunk with metadata"""
text: str
chunk_id: int
start_pos: int
end_pos: int
entities: List[str] = None
chunk_type: str = "content"
relevance_score: float = 0.0
@dataclass
class Entity:
"""Represents an extracted entity"""
text: str
label: str
confidence: float
start_pos: int
end_pos: int
# Simple caching decorator to replace Streamlit's cache
def simple_cache(func):
"""Simple caching decorator"""
cache = {}
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Create a simple key from args (excluding self)
key = str(args[1:]) + str(sorted(kwargs.items()))
if key not in cache:
cache[key] = func(*args, **kwargs)
return cache[key]
return wrapper
class DocumentProcessor:
"""
Streamlined document processor for Hugging Face Spaces deployment.
Focuses on core functionality with minimal dependencies.
"""
def __init__(self):
"""Initialize the document processor"""
self.chunks = []
self.embeddings = []
self.entities = []
self.document_text = ""
self.document_type = "general"
# Initialize embedding model with caching
self.embed_model = self._load_embedding_model()
# Simple entity patterns for basic extraction
self.entity_patterns = {
'PERSON': [
r'\b([A-Z][a-z]{1,15}\s+[A-Z][a-z]{1,15}(?:\s+[A-Z][a-z]{1,15})?)\b',
r'\b(?:Mr\.|Ms\.|Mrs\.|Dr\.)\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+){1,2})'
],
'EMAIL': [
r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'
],
'PHONE': [
r'\+?1?[-.\s]?\(?([0-9]{3})\)?[-.\s]?([0-9]{3})[-.\s]?([0-9]{4})'
],
'DATE': [
r'\b(?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{1,2},?\s+\d{4}\b',
r'\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b'
],
'ORGANIZATION': [
r'\b([A-Z][a-zA-Z\s&.,]+?)\s+(?:Inc|LLC|Corp|Company|Technologies|University|College|Institute)\b'
]
}
# Document type indicators
self.doc_type_indicators = {
'resume': ['objective', 'summary', 'experience', 'education', 'skills', 'employment'],
'report': ['executive summary', 'methodology', 'findings', 'conclusion', 'analysis'],
'contract': ['agreement', 'party', 'whereas', 'terms', 'conditions'],
'manual': ['instructions', 'procedure', 'step', 'guide', 'tutorial'],
'academic': ['abstract', 'introduction', 'literature review', 'methodology', 'results']
}
@simple_cache
def _load_embedding_model(self):
"""Load embedding model with simple caching"""
try:
logger.info("🔄 Loading embedding model...")
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
logger.info("✅ Embedding model loaded successfully")
return model
except Exception as e:
logger.error(f"❌ Failed to load embedding model: {e}")
raise
def detect_document_type(self, text: str) -> str:
"""Detect document type based on content"""
text_lower = text.lower()
type_scores = {}
for doc_type, indicators in self.doc_type_indicators.items():
score = sum(1 for indicator in indicators if indicator in text_lower)
type_scores[doc_type] = score
if type_scores:
detected_type = max(type_scores, key=type_scores.get)
if type_scores[detected_type] >= 2:
return detected_type
return 'general'
def extract_text_from_file(self, file_path: str) -> str:
"""Extract text from various file types"""
_, ext = os.path.splitext(file_path.lower())
try:
if ext == '.pdf':
return self._extract_from_pdf(file_path)
elif ext in ['.txt', '.md']:
return self._extract_from_text(file_path)
elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
return self._extract_from_image(file_path)
elif ext in ['.docx']:
return self._extract_from_docx(file_path)
else:
# Fallback: try to read as text
return self._extract_from_text(file_path)
except Exception as e:
logger.error(f"Text extraction failed for {file_path}: {e}")
return f"[Error extracting text from {os.path.basename(file_path)}: {str(e)}]"
def _extract_from_pdf(self, file_path: str) -> str:
"""Extract text from PDF files"""
try:
import PyPDF2
text = ""
with open(file_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
for page_num, page in enumerate(reader.pages):
try:
text += page.extract_text() + "\n\n"
# Limit to prevent memory issues
if len(text) > 100000: # 100KB limit
text += "\n[Note: PDF truncated due to size]"
break
except Exception as e:
logger.warning(f"Error extracting page {page_num}: {e}")
continue
return text
except ImportError:
try:
import pdfplumber
with pdfplumber.open(file_path) as pdf:
text = ""
for i, page in enumerate(pdf.pages):
try:
text += page.extract_text() + "\n\n"
if len(text) > 100000:
text += "\n[Note: PDF truncated due to size]"
break
except Exception as e:
logger.warning(f"Error extracting page {i}: {e}")
continue
return text
except ImportError:
return f"[PDF file: {os.path.basename(file_path)}]\nPDF extraction libraries not available."
def _extract_from_text(self, file_path: str) -> str:
"""Extract text from text files"""
encodings = ['utf-8', 'latin-1', 'cp1252']
for encoding in encodings:
try:
with open(file_path, 'r', encoding=encoding) as f:
return f.read()
except UnicodeDecodeError:
continue
return f"[Text file: {os.path.basename(file_path)}]\nCould not decode file."
def _extract_from_image(self, file_path: str) -> str:
"""Extract text from images using OCR (simplified)"""
try:
import pytesseract
from PIL import Image
image = Image.open(file_path)
text = pytesseract.image_to_string(image)
if text.strip():
return f"[Image: {os.path.basename(file_path)}]\n\nExtracted text:\n{text}"
else:
return f"[Image: {os.path.basename(file_path)}]\nNo text could be extracted from this image."
except ImportError:
return f"[Image: {os.path.basename(file_path)}]\nOCR library not available for text extraction."
except Exception as e:
return f"[Image: {os.path.basename(file_path)}]\nError extracting text: {str(e)}"
def _extract_from_docx(self, file_path: str) -> str:
"""Extract text from DOCX files"""
try:
import docx
doc = docx.Document(file_path)
text = ""
for paragraph in doc.paragraphs:
text += paragraph.text + "\n"
return text
except ImportError:
return f"[DOCX file: {os.path.basename(file_path)}]\nDOCX extraction library not available."
except Exception as e:
return f"[DOCX file: {os.path.basename(file_path)}]\nError extracting text: {str(e)}"
def extract_entities(self, text: str) -> List[Entity]:
"""Extract entities using regex patterns"""
entities = []
for label, patterns in self.entity_patterns.items():
for pattern in patterns:
matches = re.finditer(pattern, text, re.MULTILINE | re.IGNORECASE)
for match in matches:
entity_text = match.group(1) if match.groups() else match.group(0)
entity_text = entity_text.strip()
if self._is_valid_entity(entity_text, label):
entities.append(Entity(
text=entity_text,
label=label,
confidence=0.8, # Default confidence for regex matches
start_pos=match.start(),
end_pos=match.end()
))
# Deduplicate entities
return self._deduplicate_entities(entities)
def _is_valid_entity(self, text: str, label: str) -> bool:
"""Validate extracted entities"""
if not text or len(text.strip()) < 2:
return False
if label == 'PERSON':
# Check if it looks like a person name
words = text.split()
if len(words) < 2 or len(words) > 3:
return False
# Should not contain common non-name words
non_name_words = {'resume', 'objective', 'summary', 'experience', 'education', 'skills'}
if any(word.lower() in non_name_words for word in words):
return False
return True
def _deduplicate_entities(self, entities: List[Entity]) -> List[Entity]:
"""Remove duplicate entities"""
seen = set()
unique_entities = []
for entity in entities:
key = (entity.text.lower(), entity.label)
if key not in seen:
seen.add(key)
unique_entities.append(entity)
return sorted(unique_entities, key=lambda x: x.confidence, reverse=True)
def create_chunks(self, text: str, chunk_size: int = 512, overlap: int = 50) -> List[DocumentChunk]:
"""Create text chunks with overlap"""
chunks = []
# Clean text
text = re.sub(r'\s+', ' ', text).strip()
# Split by sentences first
sentences = re.split(r'(?<=[.!?])\s+', text)
current_chunk = ""
current_start = 0
chunk_id = 0
for sentence in sentences:
if len(current_chunk) + len(sentence) > chunk_size and current_chunk:
# Create chunk
chunks.append(DocumentChunk(
text=current_chunk.strip(),
chunk_id=chunk_id,
start_pos=current_start,
end_pos=current_start + len(current_chunk),
entities=self._extract_chunk_entities(current_chunk)
))
chunk_id += 1
# Create overlap
words = current_chunk.split()
overlap_words = words[-overlap//4:] if len(words) > overlap//4 else []
current_chunk = " ".join(overlap_words) + " " + sentence
current_start = max(0, current_start + len(current_chunk) - len(" ".join(overlap_words)))
else:
if current_chunk:
current_chunk += " " + sentence
else:
current_chunk = sentence
# Add final chunk
if current_chunk.strip():
chunks.append(DocumentChunk(
text=current_chunk.strip(),
chunk_id=chunk_id,
start_pos=current_start,
end_pos=current_start + len(current_chunk),
entities=self._extract_chunk_entities(current_chunk)
))
return chunks
def _extract_chunk_entities(self, chunk_text: str) -> List[str]:
"""Extract entity names present in a chunk"""
chunk_entities = []
for entity in self.entities:
if entity.text.lower() in chunk_text.lower():
chunk_entities.append(entity.text)
return chunk_entities
def create_embeddings(self, chunks: List[DocumentChunk]) -> List[List[float]]:
"""Create embeddings for chunks"""
texts = [chunk.text for chunk in chunks]
try:
# Create embeddings in batches for efficiency
batch_size = 32
embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
batch_embeddings = self.embed_model.encode(batch, show_progress_bar=False)
embeddings.extend(batch_embeddings.tolist())
return embeddings
except Exception as e:
logger.error(f"Failed to create embeddings: {e}")
# Return zero embeddings as fallback
return [[0.0] * 384 for _ in texts] # 384 is MiniLM dimension
def generate_suggestions(self, document_type: str, entities: List[Entity]) -> List[str]:
"""Generate suggested questions based on document content"""
suggestions = []
# Find primary person entity
person_entities = [e for e in entities if e.label == 'PERSON']
primary_person = person_entities[0] if person_entities else None
if document_type == 'resume':
if primary_person:
suggestions.extend([
"Whose resume is this?",
f"What are {primary_person.text}'s qualifications?",
f"What skills does {primary_person.text} have?",
f"What is {primary_person.text}'s work experience?"
])
else:
suggestions.extend([
"Whose CV is this?",
"What are the candidate's qualifications?",
"What skills are mentioned?",
"What work experience is listed?"
])
elif document_type == 'report':
suggestions.extend([
"What is the main topic of this report?",
"What are the key findings?",
"What methodology was used?",
"What are the conclusions?"
])
else:
suggestions.extend([
"What is this document about?",
"What are the main topics discussed?",
"Who are the key people mentioned?",
"What important information is contained here?"
])
# Add entity-specific suggestions
if any(e.label == 'EMAIL' for e in entities):
suggestions.append("What contact information is provided?")
if any(e.label == 'ORGANIZATION' for e in entities):
suggestions.append("What organizations are mentioned?")
return suggestions[:5] # Return top 5 suggestions
def process_document(self, file_path: str, use_smart_processing: bool = True) -> Dict[str, Any]:
"""
Process a document and extract all information.
Args:
file_path: Path to the document file
use_smart_processing: Whether to use smart entity extraction
Returns:
Dictionary with processing results
"""
try:
logger.info(f"📄 Processing document: {os.path.basename(file_path)}")
# Extract text
self.document_text = self.extract_text_from_file(file_path)
if not self.document_text or len(self.document_text.strip()) < 10:
return {
'success': False,
'error': 'Could not extract meaningful text from document'
}
# Detect document type
self.document_type = self.detect_document_type(self.document_text)
# Extract entities if smart processing is enabled
if use_smart_processing:
self.entities = self.extract_entities(self.document_text)
else:
self.entities = []
# Create chunks
self.chunks = self.create_chunks(self.document_text)
# Create embeddings
self.embeddings = self.create_embeddings(self.chunks)
# Generate suggestions
suggestions = self.generate_suggestions(self.document_type, self.entities)
logger.info(f"✅ Processing complete: {len(self.chunks)} chunks, {len(self.entities)} entities")
return {
'success': True,
'chunks': self.chunks,
'entities': self.entities,
'document_type': self.document_type,
'entities_found': len(self.entities),
'suggestions': suggestions,
'text_length': len(self.document_text),
'processing_stats': {
'chunks_created': len(self.chunks),
'entities_extracted': len(self.entities),
'document_type': self.document_type
}
}
except Exception as e:
logger.error(f"❌ Document processing failed: {e}")
return {
'success': False,
'error': str(e)
}
def query_document(
self,
query: str,
top_k: int = 5,
use_smart_retrieval: bool = True,
use_prf: bool = False,
use_variants: bool = False,
use_reranking: bool = False
) -> Dict[str, Any]:
"""
Query the processed document.
Args:
query: User's question
top_k: Number of chunks to retrieve
use_smart_retrieval: Whether to use entity-aware retrieval
use_prf: Whether to use pseudo relevance feedback
use_variants: Whether to generate query variants
use_reranking: Whether to apply reranking
Returns:
Dictionary with context and metadata
"""
if not self.chunks or not self.embeddings:
return {
'context': '',
'chunks': [],
'error': 'No document processed'
}
try:
# Create query embedding
query_embedding = self.embed_model.encode([query])[0]
# Calculate similarities
similarities = []
for i, chunk_embedding in enumerate(self.embeddings):
similarity = np.dot(query_embedding, chunk_embedding) / (
np.linalg.norm(query_embedding) * np.linalg.norm(chunk_embedding)
)
similarities.append((i, float(similarity)))
# Sort by similarity
similarities.sort(key=lambda x: x[1], reverse=True)
# Apply smart retrieval boosts if enabled
if use_smart_retrieval:
similarities = self._apply_smart_boosts(query, similarities)
# Get top chunks
top_indices = [idx for idx, _ in similarities[:top_k]]
selected_chunks = [self.chunks[i] for i in top_indices]
# Build context
context_parts = []
for i, chunk in enumerate(selected_chunks):
context_parts.append(f"[Chunk {i+1}]\n{chunk.text}")
context = "\n\n".join(context_parts)
return {
'context': context,
'chunks': selected_chunks,
'similarities': [similarities[i][1] for i in range(min(top_k, len(similarities)))],
'query_analysis': {
'entity_matches': self._find_entity_matches(query),
'query_type': self._analyze_query_type(query)
},
'enhancement_info': {
'smart_retrieval_applied': use_smart_retrieval,
'prf_applied': use_prf,
'variants_generated': use_variants,
'reranking_applied': use_reranking
}
}
except Exception as e:
logger.error(f"❌ Query processing failed: {e}")
return {
'context': '',
'chunks': [],
'error': str(e)
}
def _apply_smart_boosts(self, query: str, similarities: List[Tuple[int, float]]) -> List[Tuple[int, float]]:
"""Apply smart retrieval boosts based on entities and query analysis"""
query_lower = query.lower()
boosted_similarities = []
for chunk_idx, similarity in similarities:
chunk = self.chunks[chunk_idx]
boost = 0.0
# Boost for entity matches
for entity in self.entities:
if entity.text.lower() in query_lower and entity.text.lower() in chunk.text.lower():
boost += 0.2 * entity.confidence
# Boost for query type matches
if any(word in query_lower for word in ['who', 'whose', 'name']):
if any(entity.label == 'PERSON' for entity in self.entities
if entity.text.lower() in chunk.text.lower()):
boost += 0.3
final_similarity = min(1.0, similarity + boost)
boosted_similarities.append((chunk_idx, final_similarity))
# Re-sort after boosting
boosted_similarities.sort(key=lambda x: x[1], reverse=True)
return boosted_similarities
def _find_entity_matches(self, query: str) -> List[str]:
"""Find entities mentioned in the query"""
query_lower = query.lower()
matches = []
for entity in self.entities:
if entity.text.lower() in query_lower:
matches.append(entity.text)
return matches
def _analyze_query_type(self, query: str) -> str:
"""Analyze the type of query"""
query_lower = query.lower()
if any(word in query_lower for word in ['who', 'whose', 'name']):
return 'identity'
elif any(word in query_lower for word in ['what', 'describe', 'explain']):
return 'descriptive'
elif any(word in query_lower for word in ['when', 'date', 'time']):
return 'temporal'
elif any(word in query_lower for word in ['where', 'location']):
return 'location'
elif any(word in query_lower for word in ['how', 'process', 'method']):
return 'procedural'
else:
return 'general'
def get_document_stats(self) -> Dict[str, Any]:
"""Get statistics about the processed document"""
return {
'document_type': self.document_type,
'text_length': len(self.document_text),
'chunks_count': len(self.chunks),
'entities_count': len(self.entities),
'entities_by_type': {
label: len([e for e in self.entities if e.label == label])
for label in set(e.label for e in self.entities)
} if self.entities else {},
'avg_chunk_length': np.mean([len(chunk.text) for chunk in self.chunks]) if self.chunks else 0
}