Spaces:
Sleeping
Sleeping
Commit ·
acf89de
1
Parent(s): ae315c9
Using oop
Browse files- src/datashap/DataSHAP.py +148 -0
- src/datashap/__init__.py +0 -0
- src/streamlit_app.py +28 -137
src/datashap/DataSHAP.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
from datetime import timedelta
|
| 3 |
+
import re
|
| 4 |
+
import plotly.express as px
|
| 5 |
+
from plotly.subplots import make_subplots
|
| 6 |
+
import numpy as np
|
| 7 |
+
from pandas import DataFrame
|
| 8 |
+
import shap
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
from huggingface_hub import HfFileSystem
|
| 11 |
+
import warnings
|
| 12 |
+
warnings.simplefilter("ignore", category=DeprecationWarning)
|
| 13 |
+
|
| 14 |
+
class DataSHAP:
|
| 15 |
+
|
| 16 |
+
def __init__(self, file_addr, company, trim):
|
| 17 |
+
self.file_addr = file_addr
|
| 18 |
+
self.shap_value = self.load_file(self.file_addr)
|
| 19 |
+
self.df = self.mount_df(self.shap_value)
|
| 20 |
+
self.total_tokens, self.h, self.m, self.s = self.calc_performance(self.shap_value.compute_time,
|
| 21 |
+
self.df['qty_tokens'])
|
| 22 |
+
self.statistic = self.get_statistic(self.df)
|
| 23 |
+
self.trim = trim
|
| 24 |
+
self.company = company
|
| 25 |
+
self.plot = self.three_plot(self.df)
|
| 26 |
+
self.plot_bar, self.axis, self.rank = self.plot_rank()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def load_file(self, file_addr):
|
| 31 |
+
shap_value = 0
|
| 32 |
+
fs = HfFileSystem()
|
| 33 |
+
with fs.open(file_addr, 'rb') as inp:
|
| 34 |
+
shap_value = pickle.load(inp)
|
| 35 |
+
inp.close()
|
| 36 |
+
return shap_value
|
| 37 |
+
|
| 38 |
+
def join_text(self, list):
|
| 39 |
+
return np.array([''.join(x) for x in list])
|
| 40 |
+
|
| 41 |
+
def count_token(self, list):
|
| 42 |
+
return np.array([len(x) for x in list])
|
| 43 |
+
|
| 44 |
+
def calc_allscore(self, shap_value):
|
| 45 |
+
base_values = DataFrame(shap_value.base_values, columns=['POSITIVE', 'NEGATIVE', 'NEUTRAL'])
|
| 46 |
+
df = DataFrame()
|
| 47 |
+
df['NEUTRAL'] = DataFrame(shap_value[:,:,'NEUTRAL'].values,
|
| 48 |
+
columns=['NEUTRAL'])['NEUTRAL'].apply(lambda x: x.sum()) + base_values['NEUTRAL']
|
| 49 |
+
df['POSITIVE'] = DataFrame(shap_value[:,:,'POSITIVE'].values,
|
| 50 |
+
columns=['POSITIVE'])['POSITIVE'].apply(lambda x: x.sum()) + base_values['POSITIVE']
|
| 51 |
+
df['NEGATIVE'] = DataFrame(shap_value[:,:,'NEGATIVE'].values,
|
| 52 |
+
columns=['NEGATIVE'])['NEGATIVE'].apply(lambda x: x.sum()) + base_values['NEGATIVE']
|
| 53 |
+
return df
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def mount_df(self, shap_value):
|
| 57 |
+
scores = self.calc_allscore(shap_value)
|
| 58 |
+
text = self.join_text(shap_value.data)
|
| 59 |
+
token_qty = self.count_token(shap_value.data)
|
| 60 |
+
df = np.stack((text, token_qty), axis=-1)
|
| 61 |
+
df = np.hstack((df, scores))
|
| 62 |
+
df = DataFrame(df, columns=['speech', 'qty_tokens',
|
| 63 |
+
'neutral_score', 'positive_score', 'negative_score'])
|
| 64 |
+
title_score = ['positive_score', 'negative_score', 'neutral_score']
|
| 65 |
+
df = self.df_idxmax_score(df, title_score)
|
| 66 |
+
return df
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def df_idxmax_score(self, df, title_score):
|
| 70 |
+
df[title_score] = df[title_score].astype('float')
|
| 71 |
+
df['tag'] = df[title_score].idxmax(axis="columns")
|
| 72 |
+
df['score'] = df[title_score].max(axis='columns')
|
| 73 |
+
df['tag'] = df['tag'].replace({'positive_score': 'POSITIVE', 'negative_score': 'NEGATIVE', 'neutral_score': 'NEUTRAL'})
|
| 74 |
+
return df
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def calc_performance(self, time_s, df:DataFrame):
|
| 78 |
+
total_tokens = df.astype('int64').sum()
|
| 79 |
+
proc_time=timedelta(seconds=time_s)
|
| 80 |
+
h,m,s = re.split(':', str(proc_time))
|
| 81 |
+
return total_tokens, h, m, s
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_statistic(self, df)->DataFrame:
|
| 85 |
+
statistic = DataFrame()
|
| 86 |
+
statistic['positive_score'] = df[df['tag']=='POSITIVE']['score'].describe()
|
| 87 |
+
statistic['negative_score'] = df[df['tag']=='NEGATIVE']['score'].describe()
|
| 88 |
+
statistic['neutral_score'] = df[df['tag']=='NEUTRAL']['score'].describe()
|
| 89 |
+
return statistic
|
| 90 |
+
|
| 91 |
+
def get_performance(self):
|
| 92 |
+
return self.total_tokens, self.h, self.m, self.s
|
| 93 |
+
|
| 94 |
+
def three_plot(self, df):
|
| 95 |
+
fig = make_subplots(rows=2, cols=2, horizontal_spacing = 0.0, vertical_spacing = 0.05,
|
| 96 |
+
shared_xaxes=True, shared_yaxes=True,
|
| 97 |
+
row_heights=[0.4, 0.6], column_widths=[0.8, 0.2])
|
| 98 |
+
|
| 99 |
+
fig_scatter = px.scatter(df, x=df.index, y=['score'], color="tag",)
|
| 100 |
+
fig_histogram = px.histogram(df, x=df.index, color='tag', nbins=20,)
|
| 101 |
+
fig_box = px.box(df, x='tag', y="score", color='tag',)
|
| 102 |
+
|
| 103 |
+
fig_scatter.data[1]['marker']={'color': '#000007'}
|
| 104 |
+
fig_histogram.data[1]['marker']={'color': '#000007', 'pattern': {'shape': ''}}
|
| 105 |
+
fig_box.data[1]['marker']={'color': '#000007'}
|
| 106 |
+
|
| 107 |
+
for x in range(3):
|
| 108 |
+
fig_histogram.data[x]['showlegend']=False
|
| 109 |
+
fig_box.data[x]['showlegend']=False
|
| 110 |
+
fig.add_trace(fig_scatter.data[x], row=2, col=1)
|
| 111 |
+
fig.add_trace(fig_histogram.data[x], row=1, col=1)
|
| 112 |
+
fig.add_trace(fig_box.data[x], row=2, col=2,)
|
| 113 |
+
|
| 114 |
+
fig.update_layout(barmode='overlay', title=f'''Estatísticas: {self.company}<br>Trimestre {self.trim} de 2024''',
|
| 115 |
+
xaxis3_rangeslider=dict(visible=True, bgcolor="#636EFA", thickness=0.03),
|
| 116 |
+
legend=dict(orientation="h", yanchor="top",
|
| 117 |
+
y=1.3, xanchor="center", x=0.5),
|
| 118 |
+
scene = dict(yaxis = dict(title=''),))
|
| 119 |
+
fig.update_xaxes(showticklabels=False, showgrid=True, row=1, col=1)
|
| 120 |
+
fig.update_xaxes(title_text='#Fala', showgrid=True, row=2, col=1)
|
| 121 |
+
fig.update_yaxes(title_text='Score', row=2, col=1)
|
| 122 |
+
fig.update_yaxes(title_text='Frequência', row=1, col=1)
|
| 123 |
+
fig.update_traces(marker={"opacity": 0.7})
|
| 124 |
+
return fig
|
| 125 |
+
|
| 126 |
+
def plot_rank(self, tag={'NEUTRAL':'Neutro',
|
| 127 |
+
'POSITIVE':'Positivo',
|
| 128 |
+
'NEGATIVE':'Negativo',},
|
| 129 |
+
max_display=11):
|
| 130 |
+
plot_bar = dict()
|
| 131 |
+
axis = dict()
|
| 132 |
+
rank = DataFrame()
|
| 133 |
+
for key, val in tag.items():
|
| 134 |
+
plot, ax = plt.subplots()
|
| 135 |
+
shap.plots.bar(self.shap_value[:,:,key], show=False, max_display=max_display,)
|
| 136 |
+
plot_bar[key] = plot
|
| 137 |
+
axis[key] = ax
|
| 138 |
+
rank[val] = DataFrame(ax.get_yticklabels()[:-max_display-1]).astype(str)
|
| 139 |
+
plt.close()
|
| 140 |
+
|
| 141 |
+
rank[list(tag.values())] = rank[list(tag.values())].replace(r"(([T])\w+|(\d+,)|(\d+.\d+,)|(['\(\)\s]))", '', regex=True)
|
| 142 |
+
|
| 143 |
+
return plot_bar, axis, rank
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def get_plot_rank(self, max_display=11):
|
| 147 |
+
self.plot_bar, self.axis, self.rank = self.plot_rank(max_display=max_display)
|
| 148 |
+
return self.plot_bar, self.axis, self.rank
|
src/datashap/__init__.py
ADDED
|
File without changes
|
src/streamlit_app.py
CHANGED
|
@@ -1,133 +1,34 @@
|
|
| 1 |
-
import pickle
|
| 2 |
-
from datetime import timedelta
|
| 3 |
-
import re
|
| 4 |
-
import plotly.express as px
|
| 5 |
-
from plotly.subplots import make_subplots
|
| 6 |
-
import numpy as np
|
| 7 |
-
import pandas as pd
|
| 8 |
import streamlit as st
|
| 9 |
import streamlit.components.v1 as components
|
| 10 |
import shap
|
|
|
|
| 11 |
import matplotlib.pyplot as plt
|
| 12 |
-
from huggingface_hub import HfFileSystem
|
| 13 |
import warnings
|
| 14 |
warnings.simplefilter("ignore", category=DeprecationWarning)
|
| 15 |
|
| 16 |
-
fs = HfFileSystem()
|
| 17 |
-
|
| 18 |
-
def join_text(lista):
|
| 19 |
-
return np.array([''.join(x) for x in lista])
|
| 20 |
-
|
| 21 |
-
def count_token(lista):
|
| 22 |
-
return np.array([len(x) for x in lista])
|
| 23 |
-
|
| 24 |
-
def calc_allscore(lista):
|
| 25 |
-
base_values = pd.DataFrame(lista.base_values, columns=['POSITIVE', 'NEGATIVE', 'NEUTRAL'])
|
| 26 |
-
a = pd.DataFrame()
|
| 27 |
-
a['NEUTRAL'] = pd.DataFrame(lista[:,:,'NEUTRAL'].values, columns=['NEUTRAL'])['NEUTRAL'].apply(lambda x: x.sum()) + base_values['NEUTRAL']
|
| 28 |
-
a['POSITIVE'] = pd.DataFrame(lista[:,:,'POSITIVE'].values, columns=['POSITIVE'])['POSITIVE'].apply(lambda x: x.sum()) + base_values['POSITIVE']
|
| 29 |
-
a['NEGATIVE'] = pd.DataFrame(lista[:,:,'NEGATIVE'].values, columns=['NEGATIVE'])['NEGATIVE'].apply(lambda x: x.sum()) + base_values['NEGATIVE']
|
| 30 |
-
return a
|
| 31 |
|
| 32 |
def select_fala():
|
| 33 |
if st.session_state.df_falas['selection']['rows']:
|
| 34 |
st.session_state.num_fala = st.session_state.df_falas['selection']['rows'][0]
|
| 35 |
num = st.session_state.num_fala
|
| 36 |
-
df=st.session_state[
|
| 37 |
-
rotulo = df.iloc[num]['
|
| 38 |
option_map = {'Neutro':'NEUTRAL', 'Positivo':'POSITIVE', 'Negativo':'NEGATIVE',}
|
| 39 |
st.session_state.rotulo = option_map[rotulo]
|
| 40 |
|
| 41 |
@st.cache_resource
|
| 42 |
-
def
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
rank = pd.DataFrame()
|
| 46 |
-
for key, val in rotulo.items():
|
| 47 |
-
plot_bar, ax = plt.subplots()
|
| 48 |
-
shap.plots.bar(shap_values[empresa][trim][:,:,key], show=False, max_display=max_display,)
|
| 49 |
-
plot[key] = plot_bar
|
| 50 |
-
axis[key] = ax
|
| 51 |
-
rank[val] = pd.DataFrame(ax.get_yticklabels()[:-max_display-1]).astype(str)
|
| 52 |
-
|
| 53 |
-
rank[list(rotulo.values())] = rank[list(rotulo.values())].replace(r"(([T])\w+|(\d+,)|(\d+.\d+,)|(['\(\)\s]))", '', regex=True)
|
| 54 |
-
|
| 55 |
-
return plot, axis, rank
|
| 56 |
-
|
| 57 |
-
@st.cache_data
|
| 58 |
-
def load_file(arquivo):
|
| 59 |
-
shap_value = 0
|
| 60 |
-
with fs.open(arquivo, 'rb') as inp:
|
| 61 |
-
shap_value = pickle.load(inp)
|
| 62 |
-
inp.close()
|
| 63 |
return shap_value
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
df = pd.DataFrame(df, columns=['fala', 'qtde_tokens',
|
| 73 |
-
'neutral_score', 'positive_score', 'negative_score'])
|
| 74 |
-
return df
|
| 75 |
-
|
| 76 |
-
@st.cache_data
|
| 77 |
-
def calc_performance(time_s, df):
|
| 78 |
-
total_tokens = df.astype('int64').sum()
|
| 79 |
-
proc_time=timedelta(seconds=time_s)
|
| 80 |
-
h,m,s = re.split(':', str(proc_time))
|
| 81 |
-
return total_tokens, h, m, s
|
| 82 |
-
|
| 83 |
-
@st.cache_data
|
| 84 |
-
def get_statistic(df):
|
| 85 |
-
estatistica = pd.DataFrame()
|
| 86 |
-
estatistica['Positivo'] = df[df['rotulo']=='Positivo']['score'].describe()
|
| 87 |
-
estatistica['Negativo'] = df[df['rotulo']=='Negativo']['score'].describe()
|
| 88 |
-
estatistica['Neutro'] = df[df['rotulo']=='Neutro']['score'].describe()
|
| 89 |
-
return estatistica
|
| 90 |
-
|
| 91 |
-
@st.cache_resource
|
| 92 |
-
def three_plot(df):
|
| 93 |
-
fig = make_subplots(rows=2, cols=2, horizontal_spacing = 0.0, vertical_spacing = 0.05,
|
| 94 |
-
shared_xaxes=True, shared_yaxes=True,
|
| 95 |
-
row_heights=[0.4, 0.6], column_widths=[0.8, 0.2])
|
| 96 |
-
|
| 97 |
-
fig_scatter = px.scatter(df, x=df.index, y=['score'], color="rotulo",)
|
| 98 |
-
fig_histogram = px.histogram(df, x=df.index, color='rotulo', nbins=20,)
|
| 99 |
-
fig_box = px.box(df, x='rotulo', y="score", color='rotulo',)
|
| 100 |
-
|
| 101 |
-
fig_scatter.data[1]['marker']={'color': '#000007'}
|
| 102 |
-
fig_histogram.data[1]['marker']={'color': '#000007', 'pattern': {'shape': ''}}
|
| 103 |
-
fig_box.data[1]['marker']={'color': '#000007'}
|
| 104 |
-
|
| 105 |
-
for x in range(3):
|
| 106 |
-
fig_histogram.data[x]['showlegend']=False
|
| 107 |
-
fig_box.data[x]['showlegend']=False
|
| 108 |
-
fig.add_trace(fig_scatter.data[x], row=2, col=1)
|
| 109 |
-
fig.add_trace(fig_histogram.data[x], row=1, col=1)
|
| 110 |
-
fig.add_trace(fig_box.data[x], row=2, col=2,)
|
| 111 |
-
|
| 112 |
-
fig.update_layout(barmode='overlay', title=f'''Estatísticas: {empresa_dict[empresa]}<br>Trimestre {trim} de 2024''',
|
| 113 |
-
xaxis3_rangeslider=dict(visible=True, bgcolor="#636EFA", thickness=0.03),
|
| 114 |
-
legend=dict(orientation="h", yanchor="top",
|
| 115 |
-
y=1.3, xanchor="center", x=0.5),
|
| 116 |
-
scene = dict(yaxis = dict(title=''),))
|
| 117 |
-
fig.update_xaxes(showticklabels=False, showgrid=True, row=1, col=1)
|
| 118 |
-
fig.update_xaxes(title_text='#Fala', showgrid=True, row=2, col=1)
|
| 119 |
-
fig.update_yaxes(title_text='Score', row=2, col=1)
|
| 120 |
-
fig.update_yaxes(title_text='Frequência', row=1, col=1)
|
| 121 |
-
fig.update_traces(marker={"opacity": 0.7})
|
| 122 |
-
return fig
|
| 123 |
-
|
| 124 |
-
@st.cache_data
|
| 125 |
-
def df_idxmax_score(df, empresa, trim, title_score):
|
| 126 |
-
df[title_score] = df[title_score].astype('float')
|
| 127 |
-
df['rotulo'] = df[title_score].idxmax(axis="columns")
|
| 128 |
-
df['score'] = df[title_score].max(axis='columns')
|
| 129 |
-
df['rotulo'] = df['rotulo'].replace({'positive_score': 'Positivo', 'negative_score': 'Negativo', 'neutral_score': 'Neutro'})
|
| 130 |
-
return df
|
| 131 |
|
| 132 |
st.set_page_config(page_title="TCCPoliUSPPro", )
|
| 133 |
|
|
@@ -138,16 +39,7 @@ shap_values = {}
|
|
| 138 |
title_score = ['positive_score', 'negative_score', 'neutral_score']
|
| 139 |
|
| 140 |
for key, val in pasta.items():
|
| 141 |
-
|
| 142 |
-
st.session_state[key] = []
|
| 143 |
-
st.session_state[f'df_{key}']=[]
|
| 144 |
-
for i in range(1,5):
|
| 145 |
-
arquivo=f'spaces/marcossuzuki/TCC_PoliUSPPro/transcrição audio RI/{val}/valores_shap-{key}{i}t24.save'
|
| 146 |
-
shap_value = load_file(arquivo)
|
| 147 |
-
df = mount_df(shap_value, val, i)
|
| 148 |
-
df = df_idxmax_score(df, val, i, title_score)
|
| 149 |
-
st.session_state[key].append(shap_value)
|
| 150 |
-
st.session_state[f'df_{key}'].append(df)
|
| 151 |
shap_values[key] = st.session_state[key]
|
| 152 |
|
| 153 |
st.header("Sentimento da fala e Scores")
|
|
@@ -165,12 +57,12 @@ trim = col2.number_input("**Trimestre de 2024:**", 1, max_value = 4, key='trimes
|
|
| 165 |
|
| 166 |
text_num = col3.number_input(
|
| 167 |
"**Fala número:**",
|
| 168 |
-
0, max_value = len(shap_values[empresa][trim-1])-1,
|
| 169 |
key='num_fala',)
|
| 170 |
|
| 171 |
-
df=
|
| 172 |
|
| 173 |
-
total_tokens, h, m, s =
|
| 174 |
|
| 175 |
col4.write(f"**Total tokens:** {total_tokens} \
|
| 176 |
\n**Compute time:** {h}h {m}m {s:.2}s")
|
|
@@ -182,22 +74,20 @@ with tab1:
|
|
| 182 |
selection_mode = 'single-row',
|
| 183 |
key='df_falas',
|
| 184 |
on_select=select_fala,
|
| 185 |
-
column_config={'
|
| 186 |
-
'
|
| 187 |
'positive_score':st.column_config.NumberColumn("Score Positivo",),
|
| 188 |
'negative_score':st.column_config.NumberColumn("Score Negativo",),
|
| 189 |
'neutral_score':st.column_config.NumberColumn("Score Neutro",),
|
| 190 |
-
'
|
| 191 |
},
|
| 192 |
height=200,)
|
| 193 |
|
| 194 |
with tab2:
|
| 195 |
-
|
| 196 |
-
st.dataframe(estatistica, )
|
| 197 |
|
| 198 |
with tab3:
|
| 199 |
-
|
| 200 |
-
st.plotly_chart(fig)
|
| 201 |
|
| 202 |
|
| 203 |
score_positive, score_negative, score_neutral = df.loc[text_num, title_score]
|
|
@@ -211,7 +101,7 @@ rotulo = st.radio(
|
|
| 211 |
key='rotulo'
|
| 212 |
)
|
| 213 |
|
| 214 |
-
plot_text = shap.plots.text(shap_values[empresa][trim-1][text_num, :, rotulo], display = False)
|
| 215 |
components.html(plot_text, height = 180, scrolling = True)
|
| 216 |
|
| 217 |
st.header("Gráfico waterfall dos termos e Valores de Shapley")
|
|
@@ -219,18 +109,19 @@ st.header("Gráfico waterfall dos termos e Valores de Shapley")
|
|
| 219 |
with st.expander("Expand"):
|
| 220 |
max_display = st.slider(
|
| 221 |
"**Máximo de exibição:**",
|
| 222 |
-
1, max_value = int(df['
|
| 223 |
-
value=int(int(df['
|
| 224 |
)
|
| 225 |
|
| 226 |
plot_waterfall, ax = plt.subplots()
|
| 227 |
-
shap.plots.waterfall(shap_values[empresa][trim-1][text_num, :, rotulo], show=False, max_display=max_display)
|
| 228 |
st.pyplot(plot_waterfall)
|
|
|
|
| 229 |
|
| 230 |
st.header('Rank de termos do documento em Gráfico Barra')
|
| 231 |
|
| 232 |
with st.expander("Expand"):
|
| 233 |
-
plot_bar, ax, rank =
|
| 234 |
for key, val in option_map.items():
|
| 235 |
st.subheader(val)
|
| 236 |
st.pyplot(plot_bar[key])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import streamlit.components.v1 as components
|
| 3 |
import shap
|
| 4 |
+
from datashap import DataSHAP as ds
|
| 5 |
import matplotlib.pyplot as plt
|
|
|
|
| 6 |
import warnings
|
| 7 |
warnings.simplefilter("ignore", category=DeprecationWarning)
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
def select_fala():
|
| 11 |
if st.session_state.df_falas['selection']['rows']:
|
| 12 |
st.session_state.num_fala = st.session_state.df_falas['selection']['rows'][0]
|
| 13 |
num = st.session_state.num_fala
|
| 14 |
+
df=st.session_state[st.session_state.empresa][st.session_state.trimestre-1].df
|
| 15 |
+
rotulo = df.iloc[num]['tag']
|
| 16 |
option_map = {'Neutro':'NEUTRAL', 'Positivo':'POSITIVE', 'Negativo':'NEGATIVE',}
|
| 17 |
st.session_state.rotulo = option_map[rotulo]
|
| 18 |
|
| 19 |
@st.cache_resource
|
| 20 |
+
def get_dataSHAP(file, company, trim):
|
| 21 |
+
shap_value=ds.DataSHAP(file, company, trim)
|
| 22 |
+
shap_value.df['tag'] = shap_value.df['tag'].replace({'POSITIVE':'Positivo', 'NEGATIVE':'Negativo', 'NEUTRAL':'Neutro'})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
return shap_value
|
| 24 |
|
| 25 |
+
def init_session(key, val):
|
| 26 |
+
if key not in st.session_state:
|
| 27 |
+
st.session_state[key] = []
|
| 28 |
+
for i in range(1,5):
|
| 29 |
+
arquivo=f'spaces/marcossuzuki/TCC_PoliUSPPro/transcrição audio RI/{val}/valores_shap-{key}{i}t24.save'
|
| 30 |
+
shap_value = get_dataSHAP(arquivo, empresa_dict[key], i)
|
| 31 |
+
st.session_state[key].append(shap_value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
st.set_page_config(page_title="TCCPoliUSPPro", )
|
| 34 |
|
|
|
|
| 39 |
title_score = ['positive_score', 'negative_score', 'neutral_score']
|
| 40 |
|
| 41 |
for key, val in pasta.items():
|
| 42 |
+
init_session(key, val)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
shap_values[key] = st.session_state[key]
|
| 44 |
|
| 45 |
st.header("Sentimento da fala e Scores")
|
|
|
|
| 57 |
|
| 58 |
text_num = col3.number_input(
|
| 59 |
"**Fala número:**",
|
| 60 |
+
0, max_value = len(shap_values[empresa][trim-1].shap_value)-1,
|
| 61 |
key='num_fala',)
|
| 62 |
|
| 63 |
+
df=shap_values[empresa][trim-1].df
|
| 64 |
|
| 65 |
+
total_tokens, h, m, s = shap_values[empresa][trim-1].get_performance()
|
| 66 |
|
| 67 |
col4.write(f"**Total tokens:** {total_tokens} \
|
| 68 |
\n**Compute time:** {h}h {m}m {s:.2}s")
|
|
|
|
| 74 |
selection_mode = 'single-row',
|
| 75 |
key='df_falas',
|
| 76 |
on_select=select_fala,
|
| 77 |
+
column_config={'speech':st.column_config.Column('Fala', width=100),
|
| 78 |
+
'qty_tokens':st.column_config.NumberColumn("Qtde. Tokens", format='%d'),
|
| 79 |
'positive_score':st.column_config.NumberColumn("Score Positivo",),
|
| 80 |
'negative_score':st.column_config.NumberColumn("Score Negativo",),
|
| 81 |
'neutral_score':st.column_config.NumberColumn("Score Neutro",),
|
| 82 |
+
'tag':"Rótulo",
|
| 83 |
},
|
| 84 |
height=200,)
|
| 85 |
|
| 86 |
with tab2:
|
| 87 |
+
st.dataframe(shap_values[empresa][trim-1].statistic, )
|
|
|
|
| 88 |
|
| 89 |
with tab3:
|
| 90 |
+
st.plotly_chart(shap_values[empresa][trim-1].plot)
|
|
|
|
| 91 |
|
| 92 |
|
| 93 |
score_positive, score_negative, score_neutral = df.loc[text_num, title_score]
|
|
|
|
| 101 |
key='rotulo'
|
| 102 |
)
|
| 103 |
|
| 104 |
+
plot_text = shap.plots.text(shap_values[empresa][trim-1].shap_value[text_num, :, rotulo], display = False)
|
| 105 |
components.html(plot_text, height = 180, scrolling = True)
|
| 106 |
|
| 107 |
st.header("Gráfico waterfall dos termos e Valores de Shapley")
|
|
|
|
| 109 |
with st.expander("Expand"):
|
| 110 |
max_display = st.slider(
|
| 111 |
"**Máximo de exibição:**",
|
| 112 |
+
1, max_value = int(df['qty_tokens'][text_num]),
|
| 113 |
+
value=int(int(df['qty_tokens'][text_num])/3)+1
|
| 114 |
)
|
| 115 |
|
| 116 |
plot_waterfall, ax = plt.subplots()
|
| 117 |
+
shap.plots.waterfall(shap_values[empresa][trim-1].shap_value[text_num, :, rotulo], show=False, max_display=max_display)
|
| 118 |
st.pyplot(plot_waterfall)
|
| 119 |
+
plt.close()
|
| 120 |
|
| 121 |
st.header('Rank de termos do documento em Gráfico Barra')
|
| 122 |
|
| 123 |
with st.expander("Expand"):
|
| 124 |
+
plot_bar, ax, rank = shap_values[empresa][trim-1].get_plot_rank()
|
| 125 |
for key, val in option_map.items():
|
| 126 |
st.subheader(val)
|
| 127 |
st.pyplot(plot_bar[key])
|