Prithvi EO 2.0 - Burn Scar Severity Detection
Fine-tuned Prithvi EO 2.0 (IBM/NASA Geospatial Foundation Model) for pixel-level wildfire burn scar severity classification using Sentinel-2 multi-temporal satellite imagery.
Author: Tushar Thokdar
Model Description
This model performs semantic segmentation on multi-temporal Sentinel-2 satellite imagery to classify burn severity into 5 classes. It builds on the NASA-IBM Prithvi EO 2.0 Vision Transformer backbone with a UperNet decoder, fine-tuned using a novel Delta Channel Algorithm and a Freeze-then-Unfreeze training strategy.
Architecture
| Component | Details |
|---|---|
| Backbone | prithvi_eo_v2_tiny_tl (ViT-based Masked Autoencoder) |
| Decoder | UperNet (Unified Perceptual Parsing Network) |
| Parameters | ~25.1M |
| Pre-training | IBM/NASA global satellite imagery |
| Fine-tuning | Burn scar severity from Sentinel-2 |
Input Specification
- Shape:
(Batch, Time=3, Channels=6, Height=224, Width=224) - Temporal frames: Pre-fire, Post-fire, Delta (Post - Pre, clipped to [-1, 1])
- Spectral bands: B2 (Blue), B3 (Green), B4 (Red), B8A (NIR), B11 (SWIR1), B12 (SWIR2)
- Normalization: Surface reflectance values in [0, 1]
Output Classes
| Class ID | Label | Color |
|---|---|---|
| 0 | Unburned | Green |
| 1 | Low Severity | Yellow |
| 2 | Moderate-Low | Orange |
| 3 | Moderate-High | Red-Orange |
| 4 | High Severity | Dark Red |
Performance
| Metric | Baseline (Pretrained) | Fine-Tuned | Improvement |
|---|---|---|---|
| Accuracy | 35.86% | 69.93% | +34.07% |
| Macro F1 | 0.1160 | 0.6218 | +50.58% |
| Weighted F1 | 0.2035 | 0.7015 | +49.80% |
| Burned F1 | 0.0133 | 0.5553 | +54.20% |
Per-Class F1 Scores
| Class | Baseline | Fine-Tuned |
|---|---|---|
| Unburned | 0.5667 | 0.8187 |
| Low Severity | 0.0000 | 0.5336 |
| Moderate-Low | 0.0000 | 0.4804 |
| Moderate-High | 0.0133 | 0.4029 |
| High Severity | 0.0509 | 0.7055 |
Training Details
Strategy: Freeze-then-Unfreeze
- Stage 1 (Epochs 0-5): ViT backbone frozen, only UperNet decoder trained
- Stage 2 (Epochs 6+): Full model unfrozen for joint end-to-end optimization
Hyperparameters
| Parameter | Value |
|---|---|
| Optimizer | AdamW |
| Learning Rate | 1e-4 |
| Weight Decay | 0.01 |
| LR Scheduler | Cosine Annealing (T_max=30, eta_min=1e-6) |
| Batch Size | 8 |
| Max Epochs | 50 (early stopped at 38) |
| Early Stopping | Patience=12 on val_loss |
| Precision | 16-bit mixed (AMP) |
| Gradient Clipping | 1.0 |
Loss Function
Hybrid loss combining class-weighted Cross-Entropy and Dice Loss:
Loss = CrossEntropy(class_weights) + 0.8 * DiceLoss(classes=[1,2,3,4])
Class weights: [0.7406, 1.3147, 1.0963, 1.2333, 0.8575]
Key Innovation: Delta Channel
Standard approaches use 2-frame stacks (pre + post). This model uses a 3-frame stack that includes an explicit spectral difference:
Delta = Clip(Post-fire - Pre-fire, -1.0, 1.0)
Input = Stack([Pre-fire, Post-fire, Delta]) # Shape: (3, 6, 224, 224)
This provides the model with direct change magnitude information, significantly improving boundary detection.
Usage
Quick Inference
import torch
import numpy as np
from huggingface_hub import hf_hub_download
# Download checkpoint
ckpt_path = hf_hub_download(
repo_id="Tushar365/prithvi-burn-scar-model",
filename="prithvi-delta-best.ckpt",
)
# Model class must be available for load_from_checkpoint
from model import PrithviFireSegmentation
CLASS_WEIGHTS = torch.tensor([0.7406347, 1.3146889, 1.0963433, 1.233306, 0.8574721])
model = PrithviFireSegmentation.load_from_checkpoint(
ckpt_path, class_weights=CLASS_WEIGHTS
)
model.eval()
model.cuda()
# Load a sample chip (3, 6, 224, 224)
image = np.load("chip_000000.npy")
input_tensor = torch.from_numpy(image).float().unsqueeze(0).cuda()
with torch.no_grad():
logits = model(input_tensor)
prediction = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()
# prediction shape: (224, 224) with values 0-4
Dependencies
torch>=2.1.0
pytorch-lightning==2.1.0
terratorch
segmentation-models-pytorch
numpy==1.26.4
Files in This Repository
| File | Description |
|---|---|
prithvi-delta-best.ckpt |
Fine-tuned model checkpoint (~100MB) |
model.py |
PrithviFireSegmentation class definition (required for loading) |
README.md |
This model card |
Limitations
- Trained on a single wildfire event in Northern California; generalization to other regions/fire types needs validation
- 20m spatial resolution may miss fine-grained burn patterns
- dNBR-derived labels have inherent uncertainty at class boundaries
- Performance on cloudy/smoky imagery has not been evaluated
Citation
@misc{thokdar2024prithviburnscars,
title={Fine-tuning Prithvi EO 2.0 for Burn Scar Severity Detection with Delta Channel Algorithm},
author={Tushar Thokdar},
year={2024},
note={End-to-end wildfire damage assessment pipeline using NASA-IBM foundation model}
}
Acknowledgments
- Prithvi EO 2.0 by IBM and NASA for the foundation model
- ESA Copernicus for Sentinel-2 satellite imagery
- Google Earth Engine for data access and preprocessing
Dataset used to train Tushar365/prithvi-burn-scar-model
Space using Tushar365/prithvi-burn-scar-model 1
Evaluation results
- Macro F1 on Prithvi Burn Scar Severity Datasetself-reported0.622
- Accuracy on Prithvi Burn Scar Severity Datasetself-reported0.699
- Weighted F1 on Prithvi Burn Scar Severity Datasetself-reported0.702