retrodict / app.py
junsol's picture
Update app.py
23decd3
import re
import streamlit as st
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
from pygam import GAM
from rapidfuzz import fuzz, utils
@st.cache_data
def convert_df(df):
'''dataframe to csv'''
return df.to_csv(index=False).encode('utf-8')
def highlight(s, query):
'''find the query and highlight it'''
highlighted_query = '<span style="background-color: #FFFF00">{}</span>'
result = re.sub(f'(?i){query}', lambda m: highlighted_query.format(m.group(0)), s)
return result
def score_similarity(df, column, query):
'''similarity between query and each value in the column'''
df = df[df[column].str.lower().str.contains(query)]
df['wratio'] = df[column].apply(lambda x: fuzz.WRatio(query, x))
df = df.sort_values('wratio', ascending=False)
df[column] = df[column].apply(lambda x: highlight(str(x), query))
return df
def search_data(query, search_criteria, search_filter):
'''search query in the data and return the relevant rows.
search_criteria: column to find query, search_filter: subject'''
if len(search_filter) == 0:
data_sub = data
else:
selected_subjects = count_meta.loc[count_meta.option.isin(search_filter), 'subject']
selected_vars = subject.loc[subject.subject.isin(selected_subjects), 'var_name']
data_sub = data.loc[data.var_name.isin(selected_vars)]
if len(query) == 0:
return data_sub
query = query.lower()
r = [score_similarity(data_sub, col, query) for col in ['var_name', 'var_description', 'question', 'subject']]
results = {
'All': pd.concat(r, axis=0).drop_duplicates('original_var_name').reset_index(drop=True),
'Variable Names': r[0],
'Variable Descriptions': r[1],
'Survey Questions': r[2],
'GSS Tags': r[3]
}
return results.get(search_criteria)
def get_table_for_figure(var):
'''Create the table that will be used to draw the chart'''
tab_pred = dt_summary1.loc[(dt_summary1.variable == var) & (dt_summary1.pred_type == 'rescale')]
tab_obs = dt_summary1.loc[(dt_summary1.variable == var) & (dt_summary1.pred_type == 'obsbin')]
tab_obs = tab_obs.merge(tab_pred[['year']], how='right').rename({'mean': 'obs_mean', 'lci': 'obs_lci', 'uci': 'obs_uci'}, axis=1)
tab_pred = pd.merge(tab_pred, tab_obs[['year','obs_mean','obs_lci','obs_uci']], on = 'year')
tab_pred['lci'] = tab_pred['mean'] - 0.03
tab_pred['uci'] = tab_pred['mean'] + 0.03
tab_pred.loc[(tab_pred.lci <= tab_pred.obs_mean) & (tab_pred.obs_mean <= tab_pred.uci) , 'overlap'] = 0
tab_pred.loc[pd.notnull(tab_pred.obs_mean) & ~((tab_pred.lci <= tab_pred.obs_mean) & (tab_pred.obs_mean <= tab_pred.uci)) , 'overlap'] = 1
tab_pred.loc[pd.isnull(tab_pred.obs_mean), 'overlap'] = 2
tab_pred[['mean', 'lci', 'uci', 'obs_mean', 'obs_lci', 'obs_uci']] *= 100
if pd.notnull(tab_pred['obs_mean']).sum() == 1:
tab_pred.loc[pd.notnull(tab_pred.obs_mean), ['mean', 'lci', 'uci']] = None
return tab_pred
def get_figure(var, var_desc):
df = get_table_for_figure(var)
range_y = (min(df['lci'].min(), df.obs_lci.min()) - 15,
max(df['uci'].max(), df.obs_uci.max()) + 15)
df['mean_gam'] = df['mean']
df.loc[pd.isnull(df['mean']), 'mean_gam'] = df.loc[pd.isnull(df['mean']), 'obs_mean']
gam = GAM().gridsearch(np.array(df[['year']]), np.array(df[['mean_gam']])) # GAM to draw the trend line
XX = gam.generate_X_grid(term=0)
plt.figure(figsize=(10, 6))
fig, ax = plt.subplots()
marker_dict = {0: 'o', 1: 'o', 2: (8, 2, 0)}
color_dict = {0: 'black', 1: 'black', 2: 'b'}
fill_color_dict = {0: 'black', 1: 'none', 2: 'b'}
label_dict = {0: 'correct prediction', 1: 'incorrect prediction', 2: 'novel prediction'}
size_dict = {0: 5.5, 1: 5.8, 2: 8}
plt.plot(XX[:, 0].flatten(), gam.predict(XX), 'black') # GAM plot
plt.fill_between(XX[:, 0].flatten(), gam.confidence_intervals(XX, width=.95)[:, 0], gam.confidence_intervals(XX, width=.95)[:, 1], color='lightgray')
for i in range(3): # Plot predictions
plt.plot(df.loc[df.overlap == i, 'year'], df.loc[df.overlap == i, 'mean'], label=label_dict[i], marker=marker_dict[i],
c=color_dict[i], markerfacecolor=fill_color_dict[i], ms=size_dict[i], mew=1, linestyle = 'None')
plt.errorbar(df.loc[df.overlap == i, 'year'], df.loc[df.overlap == i, 'mean'],
yerr=[df.loc[df.overlap == i, 'mean']-df.loc[df.overlap == i, 'lci'], df.loc[df.overlap == i, 'uci']-df.loc[df.overlap == i, 'mean']],
fmt="none", color=color_dict[i], ecolor=color_dict[i], elinewidth=1, capsize=0)
plt.plot(df['year'], df['obs_mean'], marker="+", c='r', ms=7, mew=1.1, linestyle = 'None', label='observed') # Plot observations
plt.errorbar(df['year'], df['obs_mean'], yerr=[df['obs_mean']-df['obs_lci'], df['obs_uci']-df['obs_mean']],
fmt='none', color='r', ecolor='r', elinewidth=1, capsize=0)
# Plot styles
plt.xlabel('Year')
plt.ylabel('Positive Response (%)')
plt.legend(ncol=4, loc="upper center", bbox_to_anchor=(0.5, 1.09),
frameon=False, shadow=False, prop = { "size": 10 }, columnspacing=0.1, handletextpad=0.01)
plt.title(r"$\bf{" + f'({var})' + "}$" + f' {var_desc}', loc='left', fontdict={'fontsize': 10}, wrap=True, y=1.06)
plt.grid(which='major', color='lightgray', linestyle='-', alpha=0.2)
plt.grid(which='minor', color='lightgray', linestyle='-', alpha=0.4)
plt.minorticks_on()
# plt.yticks(np.arange(20, 108, 20))
ax.yaxis.set_minor_locator(MultipleLocator(10))
ax.xaxis.set_minor_locator(MultipleLocator(5))
plt.ylim(range_y)
plt.xlim((1970, 2023))
fn = f'{var}.png'
plt.savefig(fn)
return fig, df, fn
def assign_current_var(x):
st.session_state['current_var'] = x
def view_help():
st.session_state['current_var'] = None
st.session_state['search_query'] = None
st.session_state['current_page'] = 1
def main():
'''Main function to run the Streamlit app'''
st.set_page_config(page_title="AI-Augmented Surveys: Retrodiction Demo", layout="wide")
# CSS to inject contained in a string
hide_table_row_index = """
<style>
thead tr th:first-child {display:none}
tbody th {display:none}
</style>
"""
# Inject CSS with Markdown to remove index from dataframes
st.markdown(hide_table_row_index, unsafe_allow_html=True)
with st.sidebar:
st.title("AI-Augmented Surveys")
st.markdown("*Retrodiction Demo*", unsafe_allow_html=True)
help_click = st.button("Help", on_click=view_help)
n_var_per_page = 25 # parameters for search results
n_cards_per_row = 4
if st.session_state['current_var'] is None: # When there is no variable that the user clicked in the search results
col1, col2, col3 = st.columns([3, 1, 1])
with col1:
search_query = st.text_input('Search GSS variables')
with col2:
search_criteria = st.selectbox("Criteria", ['All', 'Variable names', 'Variable descriptions', 'Survey questions', 'GSS tags'])
with col3:
search_filter = st.multiselect("Filter by GSS tags", count_meta['option'])
search_button = st.button("Search")
if search_query or search_button: # Perform search when the search button is clicked
st.session_state['search_query'] = search_query
st.session_state['current_page'] = 1
if help_click or st.session_state['search_query'] is None:
help1, help2, help3, help4 = st.columns([1, 1, 1, 1])
with help1:
st.write("---")
st.subheader("📈")
st.markdown("This demo site presents public opinion trends as predicted by our General Social Survey (GSS)-based AI language models, specifically Alpaca-7b fine-tuned for retrodiction tasks. Users can search for variables of interest using keywords (e.g., “gay”) or by filtering using GSS tags. After selecting a variable, click on “View this variable” to explore further.")
with help2:
st.write("---")
st.subheader("🗃")
st.markdown("Take advantage of the different tabs available! Here, you’ll find information about the predictive accuracy of our models, and how we binarize response options into 1 or 0. You can download data in CSV or image format as well.")
with help3:
st.write("---")
st.subheader("🧐")
st.markdown("Please note that not all GSS variables are included in our system. We only feature those with binarized options, excluding those with numerous categories and many others (see our paper for more details). We’d be glad to investigate why those variables aren’t visible.")
with help4:
st.write("---")
st.subheader("⚙️")
st.markdown("Currently, we’re refining our model to incorporate the 2022 GSS survey data and additional GSS variables. Stay tuned for updates!")
else:
results = search_data(st.session_state['search_query'], search_criteria, search_filter).reset_index(drop=True)
if results.empty:
st.warning('No variables found.')
else:
page = st.session_state['current_page'] # print search results
st.success(f'Found {len(results)} variable(s).')
col21, col22, col23, col24 = st.columns([4, 2, 2, 2])
with col23:
n_var_per_page = st.selectbox('Variables per page', [25, 50, 100, 200])
maxpage = int(np.ceil(len(results) / n_var_per_page))
with col24:
st.session_state['current_page'] = st.selectbox('Move to another page', list(range(1, maxpage+1)))
page = st.session_state['current_page']
with col22:
sortby = st.selectbox('Sort by', ['Relevance', 'AUC'])
with col21:
st.markdown(f'Page {page} of {maxpage}')
button_list = []
if sortby == 'AUC':
results = results.loc[pd.notnull(results['var_name'])]
results = results.merge(performance_partial[['var_name', 'auc']], on='var_name', how='left').sort_values('auc', ascending=False).reset_index(drop=True)
else:
results = results.loc[pd.notnull(results['var_name'])]
for n_row, row in results.loc[(page-1)*n_var_per_page:page*n_var_per_page-1].reset_index().iterrows():
i = n_row % n_cards_per_row
if i == 0:
cols = st.columns(n_cards_per_row, gap="large")
# draw the card
with cols[n_row % n_cards_per_row]:
st.write("---")
st.markdown(f"**{row['var_name'].strip()}**", unsafe_allow_html=True)
st.markdown(f"**Variable description:** *{row['var_description'].strip()}*", unsafe_allow_html=True)
st.markdown(f"**Survey question:** {row['question']}", unsafe_allow_html=True)
st.markdown(f"**GSS tags:** {row['subject']}", unsafe_allow_html=True)
button_list.append(st.button('View this variable', key=row['var_name'],
on_click=assign_current_var,
args=(row['original_var_name'], )))
for i in range(len(button_list)):
if button_list[i]:
current_variable = str(results.loc[i, 'var_name'])
else: # When there is a variable that the user clicked in the search results
st.button('Back', on_click=assign_current_var, args=(None, )) # If user click back, return to the search results
col31, col32 = st.columns([2, 3])
with col31:
st.subheader(f"**{st.session_state['current_var']}**")
var_desc = data.loc[data['var_name'] == st.session_state['current_var'], 'var_description'].tolist()[0]
st.markdown(f"**Variable description:** *{var_desc}*", unsafe_allow_html=True)
st.markdown(f"**Survey question:** {data.loc[data['var_name'] == st.session_state['current_var'], 'question'].tolist()[0]}", unsafe_allow_html=True)
st.markdown(f"**GSS tags:** {data.loc[data['var_name'] == st.session_state['current_var'], 'subject'].tolist()[0]}", unsafe_allow_html=True)
with col32:
tab1, tab2, tab3, tab4 = st.tabs(["📈 Retrodiction Chart", "🏆 Accuracy", "✏️ Binarization", "🗃 Retrodiction Data"])
var = st.session_state['current_var']
fig, df, fn = get_figure(var, var_desc)
tab1.pyplot(fig)
tab1.markdown("*Note:* The generalized additive model has been used to estimate the trend. We define the correct prediction when the prediction interval within 3% margin of error includes the observed estimate. ")
with open(fn, "rb") as img:
tab1.download_button(
label="Download image",
data=img,
file_name=fn,
mime="image/png"
)
tab4.table(df[['year', 'mean', 'obs_mean']])
tab4.download_button(
label="Download data as CSV",
data=convert_df(df[['year', 'mean', 'obs_mean']]),
file_name=f'{var}.csv',
mime='text/csv',
)
current_var_perform = performance_partial.loc[performance_partial.var_name == var, ['auc', 'accuracy', 'f1']].reset_index(drop=True)
tab2.table(current_var_perform)
if len(current_var_perform) == 0:
tab2.markdown('*Note:* If the variable has only been measured for a single year, there are no accuracy metrics available for retrodiction tasks. For information regarding the simulated AUC of a variable measured once, please refer to the paper.')
tab3.table(binary.loc[binary.variable == var, ['binarized', 'response']].sort_values(['binarized', 'response']).reset_index(drop=True))
if __name__ == '__main__':
matplotlib.rcParams.update({'font.size': 14})
if 'current_var' not in st.session_state:
st.session_state['current_var'] = None
if 'search_query' not in st.session_state:
st.session_state['search_query'] = None
if 'current_page' not in st.session_state:
st.session_state['current_page'] = 1
if 'show_help' not in st.session_state:
st.session_state['show_help'] = False
# load data
data = pd.read_parquet('var_meta.parquet')
subject = pd.read_parquet('subject_meta.parquet')
count_meta = pd.read_parquet('count_meta.parquet')
dt_summary1 = pd.read_parquet('dt_summary1_alpaca.parquet')
performance_partial = pd.read_parquet('performance_partial.parquet')
binary = pd.read_parquet('binary.parquet')
binary['binarized'] = binary['binarized'].astype(int)
df_prop = pd.read_parquet('df_prop.parquet')
# exclude variables if retrodiction cases are not available
to_be_excluded = dt_summary1.loc[(dt_summary1.pred_type == 'obsbin')].groupby('variable')['pred_type'].count().reset_index()
to_be_excluded = to_be_excluded.loc[to_be_excluded.pred_type == 1, 'variable']
performance_partial = performance_partial.loc[~performance_partial.var_name.isin(to_be_excluded)]
to_be_excluded = df_prop.loc[(df_prop.prop >= 0.5) & (df_prop.sb == False), 'variable'].tolist()
to_be_excluded += ['partyid1', 'partyid2', 'partyid3']
data = data.loc[~data.var_name.isin(to_be_excluded)]
dt_count = subject.loc[subject.var_name.isin(data.var_name)].subject.value_counts().reset_index()
dt_count['option'] = dt_count.subject + ' (' + dt_count['count'].astype(str) + ')'
count_meta = dt_count.sort_values('subject')
main()