--- language: - en license: mit library_name: pytorch tags: - computer-vision - image-classification - document-classification - rvl-cdip - resnet50 datasets: - chainyo/rvl-cdip metrics: - accuracy - top-k-accuracy pipeline_tag: image-classification model-index: - name: ResNet-50 Document Classifier results: - task: type: image-classification dataset: type: rvl_cdip name: RVL-CDIP (Test Split) metrics: - name: Accuracy type: accuracy value: 88.46% - name: Top-3 Accuracy type: top-k-accuracy value: 95.62% --- # Model Card for ResNet-50 Document Classifier This model is a **ResNet-50** Convolutional Neural Network (CNN) finetuned to classify scanned document images into **16 categories** (e.g., Emails, Invoices, Resumes, Scientific Reports). It achieves **88.46% overall accuracy** on the RVL-CDIP test set and a very strong **95.62% Top-3 Accuracy**, making it highly effective for automated document triage and organization pipelines. ## Model Details ![Model Architecture](aechitecture.png) ### Model Description This model utilizes the standard ResNet-50 architecture designed for image classification. Instead of "reading" the text like an OCR system, it analyzes the visual layout, structure, and low-level texture features of a whole document page to determine its category (e.g., recognizing the block layout of a resume versus the dense, two-column text of a scientific report). It was trained using **Transfer Learning**, starting with weights pre-trained on ImageNet and finetuning the backbone while retraining the classification head for the 16 document classes. - **Developed by:** Arpit ([@arpit-gour02](https://huggingface.co/arpit-gour02)) - **Model type:** Computer Vision (Image Classification / CNN) - **Language(s) (NLP):** English (Implicitly, via the text present in the RVL-CDIP dataset images) - **License:** MIT ## Why ResNet50 | Model | Approximate Parameters | Year Released | Layers | |------------|------------------------|---------------|--------| | VGG16 | 138.4 Million | 2014 | 16 | | AlexNet | 61.1 Million | 2012 | 8 | | ResNet-50 | 25.6 Million | 2015 | 50 | | Model | FLOPs (Billions) | Efficiency Score | |------------|------------------|-----------------------| | AlexNet | 0.7 GFLOPs | Low Cost / Low Acc | | ResNet-50 | 3.8 GFLOPs | High Efficiency | | VGG-16 | 15.5 GFLOPs | Terribly Inefficient | ### Model Sources - **Demo (Gradio App):** [https://huggingface.co/spaces/arpit-gour02/document-classification-demo](https://huggingface.co/spaces/arpit-gour02/document-classification-demo) - **Repository:** [https://huggingface.co/spaces/arpit-gour02/document-classification-demo/tree/main](https://huggingface.co/spaces/arpit-gour02/document-classification-demo/tree/main) ## Uses ### Direct Use This model is specifically designed for **Document Triage and Automation Pipelines**. It is best used as an initial sorting mechanism: 1. **Office Automation:** Automatically routing incoming scans to the correct department folder (e.g., sending "Invoices" to Accounting, "Resumes" to HR). 2. **Archive Digitization:** Rapidly tagging metadata for large legacy paper archives. 3. **Preprocessing Filter:** Acting as a cheap, fast gatekeeper before sending documents to expensive, specialized downstream systems (e.g., only sending confirmed "Forms" to a dedicated form-extraction model). ### Out-of-Scope Use This model is **not** suitable for: * **Text Extraction (OCR):** It classifies the *type* of document, it does not output the text written on it. * **Handwriting Recognition:** While it has a class for "Handwritten" documents, it only detects the *presence* of handwriting, it cannot read what is written. * **Non-Document Images:** The model will perform poorly on natural images (photos of objects, people, landscapes). ## Bias, Risks, and Limitations Users should be aware of the following technical limitations based on evaluation analysis: * **Resolution Sensitivity (The "Blur" Problem):** The model inputs are resized to `224x224`. At this low resolution, dense text pages look like blurry gray blocks. This causes significant confusion between classes defined by dense text, specifically distinguishing **Scientific Reports** from generic **File Folders**. * **Visual Similarity:** The model sometimes struggles to differentiate between **Forms** and **Questionnaires**, as they share very similar visual structures (checkboxes, lines, header fields). * **Dataset Bias:** The model was trained on the RVL-CDIP dataset, which consists primarily of older, grayscale, lower-quality scans. It may have lower accuracy on modern, born-digital, color PDF documents. ## How to Get Started with the Model Use the code block below to load the model architecture, load your trained weights, preprocess an image, and run inference. ```python import torch from torchvision import models, transforms from PIL import Image # --- Setup --- # 1. Define the 16 distinct classes class_names = [ 'advertisement', 'budget', 'email', 'file folder', 'form', 'handwritten', 'invoice', 'letter', 'memo', 'news article', 'presentation', 'questionnaire', 'resume', 'scientific publication', 'scientific report', 'specification' ] # 2. Define the preprocessing transformation (Must match training!) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), # Standard ImageNet normalization transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # --- Model Loading --- # 3. Load the ResNet-50 architecture and replace the final layer model = models.resnet50(pretrained=False) num_ftrs = model.fc.in_features model.fc = torch.nn.Linear(num_ftrs, len(class_names)) # 4. Load your trained weights (ensure path is correct) # Note: map_location='cpu' ensures it loads even without a GPU checkpoint = torch.load("resnet50_epoch_4.pth", map_location=torch.device('cpu')) # Handle potential differences in how state_dict was saved state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint model.load_state_dict(state_dict) model.eval() # Set to evaluation mode # --- Inference --- # 5. Load and preprocess an image image_path = "path_to_your_test_document.jpg" image = Image.open(image_path).convert('RGB') # Ensure 3 channels input_tensor = transform(image).unsqueeze(0) # Add batch dimension (B, C, H, W) # 6. Predict with torch.no_grad(): outputs = model(input_tensor) probabilities = torch.nn.functional.softmax(outputs, dim=1) top_prob, top_catid = torch.topk(probabilities, 1) print(f"Prediction: {class_names[top_catid.item()]}") print(f"Confidence: {top_prob.item()*100:.2f}%") ``` ## Training Details ### Training Data The model was trained on the **RVL-CDIP (Ryerson Vision Lab Complex Document Information Processing)** dataset. * **Total Size:** 400,000 grayscale images. * **Classes:** 16 perfectly balanced classes. * **Split:** The standard split is 320k Train, 40k Validation, 40k Test. This model was trained on the 20k images per class with 2.5k Images per class for Val and 2.5k Images per class for Test. * **Data Handling:** Original grayscale images were converted to 3-channel RGB to match the input expectations of the trained ResNet backbone. ### Training Procedure #### Preprocessing All images were resized to `224x224` pixels and normalized using standard ImageNet statistics (`mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]`). #### Training Hyperparameters The training used standard, stable hyperparameters for fine-tuning CNNs: * **Optimizer:** SGD (Stochastic Gradient Descent) * **Learning Rate:** 0.01 * **Momentum:** 0.9 * **Batch Size:** 64 * **Epochs:** 5 * **Training Regime:** Mixed Precision (Automatic Mixed Precision used implicitly via PyTorch for speed and memory efficiency). ## Evaluation ### Testing Data, Factors & Metrics The model was evaluated on the standard, unseen **RVL-CDIP Test Split** containing 40,000 images. **Metrics used:** * **Accuracy:** The percentage of predictions that exactly matched the ground truth. * **Top-3 Accuracy:** The percentage of times the correct label appeared in the model's top three highest-probability predictions. This is often the most relevant metric for human-in-the-loop triage systems. * **Precision/Recall/F1-Score:** Evaluated on a per-class basis to identify specific strengths and weaknesses in the model's performance. ### Results | Metric | Result | Notes | | :--- | :--- | :--- | | **Overall Accuracy** | **88.46%** | Solid baseline performance. | | **Top-3 Accuracy** | **95.62%** | Excellent reliability for triage tasks. | ![Loss and Accuracy Curves](results/loss_and_acc_curve.png) #### Confusion Matrix ![Confusion Matrix](results/cm.png) #### Detailed Classificatio report ![Detailed Classification report](results/detailed_classification_report.png) #### Detailed Performance Analysis (The "Traffic Light" Report) An analysis of per-class F1-scores reveals distinct tiers of performance: * 🟢 **Excellent (>90% F1):** `Email`, `Resume`, `Memo`, `Handwritten`, `Specification`. The model is highly reliable for core administrative documents with distinct visual structures. * 🟡 **Reliable (~85-89% F1):** `Invoice`, `Advertisement`, `News Article`, `Budget`. * 🔴 **Challenging (<75% Precision):** `Scientific Report`, `Form`, `File Folder`. The major weakness is misclassifying Scientific Reports as File Folders due to resolution constraints blurring the dense text. ## Environmental Impact The training was conducted locally on consumer-grade hardware, resulting in negligible environmental impact compared to large-scale language model training. * **Hardware Type:** Apple M-Series Chip / single NVIDIA GPU * **Hours used:** Approximately 5 hours (1 hour per epoch) * **Carbon Emitted:** Negligible local usage. ## Technical Specifications ### Model Architecture and Objective The model consists of the **ResNet-50 backbone** (a 50-layer deep Convolutional Neural Network using residual connections and bottleneck blocks) followed by a custom classification head. * **Input Shape:** `(Batch_Size, 3, 224, 224)` (RGB Images) * **Backbone Output:** 2048 feature maps of size `7x7`. * **Pooling:** Global Average Pooling reduces dimensions to `(Batch_Size, 2048)`. * **Classification Head:** A single fully connected linear layer mapping 2048 features to 16 class logits. * **Objective:** Minimize Cross-Entropy Loss between predicted logits and ground truth class labels. ## Citation If you use this model or the RVL-CDIP dataset, please cite the original paper: **BibTeX:** ```bibtex @inproceedings{harley2015icdar, title = {Evaluation of Deep Convolutional Nets for Document Image Classification and Retrieval}, author = {Adam W. Harley and Alex Ufkes and Konstantinos G. Derpanis}, booktitle = {International Conference on Document Analysis and Recognition (ICDAR)}, year = {2015} }