Spaces:
Paused
Paused
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn.datasets import load_iris, load_wine, load_breast_cancer, load_diabetes | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor, plot_tree | |
| from sklearn.metrics import accuracy_score, mean_squared_error | |
| from sklearn.preprocessing import LabelEncoder | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| from io import BytesIO | |
| # ============================ | |
| # CONFIGURATION (zéro jitter + stable) | |
| # ============================ | |
| st.set_page_config( | |
| page_title="Arbre de décision interactif", | |
| page_icon="Tree", | |
| layout="centered", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # ============================ | |
| # CSS léger (Intégration des ajustements de marge/padding) | |
| # ============================ | |
| st.markdown(""" | |
| <style> | |
| /* Sidebar */ | |
| [data-testid="stSidebar"] {background: linear-gradient(180deg, #d1fae5 0%, #a7f3d0 100%);} | |
| /* Titre Principal H1 (personnalisé) */ | |
| .big-title {font-size: 2.2rem !important; color: #065f46; font-weight: bold;} | |
| /* Titre Arbre H3 (personnalisé) */ | |
| .tree-title {font-size: 1.6rem !important; color: #047857; margin: 0.5rem 0 0.5rem 0 !important;} | |
| /* 🗑️ Suppression de l'espace H1 (padding 1.25rem 0 1rem) */ | |
| .st-emotion-cache-3uj0rx h1 { | |
| padding: 0.8rem !important; | |
| } | |
| /* ⚙️ Réduction du GAP dans certains conteneurs (était 1rem) */ | |
| .st-emotion-cache-tn0cau { | |
| gap: 0.4rem !important; | |
| } | |
| /* ⚙️ Réduction de la hauteur et de la marge du conteneur de message (ex: st.success) */ | |
| .st-emotion-cache-10p9htt { | |
| height: 1rem !important; | |
| margin-bottom: 0.5rem !important; | |
| } | |
| /* 🎯 ALIGNEMENT BOUTON : Conteneur du bouton, devient la référence */ | |
| .aligned-button-container { | |
| position: relative; | |
| } | |
| /* 🚀 DÉPLACEMENT DU BOUTON : Cible la classe exacte du bouton et le remonte */ | |
| /* Utilisez la classe que vous avez trouvée : st-emotion-cache-1c14umz */ | |
| .aligned-button-container .st-emotion-cache-1c14umz { | |
| top: -2.5rem !important; /* Déplace l'élément de 2.5rem vers le haut */ | |
| position: relative !important; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # ============================ | |
| # Calcul de la taille de police pour l'arbre (basé sur la profondeur réelle) | |
| # ============================ | |
| def calculate_fontsize(depth): | |
| """ | |
| Calcule la taille de police pour l'affichage de l'arbre en fonction de sa profondeur réelle. | |
| """ | |
| # Mapping profondeur -> taille de police (ajusté pour une meilleure transition) | |
| font_map = { | |
| 1: 30, | |
| 2: 25, | |
| 3: 22, | |
| 4: 18, | |
| 5: 16, | |
| 6: 12, | |
| 7: 10, | |
| } | |
| return font_map.get(depth, 8) | |
| # ============================ | |
| # Datasets intégrés | |
| # ============================ | |
| def load_datasets(): | |
| return { | |
| "Iris": load_iris(as_frame=True).frame, | |
| "Wine": load_wine(as_frame=True).frame, | |
| "Breast Cancer": load_breast_cancer(as_frame=True).frame, | |
| "Diabetes": load_diabetes(as_frame=True).frame, | |
| "Penguins": sns.load_dataset("penguins").dropna(), | |
| "Tips": sns.load_dataset("tips").dropna(), | |
| } | |
| DATASETS = load_datasets() | |
| # ============================ | |
| # Titre + bouton | |
| # ============================ | |
| col_title, col_button = st.columns([4, 1]) | |
| with col_title: | |
| st.markdown("<h1 class='big-title'>Arbre de décision interactif</h1>", | |
| unsafe_allow_html=True) | |
| with col_button: | |
| # Utilisation de la classe stylisée qui servira de référence pour le positionnement relatif | |
| st.markdown("<div class='aligned-button-container'>", unsafe_allow_html=True) | |
| if 'train_clicked' not in st.session_state: | |
| st.session_state.train_clicked = False | |
| st.button("Lancer", width=130, type="primary", | |
| on_click=lambda: st.session_state.update(train_clicked=True)) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| # ============================ | |
| # SIDEBAR (corrigée pour éviter st.stop() trop tôt) | |
| # ============================ | |
| with st.sidebar: | |
| st.header("Configuration") | |
| # Source de données | |
| source = st.radio("Source", ["Dataset intégré", "Fichier CSV"]) | |
| if source == "Dataset intégré": | |
| # L'ajout de la clé force un rafraîchissement de l'état de la session | |
| dataset_name = st.selectbox("Dataset", list(DATASETS.keys()), key='dataset_selection') | |
| df = DATASETS[dataset_name].copy() | |
| st.success(f"{len(df)} lignes chargées") | |
| else: | |
| uploaded_file = st.file_uploader("Charger un CSV", type=["csv", "txt"]) | |
| if uploaded_file is None: | |
| st.info("En attente d'un fichier CSV...") | |
| st.stop() | |
| sep = st.selectbox("Séparateur", [",", ";", "\t", " "], index=0) | |
| sep = "\t" if sep == "\t" else sep | |
| df = pd.read_csv(uploaded_file, sep=sep, engine="python") | |
| st.success(f"{len(df)} lignes chargées") | |
| # Déterminer la sélection par défaut de la variable cible | |
| all_columns = df.columns.tolist() | |
| default_target = 'target' if 'target' in all_columns else all_columns[0] | |
| default_target_index = all_columns.index(default_target) | |
| # Sélection des variables | |
| target = st.selectbox("Variable cible (y)", options=all_columns, index=default_target_index) | |
| features = st.multiselect( | |
| "Variables explicatives (X)", | |
| options=[c for c in df.columns if c != target], | |
| default=[c for c in df.columns if c != target] | |
| ) | |
| if not features: | |
| st.warning("Sélectionne au moins une variable explicative") | |
| st.stop() | |
| # Paramètres du modèle | |
| st.markdown("**Paramètres**") | |
| max_depth = st.slider("Profondeur max", 1, 30, 2) | |
| with st.expander("Avancé", expanded=False): | |
| min_samples_split = st.slider("Min samples split", 2, 20, 2) | |
| min_samples_leaf = st.slider("Min samples leaf", 1, 20, 1) | |
| criterion = st.selectbox("Critère", ["gini", "entropy", "squared_error", "friedman_mse"]) | |
| # ============================ | |
| # Fonction d'entraînement | |
| # ============================ | |
| def train_and_evaluate_model(X, y, max_depth, min_samples_split, min_samples_leaf, criterion): | |
| X = X.copy() | |
| y = y.copy() | |
| # Encodage des variables catégorielles | |
| for col in X.select_dtypes(include=['object', 'category']): | |
| X[col] = LabelEncoder().fit_transform(X[col].astype(str)) | |
| is_classification = y.dtype != 'float64' and y.nunique() <= 20 | |
| if is_classification: | |
| y = LabelEncoder().fit_transform(y.astype(str)) | |
| class_names = None | |
| model = DecisionTreeClassifier( | |
| max_depth=max_depth or None, | |
| min_samples_split=min_samples_split, | |
| min_samples_leaf=min_samples_leaf, | |
| criterion=criterion if criterion in ["gini", "entropy"] else "gini", | |
| random_state=42 | |
| ) | |
| else: | |
| class_names = None | |
| model = DecisionTreeRegressor( | |
| max_depth=max_depth or None, | |
| min_samples_split=min_samples_split, | |
| min_samples_leaf=min_samples_leaf, | |
| criterion=criterion if criterion in ["squared_error", "friedman_mse"] else "squared_error", | |
| random_state=42 | |
| ) | |
| X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | |
| model.fit(X_train, y_train) | |
| y_pred = model.predict(X_test) | |
| if is_classification: | |
| score = accuracy_score(y_test, y_pred) | |
| metric_name = "Accuracy" | |
| else: | |
| mse = mean_squared_error(y_test, y_pred) | |
| score = np.sqrt(mse) | |
| metric_name = "RMSE" | |
| actual_depth = model.get_depth() | |
| return model, score, metric_name, X.columns.tolist(), class_names, actual_depth | |
| # ============================ | |
| # Entraînement + affichage | |
| # ============================ | |
| if st.session_state.train_clicked: | |
| with st.spinner("Entraînement en cours..."): | |
| model, score, metric_name, feature_names, class_names, actual_depth = train_and_evaluate_model( | |
| df[features], df[target], max_depth, min_samples_split, min_samples_leaf, criterion | |
| ) | |
| # Performances compactes | |
| st.markdown("<h4 style='margin-top: -0.5rem;'>Performances</h4>", unsafe_allow_html=True) | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.metric(metric_name, f"{score:.3f}" if metric_name == "RMSE" else f"{score:.1%}") | |
| with col2: | |
| st.metric("Profondeur", actual_depth) | |
| with col3: | |
| st.metric("Feuilles", model.get_n_leaves()) | |
| # Ligne de séparation | |
| st.markdown("<hr style='margin-top: 0.5rem; margin-bottom: 0.5rem;'>", unsafe_allow_html=True) | |
| # Arbre magnifique | |
| tree_fontsize = calculate_fontsize(actual_depth) | |
| # Titre de l'arbre tiré vers le haut (-2rem pour compenser la marge Streamlit) | |
| st.markdown(f"<h3 class='tree-title' style='margin-top:-2rem;'>Arbre de décision — {target}</h3>", unsafe_allow_html=True) | |
| fig, ax = plt.subplots(figsize=(24, 16), dpi=140) | |
| plot_tree( | |
| model, | |
| feature_names=feature_names, | |
| class_names=class_names, | |
| filled=True, | |
| rounded=True, | |
| fontsize=tree_fontsize, | |
| impurity=True, | |
| precision=2, | |
| proportion=True, | |
| ax=ax | |
| ) | |
| plt.tight_layout() | |
| buf = BytesIO() | |
| plt.savefig(buf, format="png", bbox_inches="tight", dpi=180, facecolor="white") | |
| plt.close(fig) | |
| buf.seek(0) | |
| st.image(buf, use_container_width=True) | |