davidlsan commited on
Commit
9d33171
·
verified ·
1 Parent(s): b7ab62b

Add Streamlit app source and RGB model weights

Browse files
Files changed (9) hide show
  1. README.md +55 -13
  2. app/.DS_Store +0 -0
  3. app/__init__.py +1 -0
  4. app/app.py +276 -0
  5. app/model_utils.py +55 -0
  6. app/tile_utils.py +261 -0
  7. requirements.txt +309 -3
  8. train.py +169 -0
  9. weights/rgb_e15_best.pt +3 -0
README.md CHANGED
@@ -1,19 +1,61 @@
1
  ---
2
  title: EuroSAT RGB Land Cover Classifier
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: red
6
- sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
- pinned: false
11
- short_description: Intereactive demo for classifying land cover classes
12
  ---
13
 
14
- # Welcome to Streamlit!
15
 
16
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
19
- forums](https://discuss.streamlit.io).
 
1
  ---
2
  title: EuroSAT RGB Land Cover Classifier
3
+ sdk: streamlit
4
+ app_file: app/app.py
 
 
 
 
 
 
 
5
  ---
6
 
7
+ # EuroSAT Land Cover Classification
8
 
9
+ CNN-based land cover classification on EuroSAT, comparing RGB imagery with 13-band Sentinel-2 multispectral input.
10
+
11
+ ## Streamlit RGB Demo
12
+
13
+ `app/app.py` is a Hugging Face Spaces-ready Streamlit demo for the EuroSAT-RGB ResNet-50 classifier. It shows an Esri World Imagery map centered on Bergen, Norway, lets a user draw a rectangle, fetches the corresponding RGB map tiles, and displays the predicted EuroSAT land cover class plus the top-3 class probabilities.
14
+
15
+ The RGB model was trained on EuroSAT-RGB tiles, which are about 64x64 pixels and roughly 640m on a side. Predictions on arbitrary map regions are illustrative; for best results, draw a rectangle of roughly 500m-1km on a side over land.
16
+
17
+ Classes: Annual Crop, Forest, Herbaceous Vegetation, Highway, Industrial Buildings, Pasture, Permanent Crop, Residential Buildings, River, SeaLake.
18
+
19
+ Validation accuracy on EuroSAT-RGB: **96.8%**.
20
+
21
+ Main GitHub repo: [davidlsan/EuroSAT-Land-Cover-Classification](https://github.com/davidlsan/EuroSAT-Land-Cover-Classification)
22
+
23
+ ## Run Locally
24
+
25
+ Place the trained RGB checkpoint at:
26
+
27
+ ```bash
28
+ weights/rgb_e15_best.pt
29
+ ```
30
+
31
+ Install dependencies with uv and start the app:
32
+
33
+ ```bash
34
+ uv sync
35
+ uv run streamlit run app/app.py
36
+ ```
37
+
38
+ For Hugging Face Spaces, use the Streamlit SDK and include `app/app.py`, `app/model_utils.py`, `app/tile_utils.py`, `requirements.txt`, and the checkpoint at `weights/rgb_e15_best.pt`.
39
+
40
+ ## Notebooks
41
+
42
+ - `[notebooks/01_data_exploration.ipynb](notebooks/01_data_exploration.ipynb)` - RGB EDA (class balance, sample grid). Run from the repository root so PNGs land in `[figures/](figures/)`.
43
+ - `[notebooks/02_data_exploration_multispectral.ipynb](notebooks/02_data_exploration_multispectral.ipynb)` - 13-band MSI EDA (class balance, composites, per-band stats, class-mean spectra, band correlation, RGB-MSI alignment).
44
+
45
+ ## Training
46
+
47
+ Run from the repository root:
48
+
49
+ ```bash
50
+ python3 main.py --modality rgb --epochs 15 --batch-size 32 --num-workers 2 --lr 1e-3
51
+ ```
52
+
53
+ ### CLI Flags
54
+
55
+ - `--modality`: values `rgb` or `msi`. RGB loads the `blanchon/EuroSAT_RGB` dataset, while MSI loads `blanchon/EuroSAT_MSI`.
56
+ - `--epochs`: number of full passes over the training split. Defaults to `15`.
57
+ - `--batch-size`: number of samples per batch. Defaults to `32`.
58
+ - `--num-workers`: DataLoader worker processes. Defaults to `2`, but `0` is safer for debugging.
59
+ - `--lr`: learning rate for Adam. Defaults to `1e-3`.
60
+ - `--seed`: seeds Python, NumPy, and PyTorch for best-effort reproducibility.
61
 
 
 
app/.DS_Store ADDED
Binary file (6.15 kB). View file
 
app/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Streamlit EuroSAT RGB demo package."""
app/app.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import sys
3
+
4
+ import folium
5
+ import streamlit as st
6
+ from branca.element import MacroElement, Template
7
+ from folium.plugins import Draw, MeasureControl
8
+ from streamlit_folium import st_folium
9
+
10
+ REPO_ROOT = Path(__file__).resolve().parents[1]
11
+ if str(REPO_ROOT) not in sys.path:
12
+ sys.path.insert(0, str(REPO_ROOT))
13
+
14
+ from app.model_utils import ( # noqa: E402
15
+ CLASS_NAMES,
16
+ DEFAULT_CHECKPOINT_PATH,
17
+ load_rgb_model,
18
+ predict_topk,
19
+ )
20
+ from app.tile_utils import ( # noqa: E402
21
+ TileFetchError,
22
+ bbox_scale_status,
23
+ bbox_size_meters,
24
+ choose_zoom_level,
25
+ extract_bbox_from_geojson,
26
+ fetch_bbox_image,
27
+ size_warning_for_bbox,
28
+ )
29
+
30
+
31
+ ESRI_WORLD_IMAGERY = (
32
+ "https://server.arcgisonline.com/ArcGIS/rest/services/"
33
+ "World_Imagery/MapServer/tile/{z}/{y}/{x}"
34
+ )
35
+
36
+ # Farmlands starter pos, good for examples right of the bat
37
+ DEFAULT_MAP_CENTER = [50, 10]
38
+ DEFAULT_MAP_ZOOM = 15
39
+
40
+
41
+ class SingleRectangleLimiter(MacroElement):
42
+ _template = Template(
43
+ """
44
+ {% macro script(this, kwargs) %}
45
+ {{ this.map_name }}.on('draw:created', function(e) {
46
+ {{ this.drawn_items_name }}.clearLayers();
47
+ {{ this.drawn_items_name }}.addLayer(e.layer);
48
+ });
49
+ {% endmacro %}
50
+ """
51
+ )
52
+
53
+ def __init__(self, map_name: str, drawn_items_name: str):
54
+ super().__init__()
55
+ self._name = "SingleRectangleLimiter"
56
+ self.map_name = map_name
57
+ self.drawn_items_name = drawn_items_name
58
+
59
+
60
+ @st.cache_resource(show_spinner="Loading RGB ResNet-50 model...")
61
+ def get_model():
62
+ return load_rgb_model(DEFAULT_CHECKPOINT_PATH)
63
+
64
+
65
+ def build_map(
66
+ drawing: dict | None = None,
67
+ center: list[float] | None = None,
68
+ zoom: int = DEFAULT_MAP_ZOOM,
69
+ ) -> folium.Map:
70
+ fmap = folium.Map(
71
+ location=center or DEFAULT_MAP_CENTER,
72
+ zoom_start=zoom,
73
+ min_zoom=13,
74
+ max_zoom=18,
75
+ tiles=None,
76
+ control_scale=True,
77
+ )
78
+ folium.TileLayer(
79
+ tiles=ESRI_WORLD_IMAGERY,
80
+ attr="Tiles © Esri — Source: Esri, Maxar, Earthstar Geographics, and GIS User Community",
81
+ name="Esri World Imagery",
82
+ overlay=False,
83
+ control=True,
84
+ ).add_to(fmap)
85
+ draw_control = Draw(
86
+ export=False,
87
+ draw_options={
88
+ "polyline": False,
89
+ "polygon": False,
90
+ "circle": False,
91
+ "marker": False,
92
+ "circlemarker": False,
93
+ "rectangle": {
94
+ "shapeOptions": {
95
+ "color": "#ff7800",
96
+ "weight": 2,
97
+ "fillOpacity": 0.05,
98
+ }
99
+ },
100
+ },
101
+ edit_options={"edit": True, "remove": True},
102
+ )
103
+ draw_control.add_to(fmap)
104
+ SingleRectangleLimiter(
105
+ map_name=fmap.get_name(),
106
+ drawn_items_name=f"drawnItems_{draw_control.get_name()}",
107
+ ).add_to(fmap)
108
+ MeasureControl(
109
+ position="bottomleft",
110
+ primary_length_unit="meters",
111
+ secondary_length_unit="kilometers",
112
+ primary_area_unit="sqmeters",
113
+ ).add_to(fmap)
114
+
115
+ if drawing:
116
+ folium.GeoJson(
117
+ drawing,
118
+ name="Last selected rectangle",
119
+ style_function=lambda _: {
120
+ "color": "#00bcd4",
121
+ "weight": 2,
122
+ "fillOpacity": 0.04,
123
+ },
124
+ ).add_to(fmap)
125
+
126
+ return fmap
127
+
128
+
129
+ def render_sidebar() -> None:
130
+ st.sidebar.header("How to use")
131
+ st.sidebar.markdown(
132
+ "1. Pan and zoom to a land area.\n"
133
+ "2. Select the rectangle tool on the map.\n"
134
+ "3. Use the map scale bar or measure tool as a guide.\n"
135
+ "4. Draw a near-square box roughly 500m-1km on a side.\n"
136
+ "5. Review the fetched image and top predictions."
137
+ )
138
+ st.sidebar.warning(
139
+ "This model was trained on EuroSAT-RGB tiles (~64x64 pixels, ~640m on a side). "
140
+ "Predictions on arbitrary map regions are illustrative; for best results, draw "
141
+ "a rectangle of roughly 500m-1km on a side over land."
142
+ )
143
+ st.sidebar.header("EuroSAT Classes")
144
+ for class_name in CLASS_NAMES:
145
+ st.sidebar.write(f"- {class_name}")
146
+
147
+
148
+ def render_prediction(drawing) -> None:
149
+ try:
150
+ bbox = extract_bbox_from_geojson(drawing)
151
+ except ValueError as exc:
152
+ st.error(str(exc))
153
+ return
154
+
155
+ width_m, height_m = bbox_size_meters(bbox)
156
+ scale_state, scale_message = bbox_scale_status(bbox)
157
+
158
+ metric_col, scale_col, zoom_col = st.columns(3)
159
+ metric_col.metric("Rectangle width", format_meters(width_m))
160
+ scale_col.metric("Rectangle height", format_meters(height_m))
161
+ zoom_col.metric("Tile zoom", choose_zoom_level(bbox))
162
+
163
+ warning = size_warning_for_bbox(bbox)
164
+ if warning:
165
+ st.warning(warning)
166
+ return
167
+ if scale_state == "invalid":
168
+ st.warning(scale_message)
169
+ return
170
+ if scale_state == "good":
171
+ st.success(scale_message)
172
+ else:
173
+ st.warning(scale_message)
174
+
175
+ try:
176
+ with st.spinner("Fetching Esri imagery tiles..."):
177
+ image = fetch_bbox_image(bbox)
178
+ except TileFetchError as exc:
179
+ st.error(f"Could not fetch satellite imagery for this rectangle. {exc}")
180
+ return
181
+
182
+ try:
183
+ model = get_model()
184
+ except FileNotFoundError as exc:
185
+ st.error(str(exc))
186
+ return
187
+
188
+ with st.spinner("Running RGB land cover inference..."):
189
+ top_predictions = predict_topk(model, image, top_k=3)
190
+
191
+ preview_col, prediction_col = st.columns([1, 1])
192
+ with preview_col:
193
+ st.subheader("Fetched Tile Preview")
194
+ st.image(image, caption="Cropped Esri World Imagery", width='stretch')
195
+
196
+ with prediction_col:
197
+ st.subheader("Prediction")
198
+ best_class, best_prob = top_predictions[0]
199
+ st.metric("Predicted class", best_class, f"{best_prob:.1%}")
200
+ st.write("Top-3 class probabilities")
201
+ st.bar_chart(
202
+ {"Probability": {name: prob for name, prob in top_predictions}},
203
+ horizontal=True,
204
+ )
205
+
206
+
207
+ def format_meters(value: float) -> str:
208
+ if value >= 1_000:
209
+ return f"{value / 1_000:.2f} km"
210
+ return f"{value:.0f} m"
211
+
212
+
213
+ def get_drawing(data: dict | None) -> dict | None:
214
+ incoming_drawing = data.get("last_active_drawing") if data else None
215
+ current_drawing = st.session_state.get("last_drawing")
216
+ if incoming_drawing and incoming_drawing != current_drawing:
217
+ st.session_state["last_drawing"] = incoming_drawing
218
+ st.session_state["map_center"] = drawing_center(incoming_drawing)
219
+ st.rerun()
220
+ return st.session_state.get("last_drawing")
221
+
222
+
223
+ def drawing_center(drawing: dict) -> list[float]:
224
+ bbox = extract_bbox_from_geojson(drawing)
225
+ return [
226
+ (bbox.south + bbox.north) / 2.0,
227
+ (bbox.west + bbox.east) / 2.0,
228
+ ]
229
+
230
+
231
+ def reset_map() -> None:
232
+ st.session_state["map_version"] = st.session_state.get("map_version", 0) + 1
233
+
234
+
235
+ def clear_selection() -> None:
236
+ st.session_state.pop("last_drawing", None)
237
+ st.session_state.pop("map_center", None)
238
+ reset_map()
239
+
240
+
241
+ def main() -> None:
242
+ st.set_page_config(
243
+ page_title="EuroSAT RGB Land Cover Classifier",
244
+ layout="wide",
245
+ )
246
+ st.title("EuroSAT Land Cover Classifier (RGB Model)")
247
+ st.markdown(
248
+ "This demo classifies RGB satellite imagery into the 10 EuroSAT land cover "
249
+ "classes using a ResNet-50."
250
+ )
251
+ render_sidebar()
252
+
253
+ previous_drawing = st.session_state.get("last_drawing")
254
+ map_version = st.session_state.get("map_version", 0)
255
+ map_center = st.session_state.get("map_center", DEFAULT_MAP_CENTER)
256
+ data = st_folium(
257
+ build_map(previous_drawing, center=map_center),
258
+ key=f"eurosat-rgb-map-{map_version}",
259
+ height=600,
260
+ width='stretch',
261
+ returned_objects=["last_active_drawing"],
262
+ )
263
+
264
+ drawing = get_drawing(data)
265
+ if drawing:
266
+ st.button("Reset rectangle", on_click=clear_selection)
267
+ render_prediction(drawing)
268
+ else:
269
+ st.info(
270
+ "Draw a near-square rectangle on the map to fetch imagery and run the classifier. "
271
+ "Aim for 500m-1km on each side, similar to the original EuroSAT-RGB tiles."
272
+ )
273
+
274
+
275
+ if __name__ == "__main__":
276
+ main()
app/model_utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ from PIL import Image
5
+
6
+ from train import build_model, build_rgb_transform
7
+
8
+
9
+ CLASS_NAMES = [
10
+ "Annual Crop",
11
+ "Forest",
12
+ "Herbaceous Vegetation",
13
+ "Highway",
14
+ "Industrial Buildings",
15
+ "Pasture",
16
+ "Permanent Crop",
17
+ "Residential Buildings",
18
+ "River",
19
+ "SeaLake",
20
+ ]
21
+
22
+ DEFAULT_CHECKPOINT_PATH = Path("weights/rgb_e15_best.pt")
23
+
24
+
25
+ def load_rgb_model(checkpoint_path: str | Path = DEFAULT_CHECKPOINT_PATH) -> torch.nn.Module:
26
+ """Load the EuroSAT-RGB ResNet-50 checkpoint for CPU inference."""
27
+ checkpoint_path = Path(checkpoint_path)
28
+ if not checkpoint_path.exists():
29
+ raise FileNotFoundError(
30
+ f"RGB checkpoint not found at {checkpoint_path}. "
31
+ "Add weights/rgb_e15_best.pt before running the demo."
32
+ )
33
+
34
+ device = torch.device("cpu")
35
+ model = build_model(num_classes=len(CLASS_NAMES), device=device, in_channels=3)
36
+ checkpoint = torch.load(checkpoint_path, map_location=device)
37
+ model.load_state_dict(checkpoint["model_state_dict"])
38
+ model.eval()
39
+ return model
40
+
41
+
42
+ @torch.no_grad()
43
+ def predict_topk(
44
+ model: torch.nn.Module, image: Image.Image, top_k: int = 3
45
+ ) -> list[tuple[str, float]]:
46
+ """Run RGB inference and return class names with probabilities."""
47
+ transform = build_rgb_transform(train=False)
48
+ tensor = transform(image.convert("RGB")).unsqueeze(0)
49
+ logits = model(tensor)
50
+ probs = torch.softmax(logits, dim=1).squeeze(0)
51
+ top_probs, top_indices = torch.topk(probs, k=top_k)
52
+ return [
53
+ (CLASS_NAMES[int(class_idx)], float(prob))
54
+ for prob, class_idx in zip(top_probs, top_indices, strict=True)
55
+ ]
app/tile_utils.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import time
3
+ from dataclasses import dataclass
4
+ from io import BytesIO
5
+ from typing import Any
6
+
7
+ import requests
8
+ import streamlit as st
9
+ from PIL import Image
10
+
11
+
12
+ ESRI_TILE_URL = (
13
+ "https://server.arcgisonline.com/ArcGIS/rest/services/"
14
+ "World_Imagery/MapServer/tile/{z}/{y}/{x}"
15
+ )
16
+ TILE_SIZE = 256
17
+ USER_AGENT = "eurosat-rgb-streamlit-demo/1.0"
18
+ EUROSAT_TARGET_MIN_M = 500
19
+ EUROSAT_TARGET_MAX_M = 1_000
20
+ EUROSAT_ACCEPTABLE_MIN_M = 250
21
+ EUROSAT_ACCEPTABLE_MAX_M = 1_500
22
+ EUROSAT_MAX_ASPECT_RATIO = 2.0
23
+
24
+
25
+ class TileFetchError(RuntimeError):
26
+ """Raised when an Esri imagery tile cannot be fetched."""
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class BBox:
31
+ west: float
32
+ south: float
33
+ east: float
34
+ north: float
35
+
36
+
37
+ @dataclass(frozen=True)
38
+ class TileRange:
39
+ zoom: int
40
+ x_min: int
41
+ x_max: int
42
+ y_min: int
43
+ y_max: int
44
+
45
+
46
+ def extract_bbox_from_geojson(drawing: dict[str, Any]) -> BBox:
47
+ """Extract a lon/lat bbox from a Folium Draw GeoJSON rectangle."""
48
+ geometry = drawing.get("geometry", {})
49
+ coordinates = geometry.get("coordinates")
50
+ if geometry.get("type") != "Polygon" or not coordinates:
51
+ raise ValueError("Expected a drawn rectangle polygon.")
52
+
53
+ ring = coordinates[0]
54
+ lons = [point[0] for point in ring]
55
+ lats = [point[1] for point in ring]
56
+ west, east = min(lons), max(lons)
57
+ south, north = min(lats), max(lats)
58
+ if west == east or south == north:
59
+ raise ValueError("The drawn rectangle has no area.")
60
+
61
+ return BBox(west=west, south=south, east=east, north=north)
62
+
63
+
64
+ def lonlat_to_tile_fraction(lon: float, lat: float, zoom: int) -> tuple[float, float]:
65
+ """Convert lon/lat to fractional XYZ tile coordinates.
66
+
67
+ Uses the OpenStreetMap slippy-map convention:
68
+ https://wiki.openstreetmap.org/wiki/Slippy_map_tilenames
69
+ XYZ y coordinates start at 0 at the northern edge of the world.
70
+ """
71
+ lat = max(min(lat, 85.05112878), -85.05112878)
72
+ lat_rad = math.radians(lat)
73
+ n = 2**zoom
74
+ x = (lon + 180.0) / 360.0 * n
75
+ y = (
76
+ 1.0
77
+ - math.log(math.tan(lat_rad) + (1.0 / math.cos(lat_rad))) / math.pi
78
+ ) / 2.0 * n
79
+ return x, y
80
+
81
+
82
+ def bbox_to_tile_range(bbox: BBox, zoom: int) -> TileRange:
83
+ """Return the inclusive XYZ tile range covering a lon/lat bbox."""
84
+ max_tile = (2**zoom) - 1
85
+ x_west, y_north = lonlat_to_tile_fraction(bbox.west, bbox.north, zoom)
86
+ x_east, y_south = lonlat_to_tile_fraction(bbox.east, bbox.south, zoom)
87
+
88
+ x_min = max(0, min(max_tile, math.floor(x_west)))
89
+ x_max = max(0, min(max_tile, math.floor(x_east)))
90
+ y_min = max(0, min(max_tile, math.floor(y_north)))
91
+ y_max = max(0, min(max_tile, math.floor(y_south)))
92
+
93
+ return TileRange(
94
+ zoom=zoom,
95
+ x_min=min(x_min, x_max),
96
+ x_max=max(x_min, x_max),
97
+ y_min=min(y_min, y_max),
98
+ y_max=max(y_min, y_max),
99
+ )
100
+
101
+
102
+ def choose_zoom_level(bbox: BBox) -> int:
103
+ """Choose a tile zoom; EuroSAT-scale rectangles use zoom 14-15."""
104
+ width_m, height_m = bbox_size_meters(bbox)
105
+ max_side_m = max(width_m, height_m)
106
+ if max_side_m <= 1_000:
107
+ return 15
108
+ if max_side_m <= 5_000:
109
+ return 14
110
+ return 13
111
+
112
+
113
+ def bbox_size_meters(bbox: BBox) -> tuple[float, float]:
114
+ """Approximate bbox width and height in meters."""
115
+ mid_lat = (bbox.north + bbox.south) / 2.0
116
+ width_m = _haversine_meters(bbox.west, mid_lat, bbox.east, mid_lat)
117
+ height_m = _haversine_meters(bbox.west, bbox.south, bbox.west, bbox.north)
118
+ return width_m, height_m
119
+
120
+
121
+ def size_warning_for_bbox(bbox: BBox) -> str | None:
122
+ """Return a user-facing warning for rectangles outside the demo range."""
123
+ width_m, height_m = bbox_size_meters(bbox)
124
+ min_side_m = min(width_m, height_m)
125
+ max_side_m = max(width_m, height_m)
126
+ if min_side_m < 50:
127
+ return "This rectangle is very small. Draw at least about 50m on a side."
128
+ if max_side_m > 5_000:
129
+ return "This rectangle is very large. Draw at most about 5km on a side."
130
+ return None
131
+
132
+
133
+ def bbox_scale_status(bbox: BBox) -> tuple[str, str]:
134
+ """Classify whether a bbox is close enough to EuroSAT-RGB tile scale."""
135
+ width_m, height_m = bbox_size_meters(bbox)
136
+ min_side_m = min(width_m, height_m)
137
+ max_side_m = max(width_m, height_m)
138
+ aspect_ratio = max_side_m / min_side_m
139
+
140
+ if min_side_m < EUROSAT_ACCEPTABLE_MIN_M:
141
+ return (
142
+ "invalid",
143
+ "This rectangle is too small for a useful EuroSAT-style prediction. "
144
+ "Draw closer to 500m-1km on each side.",
145
+ )
146
+ if max_side_m > EUROSAT_ACCEPTABLE_MAX_M:
147
+ return (
148
+ "invalid",
149
+ "This rectangle is too large for this EuroSAT-style demo. "
150
+ "Zoom in and draw closer to 500m-1km on each side.",
151
+ )
152
+ if aspect_ratio > EUROSAT_MAX_ASPECT_RATIO:
153
+ return (
154
+ "invalid",
155
+ "This rectangle is too stretched. Draw a more square region, like the original EuroSAT tiles.",
156
+ )
157
+ if (
158
+ EUROSAT_TARGET_MIN_M <= min_side_m
159
+ and max_side_m <= EUROSAT_TARGET_MAX_M
160
+ ):
161
+ return (
162
+ "good",
163
+ "Great scale: this is close to the original EuroSAT-RGB tile footprint.",
164
+ )
165
+ return (
166
+ "usable",
167
+ "Usable, but not ideal. For the most trustworthy demo result, draw 500m-1km on each side.",
168
+ )
169
+
170
+
171
+ def fetch_bbox_image(bbox: BBox, zoom: int | None = None) -> Image.Image:
172
+ """Fetch Esri XYZ tiles for a bbox, stitch them, and crop to the bbox."""
173
+ zoom = choose_zoom_level(bbox) if zoom is None else zoom
174
+ tile_range = bbox_to_tile_range(bbox, zoom)
175
+
176
+ stitched = Image.new(
177
+ "RGB",
178
+ (
179
+ (tile_range.x_max - tile_range.x_min + 1) * TILE_SIZE,
180
+ (tile_range.y_max - tile_range.y_min + 1) * TILE_SIZE,
181
+ ),
182
+ )
183
+
184
+ for x in range(tile_range.x_min, tile_range.x_max + 1):
185
+ for y in range(tile_range.y_min, tile_range.y_max + 1):
186
+ tile = fetch_esri_tile(zoom, x, y)
187
+ stitched.paste(
188
+ tile,
189
+ (
190
+ (x - tile_range.x_min) * TILE_SIZE,
191
+ (y - tile_range.y_min) * TILE_SIZE,
192
+ ),
193
+ )
194
+ time.sleep(0.05)
195
+
196
+ crop_box = _bbox_crop_box(bbox, tile_range, stitched.size)
197
+ cropped = stitched.crop(crop_box)
198
+ if cropped.width <= 0 or cropped.height <= 0:
199
+ raise TileFetchError("The fetched imagery crop was empty.")
200
+ return cropped
201
+
202
+
203
+ @st.cache_data(show_spinner=False)
204
+ def fetch_esri_tile(zoom: int, x: int, y: int) -> Image.Image:
205
+ """Download one Esri World Imagery XYZ tile."""
206
+ url = ESRI_TILE_URL.format(z=zoom, x=x, y=y)
207
+ try:
208
+ response = requests.get(
209
+ url,
210
+ headers={"User-Agent": USER_AGENT},
211
+ timeout=10,
212
+ )
213
+ response.raise_for_status()
214
+ except requests.RequestException as exc:
215
+ raise TileFetchError(f"Could not download imagery tile z{zoom}/{x}/{y}.") from exc
216
+
217
+ try:
218
+ return Image.open(BytesIO(response.content)).convert("RGB")
219
+ except OSError as exc:
220
+ raise TileFetchError(f"Downloaded imagery tile z{zoom}/{x}/{y} was invalid.") from exc
221
+
222
+
223
+ def _bbox_crop_box(
224
+ bbox: BBox, tile_range: TileRange, stitched_size: tuple[int, int]
225
+ ) -> tuple[int, int, int, int]:
226
+ zoom = tile_range.zoom
227
+ west_px, north_px = _lonlat_to_global_pixel(bbox.west, bbox.north, zoom)
228
+ east_px, south_px = _lonlat_to_global_pixel(bbox.east, bbox.south, zoom)
229
+ origin_x = tile_range.x_min * TILE_SIZE
230
+ origin_y = tile_range.y_min * TILE_SIZE
231
+
232
+ left = math.floor(west_px - origin_x)
233
+ top = math.floor(north_px - origin_y)
234
+ right = math.ceil(east_px - origin_x)
235
+ bottom = math.ceil(south_px - origin_y)
236
+
237
+ width, height = stitched_size
238
+ return (
239
+ max(0, min(width, left)),
240
+ max(0, min(height, top)),
241
+ max(0, min(width, right)),
242
+ max(0, min(height, bottom)),
243
+ )
244
+
245
+
246
+ def _lonlat_to_global_pixel(lon: float, lat: float, zoom: int) -> tuple[float, float]:
247
+ x_tile, y_tile = lonlat_to_tile_fraction(lon, lat, zoom)
248
+ return x_tile * TILE_SIZE, y_tile * TILE_SIZE
249
+
250
+
251
+ def _haversine_meters(lon1: float, lat1: float, lon2: float, lat2: float) -> float:
252
+ radius_m = 6_371_000
253
+ phi1 = math.radians(lat1)
254
+ phi2 = math.radians(lat2)
255
+ delta_phi = math.radians(lat2 - lat1)
256
+ delta_lambda = math.radians(lon2 - lon1)
257
+ a = (
258
+ math.sin(delta_phi / 2.0) ** 2
259
+ + math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2.0) ** 2
260
+ )
261
+ return 2.0 * radius_m * math.atan2(math.sqrt(a), math.sqrt(1.0 - a))
requirements.txt CHANGED
@@ -1,3 +1,309 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv export --format requirements-txt --no-hashes --output-file requirements.txt
3
+ aiohappyeyeballs==2.6.1
4
+ # via aiohttp
5
+ aiohttp==3.13.3
6
+ # via fsspec
7
+ aiosignal==1.4.0
8
+ # via aiohttp
9
+ altair==6.1.0
10
+ # via streamlit
11
+ annotated-doc==0.0.4
12
+ # via typer
13
+ anyio==4.12.1
14
+ # via httpx
15
+ attrs==25.4.0
16
+ # via
17
+ # aiohttp
18
+ # jsonschema
19
+ # referencing
20
+ blinker==1.9.0
21
+ # via streamlit
22
+ branca==0.8.2
23
+ # via
24
+ # folium
25
+ # streamlit-folium
26
+ cachetools==7.0.6
27
+ # via streamlit
28
+ certifi==2026.2.25
29
+ # via
30
+ # httpcore
31
+ # httpx
32
+ # requests
33
+ charset-normalizer==3.4.6
34
+ # via requests
35
+ click==8.3.1
36
+ # via
37
+ # streamlit
38
+ # typer
39
+ colorama==0.4.6 ; sys_platform == 'win32'
40
+ # via
41
+ # click
42
+ # tqdm
43
+ contourpy==1.3.3
44
+ # via matplotlib
45
+ cuda-bindings==12.9.4 ; platform_machine == 'x86_64' and sys_platform == 'linux'
46
+ # via torch
47
+ cuda-pathfinder==1.4.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
48
+ # via cuda-bindings
49
+ cycler==0.12.1
50
+ # via matplotlib
51
+ datasets==4.8.2
52
+ # via eurosat-land-cover-classification
53
+ dill==0.4.1
54
+ # via
55
+ # datasets
56
+ # multiprocess
57
+ filelock==3.25.2
58
+ # via
59
+ # datasets
60
+ # huggingface-hub
61
+ # torch
62
+ folium==0.20.0
63
+ # via
64
+ # eurosat-land-cover-classification
65
+ # streamlit-folium
66
+ fonttools==4.62.1
67
+ # via matplotlib
68
+ frozenlist==1.8.0
69
+ # via
70
+ # aiohttp
71
+ # aiosignal
72
+ fsspec==2026.2.0
73
+ # via
74
+ # datasets
75
+ # huggingface-hub
76
+ # torch
77
+ gitdb==4.0.12
78
+ # via gitpython
79
+ gitpython==3.1.48
80
+ # via streamlit
81
+ h11==0.16.0
82
+ # via httpcore
83
+ hf-xet==1.4.2 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
84
+ # via huggingface-hub
85
+ httpcore==1.0.9
86
+ # via httpx
87
+ httpx==0.28.1
88
+ # via
89
+ # datasets
90
+ # huggingface-hub
91
+ huggingface-hub==1.7.1
92
+ # via datasets
93
+ idna==3.11
94
+ # via
95
+ # anyio
96
+ # httpx
97
+ # requests
98
+ # yarl
99
+ jinja2==3.1.6
100
+ # via
101
+ # altair
102
+ # branca
103
+ # folium
104
+ # pydeck
105
+ # streamlit-folium
106
+ # torch
107
+ joblib==1.5.3
108
+ # via scikit-learn
109
+ jsonschema==4.26.0
110
+ # via altair
111
+ jsonschema-specifications==2025.9.1
112
+ # via jsonschema
113
+ kiwisolver==1.5.0
114
+ # via matplotlib
115
+ markdown-it-py==4.0.0
116
+ # via rich
117
+ markupsafe==3.0.3
118
+ # via jinja2
119
+ matplotlib==3.10.8
120
+ # via
121
+ # eurosat-land-cover-classification
122
+ # seaborn
123
+ mdurl==0.1.2
124
+ # via markdown-it-py
125
+ mpmath==1.3.0
126
+ # via sympy
127
+ multidict==6.7.1
128
+ # via
129
+ # aiohttp
130
+ # yarl
131
+ multiprocess==0.70.19
132
+ # via datasets
133
+ narwhals==2.20.0
134
+ # via altair
135
+ networkx==3.6.1
136
+ # via torch
137
+ numpy==2.4.3
138
+ # via
139
+ # contourpy
140
+ # datasets
141
+ # eurosat-land-cover-classification
142
+ # folium
143
+ # matplotlib
144
+ # pandas
145
+ # pydeck
146
+ # scikit-learn
147
+ # scipy
148
+ # seaborn
149
+ # streamlit
150
+ # torchvision
151
+ nvidia-cublas-cu12==12.8.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
152
+ # via
153
+ # nvidia-cudnn-cu12
154
+ # nvidia-cusolver-cu12
155
+ # torch
156
+ nvidia-cuda-cupti-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
157
+ # via torch
158
+ nvidia-cuda-nvrtc-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
159
+ # via torch
160
+ nvidia-cuda-runtime-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
161
+ # via torch
162
+ nvidia-cudnn-cu12==9.10.2.21 ; platform_machine == 'x86_64' and sys_platform == 'linux'
163
+ # via torch
164
+ nvidia-cufft-cu12==11.3.3.83 ; platform_machine == 'x86_64' and sys_platform == 'linux'
165
+ # via torch
166
+ nvidia-cufile-cu12==1.13.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
167
+ # via torch
168
+ nvidia-curand-cu12==10.3.9.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
169
+ # via torch
170
+ nvidia-cusolver-cu12==11.7.3.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
171
+ # via torch
172
+ nvidia-cusparse-cu12==12.5.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
173
+ # via
174
+ # nvidia-cusolver-cu12
175
+ # torch
176
+ nvidia-cusparselt-cu12==0.7.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
177
+ # via torch
178
+ nvidia-nccl-cu12==2.27.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
179
+ # via torch
180
+ nvidia-nvjitlink-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
181
+ # via
182
+ # nvidia-cufft-cu12
183
+ # nvidia-cusolver-cu12
184
+ # nvidia-cusparse-cu12
185
+ # torch
186
+ nvidia-nvshmem-cu12==3.4.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
187
+ # via torch
188
+ nvidia-nvtx-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
189
+ # via torch
190
+ packaging==26.0
191
+ # via
192
+ # altair
193
+ # datasets
194
+ # huggingface-hub
195
+ # matplotlib
196
+ # streamlit
197
+ pandas==3.0.1
198
+ # via
199
+ # datasets
200
+ # seaborn
201
+ # streamlit
202
+ pillow==12.1.1
203
+ # via
204
+ # eurosat-land-cover-classification
205
+ # matplotlib
206
+ # streamlit
207
+ # torchvision
208
+ propcache==0.4.1
209
+ # via
210
+ # aiohttp
211
+ # yarl
212
+ protobuf==7.34.1
213
+ # via streamlit
214
+ pyarrow==23.0.1
215
+ # via
216
+ # datasets
217
+ # streamlit
218
+ pydeck==0.9.2
219
+ # via streamlit
220
+ pygments==2.19.2
221
+ # via rich
222
+ pyparsing==3.3.2
223
+ # via matplotlib
224
+ python-dateutil==2.9.0.post0
225
+ # via
226
+ # matplotlib
227
+ # pandas
228
+ pyyaml==6.0.3
229
+ # via
230
+ # datasets
231
+ # huggingface-hub
232
+ referencing==0.37.0
233
+ # via
234
+ # jsonschema
235
+ # jsonschema-specifications
236
+ requests==2.32.5
237
+ # via
238
+ # datasets
239
+ # folium
240
+ # streamlit
241
+ rich==14.3.3
242
+ # via typer
243
+ rpds-py==0.30.0
244
+ # via
245
+ # jsonschema
246
+ # referencing
247
+ scikit-learn==1.8.0
248
+ # via eurosat-land-cover-classification
249
+ scipy==1.17.1
250
+ # via scikit-learn
251
+ seaborn==0.13.2
252
+ # via eurosat-land-cover-classification
253
+ setuptools==82.0.1
254
+ # via torch
255
+ shellingham==1.5.4
256
+ # via typer
257
+ six==1.17.0
258
+ # via python-dateutil
259
+ smmap==5.0.3
260
+ # via gitdb
261
+ streamlit==1.56.0
262
+ # via
263
+ # eurosat-land-cover-classification
264
+ # streamlit-folium
265
+ streamlit-folium==0.27.1
266
+ # via eurosat-land-cover-classification
267
+ sympy==1.14.0
268
+ # via torch
269
+ tenacity==9.1.4
270
+ # via streamlit
271
+ threadpoolctl==3.6.0
272
+ # via scikit-learn
273
+ toml==0.10.2
274
+ # via streamlit
275
+ torch==2.10.0
276
+ # via
277
+ # eurosat-land-cover-classification
278
+ # torchvision
279
+ torchvision==0.25.0
280
+ # via eurosat-land-cover-classification
281
+ tornado==6.5.5
282
+ # via streamlit
283
+ tqdm==4.67.3
284
+ # via
285
+ # datasets
286
+ # eurosat-land-cover-classification
287
+ # huggingface-hub
288
+ triton==3.6.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'
289
+ # via torch
290
+ typer==0.24.1
291
+ # via huggingface-hub
292
+ typing-extensions==4.15.0
293
+ # via
294
+ # altair
295
+ # huggingface-hub
296
+ # streamlit
297
+ # torch
298
+ tzdata==2025.3 ; sys_platform == 'emscripten' or sys_platform == 'win32'
299
+ # via pandas
300
+ urllib3==2.6.3
301
+ # via requests
302
+ watchdog==6.0.0 ; sys_platform != 'darwin'
303
+ # via streamlit
304
+ xxhash==3.6.0
305
+ # via datasets
306
+ xyzservices==2026.3.0
307
+ # via folium
308
+ yarl==1.23.0
309
+ # via aiohttp
train.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ import numpy as np
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ from torchvision import transforms
6
+ from torch import nn
7
+ import torchvision
8
+ from tqdm import tqdm
9
+ from dataset import EuroSATDataset
10
+ import torch.nn.functional as F
11
+
12
+ # Constants retrieved from:
13
+ # https://docs.pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html
14
+ RESNET_50_WEIGHT_MEAN = [0.485, 0.456, 0.406]
15
+ RESNET_50_WEIGHT_STD = [0.229, 0.224, 0.225]
16
+
17
+ DATASET_CFG = {
18
+ "rgb": {"hf_id": "blanchon/EuroSAT_RGB", "in_channels": 3},
19
+ "msi": {"hf_id": "blanchon/EuroSAT_MSI", "in_channels": 13},
20
+ }
21
+
22
+
23
+ def to_chw_tensor(image):
24
+ hwc = np.array(image, dtype=np.float32) # HWC typical shape: 64x64x3
25
+ chw = torch.from_numpy(hwc).permute(2, 0, 1) # CHW typical shape: 3x64x64
26
+ return chw
27
+
28
+
29
+ def build_rgb_transform(train: bool):
30
+ ops = [transforms.Resize((224, 224))]
31
+ if train:
32
+ ops.append(transforms.RandomHorizontalFlip())
33
+ ops.extend(
34
+ [
35
+ transforms.ToTensor(),
36
+ transforms.Normalize(RESNET_50_WEIGHT_MEAN, RESNET_50_WEIGHT_STD),
37
+ ]
38
+ )
39
+ return transforms.Compose(ops)
40
+
41
+
42
+ def build_msi_transform(train: bool):
43
+ def _tf(image):
44
+ chw = to_chw_tensor(image)
45
+ chw = chw / 10000.0
46
+ if train and torch.rand(1).item() < 0.5:
47
+ chw = torch.flip(chw, dims=[2])
48
+ chw = F.interpolate(
49
+ chw.unsqueeze(0), size=(224, 224), mode="bilinear", align_corners=False
50
+ ).squeeze(0)
51
+ return chw
52
+
53
+ return _tf
54
+
55
+
56
+ def build_dataloaders(
57
+ modality: str,
58
+ batch_size: int,
59
+ num_workers: int,
60
+ ):
61
+ cfg = DATASET_CFG[modality]
62
+ ds = datasets.load_dataset(cfg["hf_id"])
63
+ in_channels = cfg["in_channels"]
64
+ num_classes = ds["train"].features["label"].num_classes
65
+
66
+ if modality == "rgb":
67
+ train_tf = build_rgb_transform(train=True)
68
+ eval_tf = build_rgb_transform(train=False)
69
+ else:
70
+ train_tf = build_msi_transform(train=True)
71
+ eval_tf = build_msi_transform(train=False)
72
+
73
+ train_ds = EuroSATDataset(ds["train"], train_tf)
74
+ val_ds = EuroSATDataset(ds["validation"], eval_tf)
75
+
76
+ train_loader = DataLoader(
77
+ train_ds,
78
+ batch_size=batch_size,
79
+ shuffle=True,
80
+ num_workers=num_workers,
81
+ pin_memory=torch.cuda.is_available(),
82
+ )
83
+
84
+ val_loader = DataLoader(
85
+ val_ds,
86
+ batch_size=batch_size,
87
+ shuffle=False,
88
+ num_workers=num_workers,
89
+ pin_memory=torch.cuda.is_available(),
90
+ )
91
+
92
+ return train_loader, val_loader, num_classes, in_channels
93
+
94
+
95
+ # Helper function to get the device CPU or GPU available to train the models.
96
+ def get_device() -> torch.device:
97
+ if torch.cuda.is_available():
98
+ return torch.device("cuda")
99
+ if torch.backends.mps.is_available():
100
+ return torch.device("mps")
101
+ return torch.device("cpu")
102
+
103
+
104
+ def build_model(num_classes: int, device: torch.device, in_channels: int) -> nn.Module:
105
+ model = torchvision.models.resnet50(weights=None)
106
+
107
+ if in_channels != 3:
108
+ model.conv1 = nn.Conv2d(
109
+ in_channels=in_channels,
110
+ out_channels=model.conv1.out_channels,
111
+ kernel_size=model.conv1.kernel_size,
112
+ stride=model.conv1.stride,
113
+ padding=model.conv1.padding,
114
+ bias=False,
115
+ )
116
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
117
+ return model.to(device)
118
+
119
+
120
+ def train_one_epoch(
121
+ model: nn.Module,
122
+ loader: DataLoader,
123
+ criterion: nn.Module,
124
+ optimizer: torch.optim.Optimizer,
125
+ device: torch.device,
126
+ ):
127
+ model.train()
128
+ total_loss = 0.0
129
+ n = 0
130
+
131
+ for images, labels in tqdm(loader, desc="train", leave=False):
132
+ images = images.to(device)
133
+ labels = labels.to(device, dtype=torch.long)
134
+
135
+ optimizer.zero_grad()
136
+ logits = model(images)
137
+ loss = criterion(logits, labels)
138
+ loss.backward()
139
+ optimizer.step()
140
+
141
+ batch_n = labels.size(0)
142
+ total_loss += loss.item() * batch_n
143
+ n += batch_n
144
+
145
+ train_loss = total_loss / max(n, 1) # max(n, 1) to avoid division by zero
146
+ return train_loss
147
+
148
+
149
+ @torch.no_grad()
150
+ def evaluate(
151
+ model: nn.Module, loader: DataLoader, criterion: nn.Module, device: torch.device
152
+ ):
153
+ model.eval()
154
+ total_loss, correct, total = 0.0, 0, 0
155
+ for images, labels in loader:
156
+ images = images.to(device)
157
+ labels = labels.to(device)
158
+ logits = model(images)
159
+ loss = criterion(logits, labels)
160
+
161
+ total_loss += loss.item() * labels.size(0)
162
+ correct += (logits.argmax(1) == labels).sum().item()
163
+ total += labels.size(0)
164
+
165
+ val_loss = total_loss / total
166
+ val_acc = correct / total
167
+
168
+ return val_loss, val_acc
169
+
weights/rgb_e15_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd1ceca358154f341114892ed4dcf0f8490492a695b982c7a09d00660f1d19f0
3
+ size 94429341