ojasrohatgi commited on
Commit
527ce0c
·
verified ·
1 Parent(s): 665e5d5

Update backend.py

Browse files
Files changed (1) hide show
  1. backend.py +156 -155
backend.py CHANGED
@@ -1,155 +1,156 @@
1
- import os
2
- import ee
3
- import numpy as np
4
- import requests
5
- import io
6
- import base64
7
- from rasterio.io import MemoryFile
8
- import torch
9
- import segmentation_models_pytorch as smp
10
- import matplotlib.pyplot as plt
11
- import gdown
12
- from dotenv import load_dotenv
13
-
14
- load_dotenv()
15
-
16
- PRIVATE_KEY = os.getenv("PRIVATE_KEY")
17
- with open("private-key.json", "w") as f:
18
- f.write(PRIVATE_KEY)
19
- service_account = os.getenv("SERVICE_KEY_ID")
20
- credentials = ee.ServiceAccountCredentials(service_account, 'private-key.json')
21
- ee.Initialize(credentials)
22
-
23
- MODEL_PATH = "deforestation_unet_full_model.pt"
24
- MODEL_URL = os.getenv("MODEL_URL")
25
-
26
- # Download model only if it doesn't exist
27
- if not os.path.exists(MODEL_PATH):
28
- print("Model not found. Downloading from Google Drive...")
29
- gdown.download(MODEL_URL, MODEL_PATH, quiet=False)
30
-
31
- # ee.Initialize(project=os.environ["project-id"])
32
-
33
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
34
-
35
- # Load model once
36
- model = smp.Unet(
37
- encoder_name="resnet34",
38
- encoder_weights=None,
39
- in_channels=4,
40
- classes=1,
41
- activation=None,
42
- ).to(DEVICE)
43
- model = torch.load("deforestation_unet_full_model.pt", map_location=DEVICE, weights_only=False)
44
- model.eval()
45
-
46
- def apply_scale_factors(image):
47
- optical_bands = image.select('SR_B.').multiply(0.0000275).add(-0.2)
48
- thermal_bands = image.select('ST_B.*').multiply(0.00341802).add(149.0)
49
- return image.addBands(optical_bands, None, True).addBands(thermal_bands, None, True)
50
-
51
- def fetch_rgb_ndvi(region, year, scale=30):
52
- start = ee.Date.fromYMD(year, 1, 1)
53
- end = ee.Date.fromYMD(year, 12, 31)
54
- col = (ee.ImageCollection("LANDSAT/LC08/C02/T1_L2")
55
- .filterBounds(region)
56
- .filterDate(start, end)
57
- .filterMetadata('CLOUD_COVER', 'less_than', 10)
58
- .map(apply_scale_factors))
59
- image = col.median().clip(region)
60
- ndvi = image.normalizedDifference(['SR_B5', 'SR_B4']).rename('NDVI')
61
- image = image.addBands(ndvi)
62
- return image.select(['SR_B4', 'SR_B3', 'SR_B2']), image.select('NDVI')
63
-
64
- def download_geotiff_array(img, region, bands, scale=30):
65
- url = img.getThumbURL({
66
- 'scale': scale,
67
- 'region': region,
68
- 'format': 'GeoTIFF',
69
- 'bands': bands
70
- })
71
- response = requests.get(url)
72
- with MemoryFile(response.content) as memfile:
73
- with memfile.open() as src:
74
- arr = src.read().astype(np.float32)
75
- if arr.max() > 1.5:
76
- arr /= 255.0
77
- return arr
78
-
79
- def predict_from_arrays(rgb_arr, ndvi_arr):
80
- rgb_arr = rgb_arr[:3, :, :]
81
- ndvi_arr = ndvi_arr[:1, :, :]
82
- input_arr = np.concatenate([rgb_arr, ndvi_arr], axis=0)
83
- input_tensor = torch.tensor(input_arr).unsqueeze(0).to(DEVICE)
84
- with torch.no_grad():
85
- pred = torch.sigmoid(model(input_tensor))
86
- return (pred > 0.5).float().squeeze().cpu().numpy()
87
-
88
- def get_deforestation_color_map(mask_t0, mask_t1):
89
- H, W = mask_t0.shape
90
- color_map = np.zeros((H, W, 3), dtype=np.uint8)
91
-
92
- retained = (mask_t0 == 1) & (mask_t1 == 1)
93
- lost = (mask_t0 == 1) & (mask_t1 == 0)
94
- gained = (mask_t0 == 0) & (mask_t1 == 1)
95
- none = (mask_t0 == 0) & (mask_t1 == 0)
96
-
97
- color_map[retained] = [0, 255, 0] # Green
98
- color_map[lost] = [255, 0, 0] # Red
99
- color_map[gained] = [65, 168, 255] # Blue (gain)
100
- color_map[none] = [255, 255, 255] # White (no change)
101
-
102
- return color_map
103
-
104
- def run_deforestation_pipeline(lat_min, lat_max, lon_min, lon_max, start_year, end_year):
105
- region = ee.Geometry.Rectangle([lon_min, lat_min, lon_max, lat_max])
106
-
107
- rgb_t0_ee, ndvi_t0_ee = fetch_rgb_ndvi(region, start_year)
108
- rgb_t0 = download_geotiff_array(rgb_t0_ee, region, ['SR_B4', 'SR_B3', 'SR_B2'])
109
- ndvi_t0 = download_geotiff_array(ndvi_t0_ee, region, ['NDVI'])
110
-
111
- rgb_t1_ee, ndvi_t1_ee = fetch_rgb_ndvi(region, end_year)
112
- rgb_t1 = download_geotiff_array(rgb_t1_ee, region, ['SR_B4', 'SR_B3', 'SR_B2'])
113
- ndvi_t1 = download_geotiff_array(ndvi_t1_ee, region, ['NDVI'])
114
-
115
- mask_t0 = predict_from_arrays(rgb_t0, ndvi_t0)
116
- mask_t1 = predict_from_arrays(rgb_t1, ndvi_t1)
117
-
118
- deforested_pixels = ((mask_t0 == 1) & (mask_t1 == 0)).sum()
119
- gained_pixels = ((mask_t0 == 0) & (mask_t1 == 1)).sum()
120
- total_vegetation_t0 = (mask_t0 == 1).sum()
121
-
122
- percent_loss = (deforested_pixels / total_vegetation_t0) * 100 if total_vegetation_t0 > 0 else 0
123
- percent_gain = (gained_pixels / mask_t0.size) * 100 # relative to total area
124
-
125
- color_mask = get_deforestation_color_map(mask_t0, mask_t1)
126
-
127
- # Generate figure in memory
128
- fig, axes = plt.subplots(1, 3, figsize=(12, 4))
129
-
130
- axes[0].imshow(mask_t0, cmap="Greens")
131
- axes[0].set_title(f"Vegetation in {start_year}")
132
- axes[0].axis("off")
133
-
134
- axes[1].imshow(mask_t1, cmap="Greens")
135
- axes[1].set_title(f"Vegetation in {end_year}")
136
- axes[1].axis("off")
137
-
138
- axes[2].imshow(color_mask)
139
- axes[2].set_title(f"Vegetation Change")
140
- axes[2].axis("off")
141
-
142
- plt.tight_layout()
143
-
144
- buf = io.BytesIO()
145
- plt.savefig(buf, format="png")
146
- plt.close(fig)
147
- buf.seek(0)
148
- img_base64 = base64.b64encode(buf.read()).decode('utf-8')
149
-
150
- return {
151
- "percent_deforested": round(percent_loss, 2),
152
- "percent_regrowth": round(percent_gain, 2),
153
- "image_base64": img_base64
154
- }
155
-
 
 
1
+ import os
2
+ import ee
3
+ import numpy as np
4
+ import requests
5
+ import io
6
+ import base64
7
+ from rasterio.io import MemoryFile
8
+ import torch
9
+ import segmentation_models_pytorch as smp
10
+ import matplotlib.pyplot as plt
11
+ import gdown
12
+ from dotenv import load_dotenv
13
+
14
+ load_dotenv()
15
+
16
+ key_path = "/tmp/private-key.json"
17
+ with open(key_path, "w") as f:
18
+ f.write(os.environos.getenv("PRIVATE_KEY"))
19
+
20
+ service_account = os.getenv("SERVICE_KEY_ID")
21
+ credentials = ee.ServiceAccountCredentials(service_account, 'private-key.json')
22
+ ee.Initialize(credentials)
23
+
24
+ MODEL_PATH = "deforestation_unet_full_model.pt"
25
+ MODEL_URL = os.getenv("MODEL_URL")
26
+
27
+ # Download model only if it doesn't exist
28
+ if not os.path.exists(MODEL_PATH):
29
+ print("Model not found. Downloading from Google Drive...")
30
+ gdown.download(MODEL_URL, MODEL_PATH, quiet=False)
31
+
32
+ # ee.Initialize(project=os.environ["project-id"])
33
+
34
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
35
+
36
+ # Load model once
37
+ model = smp.Unet(
38
+ encoder_name="resnet34",
39
+ encoder_weights=None,
40
+ in_channels=4,
41
+ classes=1,
42
+ activation=None,
43
+ ).to(DEVICE)
44
+ model = torch.load("deforestation_unet_full_model.pt", map_location=DEVICE, weights_only=False)
45
+ model.eval()
46
+
47
+ def apply_scale_factors(image):
48
+ optical_bands = image.select('SR_B.').multiply(0.0000275).add(-0.2)
49
+ thermal_bands = image.select('ST_B.*').multiply(0.00341802).add(149.0)
50
+ return image.addBands(optical_bands, None, True).addBands(thermal_bands, None, True)
51
+
52
+ def fetch_rgb_ndvi(region, year, scale=30):
53
+ start = ee.Date.fromYMD(year, 1, 1)
54
+ end = ee.Date.fromYMD(year, 12, 31)
55
+ col = (ee.ImageCollection("LANDSAT/LC08/C02/T1_L2")
56
+ .filterBounds(region)
57
+ .filterDate(start, end)
58
+ .filterMetadata('CLOUD_COVER', 'less_than', 10)
59
+ .map(apply_scale_factors))
60
+ image = col.median().clip(region)
61
+ ndvi = image.normalizedDifference(['SR_B5', 'SR_B4']).rename('NDVI')
62
+ image = image.addBands(ndvi)
63
+ return image.select(['SR_B4', 'SR_B3', 'SR_B2']), image.select('NDVI')
64
+
65
+ def download_geotiff_array(img, region, bands, scale=30):
66
+ url = img.getThumbURL({
67
+ 'scale': scale,
68
+ 'region': region,
69
+ 'format': 'GeoTIFF',
70
+ 'bands': bands
71
+ })
72
+ response = requests.get(url)
73
+ with MemoryFile(response.content) as memfile:
74
+ with memfile.open() as src:
75
+ arr = src.read().astype(np.float32)
76
+ if arr.max() > 1.5:
77
+ arr /= 255.0
78
+ return arr
79
+
80
+ def predict_from_arrays(rgb_arr, ndvi_arr):
81
+ rgb_arr = rgb_arr[:3, :, :]
82
+ ndvi_arr = ndvi_arr[:1, :, :]
83
+ input_arr = np.concatenate([rgb_arr, ndvi_arr], axis=0)
84
+ input_tensor = torch.tensor(input_arr).unsqueeze(0).to(DEVICE)
85
+ with torch.no_grad():
86
+ pred = torch.sigmoid(model(input_tensor))
87
+ return (pred > 0.5).float().squeeze().cpu().numpy()
88
+
89
+ def get_deforestation_color_map(mask_t0, mask_t1):
90
+ H, W = mask_t0.shape
91
+ color_map = np.zeros((H, W, 3), dtype=np.uint8)
92
+
93
+ retained = (mask_t0 == 1) & (mask_t1 == 1)
94
+ lost = (mask_t0 == 1) & (mask_t1 == 0)
95
+ gained = (mask_t0 == 0) & (mask_t1 == 1)
96
+ none = (mask_t0 == 0) & (mask_t1 == 0)
97
+
98
+ color_map[retained] = [0, 255, 0] # Green
99
+ color_map[lost] = [255, 0, 0] # Red
100
+ color_map[gained] = [65, 168, 255] # Blue (gain)
101
+ color_map[none] = [255, 255, 255] # White (no change)
102
+
103
+ return color_map
104
+
105
+ def run_deforestation_pipeline(lat_min, lat_max, lon_min, lon_max, start_year, end_year):
106
+ region = ee.Geometry.Rectangle([lon_min, lat_min, lon_max, lat_max])
107
+
108
+ rgb_t0_ee, ndvi_t0_ee = fetch_rgb_ndvi(region, start_year)
109
+ rgb_t0 = download_geotiff_array(rgb_t0_ee, region, ['SR_B4', 'SR_B3', 'SR_B2'])
110
+ ndvi_t0 = download_geotiff_array(ndvi_t0_ee, region, ['NDVI'])
111
+
112
+ rgb_t1_ee, ndvi_t1_ee = fetch_rgb_ndvi(region, end_year)
113
+ rgb_t1 = download_geotiff_array(rgb_t1_ee, region, ['SR_B4', 'SR_B3', 'SR_B2'])
114
+ ndvi_t1 = download_geotiff_array(ndvi_t1_ee, region, ['NDVI'])
115
+
116
+ mask_t0 = predict_from_arrays(rgb_t0, ndvi_t0)
117
+ mask_t1 = predict_from_arrays(rgb_t1, ndvi_t1)
118
+
119
+ deforested_pixels = ((mask_t0 == 1) & (mask_t1 == 0)).sum()
120
+ gained_pixels = ((mask_t0 == 0) & (mask_t1 == 1)).sum()
121
+ total_vegetation_t0 = (mask_t0 == 1).sum()
122
+
123
+ percent_loss = (deforested_pixels / total_vegetation_t0) * 100 if total_vegetation_t0 > 0 else 0
124
+ percent_gain = (gained_pixels / mask_t0.size) * 100 # relative to total area
125
+
126
+ color_mask = get_deforestation_color_map(mask_t0, mask_t1)
127
+
128
+ # Generate figure in memory
129
+ fig, axes = plt.subplots(1, 3, figsize=(12, 4))
130
+
131
+ axes[0].imshow(mask_t0, cmap="Greens")
132
+ axes[0].set_title(f"Vegetation in {start_year}")
133
+ axes[0].axis("off")
134
+
135
+ axes[1].imshow(mask_t1, cmap="Greens")
136
+ axes[1].set_title(f"Vegetation in {end_year}")
137
+ axes[1].axis("off")
138
+
139
+ axes[2].imshow(color_mask)
140
+ axes[2].set_title(f"Vegetation Change")
141
+ axes[2].axis("off")
142
+
143
+ plt.tight_layout()
144
+
145
+ buf = io.BytesIO()
146
+ plt.savefig(buf, format="png")
147
+ plt.close(fig)
148
+ buf.seek(0)
149
+ img_base64 = base64.b64encode(buf.read()).decode('utf-8')
150
+
151
+ return {
152
+ "percent_deforested": round(percent_loss, 2),
153
+ "percent_regrowth": round(percent_gain, 2),
154
+ "image_base64": img_base64
155
+ }
156
+