Decision_tree / app.py
Eric2mangel's picture
Upload app.py
1d476df verified
import streamlit as st
import pandas as pd
import numpy as np
from sklearn.datasets import load_iris, load_wine, load_breast_cancer, load_diabetes
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor, plot_tree
from sklearn.metrics import accuracy_score, mean_squared_error
from sklearn.preprocessing import LabelEncoder
import seaborn as sns
import matplotlib.pyplot as plt
from io import BytesIO
# ============================
# CONFIGURATION (zéro jitter + stable)
# ============================
st.set_page_config(
page_title="Arbre de décision interactif",
page_icon="Tree",
layout="centered",
initial_sidebar_state="expanded"
)
# ============================
# CSS léger (Intégration des ajustements de marge/padding)
# ============================
st.markdown("""
<style>
/* Sidebar */
[data-testid="stSidebar"] {background: linear-gradient(180deg, #d1fae5 0%, #a7f3d0 100%);}
/* Titre Principal H1 (personnalisé) */
.big-title {font-size: 2.2rem !important; color: #065f46; font-weight: bold;}
/* Titre Arbre H3 (personnalisé) */
.tree-title {font-size: 1.6rem !important; color: #047857; margin: 0.5rem 0 0.5rem 0 !important;}
/* 🗑️ Suppression de l'espace H1 (padding 1.25rem 0 1rem) */
.st-emotion-cache-3uj0rx h1 {
padding: 0.8rem !important;
}
/* ⚙️ Réduction du GAP dans certains conteneurs (était 1rem) */
.st-emotion-cache-tn0cau {
gap: 0.4rem !important;
}
/* ⚙️ Réduction de la hauteur et de la marge du conteneur de message (ex: st.success) */
.st-emotion-cache-10p9htt {
height: 1rem !important;
margin-bottom: 0.5rem !important;
}
/* 🎯 ALIGNEMENT BOUTON : Conteneur du bouton, devient la référence */
.aligned-button-container {
position: relative;
}
/* 🚀 DÉPLACEMENT DU BOUTON : Cible la classe exacte du bouton et le remonte */
/* Utilisez la classe que vous avez trouvée : st-emotion-cache-1c14umz */
.aligned-button-container .st-emotion-cache-1c14umz {
top: -2.5rem !important; /* Déplace l'élément de 2.5rem vers le haut */
position: relative !important;
}
</style>
""", unsafe_allow_html=True)
# ============================
# Calcul de la taille de police pour l'arbre (basé sur la profondeur réelle)
# ============================
def calculate_fontsize(depth):
"""
Calcule la taille de police pour l'affichage de l'arbre en fonction de sa profondeur réelle.
"""
# Mapping profondeur -> taille de police (ajusté pour une meilleure transition)
font_map = {
1: 30,
2: 25,
3: 22,
4: 18,
5: 16,
6: 12,
7: 10,
}
return font_map.get(depth, 8)
# ============================
# Datasets intégrés
# ============================
@st.cache_data
def load_datasets():
return {
"Iris": load_iris(as_frame=True).frame,
"Wine": load_wine(as_frame=True).frame,
"Breast Cancer": load_breast_cancer(as_frame=True).frame,
"Diabetes": load_diabetes(as_frame=True).frame,
"Penguins": sns.load_dataset("penguins").dropna(),
"Tips": sns.load_dataset("tips").dropna(),
}
DATASETS = load_datasets()
# ============================
# Titre + bouton
# ============================
col_title, col_button = st.columns([4, 1])
with col_title:
st.markdown("<h1 class='big-title'>Arbre de décision interactif</h1>",
unsafe_allow_html=True)
with col_button:
# Utilisation de la classe stylisée qui servira de référence pour le positionnement relatif
st.markdown("<div class='aligned-button-container'>", unsafe_allow_html=True)
if 'train_clicked' not in st.session_state:
st.session_state.train_clicked = False
st.button("Lancer", width=130, type="primary",
on_click=lambda: st.session_state.update(train_clicked=True))
st.markdown("</div>", unsafe_allow_html=True)
# ============================
# SIDEBAR (corrigée pour éviter st.stop() trop tôt)
# ============================
with st.sidebar:
st.header("Configuration")
# Source de données
source = st.radio("Source", ["Dataset intégré", "Fichier CSV"])
if source == "Dataset intégré":
# L'ajout de la clé force un rafraîchissement de l'état de la session
dataset_name = st.selectbox("Dataset", list(DATASETS.keys()), key='dataset_selection')
df = DATASETS[dataset_name].copy()
st.success(f"{len(df)} lignes chargées")
else:
uploaded_file = st.file_uploader("Charger un CSV", type=["csv", "txt"])
if uploaded_file is None:
st.info("En attente d'un fichier CSV...")
st.stop()
sep = st.selectbox("Séparateur", [",", ";", "\t", " "], index=0)
sep = "\t" if sep == "\t" else sep
df = pd.read_csv(uploaded_file, sep=sep, engine="python")
st.success(f"{len(df)} lignes chargées")
# Déterminer la sélection par défaut de la variable cible
all_columns = df.columns.tolist()
default_target = 'target' if 'target' in all_columns else all_columns[0]
default_target_index = all_columns.index(default_target)
# Sélection des variables
target = st.selectbox("Variable cible (y)", options=all_columns, index=default_target_index)
features = st.multiselect(
"Variables explicatives (X)",
options=[c for c in df.columns if c != target],
default=[c for c in df.columns if c != target]
)
if not features:
st.warning("Sélectionne au moins une variable explicative")
st.stop()
# Paramètres du modèle
st.markdown("**Paramètres**")
max_depth = st.slider("Profondeur max", 1, 30, 2)
with st.expander("Avancé", expanded=False):
min_samples_split = st.slider("Min samples split", 2, 20, 2)
min_samples_leaf = st.slider("Min samples leaf", 1, 20, 1)
criterion = st.selectbox("Critère", ["gini", "entropy", "squared_error", "friedman_mse"])
# ============================
# Fonction d'entraînement
# ============================
@st.cache_data(show_spinner=False)
def train_and_evaluate_model(X, y, max_depth, min_samples_split, min_samples_leaf, criterion):
X = X.copy()
y = y.copy()
# Encodage des variables catégorielles
for col in X.select_dtypes(include=['object', 'category']):
X[col] = LabelEncoder().fit_transform(X[col].astype(str))
is_classification = y.dtype != 'float64' and y.nunique() <= 20
if is_classification:
y = LabelEncoder().fit_transform(y.astype(str))
class_names = None
model = DecisionTreeClassifier(
max_depth=max_depth or None,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
criterion=criterion if criterion in ["gini", "entropy"] else "gini",
random_state=42
)
else:
class_names = None
model = DecisionTreeRegressor(
max_depth=max_depth or None,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
criterion=criterion if criterion in ["squared_error", "friedman_mse"] else "squared_error",
random_state=42
)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
if is_classification:
score = accuracy_score(y_test, y_pred)
metric_name = "Accuracy"
else:
mse = mean_squared_error(y_test, y_pred)
score = np.sqrt(mse)
metric_name = "RMSE"
actual_depth = model.get_depth()
return model, score, metric_name, X.columns.tolist(), class_names, actual_depth
# ============================
# Entraînement + affichage
# ============================
if st.session_state.train_clicked:
with st.spinner("Entraînement en cours..."):
model, score, metric_name, feature_names, class_names, actual_depth = train_and_evaluate_model(
df[features], df[target], max_depth, min_samples_split, min_samples_leaf, criterion
)
# Performances compactes
st.markdown("<h4 style='margin-top: -0.5rem;'>Performances</h4>", unsafe_allow_html=True)
col1, col2, col3 = st.columns(3)
with col1:
st.metric(metric_name, f"{score:.3f}" if metric_name == "RMSE" else f"{score:.1%}")
with col2:
st.metric("Profondeur", actual_depth)
with col3:
st.metric("Feuilles", model.get_n_leaves())
# Ligne de séparation
st.markdown("<hr style='margin-top: 0.5rem; margin-bottom: 0.5rem;'>", unsafe_allow_html=True)
# Arbre magnifique
tree_fontsize = calculate_fontsize(actual_depth)
# Titre de l'arbre tiré vers le haut (-2rem pour compenser la marge Streamlit)
st.markdown(f"<h3 class='tree-title' style='margin-top:-2rem;'>Arbre de décision — {target}</h3>", unsafe_allow_html=True)
fig, ax = plt.subplots(figsize=(24, 16), dpi=140)
plot_tree(
model,
feature_names=feature_names,
class_names=class_names,
filled=True,
rounded=True,
fontsize=tree_fontsize,
impurity=True,
precision=2,
proportion=True,
ax=ax
)
plt.tight_layout()
buf = BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", dpi=180, facecolor="white")
plt.close(fig)
buf.seek(0)
st.image(buf, use_container_width=True)