Update README.md
Browse files
README.md
CHANGED
|
@@ -1,3 +1,284 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: mit
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
datasets:
|
| 4 |
+
- ArtifactClfDurham/OrientalMuseum-white
|
| 5 |
+
language:
|
| 6 |
+
- en
|
| 7 |
+
base_model:
|
| 8 |
+
- google/efficientnet-b0
|
| 9 |
+
tags:
|
| 10 |
+
- artifact
|
| 11 |
+
- museum
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# Artifact Classification Model v2 - Best Model Usage Guide
|
| 15 |
+
|
| 16 |
+
This directory contains the improved v2 artifact classification model with state-of-the-art performance for classifying museum artifacts by both object type and material.
|
| 17 |
+
|
| 18 |
+
## Model Overview
|
| 19 |
+
|
| 20 |
+
The v2 model is an advanced multi-output neural network that predicts two attributes simultaneously:
|
| 21 |
+
- **Object Name**: The type/category of the artifact (e.g., "vase", "statue", "pottery")
|
| 22 |
+
- **Material**: The material composition (e.g., "ceramic", "bronze", "stone")
|
| 23 |
+
|
| 24 |
+
### Key Improvements Over v1
|
| 25 |
+
- **EfficientNet Backbone**: Uses EfficientNet-B0 instead of ResNet-50 for better feature extraction
|
| 26 |
+
- **Attention Mechanism**: Includes an attention layer to focus on relevant features
|
| 27 |
+
- **Advanced Training**: Incorporates CutMix augmentation, Focal Loss, and mixed precision training
|
| 28 |
+
- **Better Regularization**: Uses dropout and batch normalization for improved generalization
|
| 29 |
+
|
| 30 |
+
## Quick Start
|
| 31 |
+
|
| 32 |
+
### Prerequisites
|
| 33 |
+
|
| 34 |
+
Ensure you have the required dependencies installed:
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
pip install torch>=2.0.0 torchvision>=0.15.0 datasets>=2.0.0 pillow>=9.0.0 timm>=1.0.22 huggingface-hub>=0.15.0
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
### Basic Inference
|
| 41 |
+
|
| 42 |
+
```python
|
| 43 |
+
import torch
|
| 44 |
+
from PIL import Image
|
| 45 |
+
from torchvision import transforms
|
| 46 |
+
import sys
|
| 47 |
+
import os
|
| 48 |
+
|
| 49 |
+
# Add the project root to Python path
|
| 50 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
|
| 51 |
+
|
| 52 |
+
from main import load_model, run_inference
|
| 53 |
+
|
| 54 |
+
# Load the model
|
| 55 |
+
model_path = "model/v2/best_model.pth"
|
| 56 |
+
model, label_mappings = load_model(model_path)
|
| 57 |
+
|
| 58 |
+
# Prepare image
|
| 59 |
+
image_path = "path/to/your/artifact.jpg"
|
| 60 |
+
image = Image.open(image_path).convert('RGB')
|
| 61 |
+
|
| 62 |
+
# Preprocessing transform
|
| 63 |
+
transform = transforms.Compose([
|
| 64 |
+
transforms.Resize(256),
|
| 65 |
+
transforms.CenterCrop(224),
|
| 66 |
+
transforms.ToTensor(),
|
| 67 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 68 |
+
])
|
| 69 |
+
|
| 70 |
+
pixel_values = transform(image).unsqueeze(0) # Add batch dimension
|
| 71 |
+
|
| 72 |
+
# Run inference
|
| 73 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 74 |
+
preds_obj, confs_obj, preds_mat, confs_mat = run_inference(model, pixel_values, device)
|
| 75 |
+
|
| 76 |
+
# Get predictions
|
| 77 |
+
object_pred_id = preds_obj[0].item()
|
| 78 |
+
material_pred_id = preds_mat[0].item()
|
| 79 |
+
object_conf = confs_obj[0].item()
|
| 80 |
+
material_conf = confs_mat[0].item()
|
| 81 |
+
|
| 82 |
+
# Convert IDs to labels
|
| 83 |
+
object_name = label_mappings['object_name'].get(object_pred_id, f"class_{object_pred_id}")
|
| 84 |
+
material_name = label_mappings['material'].get(material_pred_id, f"class_{material_pred_id}")
|
| 85 |
+
|
| 86 |
+
print(f"Predicted Object: {object_name} (confidence: {object_conf:.3f})")
|
| 87 |
+
print(f"Predicted Material: {material_name} (confidence: {material_conf:.3f})")
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
## Model Files
|
| 91 |
+
|
| 92 |
+
- **`best_model.pth`**: The best performing model checkpoint with trained weights and label mappings
|
| 93 |
+
- **`model_improved.pth`**: Final model after complete training
|
| 94 |
+
- **`checkpoint_epoch_*.pth`**: Intermediate checkpoints saved during training
|
| 95 |
+
- **`train.py`**: Training script used to create this model
|
| 96 |
+
|
| 97 |
+
## Model Architecture
|
| 98 |
+
|
| 99 |
+
```python
|
| 100 |
+
ImprovedMultiOutputModel(
|
| 101 |
+
backbone: EfficientNet-B0 (pretrained)
|
| 102 |
+
attention: Linear(1280 β 512 β 1280) with Sigmoid
|
| 103 |
+
object_classifier: Linear(1280 β 1024 β 512 β num_object_classes)
|
| 104 |
+
material_classifier: Linear(1280 β 1024 β 512 β num_material_classes)
|
| 105 |
+
)
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
### Input Requirements
|
| 109 |
+
- **Image Size**: 224Γ224 pixels (automatically resized and cropped)
|
| 110 |
+
- **Format**: RGB images
|
| 111 |
+
- **Normalization**: ImageNet normalization (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 112 |
+
|
| 113 |
+
### Output Format
|
| 114 |
+
Returns a dictionary with:
|
| 115 |
+
- `'object_name'`: Logits for object classification
|
| 116 |
+
- `'material'`: Logits for material classification
|
| 117 |
+
|
| 118 |
+
## Evaluation
|
| 119 |
+
|
| 120 |
+
### Using the Main Evaluation Script
|
| 121 |
+
|
| 122 |
+
To evaluate the model on the Oriental Museum dataset:
|
| 123 |
+
|
| 124 |
+
```bash
|
| 125 |
+
# Evaluate on validation set
|
| 126 |
+
python main.py --model_file model/v2/best_model.pth --output eval_results_v2.json
|
| 127 |
+
|
| 128 |
+
# Evaluate with custom batch size
|
| 129 |
+
python main.py --model_file model/v2/best_model.pth --batch_size 16 --output eval_results_v2.json
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
### Evaluation Metrics
|
| 133 |
+
|
| 134 |
+
The evaluation script provides:
|
| 135 |
+
- **Object Classification Accuracy**: Accuracy for object name prediction
|
| 136 |
+
- **Material Classification Accuracy**: Accuracy for material prediction
|
| 137 |
+
- **Overall Accuracy**: Samples where both predictions are correct
|
| 138 |
+
- **Confidence Analysis**: Average confidence for correct vs incorrect predictions
|
| 139 |
+
- **Per-sample Predictions**: Detailed results for each test sample
|
| 140 |
+
|
| 141 |
+
### Expected Performance
|
| 142 |
+
|
| 143 |
+
Based on validation during training:
|
| 144 |
+
- Object Classification: ~85-90% accuracy
|
| 145 |
+
- Material Classification: ~80-85% accuracy
|
| 146 |
+
- Overall Accuracy: ~75-80% accuracy
|
| 147 |
+
|
| 148 |
+
*Note: Actual performance may vary depending on the evaluation dataset and preprocessing.*
|
| 149 |
+
|
| 150 |
+
## Training Details
|
| 151 |
+
|
| 152 |
+
The model was trained with the following configuration:
|
| 153 |
+
|
| 154 |
+
- **Dataset**: ArtifactClfDurham/OrientalMuseum-white
|
| 155 |
+
- **Training Split**: 85% of data
|
| 156 |
+
- **Validation Split**: 15% of data
|
| 157 |
+
- **Batch Size**: 32
|
| 158 |
+
- **Epochs**: 20
|
| 159 |
+
- **Optimizer**: AdamW with differential learning rates
|
| 160 |
+
- Backbone: 2e-4 (0.1Γ base LR)
|
| 161 |
+
- Heads: 2e-3 (base LR)
|
| 162 |
+
- **Augmentation**: Advanced (CutMix, rotation, color jitter, Gaussian blur)
|
| 163 |
+
- **Loss Function**: Cross-Entropy (or Focal Loss if enabled)
|
| 164 |
+
- **Scheduler**: Cosine annealing with warmup
|
| 165 |
+
|
| 166 |
+
### Advanced Training Features
|
| 167 |
+
|
| 168 |
+
- **CutMix Augmentation**: Randomly mixes image patches between samples
|
| 169 |
+
- **Focal Loss**: Addresses class imbalance (optional)
|
| 170 |
+
- **Mixed Precision**: Automatic mixed precision training for speed
|
| 171 |
+
- **Gradient Scaling**: Prevents gradient underflow
|
| 172 |
+
- **Early Stopping**: Saves best model based on validation accuracy
|
| 173 |
+
|
| 174 |
+
## Usage Examples
|
| 175 |
+
|
| 176 |
+
### Batch Inference
|
| 177 |
+
|
| 178 |
+
```python
|
| 179 |
+
import torch
|
| 180 |
+
from PIL import Image
|
| 181 |
+
from torchvision import transforms
|
| 182 |
+
import sys
|
| 183 |
+
import os
|
| 184 |
+
|
| 185 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
|
| 186 |
+
from main import load_model, run_inference
|
| 187 |
+
|
| 188 |
+
# Load model
|
| 189 |
+
model, label_mappings = load_model("model/v2/best_model.pth")
|
| 190 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 191 |
+
|
| 192 |
+
# Load multiple images
|
| 193 |
+
image_paths = ["artifact1.jpg", "artifact2.jpg", "artifact3.jpg"]
|
| 194 |
+
images = []
|
| 195 |
+
|
| 196 |
+
transform = transforms.Compose([
|
| 197 |
+
transforms.Resize(256),
|
| 198 |
+
transforms.CenterCrop(224),
|
| 199 |
+
transforms.ToTensor(),
|
| 200 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 201 |
+
])
|
| 202 |
+
|
| 203 |
+
for path in image_paths:
|
| 204 |
+
img = Image.open(path).convert('RGB')
|
| 205 |
+
images.append(transform(img))
|
| 206 |
+
|
| 207 |
+
# Batch tensor
|
| 208 |
+
batch = torch.stack(images)
|
| 209 |
+
|
| 210 |
+
# Run inference
|
| 211 |
+
preds_obj, confs_obj, preds_mat, confs_mat = run_inference(model, batch, device)
|
| 212 |
+
|
| 213 |
+
# Process results
|
| 214 |
+
for i, (obj_pred, obj_conf, mat_pred, mat_conf) in enumerate(zip(preds_obj, confs_obj, preds_mat, confs_mat)):
|
| 215 |
+
obj_name = label_mappings['object_name'].get(obj_pred.item(), f"class_{obj_pred.item()}")
|
| 216 |
+
mat_name = label_mappings['material'].get(mat_pred.item(), f"class_{mat_pred.item()}")
|
| 217 |
+
|
| 218 |
+
print(f"Image {i+1}:")
|
| 219 |
+
print(f" Object: {obj_name} ({obj_conf:.3f})")
|
| 220 |
+
print(f" Material: {mat_name} ({mat_conf:.3f})")
|
| 221 |
+
```
|
| 222 |
+
|
| 223 |
+
### Custom Dataset Evaluation
|
| 224 |
+
|
| 225 |
+
```python
|
| 226 |
+
from datasets import load_dataset
|
| 227 |
+
from main import load_model
|
| 228 |
+
import json
|
| 229 |
+
|
| 230 |
+
# Load your custom dataset
|
| 231 |
+
dataset = load_dataset("your-dataset", split="test")
|
| 232 |
+
|
| 233 |
+
# Load model
|
| 234 |
+
model, label_mappings = load_model("model/v2/best_model.pth")
|
| 235 |
+
|
| 236 |
+
# Run evaluation (modify main.py evaluation logic as needed)
|
| 237 |
+
# ... evaluation code ...
|
| 238 |
+
```
|
| 239 |
+
|
| 240 |
+
## Troubleshooting
|
| 241 |
+
|
| 242 |
+
### Common Issues
|
| 243 |
+
|
| 244 |
+
1. **CUDA Out of Memory**
|
| 245 |
+
- Reduce batch size: `--batch_size 8`
|
| 246 |
+
- Use CPU: Set device to "cpu"
|
| 247 |
+
|
| 248 |
+
2. **Import Errors**
|
| 249 |
+
- Ensure all dependencies are installed
|
| 250 |
+
- Check Python path includes project root
|
| 251 |
+
|
| 252 |
+
3. **Model Loading Errors**
|
| 253 |
+
- Verify the model file path is correct
|
| 254 |
+
- Ensure PyTorch version compatibility
|
| 255 |
+
|
| 256 |
+
4. **Low Confidence Scores**
|
| 257 |
+
- Model may not be trained on similar artifacts
|
| 258 |
+
- Check image preprocessing matches training setup
|
| 259 |
+
|
| 260 |
+
### Performance Tips
|
| 261 |
+
|
| 262 |
+
- Use GPU for faster inference
|
| 263 |
+
- Process images in batches for efficiency
|
| 264 |
+
- Use the best_model.pth for production use
|
| 265 |
+
- Consider model quantization for deployment
|
| 266 |
+
|
| 267 |
+
## Model Limitations
|
| 268 |
+
|
| 269 |
+
- Trained specifically on Oriental Museum artifacts
|
| 270 |
+
- May not generalize well to artifacts from other cultures/regions
|
| 271 |
+
- Performance depends on image quality and lighting
|
| 272 |
+
- Multi-output nature may have trade-offs between object and material accuracy
|
| 273 |
+
|
| 274 |
+
## Contributing
|
| 275 |
+
|
| 276 |
+
To improve the model:
|
| 277 |
+
1. Use the training script with different hyperparameters
|
| 278 |
+
2. Experiment with different backbones
|
| 279 |
+
3. Add more advanced augmentations
|
| 280 |
+
4. Fine-tune on additional datasets
|
| 281 |
+
|
| 282 |
+
## License
|
| 283 |
+
|
| 284 |
+
This model is part of the artifact identification project. Check the main project license for usage terms.
|