File size: 8,520 Bytes
9d33171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f72eff0
9d33171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9713a88
9d33171
 
 
 
 
9713a88
 
 
 
 
 
 
 
 
 
 
 
9d33171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
from pathlib import Path
import sys

import folium
import streamlit as st
from branca.element import MacroElement, Template
from folium.plugins import Draw, MeasureControl
from streamlit_folium import st_folium

REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from app.model_utils import (  # noqa: E402
    CLASS_NAMES,
    DEFAULT_CHECKPOINT_PATH,
    load_rgb_model,
    predict_topk,
)
from app.tile_utils import (  # noqa: E402
    TileFetchError,
    bbox_scale_status,
    bbox_size_meters,
    choose_zoom_level,
    extract_bbox_from_geojson,
    fetch_bbox_image,
    size_warning_for_bbox,
)


ESRI_WORLD_IMAGERY = (
    "https://server.arcgisonline.com/ArcGIS/rest/services/"
    "World_Imagery/MapServer/tile/{z}/{y}/{x}"
)

# Farmlands starter pos, good for examples right of the bat
DEFAULT_MAP_CENTER = [50, 10]
DEFAULT_MAP_ZOOM = 15


class SingleRectangleLimiter(MacroElement):
    _template = Template(
        """
        {% macro script(this, kwargs) %}
        {{ this.map_name }}.on('draw:created', function(e) {
            {{ this.drawn_items_name }}.clearLayers();
            {{ this.drawn_items_name }}.addLayer(e.layer);
        });
        {% endmacro %}
        """
    )

    def __init__(self, map_name: str, drawn_items_name: str):
        super().__init__()
        self._name = "SingleRectangleLimiter"
        self.map_name = map_name
        self.drawn_items_name = drawn_items_name


@st.cache_resource(show_spinner="Loading RGB ResNet-50 model...")
def get_model():
    return load_rgb_model(DEFAULT_CHECKPOINT_PATH)


def build_map(
    drawing: dict | None = None,
    center: list[float] | None = None,
    zoom: int = DEFAULT_MAP_ZOOM,
) -> folium.Map:
    fmap = folium.Map(
        location=center or DEFAULT_MAP_CENTER,
        zoom_start=zoom,
        min_zoom=13,
        max_zoom=18,
        tiles=None,
        control_scale=True,
    )
    folium.TileLayer(
        tiles=ESRI_WORLD_IMAGERY,
        attr="Tiles © Esri — Source: Esri, Maxar, Earthstar Geographics, and GIS User Community",
        name="Esri World Imagery",
        overlay=False,
        control=True,
    ).add_to(fmap)
    draw_control = Draw(
        export=False,
        draw_options={
            "polyline": False,
            "polygon": False,
            "circle": False,
            "marker": False,
            "circlemarker": False,
            "rectangle": {
                "shapeOptions": {
                    "color": "#ff7800",
                    "weight": 2,
                    "fillOpacity": 0.05,
                }
            },
        },
        edit_options={"edit": True, "remove": True},
    )
    draw_control.add_to(fmap)
    SingleRectangleLimiter(
        map_name=fmap.get_name(),
        drawn_items_name=f"drawnItems_{draw_control.get_name()}",
    ).add_to(fmap)
    MeasureControl(
        position="bottomleft",
        primary_length_unit="meters",
        secondary_length_unit="kilometers",
        primary_area_unit="sqmeters",
    ).add_to(fmap)

    if drawing:
        folium.GeoJson(
            drawing,
            name="Last selected rectangle",
            style_function=lambda _: {
                "color": "#00bcd4",
                "weight": 2,
                "fillOpacity": 0.04,
            },
        ).add_to(fmap)

    return fmap


def render_sidebar() -> None:
    st.sidebar.header("How to use")
    st.sidebar.markdown(
        "1. Pan and zoom to a land area.\n"
        "2. Select the rectangle tool on the map.\n"
        "3. Use the map scale bar or measure tool as a guide.\n"
        "4. Draw a near-square box roughly 500m-1km on a side.\n"
        "5. Review the fetched image and top predictions."
    )
    st.sidebar.warning(
        "This model was trained on EuroSAT-RGB tiles (~64x64 pixels, ~640m on a side). "
        "Predictions on arbitrary map regions are illustrative; for best results, draw "
        "a rectangle of roughly 500m-1km on a side over land."
    )
    st.sidebar.header("EuroSAT Classes")
    for class_name in CLASS_NAMES:
        st.sidebar.write(f"- {class_name}")


def render_prediction(drawing) -> None:
    try:
        bbox = extract_bbox_from_geojson(drawing)
    except ValueError as exc:
        st.error(str(exc))
        return

    width_m, height_m = bbox_size_meters(bbox)
    scale_state, scale_message = bbox_scale_status(bbox)

    metric_col, scale_col, zoom_col = st.columns(3)
    metric_col.metric("Rectangle width", format_meters(width_m))
    scale_col.metric("Rectangle height", format_meters(height_m))
    zoom_col.metric("Tile zoom", choose_zoom_level(bbox))

    warning = size_warning_for_bbox(bbox)
    if warning:
        st.warning(warning)
        return
    if scale_state == "invalid":
        st.warning(scale_message)
        return
    if scale_state == "good":
        st.success(scale_message)
    else:
        st.warning(scale_message)

    try:
        with st.spinner("Fetching Esri imagery tiles..."):
            image = fetch_bbox_image(bbox)
    except TileFetchError as exc:
        st.error(f"Could not fetch satellite imagery for this rectangle. {exc}")
        return

    try:
        model = get_model()
    except FileNotFoundError as exc:
        st.error(str(exc))
        return

    with st.spinner("Running RGB land cover inference..."):
        top_predictions = predict_topk(model, image, top_k=3)

    preview_col, prediction_col = st.columns([1, 1])
    with preview_col:
        st.subheader("Fetched Tile Preview")
        st.image(image, caption="Cropped Esri World Imagery", width='stretch')

    with prediction_col:
        st.subheader("Prediction")
        best_class, best_prob = top_predictions[0]
        st.metric("Predicted class", best_class, f"{best_prob:.1%}")
        st.write("Top-3 class probabilities")
        st.bar_chart(
            {"Probability": {name: prob for name, prob in top_predictions}},
            horizontal=True,
        )


def format_meters(value: float) -> str:
    if value >= 1_000:
        return f"{value / 1_000:.2f} km"
    return f"{value:.0f} m"


def get_drawing(data: dict | None) -> dict | None:
    incoming_drawing = data.get("last_active_drawing") if data else None
    current_drawing = st.session_state.get("last_drawing")
    if incoming_drawing and incoming_drawing != current_drawing:
        st.session_state["last_drawing"] = incoming_drawing
        st.session_state["map_center"] = drawing_center(incoming_drawing)
        st.rerun()
    return st.session_state.get("last_drawing")


def drawing_center(drawing: dict) -> list[float]:
    bbox = extract_bbox_from_geojson(drawing)
    return [
        (bbox.south + bbox.north) / 2.0,
        (bbox.west + bbox.east) / 2.0,
    ]


def reset_map() -> None:
    st.session_state["map_version"] = st.session_state.get("map_version", 0) + 1


def clear_selection() -> None:
    st.session_state.pop("last_drawing", None)
    st.session_state.pop("map_center", None)
    reset_map()


def main() -> None:
    st.set_page_config(
        page_title="EuroSAT RGB Land Cover Classifier",
        layout="wide",
        initial_sidebar_state="expanded",
    )
    st.title("EuroSAT Land Cover Classifier (RGB Model)")
    st.markdown(
        "This demo classifies RGB satellite imagery into the 10 EuroSAT land cover "
        "classes using a ResNet-50."
        """
        <style>
        html {
            overflow-y: scroll;
        }
        .block-container {
            max-width: 1200px;
            margin: 0 auto;
        }
        </style>
        """,
        unsafe_allow_html=True,
    )
    render_sidebar()

    previous_drawing = st.session_state.get("last_drawing")
    map_version = st.session_state.get("map_version", 0)
    map_center = st.session_state.get("map_center", DEFAULT_MAP_CENTER)
    data = st_folium(
        build_map(previous_drawing, center=map_center),
        key=f"eurosat-rgb-map-{map_version}",
        height=600,
        width='stretch',
        returned_objects=["last_active_drawing"],
    )

    drawing = get_drawing(data)
    if drawing:
        st.button("Reset rectangle", on_click=clear_selection)
        render_prediction(drawing)
    else:
        st.info(
            "Draw a near-square rectangle on the map to fetch imagery and run the classifier. "
            "Aim for 500m-1km on each side, similar to the original EuroSAT-RGB tiles."
        )


if __name__ == "__main__":
    main()