Spaces:
Paused
Paused
| import os | |
| import io | |
| import magic | |
| import PyPDF2 | |
| import pandas as pd | |
| from docx import Document | |
| from PIL import Image | |
| import pytesseract | |
| import re | |
| import openpyxl | |
| from transformers import ( | |
| pipeline, | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| LayoutLMv3Processor, | |
| LayoutLMv3ForTokenClassification, | |
| AutoImageProcessor, | |
| AutoModelForImageClassification | |
| ) | |
| import torch | |
| import numpy as np | |
| from typing import Dict, List, Tuple, Optional | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class DocumentClassifier: | |
| """ | |
| A document classifier that uses Hugging Face models to classify different types of documents. | |
| """ | |
| def __init__(self): | |
| """ | |
| Initialize the document classifier with Microsoft models. | |
| """ | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {self.device}") | |
| # Initialize LayoutLMv3 for document understanding | |
| try: | |
| logger.info("Loading LayoutLMv3 model...") | |
| self.layoutlmv3_processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base") | |
| self.layoutlmv3_model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base") | |
| self.layoutlmv3_model.to(self.device) | |
| logger.info("✅ LayoutLMv3 model loaded successfully") | |
| except Exception as e: | |
| logger.warning(f"Failed to load LayoutLMv3 model: {e}") | |
| self.layoutlmv3_processor = None | |
| self.layoutlmv3_model = None | |
| # Initialize DIT model for document classification | |
| try: | |
| logger.info("Loading DIT model...") | |
| self.dit_processor = AutoImageProcessor.from_pretrained("microsoft/dit-base-finetuned-rvlcdip") | |
| self.dit_model = AutoModelForImageClassification.from_pretrained("microsoft/dit-base-finetuned-rvlcdip") | |
| self.dit_model.to(self.device) | |
| logger.info("✅ DIT model loaded successfully") | |
| except Exception as e: | |
| logger.warning(f"Failed to load DIT model: {e}") | |
| self.dit_processor = None | |
| self.dit_model = None | |
| # Fallback text classifier | |
| try: | |
| self.fallback_classifier = pipeline( | |
| "text-classification", | |
| model="distilbert-base-uncased-finetuned-sst-2-english", | |
| device=0 if self.device == "cuda" else -1 | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Failed to load fallback classifier: {e}") | |
| self.fallback_classifier = None | |
| # Document type mappings | |
| self.document_types = { | |
| 'pdf': 'PDF Document', | |
| 'docx': 'Word Document', | |
| 'doc': 'Word Document', | |
| 'txt': 'Text Document', | |
| 'xlsx': 'Excel Spreadsheet', | |
| 'xls': 'Excel Spreadsheet', | |
| 'csv': 'CSV File', | |
| 'jpg': 'Image', | |
| 'jpeg': 'Image', | |
| 'png': 'Image', | |
| 'gif': 'Image', | |
| 'bmp': 'Image', | |
| 'tiff': 'Image', | |
| 'ppt': 'PowerPoint Presentation', | |
| 'pptx': 'PowerPoint Presentation' | |
| } | |
| # RVL-CDIP document classes (used by DIT model) | |
| self.rvlcdip_classes = [ | |
| 'letter', 'form', 'email', 'handwritten', 'advertisement', | |
| 'scientific report', 'scientific publication', 'specification', | |
| 'file folder', 'news article', 'budget', 'invoice', 'presentation', | |
| 'questionnaire', 'resume', 'memo' | |
| ] | |
| # Content-based classification keywords | |
| self.content_keywords = { | |
| 'letter': ['dear', 'sincerely', 'regards', 'yours truly', 'to whom it may concern'], | |
| 'form': ['form', 'application', 'registration', 'signature', 'date', 'name', 'address'], | |
| 'email': ['subject:', 'from:', 'to:', 'cc:', 'bcc:', 'sent:', 'received:'], | |
| 'handwritten': ['handwritten', 'hand written', 'manuscript', 'notes'], | |
| 'advertisement': ['advertisement', 'ad', 'promotion', 'sale', 'offer', 'discount'], | |
| 'scientific report': ['abstract', 'introduction', 'methodology', 'results', 'conclusion', 'references'], | |
| 'scientific publication': ['journal', 'publication', 'peer reviewed', 'doi:', 'issn:', 'volume'], | |
| 'specification': ['specification', 'requirements', 'technical', 'system', 'software', 'hardware'], | |
| 'file folder': ['folder', 'directory', 'file', 'document'], | |
| 'news article': ['news', 'article', 'breaking', 'reporter', 'journalist', 'headline'], | |
| 'budget': ['budget', 'financial', 'revenue', 'expense', 'profit', 'loss', 'balance'], | |
| 'invoice': ['invoice', 'bill', 'payment', 'amount due', 'total', 'subtotal', 'tax'], | |
| 'presentation': ['presentation', 'slide', 'powerpoint', 'agenda', 'meeting'], | |
| 'questionnaire': ['questionnaire', 'survey', 'question', 'answer', 'response'], | |
| 'resume': ['resume', 'cv', 'curriculum vitae', 'experience', 'education', 'skills'], | |
| 'memo': ['memo', 'memorandum', 'to:', 'from:', 'date:', 'subject:', 're:'] | |
| } | |
| def extract_text_from_file(self, file_path: str) -> str: | |
| """ | |
| Extract text content from various file types. | |
| Args: | |
| file_path: Path to the file | |
| Returns: | |
| Extracted text content | |
| """ | |
| try: | |
| # Get file type using python-magic | |
| file_type = magic.from_file(file_path, mime=True) | |
| file_extension = os.path.splitext(file_path)[1].lower().lstrip('.') | |
| text_content = "" | |
| if file_extension == 'pdf': | |
| text_content = self._extract_pdf_text(file_path) | |
| elif file_extension in ['docx', 'doc']: | |
| text_content = self._extract_word_text(file_path) | |
| elif file_extension in ['xlsx', 'xls']: | |
| text_content = self._extract_excel_text(file_path) | |
| elif file_extension == 'txt': | |
| text_content = self._extract_txt_text(file_path) | |
| elif file_extension in ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'tiff']: | |
| text_content = self._extract_image_text(file_path) | |
| else: | |
| # Try to read as text file | |
| try: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| text_content = f.read() | |
| except: | |
| with open(file_path, 'r', encoding='latin-1') as f: | |
| text_content = f.read() | |
| return text_content | |
| except Exception as e: | |
| logger.error(f"Error extracting text from {file_path}: {e}") | |
| return "" | |
| def _extract_pdf_text(self, file_path: str) -> str: | |
| """Extract text from PDF files.""" | |
| try: | |
| with open(file_path, 'rb') as file: | |
| pdf_reader = PyPDF2.PdfReader(file) | |
| text = "" | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() + "\n" | |
| return text | |
| except Exception as e: | |
| logger.error(f"Error extracting PDF text: {e}") | |
| return "" | |
| def _extract_word_text(self, file_path: str) -> str: | |
| """Extract text from Word documents.""" | |
| try: | |
| doc = Document(file_path) | |
| text = "" | |
| for paragraph in doc.paragraphs: | |
| text += paragraph.text + "\n" | |
| return text | |
| except Exception as e: | |
| logger.error(f"Error extracting Word text: {e}") | |
| return "" | |
| def _extract_excel_text(self, file_path: str) -> str: | |
| """Extract text from Excel files.""" | |
| try: | |
| workbook = openpyxl.load_workbook(file_path) | |
| text = "" | |
| for sheet_name in workbook.sheetnames: | |
| sheet = workbook[sheet_name] | |
| for row in sheet.iter_rows(values_only=True): | |
| text += " ".join([str(cell) for cell in row if cell is not None]) + "\n" | |
| return text | |
| except Exception as e: | |
| logger.error(f"Error extracting Excel text: {e}") | |
| return "" | |
| def _extract_txt_text(self, file_path: str) -> str: | |
| """Extract text from plain text files.""" | |
| try: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| return f.read() | |
| except: | |
| try: | |
| with open(file_path, 'r', encoding='latin-1') as f: | |
| return f.read() | |
| except Exception as e: | |
| logger.error(f"Error extracting text file: {e}") | |
| return "" | |
| def _extract_image_text(self, file_path: str) -> str: | |
| """Extract text from images using OCR via pytesseract.""" | |
| try: | |
| image = Image.open(file_path).convert("RGB") | |
| text = pytesseract.image_to_string(image) | |
| return text or "" | |
| except Exception as e: | |
| logger.error(f"Error extracting image text (OCR): {e}") | |
| return "" | |
| def _tokenize_label(self, label: str) -> List[str]: | |
| """Tokenize a label into meaningful keywords for matching.""" | |
| stopwords = { | |
| 'the','a','an','and','or','of','with','for','to','by','on','in','this','that','valid','expired','less','than','one','two','years','year','more','not','certificate','document','card','form','report','record','statement','results','order','stamp','authority','authorization','affidavit','evaluation' | |
| } | |
| tokens = re.split(r"[^a-zA-Z0-9+]+", label.lower()) | |
| tokens = [t for t in tokens if t and t not in stopwords and len(t) > 2] | |
| return tokens | |
| def classify_against_labels(self, text: str, labels: List[str]) -> Dict[str, float]: | |
| """Score OCR text against a provided list of labels using simple keyword overlap.""" | |
| if not text.strip() or not labels: | |
| return {} | |
| text_lower = text.lower() | |
| scores: Dict[str, float] = {} | |
| for label in labels: | |
| keywords = self._tokenize_label(label) | |
| if not keywords: | |
| continue | |
| hits = 0 | |
| for kw in keywords: | |
| if kw in text_lower: | |
| hits += 1 | |
| # simple ratio over keywords | |
| scores[label] = hits / len(keywords) | |
| # normalize | |
| total = sum(scores.values()) | |
| if total > 0: | |
| scores = {k: v / total for k, v in scores.items()} | |
| return scores | |
| def classify_with_dit_model(self, image_path: str) -> Dict[str, float]: | |
| """ | |
| Classify document using DIT model (Document Image Transformer). | |
| Args: | |
| image_path: Path to the document image | |
| Returns: | |
| Dictionary with document type probabilities | |
| """ | |
| if not self.dit_model or not self.dit_processor: | |
| return {"unknown": 1.0} | |
| try: | |
| # Load and preprocess image | |
| image = Image.open(image_path).convert("RGB") | |
| inputs = self.dit_processor(images=image, return_tensors="pt").to(self.device) | |
| # Get predictions | |
| with torch.no_grad(): | |
| outputs = self.dit_model(**inputs) | |
| predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| # Map predictions to document types | |
| scores = {} | |
| for i, class_name in enumerate(self.rvlcdip_classes): | |
| scores[class_name] = float(predictions[0][i]) | |
| return scores | |
| except Exception as e: | |
| logger.error(f"DIT model classification failed: {e}") | |
| return {"unknown": 1.0} | |
| def classify_with_layoutlmv3(self, text: str, image_path: str = None) -> Dict[str, float]: | |
| """ | |
| Classify document using LayoutLMv3 model. | |
| Args: | |
| text: Text content of the document | |
| image_path: Optional path to document image | |
| Returns: | |
| Dictionary with document type probabilities | |
| """ | |
| if not self.layoutlmv3_model or not self.layoutlmv3_processor: | |
| return {"unknown": 1.0} | |
| try: | |
| # For now, we'll use text-only classification | |
| # In a full implementation, you'd also process the image/layout | |
| if not text.strip(): | |
| return {"unknown": 1.0} | |
| # Truncate text if too long | |
| max_length = 512 | |
| if len(text) > max_length: | |
| text = text[:max_length] | |
| # Simple text-based classification using keyword matching | |
| # LayoutLMv3 is primarily for token classification, so we'll use it differently | |
| text_lower = text.lower() | |
| scores = {} | |
| for doc_type, keywords in self.content_keywords.items(): | |
| score = 0 | |
| for keyword in keywords: | |
| if keyword in text_lower: | |
| score += 1 | |
| scores[doc_type] = score / len(keywords) if keywords else 0 | |
| # Normalize scores | |
| total_score = sum(scores.values()) | |
| if total_score > 0: | |
| scores = {k: v/total_score for k, v in scores.items()} | |
| else: | |
| scores = {"unknown": 1.0} | |
| return scores | |
| except Exception as e: | |
| logger.error(f"LayoutLMv3 classification failed: {e}") | |
| return {"unknown": 1.0} | |
| def classify_by_content(self, text: str, image_path: str = None) -> Dict[str, float]: | |
| """ | |
| Classify document based on content analysis using Microsoft models. | |
| Args: | |
| text: Text content to analyze | |
| image_path: Optional path to document image | |
| Returns: | |
| Dictionary with document type probabilities | |
| """ | |
| if not text.strip() and not image_path: | |
| return {"unknown": 1.0} | |
| # Try DIT model first if we have an image | |
| dit_scores = {} | |
| if image_path and os.path.exists(image_path): | |
| try: | |
| dit_scores = self.classify_with_dit_model(image_path) | |
| logger.info(f"DIT model classification: {dit_scores}") | |
| except Exception as e: | |
| logger.warning(f"DIT model failed: {e}") | |
| # Try LayoutLMv3 model | |
| layoutlmv3_scores = {} | |
| if text.strip(): | |
| try: | |
| layoutlmv3_scores = self.classify_with_layoutlmv3(text, image_path) | |
| logger.info(f"LayoutLMv3 classification: {layoutlmv3_scores}") | |
| except Exception as e: | |
| logger.warning(f"LayoutLMv3 model failed: {e}") | |
| # Fallback to keyword-based classification | |
| keyword_scores = {} | |
| if text.strip(): | |
| text_lower = text.lower() | |
| for doc_type, keywords in self.content_keywords.items(): | |
| score = 0 | |
| for keyword in keywords: | |
| if keyword in text_lower: | |
| score += 1 | |
| keyword_scores[doc_type] = score / len(keywords) if keywords else 0 | |
| # Combine scores from different methods | |
| combined_scores = {} | |
| all_doc_types = set(list(dit_scores.keys()) + list(layoutlmv3_scores.keys()) + list(keyword_scores.keys())) | |
| for doc_type in all_doc_types: | |
| score = 0 | |
| count = 0 | |
| if doc_type in dit_scores and dit_scores[doc_type] > 0: | |
| score += dit_scores[doc_type] * 0.5 # DIT gets higher weight | |
| count += 1 | |
| if doc_type in layoutlmv3_scores and layoutlmv3_scores[doc_type] > 0: | |
| score += layoutlmv3_scores[doc_type] * 0.3 | |
| count += 1 | |
| if doc_type in keyword_scores and keyword_scores[doc_type] > 0: | |
| score += keyword_scores[doc_type] * 0.2 | |
| count += 1 | |
| if count > 0: | |
| combined_scores[doc_type] = score | |
| # Fallback to fallback classifier if no good scores | |
| if not combined_scores or max(combined_scores.values()) < 0.1: | |
| if self.fallback_classifier and text.strip(): | |
| try: | |
| max_length = 512 | |
| if len(text) > max_length: | |
| text = text[:max_length] | |
| hf_result = self.fallback_classifier(text) | |
| if hf_result: | |
| # Map sentiment to document types | |
| sentiment = hf_result[0]['label'].lower() | |
| confidence = hf_result[0]['score'] | |
| if 'positive' in sentiment: | |
| combined_scores['letter'] = confidence * 0.3 | |
| combined_scores['email'] = confidence * 0.2 | |
| elif 'negative' in sentiment: | |
| combined_scores['memo'] = confidence * 0.3 | |
| combined_scores['form'] = confidence * 0.2 | |
| else: | |
| combined_scores['report'] = confidence * 0.2 | |
| except Exception as e: | |
| logger.warning(f"Fallback classifier failed: {e}") | |
| # Normalize scores | |
| total_score = sum(combined_scores.values()) | |
| if total_score > 0: | |
| combined_scores = {k: v/total_score for k, v in combined_scores.items()} | |
| else: | |
| combined_scores = {"unknown": 1.0} | |
| return combined_scores | |
| def classify_document(self, file_path: str, allowed_labels: Optional[List[str]] = None) -> Dict[str, any]: | |
| """ | |
| Classify a document and return comprehensive results. | |
| Args: | |
| file_path: Path to the document file | |
| Returns: | |
| Dictionary containing classification results | |
| """ | |
| try: | |
| # Get file extension | |
| file_extension = os.path.splitext(file_path)[1].lower().lstrip('.') | |
| file_type = self.document_types.get(file_extension, 'Unknown') | |
| # Extract text content | |
| text_content = self.extract_text_from_file(file_path) | |
| # If a custom label list is provided, score against it using OCR text | |
| if allowed_labels: | |
| label_scores = self.classify_against_labels(text_content, allowed_labels) | |
| # Fallback to generic method if scores are empty | |
| content_classification = label_scores if label_scores else self.classify_by_content(text_content) | |
| else: | |
| # Generic method (legacy) | |
| content_classification = self.classify_by_content(text_content) | |
| # Get the most likely document type | |
| most_likely_type = max(content_classification.items(), key=lambda x: x[1]) | |
| result = { | |
| 'file_path': file_path, | |
| 'file_name': os.path.basename(file_path), | |
| 'file_type': file_type, | |
| 'file_extension': file_extension, | |
| 'content_length': len(text_content), | |
| 'text_preview': text_content[:200] + "..." if len(text_content) > 200 else text_content, | |
| 'classification': most_likely_type[0], | |
| 'confidence': most_likely_type[1], | |
| 'all_scores': content_classification, | |
| 'success': True | |
| } | |
| return result | |
| except Exception as e: | |
| logger.error(f"Error classifying document {file_path}: {e}") | |
| return { | |
| 'file_path': file_path, | |
| 'file_name': os.path.basename(file_path), | |
| 'error': str(e), | |
| 'success': False | |
| } | |
| def classify_multiple_documents(self, file_paths: List[str]) -> List[Dict[str, any]]: | |
| """ | |
| Classify multiple documents. | |
| Args: | |
| file_paths: List of file paths to classify | |
| Returns: | |
| List of classification results | |
| """ | |
| results = [] | |
| for file_path in file_paths: | |
| result = self.classify_document(file_path) | |
| results.append(result) | |
| return results |