File size: 9,019 Bytes
5266382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
"""
LayoutLMv3 Document Classifier Demo
Classifies document pages from PDF files using a fine-tuned LayoutLMv3 model.
"""

import gradio as gr
import torch
import fitz  # PyMuPDF
from PIL import Image
import numpy as np
from pathlib import Path
import tempfile
import os
from typing import List, Tuple, Dict, Any
from transformers import LayoutLMv3Processor, LayoutLMv3ForSequenceClassification


# Model configuration
MODEL_NAME = "jmparejaz/layoutlmv3-base-expedientes"
DEFAULT_DPI = 150


class DocumentClassifier:
    """Handles model loading and inference for document classification."""
    
    def __init__(self, model_name: str):
        self.model_name = model_name
        self._model = None
        self._processor = None
        self._label_names = None
        self._device = "cpu"  # Use CPU for HuggingFace Spaces free tier
    
    def load_model(self):
        """Load the model and processor lazily."""
        if self._model is None:
            print(f"[Loading model] {self.model_name}")
            self._processor = LayoutLMv3Processor.from_pretrained(self.model_name)
            self._model = LayoutLMv3ForSequenceClassification.from_pretrained(self.model_name)
            self._model.to(self._device)
            self._model.eval()
            
            # Get label names from model config
            self._label_names = list(self._model.config.label2id.keys())
            print(f"[Model loaded] Labels: {self._label_names}")
    
    def classify_image(self, image: Image.Image) -> Tuple[str, float]:
        """
        Classify a single image.
        
        Args:
            image: PIL Image (RGB)
            
        Returns:
            Tuple of (label, confidence)
        """
        self.load_model()
        
        # Ensure image is RGB
        if image.mode != "RGB":
            image = image.convert("RGB")
        
        # Process image
        with torch.no_grad():
            encoding = self._processor(
                images=image,
                padding="max_length",
                truncation=True,
                max_length=512,
                return_tensors="pt"
            )
            
            # Move to device
            encoding = {k: v.to(self._device) for k, v in encoding.items()}
            
            # Forward pass
            outputs = self._model(**encoding)
            logits = outputs.logits
            
            # Get prediction
            probs = torch.softmax(logits, dim=-1)
            pred_idx = torch.argmax(probs, dim=-1).item()
            confidence = probs[0, pred_idx].item()
            
            label = self._label_names[pred_idx] if self._label_names else str(pred_idx)
            
            return label, confidence
    
    def classify_batch(self, images: List[Image.Image]) -> List[Tuple[str, float]]:
        """Classify multiple images."""
        return [self.classify_image(img) for img in images]


def pdf_to_images(pdf_path: str, dpi: int = DEFAULT_DPI) -> List[Image.Image]:
    """
    Convert PDF pages to images.
    
    Args:
        pdf_path: Path to PDF file
        dpi: Resolution for rendering
        
    Returns:
        List of PIL Images (one per page)
    """
    images = []
    doc = fitz.open(pdf_path)
    
    zoom = dpi / 72
    matrix = fitz.Matrix(zoom, zoom)
    
    for page_num in range(len(doc)):
        page = doc.load_page(page_num)
        pix = page.get_pixmap(matrix=matrix)
        img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
        images.append(img)
    
    doc.close()
    return images


def process_pdf(
    pdf_file: str,
    classifier: DocumentClassifier,
    progress: gr.Progress = None
) -> Tuple[List[Image.Image], List[str], List[float], str]:
    """
    Process a PDF file and classify each page.
    
    Args:
        pdf_file: Path to uploaded PDF
        classifier: DocumentClassifier instance
        progress: Gradio progress tracker
        
    Returns:
        Tuple of (images, labels, confidences, summary_text)
    """
    if pdf_file is None:
        return [], [], [], "❌ Por favor sube un archivo PDF."
    
    try:
        # Convert PDF to images
        if progress:
            progress(0.1, desc="📄 Convirtiendo PDF a imágenes...")
        
        images = pdf_to_images(pdf_file)
        
        if not images:
            return [], [], [], "❌ No se pudieron extraer páginas del PDF."
        
        # Classify each page
        labels = []
        confidences = []
        
        for i, img in enumerate(images):
            if progress:
                progress(0.1 + (0.8 * (i + 1) / len(images)), 
                        desc=f"🔍 Clasificando página {i+1}/{len(images)}...")
            
            label, conf = classifier.classify_image(img)
            labels.append(label)
            confidences.append(conf)
        
        # Create summary
        summary_lines = [f"## 📊 Resumen de Clasificación\n"]
        summary_lines.append(f"**Total de páginas:** {len(images)}\n")
        
        # Count by label
        label_counts = {}
        for label in labels:
            label_counts[label] = label_counts.get(label, 0) + 1
        
        summary_lines.append("### Distribución por clase:\n")
        for label, count in sorted(label_counts.items()):
            percentage = (count / len(labels)) * 100
            summary_lines.append(f"- **{label}**: {count} páginas ({percentage:.1f}%)")
        
        summary_text = "\n".join(summary_lines)
        
        if progress:
            progress(1.0, desc="✅ Procesamiento completado")
        
        return images, labels, confidences, summary_text
        
    except Exception as e:
        return [], [], [], f"❌ Error procesando el PDF: {str(e)}"


# Global classifier instance
classifier = DocumentClassifier(MODEL_NAME)


def create_interface():
    """Create the Gradio interface."""
    
    with gr.Blocks(
        title="LayoutLMv3 Document Classifier",
        theme=gr.themes.Soft()
    ) as demo:
        
        gr.Markdown("""
        # 📄 LayoutLMv3 Document Classifier
        
        Clasificador de páginas de documentos legales colombianos usando un modelo 
        **LayoutLMv3** fine-tuned.
        
        **Instrucciones:**
        1. Sube un archivo PDF
        2. El modelo procesará cada página
        3. Verás la imagen, clase predicha y nivel de confianza
        """)
        
        with gr.Row():
            with gr.Column(scale=1):
                pdf_input = gr.File(
                    label="📄 Sube tu PDF",
                    file_types=[".pdf"],
                    type="filepath"
                )
                
                process_btn = gr.Button(
                    "🔍 Clasificar Documento",
                    variant="primary",
                    size="lg"
                )
                
                summary_output = gr.Markdown(
                    label="📊 Resumen",
                    visible=True
                )
        
        with gr.Row():
            gallery_output = gr.Gallery(
                label="📑 Páginas Clasificadas",
                show_label=True,
                columns=3,
                rows=2,
                height="auto",
                object_fit="contain"
            )
        
        # Results table
        results_table = gr.Dataframe(
            headers=["Página", "Clase", "Confianza"],
            datatype=["number", "str", "number"],
            label="📋 Detalles de Clasificación",
            interactive=False
        )
        
        gr.Markdown("""
        ---
        **Modelo:** [jmparejaz/layoutlmv3-base-expedientes](https://huggingface.co/jmparejaz/layoutlmv3-base-expedientes)
        
        **Notas:**
        - Procesamiento optimizado para CPU
        - Versión 0.1.0 (en desarrollo activo)
        """)
        
        def on_process(pdf_file, progress=gr.Progress()):
            images, labels, confidences, summary = process_pdf(
                pdf_file, classifier, progress
            )
            
            if not images:
                return None, [], summary
            
            # Prepare gallery items with captions
            gallery_items = []
            table_data = []
            
            for i, (img, label, conf) in enumerate(zip(images, labels, confidences)):
                # Add to gallery with caption
                caption = f"Pág. {i+1}: {label}\n({conf*100:.1f}%)"
                gallery_items.append((img, caption))
                
                # Add to table
                table_data.append([i + 1, label, f"{conf*100:.2f}%"])
            
            return gallery_items, table_data, summary
        
        process_btn.click(
            fn=on_process,
            inputs=[pdf_input],
            outputs=[gallery_output, results_table, summary_output]
        )
    
    return demo


if __name__ == "__main__":
    demo = create_interface()
    demo.launch()