Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import time | |
| import pandas as pd | |
| import plotly.express as px | |
| import snowflake.connector | |
| import base64 | |
| from datetime import timedelta, datetime | |
| from cryptography.hazmat.primitives import serialization | |
| from cryptography.hazmat.backends import default_backend | |
| import concurrent.futures | |
| # Import SQL query functions. | |
| from house_ad_queries import ( | |
| get_main_query, | |
| get_flex_query, | |
| get_bidder_query, | |
| get_deal_query, | |
| get_ad_unit_query, | |
| get_browser_query, | |
| get_device_query, | |
| get_random_integer_query, | |
| get_hb_pb_query, | |
| get_hb_size_query, | |
| ) | |
| # Import the house ad section config. | |
| from house_ad_section_utils import update_section_generic | |
| # Import the NEXT_STEPS_INSTRUCTIONS at the top. | |
| from house_ad_instructions import NEXT_STEPS_INSTRUCTIONS | |
| # Initialize session state keys at the top so they only get set once. | |
| st.session_state.setdefault("query_run", False) | |
| st.session_state.setdefault("findings_messages", []) | |
| st.session_state.setdefault("key_findings_output", None) | |
| st.session_state.setdefault("query_df", None) | |
| st.session_state.setdefault("agg_df", None) | |
| st.session_state.setdefault("top_level_spike_time", None) | |
| # --- Helper Functions --- | |
| # def load_private_key(key_str): | |
| # """Load a PEM-formatted private key.""" | |
| # return serialization.load_pem_private_key( | |
| # key_str.encode("utf-8"), | |
| # password=None, | |
| # backend=default_backend() | |
| # ) | |
| # Use caching to avoid re-running the same query on every interaction. | |
| def cached_run_query( | |
| query, | |
| private_key_b64: str, | |
| user: str, | |
| account_identifier: str, | |
| warehouse: str, | |
| database: str, | |
| schema: str, | |
| role: str, | |
| ): | |
| # 1) Decode the base64‐encoded DER key | |
| der = base64.b64decode(private_key_b64) | |
| """Connect to Snowflake and execute the given query. Cached to reduce re-runs.""" | |
| # private_key_obj = load_private_key(key_str=private_key_str) | |
| conn = snowflake.connector.connect( | |
| user=user, | |
| account=account_identifier, | |
| warehouse=warehouse, | |
| database=database, | |
| schema=schema, | |
| role=role, | |
| private_key=der, | |
| ) | |
| cs = conn.cursor() | |
| cs.execute("ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = 1800") | |
| cs.execute(query) | |
| results = cs.fetchall() | |
| columns = [col[0] for col in cs.description] | |
| df = pd.DataFrame(results, columns=columns) | |
| cs.close() | |
| conn.close() | |
| return df | |
| # --- Main Function for House Ad Spike Analysis --- | |
| def run_house_ad_spike_query( | |
| table, | |
| start_datetime, | |
| end_datetime, | |
| message_filter, | |
| campaign_id, | |
| private_key_str, | |
| user, | |
| account_identifier, | |
| warehouse, | |
| database, | |
| schema, | |
| role, | |
| client, | |
| ): | |
| """ | |
| Run the house ad spike query along with additional dimensions, | |
| generate key findings via OpenAI, and display the results. | |
| """ | |
| # --- Generate SQL Queries --- | |
| main_sql = get_main_query( | |
| table, start_datetime, end_datetime, message_filter, campaign_id | |
| ) | |
| flex_sql = get_flex_query( | |
| table, start_datetime, end_datetime, message_filter, campaign_id | |
| ) | |
| bidder_sql = get_bidder_query( | |
| table, start_datetime, end_datetime, message_filter, campaign_id | |
| ) | |
| deal_sql = get_deal_query( | |
| table, start_datetime, end_datetime, message_filter, campaign_id | |
| ) | |
| ad_unit_sql = get_ad_unit_query( | |
| table, start_datetime, end_datetime, message_filter, campaign_id | |
| ) | |
| browser_sql = get_browser_query( | |
| table, start_datetime, end_datetime, message_filter, campaign_id | |
| ) | |
| device_sql = get_device_query( | |
| table, start_datetime, end_datetime, message_filter, campaign_id | |
| ) | |
| random_integer_sql = get_random_integer_query( | |
| table, start_datetime, end_datetime, message_filter, campaign_id | |
| ) | |
| hb_pb_sql = get_hb_pb_query( | |
| table, start_datetime, end_datetime, message_filter, campaign_id | |
| ) | |
| hb_size_sql = get_hb_size_query( | |
| table, start_datetime, end_datetime, message_filter, campaign_id | |
| ) | |
| # --- Main Query Execution --- | |
| # Run query only if it hasn't been run already. | |
| if not st.session_state["query_run"]: | |
| try: | |
| start_main = time.time() | |
| with st.spinner("Connecting to Snowflake and running top-level query..."): | |
| df = cached_run_query( | |
| main_sql, | |
| private_key_str, | |
| user, | |
| account_identifier, | |
| warehouse, | |
| database, | |
| schema, | |
| role, | |
| ) | |
| elapsed_main = time.time() - start_main | |
| elapsed_minutes = int(elapsed_main // 60) | |
| elapsed_seconds = elapsed_main % 60 | |
| st.info( | |
| f"Top-level SQL query executed in {elapsed_minutes} minute(s) and {elapsed_seconds:.2f} seconds." | |
| ) | |
| # Process the results. | |
| df.columns = [col.upper() for col in df.columns] | |
| df.sort_values(by=["EST_HOUR", "EST_MINUTE"], inplace=True) | |
| df["timestamp"] = pd.to_datetime( | |
| df["EST_DATE"].astype(str) | |
| + " " | |
| + df["EST_HOUR"].astype(str).str.zfill(2) | |
| + ":" | |
| + df["EST_MINUTE"].astype(str).str.zfill(2) | |
| ) | |
| df["5min"] = df["timestamp"].dt.floor("5T") | |
| agg_df = df.groupby("5min", as_index=False)["CNT"].sum() | |
| st.session_state["query_df"] = df | |
| st.session_state["agg_df"] = agg_df | |
| st.session_state["query_run"] = True | |
| except Exception as e: | |
| st.error(f"Error during main query execution: {e}") | |
| return | |
| else: | |
| # Use stored data. | |
| df = st.session_state.get("query_df") | |
| agg_df = st.session_state.get("agg_df") | |
| # --- Display Main Query Results --- | |
| st.header("Top-Level Data") | |
| top_level_baseline = 30 | |
| agg_df["is_spike"] = agg_df.apply( | |
| lambda row: row["CNT"] > top_level_baseline, axis=1 | |
| ) | |
| spike_start = None | |
| consecutive = 0 | |
| for idx, row in agg_df.sort_values("5min").iterrows(): | |
| if row["is_spike"]: | |
| consecutive += 1 | |
| if consecutive == 2: | |
| spike_start = row["5min"] - timedelta(minutes=5) | |
| break | |
| else: | |
| consecutive = 0 | |
| if spike_start: | |
| msg = f"Top-Level: House ad increase detected starting around {spike_start.strftime('%I:%M %p')}." | |
| st.success(msg) | |
| else: | |
| msg = "Top-Level: No large, consistent spike detected in the current data." | |
| st.info(msg) | |
| # Append the message only once. | |
| findings_messages = st.session_state.setdefault("findings_messages", []) | |
| if msg not in findings_messages: | |
| findings_messages.append(msg) | |
| st.session_state["top_level_spike_time"] = spike_start | |
| with st.expander("Show Raw Data"): | |
| st.dataframe(df) | |
| with st.expander("Show Raw 5-Minute Aggregated Data with Spike Alert"): | |
| st.dataframe(agg_df) | |
| title_text = "House Ads Count by 5-Minute Interval" | |
| fig = px.line( | |
| agg_df, | |
| x="5min", | |
| y="CNT", | |
| title=title_text, | |
| labels={"5min": "Time", "CNT": "House Ads Count"}, | |
| ) | |
| fig.update_xaxes(tickformat="%I:%M %p") | |
| st.plotly_chart(fig, use_container_width=True) | |
| st.markdown("<hr style='border: 3px solid gray;'>", unsafe_allow_html=True) | |
| # --- Key Findings via OpenAI --- | |
| st.header("Key Findings and Next Steps") | |
| # Create a container to hold the key findings output. | |
| key_findings_container = st.container() | |
| # Initially display what’s in session_state (if anything) or a placeholder. | |
| with key_findings_container: | |
| if st.session_state.get("key_findings_output"): | |
| st.markdown( | |
| st.session_state.get("key_findings_output"), | |
| unsafe_allow_html=True, | |
| ) | |
| else: | |
| st.info( | |
| "Key findings will appear here once additional queries have finished." | |
| ) | |
| def generate_key_findings_callback(): | |
| findings = "\n".join(st.session_state.get("findings_messages", [])) | |
| flex_jira_info = st.session_state.get("flex_jira_info", "") | |
| jira_section = ( | |
| f"\nJira Ticket Information from Flex Bucket section:\n{flex_jira_info}\n" | |
| if flex_jira_info | |
| else "" | |
| ) | |
| prompt = ( | |
| "You are a helpful analyst investigating a spike in house ads. A house ad spike detection dashboard has compiled a list of findings " | |
| "showing potential spikes across different dimensions. Below are the detailed findings from the dashboard, along with any flagged Jira ticket " | |
| "information. The NEXT_STEPS_INSTRUCTIONS file contains recommended next steps for each section depending on the spike(s) flagged in the dashboard:\n\n" | |
| f"Findings:\n{findings}\n" | |
| f"{jira_section}\n" | |
| "Next Steps Instructions:\n" | |
| f"{NEXT_STEPS_INSTRUCTIONS}\n\n" | |
| "Using the Findings, jira section information, and Next Steps Instructions as helpful context, create a concise summary " | |
| "that identifies the likely cause/causes of any detected house ad spikes and recommends actionable next steps. The summary " | |
| "should be a few sentences long followed by bullet points with the main findings and recommended next steps. Please output " | |
| "the summary in Markdown format with each bullet point on a new line, and indent sub-bullets properly. Ensure that each bullet " | |
| "point is on its own line. There is no need to explicitly mention the Instructions file in the summary, just use it to " | |
| "inform your analysis. " | |
| ) | |
| st.session_state["key_findings"] = prompt | |
| try: | |
| response = client.responses.create( | |
| model="o3-mini", | |
| instructions="You are a helpful analyst who provides insights and recommends next steps.", | |
| input=prompt, | |
| ) | |
| st.session_state["key_findings_output"] = response.output_text.strip() | |
| except Exception as e: | |
| st.session_state["key_findings_output"] = f"Error calling OpenAI API: {e}" | |
| # --- Additional Queries for Specific Dimensions --- | |
| st.header("Specific Dimensions Data") | |
| st.info("Checking specific dimensions for house ad spikes...") | |
| with st.spinner("Running additional queries..."): | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| futures = {} | |
| start_times = {} | |
| query_dict = { | |
| "flex bucket": flex_sql, | |
| "bidder": bidder_sql, | |
| "deal": deal_sql, | |
| "ad_unit": ad_unit_sql, | |
| "browser": browser_sql, | |
| "device": device_sql, | |
| "random_integer": random_integer_sql, | |
| "hb_pb": hb_pb_sql, | |
| "hb_size": hb_size_sql, | |
| } | |
| for key, query in query_dict.items(): | |
| start_times[key] = time.time() | |
| futures[key] = executor.submit( | |
| cached_run_query, | |
| query, | |
| private_key_str, | |
| user, | |
| account_identifier, | |
| warehouse, | |
| database, | |
| schema, | |
| role, | |
| ) | |
| containers = { | |
| "flex bucket": st.container(), | |
| "bidder": st.container(), | |
| "deal": st.container(), | |
| "ad_unit": st.container(), | |
| "browser": st.container(), | |
| "device": st.container(), | |
| "random_integer": st.container(), | |
| "hb_pb": st.container(), | |
| "hb_size": st.container(), | |
| } | |
| spike_time = st.session_state.get("top_level_spike_time") | |
| while futures: | |
| done, _ = concurrent.futures.wait( | |
| list(futures.values()), | |
| timeout=0.5, | |
| return_when=concurrent.futures.FIRST_COMPLETED, | |
| ) | |
| for future in done: | |
| key = [k for k, f in futures.items() if f == future][0] | |
| df_result = future.result() | |
| update_section_generic( | |
| key, df_result, start_times, containers[key], spike_time | |
| ) | |
| del futures[key] | |
| # Once all additional queries have completed, automatically generate key findings. | |
| generate_key_findings_callback() | |
| # Update the key findings container with the new output. | |
| with key_findings_container: | |
| st.markdown( | |
| st.session_state.get("key_findings_output", ""), | |
| unsafe_allow_html=True, | |
| ) | |