arpit-gour02 commited on
Commit
1321f6f
·
verified ·
1 Parent(s): 950e52a

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +228 -0
README.md ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: mit
5
+ library_name: pytorch
6
+ tags:
7
+ - computer-vision
8
+ - image-classification
9
+ - document-classification
10
+ - rvl-cdip
11
+ - resnet50
12
+ datasets:
13
+ - rvl_cdip
14
+ metrics:
15
+ - accuracy
16
+ - top-k-accuracy
17
+ pipeline_tag: image-classification
18
+ model-index:
19
+ - name: ResNet-50 Document Classifier
20
+ results:
21
+ - task:
22
+ type: image-classification
23
+ dataset:
24
+ type: rvl_cdip
25
+ name: RVL-CDIP (Test Split)
26
+ metrics:
27
+ - name: Accuracy
28
+ type: accuracy
29
+ value: 0.8846
30
+ - name: Top-3 Accuracy
31
+ type: top-k-accuracy
32
+ value: 0.9562
33
+ ---
34
+
35
+ # Model Card for ResNet-50 Document Classifier
36
+
37
+ ## Quick Summary
38
+
39
+ This model is a **ResNet-50** Convolutional Neural Network (CNN) finetuned to classify scanned document images into **16 categories** (e.g., Emails, Invoices, Resumes, Scientific Reports). It achieves **88.46% overall accuracy** on the RVL-CDIP test set and a very strong **95.62% Top-3 Accuracy**, making it highly effective for automated document triage and organization pipelines.
40
+
41
+ ## Model Details
42
+
43
+ ### Model Description
44
+
45
+ This model utilizes the standard ResNet-50 architecture designed for image classification. Instead of "reading" the text like an OCR system, it analyzes the visual layout, structure, and low-level texture features of a whole document page to determine its category (e.g., recognizing the block layout of a resume versus the dense, two-column text of a scientific report).
46
+
47
+ It was trained using **Transfer Learning**, starting with weights pre-trained on ImageNet and finetuning the backbone while retraining the classification head for the 16 document classes.
48
+
49
+ - **Developed by:** Arpit ([@arpit-gour02](https://huggingface.co/arpit-gour02))
50
+ - **Model type:** Computer Vision (Image Classification / CNN)
51
+ - **Language(s) (NLP):** English (Implicitly, via the text present in the RVL-CDIP dataset images)
52
+ - **License:** MIT
53
+ - **Finetuned from model:** ResNet-50 (ImageNet weights)
54
+
55
+ ### Model Sources
56
+
57
+ - **Demo (Gradio App):** [https://huggingface.co/spaces/arpit-gour02/document-classification-demo](https://huggingface.co/spaces/arpit-gour02/document-classification-demo)
58
+ - **Repository:** [https://huggingface.co/spaces/arpit-gour02/document-classification-demo/tree/main](https://huggingface.co/spaces/arpit-gour02/document-classification-demo/tree/main)
59
+
60
+ ## Uses
61
+
62
+ ### Direct Use
63
+
64
+ This model is specifically designed for **Document Triage and Automation Pipelines**. It is best used as an initial sorting mechanism:
65
+
66
+ 1. **Office Automation:** Automatically routing incoming scans to the correct department folder (e.g., sending "Invoices" to Accounting, "Resumes" to HR).
67
+ 2. **Archive Digitization:** Rapidly tagging metadata for large legacy paper archives.
68
+ 3. **Preprocessing Filter:** Acting as a cheap, fast gatekeeper before sending documents to expensive, specialized downstream systems (e.g., only sending confirmed "Forms" to a dedicated form-extraction model).
69
+
70
+ ### Out-of-Scope Use
71
+
72
+ This model is **not** suitable for:
73
+
74
+ * **Text Extraction (OCR):** It classifies the *type* of document, it does not output the text written on it.
75
+ * **Handwriting Recognition:** While it has a class for "Handwritten" documents, it only detects the *presence* of handwriting, it cannot read what is written.
76
+ * **Non-Document Images:** The model will perform poorly on natural images (photos of objects, people, landscapes).
77
+
78
+ ## Bias, Risks, and Limitations
79
+
80
+ Users should be aware of the following technical limitations based on evaluation analysis:
81
+
82
+ * **Resolution Sensitivity (The "Blur" Problem):** The model inputs are resized to `224x224`. At this low resolution, dense text pages look like blurry gray blocks. This causes significant confusion between classes defined by dense text, specifically distinguishing **Scientific Reports** from generic **File Folders**.
83
+ * **Visual Similarity:** The model sometimes struggles to differentiate between **Forms** and **Questionnaires**, as they share very similar visual structures (checkboxes, lines, header fields).
84
+ * **Dataset Bias:** The model was trained on the RVL-CDIP dataset, which consists primarily of older, grayscale, lower-quality scans. It may have lower accuracy on modern, born-digital, color PDF documents.
85
+
86
+ ## How to Get Started with the Model
87
+
88
+ Use the code block below to load the model architecture, load your trained weights, preprocess an image, and run inference.
89
+
90
+ ```python
91
+ import torch
92
+ from torchvision import models, transforms
93
+ from PIL import Image
94
+
95
+ # --- Setup ---
96
+ # 1. Define the 16 distinct classes
97
+ class_names = [
98
+ 'advertisement', 'budget', 'email', 'file folder', 'form', 'handwritten',
99
+ 'invoice', 'letter', 'memo', 'news article', 'presentation', 'questionnaire',
100
+ 'resume', 'scientific publication', 'scientific report', 'specification'
101
+ ]
102
+
103
+ # 2. Define the preprocessing transformation (Must match training!)
104
+ transform = transforms.Compose([
105
+ transforms.Resize((224, 224)),
106
+ transforms.ToTensor(),
107
+ # Standard ImageNet normalization
108
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
109
+ ])
110
+
111
+ # --- Model Loading ---
112
+ # 3. Load the ResNet-50 architecture and replace the final layer
113
+ model = models.resnet50(pretrained=False)
114
+ num_ftrs = model.fc.in_features
115
+ model.fc = torch.nn.Linear(num_ftrs, len(class_names))
116
+
117
+ # 4. Load your trained weights (ensure path is correct)
118
+ # Note: map_location='cpu' ensures it loads even without a GPU
119
+ checkpoint = torch.load("resnet50_epoch_4.pth", map_location=torch.device('cpu'))
120
+ # Handle potential differences in how state_dict was saved
121
+ state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
122
+ model.load_state_dict(state_dict)
123
+
124
+ model.eval() # Set to evaluation mode
125
+
126
+ # --- Inference ---
127
+ # 5. Load and preprocess an image
128
+ image_path = "path_to_your_test_document.jpg"
129
+ image = Image.open(image_path).convert('RGB') # Ensure 3 channels
130
+ input_tensor = transform(image).unsqueeze(0) # Add batch dimension (B, C, H, W)
131
+
132
+ # 6. Predict
133
+ with torch.no_grad():
134
+ outputs = model(input_tensor)
135
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
136
+ top_prob, top_catid = torch.topk(probabilities, 1)
137
+
138
+ print(f"Prediction: {class_names[top_catid.item()]}")
139
+ print(f"Confidence: {top_prob.item()*100:.2f}%")
140
+ ```
141
+
142
+
143
+
144
+ ## Training Details
145
+
146
+ ### Training Data
147
+
148
+ The model was trained on the **RVL-CDIP (Ryerson Vision Lab Complex Document Information Processing)** dataset.
149
+
150
+ * **Total Size:** 400,000 grayscale images.
151
+ * **Classes:** 16 perfectly balanced classes.
152
+ * **Split:** The standard split is 320k Train, 40k Validation, 40k Test. This model was trained on the 25,000 images per class available in the training set.
153
+ * **Data Handling:** Original grayscale images were converted to 3-channel RGB to match the input expectations of the pre-trained ResNet backbone.
154
+
155
+ ### Training Procedure
156
+
157
+ #### Preprocessing
158
+ All images were resized to `224x224` pixels and normalized using standard ImageNet statistics (`mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]`).
159
+
160
+ #### Training Hyperparameters
161
+ The training used standard, stable hyperparameters for fine-tuning CNNs:
162
+
163
+ * **Optimizer:** SGD (Stochastic Gradient Descent)
164
+ * **Learning Rate:** 0.01
165
+ * **Momentum:** 0.9
166
+ * **Batch Size:** 64
167
+ * **Epochs:** 5
168
+ * **Training Regime:** Mixed Precision (Automatic Mixed Precision used implicitly via PyTorch for speed and memory efficiency).
169
+
170
+ ## Evaluation
171
+
172
+ ### Testing Data, Factors & Metrics
173
+
174
+ The model was evaluated on the standard, unseen **RVL-CDIP Test Split** containing 40,000 images.
175
+
176
+ **Metrics used:**
177
+ * **Accuracy:** The percentage of predictions that exactly matched the ground truth.
178
+ * **Top-3 Accuracy:** The percentage of times the correct label appeared in the model's top three highest-probability predictions. This is often the most relevant metric for human-in-the-loop triage systems.
179
+ * **Precision/Recall/F1-Score:** Evaluated on a per-class basis to identify specific strengths and weaknesses in the model's performance.
180
+
181
+ ### Results
182
+
183
+ | Metric | Result | Notes |
184
+ | :--- | :--- | :--- |
185
+ | **Overall Accuracy** | **88.46%** | Solid baseline performance. |
186
+ | **Top-3 Accuracy** | **95.62%** | Excellent reliability for triage tasks. |
187
+
188
+ #### Detailed Performance Analysis (The "Traffic Light" Report)
189
+
190
+ An analysis of per-class F1-scores reveals distinct tiers of performance:
191
+
192
+ * 🟢 **Excellent (>90% F1):** `Email`, `Resume`, `Memo`, `Handwritten`, `Specification`. The model is highly reliable for core administrative documents with distinct visual structures.
193
+ * 🟡 **Reliable (~85-89% F1):** `Invoice`, `Advertisement`, `News Article`, `Budget`.
194
+ * 🔴 **Challenging (<75% Precision):** `Scientific Report`, `Form`, `File Folder`. The major weakness is misclassifying Scientific Reports as File Folders due to resolution constraints blurring the dense text.
195
+
196
+ ## Environmental Impact
197
+
198
+ The training was conducted locally on consumer-grade hardware, resulting in negligible environmental impact compared to large-scale language model training.
199
+
200
+ * **Hardware Type:** Apple M-Series Chip / single NVIDIA GPU
201
+ * **Hours used:** Approximately 5 hours (1 hour per epoch)
202
+ * **Carbon Emitted:** Negligible local usage.
203
+
204
+ ## Technical Specifications
205
+
206
+ ### Model Architecture and Objective
207
+
208
+ The model consists of the **ResNet-50 backbone** (a 50-layer deep Convolutional Neural Network using residual connections and bottleneck blocks) followed by a custom classification head.
209
+
210
+ * **Input Shape:** `(Batch_Size, 3, 224, 224)` (RGB Images)
211
+ * **Backbone Output:** 2048 feature maps of size `7x7`.
212
+ * **Pooling:** Global Average Pooling reduces dimensions to `(Batch_Size, 2048)`.
213
+ * **Classification Head:** A single fully connected linear layer mapping 2048 features to 16 class logits.
214
+ * **Objective:** Minimize Cross-Entropy Loss between predicted logits and ground truth class labels.
215
+
216
+ ## Citation
217
+
218
+ If you use this model or the RVL-CDIP dataset, please cite the original paper:
219
+
220
+ **BibTeX:**
221
+
222
+ ```bibtex
223
+ @inproceedings{harley2015icdar,
224
+ title = {Evaluation of Deep Convolutional Nets for Document Image Classification and Retrieval},
225
+ author = {Adam W. Harley and Alex Ufkes and Konstantinos G. Derpanis},
226
+ booktitle = {International Conference on Document Analysis and Recognition (ICDAR)},
227
+ year = {2015}
228
+ }