| """Shared session state: grid + travel times computed once, stored in st.session_state.""" |
|
|
| import numpy as np |
| import streamlit as st |
|
|
| from .data import ( |
| Station, |
| load_water_polygon, load_stations_raw, classify_cells_by_zone, get_passage_coords, |
| load_risk_scenarios, load_shoreline, |
| ) |
| from .config import ensure_config, get_neighbor_offsets, get_config_value |
| from .grid import generate_grid, snap_to_grid |
| from .graph import build_graph |
| from .risk_distribution import IncidentDistribution |
| from .routing import compute_travel_times |
|
|
| STATIONS_STATE_KEY = "active_stations" |
|
|
|
|
| def _compute(cell_size_m: int, neighbor_offsets: list[tuple[int, int]], neighbor_level: int): |
| water = load_water_polygon() |
| lats, lons, dlat, dlon = generate_grid(water, cell_size_m=cell_size_m) |
|
|
| zones = classify_cells_by_zone(lats, lons) |
| graph = build_graph( |
| lats, lons, dlat, dlon, |
| neighbor_offsets=neighbor_offsets, |
| cell_zones=zones, |
| passage_coords=get_passage_coords(), |
| passage_radius_m=1000.0, |
| ) |
|
|
| stations = get_active_stations() |
| sources = [snap_to_grid(s.lat, s.lon, lats, lons) for s in stations] |
| speeds = [s.speed_kmh for s in stations] |
| travel_times = compute_travel_times(graph, sources, speeds) |
| min_times = np.min(travel_times, axis=0) |
|
|
| return { |
| "cell_size": cell_size_m, |
| "lats": lats, |
| "lons": lons, |
| "travel_times": travel_times, |
| "min_times": min_times, |
| "stations": stations, |
| "neighbor_level": neighbor_level, |
| "neighbor_offsets": tuple(neighbor_offsets), |
| } |
|
|
|
|
| def sidebar_section(title: str, expanded: bool = True): |
| """Named sidebar group for related page parameters.""" |
| return st.sidebar.expander(title, expanded=expanded) |
|
|
|
|
| def sidebar_controls(container=None): |
| """Shared cell size slider. Persists via session_state.""" |
| if container is None: |
| container = st.sidebar |
| if "cell_size" not in st.session_state: |
| st.session_state["cell_size"] = 200 |
| val = container.slider( |
| "Размер ячейки (м)", 40, 1000, |
| value=st.session_state["cell_size"], step=20, |
| ) |
| st.session_state["cell_size"] = val |
| return val |
|
|
|
|
| def _invalidate_station_results(): |
| st.session_state.pop("results", None) |
|
|
|
|
| def _normalize_station_row(row: dict, fallback_id: str) -> dict: |
| raw_name = row.get("name") |
| raw_id = row.get("id") |
| name = "" if raw_name is None else str(raw_name).strip() |
| station_id = "" if raw_id is None else str(raw_id).strip() |
| if not name or name.lower() == "nan": |
| name = fallback_id |
| if not station_id or station_id.lower() == "nan": |
| station_id = name or fallback_id |
| lat = float(row["lat"]) |
| lon = float(row["lon"]) |
| speed_kmh = float(row["speed_kmh"]) |
| if not (np.isfinite(lat) and np.isfinite(lon) and np.isfinite(speed_kmh)): |
| raise ValueError("Координаты и скорость должны быть конечными числами") |
| return { |
| "id": station_id, |
| "name": name, |
| "lat": lat, |
| "lon": lon, |
| "speed_kmh": speed_kmh, |
| } |
|
|
|
|
| def _base_station_rows() -> list[dict]: |
| return [ |
| _normalize_station_row(row, fallback_id=f"station_{i + 1}") |
| for i, row in enumerate(load_stations_raw()) |
| ] |
|
|
|
|
| def _ensure_active_stations(): |
| if STATIONS_STATE_KEY in st.session_state: |
| return |
| rows = _base_station_rows() |
| rows.extend( |
| _normalize_station_row(row, fallback_id=f"added_{i + 1}") |
| for i, row in enumerate(st.session_state.get("added_stations", [])) |
| ) |
| st.session_state[STATIONS_STATE_KEY] = rows |
|
|
|
|
| def get_added_stations_raw() -> list[dict]: |
| """Stations added from optimization during the current Streamlit session.""" |
| _ensure_active_stations() |
| base_ids = {row["id"] for row in _base_station_rows()} |
| return [ |
| row.copy() |
| for row in st.session_state[STATIONS_STATE_KEY] |
| if row["id"] not in base_ids |
| ] |
|
|
|
|
| def get_active_stations() -> list[Station]: |
| """Stations active in the current session.""" |
| return [Station(**row) for row in get_active_stations_raw()] |
|
|
|
|
| def get_active_stations_raw() -> list[dict]: |
| """Raw station dicts for pydeck layers, including session edits.""" |
| _ensure_active_stations() |
| return [row.copy() for row in st.session_state[STATIONS_STATE_KEY]] |
|
|
|
|
| def set_active_stations_raw(rows: list[dict]): |
| """Replace all active stations for the current session.""" |
| normalized = [ |
| _normalize_station_row(row, fallback_id=f"station_{i + 1}") |
| for i, row in enumerate(rows) |
| ] |
| ids = [row["id"] for row in normalized] |
| if len(set(ids)) != len(ids): |
| raise ValueError("ID станций должны быть уникальными") |
| if not normalized: |
| raise ValueError("Нужна хотя бы одна станция") |
| st.session_state[STATIONS_STATE_KEY] = normalized |
| _invalidate_station_results() |
|
|
|
|
| def reset_active_stations(): |
| """Restore stations from data/stations.json for the current session.""" |
| st.session_state[STATIONS_STATE_KEY] = _base_station_rows() |
| st.session_state.pop("added_stations", None) |
| _invalidate_station_results() |
|
|
|
|
| def active_stations_signature() -> tuple: |
| """Hashable signature for station-dependent caches.""" |
| return tuple( |
| (s.id, round(float(s.lat), 7), round(float(s.lon), 7), float(s.speed_kmh)) |
| for s in get_active_stations() |
| ) |
|
|
|
|
| def add_session_station(name: str, lat: float, lon: float, speed_kmh: float, station_id: str) -> bool: |
| """Add one station to the current session. Returns False if it already exists.""" |
| rows = get_active_stations_raw() |
| ids = {row["id"] for row in rows} |
| if station_id in ids: |
| return False |
|
|
| rows.append( |
| { |
| "id": station_id, |
| "name": name, |
| "lat": float(lat), |
| "lon": float(lon), |
| "speed_kmh": float(speed_kmh), |
| } |
| ) |
| set_active_stations_raw(rows) |
| return True |
|
|
|
|
| def risk_scenario_control(cfg: dict, scenarios: dict, label: str = "Сценарий", container=None): |
| """Shared risk scenario selector. Stores the selected key in cfg.""" |
| if container is None: |
| container = st.sidebar |
| options = list(scenarios) |
| current = cfg.get("risk_scenario", "summer") |
| if current not in scenarios: |
| current = options[0] |
| cfg["risk_scenario"] = container.selectbox( |
| label, |
| options, |
| index=options.index(current), |
| format_func=lambda key: scenarios[key].get("title", key), |
| ) |
| return cfg["risk_scenario"] |
|
|
|
|
| def get_results(): |
| """Return precomputed results, recomputing only when cell_size changes.""" |
| cell_size = st.session_state.get("cell_size", 200) |
| neighbor_level = int(get_config_value("neighbor_level")) |
| neighbor_offsets = get_neighbor_offsets() |
| neighbor_offsets_sig = tuple(neighbor_offsets) |
| stations_sig = active_stations_signature() |
| cached = st.session_state.get("results") |
|
|
| if ( |
| cached is None |
| or cached["cell_size"] != cell_size |
| or cached.get("neighbor_level") != neighbor_level |
| or cached.get("neighbor_offsets") != neighbor_offsets_sig |
| or cached.get("stations_signature") != stations_sig |
| ): |
| with st.spinner("Расчёт сетки и маршрутов..."): |
| st.session_state["results"] = _compute(cell_size, neighbor_offsets, neighbor_level) |
| st.session_state["results"]["stations_signature"] = stations_sig |
|
|
| r = st.session_state["results"] |
| return r["lats"], r["lons"], r["travel_times"], r["min_times"], r["stations"] |
|
|
|
|
| def get_risk_distribution() -> IncidentDistribution: |
| """Return the configured incident distribution over the current grid.""" |
| lats, lons, _, _, _ = get_results() |
| cfg = ensure_config() |
| scenarios = load_risk_scenarios() |
| scenario = cfg.get("risk_scenario", "summer") |
| if scenario not in scenarios: |
| scenario = next(iter(scenarios)) |
| cfg["risk_scenario"] = scenario |
|
|
| cell_size = st.session_state.get("cell_size", 200) |
| cached = st.session_state.get("risk_distribution") |
| signature = (scenario, cell_size, len(lats)) |
| if cached is None or cached.get("signature") != signature: |
| with st.spinner("Расчёт модельной плотности происшествий..."): |
| dist = IncidentDistribution.from_scenario( |
| scenario, |
| lats, |
| lons, |
| scenarios, |
| water_polygon=load_water_polygon(), |
| shoreline=load_shoreline(), |
| ) |
| st.session_state["risk_distribution"] = { |
| "signature": signature, |
| "distribution": dist, |
| } |
|
|
| return st.session_state["risk_distribution"]["distribution"] |
|
|