edornd commited on
Commit
df521bd
Β·
unverified Β·
1 Parent(s): 8f3b8a9

First commit

Browse files
Files changed (3) hide show
  1. README.md +16 -8
  2. requirements.txt +10 -3
  3. src/streamlit_app.py +579 -37
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Terramind Ad
3
- emoji: πŸš€
4
- colorFrom: red
5
- colorTo: red
6
  sdk: docker
7
  app_port: 8501
8
  tags:
@@ -12,9 +12,17 @@ short_description: Demo of Anomaly Detection with Terramind
12
  license: mit
13
  ---
14
 
15
- # Welcome to Streamlit!
16
 
17
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
18
 
19
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
20
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Terramind Change Detection
3
+ emoji: 🌍
4
+ colorFrom: blue
5
+ colorTo: yellow
6
  sdk: docker
7
  app_port: 8501
8
  tags:
 
12
  license: mit
13
  ---
14
 
 
15
 
16
+ # Terramind AD Interactive Dashboard
17
 
18
+ Launch an interactive Streamlit dashboard to explore results:
19
+
20
+ ```bash
21
+ uv run tools/app.py
22
+ ```
23
+
24
+ The dashboard provides:
25
+ - Spatial heatmap of detected changes
26
+ - Interactive patch selection
27
+ - Temporal analysis with embedding trajectories
28
+ - RGB imagery overlay with detected change areas
requirements.txt CHANGED
@@ -1,3 +1,10 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ rioxarray
3
+ shapely
4
+ xarray
5
+ numpy
6
+ matplotlib
7
+ huggingface_hub
8
+ pydantic
9
+ geojson-pydantic
10
+ zarr
src/streamlit_app.py CHANGED
@@ -1,40 +1,582 @@
1
- import altair as alt
 
 
2
  import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import json
2
+ from pathlib import Path
3
+
4
  import numpy as np
 
5
  import streamlit as st
6
+ import xarray as xr
7
+ from geojson_pydantic import FeatureCollection
8
+ from huggingface_hub import snapshot_download
9
+ from matplotlib import patches
10
+ from matplotlib import pyplot as plt
11
+ from matplotlib.axes import Axes
12
+ from numpy.typing import NDArray
13
+ from pydantic import BaseModel
14
+
15
+
16
+ # download dataset from Hugging Face on startup
17
+ @st.cache_resource
18
+ def download_dataset() -> Path:
19
+ """Download dataset from Hugging Face Hub and return the path."""
20
+ repo_id = "edornd/terramind-ad-data"
21
+ with st.spinner("Downloading dataset from Hugging Face..."):
22
+ local_dir = snapshot_download(repo_id=repo_id, repo_type="dataset")
23
+ return Path(local_dir)
24
+
25
+
26
+ # configuration constants
27
+ DATA_DIR = download_dataset()
28
+ SENSOR_DIR = "s2"
29
+ EVENTS_CONFIG_PATH = DATA_DIR / "events.json"
30
+ # display configuration
31
+ SPATIAL_MAP_SIZE = (2.5, 2.5) # figure size for RGB, PCA, anomaly maps
32
+ TEMPORAL_PLOT_SIZE = (12, 2.2) # figure size for temporal series
33
+ SPATIAL_DPI = 96 # DPI for spatial maps
34
+ TEMPORAL_DPI = 250 # DPI for temporal plots (higher for clarity)
35
+
36
+
37
+ class DisasterSite(BaseModel):
38
+ """Configuration for a disaster site."""
39
+
40
+ id: str
41
+ name: str
42
+ event_type: str
43
+ event_date: str
44
+ observed_event: FeatureCollection
45
+ epsg: int
46
+ historical_start: str
47
+ historical_end: str
48
+ description: str = ""
49
+ default_patch_x: int | None = None
50
+ default_patch_y: int | None = None
51
+
52
+
53
+ class SitesConfig(BaseModel):
54
+ """Root configuration with all sites."""
55
+
56
+ sites: list[DisasterSite]
57
+
58
+
59
+ def min_max_scale(values: np.ndarray) -> np.ndarray:
60
+ """Scale values to [0, 1] using min-max normalization.
61
+
62
+ Args:
63
+ values: input array
64
+
65
+ Returns:
66
+ scaled array in [0, 1]
67
+ """
68
+ vmin = values.min()
69
+ vmax = values.max()
70
+ if vmax - vmin < 1e-8:
71
+ return np.zeros_like(values)
72
+ return (values - vmin) / (vmax - vmin)
73
+
74
+
75
+ def percentile_clip_scale(values: np.ndarray, lower: float = 2.0, upper: float = 98.0) -> np.ndarray:
76
+ """Clip values to percentile range and scale to [0, 1].
77
+
78
+ Args:
79
+ values: input array
80
+ lower: lower percentile (default 2nd percentile)
81
+ upper: upper percentile (default 98th percentile)
82
+
83
+ Returns:
84
+ clipped and scaled array in [0, 1]
85
+ """
86
+ vmin, vmax = np.percentile(values, [lower, upper])
87
+ clipped = np.clip(values, vmin, vmax)
88
+ if vmax - vmin < 1e-8:
89
+ return np.zeros_like(clipped)
90
+ return (clipped - vmin) / (vmax - vmin)
91
+
92
+
93
+ # set matplotlib style for professional web plots
94
+ plt.style.use("seaborn-v0_8-darkgrid")
95
+ plt.rcParams.update(
96
+ {
97
+ "font.size": 9,
98
+ "axes.titlesize": 10,
99
+ "axes.labelsize": 9,
100
+ "xtick.labelsize": 8,
101
+ "ytick.labelsize": 8,
102
+ "legend.fontsize": 8,
103
+ "figure.dpi": SPATIAL_DPI,
104
+ "savefig.dpi": SPATIAL_DPI,
105
+ "axes.grid": True,
106
+ "grid.alpha": 0.3,
107
+ "grid.linewidth": 0.5,
108
+ }
109
+ )
110
+
111
+
112
+ def draw_crosshair(
113
+ ax: Axes,
114
+ cx: int,
115
+ cy: int,
116
+ size: int = 32,
117
+ color: str = "red",
118
+ draw_lines: bool = True,
119
+ ):
120
+ half_size = size // 2
121
+ square = patches.Rectangle(
122
+ (cx - half_size, cy - half_size),
123
+ width=size,
124
+ height=size,
125
+ fill=False,
126
+ edgecolor=color,
127
+ linewidth=1,
128
+ )
129
+ ax.add_patch(square)
130
+ if draw_lines:
131
+ ax.hlines(y=cy, xmin=0, xmax=(cx - half_size), linestyle=":", color=color, linewidth=1)
132
+ ax.vlines(x=cx, ymin=0, ymax=(cy - half_size), linestyle=":", color=color, linewidth=1)
133
+
134
+
135
+ def render_rgb_image(
136
+ rgb_data: xr.DataArray,
137
+ time_idx: int,
138
+ selected_patch: tuple[int, int] | None = None,
139
+ downsample: int = 16,
140
+ ) -> None:
141
+ """Render RGB satellite image with optional patch marker.
142
+
143
+ Loads only the requested timestep from zarr (lazy loading).
144
+ """
145
+ # lazy load only this timestep
146
+ rgb = rgb_data.isel(time=time_idx, band=[3, 2, 1]).values # B4, B3, B2
147
+ rgb = np.clip(rgb / 5000 * 255, 0, 255).astype(np.uint8)
148
+ rgb = rgb.transpose(1, 2, 0)
149
+
150
+ fig, ax = plt.subplots(figsize=SPATIAL_MAP_SIZE, facecolor="white")
151
+ ax.imshow(rgb)
152
+ ax.axis("off")
153
+
154
+ if selected_patch is not None:
155
+ px, py = selected_patch
156
+ cx = px * downsample + downsample // 2
157
+ cy = py * downsample + downsample // 2
158
+ draw_crosshair(ax, cx, cy)
159
+
160
+ plt.tight_layout(pad=0.1)
161
+ st.pyplot(fig, width="stretch")
162
+ plt.close()
163
+
164
+
165
+ def render_pca_features(
166
+ pca_data: xr.DataArray,
167
+ time_idx: int,
168
+ selected_patch: tuple[int, int] | None = None,
169
+ ) -> None:
170
+ """Render PCA feature visualization with z-score normalization.
171
+
172
+ Loads only the requested timestep from zarr (lazy loading).
173
+ """
174
+ # lazy load only this timestep (handle both xarray and zarr arrays)
175
+ if hasattr(pca_data, "isel"):
176
+ pca_t = pca_data.isel(time=time_idx).values # xarray DataArray
177
+ else:
178
+ pca_t = pca_data[time_idx] # zarr Array # (H, W, 3)
179
+
180
+ # apply normalization
181
+ pca_flat = pca_t.reshape(-1, 3)
182
+ pca_norm = percentile_clip_scale(pca_flat)
183
+ pca_scaled = min_max_scale(pca_norm)
184
+ pca_rgb = pca_scaled.reshape(pca_t.shape)
185
+
186
+ fig, ax = plt.subplots(figsize=SPATIAL_MAP_SIZE, facecolor="white")
187
+ ax.imshow(pca_rgb, interpolation="nearest")
188
+ ax.axis("off")
189
+
190
+ if selected_patch is not None:
191
+ px, py = selected_patch
192
+ draw_crosshair(ax, px, py, size=4, color="yellow")
193
+
194
+ plt.tight_layout(pad=0.1)
195
+ st.pyplot(fig, width="stretch")
196
+ plt.close()
197
+
198
+
199
+ def render_anomaly_map(
200
+ accumulated_anomalies: NDArray,
201
+ selected_patch: tuple[int, int] | None = None,
202
+ ) -> None:
203
+ """Render accumulated post-event anomaly heatmap.
204
+
205
+ Args:
206
+ accumulated_anomalies: (H, W) count of anomalies per pixel after event
207
+ selected_patch: (x, y) coordinates of selected patch
208
+ """
209
+ fig, ax = plt.subplots(figsize=SPATIAL_MAP_SIZE, facecolor="white")
210
+
211
+ # normalize for visualization
212
+ max_count = accumulated_anomalies.max()
213
+ if max_count > 0:
214
+ normalized = accumulated_anomalies / max_count
215
+ else:
216
+ normalized = accumulated_anomalies
217
+
218
+ ax.imshow(normalized, cmap="magma", vmin=0, vmax=1, interpolation="nearest")
219
+ ax.axis("off")
220
+
221
+ if selected_patch is not None:
222
+ px, py = selected_patch
223
+ draw_crosshair(ax, px, py, size=3, draw_lines=False)
224
+
225
+ plt.tight_layout(pad=0.1)
226
+ st.pyplot(fig, width="stretch")
227
+ plt.close()
228
+
229
+
230
+ def render_temporal_series(
231
+ residuals: NDArray,
232
+ anomaly_mask: NDArray,
233
+ timestamps: list[str],
234
+ patch_coord: tuple[int, int],
235
+ time_idx: int,
236
+ event_idx: int,
237
+ ) -> None:
238
+ """Render temporal evolution at selected patch."""
239
+ px, py = patch_coord
240
+ residuals_patch = residuals[:, py, px]
241
+ anomaly_patch = anomaly_mask[:, py, px]
242
+
243
+ fig, ax = plt.subplots(figsize=TEMPORAL_PLOT_SIZE, facecolor="white", dpi=TEMPORAL_DPI)
244
+ time_indices = np.arange(len(timestamps))
245
+ # plot residuals with professional styling
246
+ ax.plot(
247
+ time_indices,
248
+ residuals_patch,
249
+ "o-",
250
+ color="#2E86AB",
251
+ alpha=0.7,
252
+ markersize=3.5,
253
+ linewidth=1.3,
254
+ label="Residual",
255
+ )
256
+
257
+ # mark anomalies with red X (only post-event)
258
+ anom_indices = np.where(anomaly_patch)[0]
259
+ post_event_anom_indices = anom_indices[anom_indices >= event_idx]
260
+ if len(post_event_anom_indices) > 0:
261
+ ax.scatter(
262
+ post_event_anom_indices,
263
+ residuals_patch[post_event_anom_indices],
264
+ marker="x",
265
+ s=60,
266
+ c="#C73E1D",
267
+ linewidths=2.2,
268
+ zorder=5,
269
+ label="Anomaly",
270
+ )
271
+ # mark current time and event
272
+ ax.axvline(time_idx, color="#F18F01", linestyle="--", linewidth=1.8, alpha=0.7, label="Current")
273
+ if event_idx is not None:
274
+ ax.axvline(event_idx, color="#6A4C93", linestyle=":", linewidth=1.8, alpha=0.7, label="Event")
275
+
276
+ ax.set_xlabel("Date", fontweight="semibold")
277
+ ax.set_ylabel("PC1 Value", fontweight="semibold")
278
+ ax.set_title(f"Temporal Profile at Patch ({px}, {py})", fontweight="bold", pad=10)
279
+
280
+ # show dates on x-axis with smart ticking
281
+ n_ticks = min(10, len(timestamps))
282
+ tick_indices = np.linspace(0, len(timestamps) - 1, n_ticks, dtype=int)
283
+ ax.set_xticks(tick_indices)
284
+ ax.set_xticklabels([timestamps[i] for i in tick_indices], rotation=45, ha="right")
285
+ ax.tick_params(labelsize=7)
286
+ ax.legend(loc="upper left", fontsize=6, framealpha=0.95, ncol=3, edgecolor="gray", fancybox=True)
287
+
288
+ plt.tight_layout()
289
+ st.pyplot(fig, width="content")
290
+ plt.close()
291
+
292
+
293
+ def render_anomaly_timeline(
294
+ anomaly_mask: NDArray,
295
+ timestamps: list[str],
296
+ time_idx: int,
297
+ event_idx: int,
298
+ ) -> None:
299
+ """Render timeline of anomaly counts over time."""
300
+ T, H, W = anomaly_mask.shape
301
+ anomaly_counts = anomaly_mask.sum(axis=(1, 2))
302
+
303
+ fig, ax = plt.subplots(figsize=TEMPORAL_PLOT_SIZE, facecolor="white", dpi=TEMPORAL_DPI)
304
+
305
+ time_indices = np.arange(len(timestamps))
306
+ colors = ["#F18F01" if i == time_idx else "#2E86AB" for i in range(len(timestamps))]
307
+
308
+ ax.bar(time_indices, anomaly_counts, color=colors, alpha=0.75, width=0.85, edgecolor="white", linewidth=0.5)
309
+
310
+ # mark event
311
+ if event_idx is not None:
312
+ ax.axvline(event_idx, color="#6A4C93", linestyle=":", linewidth=2, alpha=0.8, label="Event")
313
+ ax.legend(loc="upper right", fontsize=8, framealpha=0.95, edgecolor="gray")
314
+
315
+ ax.set_xlabel("Date", fontweight="semibold")
316
+ ax.set_ylabel("Anomalous Patches", fontweight="semibold")
317
+ ax.set_title(f"Spatial Anomaly Count Over Time (Total: {H * W} patches)", fontweight="bold", pad=10)
318
+
319
+ # show dates on x-axis with smart ticking
320
+ n_ticks = min(10, len(timestamps))
321
+ tick_indices = np.linspace(0, len(timestamps) - 1, n_ticks, dtype=int)
322
+ ax.set_xticks(tick_indices)
323
+ ax.set_xticklabels([timestamps[i] for i in tick_indices], rotation=45, ha="right")
324
+ ax.tick_params(labelsize=7)
325
+
326
+ plt.tight_layout()
327
+ st.pyplot(fig, width="content")
328
+ plt.close()
329
+
330
+
331
+ @st.cache_resource
332
+ def load_site_data(site_id: str) -> dict:
333
+ """Load lazy references to zarr data (no eager loading into memory).
334
+
335
+ Returns xarray DataArrays that load data on-demand when sliced.
336
+ """
337
+ features_path = DATA_DIR / site_id / "features" / SENSOR_DIR / "features.zarr"
338
+ if not features_path.exists():
339
+ raise FileNotFoundError(f"Features not found: {features_path}")
340
+
341
+ # load as zarr group for metadata
342
+ import zarr
343
+
344
+ features_group = zarr.open(str(features_path), mode="r")
345
+ timestamps = [ts.decode("utf-8") for ts in features_group["timestamps"][:]] # type: ignore
346
+ metadata = dict(features_group.attrs)
347
+
348
+ # load PC3 features for visualization (stored as zarr arrays, not xarray)
349
+ pc3_path = DATA_DIR / site_id / "features" / SENSOR_DIR / "features_pc3.zarr"
350
+ pca_data = None
351
+ if pc3_path.exists():
352
+ pca_group = zarr.open(str(pc3_path), mode="r")
353
+ pca_data = pca_group["features"] # type: ignore lazy array (T, H, W, 3)
354
+
355
+ # lazy load RGB imagery
356
+ sat_zarr_path = DATA_DIR / site_id / "images" / SENSOR_DIR / "timeseries.zarr"
357
+
358
+ sat_data = None
359
+ if sat_zarr_path.exists():
360
+ ds = xr.open_zarr(sat_zarr_path, consolidated=True)
361
+ sat_data = ds[list(ds.data_vars)[0]] # lazy DataArray
362
+
363
+ return {
364
+ "timestamps": timestamps,
365
+ "rgb_data": sat_data,
366
+ "metadata": metadata,
367
+ "pca_data": pca_data,
368
+ "T": len(timestamps),
369
+ "H": pca_data.shape[1] if pca_data is not None else 0, # type: ignore
370
+ "W": pca_data.shape[2] if pca_data is not None else 0, # type: ignore
371
+ }
372
+
373
+
374
+ @st.cache_data
375
+ def load_anomaly_data(site_id: str) -> dict | None:
376
+ """Load pre-computed anomaly detection results."""
377
+ detection_path = DATA_DIR / site_id / "anomalies" / SENSOR_DIR / "detection.npz"
378
+
379
+ if not detection_path.exists():
380
+ return None
381
+
382
+ data = np.load(detection_path)
383
+
384
+ # compute anomaly mask: residuals > threshold (per-pixel)
385
+ residuals = data["residuals"] # (T, H, W)
386
+ threshold = data["threshold"] # (H, W)
387
+ valid_mask = data["valid_mask"] # (T, H, W)
388
+ event_idx = int(data["event_idx"])
389
+
390
+ # binary anomaly: where residual exceeds threshold AND observation is clear
391
+ anomaly_mask = (residuals > threshold[None, :, :]) & valid_mask # (T, H, W)
392
+
393
+ # compute accumulated post-event anomalies for visualization
394
+ post_event_mask = anomaly_mask[event_idx:] # (T_post, H, W)
395
+ accumulated_anomalies = post_event_mask.sum(axis=0).astype(float) # (H, W)
396
+
397
+ # try to load filtered results if available
398
+ filtered_path = DATA_DIR / site_id / "anomalies" / SENSOR_DIR / "detection_filtered.npz"
399
+ if filtered_path.exists():
400
+ filtered_data = np.load(filtered_path)
401
+ accumulated_anomalies = filtered_data["accumulated_filtered"] # use filtered version
402
+
403
+ return {
404
+ "residuals_timeseries": residuals,
405
+ "anomaly_mask_timeseries": anomaly_mask,
406
+ "fitted_timeseries": data["fitted_values"],
407
+ "valid_mask": valid_mask,
408
+ "threshold": threshold,
409
+ "event_idx": event_idx,
410
+ "accumulated_anomalies": accumulated_anomalies, # (H, W) accumulated post-event
411
+ }
412
+
413
+
414
+ @st.cache_resource
415
+ def load_sites_config() -> SitesConfig:
416
+ """Load site configurations from events.json."""
417
+ with EVENTS_CONFIG_PATH.open() as f:
418
+ return SitesConfig(**json.load(f))
419
+
420
+
421
+ def run():
422
+ st.set_page_config(page_title="TerraMind Anomaly Detection", page_icon="🌍", layout="wide")
423
+ st.sidebar.title("TerraMind \nChange Detection")
424
+
425
+ # load available sites
426
+ config = load_sites_config()
427
+ site_options = {site.id: site.name for site in config.sites}
428
+
429
+ # site selection dropdown
430
+ site_id = st.sidebar.selectbox(
431
+ "Site",
432
+ options=list(site_options.keys()),
433
+ format_func=lambda x: site_options[x],
434
+ )
435
+
436
+ # get current site config
437
+ current_site = next(site for site in config.sites if site.id == site_id)
438
+
439
+ # load data
440
+ try:
441
+ with st.spinner("Loading data..."):
442
+ data = load_site_data(site_id)
443
+ anomaly_data = load_anomaly_data(site_id)
444
+ except FileNotFoundError as e:
445
+ st.error(f"❌ {e}")
446
+ st.info(f"Expected structure:\n- `{DATA_DIR}/<site_id>/features/{SENSOR_DIR}/features.zarr`")
447
+ return
448
+
449
+ timestamps = data["timestamps"]
450
+ rgb_data = data["rgb_data"]
451
+ pca_data = data["pca_data"]
452
+ T = data["T"]
453
+ H = data["H"]
454
+ W = data["W"]
455
+
456
+ # sidebar: show errors only
457
+ if anomaly_data is None:
458
+ st.sidebar.error("⚠️ No anomaly data")
459
+ st.sidebar.info("Run: `uv run python tools/detect.py run --site-id <site_id>`")
460
+ return
461
+ if pca_data is None:
462
+ st.sidebar.error("⚠️ No PC3 features")
463
+ st.sidebar.info("Run: `uv run python tools/infer.py pca --site-id <site_id> --n-components 3`")
464
+ return
465
+ event_idx = anomaly_data["event_idx"]
466
+
467
+ # controls
468
+ st.sidebar.markdown("---")
469
+ st.sidebar.subheader("πŸŽ›οΈ Controls")
470
+
471
+ # reset state when site changes
472
+ if "current_site_id" not in st.session_state or st.session_state.current_site_id != site_id:
473
+ st.session_state.current_site_id = site_id
474
+ st.session_state.time_idx = event_idx if event_idx is not None else 0
475
+ st.session_state.patch_x = current_site.default_patch_x or W // 2
476
+ st.session_state.patch_y = current_site.default_patch_y or H // 2
477
+
478
+ # time control with +/- buttons
479
+ st.sidebar.markdown("**⏱️ Time Selection**")
480
+
481
+ # clamp time_idx to valid range (in case data size changed)
482
+ st.session_state.time_idx = min(max(0, st.session_state.time_idx), T - 1)
483
+
484
+ col_minus, col_slider, col_plus = st.sidebar.columns([1, 8, 1])
485
+ with col_minus:
486
+ if st.button(
487
+ "",
488
+ key="time_minus",
489
+ type="tertiary",
490
+ help="Previous timestep",
491
+ icon=":material/do_not_disturb_on:",
492
+ ):
493
+ st.session_state.time_idx = max(0, st.session_state.time_idx - 1)
494
+ with col_slider:
495
+ time_idx = st.slider(
496
+ "Date",
497
+ 0,
498
+ T - 1,
499
+ st.session_state.time_idx,
500
+ format=f"{timestamps[st.session_state.time_idx]}",
501
+ label_visibility="collapsed",
502
+ )
503
+ st.session_state.time_idx = time_idx
504
+ with col_plus:
505
+ if st.button(
506
+ "",
507
+ key="time_plus",
508
+ type="tertiary",
509
+ help="Next timestep",
510
+ icon=":material/add_circle:",
511
+ ):
512
+ st.session_state.time_idx = min(T - 1, st.session_state.time_idx + 1)
513
+ time_idx = st.session_state.time_idx
514
+
515
+ st.sidebar.markdown("**πŸ“ Patch Selection**")
516
+ # clamp patch coordinates to valid range
517
+ st.session_state.patch_x = min(max(0, st.session_state.patch_x), W - 1)
518
+ st.session_state.patch_y = min(max(0, st.session_state.patch_y), H - 1)
519
+
520
+ col1, col2 = st.sidebar.columns(2)
521
+ patch_x = col1.number_input("X", 0, W - 1, st.session_state.patch_x, key="px")
522
+ patch_y = col2.number_input("Y", 0, H - 1, st.session_state.patch_y, key="py")
523
+
524
+ # update session state with any manual changes
525
+ st.session_state.patch_x = patch_x
526
+ st.session_state.patch_y = patch_y
527
+
528
+ # main content: temporal analysis view (always shown)
529
+ st.title(site_options[site_id])
530
+
531
+ # spatial context (small maps)
532
+ st.markdown(f"### πŸ—ΊοΈ Spatial Context β€” `{timestamps[time_idx]}`")
533
+ col1, col2, col3 = st.columns(3)
534
+
535
+ with col1:
536
+ st.markdown("**RGB**")
537
+ if rgb_data is not None:
538
+ render_rgb_image(
539
+ rgb_data,
540
+ time_idx,
541
+ (int(patch_x), int(patch_y)),
542
+ )
543
+ else:
544
+ st.warning("RGB data not available")
545
+
546
+ with col2:
547
+ st.markdown("**PCA**")
548
+ render_pca_features(
549
+ pca_data,
550
+ time_idx,
551
+ (int(patch_x), int(patch_y)),
552
+ )
553
+
554
+ with col3:
555
+ st.markdown("**Anomaly Heatmap**")
556
+ render_anomaly_map(
557
+ anomaly_data["accumulated_anomalies"],
558
+ (int(patch_x), int(patch_y)),
559
+ )
560
+
561
+ st.markdown(f"### πŸ“ˆ Temporal Analysis β€” Patch `({patch_x}, {patch_y})`")
562
+
563
+ # temporal series and anomaly timeline
564
+ render_temporal_series(
565
+ residuals=anomaly_data["residuals_timeseries"],
566
+ anomaly_mask=anomaly_data["anomaly_mask_timeseries"],
567
+ timestamps=timestamps,
568
+ patch_coord=(int(patch_x), int(patch_y)),
569
+ time_idx=time_idx,
570
+ event_idx=event_idx,
571
+ )
572
+
573
+ render_anomaly_timeline(
574
+ anomaly_mask=anomaly_data["anomaly_mask_timeseries"],
575
+ timestamps=timestamps,
576
+ time_idx=time_idx,
577
+ event_idx=event_idx,
578
+ )
579
+
580
 
581
+ if __name__ == "__main__":
582
+ run()