Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| st.title("Q-Learning - Grille et Convergence") | |
| # --- PARAMÈTRES DE SIMULATION --- | |
| st.sidebar.header("1. Environnement") | |
| n_rows = st.sidebar.number_input("Nombre de lignes", min_value=2, max_value=10, value=3) | |
| n_cols = st.sidebar.number_input("Nombre de colonnes", min_value=2, max_value=10, value=3) | |
| start_row = st.sidebar.number_input("Ligne de l'état initial (1-indexé)", min_value=1, max_value=n_rows, value=1) - 1 | |
| start_col = st.sidebar.number_input("Colonne de l'état initial (1-indexé)", min_value=1, max_value=n_cols, value=1) - 1 | |
| goal_row = st.sidebar.number_input("Ligne de l'état final (1-indexé)", min_value=1, max_value=n_rows, value=n_rows) - 1 | |
| goal_col = st.sidebar.number_input("Colonne de l'état final (1-indexé)", min_value=1, max_value=n_cols, value=n_cols) - 1 | |
| reward_else = st.sidebar.number_input("Récompense par défaut (cases normales)", value=-0.01) | |
| reward_goal = st.sidebar.number_input("Récompense de l'état final", value=1.00) | |
| st.sidebar.header("2. Paramètres Q-Learning") | |
| alpha = st.sidebar.slider("Alpha (taux d'apprentissage)", 0.0, 1.0, 0.5) | |
| gamma = st.sidebar.slider("Gamma (facteur de réduction)", 0.0, 1.0, 0.9) | |
| epsilon = st.sidebar.slider("Epsilon (exploration)", 0.0, 1.0, 0.1) | |
| st.sidebar.header("3. Critères de convergence") | |
| nb_max_episode = st.sidebar.number_input("Nombre max d'épisodes", min_value=1, value=1000) | |
| max_steps_per_episode = st.sidebar.number_input("Nombre max d'étapes par épisode", min_value=1, value=15) | |
| seuil_convergence = st.sidebar.number_input("Seuil de convergence", value=0.01) | |
| nb_ep_stables = st.sidebar.number_input("Épisodes consécutifs sans changement", min_value=1, value=2) | |
| start_button = st.button("Lancer l'apprentissage") | |
| if start_button: | |
| nb_actions = 4 # Haut, bas, gauche, droite | |
| nb_states = n_rows * n_cols | |
| Q_table = np.zeros((nb_actions, nb_states)) | |
| def state_index(row, col): | |
| return row * n_cols + col | |
| def get_reward(state): | |
| return reward_goal if state == goal_state else reward_else | |
| def next_state(state, action): | |
| row, col = divmod(state, n_cols) | |
| if action == 0 and row > 0: # Haut | |
| row -= 1 | |
| elif action == 1 and row < n_rows - 1: # Bas | |
| row += 1 | |
| elif action == 2 and col > 0: # Gauche | |
| col -= 1 | |
| elif action == 3 and col < n_cols - 1: # Droite | |
| col += 1 | |
| return state_index(row, col) | |
| start_state = state_index(start_row, start_col) | |
| goal_state = state_index(goal_row, goal_col) | |
| Q_old = Q_table.copy() | |
| stable_counter = 0 | |
| q_table_placeholder = st.empty() | |
| for episode in range(nb_max_episode): | |
| state = start_state | |
| done = False | |
| steps = 0 | |
| while not done and steps < max_steps_per_episode: | |
| steps += 1 | |
| if np.random.uniform(0, 1) < epsilon: | |
| action = np.random.choice(nb_actions) | |
| else: | |
| action = np.argmax(Q_table[:, state]) | |
| new_state = next_state(state, action) | |
| reward = get_reward(new_state) | |
| Q_table[action, state] = (1 - alpha) * Q_table[action, state] + alpha * ( | |
| reward + gamma * np.max(Q_table[:, new_state]) | |
| ) | |
| state = new_state | |
| if state == goal_state: | |
| done = True | |
| delta = np.abs(Q_table - Q_old).max() | |
| if delta < seuil_convergence: | |
| stable_counter += 1 | |
| else: | |
| stable_counter = 0 | |
| # Affichage de la Q-table mise à jour | |
| q_table_placeholder.subheader(f"Q-Table après l’épisode {episode + 1}") | |
| q_table_placeholder.dataframe(Q_table.round(2)) | |
| if stable_counter >= nb_ep_stables: | |
| st.success(f"Convergence atteinte à l’épisode {episode + 1}") | |
| break | |
| Q_old = Q_table.copy() | |
| # ----- Affichage de la politique optimale ----- | |
| policy = np.argmax(Q_table, axis=0).reshape(n_rows, n_cols) | |
| st.subheader("Politique optimale sous forme de flèches") | |
| fig, ax = plt.subplots(figsize=(n_cols, n_rows)) | |
| ax.set_xlim(0, n_cols) | |
| ax.set_ylim(0, n_rows) | |
| ax.set_xticks(np.arange(n_cols)) | |
| ax.set_yticks(np.arange(n_rows)) | |
| ax.set_xticklabels([]) | |
| ax.set_yticklabels([]) | |
| ax.grid(True) | |
| for i in range(n_rows): | |
| for j in range(n_cols): | |
| action = policy[i, j] | |
| dx, dy = 0, 0 | |
| if action == 0: # haut | |
| dx, dy = 0, 0.4 | |
| elif action == 1: # bas | |
| dx, dy = 0, -0.4 | |
| elif action == 2: # gauche | |
| dx, dy = -0.4, 0 | |
| elif action == 3: # droite | |
| dx, dy = 0.4, 0 | |
| ax.arrow(j + 0.5, n_rows - i - 0.5, dx, dy, head_width=0.2, head_length=0.1, fc='blue', ec='blue') | |
| # Objectif en vert | |
| ax.add_patch(plt.Rectangle((goal_col, n_rows - goal_row - 1), 1, 1, fill=True, color='green', alpha=0.3)) | |
| ax.text(goal_col + 0.5, n_rows - goal_row - 0.5, 'GOAL', ha='center', va='center', fontsize=12, fontweight='bold') | |
| st.pyplot(fig) | |
| # ----- Valeur max Q par état ----- | |
| st.subheader("Valeur maximale de Q par état") | |
| fig2, ax2 = plt.subplots() | |
| ax2.plot(np.max(Q_table, axis=0), marker='o') | |
| ax2.set_xlabel("État") | |
| ax2.set_ylabel("Valeur max Q") | |
| ax2.grid(True) | |
| st.pyplot(fig2) | |