File size: 8,583 Bytes
9d33171 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 | import math
import time
from dataclasses import dataclass
from io import BytesIO
from typing import Any
import requests
import streamlit as st
from PIL import Image
ESRI_TILE_URL = (
"https://server.arcgisonline.com/ArcGIS/rest/services/"
"World_Imagery/MapServer/tile/{z}/{y}/{x}"
)
TILE_SIZE = 256
USER_AGENT = "eurosat-rgb-streamlit-demo/1.0"
EUROSAT_TARGET_MIN_M = 500
EUROSAT_TARGET_MAX_M = 1_000
EUROSAT_ACCEPTABLE_MIN_M = 250
EUROSAT_ACCEPTABLE_MAX_M = 1_500
EUROSAT_MAX_ASPECT_RATIO = 2.0
class TileFetchError(RuntimeError):
"""Raised when an Esri imagery tile cannot be fetched."""
@dataclass(frozen=True)
class BBox:
west: float
south: float
east: float
north: float
@dataclass(frozen=True)
class TileRange:
zoom: int
x_min: int
x_max: int
y_min: int
y_max: int
def extract_bbox_from_geojson(drawing: dict[str, Any]) -> BBox:
"""Extract a lon/lat bbox from a Folium Draw GeoJSON rectangle."""
geometry = drawing.get("geometry", {})
coordinates = geometry.get("coordinates")
if geometry.get("type") != "Polygon" or not coordinates:
raise ValueError("Expected a drawn rectangle polygon.")
ring = coordinates[0]
lons = [point[0] for point in ring]
lats = [point[1] for point in ring]
west, east = min(lons), max(lons)
south, north = min(lats), max(lats)
if west == east or south == north:
raise ValueError("The drawn rectangle has no area.")
return BBox(west=west, south=south, east=east, north=north)
def lonlat_to_tile_fraction(lon: float, lat: float, zoom: int) -> tuple[float, float]:
"""Convert lon/lat to fractional XYZ tile coordinates.
Uses the OpenStreetMap slippy-map convention:
https://wiki.openstreetmap.org/wiki/Slippy_map_tilenames
XYZ y coordinates start at 0 at the northern edge of the world.
"""
lat = max(min(lat, 85.05112878), -85.05112878)
lat_rad = math.radians(lat)
n = 2**zoom
x = (lon + 180.0) / 360.0 * n
y = (
1.0
- math.log(math.tan(lat_rad) + (1.0 / math.cos(lat_rad))) / math.pi
) / 2.0 * n
return x, y
def bbox_to_tile_range(bbox: BBox, zoom: int) -> TileRange:
"""Return the inclusive XYZ tile range covering a lon/lat bbox."""
max_tile = (2**zoom) - 1
x_west, y_north = lonlat_to_tile_fraction(bbox.west, bbox.north, zoom)
x_east, y_south = lonlat_to_tile_fraction(bbox.east, bbox.south, zoom)
x_min = max(0, min(max_tile, math.floor(x_west)))
x_max = max(0, min(max_tile, math.floor(x_east)))
y_min = max(0, min(max_tile, math.floor(y_north)))
y_max = max(0, min(max_tile, math.floor(y_south)))
return TileRange(
zoom=zoom,
x_min=min(x_min, x_max),
x_max=max(x_min, x_max),
y_min=min(y_min, y_max),
y_max=max(y_min, y_max),
)
def choose_zoom_level(bbox: BBox) -> int:
"""Choose a tile zoom; EuroSAT-scale rectangles use zoom 14-15."""
width_m, height_m = bbox_size_meters(bbox)
max_side_m = max(width_m, height_m)
if max_side_m <= 1_000:
return 15
if max_side_m <= 5_000:
return 14
return 13
def bbox_size_meters(bbox: BBox) -> tuple[float, float]:
"""Approximate bbox width and height in meters."""
mid_lat = (bbox.north + bbox.south) / 2.0
width_m = _haversine_meters(bbox.west, mid_lat, bbox.east, mid_lat)
height_m = _haversine_meters(bbox.west, bbox.south, bbox.west, bbox.north)
return width_m, height_m
def size_warning_for_bbox(bbox: BBox) -> str | None:
"""Return a user-facing warning for rectangles outside the demo range."""
width_m, height_m = bbox_size_meters(bbox)
min_side_m = min(width_m, height_m)
max_side_m = max(width_m, height_m)
if min_side_m < 50:
return "This rectangle is very small. Draw at least about 50m on a side."
if max_side_m > 5_000:
return "This rectangle is very large. Draw at most about 5km on a side."
return None
def bbox_scale_status(bbox: BBox) -> tuple[str, str]:
"""Classify whether a bbox is close enough to EuroSAT-RGB tile scale."""
width_m, height_m = bbox_size_meters(bbox)
min_side_m = min(width_m, height_m)
max_side_m = max(width_m, height_m)
aspect_ratio = max_side_m / min_side_m
if min_side_m < EUROSAT_ACCEPTABLE_MIN_M:
return (
"invalid",
"This rectangle is too small for a useful EuroSAT-style prediction. "
"Draw closer to 500m-1km on each side.",
)
if max_side_m > EUROSAT_ACCEPTABLE_MAX_M:
return (
"invalid",
"This rectangle is too large for this EuroSAT-style demo. "
"Zoom in and draw closer to 500m-1km on each side.",
)
if aspect_ratio > EUROSAT_MAX_ASPECT_RATIO:
return (
"invalid",
"This rectangle is too stretched. Draw a more square region, like the original EuroSAT tiles.",
)
if (
EUROSAT_TARGET_MIN_M <= min_side_m
and max_side_m <= EUROSAT_TARGET_MAX_M
):
return (
"good",
"Great scale: this is close to the original EuroSAT-RGB tile footprint.",
)
return (
"usable",
"Usable, but not ideal. For the most trustworthy demo result, draw 500m-1km on each side.",
)
def fetch_bbox_image(bbox: BBox, zoom: int | None = None) -> Image.Image:
"""Fetch Esri XYZ tiles for a bbox, stitch them, and crop to the bbox."""
zoom = choose_zoom_level(bbox) if zoom is None else zoom
tile_range = bbox_to_tile_range(bbox, zoom)
stitched = Image.new(
"RGB",
(
(tile_range.x_max - tile_range.x_min + 1) * TILE_SIZE,
(tile_range.y_max - tile_range.y_min + 1) * TILE_SIZE,
),
)
for x in range(tile_range.x_min, tile_range.x_max + 1):
for y in range(tile_range.y_min, tile_range.y_max + 1):
tile = fetch_esri_tile(zoom, x, y)
stitched.paste(
tile,
(
(x - tile_range.x_min) * TILE_SIZE,
(y - tile_range.y_min) * TILE_SIZE,
),
)
time.sleep(0.05)
crop_box = _bbox_crop_box(bbox, tile_range, stitched.size)
cropped = stitched.crop(crop_box)
if cropped.width <= 0 or cropped.height <= 0:
raise TileFetchError("The fetched imagery crop was empty.")
return cropped
@st.cache_data(show_spinner=False)
def fetch_esri_tile(zoom: int, x: int, y: int) -> Image.Image:
"""Download one Esri World Imagery XYZ tile."""
url = ESRI_TILE_URL.format(z=zoom, x=x, y=y)
try:
response = requests.get(
url,
headers={"User-Agent": USER_AGENT},
timeout=10,
)
response.raise_for_status()
except requests.RequestException as exc:
raise TileFetchError(f"Could not download imagery tile z{zoom}/{x}/{y}.") from exc
try:
return Image.open(BytesIO(response.content)).convert("RGB")
except OSError as exc:
raise TileFetchError(f"Downloaded imagery tile z{zoom}/{x}/{y} was invalid.") from exc
def _bbox_crop_box(
bbox: BBox, tile_range: TileRange, stitched_size: tuple[int, int]
) -> tuple[int, int, int, int]:
zoom = tile_range.zoom
west_px, north_px = _lonlat_to_global_pixel(bbox.west, bbox.north, zoom)
east_px, south_px = _lonlat_to_global_pixel(bbox.east, bbox.south, zoom)
origin_x = tile_range.x_min * TILE_SIZE
origin_y = tile_range.y_min * TILE_SIZE
left = math.floor(west_px - origin_x)
top = math.floor(north_px - origin_y)
right = math.ceil(east_px - origin_x)
bottom = math.ceil(south_px - origin_y)
width, height = stitched_size
return (
max(0, min(width, left)),
max(0, min(height, top)),
max(0, min(width, right)),
max(0, min(height, bottom)),
)
def _lonlat_to_global_pixel(lon: float, lat: float, zoom: int) -> tuple[float, float]:
x_tile, y_tile = lonlat_to_tile_fraction(lon, lat, zoom)
return x_tile * TILE_SIZE, y_tile * TILE_SIZE
def _haversine_meters(lon1: float, lat1: float, lon2: float, lat2: float) -> float:
radius_m = 6_371_000
phi1 = math.radians(lat1)
phi2 = math.radians(lat2)
delta_phi = math.radians(lat2 - lat1)
delta_lambda = math.radians(lon2 - lon1)
a = (
math.sin(delta_phi / 2.0) ** 2
+ math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2.0) ** 2
)
return 2.0 * radius_m * math.atan2(math.sqrt(a), math.sqrt(1.0 - a))
|