File size: 5,268 Bytes
527ce0c 63234e6 527ce0c c031958 527ce0c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | import os
import ee
import numpy as np
import requests
import io
import base64
from rasterio.io import MemoryFile
import torch
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
import gdown
from dotenv import load_dotenv
load_dotenv()
key_path = "/tmp/private-key.json"
with open(key_path, "w") as f:
f.write(os.getenv("PRIVATE_KEY"))
service_account = os.getenv("SERVICE_KEY_ID")
credentials = ee.ServiceAccountCredentials(service_account, key_path)
ee.Initialize(credentials)
MODEL_PATH = "deforestation_unet_full_model.pt"
MODEL_URL = os.getenv("MODEL_URL")
# Download model only if it doesn't exist
if not os.path.exists(MODEL_PATH):
print("Model not found. Downloading from Google Drive...")
gdown.download(MODEL_URL, MODEL_PATH, quiet=False)
# ee.Initialize(project=os.environ["project-id"])
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Load model once
model = smp.Unet(
encoder_name="resnet34",
encoder_weights=None,
in_channels=4,
classes=1,
activation=None,
).to(DEVICE)
model = torch.load("deforestation_unet_full_model.pt", map_location=DEVICE, weights_only=False)
model.eval()
def apply_scale_factors(image):
optical_bands = image.select('SR_B.').multiply(0.0000275).add(-0.2)
thermal_bands = image.select('ST_B.*').multiply(0.00341802).add(149.0)
return image.addBands(optical_bands, None, True).addBands(thermal_bands, None, True)
def fetch_rgb_ndvi(region, year, scale=30):
start = ee.Date.fromYMD(year, 1, 1)
end = ee.Date.fromYMD(year, 12, 31)
col = (ee.ImageCollection("LANDSAT/LC08/C02/T1_L2")
.filterBounds(region)
.filterDate(start, end)
.filterMetadata('CLOUD_COVER', 'less_than', 10)
.map(apply_scale_factors))
image = col.median().clip(region)
ndvi = image.normalizedDifference(['SR_B5', 'SR_B4']).rename('NDVI')
image = image.addBands(ndvi)
return image.select(['SR_B4', 'SR_B3', 'SR_B2']), image.select('NDVI')
def download_geotiff_array(img, region, bands, scale=30):
url = img.getThumbURL({
'scale': scale,
'region': region,
'format': 'GeoTIFF',
'bands': bands
})
response = requests.get(url)
with MemoryFile(response.content) as memfile:
with memfile.open() as src:
arr = src.read().astype(np.float32)
if arr.max() > 1.5:
arr /= 255.0
return arr
def predict_from_arrays(rgb_arr, ndvi_arr):
rgb_arr = rgb_arr[:3, :, :]
ndvi_arr = ndvi_arr[:1, :, :]
input_arr = np.concatenate([rgb_arr, ndvi_arr], axis=0)
input_tensor = torch.tensor(input_arr).unsqueeze(0).to(DEVICE)
with torch.no_grad():
pred = torch.sigmoid(model(input_tensor))
return (pred > 0.5).float().squeeze().cpu().numpy()
def get_deforestation_color_map(mask_t0, mask_t1):
H, W = mask_t0.shape
color_map = np.zeros((H, W, 3), dtype=np.uint8)
retained = (mask_t0 == 1) & (mask_t1 == 1)
lost = (mask_t0 == 1) & (mask_t1 == 0)
gained = (mask_t0 == 0) & (mask_t1 == 1)
none = (mask_t0 == 0) & (mask_t1 == 0)
color_map[retained] = [0, 255, 0] # Green
color_map[lost] = [255, 0, 0] # Red
color_map[gained] = [65, 168, 255] # Blue (gain)
color_map[none] = [255, 255, 255] # White (no change)
return color_map
def run_deforestation_pipeline(lat_min, lat_max, lon_min, lon_max, start_year, end_year):
region = ee.Geometry.Rectangle([lon_min, lat_min, lon_max, lat_max])
rgb_t0_ee, ndvi_t0_ee = fetch_rgb_ndvi(region, start_year)
rgb_t0 = download_geotiff_array(rgb_t0_ee, region, ['SR_B4', 'SR_B3', 'SR_B2'])
ndvi_t0 = download_geotiff_array(ndvi_t0_ee, region, ['NDVI'])
rgb_t1_ee, ndvi_t1_ee = fetch_rgb_ndvi(region, end_year)
rgb_t1 = download_geotiff_array(rgb_t1_ee, region, ['SR_B4', 'SR_B3', 'SR_B2'])
ndvi_t1 = download_geotiff_array(ndvi_t1_ee, region, ['NDVI'])
mask_t0 = predict_from_arrays(rgb_t0, ndvi_t0)
mask_t1 = predict_from_arrays(rgb_t1, ndvi_t1)
deforested_pixels = ((mask_t0 == 1) & (mask_t1 == 0)).sum()
gained_pixels = ((mask_t0 == 0) & (mask_t1 == 1)).sum()
total_vegetation_t0 = (mask_t0 == 1).sum()
percent_loss = (deforested_pixels / total_vegetation_t0) * 100 if total_vegetation_t0 > 0 else 0
percent_gain = (gained_pixels / mask_t0.size) * 100 # relative to total area
color_mask = get_deforestation_color_map(mask_t0, mask_t1)
# Generate figure in memory
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(mask_t0, cmap="Greens")
axes[0].set_title(f"Vegetation in {start_year}")
axes[0].axis("off")
axes[1].imshow(mask_t1, cmap="Greens")
axes[1].set_title(f"Vegetation in {end_year}")
axes[1].axis("off")
axes[2].imshow(color_mask)
axes[2].set_title(f"Vegetation Change")
axes[2].axis("off")
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png")
plt.close(fig)
buf.seek(0)
img_base64 = base64.b64encode(buf.read()).decode('utf-8')
return {
"percent_deforested": round(percent_loss, 2),
"percent_regrowth": round(percent_gain, 2),
"image_base64": img_base64
}
|