Spaces:
Build error
Build error
| from forecast.page_config import APP_PAGE_HEADER | |
| import streamlit as st | |
| import pandas as pd | |
| APP_PAGE_HEADER() | |
| class InputData: | |
| def get_data(cls) -> pd.DataFrame: | |
| """ | |
| Datasets sources: | |
| avg_daily_air_temp_celsius_helsinki: http://shorturl.at/gBR06 | |
| Returns: | |
| """ | |
| st1, st2 = st.columns(2) | |
| sample = st1.selectbox( | |
| "Sample datasets", | |
| options=["", "sample1", "sample2", "avg_daily_air_temp_celsius_helsinki"], | |
| ) | |
| if sample: | |
| file_ = f"data/{sample}.csv" | |
| return pd.read_csv(file_) | |
| uploaded_data = cls.read_file(st2) | |
| if uploaded_data is not None: | |
| return uploaded_data | |
| def read_file(cls, st_): | |
| file_ = st_.file_uploader("Upload your dataset (csv file)") | |
| if not file_: | |
| st.stop() | |
| if file_: | |
| sep = st.selectbox("column sep", options=[",", ";", "|"]) | |
| df = pd.read_csv(file_, sep=sep) | |
| cols = df.columns.tolist() | |
| # -- choose date/target columns | |
| st1, st3, st2 = st.columns(3) | |
| date_col = st1.selectbox( | |
| "Date column (x-axis / index)", options=[""] + cols | |
| ) | |
| st.session_state.date_formatter = st2.text_input( | |
| "Optional: Date format e.g. %Y-%m-%d" | |
| ) | |
| target = st3.selectbox( | |
| "Target column (y-axis / target variable)", options=[""] + cols | |
| ) | |
| # -- display | |
| if date_col and target: | |
| df = df[[date_col, target]] | |
| return df | |
| st.write(df) | |
| st.stop() | |
| return file_ | |
| def preprocess_data(cls, df: pd.DataFrame) -> pd.DataFrame: | |
| df.columns = ["ds", "y"] | |
| # -- date column | |
| try: | |
| date_formatter = ( | |
| st.session_state.date_formatter | |
| if "date_formatter" in st.session_state | |
| else None | |
| ) | |
| if date_formatter: | |
| st.write("hello") | |
| st.stop() | |
| df["ds"] = pd.to_datetime( | |
| df["ds"], format=st.session_state.date_formatter | |
| ) | |
| else: | |
| df["ds"] = pd.to_datetime(df["ds"]).dt.date | |
| except: | |
| st.error("Date column is not in correct format") | |
| st.write(df["ds"]) | |
| st.stop() | |
| # -- target column | |
| df["y"] = df["y"].apply(lambda x: float(str(x).replace(",", ""))) | |
| df["y"] = df["y"].astype(float) | |
| return df | |
| class PredictionApp: | |
| def run_prediction(df: pd.DataFrame) -> pd.DataFrame: | |
| # -- prepare data and user input | |
| st1, st2 = st.columns(2) | |
| st1.write(df) | |
| st2.line_chart(df.set_index("ds")) | |
| segmented_df = PredictionApp.split_df_by_date(df) | |
| # -- future date picker | |
| n = segmented_df.shape[0] | |
| future = st.slider("Number of days to predict", 7, n * 2, value=int(n / 2)) | |
| params = PredictionApp.user_input_model_params() | |
| run = st.button("Run") | |
| PredictionApp.display_prophet_docs() | |
| if not run: | |
| return | |
| # -- run prediction | |
| with st.spinner("running prediction engine .."): | |
| from forecast.fbprophet.model import ProphetModel | |
| model = ProphetModel() | |
| pred = model.predict(segmented_df, period=future, **params) | |
| return pred | |
| def split_df_by_date(df: pd.DataFrame) -> pd.DataFrame: | |
| # -- split dataframe by date | |
| st.caption("Choose the target fitting period") | |
| st1, st2 = st.columns(2) | |
| from_ = st1.date_input( | |
| "from", min_value=df.ds.min(), max_value=df.ds.max(), value=df.ds.min() | |
| ) | |
| to_ = st2.date_input( | |
| "to", min_value=df.ds.min(), max_value=df.ds.max(), value=df.ds.max() | |
| ) | |
| ix1 = df.index[df.ds == from_][0] | |
| ix2 = df.index[df.ds == to_][0] | |
| new_df = PredictionApp._displayed_segmented_dataframe(df, from_=ix1, to_=ix2) | |
| new_df.reset_index(inplace=True) | |
| st1.write(f"{new_df['ds'].min()}") | |
| st2.write(f"{new_df['ds'].max()}") | |
| return new_df | |
| def _displayed_segmented_dataframe( | |
| cls, df: pd.DataFrame, from_: int, to_: int | |
| ) -> pd.DataFrame: | |
| df = df.set_index("ds") | |
| df_ = df[from_ : to_ + 1] | |
| st.line_chart(df_) | |
| return df_ | |
| def user_input_model_params(cls): | |
| raw_params = st.text_input( | |
| "Model params. Type param name and its value e.g. growth=logistic" | |
| ) | |
| if raw_params: | |
| in_params = [x.strip() for x in raw_params.split(",")] | |
| params = {} | |
| for param in in_params: | |
| k, v = param.split("=") | |
| params[k] = float(v) if v.isdigit() else v | |
| if "growth" in params and params["growth"] == "logistic": | |
| cap = st.text_input("cap") | |
| if cap: | |
| params["cap"] = float(cap) | |
| else: | |
| st.warning("Cap is required for logistic growth") | |
| st.stop() | |
| st.write("Your input params:") | |
| st.write(params) | |
| return params | |
| return {} | |
| def display_prophet_docs(cls): | |
| from prophet import Prophet | |
| with st.expander("View model params"): | |
| st.write(Prophet.__doc__) | |
| st.markdown( | |
| "> more details: [visit](https://facebook.github.io/prophet/)", | |
| unsafe_allow_html=True, | |
| ) | |
| def app(): | |
| data = InputData.get_data() | |
| data = InputData.preprocess_data(data) | |
| PredictionApp.run_prediction(data) | |
| if __name__ == "__main__": | |
| app() | |