jmparejaz's picture
Upload folder using huggingface_hub
5266382 verified
"""
LayoutLMv3 Document Classifier Demo
Classifies document pages from PDF files using a fine-tuned LayoutLMv3 model.
"""
import gradio as gr
import torch
import fitz # PyMuPDF
from PIL import Image
import numpy as np
from pathlib import Path
import tempfile
import os
from typing import List, Tuple, Dict, Any
from transformers import LayoutLMv3Processor, LayoutLMv3ForSequenceClassification
# Model configuration
MODEL_NAME = "jmparejaz/layoutlmv3-base-expedientes"
DEFAULT_DPI = 150
class DocumentClassifier:
"""Handles model loading and inference for document classification."""
def __init__(self, model_name: str):
self.model_name = model_name
self._model = None
self._processor = None
self._label_names = None
self._device = "cpu" # Use CPU for HuggingFace Spaces free tier
def load_model(self):
"""Load the model and processor lazily."""
if self._model is None:
print(f"[Loading model] {self.model_name}")
self._processor = LayoutLMv3Processor.from_pretrained(self.model_name)
self._model = LayoutLMv3ForSequenceClassification.from_pretrained(self.model_name)
self._model.to(self._device)
self._model.eval()
# Get label names from model config
self._label_names = list(self._model.config.label2id.keys())
print(f"[Model loaded] Labels: {self._label_names}")
def classify_image(self, image: Image.Image) -> Tuple[str, float]:
"""
Classify a single image.
Args:
image: PIL Image (RGB)
Returns:
Tuple of (label, confidence)
"""
self.load_model()
# Ensure image is RGB
if image.mode != "RGB":
image = image.convert("RGB")
# Process image
with torch.no_grad():
encoding = self._processor(
images=image,
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt"
)
# Move to device
encoding = {k: v.to(self._device) for k, v in encoding.items()}
# Forward pass
outputs = self._model(**encoding)
logits = outputs.logits
# Get prediction
probs = torch.softmax(logits, dim=-1)
pred_idx = torch.argmax(probs, dim=-1).item()
confidence = probs[0, pred_idx].item()
label = self._label_names[pred_idx] if self._label_names else str(pred_idx)
return label, confidence
def classify_batch(self, images: List[Image.Image]) -> List[Tuple[str, float]]:
"""Classify multiple images."""
return [self.classify_image(img) for img in images]
def pdf_to_images(pdf_path: str, dpi: int = DEFAULT_DPI) -> List[Image.Image]:
"""
Convert PDF pages to images.
Args:
pdf_path: Path to PDF file
dpi: Resolution for rendering
Returns:
List of PIL Images (one per page)
"""
images = []
doc = fitz.open(pdf_path)
zoom = dpi / 72
matrix = fitz.Matrix(zoom, zoom)
for page_num in range(len(doc)):
page = doc.load_page(page_num)
pix = page.get_pixmap(matrix=matrix)
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
images.append(img)
doc.close()
return images
def process_pdf(
pdf_file: str,
classifier: DocumentClassifier,
progress: gr.Progress = None
) -> Tuple[List[Image.Image], List[str], List[float], str]:
"""
Process a PDF file and classify each page.
Args:
pdf_file: Path to uploaded PDF
classifier: DocumentClassifier instance
progress: Gradio progress tracker
Returns:
Tuple of (images, labels, confidences, summary_text)
"""
if pdf_file is None:
return [], [], [], "❌ Por favor sube un archivo PDF."
try:
# Convert PDF to images
if progress:
progress(0.1, desc="📄 Convirtiendo PDF a imágenes...")
images = pdf_to_images(pdf_file)
if not images:
return [], [], [], "❌ No se pudieron extraer páginas del PDF."
# Classify each page
labels = []
confidences = []
for i, img in enumerate(images):
if progress:
progress(0.1 + (0.8 * (i + 1) / len(images)),
desc=f"🔍 Clasificando página {i+1}/{len(images)}...")
label, conf = classifier.classify_image(img)
labels.append(label)
confidences.append(conf)
# Create summary
summary_lines = [f"## 📊 Resumen de Clasificación\n"]
summary_lines.append(f"**Total de páginas:** {len(images)}\n")
# Count by label
label_counts = {}
for label in labels:
label_counts[label] = label_counts.get(label, 0) + 1
summary_lines.append("### Distribución por clase:\n")
for label, count in sorted(label_counts.items()):
percentage = (count / len(labels)) * 100
summary_lines.append(f"- **{label}**: {count} páginas ({percentage:.1f}%)")
summary_text = "\n".join(summary_lines)
if progress:
progress(1.0, desc="✅ Procesamiento completado")
return images, labels, confidences, summary_text
except Exception as e:
return [], [], [], f"❌ Error procesando el PDF: {str(e)}"
# Global classifier instance
classifier = DocumentClassifier(MODEL_NAME)
def create_interface():
"""Create the Gradio interface."""
with gr.Blocks(
title="LayoutLMv3 Document Classifier",
theme=gr.themes.Soft()
) as demo:
gr.Markdown("""
# 📄 LayoutLMv3 Document Classifier
Clasificador de páginas de documentos legales colombianos usando un modelo
**LayoutLMv3** fine-tuned.
**Instrucciones:**
1. Sube un archivo PDF
2. El modelo procesará cada página
3. Verás la imagen, clase predicha y nivel de confianza
""")
with gr.Row():
with gr.Column(scale=1):
pdf_input = gr.File(
label="📄 Sube tu PDF",
file_types=[".pdf"],
type="filepath"
)
process_btn = gr.Button(
"🔍 Clasificar Documento",
variant="primary",
size="lg"
)
summary_output = gr.Markdown(
label="📊 Resumen",
visible=True
)
with gr.Row():
gallery_output = gr.Gallery(
label="📑 Páginas Clasificadas",
show_label=True,
columns=3,
rows=2,
height="auto",
object_fit="contain"
)
# Results table
results_table = gr.Dataframe(
headers=["Página", "Clase", "Confianza"],
datatype=["number", "str", "number"],
label="📋 Detalles de Clasificación",
interactive=False
)
gr.Markdown("""
---
**Modelo:** [jmparejaz/layoutlmv3-base-expedientes](https://huggingface.co/jmparejaz/layoutlmv3-base-expedientes)
**Notas:**
- Procesamiento optimizado para CPU
- Versión 0.1.0 (en desarrollo activo)
""")
def on_process(pdf_file, progress=gr.Progress()):
images, labels, confidences, summary = process_pdf(
pdf_file, classifier, progress
)
if not images:
return None, [], summary
# Prepare gallery items with captions
gallery_items = []
table_data = []
for i, (img, label, conf) in enumerate(zip(images, labels, confidences)):
# Add to gallery with caption
caption = f"Pág. {i+1}: {label}\n({conf*100:.1f}%)"
gallery_items.append((img, caption))
# Add to table
table_data.append([i + 1, label, f"{conf*100:.2f}%"])
return gallery_items, table_data, summary
process_btn.click(
fn=on_process,
inputs=[pdf_input],
outputs=[gallery_output, results_table, summary_output]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch()