File size: 11,140 Bytes
1321f6f
 
 
 
 
 
 
 
 
 
 
 
8fdd1db
1321f6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1be6cb
1321f6f
 
f1be6cb
1321f6f
 
 
 
 
 
 
 
8a37ee0
 
1321f6f
 
 
 
 
 
 
 
 
 
8a37ee0
 
 
 
 
 
 
 
 
 
 
 
 
 
1321f6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7411c1
85c0c4a
1321f6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a37ee0
 
 
 
 
 
 
 
 
1321f6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
---
language:
- en
license: mit
library_name: pytorch
tags:
- computer-vision
- image-classification
- document-classification
- rvl-cdip
- resnet50
datasets:
- chainyo/rvl-cdip
metrics:
- accuracy
- top-k-accuracy
pipeline_tag: image-classification
model-index:
- name: ResNet-50 Document Classifier
  results:
  - task:
      type: image-classification
    dataset:
      type: rvl_cdip
      name: RVL-CDIP (Test Split)
    metrics:
    - name: Accuracy
      type: accuracy
      value: 88.46%
    - name: Top-3 Accuracy
      type: top-k-accuracy
      value: 95.62%
---

# Model Card for ResNet-50 Document Classifier

This model is a **ResNet-50** Convolutional Neural Network (CNN) finetuned to classify scanned document images into **16 categories** (e.g., Emails, Invoices, Resumes, Scientific Reports). It achieves **88.46% overall accuracy** on the RVL-CDIP test set and a very strong **95.62% Top-3 Accuracy**, making it highly effective for automated document triage and organization pipelines.

## Model Details

![Model Architecture](aechitecture.png)

### Model Description

This model utilizes the standard ResNet-50 architecture designed for image classification. Instead of "reading" the text like an OCR system, it analyzes the visual layout, structure, and low-level texture features of a whole document page to determine its category (e.g., recognizing the block layout of a resume versus the dense, two-column text of a scientific report).

It was trained using **Transfer Learning**, starting with weights pre-trained on ImageNet and finetuning the backbone while retraining the classification head for the 16 document classes.

- **Developed by:** Arpit ([@arpit-gour02](https://huggingface.co/arpit-gour02))
- **Model type:** Computer Vision (Image Classification / CNN)
- **Language(s) (NLP):** English (Implicitly, via the text present in the RVL-CDIP dataset images)
- **License:** MIT

## Why ResNet50

| Model      | Approximate Parameters | Year Released | Layers |
|------------|------------------------|---------------|--------|
| VGG16      | 138.4 Million          | 2014          | 16     |
| AlexNet    | 61.1 Million           | 2012          | 8      |
| ResNet-50  | 25.6 Million           | 2015          | 50     |

| Model      | FLOPs (Billions) | Efficiency Score      |
|------------|------------------|-----------------------|
| AlexNet    | 0.7 GFLOPs       | Low Cost / Low Acc    |
| ResNet-50  | 3.8 GFLOPs       | High Efficiency       |
| VGG-16     | 15.5 GFLOPs      | Terribly Inefficient  |

### Model Sources

- **Demo (Gradio App):** [https://huggingface.co/spaces/arpit-gour02/document-classification-demo](https://huggingface.co/spaces/arpit-gour02/document-classification-demo)
- **Repository:** [https://huggingface.co/spaces/arpit-gour02/document-classification-demo/tree/main](https://huggingface.co/spaces/arpit-gour02/document-classification-demo/tree/main)

## Uses

### Direct Use

This model is specifically designed for **Document Triage and Automation Pipelines**. It is best used as an initial sorting mechanism:

1.  **Office Automation:** Automatically routing incoming scans to the correct department folder (e.g., sending "Invoices" to Accounting, "Resumes" to HR).
2.  **Archive Digitization:** Rapidly tagging metadata for large legacy paper archives.
3.  **Preprocessing Filter:** Acting as a cheap, fast gatekeeper before sending documents to expensive, specialized downstream systems (e.g., only sending confirmed "Forms" to a dedicated form-extraction model).

### Out-of-Scope Use

This model is **not** suitable for:

* **Text Extraction (OCR):** It classifies the *type* of document, it does not output the text written on it.
* **Handwriting Recognition:** While it has a class for "Handwritten" documents, it only detects the *presence* of handwriting, it cannot read what is written.
* **Non-Document Images:** The model will perform poorly on natural images (photos of objects, people, landscapes).

## Bias, Risks, and Limitations

Users should be aware of the following technical limitations based on evaluation analysis:

* **Resolution Sensitivity (The "Blur" Problem):** The model inputs are resized to `224x224`. At this low resolution, dense text pages look like blurry gray blocks. This causes significant confusion between classes defined by dense text, specifically distinguishing **Scientific Reports** from generic **File Folders**.
* **Visual Similarity:** The model sometimes struggles to differentiate between **Forms** and **Questionnaires**, as they share very similar visual structures (checkboxes, lines, header fields).
* **Dataset Bias:** The model was trained on the RVL-CDIP dataset, which consists primarily of older, grayscale, lower-quality scans. It may have lower accuracy on modern, born-digital, color PDF documents.

## How to Get Started with the Model

Use the code block below to load the model architecture, load your trained weights, preprocess an image, and run inference.

```python
import torch
from torchvision import models, transforms
from PIL import Image

# --- Setup ---
# 1. Define the 16 distinct classes
class_names = [
    'advertisement', 'budget', 'email', 'file folder', 'form', 'handwritten', 
    'invoice', 'letter', 'memo', 'news article', 'presentation', 'questionnaire', 
    'resume', 'scientific publication', 'scientific report', 'specification'
]

# 2. Define the preprocessing transformation (Must match training!)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # Standard ImageNet normalization
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# --- Model Loading ---
# 3. Load the ResNet-50 architecture and replace the final layer
model = models.resnet50(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, len(class_names))

# 4. Load your trained weights (ensure path is correct)
# Note: map_location='cpu' ensures it loads even without a GPU
checkpoint = torch.load("resnet50_epoch_4.pth", map_location=torch.device('cpu'))
# Handle potential differences in how state_dict was saved
state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
model.load_state_dict(state_dict)

model.eval() # Set to evaluation mode

# --- Inference ---
# 5. Load and preprocess an image
image_path = "path_to_your_test_document.jpg" 
image = Image.open(image_path).convert('RGB') # Ensure 3 channels
input_tensor = transform(image).unsqueeze(0) # Add batch dimension (B, C, H, W)

# 6. Predict
with torch.no_grad():
    outputs = model(input_tensor)
    probabilities = torch.nn.functional.softmax(outputs, dim=1)
    top_prob, top_catid = torch.topk(probabilities, 1)
    
    print(f"Prediction: {class_names[top_catid.item()]}")
    print(f"Confidence: {top_prob.item()*100:.2f}%")
```



## Training Details

### Training Data

The model was trained on the **RVL-CDIP (Ryerson Vision Lab Complex Document Information Processing)** dataset.

* **Total Size:** 400,000 grayscale images.
* **Classes:** 16 perfectly balanced classes.
* **Split:** The standard split is 320k Train, 40k Validation, 40k Test. This model was trained on the 20k images per class with 2.5k Images per class for Val and 2.5k Images per class for Test.
* **Data Handling:** Original grayscale images were converted to 3-channel RGB to match the input expectations of the trained ResNet backbone.

### Training Procedure

#### Preprocessing
All images were resized to `224x224` pixels and normalized using standard ImageNet statistics (`mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]`).

#### Training Hyperparameters
The training used standard, stable hyperparameters for fine-tuning CNNs:

* **Optimizer:** SGD (Stochastic Gradient Descent)
* **Learning Rate:** 0.01
* **Momentum:** 0.9
* **Batch Size:** 64
* **Epochs:** 5
* **Training Regime:** Mixed Precision (Automatic Mixed Precision used implicitly via PyTorch for speed and memory efficiency).

## Evaluation

### Testing Data, Factors & Metrics

The model was evaluated on the standard, unseen **RVL-CDIP Test Split** containing 40,000 images.

**Metrics used:**
* **Accuracy:** The percentage of predictions that exactly matched the ground truth.
* **Top-3 Accuracy:** The percentage of times the correct label appeared in the model's top three highest-probability predictions. This is often the most relevant metric for human-in-the-loop triage systems.
* **Precision/Recall/F1-Score:** Evaluated on a per-class basis to identify specific strengths and weaknesses in the model's performance.

### Results

| Metric | Result | Notes |
| :--- | :--- | :--- |
| **Overall Accuracy** | **88.46%** | Solid baseline performance. |
| **Top-3 Accuracy** | **95.62%** | Excellent reliability for triage tasks. |

![Loss and Accuracy Curves](results/loss_and_acc_curve.png)

#### Confusion Matrix
![Confusion Matrix](results/cm.png)

#### Detailed Classificatio report
![Detailed Classification report](results/detailed_classification_report.png)


#### Detailed Performance Analysis (The "Traffic Light" Report)

An analysis of per-class F1-scores reveals distinct tiers of performance:

* 🟢 **Excellent (>90% F1):** `Email`, `Resume`, `Memo`, `Handwritten`, `Specification`. The model is highly reliable for core administrative documents with distinct visual structures.
* 🟡 **Reliable (~85-89% F1):** `Invoice`, `Advertisement`, `News Article`, `Budget`.
* 🔴 **Challenging (<75% Precision):** `Scientific Report`, `Form`, `File Folder`. The major weakness is misclassifying Scientific Reports as File Folders due to resolution constraints blurring the dense text.

## Environmental Impact

The training was conducted locally on consumer-grade hardware, resulting in negligible environmental impact compared to large-scale language model training.

* **Hardware Type:** Apple M-Series Chip / single NVIDIA GPU
* **Hours used:** Approximately 5 hours (1 hour per epoch)
* **Carbon Emitted:** Negligible local usage.

## Technical Specifications

### Model Architecture and Objective

The model consists of the **ResNet-50 backbone** (a 50-layer deep Convolutional Neural Network using residual connections and bottleneck blocks) followed by a custom classification head.

* **Input Shape:** `(Batch_Size, 3, 224, 224)` (RGB Images)
* **Backbone Output:** 2048 feature maps of size `7x7`.
* **Pooling:** Global Average Pooling reduces dimensions to `(Batch_Size, 2048)`.
* **Classification Head:** A single fully connected linear layer mapping 2048 features to 16 class logits.
* **Objective:** Minimize Cross-Entropy Loss between predicted logits and ground truth class labels.

## Citation

If you use this model or the RVL-CDIP dataset, please cite the original paper:

**BibTeX:**

```bibtex
@inproceedings{harley2015icdar,
  title = {Evaluation of Deep Convolutional Nets for Document Image Classification and Retrieval},
  author = {Adam W. Harley and Alex Ufkes and Konstantinos G. Derpanis},
  booktitle = {International Conference on Document Analysis and Recognition (ICDAR)},
  year = {2015}
}