brunaaaz commited on
Commit
6638655
·
verified ·
1 Parent(s): 70f15a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -30
app.py CHANGED
@@ -58,8 +58,9 @@ def processar_dados(df):
58
 
59
  return X, y, df # Retorna df original limpo para visualização
60
 
 
61
  @st.cache_resource
62
- def treinar_modelo(X, y):
63
  # Split
64
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
65
 
@@ -82,6 +83,12 @@ def treinar_modelo(X, y):
82
  model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42, base_score=0.5)
83
  model.fit(X_train_final, y_train_bal)
84
 
 
 
 
 
 
 
85
  return model, scaler, X_test_final, y_test, X_train_final, feature_names
86
 
87
  # --- LOGICA PRINCIPAL ---
@@ -94,7 +101,8 @@ if df_raw is not None:
94
  # Mostra um spinner enquanto carrega para o usuário saber que está trabalhando
95
  with st.spinner('Inicializando sistema: Processando dados e treinando IA...'):
96
  X, y, df_clean = processar_dados(df_raw)
97
- model, scaler, X_test, y_test, X_train, feature_names = treinar_modelo(X, y)
 
98
 
99
  # --- SIDEBAR (Simulador) ---
100
  st.sidebar.header("📂 Menu")
@@ -188,37 +196,44 @@ if df_raw is not None:
188
 
189
  try:
190
  # Calcular SHAP
 
191
  explainer = shap.TreeExplainer(model)
192
  shap_values = explainer.shap_values(X_test)
193
-
194
- st.markdown("**1. Visão Global (Quais variáveis importam mais?)**")
195
- # Correção para exibir o gráfico sem warning: criar figura explícita e passar para st.pyplot
196
- fig_summary, ax = plt.subplots()
197
- shap.summary_plot(shap_values, X_test, show=False)
198
- st.pyplot(plt.gcf())
199
- plt.clf() # Limpar figura atual
200
-
201
- st.markdown("---")
202
- st.markdown("**2. Visão Local (Análise caso a caso)**")
203
-
204
- # Seletor de índice
205
- idx = st.number_input("Selecione o ID do Cliente para auditar:", min_value=0, max_value=len(X_test)-1, value=0)
206
-
207
- real_val = y_test.iloc[idx]
208
- pred_val = y_pred[idx]
209
- st.write(f"Cliente ID {idx} | Real: {'Bad' if real_val==1 else 'Good'} | Predito: {'Bad' if pred_val==1 else 'Good'}")
210
-
211
- # Waterfall Plot
212
- fig_waterfall = plt.figure()
213
- shap.plots.waterfall(shap.Explanation(values=shap_values[idx],
214
- base_values=explainer.expected_value,
215
- data=X_test.iloc[idx],
216
- feature_names=X_test.columns.tolist()),
217
- max_display=10, show=False)
218
- st.pyplot(fig_waterfall)
219
  except Exception as e:
220
- st.error(f"Erro ao calcular SHAP: {e}")
221
- st.warning("Dica: Tente recarregar a página ou verifique compatibilidade de versões.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
  # TAB 4: Clusters
224
  with tab4:
 
58
 
59
  return X, y, df # Retorna df original limpo para visualização
60
 
61
+ # Renomeado para v2 para forçar o Streamlit a limpar o cache antigo e aplicar o fix do base_score
62
  @st.cache_resource
63
+ def treinar_modelo_v2(X, y):
64
  # Split
65
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
66
 
 
83
  model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42, base_score=0.5)
84
  model.fit(X_train_final, y_train_bal)
85
 
86
+ # Garantia extra: forçar parametro no booster interno
87
+ try:
88
+ model.get_booster().set_param({'base_score': 0.5})
89
+ except:
90
+ pass
91
+
92
  return model, scaler, X_test_final, y_test, X_train_final, feature_names
93
 
94
  # --- LOGICA PRINCIPAL ---
 
101
  # Mostra um spinner enquanto carrega para o usuário saber que está trabalhando
102
  with st.spinner('Inicializando sistema: Processando dados e treinando IA...'):
103
  X, y, df_clean = processar_dados(df_raw)
104
+ # Chamando a função v2 para garantir que o fix seja usado
105
+ model, scaler, X_test, y_test, X_train, feature_names = treinar_modelo_v2(X, y)
106
 
107
  # --- SIDEBAR (Simulador) ---
108
  st.sidebar.header("📂 Menu")
 
196
 
197
  try:
198
  # Calcular SHAP
199
+ # TENTATIVA 1: Explainer Padrão
200
  explainer = shap.TreeExplainer(model)
201
  shap_values = explainer.shap_values(X_test)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  except Exception as e:
203
+ # TENTATIVA 2: Fallback para erro de versão XGBoost/SHAP
204
+ if "could not convert string to float" in str(e):
205
+ st.warning("🔄 Ajustando compatibilidade do SHAP... (Isso é normal em versões novas)")
206
+ # Usa o booster interno diretamente, ignorando o wrapper sklearn que causa o erro
207
+ explainer = shap.TreeExplainer(model.get_booster())
208
+ shap_values = explainer.shap_values(X_test)
209
+ else:
210
+ raise e
211
+
212
+ st.markdown("**1. Visão Global (Quais variáveis importam mais?)**")
213
+ # Correção para exibir o gráfico sem warning: criar figura explícita e passar para st.pyplot
214
+ fig_summary, ax = plt.subplots()
215
+ shap.summary_plot(shap_values, X_test, show=False)
216
+ st.pyplot(plt.gcf())
217
+ plt.clf() # Limpar figura atual
218
+
219
+ st.markdown("---")
220
+ st.markdown("**2. Visão Local (Análise caso a caso)**")
221
+
222
+ # Seletor de índice
223
+ idx = st.number_input("Selecione o ID do Cliente para auditar:", min_value=0, max_value=len(X_test)-1, value=0)
224
+
225
+ real_val = y_test.iloc[idx]
226
+ pred_val = y_pred[idx]
227
+ st.write(f"Cliente ID {idx} | Real: {'Bad' if real_val==1 else 'Good'} | Predito: {'Bad' if pred_val==1 else 'Good'}")
228
+
229
+ # Waterfall Plot
230
+ fig_waterfall = plt.figure()
231
+ shap.plots.waterfall(shap.Explanation(values=shap_values[idx],
232
+ base_values=explainer.expected_value,
233
+ data=X_test.iloc[idx],
234
+ feature_names=X_test.columns.tolist()),
235
+ max_display=10, show=False)
236
+ st.pyplot(fig_waterfall)
237
 
238
  # TAB 4: Clusters
239
  with tab4: