import json from pathlib import Path import numpy as np import streamlit as st import xarray as xr from geojson_pydantic import FeatureCollection from huggingface_hub import snapshot_download from matplotlib import patches from matplotlib import pyplot as plt from matplotlib.axes import Axes from numpy.typing import NDArray from pydantic import BaseModel # download dataset from Hugging Face on startup @st.cache_resource def download_dataset() -> Path: """Download dataset from Hugging Face Hub and return the path.""" repo_id = "edornd/terramind-ad-data" with st.spinner("Downloading dataset from Hugging Face..."): local_dir = snapshot_download(repo_id=repo_id, repo_type="dataset") return Path(local_dir) # configuration constants DATA_DIR = download_dataset() SENSOR_DIR = "s2" EVENTS_CONFIG_PATH = DATA_DIR / "events.json" # display configuration SPATIAL_MAP_SIZE = (2.5, 2.5) # figure size for RGB, PCA, anomaly maps TEMPORAL_PLOT_SIZE = (12, 2.2) # figure size for temporal series SPATIAL_DPI = 96 # DPI for spatial maps TEMPORAL_DPI = 250 # DPI for temporal plots (higher for clarity) class DisasterSite(BaseModel): """Configuration for a disaster site.""" id: str name: str event_type: str event_date: str observed_event: FeatureCollection epsg: int historical_start: str historical_end: str description: str = "" default_patch_x: int | None = None default_patch_y: int | None = None class SitesConfig(BaseModel): """Root configuration with all sites.""" sites: list[DisasterSite] def min_max_scale(values: np.ndarray) -> np.ndarray: """Scale values to [0, 1] using min-max normalization. Args: values: input array Returns: scaled array in [0, 1] """ vmin = values.min() vmax = values.max() if vmax - vmin < 1e-8: return np.zeros_like(values) return (values - vmin) / (vmax - vmin) def percentile_clip_scale(values: np.ndarray, lower: float = 2.0, upper: float = 98.0) -> np.ndarray: """Clip values to percentile range and scale to [0, 1]. Args: values: input array lower: lower percentile (default 2nd percentile) upper: upper percentile (default 98th percentile) Returns: clipped and scaled array in [0, 1] """ vmin, vmax = np.percentile(values, [lower, upper]) clipped = np.clip(values, vmin, vmax) if vmax - vmin < 1e-8: return np.zeros_like(clipped) return (clipped - vmin) / (vmax - vmin) # set matplotlib style for professional web plots plt.style.use("seaborn-v0_8-darkgrid") plt.rcParams.update( { "font.size": 9, "axes.titlesize": 10, "axes.labelsize": 9, "xtick.labelsize": 8, "ytick.labelsize": 8, "legend.fontsize": 8, "figure.dpi": SPATIAL_DPI, "savefig.dpi": SPATIAL_DPI, "axes.grid": True, "grid.alpha": 0.3, "grid.linewidth": 0.5, } ) def draw_crosshair( ax: Axes, cx: int, cy: int, size: int = 32, color: str = "red", draw_lines: bool = True, ): half_size = size // 2 square = patches.Rectangle( (cx - half_size, cy - half_size), width=size, height=size, fill=False, edgecolor=color, linewidth=1, ) ax.add_patch(square) if draw_lines: ax.hlines(y=cy, xmin=0, xmax=(cx - half_size), linestyle=":", color=color, linewidth=1) ax.vlines(x=cx, ymin=0, ymax=(cy - half_size), linestyle=":", color=color, linewidth=1) def render_rgb_image( rgb_data: xr.DataArray, time_idx: int, selected_patch: tuple[int, int] | None = None, downsample: int = 16, ) -> None: """Render RGB satellite image with optional patch marker. Loads only the requested timestep from zarr (lazy loading). """ # lazy load only this timestep rgb = rgb_data.isel(time=time_idx, band=[3, 2, 1]).values # B4, B3, B2 rgb = np.clip(rgb / 5000 * 255, 0, 255).astype(np.uint8) rgb = rgb.transpose(1, 2, 0) fig, ax = plt.subplots(figsize=SPATIAL_MAP_SIZE, facecolor="white") ax.imshow(rgb) ax.axis("off") if selected_patch is not None: px, py = selected_patch cx = px * downsample + downsample // 2 cy = py * downsample + downsample // 2 draw_crosshair(ax, cx, cy) plt.tight_layout(pad=0.1) st.pyplot(fig, width="stretch") plt.close() def render_pca_features( pca_data: xr.DataArray, time_idx: int, selected_patch: tuple[int, int] | None = None, ) -> None: """Render PCA feature visualization with z-score normalization. Loads only the requested timestep from zarr (lazy loading). """ # lazy load only this timestep (handle both xarray and zarr arrays) if hasattr(pca_data, "isel"): pca_t = pca_data.isel(time=time_idx).values # xarray DataArray else: pca_t = pca_data[time_idx] # zarr Array # (H, W, 3) # apply normalization pca_flat = pca_t.reshape(-1, 3) pca_norm = percentile_clip_scale(pca_flat) pca_scaled = min_max_scale(pca_norm) pca_rgb = pca_scaled.reshape(pca_t.shape) fig, ax = plt.subplots(figsize=SPATIAL_MAP_SIZE, facecolor="white") ax.imshow(pca_rgb, interpolation="nearest") ax.axis("off") if selected_patch is not None: px, py = selected_patch draw_crosshair(ax, px, py, size=4, color="yellow") plt.tight_layout(pad=0.1) st.pyplot(fig, width="stretch") plt.close() def render_anomaly_map( accumulated_anomalies: NDArray, selected_patch: tuple[int, int] | None = None, ) -> None: """Render accumulated post-event anomaly heatmap. Args: accumulated_anomalies: (H, W) count of anomalies per pixel after event selected_patch: (x, y) coordinates of selected patch """ fig, ax = plt.subplots(figsize=SPATIAL_MAP_SIZE, facecolor="white") # normalize for visualization max_count = accumulated_anomalies.max() if max_count > 0: normalized = accumulated_anomalies / max_count else: normalized = accumulated_anomalies ax.imshow(normalized, cmap="magma", vmin=0, vmax=1, interpolation="nearest") ax.axis("off") if selected_patch is not None: px, py = selected_patch draw_crosshair(ax, px, py, size=3, draw_lines=False) plt.tight_layout(pad=0.1) st.pyplot(fig, width="stretch") plt.close() def render_temporal_series( residuals: NDArray, anomaly_mask: NDArray, timestamps: list[str], patch_coord: tuple[int, int], time_idx: int, event_idx: int, ) -> None: """Render temporal evolution at selected patch.""" px, py = patch_coord residuals_patch = residuals[:, py, px] anomaly_patch = anomaly_mask[:, py, px] fig, ax = plt.subplots(figsize=TEMPORAL_PLOT_SIZE, facecolor="white", dpi=TEMPORAL_DPI) time_indices = np.arange(len(timestamps)) # plot residuals with professional styling ax.plot( time_indices, residuals_patch, "o-", color="#2E86AB", alpha=0.7, markersize=3.5, linewidth=1.3, label="Residual", ) # mark anomalies with red X (only post-event) anom_indices = np.where(anomaly_patch)[0] post_event_anom_indices = anom_indices[anom_indices >= event_idx] if len(post_event_anom_indices) > 0: ax.scatter( post_event_anom_indices, residuals_patch[post_event_anom_indices], marker="x", s=60, c="#C73E1D", linewidths=2.2, zorder=5, label="Anomaly", ) # mark current time and event ax.axvline(time_idx, color="#F18F01", linestyle="--", linewidth=1.8, alpha=0.7, label="Current") if event_idx is not None: ax.axvline(event_idx, color="#6A4C93", linestyle=":", linewidth=1.8, alpha=0.7, label="Event") ax.set_xlabel("Date", fontweight="semibold") ax.set_ylabel("PC1 Value", fontweight="semibold") ax.set_title(f"Temporal Profile at Patch ({px}, {py})", fontweight="bold", pad=10) # show dates on x-axis with smart ticking n_ticks = min(10, len(timestamps)) tick_indices = np.linspace(0, len(timestamps) - 1, n_ticks, dtype=int) ax.set_xticks(tick_indices) ax.set_xticklabels([timestamps[i] for i in tick_indices], rotation=45, ha="right") ax.tick_params(labelsize=7) ax.legend(loc="upper left", fontsize=6, framealpha=0.95, ncol=3, edgecolor="gray", fancybox=True) plt.tight_layout() st.pyplot(fig, width="content") plt.close() def render_anomaly_timeline( anomaly_mask: NDArray, timestamps: list[str], time_idx: int, event_idx: int, ) -> None: """Render timeline of anomaly counts over time.""" T, H, W = anomaly_mask.shape anomaly_counts = anomaly_mask.sum(axis=(1, 2)) fig, ax = plt.subplots(figsize=TEMPORAL_PLOT_SIZE, facecolor="white", dpi=TEMPORAL_DPI) time_indices = np.arange(len(timestamps)) colors = ["#F18F01" if i == time_idx else "#2E86AB" for i in range(len(timestamps))] ax.bar(time_indices, anomaly_counts, color=colors, alpha=0.75, width=0.85, edgecolor="white", linewidth=0.5) # mark event if event_idx is not None: ax.axvline(event_idx, color="#6A4C93", linestyle=":", linewidth=2, alpha=0.8, label="Event") ax.legend(loc="upper right", fontsize=8, framealpha=0.95, edgecolor="gray") ax.set_xlabel("Date", fontweight="semibold") ax.set_ylabel("Anomalous Patches", fontweight="semibold") ax.set_title(f"Spatial Anomaly Count Over Time (Total: {H * W} patches)", fontweight="bold", pad=10) # show dates on x-axis with smart ticking n_ticks = min(10, len(timestamps)) tick_indices = np.linspace(0, len(timestamps) - 1, n_ticks, dtype=int) ax.set_xticks(tick_indices) ax.set_xticklabels([timestamps[i] for i in tick_indices], rotation=45, ha="right") ax.tick_params(labelsize=7) plt.tight_layout() st.pyplot(fig, width="content") plt.close() @st.cache_resource def load_site_data(site_id: str) -> dict: """Load lazy references to zarr data (no eager loading into memory). Returns xarray DataArrays that load data on-demand when sliced. """ features_path = DATA_DIR / site_id / "features" / SENSOR_DIR / "features.zarr" if not features_path.exists(): raise FileNotFoundError(f"Features not found: {features_path}") # load as zarr group for metadata import zarr features_group = zarr.open(str(features_path), mode="r") timestamps = [ts.decode("utf-8") for ts in features_group["timestamps"][:]] # type: ignore metadata = dict(features_group.attrs) # load PC3 features for visualization (stored as zarr arrays, not xarray) pc3_path = DATA_DIR / site_id / "features" / SENSOR_DIR / "features_pc3.zarr" pca_data = None if pc3_path.exists(): pca_group = zarr.open(str(pc3_path), mode="r") pca_data = pca_group["features"] # type: ignore lazy array (T, H, W, 3) # lazy load RGB imagery sat_zarr_path = DATA_DIR / site_id / "images" / SENSOR_DIR / "timeseries.zarr" sat_data = None if sat_zarr_path.exists(): ds = xr.open_zarr(sat_zarr_path, consolidated=True) sat_data = ds[list(ds.data_vars)[0]] # lazy DataArray return { "timestamps": timestamps, "rgb_data": sat_data, "metadata": metadata, "pca_data": pca_data, "T": len(timestamps), "H": pca_data.shape[1] if pca_data is not None else 0, # type: ignore "W": pca_data.shape[2] if pca_data is not None else 0, # type: ignore } @st.cache_data def load_anomaly_data(site_id: str) -> dict | None: """Load pre-computed anomaly detection results.""" detection_path = DATA_DIR / site_id / "anomalies" / SENSOR_DIR / "detection.npz" if not detection_path.exists(): return None data = np.load(detection_path) # compute anomaly mask: residuals > threshold (per-pixel) residuals = data["residuals"] # (T, H, W) threshold = data["threshold"] # (H, W) valid_mask = data["valid_mask"] # (T, H, W) event_idx = int(data["event_idx"]) # binary anomaly: where residual exceeds threshold AND observation is clear anomaly_mask = (residuals > threshold[None, :, :]) & valid_mask # (T, H, W) # compute accumulated post-event anomalies for visualization post_event_mask = anomaly_mask[event_idx:] # (T_post, H, W) accumulated_anomalies = post_event_mask.sum(axis=0).astype(float) # (H, W) # try to load filtered results if available filtered_path = DATA_DIR / site_id / "anomalies" / SENSOR_DIR / "detection_filtered.npz" if filtered_path.exists(): filtered_data = np.load(filtered_path) accumulated_anomalies = filtered_data["accumulated_filtered"] # use filtered version return { "residuals_timeseries": residuals, "anomaly_mask_timeseries": anomaly_mask, "fitted_timeseries": data["fitted_values"], "valid_mask": valid_mask, "threshold": threshold, "event_idx": event_idx, "accumulated_anomalies": accumulated_anomalies, # (H, W) accumulated post-event } @st.cache_resource def load_sites_config() -> SitesConfig: """Load site configurations from events.json.""" with EVENTS_CONFIG_PATH.open() as f: return SitesConfig(**json.load(f)) def run(): st.set_page_config(page_title="TerraMind Anomaly Detection", page_icon="🌍", layout="wide") st.sidebar.title("TerraMind \nChange Detection") # load available sites config = load_sites_config() site_options = {site.id: site.name for site in config.sites} # site selection dropdown site_id = st.sidebar.selectbox( "Site", options=list(site_options.keys()), format_func=lambda x: site_options[x], ) # get current site config current_site = next(site for site in config.sites if site.id == site_id) # load data try: with st.spinner("Loading data..."): data = load_site_data(site_id) anomaly_data = load_anomaly_data(site_id) except FileNotFoundError as e: st.error(f"❌ {e}") st.info(f"Expected structure:\n- `{DATA_DIR}//features/{SENSOR_DIR}/features.zarr`") return timestamps = data["timestamps"] rgb_data = data["rgb_data"] pca_data = data["pca_data"] T = data["T"] H = data["H"] W = data["W"] # sidebar: show errors only if anomaly_data is None: st.sidebar.error("⚠️ No anomaly data") st.sidebar.info("Run: `uv run python tools/detect.py run --site-id `") return if pca_data is None: st.sidebar.error("⚠️ No PC3 features") st.sidebar.info("Run: `uv run python tools/infer.py pca --site-id --n-components 3`") return event_idx = anomaly_data["event_idx"] # controls st.sidebar.markdown("---") st.sidebar.subheader("πŸŽ›οΈ Controls") # reset state when site changes if "current_site_id" not in st.session_state or st.session_state.current_site_id != site_id: st.session_state.current_site_id = site_id st.session_state.time_idx = event_idx if event_idx is not None else 0 st.session_state.patch_x = current_site.default_patch_x or W // 2 st.session_state.patch_y = current_site.default_patch_y or H // 2 # time control with +/- buttons st.sidebar.markdown("**⏱️ Time Selection**") # clamp time_idx to valid range (in case data size changed) st.session_state.time_idx = min(max(0, st.session_state.time_idx), T - 1) col_minus, col_slider, col_plus = st.sidebar.columns([1, 8, 1]) with col_minus: if st.button( "", key="time_minus", type="tertiary", help="Previous timestep", icon=":material/do_not_disturb_on:", ): st.session_state.time_idx = max(0, st.session_state.time_idx - 1) with col_slider: time_idx = st.slider( "Date", 0, T - 1, st.session_state.time_idx, format=f"{timestamps[st.session_state.time_idx]}", label_visibility="collapsed", ) st.session_state.time_idx = time_idx with col_plus: if st.button( "", key="time_plus", type="tertiary", help="Next timestep", icon=":material/add_circle:", ): st.session_state.time_idx = min(T - 1, st.session_state.time_idx + 1) time_idx = st.session_state.time_idx st.sidebar.markdown("**πŸ“ Patch Selection**") # clamp patch coordinates to valid range st.session_state.patch_x = min(max(0, st.session_state.patch_x), W - 1) st.session_state.patch_y = min(max(0, st.session_state.patch_y), H - 1) col1, col2 = st.sidebar.columns(2) patch_x = col1.number_input("X", 0, W - 1, st.session_state.patch_x, key="px") patch_y = col2.number_input("Y", 0, H - 1, st.session_state.patch_y, key="py") # update session state with any manual changes st.session_state.patch_x = patch_x st.session_state.patch_y = patch_y # main content: temporal analysis view (always shown) st.title(site_options[site_id]) # spatial context (small maps) st.markdown(f"### πŸ—ΊοΈ Spatial Context β€” `{timestamps[time_idx]}`") col1, col2, col3 = st.columns(3) with col1: st.markdown("**RGB**") if rgb_data is not None: render_rgb_image( rgb_data, time_idx, (int(patch_x), int(patch_y)), ) else: st.warning("RGB data not available") with col2: st.markdown("**PCA**") render_pca_features( pca_data, time_idx, (int(patch_x), int(patch_y)), ) with col3: st.markdown("**Anomaly Heatmap**") render_anomaly_map( anomaly_data["accumulated_anomalies"], (int(patch_x), int(patch_y)), ) st.markdown(f"### πŸ“ˆ Temporal Analysis β€” Patch `({patch_x}, {patch_y})`") # temporal series and anomaly timeline render_temporal_series( residuals=anomaly_data["residuals_timeseries"], anomaly_mask=anomaly_data["anomaly_mask_timeseries"], timestamps=timestamps, patch_coord=(int(patch_x), int(patch_y)), time_idx=time_idx, event_idx=event_idx, ) render_anomaly_timeline( anomaly_mask=anomaly_data["anomaly_mask_timeseries"], timestamps=timestamps, time_idx=time_idx, event_idx=event_idx, ) if __name__ == "__main__": run()