import streamlit as st import pandas as pd import numpy as np from sklearn.datasets import load_iris, load_wine, load_breast_cancer, load_diabetes from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor, plot_tree from sklearn.metrics import accuracy_score, mean_squared_error from sklearn.preprocessing import LabelEncoder import seaborn as sns import matplotlib.pyplot as plt from io import BytesIO # ============================ # CONFIGURATION (zéro jitter + stable) # ============================ st.set_page_config( page_title="Arbre de décision interactif", page_icon="Tree", layout="centered", initial_sidebar_state="expanded" ) # ============================ # CSS léger (Intégration des ajustements de marge/padding) # ============================ st.markdown(""" """, unsafe_allow_html=True) # ============================ # Calcul de la taille de police pour l'arbre (basé sur la profondeur réelle) # ============================ def calculate_fontsize(depth): """ Calcule la taille de police pour l'affichage de l'arbre en fonction de sa profondeur réelle. """ # Mapping profondeur -> taille de police (ajusté pour une meilleure transition) font_map = { 1: 30, 2: 25, 3: 22, 4: 18, 5: 16, 6: 12, 7: 10, } return font_map.get(depth, 8) # ============================ # Datasets intégrés # ============================ @st.cache_data def load_datasets(): return { "Iris": load_iris(as_frame=True).frame, "Wine": load_wine(as_frame=True).frame, "Breast Cancer": load_breast_cancer(as_frame=True).frame, "Diabetes": load_diabetes(as_frame=True).frame, "Penguins": sns.load_dataset("penguins").dropna(), "Tips": sns.load_dataset("tips").dropna(), } DATASETS = load_datasets() # ============================ # Titre + bouton # ============================ col_title, col_button = st.columns([4, 1]) with col_title: st.markdown("

Arbre de décision interactif

", unsafe_allow_html=True) with col_button: # Utilisation de la classe stylisée qui servira de référence pour le positionnement relatif st.markdown("
", unsafe_allow_html=True) if 'train_clicked' not in st.session_state: st.session_state.train_clicked = False st.button("Lancer", width=130, type="primary", on_click=lambda: st.session_state.update(train_clicked=True)) st.markdown("
", unsafe_allow_html=True) # ============================ # SIDEBAR (corrigée pour éviter st.stop() trop tôt) # ============================ with st.sidebar: st.header("Configuration") # Source de données source = st.radio("Source", ["Dataset intégré", "Fichier CSV"]) if source == "Dataset intégré": # L'ajout de la clé force un rafraîchissement de l'état de la session dataset_name = st.selectbox("Dataset", list(DATASETS.keys()), key='dataset_selection') df = DATASETS[dataset_name].copy() st.success(f"{len(df)} lignes chargées") else: uploaded_file = st.file_uploader("Charger un CSV", type=["csv", "txt"]) if uploaded_file is None: st.info("En attente d'un fichier CSV...") st.stop() sep = st.selectbox("Séparateur", [",", ";", "\t", " "], index=0) sep = "\t" if sep == "\t" else sep df = pd.read_csv(uploaded_file, sep=sep, engine="python") st.success(f"{len(df)} lignes chargées") # Déterminer la sélection par défaut de la variable cible all_columns = df.columns.tolist() default_target = 'target' if 'target' in all_columns else all_columns[0] default_target_index = all_columns.index(default_target) # Sélection des variables target = st.selectbox("Variable cible (y)", options=all_columns, index=default_target_index) features = st.multiselect( "Variables explicatives (X)", options=[c for c in df.columns if c != target], default=[c for c in df.columns if c != target] ) if not features: st.warning("Sélectionne au moins une variable explicative") st.stop() # Paramètres du modèle st.markdown("**Paramètres**") max_depth = st.slider("Profondeur max", 1, 30, 2) with st.expander("Avancé", expanded=False): min_samples_split = st.slider("Min samples split", 2, 20, 2) min_samples_leaf = st.slider("Min samples leaf", 1, 20, 1) criterion = st.selectbox("Critère", ["gini", "entropy", "squared_error", "friedman_mse"]) # ============================ # Fonction d'entraînement # ============================ @st.cache_data(show_spinner=False) def train_and_evaluate_model(X, y, max_depth, min_samples_split, min_samples_leaf, criterion): X = X.copy() y = y.copy() # Encodage des variables catégorielles for col in X.select_dtypes(include=['object', 'category']): X[col] = LabelEncoder().fit_transform(X[col].astype(str)) is_classification = y.dtype != 'float64' and y.nunique() <= 20 if is_classification: y = LabelEncoder().fit_transform(y.astype(str)) class_names = None model = DecisionTreeClassifier( max_depth=max_depth or None, min_samples_split=min_samples_split, min_samples_leaf=min_samples_leaf, criterion=criterion if criterion in ["gini", "entropy"] else "gini", random_state=42 ) else: class_names = None model = DecisionTreeRegressor( max_depth=max_depth or None, min_samples_split=min_samples_split, min_samples_leaf=min_samples_leaf, criterion=criterion if criterion in ["squared_error", "friedman_mse"] else "squared_error", random_state=42 ) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) model.fit(X_train, y_train) y_pred = model.predict(X_test) if is_classification: score = accuracy_score(y_test, y_pred) metric_name = "Accuracy" else: mse = mean_squared_error(y_test, y_pred) score = np.sqrt(mse) metric_name = "RMSE" actual_depth = model.get_depth() return model, score, metric_name, X.columns.tolist(), class_names, actual_depth # ============================ # Entraînement + affichage # ============================ if st.session_state.train_clicked: with st.spinner("Entraînement en cours..."): model, score, metric_name, feature_names, class_names, actual_depth = train_and_evaluate_model( df[features], df[target], max_depth, min_samples_split, min_samples_leaf, criterion ) # Performances compactes st.markdown("

Performances

", unsafe_allow_html=True) col1, col2, col3 = st.columns(3) with col1: st.metric(metric_name, f"{score:.3f}" if metric_name == "RMSE" else f"{score:.1%}") with col2: st.metric("Profondeur", actual_depth) with col3: st.metric("Feuilles", model.get_n_leaves()) # Ligne de séparation st.markdown("
", unsafe_allow_html=True) # Arbre magnifique tree_fontsize = calculate_fontsize(actual_depth) # Titre de l'arbre tiré vers le haut (-2rem pour compenser la marge Streamlit) st.markdown(f"

Arbre de décision — {target}

", unsafe_allow_html=True) fig, ax = plt.subplots(figsize=(24, 16), dpi=140) plot_tree( model, feature_names=feature_names, class_names=class_names, filled=True, rounded=True, fontsize=tree_fontsize, impurity=True, precision=2, proportion=True, ax=ax ) plt.tight_layout() buf = BytesIO() plt.savefig(buf, format="png", bbox_inches="tight", dpi=180, facecolor="white") plt.close(fig) buf.seek(0) st.image(buf, use_container_width=True)