| import html |
| import json |
| import os |
| from io import StringIO |
| import streamlit as st |
| import pandas as pd |
| from bs4 import BeautifulSoup |
| from snowflake.snowpark import Session |
|
|
| from Messaging_system.Permes import Permes |
| from Messaging_system.context_validator import Validator |
| from dotenv import load_dotenv |
| load_dotenv() |
|
|
| |
| |
| @st.cache_data |
| def load_data(file_path): |
| return pd.read_csv(file_path) |
|
|
| |
|
|
| def load_config_(file_path): |
| """ |
| Loads configuration JSON files from the local space. (mostly for loading the Snowflake connection parameters) |
| :param file_path: local path to the JSON file |
| :return: JSON file |
| """ |
| with open(file_path, 'r') as file: |
| return json.load(file) |
|
|
|
|
| |
|
|
|
|
| |
| st.set_page_config(page_title="Personalized Message Generator", page_icon=":mailbox_with_mail:", layout="wide") |
|
|
| st.markdown( |
| """ |
| <style> |
| body { |
| background-color: #000000; |
| color: #FFD700; |
| } |
| .stButton > button { |
| background-color: #FFD700; |
| color: #000000; |
| } |
| h1, h2, h3, h4, h5, h6 { |
| color: #FFD700; |
| } |
| .section { |
| margin-bottom: 30px; |
| } |
| .input-label { |
| font-size: 18px; |
| font-weight: bold; |
| margin-top: 10px; |
| } |
| </style> |
| """, |
| unsafe_allow_html=True |
| ) |
|
|
|
|
| |
|
|
| def filter_validated_users(users): |
| """ |
| Filters the input DataFrame by removing rows where the 'valid' column has the value 'False'. |
| |
| Parameters: |
| users (DataFrame): A pandas DataFrame with a 'valid' column containing strings 'True' or 'False'. |
| |
| Returns: |
| DataFrame: A filtered DataFrame containing only rows where 'valid' is 'True'. |
| """ |
| |
| users['valid'] = users['valid'].map({'True': True, 'False': False}) |
|
|
| |
| filtered_users = users[users['valid']] |
|
|
| |
| filtered_users = filtered_users.reset_index(drop=True) |
|
|
| return filtered_users |
|
|
| |
|
|
| |
|
|
|
|
|
|
| |
| def clean_html_tags(users_df): |
| """ |
| accept the data as a Pandas Dataframe and return the preprocessed dataframe. |
| This function has access to the columns that contain HTML tags and codes, Therefore it will apply cleaning |
| procedures to those columns. |
| functions to preprocess the data |
| :return: updates users_df |
| """ |
|
|
| for col in users_df.columns: |
| |
| users_df[col] = users_df[col].apply(clean_text) |
|
|
| return users_df |
|
|
|
|
| |
|
|
| def clean_text(text): |
| if isinstance(text, str): |
| |
| text = html.unescape(text) |
| |
| soup = BeautifulSoup(text, "html.parser") |
| return soup.get_text() |
| else: |
| return text |
|
|
|
|
| |
| |
| openai_api_key = os.environ.get('OPENAI_API') |
| st.session_state["openai_api_key"] = openai_api_key |
| |
|
|
|
|
| |
|
|
| def initialize_session_state(): |
| |
| st.session_state["involve_recsys_result"] = False |
| st.session_state["involve_last_interaction"] = False |
| st.session_state.valid_instructions = "" |
| st.session_state.invalid_instructions = "" |
|
|
| |
| for key in [ |
| "data", "brand","recsys_contents", "generated", "csv_output", "users_message", "messaging_mode", |
| "messaging_type", "target_column", "ugc_column", "identifier_column", "input_validator", "selected_input_features" |
| "selected_features", "additional_instructions", "segment_info", "message_style", "sample_example", |
| "CTA", "all_features", "number_of_messages", "instructionset", "segment_name", "number_of_samples", |
| "selected_source_features", "platform" |
| ]: |
| if key not in st.session_state: |
| st.session_state[key] = None |
|
|
|
|
| def upload_csv_file(): |
| st.header("Upload CSV File") |
| uploaded_file = st.file_uploader("Choose a CSV file", type="csv") |
|
|
| if uploaded_file is not None: |
| users = load_data(uploaded_file) |
| st.write(f"Data loaded from {uploaded_file.name}") |
| st.session_state.data = users |
|
|
| columns = users.columns.tolist() |
| st.subheader("Available Columns in Uploaded CSV") |
| st.write(columns) |
| return users |
| else: |
| return None |
|
|
|
|
| def select_identifier_column(users): |
| st.header("Select Identifier Column") |
| columns = users.columns.tolist() |
| identifier_column = st.selectbox("Select the identifier column", columns) |
| st.session_state.identifier_column = identifier_column |
| st.markdown("---") |
|
|
|
|
| def select_target_audience(): |
| st.header("Select Target Audience") |
| options = ["drumeo", "pianote", "guitareo", "singeo"] |
| brand = st.selectbox("Choose the brand for the users", options) |
| st.session_state.brand = brand |
| st.markdown("---") |
|
|
|
|
| def select_target_messaging_type(): |
| st.header("Select Target Messaging Type") |
| messaging_type = st.selectbox("Choose the target messaging type", ["Push Notification", "In-App Notification"]) |
|
|
| st.session_state.messaging_type = "push" if messaging_type == "Push Notification" else "app" |
| st.markdown("---") |
|
|
|
|
| def input_personalization_parameters(): |
| st.header("Personalization Parameters") |
| st.session_state.segment_info = st.text_area("Segment Info", "", placeholder="Tell us more about the users...") |
| st.session_state.CTA = st.text_area("CTA", "", placeholder="e.g., check out 'Inspired by your activity' that we have crafted just for you!") |
| st.session_state.message_style = st.text_area("Message Style", "", placeholder="(optional) e.g., be kind and friendly (it's better to be as specific as possible)") |
| st.session_state.sample_example = st.text_area("Sample Example", "", placeholder="(optional) e.g., Hello! We have crafted a perfect set of courses just for you!") |
| number_of_samples = st.text_input("Number of samples to generate messages", "20", placeholder="(optional) default is 20") |
| st.session_state.number_of_samples = int(number_of_samples) if number_of_samples else 20 |
| st.markdown("---") |
|
|
|
|
| def input_message_sequence_parameters(): |
| """Collect settings for sequential message generation (new feature).""" |
| st.header("Sequential Messaging Parameters") |
|
|
| |
| number_of_messages = st.number_input( |
| "Number of sequential messages to generate (per user)", |
| min_value=1, max_value=10, value=1, step=1, key="num_seq_msgs" |
| ) |
| st.session_state.number_of_messages = number_of_messages |
|
|
| |
| segment_name = st.text_input( |
| "Segment Name", value="", placeholder="e.g., no_recent_activity", key="segment_name_input" |
| ) |
| st.session_state.segment_name = segment_name |
|
|
| |
| st.subheader("Instructions per Message") |
| st.caption("Provide additional tone or style instructions for each sequential message. Leave blank to inherit the main instructions.") |
|
|
| instructionset = {} |
| cols = st.columns(number_of_messages) |
| for i in range(1, number_of_messages + 1): |
| with cols[(i - 1) % number_of_messages]: |
| instr = st.text_input( |
| f"Message {i} instructions", value="", placeholder="e.g., Be Cheerful & Motivational", key=f"instr_{i}" |
| ) |
| if instr.strip(): |
| instructionset[i] = instr.strip() |
|
|
| |
| st.session_state.instructionset = instructionset |
| st.markdown("---") |
|
|
|
|
| def select_features_from_source_info(): |
| st.header("Select Features from Available Source Information") |
| available_features = ["first_name", "biography", "birthday_reminder", "goals", "Minutes_practiced", "Last_completed_content"] |
| selected_source_features = st.multiselect("Select features to use from available source information", available_features) |
| selected_source_features.append("instrument") |
| st.session_state.selected_source_features = selected_source_features |
| st.markdown("---") |
|
|
|
|
| def select_features_from_input_file(users): |
| st.header("Select Features from your Input file") |
| columns = users.columns.tolist() |
| selected_features = st.multiselect("Select features to use in generated messages from the input file", columns) |
| st.session_state.selected_features = selected_features |
| st.markdown("---") |
|
|
|
|
| def provide_additional_instructions(): |
| st.header("Additional Instructions") |
| additional_instructions = st.text_area("Provide additional instructions on how to use selected features in the generated message", "") |
| st.session_state.additional_instructions = additional_instructions |
| st.markdown("---") |
|
|
|
|
| def parse_user_generated_context(users): |
| st.header("Parsing User-Generated Context") |
| user_generated_context = st.checkbox("Do we have a user-generated context provided in the input that you wish to filter?") |
| st.session_state.user_generated_context = user_generated_context |
|
|
| if user_generated_context: |
| columns = users.columns.tolist() |
| ugc_column = st.selectbox("Select the column that contains User-Generated Context", columns) |
| st.session_state.ugc_column = ugc_column |
|
|
| st.subheader("Provide Additional Instructions for Validation (Optional)") |
| valid_instructions = st.text_area("Instructions for valid context", placeholder="Provide instructions for what constitutes valid context...") |
| invalid_instructions = st.text_area("Instructions for invalid context", placeholder="Provide instructions for what constitutes invalid context...") |
| st.session_state.valid_instructions = valid_instructions |
| st.session_state.invalid_instructions = invalid_instructions |
|
|
| input_validator = Validator(api_key=st.session_state.openai_api_key) |
| st.session_state.input_validator = input_validator |
| st.markdown("---") |
|
|
|
|
| def include_content_recommendations(): |
| st.header("Include Content Recommendations") |
| include_recommendation = st.checkbox("Would you like to include content in the message to recommend to the students?") |
| st.session_state.include_recommendation = include_recommendation |
|
|
| if include_recommendation: |
| recommendation_source = st.radio("Select recommendation source", ["Input File", "Musora Recommender System"]) |
| st.session_state.recommendation_source = recommendation_source |
|
|
| if recommendation_source == "Musora Recommender System": |
| st.session_state.involve_recsys_result = True |
| st.session_state.messaging_mode = "recsys_result" |
|
|
| list_of_content_types = ["song", "workout", "quick_tips", "course"] |
| selected_content_types = st.multiselect("Select content_types that you would like to recommend", list_of_content_types) |
| st.session_state.recsys_contents = selected_content_types |
| else: |
| st.session_state.involve_recsys_result = False |
| st.session_state.messaging_mode = "message" |
| columns = st.session_state.data.columns.tolist() |
| target_column = st.selectbox("Select the target column for recommendations", columns) |
| st.session_state.target_column = target_column |
| else: |
| st.session_state.messaging_mode = "message" |
| st.session_state.target_column = None |
| st.markdown("---") |
|
|
|
|
| def generate_personalized_messages(users): |
| st.header("Generate Personalized Messages") |
| if st.button("Generate Personalized Messages"): |
| if st.session_state.CTA.strip() == "" or st.session_state.segment_info.strip() == "": |
| st.error("CTA and Segment Info are mandatory fields and cannot be left empty.") |
| else: |
| conn = { |
| "user": os.environ.get("snowflake_user"), |
| "password": os.environ.get("snowflake_password"), |
| "account": os.environ.get("snowflake_account"), |
| "role": os.environ.get("snowflake_role"), |
| "database": os.environ.get("snowflake_database"), |
| "warehouse": os.environ.get("snowflake_warehouse"), |
| "schema": os.environ.get("snowflake_schema") |
| } |
|
|
| config_file_path = 'Config_files/message_system_config.json' |
| config_file = load_config_(config_file_path) |
| session = Session.builder.configs(conn).create() |
|
|
| if st.session_state.user_generated_context: |
| if st.session_state.valid_instructions.strip() or st.session_state.invalid_instructions.strip(): |
| st.session_state.input_validator.set_validator_instructions( |
| valid_instructions=st.session_state.valid_instructions, |
| invalid_instructions=st.session_state.invalid_instructions |
| ) |
| else: |
| st.session_state.input_validator.set_validator_instructions() |
|
|
| |
| progress_bar = st.progress(0) |
| status_text = st.empty() |
|
|
| |
| def progress_callback(progress, total): |
| percent_complete = int(progress / total * 100) |
| progress_bar.progress(percent_complete) |
| status_text.text(f"Validating user_generated_context: {percent_complete}%") |
|
|
| st.info("Validating user-generated content. This may take a few moments...") |
| users = st.session_state.input_validator.validate_dataframe( |
| dataframe=users, target_column=st.session_state.ugc_column, progress_callback=progress_callback) |
| users = filter_validated_users(users) |
| st.success("User-generated content has been validated and filtered.") |
|
|
| st.session_state.all_features = st.session_state.selected_source_features + st.session_state.selected_features |
|
|
| if "Last_completed_content" in st.session_state.selected_source_features: |
| st.session_state.involve_last_interaction = True |
| else: |
| st.session_state.involve_last_interaction = False |
|
|
| |
| progress_bar = st.progress(0) |
| status_text = st.empty() |
|
|
| |
| def progress_callback(progress, total): |
| percent_complete = int(progress / total * 100) |
| progress_bar.progress(percent_complete) |
| status_text.text(f"Processing: {percent_complete}%") |
|
|
| permes = Permes() |
| users_message = permes.create_personalize_messages( |
| session=session, |
| users=users, |
| brand=st.session_state.brand, |
| config_file=config_file, |
| openai_api_key=os.environ.get('OPENAI_API'), |
| CTA=st.session_state.CTA, |
| segment_info=st.session_state.segment_info, |
| number_of_samples=st.session_state.number_of_samples, |
| message_style=st.session_state.message_style, |
| sample_example=st.session_state.sample_example, |
| selected_input_features=st.session_state.selected_features, |
| selected_source_features=st.session_state.selected_source_features, |
| additional_instructions=st.session_state.additional_instructions, |
| platform=st.session_state.messaging_type, |
| involve_last_interaction=st.session_state.involve_last_interaction, |
| involve_recsys_result=st.session_state.involve_recsys_result, |
| messaging_mode=st.session_state.messaging_mode, |
| identifier_column=st.session_state.identifier_column, |
| target_column=st.session_state.target_column, |
| recsys_contents=st.session_state.recsys_contents, |
| progress_callback=progress_callback, |
| |
| number_of_messages=st.session_state.number_of_messages, |
| instructionset=st.session_state.instructionset, |
| segment_name=st.session_state.segment_name |
| ) |
|
|
| |
| progress_bar.empty() |
| status_text.empty() |
|
|
| csv_output = users_message.to_csv(encoding='utf-8-sig', index=False) |
| st.session_state.csv_output = csv_output |
| st.session_state.users_message = users_message |
| st.session_state.generated = True |
| st.success("Personalized messages have been generated.") |
| st.markdown("---") |
|
|
|
|
| def download_generated_messages(): |
| if st.session_state.get('generated', False): |
| st.header("Download Generated Messages") |
|
|
| |
| df = st.session_state.users_message |
|
|
| |
| csv_buffer = StringIO() |
| df.to_csv(csv_buffer, index=False, encoding='utf-8-sig') |
| csv_buffer.seek(0) |
|
|
| |
| csv_bytes = csv_buffer.getvalue().encode('utf-8-sig') |
|
|
| |
| st.download_button( |
| label="Download output messages as a CSV file", |
| data=csv_bytes, |
| file_name='personalized_messages.csv', |
| mime='text/csv' |
| ) |
|
|
|
|
| def view_generated_messages(): |
| |
| if not st.session_state.get('generated', False): |
| return |
|
|
| st.title("🔔 Generated Push Notifications Review") |
| df = st.session_state.users_message |
| identifier = st.session_state.identifier_column.lower() |
| features = st.session_state.all_features |
|
|
| for idx, (_, user_row) in enumerate(df.iterrows(), start=1): |
| user_id = user_row.get(identifier, "N/A") |
| |
| with st.expander(f"{idx}. User ID: {user_id}", expanded=(idx == 1)): |
| st.markdown("##### 👤 User Features") |
| |
| feature_cols = st.columns(3) |
| for i, feat in enumerate(features): |
| val = user_row.get(feat, "N/A") |
| feature_cols[i % 3].write(f"**{feat}**: {val}") |
|
|
| st.markdown("---") |
| st.markdown("##### 📝 Generated Messages") |
| raw = user_row.get('message', '[]') |
|
|
| try: |
| parsed = json.loads(raw) |
| |
| if isinstance(parsed, dict) and 'messages_sequence' in parsed: |
| messages = parsed['messages_sequence'] |
| |
| elif isinstance(parsed, list): |
| messages = parsed |
| else: |
| st.warning( |
| "Unexpected JSON structure for messages; expected a list or {'messages_sequence': [...]}") |
| messages = [] |
| except json.JSONDecodeError: |
| st.error("Could not parse message JSON") |
| messages = [] |
|
|
| |
| for m_idx, msg in enumerate(messages, start=1): |
| c_img, c_text = st.columns([1, 3]) |
| with c_img: |
| thumb = msg.get('thumbnail_url') |
| if thumb: |
| st.image(thumb, width=80) |
| else: |
| st.write("No image") |
|
|
| with c_text: |
| header = msg.get('header', '') |
| body = msg.get('message', '') |
| link = msg.get('web_url_path', '#') |
| st.markdown(f"**{m_idx}. {header}**") |
| st.markdown(body) |
| st.markdown(f"[Read more →]({link})") |
|
|
| st.markdown("---") |
|
|
|
|
| if __name__ == "__main__": |
| st.title("Personalized Message Generator") |
|
|
| |
| initialize_session_state() |
|
|
| |
| users = upload_csv_file() |
|
|
| if users is not None: |
| |
| select_identifier_column(users) |
| select_target_audience() |
| select_target_messaging_type() |
| input_personalization_parameters() |
| input_message_sequence_parameters() |
| select_features_from_source_info() |
| select_features_from_input_file(users) |
| provide_additional_instructions() |
| parse_user_generated_context(users) |
| include_content_recommendations() |
| generate_personalized_messages(users) |
| download_generated_messages() |
| view_generated_messages() |
|
|