Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import pandas as pd | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import streamlit as st | |
| def _viz_rank(results): | |
| tau = results["tau"] | |
| concepts = results["concepts"] | |
| tau_mu = tau.mean(axis=0) | |
| sorted_idx = np.argsort(tau_mu) | |
| sorted_tau = tau_mu[sorted_idx] | |
| sorted_concepts = [concepts[idx] for idx in sorted_idx] | |
| sorted_width = 1 - sorted_tau | |
| sorted_width /= sorted_width.max() | |
| sorted_width *= 80 | |
| rank_el = "" | |
| for concept_idx, concept in enumerate(sorted_concepts): | |
| circle_style = ( | |
| "background: #418FDE;border-radius: 50%;width:" | |
| f" {sorted_width[concept_idx]}px;padding-bottom:" | |
| f" {sorted_width[concept_idx]}px;" | |
| ) | |
| rank_el += ( | |
| "<div id='conceptContainer'><p" | |
| f" id='concept'><strong>{concept}<strong></p><div id='circleContainer'><div" | |
| f" style='{circle_style}'></div></div></div>" | |
| ) | |
| st.markdown(rank_el, unsafe_allow_html=True) | |
| def _viz_test(results): | |
| rejected = results["rejected"] | |
| tau = results["tau"] | |
| concepts = results["concepts"] | |
| significance_level = results["significance_level"] | |
| rejected_mu = rejected.mean(axis=0) | |
| tau_mu = tau.mean(axis=0) | |
| sorted_idx = np.argsort(tau_mu)[::-1] | |
| sorted_tau = tau_mu[sorted_idx] | |
| sorted_rejected = rejected_mu[sorted_idx] | |
| sorted_concepts = [concepts[idx] for idx in sorted_idx] | |
| rank_df = [] | |
| for concept, tau, rejected in zip(sorted_concepts, sorted_tau, sorted_rejected): | |
| rank_df.append({"concept": concept, "tau": tau, "rejected": rejected}) | |
| rank_df = pd.DataFrame(rank_df) | |
| fig = go.Figure() | |
| fig.add_trace( | |
| go.Scatter( | |
| x=rank_df["rejected"], | |
| y=rank_df["concept"], | |
| marker=dict(size=8), | |
| line=dict(color="#1f78b4", dash="dash"), | |
| name="Rejection rate", | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Bar( | |
| x=rank_df["tau"], | |
| y=rank_df["concept"], | |
| orientation="h", | |
| marker=dict(color="#a6cee3"), | |
| name="Rejection time", | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=[significance_level, significance_level], | |
| y=[sorted_concepts[0], sorted_concepts[0]], | |
| mode="lines", | |
| line=dict(color="black", dash="dash"), | |
| name="significance level", | |
| ) | |
| ) | |
| fig.add_vline(significance_level, line_dash="dash", line_color="black") | |
| fig.update_layout( | |
| yaxis_title="Rank of importance", | |
| xaxis_title="", | |
| margin=dict(l=20, r=20, t=20, b=20), | |
| ) | |
| if rank_df["tau"].min() <= 0.3: | |
| fig.update_layout( | |
| legend=dict( | |
| x=0.3, | |
| y=1.0, | |
| bordercolor="black", | |
| borderwidth=1, | |
| ), | |
| ) | |
| _, centercol, _ = st.columns([1, 3, 1]) | |
| with centercol: | |
| st.plotly_chart(fig, use_container_width=True) | |
| def _viz_wealth(results): | |
| wealth = results["wealth"] | |
| concepts = results["concepts"] | |
| significance_level = results["significance_level"] | |
| wealth_mu = wealth.mean(axis=0) | |
| wealth_df = [] | |
| for concept_idx, concept in enumerate(concepts): | |
| for t in range(wealth.shape[1]): | |
| wealth_df.append( | |
| {"time": t, "concept": concept, "wealth": wealth_mu[t, concept_idx]} | |
| ) | |
| wealth_df = pd.DataFrame(wealth_df) | |
| fig = px.line(wealth_df, x="time", y="wealth", color="concept") | |
| fig.add_hline( | |
| y=1 / significance_level, | |
| line_dash="dash", | |
| line_color="black", | |
| annotation_text="Rejection threshold (1 / α)", | |
| annotation_position="bottom right", | |
| ) | |
| fig.update_yaxes(range=[0, 1.5 * 1 / significance_level]) | |
| fig.update_layout(margin=dict(l=20, r=20, t=20, b=20)) | |
| st.plotly_chart(fig, use_container_width=True) | |
| def viz_results(): | |
| results = st.session_state.results | |
| st.header("Results") | |
| rank_tab, test_tab, wealth_tab = st.tabs( | |
| ["Rank of importance", "Testing results", "Wealth process"] | |
| ) | |
| with rank_tab: | |
| st.subheader("Rank of Importance") | |
| st.write( | |
| """ | |
| This tab visually shows the rank of importance of the specified concepts | |
| for the prediction of the model on the input image. Larger font sizes indicate | |
| higher importance. See the other two tabs for more details. | |
| """ | |
| ) | |
| if results is not None: | |
| _viz_rank(results) | |
| st.divider() | |
| else: | |
| st.info("Waiting for results", icon="ℹ️") | |
| with test_tab: | |
| st.subheader("Testing Results") | |
| st.write( | |
| """ | |
| Importance is measured by performing sequential tests of statistical independence. | |
| This tab shows the results of these tests and how the rank of importance is computed. | |
| Concepts are sorted by increasing rejection time, where a shorter rejection time indicates | |
| higher importance. | |
| """ | |
| ) | |
| with st.expander("Details"): | |
| st.markdown( | |
| """ | |
| Results are averaged over multiple random draws of conditioning subsets of | |
| concepts. The number of tests can be controlled under `Advanced settings`. | |
| - **Rejection rate**: The average number of times the test is rejected for a concept. | |
| - **Rejection time**: The (normalized) average number of steps before the test is | |
| rejected for a concept. | |
| - **Significance level**: The level at which the test is rejected for a concept. | |
| """ | |
| ) | |
| if results is not None: | |
| _viz_test(results) | |
| st.divider() | |
| else: | |
| st.info("Waiting for results", icon="ℹ️") | |
| with wealth_tab: | |
| st.subheader("Wealth Process of Testing Procedures") | |
| st.markdown( | |
| """ | |
| Sequential tests instantiate a wealth process for each concept. Once the | |
| wealth reaches a value of 1/α, the test is rejected with Type I error control at | |
| level α. This tab shows the average wealth process of the testing procedures for | |
| each concept. | |
| """ | |
| ) | |
| if results is not None: | |
| _viz_wealth(results) | |
| st.divider() | |
| else: | |
| st.info("Waiting for results", icon="ℹ️") | |