Naturecode Coral
A deep learning model for automated coral health assessment, classifying underwater coral images as healthy or unhealthy (bleached/diseased/dead).
Model Description
This model is designed for coral reef monitoring and conservation efforts. It uses an EfficientNetV2-M backbone trained on multiple coral health datasets to provide robust cross-domain generalization.
Architecture
- Base Model: EfficientNetV2-M (timm)
- Input Size: 384x384 RGB images
- Output: Binary classification (healthy/unhealthy)
- Parameters: ~54M
Training Data
- Koh Tao Coral Condition Dataset: Expert-labeled coral images from Thailand
- Coralscapes: Semantic segmentation dataset with derived health labels
- Combined training: ~15,000 labeled images across multiple reef ecosystems
Performance
Cross-Domain Test Results (Coralscapes Test Set)
| Configuration | Balanced Accuracy | Healthy Recall | Unhealthy Recall |
|---|---|---|---|
| Standard (threshold=0.5) | 82.2% | 88.0% | 76.5% |
| Optimized (threshold=0.35) | 84.6% | 80.9% | 88.2% |
| High Sensitivity (threshold=0.25) | 82.9% | 73.6% | 92.2% |
Recommended Settings
- Default threshold: 0.35 (balanced performance)
- High sensitivity screening: 0.25 (catches 92%+ of unhealthy coral)
- Test-Time Augmentation (TTA): Enabled for best results
Usage
Installation
pip install torch torchvision timm pillow
Basic Inference
import torch
import timm
from PIL import Image
from torchvision import transforms
# Load model
model = timm.create_model('tf_efficientnetv2_m', pretrained=False, num_classes=2)
checkpoint = torch.load('model.pt', map_location='cpu', weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Preprocessing
transform = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Predict
image = Image.open('coral_image.jpg').convert('RGB')
input_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)
probs = torch.softmax(output, dim=1)
# Using optimized threshold
threshold = 0.35
unhealthy_prob = probs[0, 1].item()
prediction = 'unhealthy' if unhealthy_prob >= threshold else 'healthy'
print(f"Prediction: {prediction} (unhealthy probability: {unhealthy_prob:.1%})")
Inference with Test-Time Augmentation (Recommended)
import torch
import timm
from PIL import Image
from torchvision import transforms
def get_tta_transforms(img_size=384):
"""Get TTA transforms for more robust predictions."""
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
return [
transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(), normalize]),
transforms.Compose([transforms.Resize((img_size, img_size)), transforms.RandomHorizontalFlip(p=1.0), transforms.ToTensor(), normalize]),
transforms.Compose([transforms.Resize((img_size, img_size)), transforms.RandomVerticalFlip(p=1.0), transforms.ToTensor(), normalize]),
transforms.Compose([transforms.Resize((img_size, img_size)), transforms.Lambda(lambda x: x.rotate(90)), transforms.ToTensor(), normalize]),
transforms.Compose([transforms.Resize((img_size, img_size)), transforms.Lambda(lambda x: x.rotate(270)), transforms.ToTensor(), normalize]),
]
def predict_with_tta(model, image_path, device='cuda', threshold=0.35):
"""Predict with test-time augmentation."""
model.eval()
img = Image.open(image_path).convert('RGB')
tta_transforms = get_tta_transforms()
all_logits = []
with torch.no_grad():
for transform in tta_transforms:
img_tensor = transform(img).unsqueeze(0).to(device)
logits = model(img_tensor)
all_logits.append(logits)
# Average predictions
avg_logits = torch.stack(all_logits).mean(dim=0)
probs = torch.softmax(avg_logits, dim=1)
unhealthy_prob = probs[0, 1].item()
prediction = 'unhealthy' if unhealthy_prob >= threshold else 'healthy'
confidence = unhealthy_prob if prediction == 'unhealthy' else (1 - unhealthy_prob)
return {
'prediction': prediction,
'confidence': confidence,
'unhealthy_probability': unhealthy_prob,
'healthy_probability': 1 - unhealthy_prob
}
# Example usage
model = timm.create_model('tf_efficientnetv2_m', pretrained=False, num_classes=2)
checkpoint = torch.load('model.pt', map_location='cuda', weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.cuda()
result = predict_with_tta(model, 'coral_image.jpg', threshold=0.35)
print(f"Prediction: {result['prediction']} ({result['confidence']:.1%} confidence)")
Deployment Recommendations
For Coral Monitoring Applications
Assisted Screening (Recommended)
- Use threshold 0.25-0.35 for automated flagging
- Human expert reviews flagged images
- Catches 88-92% of unhealthy coral
Balanced Operation
- Use threshold 0.5 for equal treatment of both classes
- Best when false positives and false negatives are equally costly
Two-Stage Pipeline
Stage 1: Automated screening (threshold=0.25) -> Flag potential issues Stage 2: Human expert review -> Final classification
Limitations
- Not for fully autonomous decisions: Missing ~12% of unhealthy coral at optimal settings
- Domain shift: Performance may vary on significantly different reef ecosystems
- Image quality: Best results with clear underwater images, may struggle with turbid water
- Lighting conditions: Trained primarily on natural reef lighting
Intended Use
- Coral reef health monitoring
- Marine conservation research
- Reef restoration project assessment
- Educational and research purposes
Ethical Considerations
This model is intended to assist, not replace, marine biologists and conservation experts. Critical decisions about reef management should involve human expertise. The model should be validated on local reef conditions before deployment.
Citation
If you use this model, please cite:
@software{naturecode_coral_2025,
title = {Naturecode Coral},
author = {Naturecode},
year = {2025},
url = {https://huggingface.co/hilarl/naturecode-coral}
}
License
Apache 2.0
Contact
For questions, issues, or collaboration inquiries, please open an issue on the model repository.
- Downloads last month
- -