Spaces:
Build error
Build error
| """ | |
| LayoutLMv3 Document Classifier Demo | |
| Classifies document pages from PDF files using a fine-tuned LayoutLMv3 model. | |
| """ | |
| import gradio as gr | |
| import torch | |
| import fitz # PyMuPDF | |
| from PIL import Image | |
| import numpy as np | |
| from pathlib import Path | |
| import tempfile | |
| import os | |
| from typing import List, Tuple, Dict, Any | |
| from transformers import LayoutLMv3Processor, LayoutLMv3ForSequenceClassification | |
| # Model configuration | |
| MODEL_NAME = "jmparejaz/layoutlmv3-base-expedientes" | |
| DEFAULT_DPI = 150 | |
| class DocumentClassifier: | |
| """Handles model loading and inference for document classification.""" | |
| def __init__(self, model_name: str): | |
| self.model_name = model_name | |
| self._model = None | |
| self._processor = None | |
| self._label_names = None | |
| self._device = "cpu" # Use CPU for HuggingFace Spaces free tier | |
| def load_model(self): | |
| """Load the model and processor lazily.""" | |
| if self._model is None: | |
| print(f"[Loading model] {self.model_name}") | |
| self._processor = LayoutLMv3Processor.from_pretrained(self.model_name) | |
| self._model = LayoutLMv3ForSequenceClassification.from_pretrained(self.model_name) | |
| self._model.to(self._device) | |
| self._model.eval() | |
| # Get label names from model config | |
| self._label_names = list(self._model.config.label2id.keys()) | |
| print(f"[Model loaded] Labels: {self._label_names}") | |
| def classify_image(self, image: Image.Image) -> Tuple[str, float]: | |
| """ | |
| Classify a single image. | |
| Args: | |
| image: PIL Image (RGB) | |
| Returns: | |
| Tuple of (label, confidence) | |
| """ | |
| self.load_model() | |
| # Ensure image is RGB | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # Process image | |
| with torch.no_grad(): | |
| encoding = self._processor( | |
| images=image, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt" | |
| ) | |
| # Move to device | |
| encoding = {k: v.to(self._device) for k, v in encoding.items()} | |
| # Forward pass | |
| outputs = self._model(**encoding) | |
| logits = outputs.logits | |
| # Get prediction | |
| probs = torch.softmax(logits, dim=-1) | |
| pred_idx = torch.argmax(probs, dim=-1).item() | |
| confidence = probs[0, pred_idx].item() | |
| label = self._label_names[pred_idx] if self._label_names else str(pred_idx) | |
| return label, confidence | |
| def classify_batch(self, images: List[Image.Image]) -> List[Tuple[str, float]]: | |
| """Classify multiple images.""" | |
| return [self.classify_image(img) for img in images] | |
| def pdf_to_images(pdf_path: str, dpi: int = DEFAULT_DPI) -> List[Image.Image]: | |
| """ | |
| Convert PDF pages to images. | |
| Args: | |
| pdf_path: Path to PDF file | |
| dpi: Resolution for rendering | |
| Returns: | |
| List of PIL Images (one per page) | |
| """ | |
| images = [] | |
| doc = fitz.open(pdf_path) | |
| zoom = dpi / 72 | |
| matrix = fitz.Matrix(zoom, zoom) | |
| for page_num in range(len(doc)): | |
| page = doc.load_page(page_num) | |
| pix = page.get_pixmap(matrix=matrix) | |
| img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| images.append(img) | |
| doc.close() | |
| return images | |
| def process_pdf( | |
| pdf_file: str, | |
| classifier: DocumentClassifier, | |
| progress: gr.Progress = None | |
| ) -> Tuple[List[Image.Image], List[str], List[float], str]: | |
| """ | |
| Process a PDF file and classify each page. | |
| Args: | |
| pdf_file: Path to uploaded PDF | |
| classifier: DocumentClassifier instance | |
| progress: Gradio progress tracker | |
| Returns: | |
| Tuple of (images, labels, confidences, summary_text) | |
| """ | |
| if pdf_file is None: | |
| return [], [], [], "❌ Por favor sube un archivo PDF." | |
| try: | |
| # Convert PDF to images | |
| if progress: | |
| progress(0.1, desc="📄 Convirtiendo PDF a imágenes...") | |
| images = pdf_to_images(pdf_file) | |
| if not images: | |
| return [], [], [], "❌ No se pudieron extraer páginas del PDF." | |
| # Classify each page | |
| labels = [] | |
| confidences = [] | |
| for i, img in enumerate(images): | |
| if progress: | |
| progress(0.1 + (0.8 * (i + 1) / len(images)), | |
| desc=f"🔍 Clasificando página {i+1}/{len(images)}...") | |
| label, conf = classifier.classify_image(img) | |
| labels.append(label) | |
| confidences.append(conf) | |
| # Create summary | |
| summary_lines = [f"## 📊 Resumen de Clasificación\n"] | |
| summary_lines.append(f"**Total de páginas:** {len(images)}\n") | |
| # Count by label | |
| label_counts = {} | |
| for label in labels: | |
| label_counts[label] = label_counts.get(label, 0) + 1 | |
| summary_lines.append("### Distribución por clase:\n") | |
| for label, count in sorted(label_counts.items()): | |
| percentage = (count / len(labels)) * 100 | |
| summary_lines.append(f"- **{label}**: {count} páginas ({percentage:.1f}%)") | |
| summary_text = "\n".join(summary_lines) | |
| if progress: | |
| progress(1.0, desc="✅ Procesamiento completado") | |
| return images, labels, confidences, summary_text | |
| except Exception as e: | |
| return [], [], [], f"❌ Error procesando el PDF: {str(e)}" | |
| # Global classifier instance | |
| classifier = DocumentClassifier(MODEL_NAME) | |
| def create_interface(): | |
| """Create the Gradio interface.""" | |
| with gr.Blocks( | |
| title="LayoutLMv3 Document Classifier", | |
| theme=gr.themes.Soft() | |
| ) as demo: | |
| gr.Markdown(""" | |
| # 📄 LayoutLMv3 Document Classifier | |
| Clasificador de páginas de documentos legales colombianos usando un modelo | |
| **LayoutLMv3** fine-tuned. | |
| **Instrucciones:** | |
| 1. Sube un archivo PDF | |
| 2. El modelo procesará cada página | |
| 3. Verás la imagen, clase predicha y nivel de confianza | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| pdf_input = gr.File( | |
| label="📄 Sube tu PDF", | |
| file_types=[".pdf"], | |
| type="filepath" | |
| ) | |
| process_btn = gr.Button( | |
| "🔍 Clasificar Documento", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| summary_output = gr.Markdown( | |
| label="📊 Resumen", | |
| visible=True | |
| ) | |
| with gr.Row(): | |
| gallery_output = gr.Gallery( | |
| label="📑 Páginas Clasificadas", | |
| show_label=True, | |
| columns=3, | |
| rows=2, | |
| height="auto", | |
| object_fit="contain" | |
| ) | |
| # Results table | |
| results_table = gr.Dataframe( | |
| headers=["Página", "Clase", "Confianza"], | |
| datatype=["number", "str", "number"], | |
| label="📋 Detalles de Clasificación", | |
| interactive=False | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| **Modelo:** [jmparejaz/layoutlmv3-base-expedientes](https://huggingface.co/jmparejaz/layoutlmv3-base-expedientes) | |
| **Notas:** | |
| - Procesamiento optimizado para CPU | |
| - Versión 0.1.0 (en desarrollo activo) | |
| """) | |
| def on_process(pdf_file, progress=gr.Progress()): | |
| images, labels, confidences, summary = process_pdf( | |
| pdf_file, classifier, progress | |
| ) | |
| if not images: | |
| return None, [], summary | |
| # Prepare gallery items with captions | |
| gallery_items = [] | |
| table_data = [] | |
| for i, (img, label, conf) in enumerate(zip(images, labels, confidences)): | |
| # Add to gallery with caption | |
| caption = f"Pág. {i+1}: {label}\n({conf*100:.1f}%)" | |
| gallery_items.append((img, caption)) | |
| # Add to table | |
| table_data.append([i + 1, label, f"{conf*100:.2f}%"]) | |
| return gallery_items, table_data, summary | |
| process_btn.click( | |
| fn=on_process, | |
| inputs=[pdf_input], | |
| outputs=[gallery_output, results_table, summary_output] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch() | |