""" LayoutLMv3 Document Classifier Demo Classifies document pages from PDF files using a fine-tuned LayoutLMv3 model. """ import gradio as gr import torch import fitz # PyMuPDF from PIL import Image import numpy as np from pathlib import Path import tempfile import os from typing import List, Tuple, Dict, Any from transformers import LayoutLMv3Processor, LayoutLMv3ForSequenceClassification # Model configuration MODEL_NAME = "jmparejaz/layoutlmv3-base-expedientes" DEFAULT_DPI = 150 class DocumentClassifier: """Handles model loading and inference for document classification.""" def __init__(self, model_name: str): self.model_name = model_name self._model = None self._processor = None self._label_names = None self._device = "cpu" # Use CPU for HuggingFace Spaces free tier def load_model(self): """Load the model and processor lazily.""" if self._model is None: print(f"[Loading model] {self.model_name}") self._processor = LayoutLMv3Processor.from_pretrained(self.model_name) self._model = LayoutLMv3ForSequenceClassification.from_pretrained(self.model_name) self._model.to(self._device) self._model.eval() # Get label names from model config self._label_names = list(self._model.config.label2id.keys()) print(f"[Model loaded] Labels: {self._label_names}") def classify_image(self, image: Image.Image) -> Tuple[str, float]: """ Classify a single image. Args: image: PIL Image (RGB) Returns: Tuple of (label, confidence) """ self.load_model() # Ensure image is RGB if image.mode != "RGB": image = image.convert("RGB") # Process image with torch.no_grad(): encoding = self._processor( images=image, padding="max_length", truncation=True, max_length=512, return_tensors="pt" ) # Move to device encoding = {k: v.to(self._device) for k, v in encoding.items()} # Forward pass outputs = self._model(**encoding) logits = outputs.logits # Get prediction probs = torch.softmax(logits, dim=-1) pred_idx = torch.argmax(probs, dim=-1).item() confidence = probs[0, pred_idx].item() label = self._label_names[pred_idx] if self._label_names else str(pred_idx) return label, confidence def classify_batch(self, images: List[Image.Image]) -> List[Tuple[str, float]]: """Classify multiple images.""" return [self.classify_image(img) for img in images] def pdf_to_images(pdf_path: str, dpi: int = DEFAULT_DPI) -> List[Image.Image]: """ Convert PDF pages to images. Args: pdf_path: Path to PDF file dpi: Resolution for rendering Returns: List of PIL Images (one per page) """ images = [] doc = fitz.open(pdf_path) zoom = dpi / 72 matrix = fitz.Matrix(zoom, zoom) for page_num in range(len(doc)): page = doc.load_page(page_num) pix = page.get_pixmap(matrix=matrix) img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) images.append(img) doc.close() return images def process_pdf( pdf_file: str, classifier: DocumentClassifier, progress: gr.Progress = None ) -> Tuple[List[Image.Image], List[str], List[float], str]: """ Process a PDF file and classify each page. Args: pdf_file: Path to uploaded PDF classifier: DocumentClassifier instance progress: Gradio progress tracker Returns: Tuple of (images, labels, confidences, summary_text) """ if pdf_file is None: return [], [], [], "❌ Por favor sube un archivo PDF." try: # Convert PDF to images if progress: progress(0.1, desc="📄 Convirtiendo PDF a imágenes...") images = pdf_to_images(pdf_file) if not images: return [], [], [], "❌ No se pudieron extraer páginas del PDF." # Classify each page labels = [] confidences = [] for i, img in enumerate(images): if progress: progress(0.1 + (0.8 * (i + 1) / len(images)), desc=f"🔍 Clasificando página {i+1}/{len(images)}...") label, conf = classifier.classify_image(img) labels.append(label) confidences.append(conf) # Create summary summary_lines = [f"## 📊 Resumen de Clasificación\n"] summary_lines.append(f"**Total de páginas:** {len(images)}\n") # Count by label label_counts = {} for label in labels: label_counts[label] = label_counts.get(label, 0) + 1 summary_lines.append("### Distribución por clase:\n") for label, count in sorted(label_counts.items()): percentage = (count / len(labels)) * 100 summary_lines.append(f"- **{label}**: {count} páginas ({percentage:.1f}%)") summary_text = "\n".join(summary_lines) if progress: progress(1.0, desc="✅ Procesamiento completado") return images, labels, confidences, summary_text except Exception as e: return [], [], [], f"❌ Error procesando el PDF: {str(e)}" # Global classifier instance classifier = DocumentClassifier(MODEL_NAME) def create_interface(): """Create the Gradio interface.""" with gr.Blocks( title="LayoutLMv3 Document Classifier", theme=gr.themes.Soft() ) as demo: gr.Markdown(""" # 📄 LayoutLMv3 Document Classifier Clasificador de páginas de documentos legales colombianos usando un modelo **LayoutLMv3** fine-tuned. **Instrucciones:** 1. Sube un archivo PDF 2. El modelo procesará cada página 3. Verás la imagen, clase predicha y nivel de confianza """) with gr.Row(): with gr.Column(scale=1): pdf_input = gr.File( label="📄 Sube tu PDF", file_types=[".pdf"], type="filepath" ) process_btn = gr.Button( "🔍 Clasificar Documento", variant="primary", size="lg" ) summary_output = gr.Markdown( label="📊 Resumen", visible=True ) with gr.Row(): gallery_output = gr.Gallery( label="📑 Páginas Clasificadas", show_label=True, columns=3, rows=2, height="auto", object_fit="contain" ) # Results table results_table = gr.Dataframe( headers=["Página", "Clase", "Confianza"], datatype=["number", "str", "number"], label="📋 Detalles de Clasificación", interactive=False ) gr.Markdown(""" --- **Modelo:** [jmparejaz/layoutlmv3-base-expedientes](https://huggingface.co/jmparejaz/layoutlmv3-base-expedientes) **Notas:** - Procesamiento optimizado para CPU - Versión 0.1.0 (en desarrollo activo) """) def on_process(pdf_file, progress=gr.Progress()): images, labels, confidences, summary = process_pdf( pdf_file, classifier, progress ) if not images: return None, [], summary # Prepare gallery items with captions gallery_items = [] table_data = [] for i, (img, label, conf) in enumerate(zip(images, labels, confidences)): # Add to gallery with caption caption = f"Pág. {i+1}: {label}\n({conf*100:.1f}%)" gallery_items.append((img, caption)) # Add to table table_data.append([i + 1, label, f"{conf*100:.2f}%"]) return gallery_items, table_data, summary process_btn.click( fn=on_process, inputs=[pdf_input], outputs=[gallery_output, results_table, summary_output] ) return demo if __name__ == "__main__": demo = create_interface() demo.launch()