Spaces:
Build error
Build error
| # --- Visualization --- | |
| import altair as alt | |
| import streamlit as st | |
| import plotly.graph_objects as go | |
| from streamlit_vega_lite import altair_component | |
| # --- Data --- | |
| import pandas as pd | |
| def base_chart(df, linked_vis=False, max_width=150, col_val=None,min_size=100,size_domain=[]): | |
| ''' Visualize the model's performance across susbets of the data''' | |
| #Defining populations in the data | |
| pop_domain = ["Overall Performance","Custom Slice","User Custom Sentence","US Protected Class"] | |
| color_range = ["#5778a4", "#e49444", "#b8b0ac","#85b6b2"] | |
| #being chart | |
| base = alt.Chart(df) | |
| if linked_vis: | |
| selected = alt.selection_single( | |
| on="click", empty="none", fields=["name", "source"] | |
| ) | |
| base = base.add_selection(selected) | |
| base = ( | |
| base.mark_bar().encode( | |
| alt.X("metric_value", | |
| scale=alt.Scale(domain=(0, 1)), title="" | |
| ), | |
| alt.Y("displayName", title=""), | |
| alt.Column("metric_type", title=""), | |
| alt.StrokeWidth("size:N", | |
| scale=alt.Scale(domain=size_domain,range=[0,1.25]), | |
| title="#sentences" | |
| ), | |
| alt.StrokeOpacity("size:N", | |
| scale=alt.Scale(domain=size_domain,range=[0,1]) | |
| ), | |
| alt.Stroke("size:N", | |
| scale=alt.Scale(domain=size_domain,range=["white","red"]), | |
| ), | |
| alt.Fill("source", | |
| scale = alt.Scale(domain = pop_domain, | |
| range=color_range), | |
| title = "Data Subpopulation"), | |
| opacity=alt.condition(selected, alt.value(1), alt.value(0.5)), | |
| tooltip=["name", "metric_type", "metric_value"] | |
| ).properties(width=125 | |
| ).configure_axis( | |
| labelFontSize=14 | |
| ). | |
| configure_legend( | |
| labelFontSize=14 | |
| ) | |
| ) | |
| else: | |
| #This is now depracted and should never occur | |
| base = ( | |
| base.mark_bar() | |
| .encode( | |
| alt.X("metric_value", scale=alt.Scale(domain=(0, 1)), title=""), | |
| alt.Y( | |
| "metric_type", | |
| title="", | |
| sort=["Overall Performance", "Your Sentences"], | |
| ), | |
| # alt.Row("metric_type",title=""), | |
| color=alt.value(col_val), | |
| tooltip=["name", "metric_type", "metric_value"], | |
| ) | |
| .properties(width=max_width) | |
| ) | |
| return base | |
| def visualize_metrics(metrics, max_width=150, linked_vis=False, col_val="#1f77b4",min_size=1000): | |
| """ | |
| Visualize the metrics of the model. | |
| """ | |
| metric_df = pd.DataFrame() | |
| for key in metrics.keys(): | |
| metric_types = [] | |
| metric_values = [] | |
| tmp = metrics[key]["metrics"] | |
| # get individual metrics | |
| for mt in tmp.keys(): | |
| metric_types = metric_types + [mt] | |
| metric_values = metric_values + [tmp[mt]] | |
| name = [key] * len(metric_types) | |
| size = [metrics[key]["size"]] * len(metric_types) | |
| source = [metrics[key]["source"]] * len(metric_types) | |
| metric_df = metric_df.append( | |
| pd.DataFrame( | |
| { | |
| "name": name, | |
| "metric_type": metric_types, | |
| "metric_value": metric_values, | |
| "source": source, | |
| "size" : [ f">={min_size} sentences" if x >= min_size else f"<{min_size} sentences" for x in size] | |
| } | |
| ) | |
| ) | |
| #adding a human friendly display name (not RG's backend-name) | |
| tmp = [i.split("->") for i in metric_df['name']] | |
| metric_df['displayName']=[i.split("@")[0] for i in [j[0] if len(j)<=1 else j[1] for j in tmp ]] | |
| #passing the size domain | |
| size_domain = [f">={min_size} sentences", f"<{min_size} sentences"] | |
| # generic metric chart | |
| base = base_chart(metric_df, linked_vis, col_val=col_val,size_domain=size_domain) | |
| # layered chart with line | |
| """ | |
| # vertical line | |
| vertline = alt.Chart().mark_rule().encode(x="a:Q") | |
| metric_chart = ( | |
| alt.layer(base, vertline,data=metric_df) | |
| .transform_calculate(a="0.5") | |
| .facet( | |
| alt.Column("metric_type", title="")) | |
| .configure_header(labelFontSize=12 | |
| ) | |
| ) | |
| """ | |
| return base | |
| #@st.cache(allow_output_mutation=True) | |
| def data_comparison(df): | |
| #set up a dropdown select bindinf | |
| #input_dropdown = alt.binding_select(options=['Negative Sentiment','Positive Sentiment']) | |
| selection = alt.selection_multi(fields=['name','sentiment']) | |
| #pop_domain = ["Overall Performance","Custom Slice","User Custom Sentence","US Protected Class"] | |
| #color_range = ["#5778a4", "#e49444", "#b8b0ac","#85b6b2",""] | |
| #highlight colors on select | |
| color = alt.condition(selection, | |
| alt.Color('source:N', legend=None), | |
| #scale = alt.Scale(domain = pop_domain,range=color_range)), | |
| alt.value('lightgray')) | |
| opacity = alt.condition(selection,alt.value(0.7),alt.value(0.25)) | |
| #basic chart | |
| scatter = alt.Chart(df).mark_point(size=100,filled=True).encode( | |
| x=alt.X('x',axis=None), | |
| y=alt.Y('y',axis=None), | |
| color = color, | |
| shape=alt.Shape('sentiment', scale=alt.Scale(range=['circle', 'diamond'])), | |
| tooltip=['source','name','sentence','sentiment'], | |
| opacity=opacity | |
| ).properties( | |
| width= 600, | |
| height = 700 | |
| ).interactive() | |
| legend = alt.Chart(df).mark_point().encode( | |
| y=alt.Y('name:N', axis=alt.Axis(orient='right'),title=""), | |
| x=alt.X("sentiment"), | |
| shape=alt.Shape('sentiment', scale=alt.Scale(range=['circle', 'diamond']),legend=None), | |
| color=color | |
| ).add_selection( | |
| selection | |
| ) | |
| layered = scatter | legend | |
| layered = layered.configure_axis( | |
| grid=False | |
| ).configure_view( | |
| strokeOpacity=0 | |
| ) | |
| return layered | |
| def vis_table(df, userInput=False): | |
| """ DEPRECATED : Visualize table data more effectively """ | |
| fig = go.Figure( | |
| data=[ | |
| go.Table( | |
| header=dict( | |
| values=list(df.columns), fill_color="paleturquoise", align="left" | |
| ), | |
| columnwidth=[400, 50, 50], | |
| cells=dict( | |
| values=[df["sentence"], df["model label"], df["probability"]], | |
| fill_color="lavender", | |
| align="left", | |
| ), | |
| ) | |
| ] | |
| ) | |
| return fig | |