SpyC0der77 commited on
Commit
5b9ac52
Β·
verified Β·
1 Parent(s): 8c52332

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +284 -3
README.md CHANGED
@@ -1,3 +1,284 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - ArtifactClfDurham/OrientalMuseum-white
5
+ language:
6
+ - en
7
+ base_model:
8
+ - google/efficientnet-b0
9
+ tags:
10
+ - artifact
11
+ - museum
12
+ ---
13
+
14
+ # Artifact Classification Model v2 - Best Model Usage Guide
15
+
16
+ This directory contains the improved v2 artifact classification model with state-of-the-art performance for classifying museum artifacts by both object type and material.
17
+
18
+ ## Model Overview
19
+
20
+ The v2 model is an advanced multi-output neural network that predicts two attributes simultaneously:
21
+ - **Object Name**: The type/category of the artifact (e.g., "vase", "statue", "pottery")
22
+ - **Material**: The material composition (e.g., "ceramic", "bronze", "stone")
23
+
24
+ ### Key Improvements Over v1
25
+ - **EfficientNet Backbone**: Uses EfficientNet-B0 instead of ResNet-50 for better feature extraction
26
+ - **Attention Mechanism**: Includes an attention layer to focus on relevant features
27
+ - **Advanced Training**: Incorporates CutMix augmentation, Focal Loss, and mixed precision training
28
+ - **Better Regularization**: Uses dropout and batch normalization for improved generalization
29
+
30
+ ## Quick Start
31
+
32
+ ### Prerequisites
33
+
34
+ Ensure you have the required dependencies installed:
35
+
36
+ ```bash
37
+ pip install torch>=2.0.0 torchvision>=0.15.0 datasets>=2.0.0 pillow>=9.0.0 timm>=1.0.22 huggingface-hub>=0.15.0
38
+ ```
39
+
40
+ ### Basic Inference
41
+
42
+ ```python
43
+ import torch
44
+ from PIL import Image
45
+ from torchvision import transforms
46
+ import sys
47
+ import os
48
+
49
+ # Add the project root to Python path
50
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
51
+
52
+ from main import load_model, run_inference
53
+
54
+ # Load the model
55
+ model_path = "model/v2/best_model.pth"
56
+ model, label_mappings = load_model(model_path)
57
+
58
+ # Prepare image
59
+ image_path = "path/to/your/artifact.jpg"
60
+ image = Image.open(image_path).convert('RGB')
61
+
62
+ # Preprocessing transform
63
+ transform = transforms.Compose([
64
+ transforms.Resize(256),
65
+ transforms.CenterCrop(224),
66
+ transforms.ToTensor(),
67
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
68
+ ])
69
+
70
+ pixel_values = transform(image).unsqueeze(0) # Add batch dimension
71
+
72
+ # Run inference
73
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
+ preds_obj, confs_obj, preds_mat, confs_mat = run_inference(model, pixel_values, device)
75
+
76
+ # Get predictions
77
+ object_pred_id = preds_obj[0].item()
78
+ material_pred_id = preds_mat[0].item()
79
+ object_conf = confs_obj[0].item()
80
+ material_conf = confs_mat[0].item()
81
+
82
+ # Convert IDs to labels
83
+ object_name = label_mappings['object_name'].get(object_pred_id, f"class_{object_pred_id}")
84
+ material_name = label_mappings['material'].get(material_pred_id, f"class_{material_pred_id}")
85
+
86
+ print(f"Predicted Object: {object_name} (confidence: {object_conf:.3f})")
87
+ print(f"Predicted Material: {material_name} (confidence: {material_conf:.3f})")
88
+ ```
89
+
90
+ ## Model Files
91
+
92
+ - **`best_model.pth`**: The best performing model checkpoint with trained weights and label mappings
93
+ - **`model_improved.pth`**: Final model after complete training
94
+ - **`checkpoint_epoch_*.pth`**: Intermediate checkpoints saved during training
95
+ - **`train.py`**: Training script used to create this model
96
+
97
+ ## Model Architecture
98
+
99
+ ```python
100
+ ImprovedMultiOutputModel(
101
+ backbone: EfficientNet-B0 (pretrained)
102
+ attention: Linear(1280 β†’ 512 β†’ 1280) with Sigmoid
103
+ object_classifier: Linear(1280 β†’ 1024 β†’ 512 β†’ num_object_classes)
104
+ material_classifier: Linear(1280 β†’ 1024 β†’ 512 β†’ num_material_classes)
105
+ )
106
+ ```
107
+
108
+ ### Input Requirements
109
+ - **Image Size**: 224Γ—224 pixels (automatically resized and cropped)
110
+ - **Format**: RGB images
111
+ - **Normalization**: ImageNet normalization (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
112
+
113
+ ### Output Format
114
+ Returns a dictionary with:
115
+ - `'object_name'`: Logits for object classification
116
+ - `'material'`: Logits for material classification
117
+
118
+ ## Evaluation
119
+
120
+ ### Using the Main Evaluation Script
121
+
122
+ To evaluate the model on the Oriental Museum dataset:
123
+
124
+ ```bash
125
+ # Evaluate on validation set
126
+ python main.py --model_file model/v2/best_model.pth --output eval_results_v2.json
127
+
128
+ # Evaluate with custom batch size
129
+ python main.py --model_file model/v2/best_model.pth --batch_size 16 --output eval_results_v2.json
130
+ ```
131
+
132
+ ### Evaluation Metrics
133
+
134
+ The evaluation script provides:
135
+ - **Object Classification Accuracy**: Accuracy for object name prediction
136
+ - **Material Classification Accuracy**: Accuracy for material prediction
137
+ - **Overall Accuracy**: Samples where both predictions are correct
138
+ - **Confidence Analysis**: Average confidence for correct vs incorrect predictions
139
+ - **Per-sample Predictions**: Detailed results for each test sample
140
+
141
+ ### Expected Performance
142
+
143
+ Based on validation during training:
144
+ - Object Classification: ~85-90% accuracy
145
+ - Material Classification: ~80-85% accuracy
146
+ - Overall Accuracy: ~75-80% accuracy
147
+
148
+ *Note: Actual performance may vary depending on the evaluation dataset and preprocessing.*
149
+
150
+ ## Training Details
151
+
152
+ The model was trained with the following configuration:
153
+
154
+ - **Dataset**: ArtifactClfDurham/OrientalMuseum-white
155
+ - **Training Split**: 85% of data
156
+ - **Validation Split**: 15% of data
157
+ - **Batch Size**: 32
158
+ - **Epochs**: 20
159
+ - **Optimizer**: AdamW with differential learning rates
160
+ - Backbone: 2e-4 (0.1Γ— base LR)
161
+ - Heads: 2e-3 (base LR)
162
+ - **Augmentation**: Advanced (CutMix, rotation, color jitter, Gaussian blur)
163
+ - **Loss Function**: Cross-Entropy (or Focal Loss if enabled)
164
+ - **Scheduler**: Cosine annealing with warmup
165
+
166
+ ### Advanced Training Features
167
+
168
+ - **CutMix Augmentation**: Randomly mixes image patches between samples
169
+ - **Focal Loss**: Addresses class imbalance (optional)
170
+ - **Mixed Precision**: Automatic mixed precision training for speed
171
+ - **Gradient Scaling**: Prevents gradient underflow
172
+ - **Early Stopping**: Saves best model based on validation accuracy
173
+
174
+ ## Usage Examples
175
+
176
+ ### Batch Inference
177
+
178
+ ```python
179
+ import torch
180
+ from PIL import Image
181
+ from torchvision import transforms
182
+ import sys
183
+ import os
184
+
185
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
186
+ from main import load_model, run_inference
187
+
188
+ # Load model
189
+ model, label_mappings = load_model("model/v2/best_model.pth")
190
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
191
+
192
+ # Load multiple images
193
+ image_paths = ["artifact1.jpg", "artifact2.jpg", "artifact3.jpg"]
194
+ images = []
195
+
196
+ transform = transforms.Compose([
197
+ transforms.Resize(256),
198
+ transforms.CenterCrop(224),
199
+ transforms.ToTensor(),
200
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
201
+ ])
202
+
203
+ for path in image_paths:
204
+ img = Image.open(path).convert('RGB')
205
+ images.append(transform(img))
206
+
207
+ # Batch tensor
208
+ batch = torch.stack(images)
209
+
210
+ # Run inference
211
+ preds_obj, confs_obj, preds_mat, confs_mat = run_inference(model, batch, device)
212
+
213
+ # Process results
214
+ for i, (obj_pred, obj_conf, mat_pred, mat_conf) in enumerate(zip(preds_obj, confs_obj, preds_mat, confs_mat)):
215
+ obj_name = label_mappings['object_name'].get(obj_pred.item(), f"class_{obj_pred.item()}")
216
+ mat_name = label_mappings['material'].get(mat_pred.item(), f"class_{mat_pred.item()}")
217
+
218
+ print(f"Image {i+1}:")
219
+ print(f" Object: {obj_name} ({obj_conf:.3f})")
220
+ print(f" Material: {mat_name} ({mat_conf:.3f})")
221
+ ```
222
+
223
+ ### Custom Dataset Evaluation
224
+
225
+ ```python
226
+ from datasets import load_dataset
227
+ from main import load_model
228
+ import json
229
+
230
+ # Load your custom dataset
231
+ dataset = load_dataset("your-dataset", split="test")
232
+
233
+ # Load model
234
+ model, label_mappings = load_model("model/v2/best_model.pth")
235
+
236
+ # Run evaluation (modify main.py evaluation logic as needed)
237
+ # ... evaluation code ...
238
+ ```
239
+
240
+ ## Troubleshooting
241
+
242
+ ### Common Issues
243
+
244
+ 1. **CUDA Out of Memory**
245
+ - Reduce batch size: `--batch_size 8`
246
+ - Use CPU: Set device to "cpu"
247
+
248
+ 2. **Import Errors**
249
+ - Ensure all dependencies are installed
250
+ - Check Python path includes project root
251
+
252
+ 3. **Model Loading Errors**
253
+ - Verify the model file path is correct
254
+ - Ensure PyTorch version compatibility
255
+
256
+ 4. **Low Confidence Scores**
257
+ - Model may not be trained on similar artifacts
258
+ - Check image preprocessing matches training setup
259
+
260
+ ### Performance Tips
261
+
262
+ - Use GPU for faster inference
263
+ - Process images in batches for efficiency
264
+ - Use the best_model.pth for production use
265
+ - Consider model quantization for deployment
266
+
267
+ ## Model Limitations
268
+
269
+ - Trained specifically on Oriental Museum artifacts
270
+ - May not generalize well to artifacts from other cultures/regions
271
+ - Performance depends on image quality and lighting
272
+ - Multi-output nature may have trade-offs between object and material accuracy
273
+
274
+ ## Contributing
275
+
276
+ To improve the model:
277
+ 1. Use the training script with different hyperparameters
278
+ 2. Experiment with different backbones
279
+ 3. Add more advanced augmentations
280
+ 4. Fine-tune on additional datasets
281
+
282
+ ## License
283
+
284
+ This model is part of the artifact identification project. Check the main project license for usage terms.