marcossuzuki commited on
Commit
acf89de
·
1 Parent(s): ae315c9

Using oop

Browse files
src/datashap/DataSHAP.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from datetime import timedelta
3
+ import re
4
+ import plotly.express as px
5
+ from plotly.subplots import make_subplots
6
+ import numpy as np
7
+ from pandas import DataFrame
8
+ import shap
9
+ import matplotlib.pyplot as plt
10
+ from huggingface_hub import HfFileSystem
11
+ import warnings
12
+ warnings.simplefilter("ignore", category=DeprecationWarning)
13
+
14
+ class DataSHAP:
15
+
16
+ def __init__(self, file_addr, company, trim):
17
+ self.file_addr = file_addr
18
+ self.shap_value = self.load_file(self.file_addr)
19
+ self.df = self.mount_df(self.shap_value)
20
+ self.total_tokens, self.h, self.m, self.s = self.calc_performance(self.shap_value.compute_time,
21
+ self.df['qty_tokens'])
22
+ self.statistic = self.get_statistic(self.df)
23
+ self.trim = trim
24
+ self.company = company
25
+ self.plot = self.three_plot(self.df)
26
+ self.plot_bar, self.axis, self.rank = self.plot_rank()
27
+
28
+
29
+
30
+ def load_file(self, file_addr):
31
+ shap_value = 0
32
+ fs = HfFileSystem()
33
+ with fs.open(file_addr, 'rb') as inp:
34
+ shap_value = pickle.load(inp)
35
+ inp.close()
36
+ return shap_value
37
+
38
+ def join_text(self, list):
39
+ return np.array([''.join(x) for x in list])
40
+
41
+ def count_token(self, list):
42
+ return np.array([len(x) for x in list])
43
+
44
+ def calc_allscore(self, shap_value):
45
+ base_values = DataFrame(shap_value.base_values, columns=['POSITIVE', 'NEGATIVE', 'NEUTRAL'])
46
+ df = DataFrame()
47
+ df['NEUTRAL'] = DataFrame(shap_value[:,:,'NEUTRAL'].values,
48
+ columns=['NEUTRAL'])['NEUTRAL'].apply(lambda x: x.sum()) + base_values['NEUTRAL']
49
+ df['POSITIVE'] = DataFrame(shap_value[:,:,'POSITIVE'].values,
50
+ columns=['POSITIVE'])['POSITIVE'].apply(lambda x: x.sum()) + base_values['POSITIVE']
51
+ df['NEGATIVE'] = DataFrame(shap_value[:,:,'NEGATIVE'].values,
52
+ columns=['NEGATIVE'])['NEGATIVE'].apply(lambda x: x.sum()) + base_values['NEGATIVE']
53
+ return df
54
+
55
+
56
+ def mount_df(self, shap_value):
57
+ scores = self.calc_allscore(shap_value)
58
+ text = self.join_text(shap_value.data)
59
+ token_qty = self.count_token(shap_value.data)
60
+ df = np.stack((text, token_qty), axis=-1)
61
+ df = np.hstack((df, scores))
62
+ df = DataFrame(df, columns=['speech', 'qty_tokens',
63
+ 'neutral_score', 'positive_score', 'negative_score'])
64
+ title_score = ['positive_score', 'negative_score', 'neutral_score']
65
+ df = self.df_idxmax_score(df, title_score)
66
+ return df
67
+
68
+
69
+ def df_idxmax_score(self, df, title_score):
70
+ df[title_score] = df[title_score].astype('float')
71
+ df['tag'] = df[title_score].idxmax(axis="columns")
72
+ df['score'] = df[title_score].max(axis='columns')
73
+ df['tag'] = df['tag'].replace({'positive_score': 'POSITIVE', 'negative_score': 'NEGATIVE', 'neutral_score': 'NEUTRAL'})
74
+ return df
75
+
76
+
77
+ def calc_performance(self, time_s, df:DataFrame):
78
+ total_tokens = df.astype('int64').sum()
79
+ proc_time=timedelta(seconds=time_s)
80
+ h,m,s = re.split(':', str(proc_time))
81
+ return total_tokens, h, m, s
82
+
83
+
84
+ def get_statistic(self, df)->DataFrame:
85
+ statistic = DataFrame()
86
+ statistic['positive_score'] = df[df['tag']=='POSITIVE']['score'].describe()
87
+ statistic['negative_score'] = df[df['tag']=='NEGATIVE']['score'].describe()
88
+ statistic['neutral_score'] = df[df['tag']=='NEUTRAL']['score'].describe()
89
+ return statistic
90
+
91
+ def get_performance(self):
92
+ return self.total_tokens, self.h, self.m, self.s
93
+
94
+ def three_plot(self, df):
95
+ fig = make_subplots(rows=2, cols=2, horizontal_spacing = 0.0, vertical_spacing = 0.05,
96
+ shared_xaxes=True, shared_yaxes=True,
97
+ row_heights=[0.4, 0.6], column_widths=[0.8, 0.2])
98
+
99
+ fig_scatter = px.scatter(df, x=df.index, y=['score'], color="tag",)
100
+ fig_histogram = px.histogram(df, x=df.index, color='tag', nbins=20,)
101
+ fig_box = px.box(df, x='tag', y="score", color='tag',)
102
+
103
+ fig_scatter.data[1]['marker']={'color': '#000007'}
104
+ fig_histogram.data[1]['marker']={'color': '#000007', 'pattern': {'shape': ''}}
105
+ fig_box.data[1]['marker']={'color': '#000007'}
106
+
107
+ for x in range(3):
108
+ fig_histogram.data[x]['showlegend']=False
109
+ fig_box.data[x]['showlegend']=False
110
+ fig.add_trace(fig_scatter.data[x], row=2, col=1)
111
+ fig.add_trace(fig_histogram.data[x], row=1, col=1)
112
+ fig.add_trace(fig_box.data[x], row=2, col=2,)
113
+
114
+ fig.update_layout(barmode='overlay', title=f'''Estatísticas: {self.company}<br>Trimestre {self.trim} de 2024''',
115
+ xaxis3_rangeslider=dict(visible=True, bgcolor="#636EFA", thickness=0.03),
116
+ legend=dict(orientation="h", yanchor="top",
117
+ y=1.3, xanchor="center", x=0.5),
118
+ scene = dict(yaxis = dict(title=''),))
119
+ fig.update_xaxes(showticklabels=False, showgrid=True, row=1, col=1)
120
+ fig.update_xaxes(title_text='#Fala', showgrid=True, row=2, col=1)
121
+ fig.update_yaxes(title_text='Score', row=2, col=1)
122
+ fig.update_yaxes(title_text='Frequência', row=1, col=1)
123
+ fig.update_traces(marker={"opacity": 0.7})
124
+ return fig
125
+
126
+ def plot_rank(self, tag={'NEUTRAL':'Neutro',
127
+ 'POSITIVE':'Positivo',
128
+ 'NEGATIVE':'Negativo',},
129
+ max_display=11):
130
+ plot_bar = dict()
131
+ axis = dict()
132
+ rank = DataFrame()
133
+ for key, val in tag.items():
134
+ plot, ax = plt.subplots()
135
+ shap.plots.bar(self.shap_value[:,:,key], show=False, max_display=max_display,)
136
+ plot_bar[key] = plot
137
+ axis[key] = ax
138
+ rank[val] = DataFrame(ax.get_yticklabels()[:-max_display-1]).astype(str)
139
+ plt.close()
140
+
141
+ rank[list(tag.values())] = rank[list(tag.values())].replace(r"(([T])\w+|(\d+,)|(\d+.\d+,)|(['\(\)\s]))", '', regex=True)
142
+
143
+ return plot_bar, axis, rank
144
+
145
+
146
+ def get_plot_rank(self, max_display=11):
147
+ self.plot_bar, self.axis, self.rank = self.plot_rank(max_display=max_display)
148
+ return self.plot_bar, self.axis, self.rank
src/datashap/__init__.py ADDED
File without changes
src/streamlit_app.py CHANGED
@@ -1,133 +1,34 @@
1
- import pickle
2
- from datetime import timedelta
3
- import re
4
- import plotly.express as px
5
- from plotly.subplots import make_subplots
6
- import numpy as np
7
- import pandas as pd
8
  import streamlit as st
9
  import streamlit.components.v1 as components
10
  import shap
 
11
  import matplotlib.pyplot as plt
12
- from huggingface_hub import HfFileSystem
13
  import warnings
14
  warnings.simplefilter("ignore", category=DeprecationWarning)
15
 
16
- fs = HfFileSystem()
17
-
18
- def join_text(lista):
19
- return np.array([''.join(x) for x in lista])
20
-
21
- def count_token(lista):
22
- return np.array([len(x) for x in lista])
23
-
24
- def calc_allscore(lista):
25
- base_values = pd.DataFrame(lista.base_values, columns=['POSITIVE', 'NEGATIVE', 'NEUTRAL'])
26
- a = pd.DataFrame()
27
- a['NEUTRAL'] = pd.DataFrame(lista[:,:,'NEUTRAL'].values, columns=['NEUTRAL'])['NEUTRAL'].apply(lambda x: x.sum()) + base_values['NEUTRAL']
28
- a['POSITIVE'] = pd.DataFrame(lista[:,:,'POSITIVE'].values, columns=['POSITIVE'])['POSITIVE'].apply(lambda x: x.sum()) + base_values['POSITIVE']
29
- a['NEGATIVE'] = pd.DataFrame(lista[:,:,'NEGATIVE'].values, columns=['NEGATIVE'])['NEGATIVE'].apply(lambda x: x.sum()) + base_values['NEGATIVE']
30
- return a
31
 
32
  def select_fala():
33
  if st.session_state.df_falas['selection']['rows']:
34
  st.session_state.num_fala = st.session_state.df_falas['selection']['rows'][0]
35
  num = st.session_state.num_fala
36
- df=st.session_state[f'df_{st.session_state.empresa}'][st.session_state.trimestre-1]
37
- rotulo = df.iloc[num]['rotulo']
38
  option_map = {'Neutro':'NEUTRAL', 'Positivo':'POSITIVE', 'Negativo':'NEGATIVE',}
39
  st.session_state.rotulo = option_map[rotulo]
40
 
41
  @st.cache_resource
42
- def plot_rank(empresa, trim, rotulo, max_display):
43
- plot = dict()
44
- axis = dict()
45
- rank = pd.DataFrame()
46
- for key, val in rotulo.items():
47
- plot_bar, ax = plt.subplots()
48
- shap.plots.bar(shap_values[empresa][trim][:,:,key], show=False, max_display=max_display,)
49
- plot[key] = plot_bar
50
- axis[key] = ax
51
- rank[val] = pd.DataFrame(ax.get_yticklabels()[:-max_display-1]).astype(str)
52
-
53
- rank[list(rotulo.values())] = rank[list(rotulo.values())].replace(r"(([T])\w+|(\d+,)|(\d+.\d+,)|(['\(\)\s]))", '', regex=True)
54
-
55
- return plot, axis, rank
56
-
57
- @st.cache_data
58
- def load_file(arquivo):
59
- shap_value = 0
60
- with fs.open(arquivo, 'rb') as inp:
61
- shap_value = pickle.load(inp)
62
- inp.close()
63
  return shap_value
64
 
65
- @st.cache_data
66
- def mount_df(_shap_value, val, i):
67
- scores = calc_allscore(_shap_value)
68
- texto_junto = join_text(_shap_value.data)
69
- token_qtde = count_token(_shap_value.data)
70
- df = np.stack((texto_junto, token_qtde), axis=-1)
71
- df = np.hstack((df, scores))
72
- df = pd.DataFrame(df, columns=['fala', 'qtde_tokens',
73
- 'neutral_score', 'positive_score', 'negative_score'])
74
- return df
75
-
76
- @st.cache_data
77
- def calc_performance(time_s, df):
78
- total_tokens = df.astype('int64').sum()
79
- proc_time=timedelta(seconds=time_s)
80
- h,m,s = re.split(':', str(proc_time))
81
- return total_tokens, h, m, s
82
-
83
- @st.cache_data
84
- def get_statistic(df):
85
- estatistica = pd.DataFrame()
86
- estatistica['Positivo'] = df[df['rotulo']=='Positivo']['score'].describe()
87
- estatistica['Negativo'] = df[df['rotulo']=='Negativo']['score'].describe()
88
- estatistica['Neutro'] = df[df['rotulo']=='Neutro']['score'].describe()
89
- return estatistica
90
-
91
- @st.cache_resource
92
- def three_plot(df):
93
- fig = make_subplots(rows=2, cols=2, horizontal_spacing = 0.0, vertical_spacing = 0.05,
94
- shared_xaxes=True, shared_yaxes=True,
95
- row_heights=[0.4, 0.6], column_widths=[0.8, 0.2])
96
-
97
- fig_scatter = px.scatter(df, x=df.index, y=['score'], color="rotulo",)
98
- fig_histogram = px.histogram(df, x=df.index, color='rotulo', nbins=20,)
99
- fig_box = px.box(df, x='rotulo', y="score", color='rotulo',)
100
-
101
- fig_scatter.data[1]['marker']={'color': '#000007'}
102
- fig_histogram.data[1]['marker']={'color': '#000007', 'pattern': {'shape': ''}}
103
- fig_box.data[1]['marker']={'color': '#000007'}
104
-
105
- for x in range(3):
106
- fig_histogram.data[x]['showlegend']=False
107
- fig_box.data[x]['showlegend']=False
108
- fig.add_trace(fig_scatter.data[x], row=2, col=1)
109
- fig.add_trace(fig_histogram.data[x], row=1, col=1)
110
- fig.add_trace(fig_box.data[x], row=2, col=2,)
111
-
112
- fig.update_layout(barmode='overlay', title=f'''Estatísticas: {empresa_dict[empresa]}<br>Trimestre {trim} de 2024''',
113
- xaxis3_rangeslider=dict(visible=True, bgcolor="#636EFA", thickness=0.03),
114
- legend=dict(orientation="h", yanchor="top",
115
- y=1.3, xanchor="center", x=0.5),
116
- scene = dict(yaxis = dict(title=''),))
117
- fig.update_xaxes(showticklabels=False, showgrid=True, row=1, col=1)
118
- fig.update_xaxes(title_text='#Fala', showgrid=True, row=2, col=1)
119
- fig.update_yaxes(title_text='Score', row=2, col=1)
120
- fig.update_yaxes(title_text='Frequência', row=1, col=1)
121
- fig.update_traces(marker={"opacity": 0.7})
122
- return fig
123
-
124
- @st.cache_data
125
- def df_idxmax_score(df, empresa, trim, title_score):
126
- df[title_score] = df[title_score].astype('float')
127
- df['rotulo'] = df[title_score].idxmax(axis="columns")
128
- df['score'] = df[title_score].max(axis='columns')
129
- df['rotulo'] = df['rotulo'].replace({'positive_score': 'Positivo', 'negative_score': 'Negativo', 'neutral_score': 'Neutro'})
130
- return df
131
 
132
  st.set_page_config(page_title="TCCPoliUSPPro", )
133
 
@@ -138,16 +39,7 @@ shap_values = {}
138
  title_score = ['positive_score', 'negative_score', 'neutral_score']
139
 
140
  for key, val in pasta.items():
141
- if key not in st.session_state:
142
- st.session_state[key] = []
143
- st.session_state[f'df_{key}']=[]
144
- for i in range(1,5):
145
- arquivo=f'spaces/marcossuzuki/TCC_PoliUSPPro/transcrição audio RI/{val}/valores_shap-{key}{i}t24.save'
146
- shap_value = load_file(arquivo)
147
- df = mount_df(shap_value, val, i)
148
- df = df_idxmax_score(df, val, i, title_score)
149
- st.session_state[key].append(shap_value)
150
- st.session_state[f'df_{key}'].append(df)
151
  shap_values[key] = st.session_state[key]
152
 
153
  st.header("Sentimento da fala e Scores")
@@ -165,12 +57,12 @@ trim = col2.number_input("**Trimestre de 2024:**", 1, max_value = 4, key='trimes
165
 
166
  text_num = col3.number_input(
167
  "**Fala número:**",
168
- 0, max_value = len(shap_values[empresa][trim-1])-1,
169
  key='num_fala',)
170
 
171
- df=st.session_state[f'df_{empresa}'][trim-1]
172
 
173
- total_tokens, h, m, s = calc_performance(shap_values[empresa][trim-1].compute_time, df['qtde_tokens'])
174
 
175
  col4.write(f"**Total tokens:** {total_tokens} \
176
  \n**Compute time:** {h}h {m}m {s:.2}s")
@@ -182,22 +74,20 @@ with tab1:
182
  selection_mode = 'single-row',
183
  key='df_falas',
184
  on_select=select_fala,
185
- column_config={'fala':st.column_config.Column('Fala', width=100),
186
- 'qtde_tokens':st.column_config.NumberColumn("Qtde. Tokens", format='%d'),
187
  'positive_score':st.column_config.NumberColumn("Score Positivo",),
188
  'negative_score':st.column_config.NumberColumn("Score Negativo",),
189
  'neutral_score':st.column_config.NumberColumn("Score Neutro",),
190
- 'rotulo':"Rótulo",
191
  },
192
  height=200,)
193
 
194
  with tab2:
195
- estatistica = get_statistic(df)
196
- st.dataframe(estatistica, )
197
 
198
  with tab3:
199
- fig = three_plot(df)
200
- st.plotly_chart(fig)
201
 
202
 
203
  score_positive, score_negative, score_neutral = df.loc[text_num, title_score]
@@ -211,7 +101,7 @@ rotulo = st.radio(
211
  key='rotulo'
212
  )
213
 
214
- plot_text = shap.plots.text(shap_values[empresa][trim-1][text_num, :, rotulo], display = False)
215
  components.html(plot_text, height = 180, scrolling = True)
216
 
217
  st.header("Gráfico waterfall dos termos e Valores de Shapley")
@@ -219,18 +109,19 @@ st.header("Gráfico waterfall dos termos e Valores de Shapley")
219
  with st.expander("Expand"):
220
  max_display = st.slider(
221
  "**Máximo de exibição:**",
222
- 1, max_value = int(df['qtde_tokens'][text_num]),
223
- value=int(int(df['qtde_tokens'][text_num])/3)+1
224
  )
225
 
226
  plot_waterfall, ax = plt.subplots()
227
- shap.plots.waterfall(shap_values[empresa][trim-1][text_num, :, rotulo], show=False, max_display=max_display)
228
  st.pyplot(plot_waterfall)
 
229
 
230
  st.header('Rank de termos do documento em Gráfico Barra')
231
 
232
  with st.expander("Expand"):
233
- plot_bar, ax, rank = plot_rank(empresa, trim-1, option_map, 11)
234
  for key, val in option_map.items():
235
  st.subheader(val)
236
  st.pyplot(plot_bar[key])
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import streamlit.components.v1 as components
3
  import shap
4
+ from datashap import DataSHAP as ds
5
  import matplotlib.pyplot as plt
 
6
  import warnings
7
  warnings.simplefilter("ignore", category=DeprecationWarning)
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def select_fala():
11
  if st.session_state.df_falas['selection']['rows']:
12
  st.session_state.num_fala = st.session_state.df_falas['selection']['rows'][0]
13
  num = st.session_state.num_fala
14
+ df=st.session_state[st.session_state.empresa][st.session_state.trimestre-1].df
15
+ rotulo = df.iloc[num]['tag']
16
  option_map = {'Neutro':'NEUTRAL', 'Positivo':'POSITIVE', 'Negativo':'NEGATIVE',}
17
  st.session_state.rotulo = option_map[rotulo]
18
 
19
  @st.cache_resource
20
+ def get_dataSHAP(file, company, trim):
21
+ shap_value=ds.DataSHAP(file, company, trim)
22
+ shap_value.df['tag'] = shap_value.df['tag'].replace({'POSITIVE':'Positivo', 'NEGATIVE':'Negativo', 'NEUTRAL':'Neutro'})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  return shap_value
24
 
25
+ def init_session(key, val):
26
+ if key not in st.session_state:
27
+ st.session_state[key] = []
28
+ for i in range(1,5):
29
+ arquivo=f'spaces/marcossuzuki/TCC_PoliUSPPro/transcrição audio RI/{val}/valores_shap-{key}{i}t24.save'
30
+ shap_value = get_dataSHAP(arquivo, empresa_dict[key], i)
31
+ st.session_state[key].append(shap_value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  st.set_page_config(page_title="TCCPoliUSPPro", )
34
 
 
39
  title_score = ['positive_score', 'negative_score', 'neutral_score']
40
 
41
  for key, val in pasta.items():
42
+ init_session(key, val)
 
 
 
 
 
 
 
 
 
43
  shap_values[key] = st.session_state[key]
44
 
45
  st.header("Sentimento da fala e Scores")
 
57
 
58
  text_num = col3.number_input(
59
  "**Fala número:**",
60
+ 0, max_value = len(shap_values[empresa][trim-1].shap_value)-1,
61
  key='num_fala',)
62
 
63
+ df=shap_values[empresa][trim-1].df
64
 
65
+ total_tokens, h, m, s = shap_values[empresa][trim-1].get_performance()
66
 
67
  col4.write(f"**Total tokens:** {total_tokens} \
68
  \n**Compute time:** {h}h {m}m {s:.2}s")
 
74
  selection_mode = 'single-row',
75
  key='df_falas',
76
  on_select=select_fala,
77
+ column_config={'speech':st.column_config.Column('Fala', width=100),
78
+ 'qty_tokens':st.column_config.NumberColumn("Qtde. Tokens", format='%d'),
79
  'positive_score':st.column_config.NumberColumn("Score Positivo",),
80
  'negative_score':st.column_config.NumberColumn("Score Negativo",),
81
  'neutral_score':st.column_config.NumberColumn("Score Neutro",),
82
+ 'tag':"Rótulo",
83
  },
84
  height=200,)
85
 
86
  with tab2:
87
+ st.dataframe(shap_values[empresa][trim-1].statistic, )
 
88
 
89
  with tab3:
90
+ st.plotly_chart(shap_values[empresa][trim-1].plot)
 
91
 
92
 
93
  score_positive, score_negative, score_neutral = df.loc[text_num, title_score]
 
101
  key='rotulo'
102
  )
103
 
104
+ plot_text = shap.plots.text(shap_values[empresa][trim-1].shap_value[text_num, :, rotulo], display = False)
105
  components.html(plot_text, height = 180, scrolling = True)
106
 
107
  st.header("Gráfico waterfall dos termos e Valores de Shapley")
 
109
  with st.expander("Expand"):
110
  max_display = st.slider(
111
  "**Máximo de exibição:**",
112
+ 1, max_value = int(df['qty_tokens'][text_num]),
113
+ value=int(int(df['qty_tokens'][text_num])/3)+1
114
  )
115
 
116
  plot_waterfall, ax = plt.subplots()
117
+ shap.plots.waterfall(shap_values[empresa][trim-1].shap_value[text_num, :, rotulo], show=False, max_display=max_display)
118
  st.pyplot(plot_waterfall)
119
+ plt.close()
120
 
121
  st.header('Rank de termos do documento em Gráfico Barra')
122
 
123
  with st.expander("Expand"):
124
+ plot_bar, ax, rank = shap_values[empresa][trim-1].get_plot_rank()
125
  for key, val in option_map.items():
126
  st.subheader(val)
127
  st.pyplot(plot_bar[key])