Spaces:
Build error
Build error
| import streamlit as st | |
| import pandas as pd | |
| from data_analysis import * | |
| import numpy as np | |
| import pickle | |
| import streamlit as st | |
| from utilities import set_header, load_local_css, update_db, project_selection | |
| from post_gres_cred import db_cred | |
| from utilities import update_db | |
| import re | |
| st.set_page_config( | |
| page_title="Data Assessment", | |
| page_icon=":shark:", | |
| layout="wide", | |
| initial_sidebar_state="collapsed", | |
| ) | |
| schema = db_cred["schema"] | |
| load_local_css("styles.css") | |
| set_header() | |
| if "username" not in st.session_state: | |
| st.session_state["username"] = None | |
| if "project_name" not in st.session_state: | |
| st.session_state["project_name"] = None | |
| if "project_dct" not in st.session_state: | |
| project_selection() | |
| st.stop() | |
| if "username" in st.session_state and st.session_state["username"] is not None: | |
| if st.session_state["project_dct"]["data_import"]["imputed_tool_df"] is None: | |
| st.error(f"Please import data from the Data Import Page") | |
| st.stop() | |
| st.session_state["cleaned_data"] = st.session_state["project_dct"]["data_import"][ | |
| "imputed_tool_df" | |
| ] | |
| st.session_state["category_dict"] = st.session_state["project_dct"]["data_import"][ | |
| "category_dict" | |
| ] | |
| # st.write(st.session_state['category_dict']) | |
| cols1 = st.columns([2, 1]) | |
| with cols1[0]: | |
| st.markdown(f"**Welcome {st.session_state['username']}**") | |
| with cols1[1]: | |
| st.markdown(f"**Current Project: {st.session_state['project_name']}**") | |
| st.title("Data Assessment") | |
| target_variables = [ | |
| st.session_state["category_dict"][key] | |
| for key in st.session_state["category_dict"].keys() | |
| if key == "Response Metrics" | |
| ] | |
| def format_display(inp): | |
| return ( | |
| inp.title() | |
| .replace("_", " ") | |
| .replace("Media", "") | |
| .replace("Cnt", "") | |
| .strip() | |
| ) | |
| target_variables = list(*target_variables) | |
| target_column = st.selectbox( | |
| "Select the Target Feature/Dependent Variable (will be used in all charts as reference)", | |
| target_variables, | |
| index=st.session_state["project_dct"]["data_validation"]["target_column"], | |
| format_func=format_display, | |
| ) | |
| st.session_state["project_dct"]["data_validation"]["target_column"] = ( | |
| target_variables.index(target_column) | |
| ) | |
| st.session_state["target_column"] = target_column | |
| if "panel" not in st.session_state["cleaned_data"].columns: | |
| st.write('True') | |
| st.session_state["cleaned_data"]["panel"] = ["Aggregated"] * len( | |
| st.session_state["cleaned_data"] | |
| ) | |
| disable = True | |
| else: | |
| panels = st.session_state["cleaned_data"]["panel"] | |
| disable = False | |
| selected_panels = st.multiselect( | |
| "Please choose the panels you wish to analyze.If no panels are selected, insights will be derived from the overall data.", | |
| st.session_state["cleaned_data"]["panel"].unique(), | |
| default=st.session_state["project_dct"]["data_validation"]["selected_panels"], | |
| disabled=disable, | |
| ) | |
| st.session_state["project_dct"]["data_validation"][ | |
| "selected_panels" | |
| ] = selected_panels | |
| aggregation_dict = { | |
| item: "sum" if key == "Media" else "mean" | |
| for key, value in st.session_state["category_dict"].items() | |
| for item in value | |
| if item not in ["date", "panel"] | |
| } | |
| aggregation_dict = { | |
| key: value | |
| for key, value in aggregation_dict.items() | |
| if key in st.session_state["cleaned_data"].columns | |
| } | |
| with st.expander("**Target Variable Analysis**"): | |
| if len(selected_panels) > 0: | |
| st.session_state["Cleaned_data_panel"] = st.session_state["cleaned_data"][ | |
| st.session_state["cleaned_data"]["panel"].isin(selected_panels) | |
| ] | |
| st.session_state["Cleaned_data_panel"] = ( | |
| st.session_state["Cleaned_data_panel"] | |
| .groupby(by="date") | |
| .agg(aggregation_dict) | |
| ) | |
| st.session_state["Cleaned_data_panel"] = st.session_state[ | |
| "Cleaned_data_panel" | |
| ].reset_index() | |
| else: | |
| # st.write(st.session_state['cleaned_data']) | |
| st.session_state["Cleaned_data_panel"] = ( | |
| st.session_state["cleaned_data"] | |
| .groupby(by="date") | |
| .agg(aggregation_dict) | |
| ) | |
| st.session_state["Cleaned_data_panel"] = st.session_state[ | |
| "Cleaned_data_panel" | |
| ].reset_index() | |
| fig = line_plot_target( | |
| st.session_state["Cleaned_data_panel"], | |
| target=target_column, | |
| title=f"{target_column} Over Time", | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| media_channel = list( | |
| *[ | |
| st.session_state["category_dict"][key] | |
| for key in st.session_state["category_dict"].keys() | |
| if key == "Media" | |
| ] | |
| ) | |
| spends_features = list( | |
| *[ | |
| st.session_state["category_dict"][key] | |
| for key in st.session_state["category_dict"].keys() | |
| if key == "Spends" | |
| ] | |
| ) | |
| # st.write(media_channel) | |
| exo_var = list( | |
| *[ | |
| st.session_state["category_dict"][key] | |
| for key in st.session_state["category_dict"].keys() | |
| if key == "Exogenous" | |
| ] | |
| ) | |
| internal_var = list( | |
| *[ | |
| st.session_state["category_dict"][key] | |
| for key in st.session_state["category_dict"].keys() | |
| if key == "Internal" | |
| ] | |
| ) | |
| Non_media_variables = exo_var + internal_var | |
| st.markdown("### Annual Data Summary") | |
| summary_df = summary( | |
| st.session_state["Cleaned_data_panel"], | |
| media_channel + [target_column] + spends_features, | |
| spends=None, | |
| Target=True, | |
| ) | |
| st.dataframe( | |
| summary_df.sort_index(axis=1), | |
| use_container_width=True, | |
| ) | |
| if st.checkbox("View Raw Data"): | |
| st.cache_resource(show_spinner=False) | |
| def raw_df_gen(): | |
| # Convert 'date' to datetime but do not convert to string yet for sorting | |
| dates = pd.to_datetime(st.session_state["Cleaned_data_panel"]["date"]) | |
| # Concatenate the dates with other numeric columns formatted | |
| raw_df = pd.concat( | |
| [ | |
| dates, | |
| st.session_state["Cleaned_data_panel"] | |
| .select_dtypes(np.number) | |
| .applymap(format_numbers), | |
| ], | |
| axis=1, | |
| ) | |
| # Now sort raw_df by the 'date' column, which is still in datetime format | |
| sorted_raw_df = raw_df.sort_values(by="date", ascending=True) | |
| # After sorting, convert 'date' to string format for display | |
| sorted_raw_df["date"] = sorted_raw_df["date"].dt.strftime("%m/%d/%Y") | |
| return sorted_raw_df | |
| # Display the sorted DataFrame in Streamlit | |
| st.dataframe(raw_df_gen()) | |
| col1 = st.columns(1) | |
| if "selected_feature" not in st.session_state: | |
| st.session_state["selected_feature"] = None | |
| # st.warning('Work in Progress') | |
| with st.expander("Media Variables Analysis"): | |
| # Get the selected feature | |
| st.session_state["selected_feature"] = st.selectbox( | |
| "Select Media", media_channel + spends_features, format_func=format_display | |
| ) | |
| # st.write(st.session_state["selected_feature"].split('cnt_')[1] ) | |
| # st.session_state["project_dct"]["data_validation"]["selected_feature"] = ( | |
| # ) | |
| # Filter spends features based on the selected feature | |
| spends_col = st.columns(2) | |
| spends_feature = [ | |
| col | |
| for col in spends_features | |
| if re.split(r"cost_|spends_", col.lower())[1] | |
| in st.session_state["selected_feature"] | |
| ] | |
| with spends_col[0]: | |
| if len(spends_feature) == 0: | |
| st.warning( | |
| "The selected metric does not include a 'spends' variable in the data. Please verify that the columns are correctly named or select the appropriate columns in the provided selection box." | |
| ) | |
| else: | |
| st.write( | |
| f'Selected "{spends_feature[0]}" as the corresponding spends variable automatically. If this is incorrect, please click the checkbox to change the variable.' | |
| ) | |
| with spends_col[1]: | |
| if len(spends_feature) == 0 or st.checkbox( | |
| 'Select "Spends" variable for CPM and CPC calculation' | |
| ): | |
| spends_feature = [st.selectbox("Spends Variable", spends_features)] | |
| if "validation" not in st.session_state: | |
| st.session_state["validation"] = st.session_state["project_dct"][ | |
| "data_validation" | |
| ]["validated_variables"] | |
| val_variables = [col for col in media_channel if col != "date"] | |
| if not set( | |
| st.session_state["project_dct"]["data_validation"]["validated_variables"] | |
| ).issubset(set(val_variables)): | |
| st.session_state["validation"] = [] | |
| else: | |
| fig_row1 = line_plot( | |
| st.session_state["Cleaned_data_panel"], | |
| x_col="date", | |
| y1_cols=[st.session_state["selected_feature"]], | |
| y2_cols=[target_column], | |
| title=f'Analysis of {st.session_state["selected_feature"]} and {[target_column][0]} Over Time', | |
| ) | |
| st.plotly_chart(fig_row1, use_container_width=True) | |
| st.markdown("### Summary") | |
| st.dataframe( | |
| summary( | |
| st.session_state["Cleaned_data_panel"], | |
| [st.session_state["selected_feature"]], | |
| spends=spends_feature[0], | |
| ), | |
| use_container_width=True, | |
| ) | |
| cols2 = st.columns(2) | |
| if len( | |
| set(st.session_state["validation"]).intersection(val_variables) | |
| ) == len(val_variables): | |
| disable = True | |
| help = "All media variables are validated" | |
| else: | |
| disable = False | |
| help = "" | |
| with cols2[0]: | |
| if st.button("Validate", disabled=disable, help=help): | |
| st.session_state["validation"].append( | |
| st.session_state["selected_feature"] | |
| ) | |
| with cols2[1]: | |
| if st.checkbox("Validate All", disabled=disable, help=help): | |
| st.session_state["validation"].extend(val_variables) | |
| st.success("All media variables are validated ✅") | |
| if len( | |
| set(st.session_state["validation"]).intersection(val_variables) | |
| ) != len(val_variables): | |
| validation_data = pd.DataFrame( | |
| { | |
| "Validate": [ | |
| (True if col in st.session_state["validation"] else False) | |
| for col in val_variables | |
| ], | |
| "Variables": val_variables, | |
| } | |
| ) | |
| sorted_validation_df = validation_data.sort_values( | |
| by="Variables", ascending=True, na_position="first" | |
| ) | |
| cols3 = st.columns([1, 30]) | |
| with cols3[1]: | |
| validation_df = st.data_editor( | |
| sorted_validation_df, | |
| # column_config={ | |
| # 'Validate':st.column_config.CheckboxColumn(wi) | |
| # }, | |
| column_config={ | |
| "Validate": st.column_config.CheckboxColumn( | |
| default=False, | |
| width=100, | |
| ), | |
| "Variables": st.column_config.TextColumn(width=1000), | |
| }, | |
| hide_index=True, | |
| ) | |
| selected_rows = validation_df[validation_df["Validate"] == True][ | |
| "Variables" | |
| ] | |
| # st.write(selected_rows) | |
| st.session_state["validation"].extend(selected_rows) | |
| st.session_state["project_dct"]["data_validation"][ | |
| "validated_variables" | |
| ] = st.session_state["validation"] | |
| not_validated_variables = [ | |
| col | |
| for col in val_variables | |
| if col not in st.session_state["validation"] | |
| ] | |
| if not_validated_variables: | |
| not_validated_message = f'The following variables are not validated:\n{" , ".join(not_validated_variables)}' | |
| st.warning(not_validated_message) | |
| with st.expander("Non-Media Variables Analysis"): | |
| if len(Non_media_variables) == 0: | |
| st.warning("Non-Media variables not present") | |
| else: | |
| selected_columns_row4 = st.selectbox( | |
| "Select Channel", | |
| Non_media_variables, | |
| format_func=format_display, | |
| index=st.session_state["project_dct"]["data_validation"][ | |
| "Non_media_variables" | |
| ], | |
| ) | |
| st.session_state["project_dct"]["data_validation"][ | |
| "Non_media_variables" | |
| ] = Non_media_variables.index(selected_columns_row4) | |
| # # Create the dual-axis line plot | |
| fig_row4 = line_plot( | |
| st.session_state["Cleaned_data_panel"], | |
| x_col="date", | |
| y1_cols=[selected_columns_row4], | |
| y2_cols=[target_column], | |
| title=f"Analysis of {selected_columns_row4} and {target_column} Over Time", | |
| ) | |
| st.plotly_chart(fig_row4, use_container_width=True) | |
| selected_non_media = selected_columns_row4 | |
| sum_df = st.session_state["Cleaned_data_panel"][ | |
| ["date", selected_non_media, target_column] | |
| ] | |
| sum_df["Year"] = pd.to_datetime( | |
| st.session_state["Cleaned_data_panel"]["date"] | |
| ).dt.year | |
| # st.dataframe(df) | |
| # st.dataframe(sum_df.head(2)) | |
| sum_df = sum_df.drop("date", axis=1).groupby("Year").agg("sum") | |
| sum_df.loc["Grand Total"] = sum_df.sum() | |
| sum_df = sum_df.applymap(format_numbers) | |
| sum_df.fillna("-", inplace=True) | |
| sum_df = sum_df.replace({"0.0": "-", "nan": "-"}) | |
| st.markdown("### Summary") | |
| st.dataframe(sum_df, use_container_width=True) | |
| with st.expander("Correlation Analysis"): | |
| options = list( | |
| st.session_state["Cleaned_data_panel"].select_dtypes(np.number).columns | |
| ) | |
| if "correlation" not in st.session_state["project_dct"]["data_import"]: | |
| st.session_state["project_dct"]["data_import"]["correlation"]=[] | |
| selected_options = st.multiselect( | |
| "Select Variables for Correlation Plot", | |
| [var for var in options if var != target_column], | |
| default=st.session_state["project_dct"]["data_import"]["correlation"], | |
| ) | |
| st.session_state["project_dct"]["data_import"]["correlation"] = selected_options | |
| st.pyplot( | |
| correlation_plot( | |
| st.session_state["Cleaned_data_panel"], | |
| selected_options, | |
| target_column, | |
| ) | |
| ) | |
| if st.button("Save Changes", use_container_width=True): | |
| # Update DB | |
| update_db( | |
| prj_id=st.session_state["project_number"], | |
| page_nam="Data Validation and Insights", | |
| file_nam="project_dct", | |
| pkl_obj=pickle.dumps(st.session_state["project_dct"]), | |
| schema=schema, | |
| ) | |
| st.success("Changes saved") | |