Spaces:
Build error
Build error
| # Importing necessary libraries | |
| import streamlit as st | |
| st.set_page_config( | |
| page_title="AI Model Transformations", | |
| page_icon="⚖️", | |
| layout="wide", | |
| initial_sidebar_state="collapsed", | |
| ) | |
| import sys | |
| import pickle | |
| import traceback | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| from post_gres_cred import db_cred | |
| from log_application import log_message | |
| from utilities import ( | |
| set_header, | |
| load_local_css, | |
| update_db, | |
| project_selection, | |
| delete_entries, | |
| retrieve_pkl_object, | |
| ) | |
| from constants import ( | |
| predefined_defaults, | |
| lead_min_value, | |
| lead_max_value, | |
| lead_step, | |
| lag_min_value, | |
| lag_max_value, | |
| lag_step, | |
| moving_average_min_value, | |
| moving_average_max_value, | |
| moving_average_step, | |
| saturation_min_value, | |
| saturation_max_value, | |
| saturation_step, | |
| power_min_value, | |
| power_max_value, | |
| power_step, | |
| adstock_min_value, | |
| adstock_max_value, | |
| adstock_step, | |
| display_max_col, | |
| ) | |
| schema = db_cred["schema"] | |
| load_local_css("styles.css") | |
| set_header() | |
| # Initialize project name session state | |
| if "project_name" not in st.session_state: | |
| st.session_state["project_name"] = None | |
| # Fetch project dictionary | |
| if "project_dct" not in st.session_state: | |
| project_selection() | |
| st.stop() | |
| # Display Username and Project Name | |
| if "username" in st.session_state and st.session_state["username"] is not None: | |
| cols1 = st.columns([2, 1]) | |
| with cols1[0]: | |
| st.markdown(f"**Welcome {st.session_state['username']}**") | |
| with cols1[1]: | |
| st.markdown(f"**Current Project: {st.session_state['project_name']}**") | |
| # Load saved data from project dictionary | |
| if st.session_state["project_dct"]["data_import"]["imputed_tool_df"] is None: | |
| st.warning( | |
| "The data import is incomplete. Please go back to the Data Import page and complete the save.", | |
| icon="🔙", | |
| ) | |
| # Log message | |
| log_message( | |
| "warning", | |
| "The data import is incomplete. Please go back to the Data Import page and complete the save.", | |
| "Transformations", | |
| ) | |
| st.stop() | |
| else: | |
| final_df_loaded = st.session_state["project_dct"]["data_import"][ | |
| "imputed_tool_df" | |
| ].copy() | |
| bin_dict_loaded = st.session_state["project_dct"]["data_import"][ | |
| "category_dict" | |
| ].copy() | |
| unique_panels = st.session_state["project_dct"]["data_import"][ | |
| "unique_panels" | |
| ].copy() | |
| # Initialize project dictionary data | |
| if st.session_state["project_dct"]["transformations"]["final_df"] is None: | |
| st.session_state["project_dct"]["transformations"][ | |
| "final_df" | |
| ] = final_df_loaded # Default as original dataframe | |
| # Extract original columns for specified categories | |
| original_columns = { | |
| category: bin_dict_loaded[category] | |
| for category in ["Media", "Internal", "Exogenous"] | |
| if category in bin_dict_loaded | |
| } | |
| # Retrive Panel columns | |
| panel = ["panel"] if len(unique_panels) > 1 else [] | |
| # Function to clear model metadata | |
| def clear_pages(): | |
| # Reset Pages | |
| st.session_state["project_dct"]["model_build"] = { | |
| "sel_target_col": None, | |
| "all_iters_check": False, | |
| "iterations": 0, | |
| "build_button": False, | |
| "show_results_check": False, | |
| "session_state_saved": {}, | |
| } | |
| st.session_state["project_dct"]["model_tuning"] = { | |
| "sel_target_col": None, | |
| "sel_model": {}, | |
| "flag_expander": False, | |
| "start_date_default": None, | |
| "end_date_default": None, | |
| "repeat_default": "No", | |
| "flags": {}, | |
| "select_all_flags_check": {}, | |
| "selected_flags": {}, | |
| "trend_check": False, | |
| "week_num_check": False, | |
| "sine_cosine_check": False, | |
| "session_state_saved": {}, | |
| } | |
| st.session_state["project_dct"]["saved_model_results"] = { | |
| "selected_options": None, | |
| "model_grid_sel": [1], | |
| } | |
| if "model_results_df" in st.session_state: | |
| del st.session_state["model_results_df"] | |
| if "model_results_data" in st.session_state: | |
| del st.session_state["model_results_data"] | |
| if "coefficients_df" in st.session_state: | |
| del st.session_state["coefficients_df"] | |
| # Function to update transformation change | |
| def transformation_change(category, transformation, key): | |
| st.session_state["project_dct"]["transformations"][category][transformation] = ( | |
| st.session_state[key] | |
| ) | |
| # Function to update specific transformation change | |
| def transformation_specific_change(channel_name, transformation, key): | |
| st.session_state["project_dct"]["transformations"]["Specific"][transformation][ | |
| channel_name | |
| ] = st.session_state[key] | |
| # Function to update transformations to apply change | |
| def transformations_to_apply_change(category, key): | |
| st.session_state["project_dct"]["transformations"][category][key] = ( | |
| st.session_state[key] | |
| ) | |
| # Function to update channel select specific change | |
| def channel_select_specific_change(): | |
| st.session_state["project_dct"]["transformations"]["Specific"][ | |
| "channel_select_specific" | |
| ] = st.session_state["channel_select_specific"] | |
| # Function to update specific transformation change | |
| def specific_transformation_change(specific_transformation_key): | |
| st.session_state["project_dct"]["transformations"]["Specific"][ | |
| specific_transformation_key | |
| ] = st.session_state[specific_transformation_key] | |
| # Function to build transformation widgets | |
| def transformation_widgets(category, transform_params, date_granularity): | |
| # Transformation Options | |
| transformation_options = { | |
| "Media": [ | |
| "Lag", | |
| "Moving Average", | |
| "Saturation", | |
| "Power", | |
| "Adstock", | |
| ], | |
| "Internal": ["Lead", "Lag", "Moving Average"], | |
| "Exogenous": ["Lead", "Lag", "Moving Average"], | |
| } | |
| # Define a helper function to create widgets for each transformation | |
| def create_transformation_widgets(column, transformations): | |
| with column: | |
| for transformation in transformations: | |
| transformation_key = f"{transformation}_{category}" | |
| slider_value = st.session_state["project_dct"]["transformations"][ | |
| category | |
| ].get(transformation, predefined_defaults[transformation]) | |
| # Conditionally create widgets for selected transformations | |
| if transformation == "Lead": | |
| st.markdown(f"**{transformation} ({date_granularity})**") | |
| lead = st.slider( | |
| label="Lead periods", | |
| min_value=lead_min_value, | |
| max_value=lead_max_value, | |
| step=lead_step, | |
| value=slider_value, | |
| key=transformation_key, | |
| label_visibility="collapsed", | |
| on_change=transformation_change, | |
| args=( | |
| category, | |
| transformation, | |
| transformation_key, | |
| ), | |
| ) | |
| start = lead[0] | |
| end = lead[1] | |
| step = lead_step | |
| transform_params[category][transformation] = np.arange( | |
| start, end + step, step | |
| ) | |
| if transformation == "Lag": | |
| st.markdown(f"**{transformation} ({date_granularity})**") | |
| lag = st.slider( | |
| label="Lag periods", | |
| min_value=lag_min_value, | |
| max_value=lag_max_value, | |
| step=lag_step, | |
| value=slider_value, | |
| key=transformation_key, | |
| label_visibility="collapsed", | |
| on_change=transformation_change, | |
| args=( | |
| category, | |
| transformation, | |
| transformation_key, | |
| ), | |
| ) | |
| start = lag[0] | |
| end = lag[1] | |
| step = lag_step | |
| transform_params[category][transformation] = np.arange( | |
| start, end + step, step | |
| ) | |
| if transformation == "Moving Average": | |
| st.markdown(f"**{transformation} ({date_granularity})**") | |
| window = st.slider( | |
| label="Window size for Moving Average", | |
| min_value=moving_average_min_value, | |
| max_value=moving_average_max_value, | |
| step=moving_average_step, | |
| value=slider_value, | |
| key=transformation_key, | |
| label_visibility="collapsed", | |
| on_change=transformation_change, | |
| args=( | |
| category, | |
| transformation, | |
| transformation_key, | |
| ), | |
| ) | |
| start = window[0] | |
| end = window[1] | |
| step = moving_average_step | |
| transform_params[category][transformation] = np.arange( | |
| start, end + step, step | |
| ) | |
| if transformation == "Saturation": | |
| st.markdown(f"**{transformation} (%)**") | |
| saturation_point = st.slider( | |
| label="Saturation Percentage", | |
| min_value=saturation_min_value, | |
| max_value=saturation_max_value, | |
| step=saturation_step, | |
| value=slider_value, | |
| key=transformation_key, | |
| label_visibility="collapsed", | |
| on_change=transformation_change, | |
| args=( | |
| category, | |
| transformation, | |
| transformation_key, | |
| ), | |
| ) | |
| start = saturation_point[0] | |
| end = saturation_point[1] | |
| step = saturation_step | |
| transform_params[category][transformation] = np.arange( | |
| start, end + step, step | |
| ) | |
| if transformation == "Power": | |
| st.markdown(f"**{transformation}**") | |
| power = st.slider( | |
| label="Power", | |
| min_value=power_min_value, | |
| max_value=power_max_value, | |
| step=power_step, | |
| value=slider_value, | |
| key=transformation_key, | |
| label_visibility="collapsed", | |
| on_change=transformation_change, | |
| args=( | |
| category, | |
| transformation, | |
| transformation_key, | |
| ), | |
| ) | |
| start = power[0] | |
| end = power[1] | |
| step = power_step | |
| transform_params[category][transformation] = np.arange( | |
| start, end + step, step | |
| ) | |
| if transformation == "Adstock": | |
| st.markdown(f"**{transformation}**") | |
| rate = st.slider( | |
| label="Decay Factor", | |
| min_value=adstock_min_value, | |
| max_value=adstock_max_value, | |
| step=adstock_step, | |
| value=slider_value, | |
| key=transformation_key, | |
| label_visibility="collapsed", | |
| on_change=transformation_change, | |
| args=( | |
| category, | |
| transformation, | |
| transformation_key, | |
| ), | |
| ) | |
| start = rate[0] | |
| end = rate[1] | |
| step = adstock_step | |
| adstock_range = [ | |
| round(a, 3) for a in np.arange(start, end + step, step) | |
| ] | |
| transform_params[category][transformation] = np.array(adstock_range) | |
| with st.expander(f"All {category} Transformations", expanded=True): | |
| transformation_key = f"transformation_{category}" | |
| # Select which transformations to apply | |
| sel_transformations = st.session_state["project_dct"]["transformations"][ | |
| category | |
| ].get(transformation_key, []) | |
| # Reset default selected channels list if options are changed | |
| for channel in sel_transformations: | |
| if channel not in transformation_options[category]: | |
| ( | |
| st.session_state["project_dct"]["transformations"][category][ | |
| transformation_key | |
| ], | |
| sel_transformations, | |
| ) = ([], []) | |
| transformations_to_apply = st.multiselect( | |
| label="Select transformations to apply", | |
| options=transformation_options[category], | |
| default=sel_transformations, | |
| key=transformation_key, | |
| on_change=transformations_to_apply_change, | |
| args=( | |
| category, | |
| transformation_key, | |
| ), | |
| ) | |
| # Determine the number of transformations to put in each column | |
| transformations_per_column = ( | |
| len(transformations_to_apply) // 2 + len(transformations_to_apply) % 2 | |
| ) | |
| # Create two columns | |
| col1, col2 = st.columns(2) | |
| # Assign transformations to each column | |
| transformations_col1 = transformations_to_apply[:transformations_per_column] | |
| transformations_col2 = transformations_to_apply[transformations_per_column:] | |
| # Create widgets in each column | |
| create_transformation_widgets(col1, transformations_col1) | |
| create_transformation_widgets(col2, transformations_col2) | |
| # Define a helper function to create widgets for each specific transformation | |
| def create_specific_transformation_widgets( | |
| column, | |
| transformations, | |
| channel_name, | |
| date_granularity, | |
| specific_transform_params, | |
| ): | |
| with column: | |
| for transformation in transformations: | |
| transformation_key = f"{transformation}_{channel_name}_specific" | |
| if ( | |
| transformation | |
| not in st.session_state["project_dct"]["transformations"]["Specific"] | |
| ): | |
| st.session_state["project_dct"]["transformations"]["Specific"][ | |
| transformation | |
| ] = {} | |
| slider_value = st.session_state["project_dct"]["transformations"][ | |
| "Specific" | |
| ][transformation].get(channel_name, predefined_defaults[transformation]) | |
| # Conditionally create widgets for selected transformations | |
| if transformation == "Lead": | |
| st.markdown(f"**Lead ({date_granularity})**") | |
| lead = st.slider( | |
| label="Lead periods", | |
| min_value=lead_min_value, | |
| max_value=lead_max_value, | |
| step=lead_step, | |
| value=slider_value, | |
| key=transformation_key, | |
| label_visibility="collapsed", | |
| on_change=transformation_specific_change, | |
| args=( | |
| channel_name, | |
| transformation, | |
| transformation_key, | |
| ), | |
| ) | |
| start = lead[0] | |
| end = lead[1] | |
| step = lead_step | |
| specific_transform_params[channel_name]["Lead"] = np.arange( | |
| start, end + step, step | |
| ) | |
| if transformation == "Lag": | |
| st.markdown(f"**Lag ({date_granularity})**") | |
| lag = st.slider( | |
| label="Lag periods", | |
| min_value=lag_min_value, | |
| max_value=lag_max_value, | |
| step=lag_step, | |
| value=slider_value, | |
| key=transformation_key, | |
| label_visibility="collapsed", | |
| on_change=transformation_specific_change, | |
| args=( | |
| channel_name, | |
| transformation, | |
| transformation_key, | |
| ), | |
| ) | |
| start = lag[0] | |
| end = lag[1] | |
| step = lag_step | |
| specific_transform_params[channel_name]["Lag"] = np.arange( | |
| start, end + step, step | |
| ) | |
| if transformation == "Moving Average": | |
| st.markdown(f"**Moving Average ({date_granularity})**") | |
| window = st.slider( | |
| label="Window size for Moving Average", | |
| min_value=moving_average_min_value, | |
| max_value=moving_average_max_value, | |
| step=moving_average_step, | |
| value=slider_value, | |
| key=transformation_key, | |
| label_visibility="collapsed", | |
| on_change=transformation_specific_change, | |
| args=( | |
| channel_name, | |
| transformation, | |
| transformation_key, | |
| ), | |
| ) | |
| start = window[0] | |
| end = window[1] | |
| step = moving_average_step | |
| specific_transform_params[channel_name]["Moving Average"] = np.arange( | |
| start, end + step, step | |
| ) | |
| if transformation == "Saturation": | |
| st.markdown("**Saturation (%)**") | |
| saturation_point = st.slider( | |
| label="Saturation Percentage", | |
| min_value=saturation_min_value, | |
| max_value=saturation_max_value, | |
| step=saturation_step, | |
| value=slider_value, | |
| key=transformation_key, | |
| label_visibility="collapsed", | |
| on_change=transformation_specific_change, | |
| args=( | |
| channel_name, | |
| transformation, | |
| transformation_key, | |
| ), | |
| ) | |
| start = saturation_point[0] | |
| end = saturation_point[1] | |
| step = saturation_step | |
| specific_transform_params[channel_name]["Saturation"] = np.arange( | |
| start, end + step, step | |
| ) | |
| if transformation == "Power": | |
| st.markdown("**Power**") | |
| power = st.slider( | |
| label="Power", | |
| min_value=power_min_value, | |
| max_value=power_max_value, | |
| step=power_step, | |
| value=slider_value, | |
| key=transformation_key, | |
| label_visibility="collapsed", | |
| on_change=transformation_specific_change, | |
| args=( | |
| channel_name, | |
| transformation, | |
| transformation_key, | |
| ), | |
| ) | |
| start = power[0] | |
| end = power[1] | |
| step = power_step | |
| specific_transform_params[channel_name]["Power"] = np.arange( | |
| start, end + step, step | |
| ) | |
| if transformation == "Adstock": | |
| st.markdown("**Adstock**") | |
| rate = st.slider( | |
| label="Decay Factor", | |
| min_value=adstock_min_value, | |
| max_value=adstock_max_value, | |
| step=adstock_step, | |
| value=slider_value, | |
| key=transformation_key, | |
| label_visibility="collapsed", | |
| on_change=transformation_specific_change, | |
| args=( | |
| channel_name, | |
| transformation, | |
| transformation_key, | |
| ), | |
| ) | |
| start = rate[0] | |
| end = rate[1] | |
| step = adstock_step | |
| adstock_range = [ | |
| round(a, 3) for a in np.arange(start, end + step, step) | |
| ] | |
| specific_transform_params[channel_name]["Adstock"] = np.array( | |
| adstock_range | |
| ) | |
| # Function to apply Lag transformation | |
| def apply_lag(df, lag): | |
| return df.shift(lag) | |
| # Function to apply Lead transformation | |
| def apply_lead(df, lead): | |
| return df.shift(-lead) | |
| # Function to apply Moving Average transformation | |
| def apply_moving_average(df, window_size): | |
| return df.rolling(window=window_size).mean() | |
| # Function to apply Saturation transformation | |
| def apply_saturation(df, saturation_percent_100): | |
| # Convert percentage to fraction | |
| saturation_percent = min(max(saturation_percent_100, 0.01), 99.99) / 100.0 | |
| # Get the maximum and minimum values | |
| column_max = df.max() | |
| column_min = df.min() | |
| # If the data is constant, scale it directly | |
| if column_min == column_max: | |
| return df.apply(lambda x: x * saturation_percent) | |
| # Compute the saturation point from the data range | |
| saturation_point = (column_min + saturation_percent * column_max) / 2 | |
| # Calculate steepness for the saturation curve | |
| numerator = np.log((1 / saturation_percent) - 1) | |
| denominator = np.log(saturation_point / column_max) | |
| steepness = numerator / denominator | |
| # Apply the saturation transformation | |
| transformed_series = df.apply( | |
| lambda x: (1 / (1 + (saturation_point / (x if x != 0 else 1e-9)) ** steepness)) * x | |
| ) | |
| return transformed_series | |
| # Function to apply Power transformation | |
| def apply_power(df, power): | |
| return df**power | |
| # Function to apply Adstock transformation | |
| def apply_adstock(df, factor): | |
| x = 0 | |
| # Use the walrus operator to update x iteratively with the Adstock formula | |
| adstock_var = [x := x * factor + v for v in df] | |
| ans = pd.Series(adstock_var, index=df.index) | |
| return ans | |
| # Function to generate transformed columns names | |
| def generate_transformed_columns( | |
| original_columns, transform_params, specific_transform_params | |
| ): | |
| transformed_columns, summary = {}, {} | |
| for category, columns in original_columns.items(): | |
| for column in columns: | |
| transformed_columns[column] = [] | |
| summary_details = ( | |
| [] | |
| ) # List to hold transformation details for the current column | |
| if ( | |
| column in specific_transform_params.keys() | |
| and len(specific_transform_params[column]) > 0 | |
| ): | |
| for transformation, values in specific_transform_params[column].items(): | |
| # Generate transformed column names for each value | |
| for value in values: | |
| transformed_name = f"{column}@{transformation}_{value}" | |
| transformed_columns[column].append(transformed_name) | |
| # Format the values list as a string with commas and "and" before the last item | |
| if len(values) > 1: | |
| formatted_values = ( | |
| ", ".join(map(str, values[:-1])) + " and " + str(values[-1]) | |
| ) | |
| else: | |
| formatted_values = str(values[0]) | |
| # Add transformation details | |
| summary_details.append(f"{transformation} ({formatted_values})") | |
| else: | |
| if category in transform_params: | |
| for transformation, values in transform_params[category].items(): | |
| # Generate transformed column names for each value | |
| if column not in specific_transform_params.keys(): | |
| for value in values: | |
| transformed_name = f"{column}@{transformation}_{value}" | |
| transformed_columns[column].append(transformed_name) | |
| # Format the values list as a string with commas and "and" before the last item | |
| if len(values) > 1: | |
| formatted_values = ( | |
| ", ".join(map(str, values[:-1])) | |
| + " and " | |
| + str(values[-1]) | |
| ) | |
| else: | |
| formatted_values = str(values[0]) | |
| # Add transformation details | |
| summary_details.append( | |
| f"{transformation} ({formatted_values})" | |
| ) | |
| else: | |
| summary_details = ["No transformation selected"] | |
| # Only add to summary if there are transformation details for the column | |
| if summary_details: | |
| formatted_summary = "⮕ ".join(summary_details) | |
| # Use <strong> tags to make the column name bold | |
| summary[column] = f"<strong>{column}</strong>: {formatted_summary}" | |
| # Generate a comprehensive summary string for all columns | |
| summary_items = [ | |
| f"{idx + 1}. {details}" for idx, details in enumerate(summary.values()) | |
| ] | |
| summary_string = "\n".join(summary_items) | |
| return transformed_columns, summary_string | |
| # Function to transform Dataframe slice | |
| def transform_slice( | |
| transform_params, | |
| transformation_functions, | |
| panel, | |
| df, | |
| df_slice, | |
| category, | |
| category_df, | |
| ): | |
| # Iterate through each transformation and its parameters for the current category | |
| for transformation, parameters in transform_params[category].items(): | |
| transformation_function = transformation_functions[transformation] | |
| # Check if there is panel data to group by | |
| if len(panel) > 0: | |
| # Apply the transformation to each group | |
| category_df = pd.concat( | |
| [ | |
| df_slice.groupby(panel) | |
| .transform(transformation_function, p) | |
| .add_suffix(f"@{transformation}_{p}") | |
| for p in parameters | |
| ], | |
| axis=1, | |
| ) | |
| # Replace all NaN or null values in category_df with 0 | |
| category_df.fillna(0, inplace=True) | |
| # Update df_slice | |
| df_slice = pd.concat( | |
| [df[panel], category_df], | |
| axis=1, | |
| ) | |
| else: | |
| for p in parameters: | |
| # Apply the transformation function to each column | |
| temp_df = df_slice.apply( | |
| lambda x: transformation_function(x, p), axis=0 | |
| ).rename( | |
| lambda x: f"{x}@{transformation}_{p}", | |
| axis="columns", | |
| ) | |
| # Concatenate the transformed DataFrame slice to the category DataFrame | |
| category_df = pd.concat([category_df, temp_df], axis=1) | |
| # Replace all NaN or null values in category_df with 0 | |
| category_df.fillna(0, inplace=True) | |
| # Update df_slice | |
| df_slice = pd.concat( | |
| [df[panel], category_df], | |
| axis=1, | |
| ) | |
| return category_df, df, df_slice | |
| # Function to apply transformations to DataFrame slices based on specified categories and parameters | |
| def apply_category_transformations( | |
| df_main, bin_dict, transform_params, panel, specific_transform_params | |
| ): | |
| # Dictionary for function mapping | |
| transformation_functions = { | |
| "Lead": apply_lead, | |
| "Lag": apply_lag, | |
| "Moving Average": apply_moving_average, | |
| "Saturation": apply_saturation, | |
| "Power": apply_power, | |
| "Adstock": apply_adstock, | |
| } | |
| # List to collect all transformed DataFrames | |
| transformed_dfs = [] | |
| # Iterate through each category specified in transform_params | |
| for category in ["Media", "Exogenous", "Internal"]: | |
| if ( | |
| category not in transform_params | |
| or category not in bin_dict | |
| or not transform_params[category] | |
| ): | |
| continue # Skip categories without transformations | |
| # Initialize category_df as an empty DataFrame | |
| category_df = pd.DataFrame() | |
| # Slice the DataFrame based on the columns specified in bin_dict for the current category | |
| df_slice = df_main[bin_dict[category] + panel].copy() | |
| # Drop the column from df_slice to skip specific transformations | |
| df_slice = df_slice.drop( | |
| columns=list(specific_transform_params.keys()), errors="ignore" | |
| ).copy() | |
| category_df, df, df_slice_updated = transform_slice( | |
| transform_params.copy(), | |
| transformation_functions.copy(), | |
| panel, | |
| df_main.copy(), | |
| df_slice.copy(), | |
| category, | |
| category_df.copy(), | |
| ) | |
| # Append the transformed category DataFrame to the list if it's not empty | |
| if not category_df.empty: | |
| transformed_dfs.append(category_df) | |
| # Apply channel specific transforms | |
| for channel_specific in specific_transform_params: | |
| # Initialize category_df as an empty DataFrame | |
| category_df = pd.DataFrame() | |
| df_slice_specific = df_main[[channel_specific] + panel].copy() | |
| transform_params_specific = { | |
| "Media": specific_transform_params[channel_specific] | |
| } | |
| category_df, df, df_slice_specific_updated = transform_slice( | |
| transform_params_specific.copy(), | |
| transformation_functions.copy(), | |
| panel, | |
| df_main.copy(), | |
| df_slice_specific.copy(), | |
| "Media", | |
| category_df.copy(), | |
| ) | |
| # Append the transformed category DataFrame to the list if it's not empty | |
| if not category_df.empty: | |
| transformed_dfs.append(category_df) | |
| # If category_df has been modified, concatenate it with the panel and response metrics from the original DataFrame | |
| if len(transformed_dfs) > 0: | |
| final_df = pd.concat([df_main] + transformed_dfs, axis=1) | |
| else: | |
| # If no transformations were applied, use the original DataFrame | |
| final_df = df_main | |
| # Find columns with '@' in their names | |
| columns_with_at = [col for col in final_df.columns if "@" in col] | |
| # Create a set of columns to drop | |
| columns_to_drop = set() | |
| # Iterate through columns with '@' to find shorter names to drop | |
| for col in columns_with_at: | |
| base_name = col.split("@")[0] | |
| for other_col in columns_with_at: | |
| if other_col.startswith(base_name) and len(other_col.split("@")) > len( | |
| col.split("@") | |
| ): | |
| columns_to_drop.add(col) | |
| break | |
| # Drop the identified columns from the DataFrame | |
| final_df.drop(columns=list(columns_to_drop), inplace=True) | |
| return final_df | |
| # Function to infers the granularity of the date column in a DataFrame | |
| def infer_date_granularity(df): | |
| # Find the most common difference | |
| common_freq = pd.Series(df["date"].unique()).diff().dt.days.dropna().mode()[0] | |
| # Map the most common difference to a granularity | |
| if common_freq == 1: | |
| return "daily" | |
| elif common_freq == 7: | |
| return "weekly" | |
| elif 28 <= common_freq <= 31: | |
| return "monthly" | |
| else: | |
| return "irregular" | |
| # Function to clean display DataFrame | |
| def clean_display_df(df, display_max_col=500): | |
| # Sort by 'panel' and 'date' | |
| sort_columns = ["panel", "date"] | |
| sorted_df = df.sort_values(by=sort_columns, ascending=True, na_position="first") | |
| # Drop duplicate columns | |
| sorted_df = sorted_df.loc[:, ~sorted_df.columns.duplicated()] | |
| # Check if the DataFrame has more than display_max_col columns | |
| exceeds_max_col = sorted_df.shape[1] > display_max_col | |
| if exceeds_max_col: | |
| # Create a new DataFrame with 'date' and 'panel' at the start | |
| display_df = sorted_df[["date", "panel"]] | |
| # Add the next display_max_col - 2 columns (as 'date' and 'panel' already occupy 2 columns) | |
| additional_columns = sorted_df.columns.difference(["date", "panel"]).tolist()[ | |
| : display_max_col - 2 | |
| ] | |
| display_df = pd.concat([display_df, sorted_df[additional_columns]], axis=1) | |
| else: | |
| # Ensure 'date' and 'panel' are the first two columns in the final display DataFrame | |
| column_order = ["date", "panel"] + sorted_df.columns.difference( | |
| ["date", "panel"] | |
| ).tolist() | |
| display_df = sorted_df[column_order] | |
| # Return the display DataFrame and whether it exceeds 500 columns | |
| return display_df, exceeds_max_col | |
| ######################################################################################################################################################### | |
| # User input for transformations | |
| ######################################################################################################################################################### | |
| try: | |
| # Page Title | |
| st.title("AI Model Transformations") | |
| # Infer date granularity | |
| date_granularity = infer_date_granularity(final_df_loaded) | |
| # Initialize the main dictionary to store the transformation parameters for each category | |
| transform_params = {"Media": {}, "Internal": {}, "Exogenous": {}} | |
| st.markdown("### Select Transformations to Apply") | |
| with st.expander("Specific Media Transformations"): | |
| # Select which transformations to apply | |
| sel_channel_specific = st.session_state["project_dct"]["transformations"][ | |
| "Specific" | |
| ].get("channel_select_specific", []) | |
| # Reset default selected channels list if options are changed | |
| for channel in sel_channel_specific: | |
| if channel not in bin_dict_loaded["Media"]: | |
| ( | |
| st.session_state["project_dct"]["transformations"]["Specific"][ | |
| "channel_select_specific" | |
| ], | |
| sel_channel_specific, | |
| ) = ([], []) | |
| select_specific_channels = st.multiselect( | |
| label="Select channel variable", | |
| default=sel_channel_specific, | |
| options=bin_dict_loaded["Media"], | |
| key="channel_select_specific", | |
| on_change=channel_select_specific_change, | |
| max_selections=30, | |
| ) | |
| specific_transform_params = {} | |
| for select_specific_channel in select_specific_channels: | |
| specific_transform_params[select_specific_channel] = {} | |
| st.divider() | |
| channel_name = str(select_specific_channel).replace("_", " ").title() | |
| st.markdown(f"###### {channel_name}") | |
| specific_transformation_key = ( | |
| f"specific_transformation_{select_specific_channel}_Media" | |
| ) | |
| transformations_options = [ | |
| "Lag", | |
| "Moving Average", | |
| "Saturation", | |
| "Power", | |
| "Adstock", | |
| ] | |
| # Select which transformations to apply | |
| sel_transformations = st.session_state["project_dct"]["transformations"][ | |
| "Specific" | |
| ].get(specific_transformation_key, []) | |
| # Reset default selected channels list if options are changed | |
| for channel in sel_transformations: | |
| if channel not in transformations_options: | |
| ( | |
| st.session_state["project_dct"]["transformations"]["Specific"][ | |
| specific_transformation_key | |
| ], | |
| sel_channel_specific, | |
| ) = ([], []) | |
| transformations_to_apply = st.multiselect( | |
| label="Select transformations to apply", | |
| options=transformations_options, | |
| default=sel_transformations, | |
| key=specific_transformation_key, | |
| on_change=specific_transformation_change, | |
| args=(specific_transformation_key,), | |
| ) | |
| # Determine the number of transformations to put in each column | |
| transformations_per_column = ( | |
| len(transformations_to_apply) // 2 + len(transformations_to_apply) % 2 | |
| ) | |
| # Create two columns | |
| col1, col2 = st.columns(2) | |
| # Assign transformations to each column | |
| transformations_col1 = transformations_to_apply[:transformations_per_column] | |
| transformations_col2 = transformations_to_apply[transformations_per_column:] | |
| # Create widgets in each column | |
| create_specific_transformation_widgets( | |
| col1, | |
| transformations_col1, | |
| select_specific_channel, | |
| date_granularity, | |
| specific_transform_params, | |
| ) | |
| create_specific_transformation_widgets( | |
| col2, | |
| transformations_col2, | |
| select_specific_channel, | |
| date_granularity, | |
| specific_transform_params, | |
| ) | |
| # Create Widgets | |
| for category in ["Media", "Internal", "Exogenous"]: | |
| # Skip Internal | |
| if category == "Internal": | |
| continue | |
| # Skip category if no column available | |
| elif ( | |
| category not in bin_dict_loaded.keys() | |
| or len(bin_dict_loaded[category]) == 0 | |
| ): | |
| st.info( | |
| f"{str(category).title()} category has no column associated with it. Skipping transformation step for this category.", | |
| icon="💬", | |
| ) | |
| continue | |
| transformation_widgets(category, transform_params, date_granularity) | |
| ######################################################################################################################################################### | |
| # Apply transformations | |
| ######################################################################################################################################################### | |
| # Reset transformation selection to default | |
| button_col = st.columns(2) | |
| with button_col[1]: | |
| if st.button("Reset to Default", use_container_width=True): | |
| st.session_state["project_dct"]["transformations"]["Media"] = {} | |
| st.session_state["project_dct"]["transformations"]["Exogenous"] = {} | |
| st.session_state["project_dct"]["transformations"]["Internal"] = {} | |
| st.session_state["project_dct"]["transformations"]["Specific"] = {} | |
| # Log message | |
| log_message( | |
| "info", | |
| "All persistent selections have been reset to their default settings and cleared.", | |
| "Transformations", | |
| ) | |
| st.rerun() | |
| # Apply category-based transformations to the DataFrame | |
| with button_col[0]: | |
| if st.button("Accept and Proceed", use_container_width=True): | |
| with st.spinner("Applying transformations ..."): | |
| final_df = apply_category_transformations( | |
| final_df_loaded.copy(), | |
| bin_dict_loaded.copy(), | |
| transform_params.copy(), | |
| panel.copy(), | |
| specific_transform_params.copy(), | |
| ) | |
| # Generate a dictionary mapping original column names to lists of transformed column names | |
| transformed_columns_dict, summary_string = generate_transformed_columns( | |
| original_columns, transform_params, specific_transform_params | |
| ) | |
| # Store into transformed dataframe and summary session state | |
| st.session_state["project_dct"]["transformations"][ | |
| "final_df" | |
| ] = final_df | |
| st.session_state["project_dct"]["transformations"][ | |
| "summary_string" | |
| ] = summary_string | |
| # Display success message | |
| st.success("Transformation of the DataFrame is successful!", icon="✅") | |
| # Log message | |
| log_message( | |
| "info", | |
| "Transformation of the DataFrame is successful!", | |
| "Transformations", | |
| ) | |
| ######################################################################################################################################################### | |
| # Display the transformed DataFrame and summary | |
| ######################################################################################################################################################### | |
| # Display the transformed DataFrame in the Streamlit app | |
| st.markdown("### Transformed DataFrame") | |
| with st.spinner("Please wait while the transformed DataFrame is loading ..."): | |
| final_df = st.session_state["project_dct"]["transformations"]["final_df"].copy() | |
| # Clean display DataFrame | |
| display_df, exceeds_max_col = clean_display_df(final_df, display_max_col) | |
| # Check the number of columns and show only the first display_max_col if there are more | |
| if exceeds_max_col: | |
| # Display a info if the DataFrame has more than display_max_col columns | |
| st.info( | |
| f"The transformed DataFrame has more than {display_max_col} columns. Displaying only the first {display_max_col} columns.", | |
| icon="💬", | |
| ) | |
| # Display Final DataFrame | |
| st.dataframe( | |
| display_df, | |
| hide_index=True, | |
| column_config={ | |
| "date": st.column_config.DateColumn("date", format="YYYY-MM-DD") | |
| }, | |
| ) | |
| # Total rows and columns | |
| total_rows, total_columns = st.session_state["project_dct"]["transformations"][ | |
| "final_df" | |
| ].shape | |
| st.markdown( | |
| f"<p style='text-align: justify;'>The transformed DataFrame contains <strong>{total_rows}</strong> rows and <strong>{total_columns}</strong> columns.</p>", | |
| unsafe_allow_html=True, | |
| ) | |
| # Display the summary of transformations as markdown | |
| if ( | |
| "summary_string" in st.session_state["project_dct"]["transformations"] | |
| and st.session_state["project_dct"]["transformations"]["summary_string"] | |
| ): | |
| with st.expander("Summary of Transformations"): | |
| st.markdown("### Summary of Transformations") | |
| st.markdown( | |
| st.session_state["project_dct"]["transformations"][ | |
| "summary_string" | |
| ], | |
| unsafe_allow_html=True, | |
| ) | |
| ######################################################################################################################################################### | |
| # Correlation Plot | |
| ######################################################################################################################################################### | |
| # Filter out the 'date' column | |
| variables = [ | |
| col for col in final_df.columns if col.lower() not in ["date", "panel"] | |
| ] | |
| with st.expander("Transformed Variable Correlation Plot"): | |
| selected_vars = st.multiselect( | |
| label="Choose variables for correlation plot:", | |
| options=variables, | |
| max_selections=30, | |
| default=st.session_state["project_dct"]["transformations"][ | |
| "correlation_plot_selection" | |
| ], | |
| key="correlation_plot_key", | |
| ) | |
| # Calculate correlation | |
| if selected_vars: | |
| corr_df = final_df[selected_vars].corr() | |
| # Prepare text annotations with 2 decimal places | |
| annotations = [] | |
| for i in range(len(corr_df)): | |
| for j in range(len(corr_df.columns)): | |
| annotations.append( | |
| go.layout.Annotation( | |
| text=f"{corr_df.iloc[i, j]:.2f}", | |
| x=corr_df.columns[j], | |
| y=corr_df.index[i], | |
| showarrow=False, | |
| font=dict(color="black"), | |
| ) | |
| ) | |
| # Plotly correlation plot using go | |
| heatmap = go.Heatmap( | |
| z=corr_df.values, | |
| x=corr_df.columns, | |
| y=corr_df.index, | |
| colorscale="RdBu", | |
| zmin=-1, | |
| zmax=1, | |
| ) | |
| layout = go.Layout( | |
| title="Transformed Variable Correlation Plot", | |
| xaxis=dict(title="Variables"), | |
| yaxis=dict(title="Variables"), | |
| width=1000, | |
| height=1000, | |
| annotations=annotations, | |
| ) | |
| fig = go.Figure(data=[heatmap], layout=layout) | |
| st.plotly_chart(fig) | |
| else: | |
| st.write("Please select at least one variable to plot.") | |
| ######################################################################################################################################################### | |
| # Accept and Save | |
| ######################################################################################################################################################### | |
| # Check for saved model | |
| if ( | |
| retrieve_pkl_object( | |
| st.session_state["project_number"], "Model_Build", "best_models", schema | |
| ) | |
| is not None | |
| ): # db | |
| st.warning( | |
| "Saving transformations will overwrite existing ones and delete all saved models. To keep previous models, please start a new project.", | |
| icon="⚠️", | |
| ) | |
| if st.button("Accept and Save", use_container_width=True): | |
| with st.spinner("Saving Changes"): | |
| # Update correlation plot selection | |
| st.session_state["project_dct"]["transformations"][ | |
| "correlation_plot_selection" | |
| ] = st.session_state["correlation_plot_key"] | |
| # Clear model metadata | |
| clear_pages() | |
| # Update DB | |
| update_db( | |
| prj_id=st.session_state["project_number"], | |
| page_nam="Transformations", | |
| file_nam="project_dct", | |
| pkl_obj=pickle.dumps(st.session_state["project_dct"]), | |
| schema=schema, | |
| ) | |
| # Clear data from DB | |
| delete_entries( | |
| st.session_state["project_number"], | |
| ["Model_Build", "Model_Tuning"], | |
| db_cred, | |
| schema, | |
| ) | |
| # Success message | |
| st.success("Saved Successfully!", icon="💾") | |
| st.toast("Saved Successfully!", icon="💾") | |
| # Log message | |
| log_message("info", "Saved Successfully!", "Transformations") | |
| except Exception as e: | |
| # Capture the error details | |
| exc_type, exc_value, exc_traceback = sys.exc_info() | |
| error_message = "".join( | |
| traceback.format_exception(exc_type, exc_value, exc_traceback) | |
| ) | |
| # Log message | |
| log_message("error", f"An error occurred: {error_message}.", "Transformations") | |
| # Display a warning message | |
| st.warning( | |
| "Oops! Something went wrong. Please try refreshing the tool or creating a new project.", | |
| icon="⚠️", | |
| ) | |