import os import io import magic import PyPDF2 import pandas as pd from docx import Document from PIL import Image import pytesseract import re import openpyxl from transformers import ( pipeline, AutoTokenizer, AutoModelForSequenceClassification, LayoutLMv3Processor, LayoutLMv3ForTokenClassification, AutoImageProcessor, AutoModelForImageClassification ) import torch import numpy as np from typing import Dict, List, Tuple, Optional import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class DocumentClassifier: """ A document classifier that uses Hugging Face models to classify different types of documents. """ def __init__(self): """ Initialize the document classifier with Microsoft models. """ self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {self.device}") # Initialize LayoutLMv3 for document understanding try: logger.info("Loading LayoutLMv3 model...") self.layoutlmv3_processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base") self.layoutlmv3_model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base") self.layoutlmv3_model.to(self.device) logger.info("✅ LayoutLMv3 model loaded successfully") except Exception as e: logger.warning(f"Failed to load LayoutLMv3 model: {e}") self.layoutlmv3_processor = None self.layoutlmv3_model = None # Initialize DIT model for document classification try: logger.info("Loading DIT model...") self.dit_processor = AutoImageProcessor.from_pretrained("microsoft/dit-base-finetuned-rvlcdip") self.dit_model = AutoModelForImageClassification.from_pretrained("microsoft/dit-base-finetuned-rvlcdip") self.dit_model.to(self.device) logger.info("✅ DIT model loaded successfully") except Exception as e: logger.warning(f"Failed to load DIT model: {e}") self.dit_processor = None self.dit_model = None # Fallback text classifier try: self.fallback_classifier = pipeline( "text-classification", model="distilbert-base-uncased-finetuned-sst-2-english", device=0 if self.device == "cuda" else -1 ) except Exception as e: logger.warning(f"Failed to load fallback classifier: {e}") self.fallback_classifier = None # Document type mappings self.document_types = { 'pdf': 'PDF Document', 'docx': 'Word Document', 'doc': 'Word Document', 'txt': 'Text Document', 'xlsx': 'Excel Spreadsheet', 'xls': 'Excel Spreadsheet', 'csv': 'CSV File', 'jpg': 'Image', 'jpeg': 'Image', 'png': 'Image', 'gif': 'Image', 'bmp': 'Image', 'tiff': 'Image', 'ppt': 'PowerPoint Presentation', 'pptx': 'PowerPoint Presentation' } # RVL-CDIP document classes (used by DIT model) self.rvlcdip_classes = [ 'letter', 'form', 'email', 'handwritten', 'advertisement', 'scientific report', 'scientific publication', 'specification', 'file folder', 'news article', 'budget', 'invoice', 'presentation', 'questionnaire', 'resume', 'memo' ] # Content-based classification keywords self.content_keywords = { 'letter': ['dear', 'sincerely', 'regards', 'yours truly', 'to whom it may concern'], 'form': ['form', 'application', 'registration', 'signature', 'date', 'name', 'address'], 'email': ['subject:', 'from:', 'to:', 'cc:', 'bcc:', 'sent:', 'received:'], 'handwritten': ['handwritten', 'hand written', 'manuscript', 'notes'], 'advertisement': ['advertisement', 'ad', 'promotion', 'sale', 'offer', 'discount'], 'scientific report': ['abstract', 'introduction', 'methodology', 'results', 'conclusion', 'references'], 'scientific publication': ['journal', 'publication', 'peer reviewed', 'doi:', 'issn:', 'volume'], 'specification': ['specification', 'requirements', 'technical', 'system', 'software', 'hardware'], 'file folder': ['folder', 'directory', 'file', 'document'], 'news article': ['news', 'article', 'breaking', 'reporter', 'journalist', 'headline'], 'budget': ['budget', 'financial', 'revenue', 'expense', 'profit', 'loss', 'balance'], 'invoice': ['invoice', 'bill', 'payment', 'amount due', 'total', 'subtotal', 'tax'], 'presentation': ['presentation', 'slide', 'powerpoint', 'agenda', 'meeting'], 'questionnaire': ['questionnaire', 'survey', 'question', 'answer', 'response'], 'resume': ['resume', 'cv', 'curriculum vitae', 'experience', 'education', 'skills'], 'memo': ['memo', 'memorandum', 'to:', 'from:', 'date:', 'subject:', 're:'] } def extract_text_from_file(self, file_path: str) -> str: """ Extract text content from various file types. Args: file_path: Path to the file Returns: Extracted text content """ try: # Get file type using python-magic file_type = magic.from_file(file_path, mime=True) file_extension = os.path.splitext(file_path)[1].lower().lstrip('.') text_content = "" if file_extension == 'pdf': text_content = self._extract_pdf_text(file_path) elif file_extension in ['docx', 'doc']: text_content = self._extract_word_text(file_path) elif file_extension in ['xlsx', 'xls']: text_content = self._extract_excel_text(file_path) elif file_extension == 'txt': text_content = self._extract_txt_text(file_path) elif file_extension in ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'tiff']: text_content = self._extract_image_text(file_path) else: # Try to read as text file try: with open(file_path, 'r', encoding='utf-8') as f: text_content = f.read() except: with open(file_path, 'r', encoding='latin-1') as f: text_content = f.read() return text_content except Exception as e: logger.error(f"Error extracting text from {file_path}: {e}") return "" def _extract_pdf_text(self, file_path: str) -> str: """Extract text from PDF files.""" try: with open(file_path, 'rb') as file: pdf_reader = PyPDF2.PdfReader(file) text = "" for page in pdf_reader.pages: text += page.extract_text() + "\n" return text except Exception as e: logger.error(f"Error extracting PDF text: {e}") return "" def _extract_word_text(self, file_path: str) -> str: """Extract text from Word documents.""" try: doc = Document(file_path) text = "" for paragraph in doc.paragraphs: text += paragraph.text + "\n" return text except Exception as e: logger.error(f"Error extracting Word text: {e}") return "" def _extract_excel_text(self, file_path: str) -> str: """Extract text from Excel files.""" try: workbook = openpyxl.load_workbook(file_path) text = "" for sheet_name in workbook.sheetnames: sheet = workbook[sheet_name] for row in sheet.iter_rows(values_only=True): text += " ".join([str(cell) for cell in row if cell is not None]) + "\n" return text except Exception as e: logger.error(f"Error extracting Excel text: {e}") return "" def _extract_txt_text(self, file_path: str) -> str: """Extract text from plain text files.""" try: with open(file_path, 'r', encoding='utf-8') as f: return f.read() except: try: with open(file_path, 'r', encoding='latin-1') as f: return f.read() except Exception as e: logger.error(f"Error extracting text file: {e}") return "" def _extract_image_text(self, file_path: str) -> str: """Extract text from images using OCR via pytesseract.""" try: image = Image.open(file_path).convert("RGB") text = pytesseract.image_to_string(image) return text or "" except Exception as e: logger.error(f"Error extracting image text (OCR): {e}") return "" def _tokenize_label(self, label: str) -> List[str]: """Tokenize a label into meaningful keywords for matching.""" stopwords = { 'the','a','an','and','or','of','with','for','to','by','on','in','this','that','valid','expired','less','than','one','two','years','year','more','not','certificate','document','card','form','report','record','statement','results','order','stamp','authority','authorization','affidavit','evaluation' } tokens = re.split(r"[^a-zA-Z0-9+]+", label.lower()) tokens = [t for t in tokens if t and t not in stopwords and len(t) > 2] return tokens def classify_against_labels(self, text: str, labels: List[str]) -> Dict[str, float]: """Score OCR text against a provided list of labels using simple keyword overlap.""" if not text.strip() or not labels: return {} text_lower = text.lower() scores: Dict[str, float] = {} for label in labels: keywords = self._tokenize_label(label) if not keywords: continue hits = 0 for kw in keywords: if kw in text_lower: hits += 1 # simple ratio over keywords scores[label] = hits / len(keywords) # normalize total = sum(scores.values()) if total > 0: scores = {k: v / total for k, v in scores.items()} return scores def classify_with_dit_model(self, image_path: str) -> Dict[str, float]: """ Classify document using DIT model (Document Image Transformer). Args: image_path: Path to the document image Returns: Dictionary with document type probabilities """ if not self.dit_model or not self.dit_processor: return {"unknown": 1.0} try: # Load and preprocess image image = Image.open(image_path).convert("RGB") inputs = self.dit_processor(images=image, return_tensors="pt").to(self.device) # Get predictions with torch.no_grad(): outputs = self.dit_model(**inputs) predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) # Map predictions to document types scores = {} for i, class_name in enumerate(self.rvlcdip_classes): scores[class_name] = float(predictions[0][i]) return scores except Exception as e: logger.error(f"DIT model classification failed: {e}") return {"unknown": 1.0} def classify_with_layoutlmv3(self, text: str, image_path: str = None) -> Dict[str, float]: """ Classify document using LayoutLMv3 model. Args: text: Text content of the document image_path: Optional path to document image Returns: Dictionary with document type probabilities """ if not self.layoutlmv3_model or not self.layoutlmv3_processor: return {"unknown": 1.0} try: # For now, we'll use text-only classification # In a full implementation, you'd also process the image/layout if not text.strip(): return {"unknown": 1.0} # Truncate text if too long max_length = 512 if len(text) > max_length: text = text[:max_length] # Simple text-based classification using keyword matching # LayoutLMv3 is primarily for token classification, so we'll use it differently text_lower = text.lower() scores = {} for doc_type, keywords in self.content_keywords.items(): score = 0 for keyword in keywords: if keyword in text_lower: score += 1 scores[doc_type] = score / len(keywords) if keywords else 0 # Normalize scores total_score = sum(scores.values()) if total_score > 0: scores = {k: v/total_score for k, v in scores.items()} else: scores = {"unknown": 1.0} return scores except Exception as e: logger.error(f"LayoutLMv3 classification failed: {e}") return {"unknown": 1.0} def classify_by_content(self, text: str, image_path: str = None) -> Dict[str, float]: """ Classify document based on content analysis using Microsoft models. Args: text: Text content to analyze image_path: Optional path to document image Returns: Dictionary with document type probabilities """ if not text.strip() and not image_path: return {"unknown": 1.0} # Try DIT model first if we have an image dit_scores = {} if image_path and os.path.exists(image_path): try: dit_scores = self.classify_with_dit_model(image_path) logger.info(f"DIT model classification: {dit_scores}") except Exception as e: logger.warning(f"DIT model failed: {e}") # Try LayoutLMv3 model layoutlmv3_scores = {} if text.strip(): try: layoutlmv3_scores = self.classify_with_layoutlmv3(text, image_path) logger.info(f"LayoutLMv3 classification: {layoutlmv3_scores}") except Exception as e: logger.warning(f"LayoutLMv3 model failed: {e}") # Fallback to keyword-based classification keyword_scores = {} if text.strip(): text_lower = text.lower() for doc_type, keywords in self.content_keywords.items(): score = 0 for keyword in keywords: if keyword in text_lower: score += 1 keyword_scores[doc_type] = score / len(keywords) if keywords else 0 # Combine scores from different methods combined_scores = {} all_doc_types = set(list(dit_scores.keys()) + list(layoutlmv3_scores.keys()) + list(keyword_scores.keys())) for doc_type in all_doc_types: score = 0 count = 0 if doc_type in dit_scores and dit_scores[doc_type] > 0: score += dit_scores[doc_type] * 0.5 # DIT gets higher weight count += 1 if doc_type in layoutlmv3_scores and layoutlmv3_scores[doc_type] > 0: score += layoutlmv3_scores[doc_type] * 0.3 count += 1 if doc_type in keyword_scores and keyword_scores[doc_type] > 0: score += keyword_scores[doc_type] * 0.2 count += 1 if count > 0: combined_scores[doc_type] = score # Fallback to fallback classifier if no good scores if not combined_scores or max(combined_scores.values()) < 0.1: if self.fallback_classifier and text.strip(): try: max_length = 512 if len(text) > max_length: text = text[:max_length] hf_result = self.fallback_classifier(text) if hf_result: # Map sentiment to document types sentiment = hf_result[0]['label'].lower() confidence = hf_result[0]['score'] if 'positive' in sentiment: combined_scores['letter'] = confidence * 0.3 combined_scores['email'] = confidence * 0.2 elif 'negative' in sentiment: combined_scores['memo'] = confidence * 0.3 combined_scores['form'] = confidence * 0.2 else: combined_scores['report'] = confidence * 0.2 except Exception as e: logger.warning(f"Fallback classifier failed: {e}") # Normalize scores total_score = sum(combined_scores.values()) if total_score > 0: combined_scores = {k: v/total_score for k, v in combined_scores.items()} else: combined_scores = {"unknown": 1.0} return combined_scores def classify_document(self, file_path: str, allowed_labels: Optional[List[str]] = None) -> Dict[str, any]: """ Classify a document and return comprehensive results. Args: file_path: Path to the document file Returns: Dictionary containing classification results """ try: # Get file extension file_extension = os.path.splitext(file_path)[1].lower().lstrip('.') file_type = self.document_types.get(file_extension, 'Unknown') # Extract text content text_content = self.extract_text_from_file(file_path) # If a custom label list is provided, score against it using OCR text if allowed_labels: label_scores = self.classify_against_labels(text_content, allowed_labels) # Fallback to generic method if scores are empty content_classification = label_scores if label_scores else self.classify_by_content(text_content) else: # Generic method (legacy) content_classification = self.classify_by_content(text_content) # Get the most likely document type most_likely_type = max(content_classification.items(), key=lambda x: x[1]) result = { 'file_path': file_path, 'file_name': os.path.basename(file_path), 'file_type': file_type, 'file_extension': file_extension, 'content_length': len(text_content), 'text_preview': text_content[:200] + "..." if len(text_content) > 200 else text_content, 'classification': most_likely_type[0], 'confidence': most_likely_type[1], 'all_scores': content_classification, 'success': True } return result except Exception as e: logger.error(f"Error classifying document {file_path}: {e}") return { 'file_path': file_path, 'file_name': os.path.basename(file_path), 'error': str(e), 'success': False } def classify_multiple_documents(self, file_paths: List[str]) -> List[Dict[str, any]]: """ Classify multiple documents. Args: file_paths: List of file paths to classify Returns: List of classification results """ results = [] for file_path in file_paths: result = self.classify_document(file_path) results.append(result) return results