Spaces:
Build error
Build error
| # streamlit | |
| import streamlit as st | |
| from streamlit_vega_lite import altair_component | |
| import base64 | |
| # data | |
| import pandas as pd | |
| # utils | |
| from numpy import round | |
| from interactive_model_cards import utils as ut | |
| def perf_interact(type="model perf",min_size=0): | |
| """ Instructions for interacting with the view""" | |
| if type == "model perf": | |
| st.markdown( | |
| f""" | |
| <span> | |
| <img src="data:image/png;base64,{base64.b64encode(open("./assets/img/warning-black.png", "rb").read()).decode()}"> All subpopulations with <strong>fewer than {min_size}</strong> sentences are reporting potentially unreliable results. These are <strong style="color:red">identified with a red border</strong> around the bar. | |
| </span> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| st.markdown("") #just to space them out | |
| st.markdown( | |
| f""" | |
| <span> | |
| <img src="data:image/png;base64,{base64.b64encode(open("./assets/img/click.png", "rb").read()).decode()}"> Click on the bars to see example sentences. | |
| </span> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| st.markdown("") #just to space them out | |
| else: | |
| st.write("This visualization shows a representation of the data according to how similar two sentences are *relative to the data the model was trained on*.") | |
| st.markdown( | |
| f""" | |
| <span> | |
| <img src="data:image/png;base64,{base64.b64encode(open("./assets/img/click.png", "rb").read()).decode()}"> <strong>Here are ways to interact with this view</strong>: | |
| </span> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| st.write("* You can `zoom in and out` of the visualization") | |
| st.write("* You can `hover` over a data point to see the sentence and sentiment") | |
| st.write("* You can `click on the legend` to emphasize subpopulations in the data according to positive of negative sentiment.") | |
| def quant_panel(sst_db, embedding, col,data_view): | |
| """ Quantitative Panel Layout""" | |
| all_metrics = {} | |
| with col: | |
| if data_view == "Model Performance Metrics": | |
| st.warning("**Model Performance Metrics**") | |
| st.markdown("* Evaluation metrics include [accuracy](https://simple.wikipedia.org/wiki/Accuracy_and_precision), [precision](https://en.wikipedia.org/wiki/Precision_and_recall), and [recall](https://en.wikipedia.org/wiki/Precision_and_recall).") | |
| st.markdown(" * Performance is shown for the training and testing set, as well as special groups within this dataset that have been automatically associated with US protected groups") | |
| min_size = st.number_input("Flag (with a red border) subpopulations with fewer than the follow sentences:", value=100, min_value=30, max_value=10000) | |
| perf_interact(type="model perf",min_size=min_size) | |
| #st.write(f'* All subsamples with `fewer than {min_size} sentences` are reporting potentially unreliable results and are <span style="color:red; fontface:bold">flagged with red border</span>. Take extra care when interpretting this data.', unsafe_allow_html=True) | |
| #st.markdown("* Click on the bars to see examples of sentences") | |
| for key in st.session_state["quant_ex"]: | |
| tmp = st.session_state["quant_ex"][key] | |
| if tmp is not None: | |
| for iKey in tmp.keys(): | |
| all_metrics[iKey] = {} | |
| all_metrics[iKey]["metrics"] = tmp[iKey] | |
| all_metrics[iKey]["source"] = key | |
| if key == "Overall Performance": | |
| #get the size of the dataset | |
| idx = ut.get_sliceid(list(sst_db.slices)).index(iKey) | |
| slice_data = list(sst_db.slices)[idx] | |
| # write slice data to UI | |
| df = ut.slice_to_df(slice_data) | |
| all_metrics[iKey]["size"] = df.shape[0] | |
| # due to the way slices are added | |
| # this hack is required | |
| if "RGDataset" in iKey: | |
| all_metrics[iKey]["source"] = "Custom Slice" | |
| elif "protected" in iKey: | |
| all_metrics[iKey]["source"] = "US Protected Class" | |
| else: | |
| all_metrics[iKey]["size"] = st.session_state["user_data"].shape[0] | |
| # st.write(all_metrics) | |
| chart = ut.visualize_metrics(all_metrics, max_width=100, linked_vis=True,min_size=min_size) | |
| event_dict = altair_component(altair_chart=chart) | |
| # st.altair_chart(chart) | |
| # if something was clicked on, find out what it was | |
| if "name" in event_dict.keys(): | |
| # identify what it was selected on | |
| st.session_state["selected_slice"] = { | |
| "name": event_dict["name"][0], | |
| "source": event_dict["source"][0], | |
| } | |
| if st.session_state["selected_slice"] is not None: | |
| get_selected = st.session_state["selected_slice"]["name"] | |
| #subsampling data from training data | |
| if st.session_state["selected_slice"]["source"] in [ | |
| "Overall Performance", | |
| "Custom Slice", | |
| "US Protected Class" | |
| ]: | |
| selected = st.session_state["selected_slice"]["name"] | |
| # get selected slice data | |
| #st.write(ut.get_sliceid(list(sst_db.slices))) | |
| idx = ut.get_sliceid(list(sst_db.slices)).index(selected) | |
| slice_data = list(sst_db.slices)[idx] | |
| # write slice data to UI | |
| df = ut.slice_to_df(slice_data) | |
| #subsetting the data | |
| st.warning("**Data Details**") | |
| with st.expander("Customize Data Sample"): | |
| with st.form("Sample Form"): | |
| st.number_input( | |
| "Number of Samples", | |
| value=min(df.shape[0],10), | |
| min_value=1, | |
| max_value=df.shape[0], | |
| key="sampleNum", | |
| ) | |
| st.selectbox( | |
| "Sample Type", | |
| [ | |
| "Random Sample", | |
| "Highest Probabilities", | |
| "Lowest Probabilities", | |
| "Mid Probabilities", | |
| ], | |
| index=0, | |
| key="sampleType", | |
| ) | |
| st.form_submit_button("Generate Sample") | |
| #drawing the sampled data | |
| #summarize slice information | |
| displayName = str(selected).split("->") | |
| if len(displayName) > 1: | |
| displayName = displayName[1].split("@")[0].strip() | |
| else: | |
| displayName= displayName[0] | |
| st.markdown( | |
| f"* The slice `{displayName}` has a total size of `{df.shape[0]} sentences`" | |
| ) | |
| #summarize data sample size and sampling method | |
| st.markdown( | |
| f"* Shown is a subsample of all the data to `{st.session_state['sampleNum']}` sampled by `{st.session_state['sampleType']}`" | |
| ) | |
| # add terms in user has selectd a custom slice | |
| if st.session_state["selected_slice"]["source"]=="Custom Slice": | |
| terms_str = ', '.join(st.session_state["slice_terms"][selected]) | |
| st.markdown(f"* This slice contains sentences containing one or more of following has the following terms:`{terms_str}`") | |
| elif st.session_state["selected_slice"]["source"]=="US Protected Class": | |
| terms = st.session_state["protected_class"][displayName] | |
| terms_str = ", ".join(terms) | |
| st.markdown(f"* Sentences pertaining this US Protected Classes contain the following-terms: `{terms_str}`") | |
| st.markdown( | |
| f""" | |
| <span> | |
| <img src="data:image/png;base64,{base64.b64encode(open("./assets/img/warning-black.png", "rb").read()).decode()}"> Detecting US Protected classess by key word search is not perfect. Some sentences below may not be pertintent to a protected class, for example the word 'black' can refer individuals but not always. | |
| </span> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| st.table( | |
| ut.subsample_df( | |
| df, | |
| st.session_state["sampleNum"], | |
| st.session_state["sampleType"], | |
| ) | |
| ) | |
| elif st.session_state["selected_slice"]["source"] in ["User Custom Sentence"]: | |
| #st.markdown(f"These are {st.session_state["user_data"]} custom sentences you have defined") | |
| st.markdown("**Data Details**") | |
| df = st.session_state["user_data"] | |
| st.markdown(f"These are your `{df.shape[0]}` custom sentences") | |
| st.write(df) | |
| else: | |
| st.warning("**Subpopulation Comparison**") | |
| perf_interact(type="comparison") | |
| with st.expander("how to read this chart:"): | |
| st.markdown("* each **point** is a single sentence") | |
| st.markdown("* the **position** of each dot is determined mathematically based upon an analysis of the words in a sentence. The **closer** two points on the visualization the **more similar** the sentences are. The **further apart ** two points on the visualization the **more different** the sentences are") | |
| st.markdown(" * the **shape** of each point reflects whether it a positive (diamond) or negative sentiment (circle)") | |
| st.markdown("* the **color** of each point is the ") | |
| #down sample embedding for altair limitations | |
| tmp = embedding | |
| tmp = ut.down_samp(embedding) | |
| st.altair_chart(ut.data_comparison(tmp)) | |
| __all__ = ["quant_panel"] | |