import streamlit as st import numpy as np import matplotlib.pyplot as plt st.title("Q-Learning - Grille et Convergence") # --- PARAMÈTRES DE SIMULATION --- st.sidebar.header("1. Environnement") n_rows = st.sidebar.number_input("Nombre de lignes", min_value=2, max_value=10, value=3) n_cols = st.sidebar.number_input("Nombre de colonnes", min_value=2, max_value=10, value=3) start_row = st.sidebar.number_input("Ligne de l'état initial (1-indexé)", min_value=1, max_value=n_rows, value=1) - 1 start_col = st.sidebar.number_input("Colonne de l'état initial (1-indexé)", min_value=1, max_value=n_cols, value=1) - 1 goal_row = st.sidebar.number_input("Ligne de l'état final (1-indexé)", min_value=1, max_value=n_rows, value=n_rows) - 1 goal_col = st.sidebar.number_input("Colonne de l'état final (1-indexé)", min_value=1, max_value=n_cols, value=n_cols) - 1 reward_else = st.sidebar.number_input("Récompense par défaut (cases normales)", value=-0.01) reward_goal = st.sidebar.number_input("Récompense de l'état final", value=1.00) st.sidebar.header("2. Paramètres Q-Learning") alpha = st.sidebar.slider("Alpha (taux d'apprentissage)", 0.0, 1.0, 0.5) gamma = st.sidebar.slider("Gamma (facteur de réduction)", 0.0, 1.0, 0.9) epsilon = st.sidebar.slider("Epsilon (exploration)", 0.0, 1.0, 0.1) st.sidebar.header("3. Critères de convergence") nb_max_episode = st.sidebar.number_input("Nombre max d'épisodes", min_value=1, value=1000) max_steps_per_episode = st.sidebar.number_input("Nombre max d'étapes par épisode", min_value=1, value=15) seuil_convergence = st.sidebar.number_input("Seuil de convergence", value=0.01) nb_ep_stables = st.sidebar.number_input("Épisodes consécutifs sans changement", min_value=1, value=2) start_button = st.button("Lancer l'apprentissage") if start_button: nb_actions = 4 # Haut, bas, gauche, droite nb_states = n_rows * n_cols Q_table = np.zeros((nb_actions, nb_states)) def state_index(row, col): return row * n_cols + col def get_reward(state): return reward_goal if state == goal_state else reward_else def next_state(state, action): row, col = divmod(state, n_cols) if action == 0 and row > 0: # Haut row -= 1 elif action == 1 and row < n_rows - 1: # Bas row += 1 elif action == 2 and col > 0: # Gauche col -= 1 elif action == 3 and col < n_cols - 1: # Droite col += 1 return state_index(row, col) start_state = state_index(start_row, start_col) goal_state = state_index(goal_row, goal_col) Q_old = Q_table.copy() stable_counter = 0 q_table_placeholder = st.empty() for episode in range(nb_max_episode): state = start_state done = False steps = 0 while not done and steps < max_steps_per_episode: steps += 1 if np.random.uniform(0, 1) < epsilon: action = np.random.choice(nb_actions) else: action = np.argmax(Q_table[:, state]) new_state = next_state(state, action) reward = get_reward(new_state) Q_table[action, state] = (1 - alpha) * Q_table[action, state] + alpha * ( reward + gamma * np.max(Q_table[:, new_state]) ) state = new_state if state == goal_state: done = True delta = np.abs(Q_table - Q_old).max() if delta < seuil_convergence: stable_counter += 1 else: stable_counter = 0 # Affichage de la Q-table mise à jour q_table_placeholder.subheader(f"Q-Table après l’épisode {episode + 1}") q_table_placeholder.dataframe(Q_table.round(2)) if stable_counter >= nb_ep_stables: st.success(f"Convergence atteinte à l’épisode {episode + 1}") break Q_old = Q_table.copy() # ----- Affichage de la politique optimale ----- policy = np.argmax(Q_table, axis=0).reshape(n_rows, n_cols) st.subheader("Politique optimale sous forme de flèches") fig, ax = plt.subplots(figsize=(n_cols, n_rows)) ax.set_xlim(0, n_cols) ax.set_ylim(0, n_rows) ax.set_xticks(np.arange(n_cols)) ax.set_yticks(np.arange(n_rows)) ax.set_xticklabels([]) ax.set_yticklabels([]) ax.grid(True) for i in range(n_rows): for j in range(n_cols): action = policy[i, j] dx, dy = 0, 0 if action == 0: # haut dx, dy = 0, 0.4 elif action == 1: # bas dx, dy = 0, -0.4 elif action == 2: # gauche dx, dy = -0.4, 0 elif action == 3: # droite dx, dy = 0.4, 0 ax.arrow(j + 0.5, n_rows - i - 0.5, dx, dy, head_width=0.2, head_length=0.1, fc='blue', ec='blue') # Objectif en vert ax.add_patch(plt.Rectangle((goal_col, n_rows - goal_row - 1), 1, 1, fill=True, color='green', alpha=0.3)) ax.text(goal_col + 0.5, n_rows - goal_row - 0.5, 'GOAL', ha='center', va='center', fontsize=12, fontweight='bold') st.pyplot(fig) # ----- Valeur max Q par état ----- st.subheader("Valeur maximale de Q par état") fig2, ax2 = plt.subplots() ax2.plot(np.max(Q_table, axis=0), marker='o') ax2.set_xlabel("État") ax2.set_ylabel("Valeur max Q") ax2.grid(True) st.pyplot(fig2)