Spaces:
Sleeping
Sleeping
File size: 5,669 Bytes
52eabca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os
import zipfile
import tempfile
import uuid
import rasterio
import numpy as np
import matplotlib.pyplot as plt
import gradio as gr
from rasterio.enums import Resampling
def generate_temp_path(suffix=".tif"):
return os.path.join(tempfile.gettempdir(), f"{uuid.uuid4().hex}{suffix}")
def read_band(path):
with rasterio.open(path) as src:
return src.read(1).astype(np.float32), src.profile
def resample_to_10m(band_path, ref_shape, ref_transform):
with rasterio.open(band_path) as src:
data = src.read(
out_shape=(1, ref_shape[0], ref_shape[1]),
resampling=Resampling.bilinear
)
return data[0].astype(np.float32)
def normalize(array):
array /= 10000.0
return np.clip(array, 0, 1)
def array_to_plot(img, title, cmap=None):
fig = plt.figure(figsize=(6, 6))
if cmap:
plt.imshow(img, cmap=cmap)
plt.colorbar()
else:
plt.imshow(img)
plt.title(title)
plt.axis('off')
return fig
def save_tif(path, array, profile, count=3):
profile = profile.copy()
profile.update({
'driver': 'GTiff',
'count': count,
'dtype': rasterio.float32,
'compress': 'deflate',
'predictor': 2,
'tiled': True,
'blockxsize': 512,
'blockysize': 512
})
with rasterio.open(path, 'w', **profile) as dst:
if count == 1:
dst.write(array, 1)
else:
for i in range(count):
dst.write(array[:, :, i], i + 1)
def process_visualization(zip_file_path, vis_type):
with tempfile.TemporaryDirectory() as temp_dir:
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
dirs = [d for d in os.listdir(temp_dir) if d.endswith(".SAFE")]
if not dirs:
raise Exception(".SAFE folder not found inside the zip file.")
extract_dir = os.path.join(temp_dir, dirs[0])
granule_dir = os.path.join(extract_dir, "GRANULE")
granule_path = [os.path.join(granule_dir, d) for d in os.listdir(granule_dir) if d.startswith("L2A_")][0]
img_data_dir = os.path.join(granule_path, "IMG_DATA")
res_paths = {
"R10m": os.path.join(img_data_dir, "R10m"),
"R20m": os.path.join(img_data_dir, "R20m")
}
# Geometry reference
b4, profile = read_band(os.path.join(res_paths["R10m"], [f for f in os.listdir(res_paths["R10m"]) if "_B04" in f][0]))
if vis_type == "Natural Color (B4, B3, B2)":
b2, _ = read_band(os.path.join(res_paths["R10m"], [f for f in os.listdir(res_paths["R10m"]) if "_B02" in f][0]))
b3, _ = read_band(os.path.join(res_paths["R10m"], [f for f in os.listdir(res_paths["R10m"]) if "_B03" in f][0]))
rgb = np.stack([b4, b3, b2], axis=-1)
rgb_plot = array_to_plot(normalize(rgb), vis_type)
tif_path = generate_temp_path(".tif")
save_tif(tif_path, rgb, profile, count=3)
return rgb_plot, tif_path
elif vis_type == "False Color Vegetation (B8, B4, B3)":
b3, _ = read_band(os.path.join(res_paths["R10m"], [f for f in os.listdir(res_paths["R10m"]) if "_B03" in f][0]))
b8, _ = read_band(os.path.join(res_paths["R10m"], [f for f in os.listdir(res_paths["R10m"]) if "_B08" in f][0]))
fcv = np.stack([b8, b4, b3], axis=-1)
fcv_plot = array_to_plot(normalize(fcv), vis_type)
tif_path = generate_temp_path(".tif")
save_tif(tif_path, fcv, profile, count=3)
return fcv_plot, tif_path
elif vis_type == "False Color SWIR (B12, B8, B4)":
b8, _ = read_band(os.path.join(res_paths["R10m"], [f for f in os.listdir(res_paths["R10m"]) if "_B08" in f][0]))
b12_path = os.path.join(res_paths["R20m"], [f for f in os.listdir(res_paths["R20m"]) if "_B12" in f][0])
b12 = resample_to_10m(b12_path, b4.shape, profile["transform"])
fcswir = np.stack([b12, b8, b4], axis=-1)
swir_plot = array_to_plot(normalize(fcswir), vis_type)
tif_path = generate_temp_path(".tif")
save_tif(tif_path, fcswir, profile, count=3)
return swir_plot, tif_path
elif vis_type == "NDVI":
b8, _ = read_band(os.path.join(res_paths["R10m"], [f for f in os.listdir(res_paths["R10m"]) if "_B08" in f][0]))
ndvi = (b8 - b4) / (b8 + b4 + 1e-6)
ndvi_plot = array_to_plot(ndvi, "NDVI", cmap='RdYlGn')
tif_path = generate_temp_path(".tif")
save_tif(tif_path, ndvi, profile, count=1)
return ndvi_plot, tif_path
else:
raise ValueError("Invalid visualization type.")
# === Gradio Interface ===
demo = gr.Interface(
fn=process_visualization,
inputs=[
gr.File(label="Sentinel-2 Archive (.zip)", type="filepath"),
gr.Dropdown(
choices=[
"Natural Color (B4, B3, B2)",
"False Color Vegetation (B8, B4, B3)",
"False Color SWIR (B12, B8, B4)",
"NDVI"
],
label="Visualization Type"
)
],
outputs=[
gr.Plot(label="Preview"),
gr.File(label="Download GeoTIFF")
],
title="Sentinel-2 Viewer + GeoTIFF Export",
description="Upload a .SAFE.zip file, choose a visualization type, and download the corresponding GeoTIFF file."
)
if __name__ == "__main__":
demo.launch() |