Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import streamlit.components.v1 as components | |
| import shap | |
| from datashap import DataSHAP as ds | |
| import matplotlib.pyplot as plt | |
| import warnings | |
| warnings.simplefilter("ignore", category=DeprecationWarning) | |
| def select_fala(): | |
| if st.session_state.df_falas['selection']['rows']: | |
| st.session_state.num_fala = st.session_state.df_falas['selection']['rows'][0] | |
| num = st.session_state.num_fala | |
| df=st.session_state[st.session_state.empresa][st.session_state.trimestre-1].df | |
| rotulo = df.iloc[num]['tag'] | |
| option_map = {'Neutro':'NEUTRAL', 'Positivo':'POSITIVE', 'Negativo':'NEGATIVE',} | |
| st.session_state.rotulo = option_map[rotulo] | |
| def get_dataSHAP(file, company, trim): | |
| shap_value=ds.DataSHAP(file, company, trim) | |
| shap_value.df['tag'] = shap_value.df['tag'].replace({'POSITIVE':'Positivo', 'NEGATIVE':'Negativo', 'NEUTRAL':'Neutro'}) | |
| return shap_value | |
| def init_session(key, val): | |
| if key not in st.session_state: | |
| st.session_state[key] = [] | |
| for i in range(1,5): | |
| arquivo=f'spaces/marcossuzuki/TCC_PoliUSPPro/transcrição audio RI/{val}/valores_shap-{key}{i}t24.save' | |
| shap_value = get_dataSHAP(arquivo, empresa_dict[key], i) | |
| st.session_state[key].append(shap_value) | |
| st.set_page_config(page_title="TCCPoliUSPPro", ) | |
| pasta = {'vale':'VALE', 'petr':'Petrobras', 'bb':'BB'} | |
| empresa_dict = {'petr':'Petrobras', 'vale':'Vale', 'bb':'Banco do Brasil'} | |
| option_map = {'NEUTRAL':'Neutro', 'POSITIVE':'Positivo', 'NEGATIVE':'Negativo',} | |
| shap_values = {} | |
| title_score = ['positive_score', 'negative_score', 'neutral_score'] | |
| for key, val in pasta.items(): | |
| init_session(key, val) | |
| shap_values[key] = st.session_state[key] | |
| st.header("Sentimento da fala e Scores") | |
| col1, col2, col3, col4 = st.columns([1.7,1.2,1.2,2], gap="small", vertical_alignment="bottom") | |
| empresa = col1.selectbox( | |
| "**Qual empresa quer analisar:**", | |
| ("vale", "bb", "petr"), | |
| format_func=lambda option: empresa_dict[option], | |
| key='empresa', | |
| ) | |
| trim = col2.number_input("**Trimestre de 2024:**", 1, max_value = 4, key='trimestre') | |
| text_num = col3.number_input( | |
| "**Fala número:**", | |
| 0, max_value = len(shap_values[empresa][trim-1].shap_value)-1, | |
| key='num_fala',) | |
| df=shap_values[empresa][trim-1].df | |
| total_tokens, h, m, s = shap_values[empresa][trim-1].get_performance() | |
| col4.write(f"**Total tokens:** {total_tokens} \ | |
| \n**Compute time:** {h}h {m}m {s:.2}s") | |
| tab1, tab2, tab3 = st.tabs(["**Data Frame**", "**Estatística Score**", '**Gráfico Estatística**']) | |
| with tab1: | |
| st.dataframe(df.style.highlight_max(axis = 1, color ='lightgreen', | |
| subset = title_score), | |
| selection_mode = 'single-row', | |
| key='df_falas', | |
| on_select=select_fala, | |
| column_config={'speech':st.column_config.Column('Fala', width=100), | |
| 'qty_tokens':st.column_config.NumberColumn("Qtde. Tokens", format='%d'), | |
| 'positive_score':st.column_config.NumberColumn("Score Positivo",), | |
| 'negative_score':st.column_config.NumberColumn("Score Negativo",), | |
| 'neutral_score':st.column_config.NumberColumn("Score Neutro",), | |
| 'tag':"Rótulo", | |
| }, | |
| height=200,) | |
| with tab2: | |
| st.dataframe(shap_values[empresa][trim-1].statistic, ) | |
| with tab3: | |
| st.plotly_chart(shap_values[empresa][trim-1].plot) | |
| score_positive, score_negative, score_neutral = df.loc[text_num, title_score] | |
| rotulo = st.radio( | |
| "**Rótulo**", | |
| option_map.keys(), | |
| horizontal=True, | |
| format_func=lambda option: option_map[option], | |
| captions = [f'{score_neutral:.4}', f'{score_positive:.4}', f'{score_negative:.4}'], | |
| key='rotulo' | |
| ) | |
| plot_text = shap_values[empresa][trim-1].shap_plot_text(text_num, rotulo) | |
| components.html(plot_text, height = 180, scrolling = True) | |
| st.header("Gráfico waterfall dos termos e Valores de Shapley") | |
| with st.expander("Expand"): | |
| max_display = st.slider( | |
| "**Máximo de exibição:**", | |
| 1, max_value = int(df['qty_tokens'][text_num]), | |
| value=int(int(df['qty_tokens'][text_num])/3)+1 | |
| ) | |
| plot_waterfall = shap_values[empresa][trim-1].shap_waterfall(text_num, rotulo, max_display) | |
| st.pyplot(plot_waterfall) | |
| st.header('Rank de termos do documento em Gráfico Barra') | |
| with st.expander("Expand"): | |
| plot_bar, ax, rank = shap_values[empresa][trim-1].get_plot_rank() | |
| for key, val in option_map.items(): | |
| st.subheader(val) | |
| st.pyplot(plot_bar[key]) | |
| st.dataframe(rank) |