Spaces:
Sleeping
Sleeping
| import os | |
| import ee | |
| import numpy as np | |
| import requests | |
| import io | |
| import base64 | |
| from rasterio.io import MemoryFile | |
| import torch | |
| import segmentation_models_pytorch as smp | |
| import matplotlib.pyplot as plt | |
| import gdown | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| key_path = "/tmp/private-key.json" | |
| with open(key_path, "w") as f: | |
| f.write(os.getenv("PRIVATE_KEY")) | |
| service_account = os.getenv("SERVICE_KEY_ID") | |
| credentials = ee.ServiceAccountCredentials(service_account, key_path) | |
| ee.Initialize(credentials) | |
| MODEL_PATH = "deforestation_unet_full_model.pt" | |
| MODEL_URL = os.getenv("MODEL_URL") | |
| # Download model only if it doesn't exist | |
| if not os.path.exists(MODEL_PATH): | |
| print("Model not found. Downloading from Google Drive...") | |
| gdown.download(MODEL_URL, MODEL_PATH, quiet=False) | |
| # ee.Initialize(project=os.environ["project-id"]) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load model once | |
| model = smp.Unet( | |
| encoder_name="resnet34", | |
| encoder_weights=None, | |
| in_channels=4, | |
| classes=1, | |
| activation=None, | |
| ).to(DEVICE) | |
| model = torch.load("deforestation_unet_full_model.pt", map_location=DEVICE, weights_only=False) | |
| model.eval() | |
| def apply_scale_factors(image): | |
| optical_bands = image.select('SR_B.').multiply(0.0000275).add(-0.2) | |
| thermal_bands = image.select('ST_B.*').multiply(0.00341802).add(149.0) | |
| return image.addBands(optical_bands, None, True).addBands(thermal_bands, None, True) | |
| def fetch_rgb_ndvi(region, year, scale=30): | |
| start = ee.Date.fromYMD(year, 1, 1) | |
| end = ee.Date.fromYMD(year, 12, 31) | |
| col = (ee.ImageCollection("LANDSAT/LC08/C02/T1_L2") | |
| .filterBounds(region) | |
| .filterDate(start, end) | |
| .filterMetadata('CLOUD_COVER', 'less_than', 10) | |
| .map(apply_scale_factors)) | |
| image = col.median().clip(region) | |
| ndvi = image.normalizedDifference(['SR_B5', 'SR_B4']).rename('NDVI') | |
| image = image.addBands(ndvi) | |
| return image.select(['SR_B4', 'SR_B3', 'SR_B2']), image.select('NDVI') | |
| def download_geotiff_array(img, region, bands, scale=30): | |
| url = img.getThumbURL({ | |
| 'scale': scale, | |
| 'region': region, | |
| 'format': 'GeoTIFF', | |
| 'bands': bands | |
| }) | |
| response = requests.get(url) | |
| with MemoryFile(response.content) as memfile: | |
| with memfile.open() as src: | |
| arr = src.read().astype(np.float32) | |
| if arr.max() > 1.5: | |
| arr /= 255.0 | |
| return arr | |
| def predict_from_arrays(rgb_arr, ndvi_arr): | |
| rgb_arr = rgb_arr[:3, :, :] | |
| ndvi_arr = ndvi_arr[:1, :, :] | |
| input_arr = np.concatenate([rgb_arr, ndvi_arr], axis=0) | |
| input_tensor = torch.tensor(input_arr).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| pred = torch.sigmoid(model(input_tensor)) | |
| return (pred > 0.5).float().squeeze().cpu().numpy() | |
| def get_deforestation_color_map(mask_t0, mask_t1): | |
| H, W = mask_t0.shape | |
| color_map = np.zeros((H, W, 3), dtype=np.uint8) | |
| retained = (mask_t0 == 1) & (mask_t1 == 1) | |
| lost = (mask_t0 == 1) & (mask_t1 == 0) | |
| gained = (mask_t0 == 0) & (mask_t1 == 1) | |
| none = (mask_t0 == 0) & (mask_t1 == 0) | |
| color_map[retained] = [0, 255, 0] # Green | |
| color_map[lost] = [255, 0, 0] # Red | |
| color_map[gained] = [65, 168, 255] # Blue (gain) | |
| color_map[none] = [255, 255, 255] # White (no change) | |
| return color_map | |
| def run_deforestation_pipeline(lat_min, lat_max, lon_min, lon_max, start_year, end_year): | |
| region = ee.Geometry.Rectangle([lon_min, lat_min, lon_max, lat_max]) | |
| rgb_t0_ee, ndvi_t0_ee = fetch_rgb_ndvi(region, start_year) | |
| rgb_t0 = download_geotiff_array(rgb_t0_ee, region, ['SR_B4', 'SR_B3', 'SR_B2']) | |
| ndvi_t0 = download_geotiff_array(ndvi_t0_ee, region, ['NDVI']) | |
| rgb_t1_ee, ndvi_t1_ee = fetch_rgb_ndvi(region, end_year) | |
| rgb_t1 = download_geotiff_array(rgb_t1_ee, region, ['SR_B4', 'SR_B3', 'SR_B2']) | |
| ndvi_t1 = download_geotiff_array(ndvi_t1_ee, region, ['NDVI']) | |
| mask_t0 = predict_from_arrays(rgb_t0, ndvi_t0) | |
| mask_t1 = predict_from_arrays(rgb_t1, ndvi_t1) | |
| deforested_pixels = ((mask_t0 == 1) & (mask_t1 == 0)).sum() | |
| gained_pixels = ((mask_t0 == 0) & (mask_t1 == 1)).sum() | |
| total_vegetation_t0 = (mask_t0 == 1).sum() | |
| percent_loss = (deforested_pixels / total_vegetation_t0) * 100 if total_vegetation_t0 > 0 else 0 | |
| percent_gain = (gained_pixels / mask_t0.size) * 100 # relative to total area | |
| color_mask = get_deforestation_color_map(mask_t0, mask_t1) | |
| # Generate figure in memory | |
| fig, axes = plt.subplots(1, 3, figsize=(12, 4)) | |
| axes[0].imshow(mask_t0, cmap="Greens") | |
| axes[0].set_title(f"Vegetation in {start_year}") | |
| axes[0].axis("off") | |
| axes[1].imshow(mask_t1, cmap="Greens") | |
| axes[1].set_title(f"Vegetation in {end_year}") | |
| axes[1].axis("off") | |
| axes[2].imshow(color_mask) | |
| axes[2].set_title(f"Vegetation Change") | |
| axes[2].axis("off") | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png") | |
| plt.close(fig) | |
| buf.seek(0) | |
| img_base64 = base64.b64encode(buf.read()).decode('utf-8') | |
| return { | |
| "percent_deforested": round(percent_loss, 2), | |
| "percent_regrowth": round(percent_gain, 2), | |
| "image_base64": img_base64 | |
| } | |