marcossuzuki's picture
change tag->rotulo in three_plot
2b7bcf8 verified
import pickle
from datetime import timedelta
import re
import plotly.express as px
from plotly.subplots import make_subplots
import numpy as np
from pandas import DataFrame
import shap
import matplotlib.pyplot as plt
from huggingface_hub import HfFileSystem
import warnings
warnings.simplefilter("ignore", category=DeprecationWarning)
class DataSHAP:
def __init__(self, file_addr, company, trim):
self.file_addr = file_addr
self.shap_value = self.load_file(self.file_addr)
self.df = self.mount_df(self.shap_value)
self.total_tokens, self.h, self.m, self.s = self.calc_performance(self.shap_value.compute_time,
self.df['qty_tokens'])
self.statistic = self.get_statistic(self.df)
self.trim = trim
self.company = company
self.plot = self.three_plot(self.df)
self.plot_bar, self.axis, self.rank = self.plot_rank()
def load_file(self, file_addr):
shap_value = 0
fs = HfFileSystem()
with fs.open(file_addr, 'rb') as inp:
shap_value = pickle.load(inp)
inp.close()
return shap_value
def join_text(self, list):
return np.array([''.join(x) for x in list])
def count_token(self, list):
return np.array([len(x) for x in list])
def calc_allscore(self, shap_value):
base_values = DataFrame(shap_value.base_values, columns=['POSITIVE', 'NEGATIVE', 'NEUTRAL'])
df = DataFrame()
df['NEUTRAL'] = DataFrame(shap_value[:,:,'NEUTRAL'].values,
columns=['NEUTRAL'])['NEUTRAL'].apply(lambda x: x.sum()) + base_values['NEUTRAL']
df['POSITIVE'] = DataFrame(shap_value[:,:,'POSITIVE'].values,
columns=['POSITIVE'])['POSITIVE'].apply(lambda x: x.sum()) + base_values['POSITIVE']
df['NEGATIVE'] = DataFrame(shap_value[:,:,'NEGATIVE'].values,
columns=['NEGATIVE'])['NEGATIVE'].apply(lambda x: x.sum()) + base_values['NEGATIVE']
return df
def mount_df(self, shap_value):
scores = self.calc_allscore(shap_value)
text = self.join_text(shap_value.data)
token_qty = self.count_token(shap_value.data)
df = np.stack((text, token_qty), axis=-1)
df = np.hstack((df, scores))
df = DataFrame(df, columns=['speech', 'qty_tokens',
'neutral_score', 'positive_score', 'negative_score'])
title_score = ['positive_score', 'negative_score', 'neutral_score']
df = self.df_idxmax_score(df, title_score)
return df
def df_idxmax_score(self, df, title_score):
df[title_score] = df[title_score].astype('float')
df['tag'] = df[title_score].idxmax(axis="columns")
df['score'] = df[title_score].max(axis='columns')
df['tag'] = df['tag'].replace({'positive_score': 'POSITIVE', 'negative_score': 'NEGATIVE', 'neutral_score': 'NEUTRAL'})
return df
def calc_performance(self, time_s, df:DataFrame):
total_tokens = df.astype('int64').sum()
proc_time=timedelta(seconds=time_s)
h,m,s = re.split(':', str(proc_time))
return total_tokens, h, m, s
def get_statistic(self, df)->DataFrame:
statistic = DataFrame()
statistic['Score Positivo'] = df[df['tag']=='POSITIVE']['score'].describe()
statistic['Score Negativo'] = df[df['tag']=='NEGATIVE']['score'].describe()
statistic['Score Neutro'] = df[df['tag']=='NEUTRAL']['score'].describe()
return statistic
def get_performance(self):
return self.total_tokens, self.h, self.m, self.s
def three_plot(self, df):
df['tag'] = df['tag'].replace({'POSITIVE': 'Positivo', 'NEGATIVE': 'Negativo', 'NEUTRAL': 'Neutro'})
df = df.rename(columns={'tag': 'rotulo'})
fig = make_subplots(rows=2, cols=2, horizontal_spacing = 0.0, vertical_spacing = 0.05,
shared_xaxes=True, shared_yaxes=True,
row_heights=[0.4, 0.6], column_widths=[0.8, 0.2])
fig_scatter = px.scatter(df, x=df.index, y=['score'], color="rotulo",)
fig_histogram = px.histogram(df, x=df.index, color='rotulo', nbins=20,)
fig_box = px.box(df, x='rotulo', y="score", color='rotulo',)
fig_scatter.data[1]['marker']={'color': '#000007'}
fig_histogram.data[1]['marker']={'color': '#000007', 'pattern': {'shape': ''}}
fig_box.data[1]['marker']={'color': '#000007'}
for x in range(3):
fig_histogram.data[x]['showlegend']=False
fig_box.data[x]['showlegend']=False
fig.add_trace(fig_scatter.data[x], row=2, col=1)
fig.add_trace(fig_histogram.data[x], row=1, col=1)
fig.add_trace(fig_box.data[x], row=2, col=2,)
fig.update_layout(barmode='overlay', title=f'''Estatísticas: {self.company}<br>Trimestre {self.trim} de 2024''',
xaxis3_rangeslider=dict(visible=True, bgcolor="#636EFA", thickness=0.03),
legend=dict(orientation="h", yanchor="top",
y=1.3, xanchor="center", x=0.5),
scene = dict(yaxis = dict(title=''),))
fig.update_xaxes(showticklabels=False, showgrid=True, row=1, col=1)
fig.update_xaxes(title_text='#Fala', showgrid=True, row=2, col=1)
fig.update_yaxes(title_text='Score', row=2, col=1)
fig.update_yaxes(title_text='Frequência', row=1, col=1)
fig.update_traces(marker={"opacity": 0.7})
return fig
def plot_rank(self, tag={'NEUTRAL':'Neutro',
'POSITIVE':'Positivo',
'NEGATIVE':'Negativo',},
max_display=11):
plot_bar = dict()
axis = dict()
rank = DataFrame()
for key, val in tag.items():
plot, ax = plt.subplots()
shap.plots.bar(self.shap_value[:,:,key], show=False, max_display=max_display,)
plot_bar[key] = plot
axis[key] = ax
rank[val] = DataFrame(ax.get_yticklabels()[:-max_display-1]).astype(str)
plt.close()
rank[list(tag.values())] = rank[list(tag.values())].replace(r"(([T])\w+|(\d+,)|(\d+.\d+,)|(['\(\)\s]))", '', regex=True)
return plot_bar, axis, rank
def get_plot_rank(self, max_display=11):
self.plot_bar, self.axis, self.rank = self.plot_rank(max_display=max_display)
return self.plot_bar, self.axis, self.rank
def shap_plot_text(self, num, tag):
return shap.plots.text(self.shap_value[num, :, tag], display = False)
def shap_waterfall(self, num, tag, max_display):
plot_waterfall, ax = plt.subplots()
shap.plots.waterfall(self.shap_value[num, :, tag], show=False, max_display=max_display)
plt.close()
return plot_waterfall