"""Shared session state: grid + travel times computed once, stored in st.session_state.""" import numpy as np import streamlit as st from .data import ( Station, load_water_polygon, load_stations_raw, classify_cells_by_zone, get_passage_coords, load_risk_scenarios, load_shoreline, ) from .config import ensure_config, get_neighbor_offsets, get_config_value from .grid import generate_grid, snap_to_grid from .graph import build_graph from .risk_distribution import IncidentDistribution from .routing import compute_travel_times STATIONS_STATE_KEY = "active_stations" def _compute(cell_size_m: int, neighbor_offsets: list[tuple[int, int]], neighbor_level: int): water = load_water_polygon() lats, lons, dlat, dlon = generate_grid(water, cell_size_m=cell_size_m) zones = classify_cells_by_zone(lats, lons) graph = build_graph( lats, lons, dlat, dlon, neighbor_offsets=neighbor_offsets, cell_zones=zones, passage_coords=get_passage_coords(), passage_radius_m=1000.0, ) stations = get_active_stations() sources = [snap_to_grid(s.lat, s.lon, lats, lons) for s in stations] speeds = [s.speed_kmh for s in stations] travel_times = compute_travel_times(graph, sources, speeds) min_times = np.min(travel_times, axis=0) return { "cell_size": cell_size_m, "lats": lats, "lons": lons, "travel_times": travel_times, "min_times": min_times, "stations": stations, "neighbor_level": neighbor_level, "neighbor_offsets": tuple(neighbor_offsets), } def sidebar_section(title: str, expanded: bool = True): """Named sidebar group for related page parameters.""" return st.sidebar.expander(title, expanded=expanded) def sidebar_controls(container=None): """Shared cell size slider. Persists via session_state.""" if container is None: container = st.sidebar if "cell_size" not in st.session_state: st.session_state["cell_size"] = 200 val = container.slider( "Размер ячейки (м)", 40, 1000, value=st.session_state["cell_size"], step=20, ) st.session_state["cell_size"] = val return val def _invalidate_station_results(): st.session_state.pop("results", None) def _normalize_station_row(row: dict, fallback_id: str) -> dict: raw_name = row.get("name") raw_id = row.get("id") name = "" if raw_name is None else str(raw_name).strip() station_id = "" if raw_id is None else str(raw_id).strip() if not name or name.lower() == "nan": name = fallback_id if not station_id or station_id.lower() == "nan": station_id = name or fallback_id lat = float(row["lat"]) lon = float(row["lon"]) speed_kmh = float(row["speed_kmh"]) if not (np.isfinite(lat) and np.isfinite(lon) and np.isfinite(speed_kmh)): raise ValueError("Координаты и скорость должны быть конечными числами") return { "id": station_id, "name": name, "lat": lat, "lon": lon, "speed_kmh": speed_kmh, } def _base_station_rows() -> list[dict]: return [ _normalize_station_row(row, fallback_id=f"station_{i + 1}") for i, row in enumerate(load_stations_raw()) ] def _ensure_active_stations(): if STATIONS_STATE_KEY in st.session_state: return rows = _base_station_rows() rows.extend( _normalize_station_row(row, fallback_id=f"added_{i + 1}") for i, row in enumerate(st.session_state.get("added_stations", [])) ) st.session_state[STATIONS_STATE_KEY] = rows def get_added_stations_raw() -> list[dict]: """Stations added from optimization during the current Streamlit session.""" _ensure_active_stations() base_ids = {row["id"] for row in _base_station_rows()} return [ row.copy() for row in st.session_state[STATIONS_STATE_KEY] if row["id"] not in base_ids ] def get_active_stations() -> list[Station]: """Stations active in the current session.""" return [Station(**row) for row in get_active_stations_raw()] def get_active_stations_raw() -> list[dict]: """Raw station dicts for pydeck layers, including session edits.""" _ensure_active_stations() return [row.copy() for row in st.session_state[STATIONS_STATE_KEY]] def set_active_stations_raw(rows: list[dict]): """Replace all active stations for the current session.""" normalized = [ _normalize_station_row(row, fallback_id=f"station_{i + 1}") for i, row in enumerate(rows) ] ids = [row["id"] for row in normalized] if len(set(ids)) != len(ids): raise ValueError("ID станций должны быть уникальными") if not normalized: raise ValueError("Нужна хотя бы одна станция") st.session_state[STATIONS_STATE_KEY] = normalized _invalidate_station_results() def reset_active_stations(): """Restore stations from data/stations.json for the current session.""" st.session_state[STATIONS_STATE_KEY] = _base_station_rows() st.session_state.pop("added_stations", None) _invalidate_station_results() def active_stations_signature() -> tuple: """Hashable signature for station-dependent caches.""" return tuple( (s.id, round(float(s.lat), 7), round(float(s.lon), 7), float(s.speed_kmh)) for s in get_active_stations() ) def add_session_station(name: str, lat: float, lon: float, speed_kmh: float, station_id: str) -> bool: """Add one station to the current session. Returns False if it already exists.""" rows = get_active_stations_raw() ids = {row["id"] for row in rows} if station_id in ids: return False rows.append( { "id": station_id, "name": name, "lat": float(lat), "lon": float(lon), "speed_kmh": float(speed_kmh), } ) set_active_stations_raw(rows) return True def risk_scenario_control(cfg: dict, scenarios: dict, label: str = "Сценарий", container=None): """Shared risk scenario selector. Stores the selected key in cfg.""" if container is None: container = st.sidebar options = list(scenarios) current = cfg.get("risk_scenario", "summer") if current not in scenarios: current = options[0] cfg["risk_scenario"] = container.selectbox( label, options, index=options.index(current), format_func=lambda key: scenarios[key].get("title", key), ) return cfg["risk_scenario"] def get_results(): """Return precomputed results, recomputing only when cell_size changes.""" cell_size = st.session_state.get("cell_size", 200) neighbor_level = int(get_config_value("neighbor_level")) neighbor_offsets = get_neighbor_offsets() neighbor_offsets_sig = tuple(neighbor_offsets) stations_sig = active_stations_signature() cached = st.session_state.get("results") if ( cached is None or cached["cell_size"] != cell_size or cached.get("neighbor_level") != neighbor_level or cached.get("neighbor_offsets") != neighbor_offsets_sig or cached.get("stations_signature") != stations_sig ): with st.spinner("Расчёт сетки и маршрутов..."): st.session_state["results"] = _compute(cell_size, neighbor_offsets, neighbor_level) st.session_state["results"]["stations_signature"] = stations_sig r = st.session_state["results"] return r["lats"], r["lons"], r["travel_times"], r["min_times"], r["stations"] def get_risk_distribution() -> IncidentDistribution: """Return the configured incident distribution over the current grid.""" lats, lons, _, _, _ = get_results() cfg = ensure_config() scenarios = load_risk_scenarios() scenario = cfg.get("risk_scenario", "summer") if scenario not in scenarios: scenario = next(iter(scenarios)) cfg["risk_scenario"] = scenario cell_size = st.session_state.get("cell_size", 200) cached = st.session_state.get("risk_distribution") signature = (scenario, cell_size, len(lats)) if cached is None or cached.get("signature") != signature: with st.spinner("Расчёт модельной плотности происшествий..."): dist = IncidentDistribution.from_scenario( scenario, lats, lons, scenarios, water_polygon=load_water_polygon(), shoreline=load_shoreline(), ) st.session_state["risk_distribution"] = { "signature": signature, "distribution": dist, } return st.session_state["risk_distribution"]["distribution"]