File size: 13,012 Bytes
08c9602
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
import streamlit as st
import time
import pandas as pd
import plotly.express as px
import snowflake.connector
import base64
from datetime import timedelta, datetime
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
import concurrent.futures

# Import SQL query functions.
from house_ad_queries import (
    get_main_query,
    get_flex_query,
    get_bidder_query,
    get_deal_query,
    get_ad_unit_query,
    get_browser_query,
    get_device_query,
    get_random_integer_query,
    get_hb_pb_query,
    get_hb_size_query,
)

# Import the house ad section config.
from house_ad_section_utils import update_section_generic

# Import the NEXT_STEPS_INSTRUCTIONS at the top.
from house_ad_instructions import NEXT_STEPS_INSTRUCTIONS

# Initialize session state keys at the top so they only get set once.
st.session_state.setdefault("query_run", False)
st.session_state.setdefault("findings_messages", [])
st.session_state.setdefault("key_findings_output", None)
st.session_state.setdefault("query_df", None)
st.session_state.setdefault("agg_df", None)
st.session_state.setdefault("top_level_spike_time", None)

# --- Helper Functions ---

# def load_private_key(key_str):
#    """Load a PEM-formatted private key."""
#    return serialization.load_pem_private_key(
#        key_str.encode("utf-8"),
#        password=None,
#        backend=default_backend()
#    )


# Use caching to avoid re-running the same query on every interaction.
@st.cache_data(show_spinner=False)
def cached_run_query(
    query,
    private_key_b64: str,
    user: str,
    account_identifier: str,
    warehouse: str,
    database: str,
    schema: str,
    role: str,
):
    # 1) Decode the base64‐encoded DER key
    der = base64.b64decode(private_key_b64)
    """Connect to Snowflake and execute the given query. Cached to reduce re-runs."""
    #    private_key_obj = load_private_key(key_str=private_key_str)
    conn = snowflake.connector.connect(
        user=user,
        account=account_identifier,
        warehouse=warehouse,
        database=database,
        schema=schema,
        role=role,
        private_key=der,
    )
    cs = conn.cursor()
    cs.execute("ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = 1800")
    cs.execute(query)
    results = cs.fetchall()
    columns = [col[0] for col in cs.description]
    df = pd.DataFrame(results, columns=columns)
    cs.close()
    conn.close()
    return df


# --- Main Function for House Ad Spike Analysis ---


def run_house_ad_spike_query(
    table,
    start_datetime,
    end_datetime,
    message_filter,
    campaign_id,
    private_key_str,
    user,
    account_identifier,
    warehouse,
    database,
    schema,
    role,
    client,
):
    """
    Run the house ad spike query along with additional dimensions,
    generate key findings via OpenAI, and display the results.
    """
    # --- Generate SQL Queries ---
    main_sql = get_main_query(
        table, start_datetime, end_datetime, message_filter, campaign_id
    )
    flex_sql = get_flex_query(
        table, start_datetime, end_datetime, message_filter, campaign_id
    )
    bidder_sql = get_bidder_query(
        table, start_datetime, end_datetime, message_filter, campaign_id
    )
    deal_sql = get_deal_query(
        table, start_datetime, end_datetime, message_filter, campaign_id
    )
    ad_unit_sql = get_ad_unit_query(
        table, start_datetime, end_datetime, message_filter, campaign_id
    )
    browser_sql = get_browser_query(
        table, start_datetime, end_datetime, message_filter, campaign_id
    )
    device_sql = get_device_query(
        table, start_datetime, end_datetime, message_filter, campaign_id
    )
    random_integer_sql = get_random_integer_query(
        table, start_datetime, end_datetime, message_filter, campaign_id
    )
    hb_pb_sql = get_hb_pb_query(
        table, start_datetime, end_datetime, message_filter, campaign_id
    )
    hb_size_sql = get_hb_size_query(
        table, start_datetime, end_datetime, message_filter, campaign_id
    )

    # --- Main Query Execution ---
    # Run query only if it hasn't been run already.
    if not st.session_state["query_run"]:
        try:
            start_main = time.time()
            with st.spinner("Connecting to Snowflake and running top-level query..."):
                df = cached_run_query(
                    main_sql,
                    private_key_str,
                    user,
                    account_identifier,
                    warehouse,
                    database,
                    schema,
                    role,
                )
            elapsed_main = time.time() - start_main
            elapsed_minutes = int(elapsed_main // 60)
            elapsed_seconds = elapsed_main % 60

            st.info(
                f"Top-level SQL query executed in {elapsed_minutes} minute(s) and {elapsed_seconds:.2f} seconds."
            )

            # Process the results.
            df.columns = [col.upper() for col in df.columns]
            df.sort_values(by=["EST_HOUR", "EST_MINUTE"], inplace=True)
            df["timestamp"] = pd.to_datetime(
                df["EST_DATE"].astype(str)
                + " "
                + df["EST_HOUR"].astype(str).str.zfill(2)
                + ":"
                + df["EST_MINUTE"].astype(str).str.zfill(2)
            )
            df["5min"] = df["timestamp"].dt.floor("5T")
            agg_df = df.groupby("5min", as_index=False)["CNT"].sum()

            st.session_state["query_df"] = df
            st.session_state["agg_df"] = agg_df
            st.session_state["query_run"] = True
        except Exception as e:
            st.error(f"Error during main query execution: {e}")
            return
    else:
        # Use stored data.
        df = st.session_state.get("query_df")
        agg_df = st.session_state.get("agg_df")

    # --- Display Main Query Results ---
    st.header("Top-Level Data")
    top_level_baseline = 30
    agg_df["is_spike"] = agg_df.apply(
        lambda row: row["CNT"] > top_level_baseline, axis=1
    )
    spike_start = None
    consecutive = 0
    for idx, row in agg_df.sort_values("5min").iterrows():
        if row["is_spike"]:
            consecutive += 1
            if consecutive == 2:
                spike_start = row["5min"] - timedelta(minutes=5)
                break
        else:
            consecutive = 0

    if spike_start:
        msg = f"Top-Level: House ad increase detected starting around {spike_start.strftime('%I:%M %p')}."
        st.success(msg)
    else:
        msg = "Top-Level: No large, consistent spike detected in the current data."
        st.info(msg)
    # Append the message only once.
    findings_messages = st.session_state.setdefault("findings_messages", [])
    if msg not in findings_messages:
        findings_messages.append(msg)
    st.session_state["top_level_spike_time"] = spike_start

    with st.expander("Show Raw Data"):
        st.dataframe(df)
    with st.expander("Show Raw 5-Minute Aggregated Data with Spike Alert"):
        st.dataframe(agg_df)

    title_text = "House Ads Count by 5-Minute Interval"
    fig = px.line(
        agg_df,
        x="5min",
        y="CNT",
        title=title_text,
        labels={"5min": "Time", "CNT": "House Ads Count"},
    )
    fig.update_xaxes(tickformat="%I:%M %p")
    st.plotly_chart(fig, use_container_width=True)

    st.markdown("<hr style='border: 3px solid gray;'>", unsafe_allow_html=True)

    # --- Key Findings via OpenAI ---
    st.header("Key Findings and Next Steps")
    # Create a container to hold the key findings output.
    key_findings_container = st.container()

    # Initially display what’s in session_state (if anything) or a placeholder.
    with key_findings_container:
        if st.session_state.get("key_findings_output"):
            st.markdown(
                st.session_state.get("key_findings_output"),
                unsafe_allow_html=True,
            )
        else:
            st.info(
                "Key findings will appear here once additional queries have finished."
            )

    def generate_key_findings_callback():
        findings = "\n".join(st.session_state.get("findings_messages", []))
        flex_jira_info = st.session_state.get("flex_jira_info", "")
        jira_section = (
            f"\nJira Ticket Information from Flex Bucket section:\n{flex_jira_info}\n"
            if flex_jira_info
            else ""
        )
        prompt = (
            "You are a helpful analyst investigating a spike in house ads. A house ad spike detection dashboard has compiled a list of findings "
            "showing potential spikes across different dimensions. Below are the detailed findings from the dashboard, along with any flagged Jira ticket "
            "information. The NEXT_STEPS_INSTRUCTIONS file contains recommended next steps for each section depending on the spike(s) flagged in the dashboard:\n\n"
            f"Findings:\n{findings}\n"
            f"{jira_section}\n"
            "Next Steps Instructions:\n"
            f"{NEXT_STEPS_INSTRUCTIONS}\n\n"
            "Using the Findings, jira section information, and Next Steps Instructions as helpful context, create a concise summary "
            "that identifies the likely cause/causes of any detected house ad spikes and recommends actionable next steps. The summary "
            "should be a few sentences long followed by bullet points with the main findings and recommended next steps. Please output "
            "the summary in Markdown format with each bullet point on a new line, and indent sub-bullets properly. Ensure that each bullet "
            "point is on its own line. There is no need to explicitly mention the Instructions file in the summary, just use it to "
            "inform your analysis. "
        )
        st.session_state["key_findings"] = prompt
        try:
            response = client.responses.create(
                model="o3-mini",
                instructions="You are a helpful analyst who provides insights and recommends next steps.",
                input=prompt,
            )
            st.session_state["key_findings_output"] = response.output_text.strip()
        except Exception as e:
            st.session_state["key_findings_output"] = f"Error calling OpenAI API: {e}"

    # --- Additional Queries for Specific Dimensions ---
    st.header("Specific Dimensions Data")
    st.info("Checking specific dimensions for house ad spikes...")

    with st.spinner("Running additional queries..."):
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = {}
            start_times = {}
            query_dict = {
                "flex bucket": flex_sql,
                "bidder": bidder_sql,
                "deal": deal_sql,
                "ad_unit": ad_unit_sql,
                "browser": browser_sql,
                "device": device_sql,
                "random_integer": random_integer_sql,
                "hb_pb": hb_pb_sql,
                "hb_size": hb_size_sql,
            }
            for key, query in query_dict.items():
                start_times[key] = time.time()
                futures[key] = executor.submit(
                    cached_run_query,
                    query,
                    private_key_str,
                    user,
                    account_identifier,
                    warehouse,
                    database,
                    schema,
                    role,
                )

            containers = {
                "flex bucket": st.container(),
                "bidder": st.container(),
                "deal": st.container(),
                "ad_unit": st.container(),
                "browser": st.container(),
                "device": st.container(),
                "random_integer": st.container(),
                "hb_pb": st.container(),
                "hb_size": st.container(),
            }

            spike_time = st.session_state.get("top_level_spike_time")

            while futures:
                done, _ = concurrent.futures.wait(
                    list(futures.values()),
                    timeout=0.5,
                    return_when=concurrent.futures.FIRST_COMPLETED,
                )
                for future in done:
                    key = [k for k, f in futures.items() if f == future][0]
                    df_result = future.result()
                    update_section_generic(
                        key, df_result, start_times, containers[key], spike_time
                    )
                    del futures[key]

    # Once all additional queries have completed, automatically generate key findings.
    generate_key_findings_callback()

    # Update the key findings container with the new output.
    with key_findings_container:
        st.markdown(
            st.session_state.get("key_findings_output", ""),
            unsafe_allow_html=True,
        )