import pickle from datetime import timedelta import re import plotly.express as px from plotly.subplots import make_subplots import numpy as np from pandas import DataFrame import shap import matplotlib.pyplot as plt from huggingface_hub import HfFileSystem import warnings warnings.simplefilter("ignore", category=DeprecationWarning) class DataSHAP: def __init__(self, file_addr, company, trim): self.file_addr = file_addr self.shap_value = self.load_file(self.file_addr) self.df = self.mount_df(self.shap_value) self.total_tokens, self.h, self.m, self.s = self.calc_performance(self.shap_value.compute_time, self.df['qty_tokens']) self.statistic = self.get_statistic(self.df) self.trim = trim self.company = company self.plot = self.three_plot(self.df) self.plot_bar, self.axis, self.rank = self.plot_rank() def load_file(self, file_addr): shap_value = 0 fs = HfFileSystem() with fs.open(file_addr, 'rb') as inp: shap_value = pickle.load(inp) inp.close() return shap_value def join_text(self, list): return np.array([''.join(x) for x in list]) def count_token(self, list): return np.array([len(x) for x in list]) def calc_allscore(self, shap_value): base_values = DataFrame(shap_value.base_values, columns=['POSITIVE', 'NEGATIVE', 'NEUTRAL']) df = DataFrame() df['NEUTRAL'] = DataFrame(shap_value[:,:,'NEUTRAL'].values, columns=['NEUTRAL'])['NEUTRAL'].apply(lambda x: x.sum()) + base_values['NEUTRAL'] df['POSITIVE'] = DataFrame(shap_value[:,:,'POSITIVE'].values, columns=['POSITIVE'])['POSITIVE'].apply(lambda x: x.sum()) + base_values['POSITIVE'] df['NEGATIVE'] = DataFrame(shap_value[:,:,'NEGATIVE'].values, columns=['NEGATIVE'])['NEGATIVE'].apply(lambda x: x.sum()) + base_values['NEGATIVE'] return df def mount_df(self, shap_value): scores = self.calc_allscore(shap_value) text = self.join_text(shap_value.data) token_qty = self.count_token(shap_value.data) df = np.stack((text, token_qty), axis=-1) df = np.hstack((df, scores)) df = DataFrame(df, columns=['speech', 'qty_tokens', 'neutral_score', 'positive_score', 'negative_score']) title_score = ['positive_score', 'negative_score', 'neutral_score'] df = self.df_idxmax_score(df, title_score) return df def df_idxmax_score(self, df, title_score): df[title_score] = df[title_score].astype('float') df['tag'] = df[title_score].idxmax(axis="columns") df['score'] = df[title_score].max(axis='columns') df['tag'] = df['tag'].replace({'positive_score': 'POSITIVE', 'negative_score': 'NEGATIVE', 'neutral_score': 'NEUTRAL'}) return df def calc_performance(self, time_s, df:DataFrame): total_tokens = df.astype('int64').sum() proc_time=timedelta(seconds=time_s) h,m,s = re.split(':', str(proc_time)) return total_tokens, h, m, s def get_statistic(self, df)->DataFrame: statistic = DataFrame() statistic['Score Positivo'] = df[df['tag']=='POSITIVE']['score'].describe() statistic['Score Negativo'] = df[df['tag']=='NEGATIVE']['score'].describe() statistic['Score Neutro'] = df[df['tag']=='NEUTRAL']['score'].describe() return statistic def get_performance(self): return self.total_tokens, self.h, self.m, self.s def three_plot(self, df): df['tag'] = df['tag'].replace({'POSITIVE': 'Positivo', 'NEGATIVE': 'Negativo', 'NEUTRAL': 'Neutro'}) df = df.rename(columns={'tag': 'rotulo'}) fig = make_subplots(rows=2, cols=2, horizontal_spacing = 0.0, vertical_spacing = 0.05, shared_xaxes=True, shared_yaxes=True, row_heights=[0.4, 0.6], column_widths=[0.8, 0.2]) fig_scatter = px.scatter(df, x=df.index, y=['score'], color="rotulo",) fig_histogram = px.histogram(df, x=df.index, color='rotulo', nbins=20,) fig_box = px.box(df, x='rotulo', y="score", color='rotulo',) fig_scatter.data[1]['marker']={'color': '#000007'} fig_histogram.data[1]['marker']={'color': '#000007', 'pattern': {'shape': ''}} fig_box.data[1]['marker']={'color': '#000007'} for x in range(3): fig_histogram.data[x]['showlegend']=False fig_box.data[x]['showlegend']=False fig.add_trace(fig_scatter.data[x], row=2, col=1) fig.add_trace(fig_histogram.data[x], row=1, col=1) fig.add_trace(fig_box.data[x], row=2, col=2,) fig.update_layout(barmode='overlay', title=f'''EstatĂ­sticas: {self.company}
Trimestre {self.trim} de 2024''', xaxis3_rangeslider=dict(visible=True, bgcolor="#636EFA", thickness=0.03), legend=dict(orientation="h", yanchor="top", y=1.3, xanchor="center", x=0.5), scene = dict(yaxis = dict(title=''),)) fig.update_xaxes(showticklabels=False, showgrid=True, row=1, col=1) fig.update_xaxes(title_text='#Fala', showgrid=True, row=2, col=1) fig.update_yaxes(title_text='Score', row=2, col=1) fig.update_yaxes(title_text='FrequĂȘncia', row=1, col=1) fig.update_traces(marker={"opacity": 0.7}) return fig def plot_rank(self, tag={'NEUTRAL':'Neutro', 'POSITIVE':'Positivo', 'NEGATIVE':'Negativo',}, max_display=11): plot_bar = dict() axis = dict() rank = DataFrame() for key, val in tag.items(): plot, ax = plt.subplots() shap.plots.bar(self.shap_value[:,:,key], show=False, max_display=max_display,) plot_bar[key] = plot axis[key] = ax rank[val] = DataFrame(ax.get_yticklabels()[:-max_display-1]).astype(str) plt.close() rank[list(tag.values())] = rank[list(tag.values())].replace(r"(([T])\w+|(\d+,)|(\d+.\d+,)|(['\(\)\s]))", '', regex=True) return plot_bar, axis, rank def get_plot_rank(self, max_display=11): self.plot_bar, self.axis, self.rank = self.plot_rank(max_display=max_display) return self.plot_bar, self.axis, self.rank def shap_plot_text(self, num, tag): return shap.plots.text(self.shap_value[num, :, tag], display = False) def shap_waterfall(self, num, tag, max_display): plot_waterfall, ax = plt.subplots() shap.plots.waterfall(self.shap_value[num, :, tag], show=False, max_display=max_display) plt.close() return plot_waterfall