File size: 8,499 Bytes
aec8fff
23cf709
aec8fff
 
23cf709
aec8fff
19dc079
7ed4d70
23cf709
95805de
23cf709
 
 
 
95805de
 
5af8aff
aec8fff
23cf709
aec8fff
 
23cf709
 
 
 
 
 
 
 
 
 
 
 
aec8fff
 
 
 
 
 
 
 
 
 
7ed4d70
 
23cf709
 
 
 
 
 
 
7ed4d70
 
 
23cf709
 
 
 
 
7ed4d70
 
 
 
5af8aff
 
 
 
 
 
 
 
 
 
 
 
 
 
23cf709
7ed4d70
 
 
 
aec8fff
 
23cf709
5af8aff
23cf709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5af8aff
 
 
23cf709
 
 
 
 
 
 
 
 
 
 
 
 
 
5af8aff
 
 
 
23cf709
19dc079
23cf709
 
 
 
 
aec8fff
7ed4d70
 
23cf709
 
7ed4d70
5af8aff
 
 
 
 
 
 
 
 
 
bc24113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aec8fff
 
 
0bc6627
 
 
 
 
19dc079
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import calendar
from datetime import datetime

import streamlit as st
from streamlit.errors import StreamlitAPIException

from admin import AuthManager, login
from components import render_dataset_metadata, render_records_by_year
from config import AppConfig
from dashboard_analytics import log_visit
from utils.data_loading import get_dataset_metadata
from utils.session import ensure_session_initialized

ensure_session_initialized()
log_visit("Settings")

col1, col2 = st.columns([2, 1])

data_manager = st.session_state.data_manager
current_start = st.session_state.get("start_date")
current_end = st.session_state.get("end_date")
reporting_month: int = st.session_state.get(
    "reporting_month", AppConfig.DEFAULT_REPORTING_MONTH
)

st.session_state.data = data_manager.load_data(
    start_date=current_start, end_date=current_end, reporting_month=reporting_month
)

full_metadata = data_manager.metadata
metadata = get_dataset_metadata(
    st.session_state.data["raw_df"], reporting_month=reporting_month
)

# Get the date range from metadata
raw_start = full_metadata["date_range"]["start"]
raw_end = full_metadata["date_range"]["end"]

# Set to first day of start month and last day of end month
min_date = raw_start.replace(day=1)
max_date = raw_end.replace(day=calendar.monthrange(raw_end.year, raw_end.month)[1])


def on_date_change():
    """Callback for date input changes"""
    # Ensure these variables are initialized
    if (
        "dataset_start_date" not in st.session_state
        or "dataset_end_date" not in st.session_state
    ):
        return

    start = st.session_state.dataset_start_date
    end = st.session_state.dataset_end_date

    # Check if either date is None or invalid - if so, keep previous values
    if start is None or end is None:
        st.error("Both start and end dates must be selected")
        return

    if start > end:
        st.error("Start date must be before end date")
        return

    if start < min_date or end > max_date:
        st.error(
            f"Dates must be between {min_date.strftime('%m/%d/%Y')} and {max_date.strftime('%m/%d/%Y')}"
        )
        if start < min_date:
            del st.session_state["dataset_start_date"]
        if end > max_date:
            del st.session_state["dataset_end_date"]
        # Force UI update by toggling a session state variable
        st.session_state["force_refresh"] = not st.session_state.get(
            "force_refresh", False
        )
        return

    # Only update if we have valid dates
    st.session_state.start_date = start
    st.session_state.end_date = end


with col1:
    st.subheader("Reporting Period")
    # Date filter controls
    filter_col1, filter_col2, filter_col3 = st.columns(3, vertical_alignment="bottom")

    # Use existing values from session state, or defaults if not set
    current_start = st.session_state.get("start_date", min_date) or min_date
    current_end = st.session_state.get("end_date", max_date) or max_date

    with filter_col1:
        try:
            start_date = st.date_input(
                "Start Date",
                value=current_start,
                min_value=min_date,
                max_value=max_date,
                format="MM/DD/YYYY",
                key="dataset_start_date",
                on_change=on_date_change,
            )
        except StreamlitAPIException:
            start_date = current_start
            st.error(
                f"Date must be between {min_date.strftime('%m/%d/%Y')} and {max_date.strftime('%m/%d/%Y')}"
            )

    with filter_col2:
        try:
            end_date = st.date_input(
                "End Date",
                value=current_end,
                min_value=min_date,
                max_value=max_date,
                format="MM/DD/YYYY",
                key="dataset_end_date",
                on_change=on_date_change,
            )
        except StreamlitAPIException:
            end_date = current_end
            # Show an error message
            st.error(
                f"Date must be between {min_date.strftime('%m/%d/%Y')} and {max_date.strftime('%m/%d/%Y')}"
            )

    config = st.session_state.get("config") or AppConfig.from_env()
    initial_reporting_month = st.session_state.get(
        "reporting_month", config.DEFAULT_REPORTING_MONTH
    )
    initial_dataset_month = st.session_state.get(
        "dataset_reporting_month", initial_reporting_month
    )

    def on_reporting_month_change():
        if "dataset_reporting_month" in st.session_state:
            st.session_state.reporting_month = st.session_state.dataset_reporting_month

    filter_row2_col1, _ = st.columns([1, 2])
    with filter_row2_col1:
        reporting_month = st.selectbox(
            "Reporting Year End Month",
            options=range(1, 13),
            format_func=lambda x: datetime(2000, x, 1).strftime("%B"),
            index=initial_reporting_month - 1,
            key="dataset_reporting_month",
            on_change=on_reporting_month_change,
        )
    st.subheader("Data Exclusions")
    exclusion_col1, exclusion_col2 = st.columns(2)

    def on_sector_exclusion_change():
        """Callback for sector exclusion changes"""
        if "excluded_sectors_widget" in st.session_state:
            # Update the persistent storage with widget values
            st.session_state.persistent_excluded_sectors = (
                st.session_state.excluded_sectors_widget
            )

            # Reload data with new exclusions
            st.session_state.data = st.session_state.data_manager.load_data(
                start_date=st.session_state.get("start_date"),
                end_date=st.session_state.get("end_date"),
                reporting_month=st.session_state.get("reporting_month"),
            )

    def on_station_exclusion_change():
        """Callback for station exclusion changes"""
        if "excluded_stations_widget" in st.session_state:
            # Update the persistent storage with widget values
            st.session_state.persistent_excluded_stations = (
                st.session_state.excluded_stations_widget
            )

            # Reload data with new exclusions
            st.session_state.data = st.session_state.data_manager.load_data(
                start_date=st.session_state.get("start_date"),
                end_date=st.session_state.get("end_date"),
                reporting_month=st.session_state.get("reporting_month"),
            )

    # Initialize persistent storage if not exists
    if "persistent_excluded_sectors" not in st.session_state:
        st.session_state.persistent_excluded_sectors = []
    if "persistent_excluded_stations" not in st.session_state:
        st.session_state.persistent_excluded_stations = []

    # Reload data if there are any exclusions
    if st.session_state.get("persistent_excluded_sectors") or st.session_state.get(
        "persistent_excluded_stations"
    ):
        st.session_state.data = st.session_state.data_manager.load_data(
            start_date=st.session_state.get("start_date"),
            end_date=st.session_state.get("end_date"),
            reporting_month=st.session_state.get("reporting_month"),
        )

    with exclusion_col1:
        # Get complete list of sectors from data manager
        all_sectors = st.session_state.data_manager.all_sectors

        st.multiselect(
            "Exclude Sectors",
            options=all_sectors,
            default=st.session_state.persistent_excluded_sectors,
            help="Select sectors to exclude from all analyses",
            key="excluded_sectors_widget",
            on_change=on_sector_exclusion_change,
        )

    with exclusion_col2:
        # Get complete list of stations from data manager
        all_stations = st.session_state.data_manager.all_stations

        st.multiselect(
            "Exclude Stations",
            options=all_stations,
            default=st.session_state.persistent_excluded_stations,
            help="Select stations to exclude from all analyses",
            key="excluded_stations_widget",
            on_change=on_station_exclusion_change,
        )

    render_dataset_metadata(metadata, min_date, max_date)
    render_records_by_year(st.session_state.data["raw_df"], reporting_month)

    # Add the login form in a half-width column
    login_col1, _ = st.columns([1, 1])
    with login_col1:
        with st.expander("Admin Login"):
            login(AuthManager(config))