File size: 8,919 Bytes
a789c3a 121486d 7ccb831 7d50fd3 a789c3a ab91b06 a789c3a ab91b06 a789c3a 7ccb831 a789c3a dee7d1f a789c3a dee7d1f a789c3a 121486d a789c3a dee7d1f a789c3a 31946d6 a789c3a 31946d6 a789c3a d39f9cd 31946d6 a789c3a 7ccb831 121486d 7ccb831 121486d 7ccb831 121486d 7ccb831 121486d 7ccb831 121486d 7ccb831 121486d 7ccb831 121486d 31946d6 a789c3a d39f9cd dee7d1f 121486d a789c3a dee7d1f 121486d dee7d1f a789c3a dee7d1f 121486d a789c3a ab91b06 d39f9cd ab91b06 7d50fd3 ab91b06 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 | """Shared session state: grid + travel times computed once, stored in st.session_state."""
import numpy as np
import streamlit as st
from .data import (
Station,
load_water_polygon, load_stations_raw, classify_cells_by_zone, get_passage_coords,
load_risk_scenarios, load_shoreline,
)
from .config import ensure_config, get_neighbor_offsets, get_config_value
from .grid import generate_grid, snap_to_grid
from .graph import build_graph
from .risk_distribution import IncidentDistribution
from .routing import compute_travel_times
STATIONS_STATE_KEY = "active_stations"
def _compute(cell_size_m: int, neighbor_offsets: list[tuple[int, int]], neighbor_level: int):
water = load_water_polygon()
lats, lons, dlat, dlon = generate_grid(water, cell_size_m=cell_size_m)
zones = classify_cells_by_zone(lats, lons)
graph = build_graph(
lats, lons, dlat, dlon,
neighbor_offsets=neighbor_offsets,
cell_zones=zones,
passage_coords=get_passage_coords(),
passage_radius_m=1000.0,
)
stations = get_active_stations()
sources = [snap_to_grid(s.lat, s.lon, lats, lons) for s in stations]
speeds = [s.speed_kmh for s in stations]
travel_times = compute_travel_times(graph, sources, speeds)
min_times = np.min(travel_times, axis=0)
return {
"cell_size": cell_size_m,
"lats": lats,
"lons": lons,
"travel_times": travel_times,
"min_times": min_times,
"stations": stations,
"neighbor_level": neighbor_level,
"neighbor_offsets": tuple(neighbor_offsets),
}
def sidebar_section(title: str, expanded: bool = True):
"""Named sidebar group for related page parameters."""
return st.sidebar.expander(title, expanded=expanded)
def sidebar_controls(container=None):
"""Shared cell size slider. Persists via session_state."""
if container is None:
container = st.sidebar
if "cell_size" not in st.session_state:
st.session_state["cell_size"] = 200
val = container.slider(
"Размер ячейки (м)", 40, 1000,
value=st.session_state["cell_size"], step=20,
)
st.session_state["cell_size"] = val
return val
def _invalidate_station_results():
st.session_state.pop("results", None)
def _normalize_station_row(row: dict, fallback_id: str) -> dict:
raw_name = row.get("name")
raw_id = row.get("id")
name = "" if raw_name is None else str(raw_name).strip()
station_id = "" if raw_id is None else str(raw_id).strip()
if not name or name.lower() == "nan":
name = fallback_id
if not station_id or station_id.lower() == "nan":
station_id = name or fallback_id
lat = float(row["lat"])
lon = float(row["lon"])
speed_kmh = float(row["speed_kmh"])
if not (np.isfinite(lat) and np.isfinite(lon) and np.isfinite(speed_kmh)):
raise ValueError("Координаты и скорость должны быть конечными числами")
return {
"id": station_id,
"name": name,
"lat": lat,
"lon": lon,
"speed_kmh": speed_kmh,
}
def _base_station_rows() -> list[dict]:
return [
_normalize_station_row(row, fallback_id=f"station_{i + 1}")
for i, row in enumerate(load_stations_raw())
]
def _ensure_active_stations():
if STATIONS_STATE_KEY in st.session_state:
return
rows = _base_station_rows()
rows.extend(
_normalize_station_row(row, fallback_id=f"added_{i + 1}")
for i, row in enumerate(st.session_state.get("added_stations", []))
)
st.session_state[STATIONS_STATE_KEY] = rows
def get_added_stations_raw() -> list[dict]:
"""Stations added from optimization during the current Streamlit session."""
_ensure_active_stations()
base_ids = {row["id"] for row in _base_station_rows()}
return [
row.copy()
for row in st.session_state[STATIONS_STATE_KEY]
if row["id"] not in base_ids
]
def get_active_stations() -> list[Station]:
"""Stations active in the current session."""
return [Station(**row) for row in get_active_stations_raw()]
def get_active_stations_raw() -> list[dict]:
"""Raw station dicts for pydeck layers, including session edits."""
_ensure_active_stations()
return [row.copy() for row in st.session_state[STATIONS_STATE_KEY]]
def set_active_stations_raw(rows: list[dict]):
"""Replace all active stations for the current session."""
normalized = [
_normalize_station_row(row, fallback_id=f"station_{i + 1}")
for i, row in enumerate(rows)
]
ids = [row["id"] for row in normalized]
if len(set(ids)) != len(ids):
raise ValueError("ID станций должны быть уникальными")
if not normalized:
raise ValueError("Нужна хотя бы одна станция")
st.session_state[STATIONS_STATE_KEY] = normalized
_invalidate_station_results()
def reset_active_stations():
"""Restore stations from data/stations.json for the current session."""
st.session_state[STATIONS_STATE_KEY] = _base_station_rows()
st.session_state.pop("added_stations", None)
_invalidate_station_results()
def active_stations_signature() -> tuple:
"""Hashable signature for station-dependent caches."""
return tuple(
(s.id, round(float(s.lat), 7), round(float(s.lon), 7), float(s.speed_kmh))
for s in get_active_stations()
)
def add_session_station(name: str, lat: float, lon: float, speed_kmh: float, station_id: str) -> bool:
"""Add one station to the current session. Returns False if it already exists."""
rows = get_active_stations_raw()
ids = {row["id"] for row in rows}
if station_id in ids:
return False
rows.append(
{
"id": station_id,
"name": name,
"lat": float(lat),
"lon": float(lon),
"speed_kmh": float(speed_kmh),
}
)
set_active_stations_raw(rows)
return True
def risk_scenario_control(cfg: dict, scenarios: dict, label: str = "Сценарий", container=None):
"""Shared risk scenario selector. Stores the selected key in cfg."""
if container is None:
container = st.sidebar
options = list(scenarios)
current = cfg.get("risk_scenario", "summer")
if current not in scenarios:
current = options[0]
cfg["risk_scenario"] = container.selectbox(
label,
options,
index=options.index(current),
format_func=lambda key: scenarios[key].get("title", key),
)
return cfg["risk_scenario"]
def get_results():
"""Return precomputed results, recomputing only when cell_size changes."""
cell_size = st.session_state.get("cell_size", 200)
neighbor_level = int(get_config_value("neighbor_level"))
neighbor_offsets = get_neighbor_offsets()
neighbor_offsets_sig = tuple(neighbor_offsets)
stations_sig = active_stations_signature()
cached = st.session_state.get("results")
if (
cached is None
or cached["cell_size"] != cell_size
or cached.get("neighbor_level") != neighbor_level
or cached.get("neighbor_offsets") != neighbor_offsets_sig
or cached.get("stations_signature") != stations_sig
):
with st.spinner("Расчёт сетки и маршрутов..."):
st.session_state["results"] = _compute(cell_size, neighbor_offsets, neighbor_level)
st.session_state["results"]["stations_signature"] = stations_sig
r = st.session_state["results"]
return r["lats"], r["lons"], r["travel_times"], r["min_times"], r["stations"]
def get_risk_distribution() -> IncidentDistribution:
"""Return the configured incident distribution over the current grid."""
lats, lons, _, _, _ = get_results()
cfg = ensure_config()
scenarios = load_risk_scenarios()
scenario = cfg.get("risk_scenario", "summer")
if scenario not in scenarios:
scenario = next(iter(scenarios))
cfg["risk_scenario"] = scenario
cell_size = st.session_state.get("cell_size", 200)
cached = st.session_state.get("risk_distribution")
signature = (scenario, cell_size, len(lats))
if cached is None or cached.get("signature") != signature:
with st.spinner("Расчёт модельной плотности происшествий..."):
dist = IncidentDistribution.from_scenario(
scenario,
lats,
lons,
scenarios,
water_polygon=load_water_polygon(),
shoreline=load_shoreline(),
)
st.session_state["risk_distribution"] = {
"signature": signature,
"distribution": dist,
}
return st.session_state["risk_distribution"]["distribution"]
|