|
|
--- |
|
|
license: apache-2.0 |
|
|
tags: |
|
|
- image-classification |
|
|
- coral |
|
|
- marine-biology |
|
|
- conservation |
|
|
- underwater |
|
|
- reef-monitoring |
|
|
- efficientnet |
|
|
- pytorch |
|
|
- timm |
|
|
datasets: |
|
|
- coralscapes |
|
|
- koh-tao-coral-condition |
|
|
pipeline_tag: image-classification |
|
|
library_name: timm |
|
|
--- |
|
|
|
|
|
# Naturecode Coral |
|
|
|
|
|
A deep learning model for automated coral health assessment, classifying underwater coral images as **healthy** or **unhealthy** (bleached/diseased/dead). |
|
|
|
|
|
## Model Description |
|
|
|
|
|
This model is designed for coral reef monitoring and conservation efforts. It uses an EfficientNetV2-M backbone trained on multiple coral health datasets to provide robust cross-domain generalization. |
|
|
|
|
|
### Architecture |
|
|
- **Base Model**: EfficientNetV2-M (timm) |
|
|
- **Input Size**: 384x384 RGB images |
|
|
- **Output**: Binary classification (healthy/unhealthy) |
|
|
- **Parameters**: ~54M |
|
|
|
|
|
### Training Data |
|
|
- **Koh Tao Coral Condition Dataset**: Expert-labeled coral images from Thailand |
|
|
- **Coralscapes**: Semantic segmentation dataset with derived health labels |
|
|
- **Combined training**: ~15,000 labeled images across multiple reef ecosystems |
|
|
|
|
|
## Performance |
|
|
|
|
|
### Cross-Domain Test Results (Coralscapes Test Set) |
|
|
|
|
|
| Configuration | Balanced Accuracy | Healthy Recall | Unhealthy Recall | |
|
|
|--------------|-------------------|----------------|------------------| |
|
|
| Standard (threshold=0.5) | 82.2% | 88.0% | 76.5% | |
|
|
| **Optimized (threshold=0.35)** | **84.6%** | 80.9% | **88.2%** | |
|
|
| High Sensitivity (threshold=0.25) | 82.9% | 73.6% | 92.2% | |
|
|
|
|
|
### Recommended Settings |
|
|
- **Default threshold**: 0.35 (balanced performance) |
|
|
- **High sensitivity screening**: 0.25 (catches 92%+ of unhealthy coral) |
|
|
- **Test-Time Augmentation (TTA)**: Enabled for best results |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Installation |
|
|
|
|
|
```bash |
|
|
pip install torch torchvision timm pillow |
|
|
``` |
|
|
|
|
|
### Basic Inference |
|
|
|
|
|
```python |
|
|
import torch |
|
|
import timm |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
|
|
|
# Load model |
|
|
model = timm.create_model('tf_efficientnetv2_m', pretrained=False, num_classes=2) |
|
|
checkpoint = torch.load('model.pt', map_location='cpu', weights_only=False) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.eval() |
|
|
|
|
|
# Preprocessing |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((384, 384)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
# Predict |
|
|
image = Image.open('coral_image.jpg').convert('RGB') |
|
|
input_tensor = transform(image).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model(input_tensor) |
|
|
probs = torch.softmax(output, dim=1) |
|
|
|
|
|
# Using optimized threshold |
|
|
threshold = 0.35 |
|
|
unhealthy_prob = probs[0, 1].item() |
|
|
prediction = 'unhealthy' if unhealthy_prob >= threshold else 'healthy' |
|
|
|
|
|
print(f"Prediction: {prediction} (unhealthy probability: {unhealthy_prob:.1%})") |
|
|
``` |
|
|
|
|
|
### Inference with Test-Time Augmentation (Recommended) |
|
|
|
|
|
```python |
|
|
import torch |
|
|
import timm |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
|
|
|
def get_tta_transforms(img_size=384): |
|
|
"""Get TTA transforms for more robust predictions.""" |
|
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
return [ |
|
|
transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(), normalize]), |
|
|
transforms.Compose([transforms.Resize((img_size, img_size)), transforms.RandomHorizontalFlip(p=1.0), transforms.ToTensor(), normalize]), |
|
|
transforms.Compose([transforms.Resize((img_size, img_size)), transforms.RandomVerticalFlip(p=1.0), transforms.ToTensor(), normalize]), |
|
|
transforms.Compose([transforms.Resize((img_size, img_size)), transforms.Lambda(lambda x: x.rotate(90)), transforms.ToTensor(), normalize]), |
|
|
transforms.Compose([transforms.Resize((img_size, img_size)), transforms.Lambda(lambda x: x.rotate(270)), transforms.ToTensor(), normalize]), |
|
|
] |
|
|
|
|
|
def predict_with_tta(model, image_path, device='cuda', threshold=0.35): |
|
|
"""Predict with test-time augmentation.""" |
|
|
model.eval() |
|
|
img = Image.open(image_path).convert('RGB') |
|
|
tta_transforms = get_tta_transforms() |
|
|
|
|
|
all_logits = [] |
|
|
with torch.no_grad(): |
|
|
for transform in tta_transforms: |
|
|
img_tensor = transform(img).unsqueeze(0).to(device) |
|
|
logits = model(img_tensor) |
|
|
all_logits.append(logits) |
|
|
|
|
|
# Average predictions |
|
|
avg_logits = torch.stack(all_logits).mean(dim=0) |
|
|
probs = torch.softmax(avg_logits, dim=1) |
|
|
unhealthy_prob = probs[0, 1].item() |
|
|
|
|
|
prediction = 'unhealthy' if unhealthy_prob >= threshold else 'healthy' |
|
|
confidence = unhealthy_prob if prediction == 'unhealthy' else (1 - unhealthy_prob) |
|
|
|
|
|
return { |
|
|
'prediction': prediction, |
|
|
'confidence': confidence, |
|
|
'unhealthy_probability': unhealthy_prob, |
|
|
'healthy_probability': 1 - unhealthy_prob |
|
|
} |
|
|
|
|
|
# Example usage |
|
|
model = timm.create_model('tf_efficientnetv2_m', pretrained=False, num_classes=2) |
|
|
checkpoint = torch.load('model.pt', map_location='cuda', weights_only=False) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model = model.cuda() |
|
|
|
|
|
result = predict_with_tta(model, 'coral_image.jpg', threshold=0.35) |
|
|
print(f"Prediction: {result['prediction']} ({result['confidence']:.1%} confidence)") |
|
|
``` |
|
|
|
|
|
## Deployment Recommendations |
|
|
|
|
|
### For Coral Monitoring Applications |
|
|
|
|
|
1. **Assisted Screening (Recommended)** |
|
|
- Use threshold 0.25-0.35 for automated flagging |
|
|
- Human expert reviews flagged images |
|
|
- Catches 88-92% of unhealthy coral |
|
|
|
|
|
2. **Balanced Operation** |
|
|
- Use threshold 0.5 for equal treatment of both classes |
|
|
- Best when false positives and false negatives are equally costly |
|
|
|
|
|
3. **Two-Stage Pipeline** |
|
|
``` |
|
|
Stage 1: Automated screening (threshold=0.25) -> Flag potential issues |
|
|
Stage 2: Human expert review -> Final classification |
|
|
``` |
|
|
|
|
|
## Limitations |
|
|
|
|
|
- **Not for fully autonomous decisions**: Missing ~12% of unhealthy coral at optimal settings |
|
|
- **Domain shift**: Performance may vary on significantly different reef ecosystems |
|
|
- **Image quality**: Best results with clear underwater images, may struggle with turbid water |
|
|
- **Lighting conditions**: Trained primarily on natural reef lighting |
|
|
|
|
|
## Intended Use |
|
|
|
|
|
- Coral reef health monitoring |
|
|
- Marine conservation research |
|
|
- Reef restoration project assessment |
|
|
- Educational and research purposes |
|
|
|
|
|
## Ethical Considerations |
|
|
|
|
|
This model is intended to assist, not replace, marine biologists and conservation experts. Critical decisions about reef management should involve human expertise. The model should be validated on local reef conditions before deployment. |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use this model, please cite: |
|
|
|
|
|
```bibtex |
|
|
@software{naturecode_coral_2025, |
|
|
title = {Naturecode Coral}, |
|
|
author = {Naturecode}, |
|
|
year = {2025}, |
|
|
|
|
|
url = {https://huggingface.co/Naturecode/coral} |
|
|
} |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
Apache 2.0 |
|
|
|
|
|
## Contact |
|
|
|
|
|
For questions, issues, or collaboration inquiries, please open an issue on the model repository. |
|
|
|