File size: 8,583 Bytes
9d33171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import math
import time
from dataclasses import dataclass
from io import BytesIO
from typing import Any

import requests
import streamlit as st
from PIL import Image


ESRI_TILE_URL = (
    "https://server.arcgisonline.com/ArcGIS/rest/services/"
    "World_Imagery/MapServer/tile/{z}/{y}/{x}"
)
TILE_SIZE = 256
USER_AGENT = "eurosat-rgb-streamlit-demo/1.0"
EUROSAT_TARGET_MIN_M = 500
EUROSAT_TARGET_MAX_M = 1_000
EUROSAT_ACCEPTABLE_MIN_M = 250
EUROSAT_ACCEPTABLE_MAX_M = 1_500
EUROSAT_MAX_ASPECT_RATIO = 2.0


class TileFetchError(RuntimeError):
    """Raised when an Esri imagery tile cannot be fetched."""


@dataclass(frozen=True)
class BBox:
    west: float
    south: float
    east: float
    north: float


@dataclass(frozen=True)
class TileRange:
    zoom: int
    x_min: int
    x_max: int
    y_min: int
    y_max: int


def extract_bbox_from_geojson(drawing: dict[str, Any]) -> BBox:
    """Extract a lon/lat bbox from a Folium Draw GeoJSON rectangle."""
    geometry = drawing.get("geometry", {})
    coordinates = geometry.get("coordinates")
    if geometry.get("type") != "Polygon" or not coordinates:
        raise ValueError("Expected a drawn rectangle polygon.")

    ring = coordinates[0]
    lons = [point[0] for point in ring]
    lats = [point[1] for point in ring]
    west, east = min(lons), max(lons)
    south, north = min(lats), max(lats)
    if west == east or south == north:
        raise ValueError("The drawn rectangle has no area.")

    return BBox(west=west, south=south, east=east, north=north)


def lonlat_to_tile_fraction(lon: float, lat: float, zoom: int) -> tuple[float, float]:
    """Convert lon/lat to fractional XYZ tile coordinates.

    Uses the OpenStreetMap slippy-map convention:
    https://wiki.openstreetmap.org/wiki/Slippy_map_tilenames
    XYZ y coordinates start at 0 at the northern edge of the world.
    """
    lat = max(min(lat, 85.05112878), -85.05112878)
    lat_rad = math.radians(lat)
    n = 2**zoom
    x = (lon + 180.0) / 360.0 * n
    y = (
        1.0
        - math.log(math.tan(lat_rad) + (1.0 / math.cos(lat_rad))) / math.pi
    ) / 2.0 * n
    return x, y


def bbox_to_tile_range(bbox: BBox, zoom: int) -> TileRange:
    """Return the inclusive XYZ tile range covering a lon/lat bbox."""
    max_tile = (2**zoom) - 1
    x_west, y_north = lonlat_to_tile_fraction(bbox.west, bbox.north, zoom)
    x_east, y_south = lonlat_to_tile_fraction(bbox.east, bbox.south, zoom)

    x_min = max(0, min(max_tile, math.floor(x_west)))
    x_max = max(0, min(max_tile, math.floor(x_east)))
    y_min = max(0, min(max_tile, math.floor(y_north)))
    y_max = max(0, min(max_tile, math.floor(y_south)))

    return TileRange(
        zoom=zoom,
        x_min=min(x_min, x_max),
        x_max=max(x_min, x_max),
        y_min=min(y_min, y_max),
        y_max=max(y_min, y_max),
    )


def choose_zoom_level(bbox: BBox) -> int:
    """Choose a tile zoom; EuroSAT-scale rectangles use zoom 14-15."""
    width_m, height_m = bbox_size_meters(bbox)
    max_side_m = max(width_m, height_m)
    if max_side_m <= 1_000:
        return 15
    if max_side_m <= 5_000:
        return 14
    return 13


def bbox_size_meters(bbox: BBox) -> tuple[float, float]:
    """Approximate bbox width and height in meters."""
    mid_lat = (bbox.north + bbox.south) / 2.0
    width_m = _haversine_meters(bbox.west, mid_lat, bbox.east, mid_lat)
    height_m = _haversine_meters(bbox.west, bbox.south, bbox.west, bbox.north)
    return width_m, height_m


def size_warning_for_bbox(bbox: BBox) -> str | None:
    """Return a user-facing warning for rectangles outside the demo range."""
    width_m, height_m = bbox_size_meters(bbox)
    min_side_m = min(width_m, height_m)
    max_side_m = max(width_m, height_m)
    if min_side_m < 50:
        return "This rectangle is very small. Draw at least about 50m on a side."
    if max_side_m > 5_000:
        return "This rectangle is very large. Draw at most about 5km on a side."
    return None


def bbox_scale_status(bbox: BBox) -> tuple[str, str]:
    """Classify whether a bbox is close enough to EuroSAT-RGB tile scale."""
    width_m, height_m = bbox_size_meters(bbox)
    min_side_m = min(width_m, height_m)
    max_side_m = max(width_m, height_m)
    aspect_ratio = max_side_m / min_side_m

    if min_side_m < EUROSAT_ACCEPTABLE_MIN_M:
        return (
            "invalid",
            "This rectangle is too small for a useful EuroSAT-style prediction. "
            "Draw closer to 500m-1km on each side.",
        )
    if max_side_m > EUROSAT_ACCEPTABLE_MAX_M:
        return (
            "invalid",
            "This rectangle is too large for this EuroSAT-style demo. "
            "Zoom in and draw closer to 500m-1km on each side.",
        )
    if aspect_ratio > EUROSAT_MAX_ASPECT_RATIO:
        return (
            "invalid",
            "This rectangle is too stretched. Draw a more square region, like the original EuroSAT tiles.",
        )
    if (
        EUROSAT_TARGET_MIN_M <= min_side_m
        and max_side_m <= EUROSAT_TARGET_MAX_M
    ):
        return (
            "good",
            "Great scale: this is close to the original EuroSAT-RGB tile footprint.",
        )
    return (
        "usable",
        "Usable, but not ideal. For the most trustworthy demo result, draw 500m-1km on each side.",
    )


def fetch_bbox_image(bbox: BBox, zoom: int | None = None) -> Image.Image:
    """Fetch Esri XYZ tiles for a bbox, stitch them, and crop to the bbox."""
    zoom = choose_zoom_level(bbox) if zoom is None else zoom
    tile_range = bbox_to_tile_range(bbox, zoom)

    stitched = Image.new(
        "RGB",
        (
            (tile_range.x_max - tile_range.x_min + 1) * TILE_SIZE,
            (tile_range.y_max - tile_range.y_min + 1) * TILE_SIZE,
        ),
    )

    for x in range(tile_range.x_min, tile_range.x_max + 1):
        for y in range(tile_range.y_min, tile_range.y_max + 1):
            tile = fetch_esri_tile(zoom, x, y)
            stitched.paste(
                tile,
                (
                    (x - tile_range.x_min) * TILE_SIZE,
                    (y - tile_range.y_min) * TILE_SIZE,
                ),
            )
            time.sleep(0.05)

    crop_box = _bbox_crop_box(bbox, tile_range, stitched.size)
    cropped = stitched.crop(crop_box)
    if cropped.width <= 0 or cropped.height <= 0:
        raise TileFetchError("The fetched imagery crop was empty.")
    return cropped


@st.cache_data(show_spinner=False)
def fetch_esri_tile(zoom: int, x: int, y: int) -> Image.Image:
    """Download one Esri World Imagery XYZ tile."""
    url = ESRI_TILE_URL.format(z=zoom, x=x, y=y)
    try:
        response = requests.get(
            url,
            headers={"User-Agent": USER_AGENT},
            timeout=10,
        )
        response.raise_for_status()
    except requests.RequestException as exc:
        raise TileFetchError(f"Could not download imagery tile z{zoom}/{x}/{y}.") from exc

    try:
        return Image.open(BytesIO(response.content)).convert("RGB")
    except OSError as exc:
        raise TileFetchError(f"Downloaded imagery tile z{zoom}/{x}/{y} was invalid.") from exc


def _bbox_crop_box(
    bbox: BBox, tile_range: TileRange, stitched_size: tuple[int, int]
) -> tuple[int, int, int, int]:
    zoom = tile_range.zoom
    west_px, north_px = _lonlat_to_global_pixel(bbox.west, bbox.north, zoom)
    east_px, south_px = _lonlat_to_global_pixel(bbox.east, bbox.south, zoom)
    origin_x = tile_range.x_min * TILE_SIZE
    origin_y = tile_range.y_min * TILE_SIZE

    left = math.floor(west_px - origin_x)
    top = math.floor(north_px - origin_y)
    right = math.ceil(east_px - origin_x)
    bottom = math.ceil(south_px - origin_y)

    width, height = stitched_size
    return (
        max(0, min(width, left)),
        max(0, min(height, top)),
        max(0, min(width, right)),
        max(0, min(height, bottom)),
    )


def _lonlat_to_global_pixel(lon: float, lat: float, zoom: int) -> tuple[float, float]:
    x_tile, y_tile = lonlat_to_tile_fraction(lon, lat, zoom)
    return x_tile * TILE_SIZE, y_tile * TILE_SIZE


def _haversine_meters(lon1: float, lat1: float, lon2: float, lat2: float) -> float:
    radius_m = 6_371_000
    phi1 = math.radians(lat1)
    phi2 = math.radians(lat2)
    delta_phi = math.radians(lat2 - lat1)
    delta_lambda = math.radians(lon2 - lon1)
    a = (
        math.sin(delta_phi / 2.0) ** 2
        + math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2.0) ** 2
    )
    return 2.0 * radius_m * math.atan2(math.sqrt(a), math.sqrt(1.0 - a))