Q_Learning / app.py
fatimataba21
Add application file
e68b2d8
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
st.title("Q-Learning - Grille et Convergence")
# --- PARAMÈTRES DE SIMULATION ---
st.sidebar.header("1. Environnement")
n_rows = st.sidebar.number_input("Nombre de lignes", min_value=2, max_value=10, value=3)
n_cols = st.sidebar.number_input("Nombre de colonnes", min_value=2, max_value=10, value=3)
start_row = st.sidebar.number_input("Ligne de l'état initial (1-indexé)", min_value=1, max_value=n_rows, value=1) - 1
start_col = st.sidebar.number_input("Colonne de l'état initial (1-indexé)", min_value=1, max_value=n_cols, value=1) - 1
goal_row = st.sidebar.number_input("Ligne de l'état final (1-indexé)", min_value=1, max_value=n_rows, value=n_rows) - 1
goal_col = st.sidebar.number_input("Colonne de l'état final (1-indexé)", min_value=1, max_value=n_cols, value=n_cols) - 1
reward_else = st.sidebar.number_input("Récompense par défaut (cases normales)", value=-0.01)
reward_goal = st.sidebar.number_input("Récompense de l'état final", value=1.00)
st.sidebar.header("2. Paramètres Q-Learning")
alpha = st.sidebar.slider("Alpha (taux d'apprentissage)", 0.0, 1.0, 0.5)
gamma = st.sidebar.slider("Gamma (facteur de réduction)", 0.0, 1.0, 0.9)
epsilon = st.sidebar.slider("Epsilon (exploration)", 0.0, 1.0, 0.1)
st.sidebar.header("3. Critères de convergence")
nb_max_episode = st.sidebar.number_input("Nombre max d'épisodes", min_value=1, value=1000)
max_steps_per_episode = st.sidebar.number_input("Nombre max d'étapes par épisode", min_value=1, value=15)
seuil_convergence = st.sidebar.number_input("Seuil de convergence", value=0.01)
nb_ep_stables = st.sidebar.number_input("Épisodes consécutifs sans changement", min_value=1, value=2)
start_button = st.button("Lancer l'apprentissage")
if start_button:
nb_actions = 4 # Haut, bas, gauche, droite
nb_states = n_rows * n_cols
Q_table = np.zeros((nb_actions, nb_states))
def state_index(row, col):
return row * n_cols + col
def get_reward(state):
return reward_goal if state == goal_state else reward_else
def next_state(state, action):
row, col = divmod(state, n_cols)
if action == 0 and row > 0: # Haut
row -= 1
elif action == 1 and row < n_rows - 1: # Bas
row += 1
elif action == 2 and col > 0: # Gauche
col -= 1
elif action == 3 and col < n_cols - 1: # Droite
col += 1
return state_index(row, col)
start_state = state_index(start_row, start_col)
goal_state = state_index(goal_row, goal_col)
Q_old = Q_table.copy()
stable_counter = 0
q_table_placeholder = st.empty()
for episode in range(nb_max_episode):
state = start_state
done = False
steps = 0
while not done and steps < max_steps_per_episode:
steps += 1
if np.random.uniform(0, 1) < epsilon:
action = np.random.choice(nb_actions)
else:
action = np.argmax(Q_table[:, state])
new_state = next_state(state, action)
reward = get_reward(new_state)
Q_table[action, state] = (1 - alpha) * Q_table[action, state] + alpha * (
reward + gamma * np.max(Q_table[:, new_state])
)
state = new_state
if state == goal_state:
done = True
delta = np.abs(Q_table - Q_old).max()
if delta < seuil_convergence:
stable_counter += 1
else:
stable_counter = 0
# Affichage de la Q-table mise à jour
q_table_placeholder.subheader(f"Q-Table après l’épisode {episode + 1}")
q_table_placeholder.dataframe(Q_table.round(2))
if stable_counter >= nb_ep_stables:
st.success(f"Convergence atteinte à l’épisode {episode + 1}")
break
Q_old = Q_table.copy()
# ----- Affichage de la politique optimale -----
policy = np.argmax(Q_table, axis=0).reshape(n_rows, n_cols)
st.subheader("Politique optimale sous forme de flèches")
fig, ax = plt.subplots(figsize=(n_cols, n_rows))
ax.set_xlim(0, n_cols)
ax.set_ylim(0, n_rows)
ax.set_xticks(np.arange(n_cols))
ax.set_yticks(np.arange(n_rows))
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.grid(True)
for i in range(n_rows):
for j in range(n_cols):
action = policy[i, j]
dx, dy = 0, 0
if action == 0: # haut
dx, dy = 0, 0.4
elif action == 1: # bas
dx, dy = 0, -0.4
elif action == 2: # gauche
dx, dy = -0.4, 0
elif action == 3: # droite
dx, dy = 0.4, 0
ax.arrow(j + 0.5, n_rows - i - 0.5, dx, dy, head_width=0.2, head_length=0.1, fc='blue', ec='blue')
# Objectif en vert
ax.add_patch(plt.Rectangle((goal_col, n_rows - goal_row - 1), 1, 1, fill=True, color='green', alpha=0.3))
ax.text(goal_col + 0.5, n_rows - goal_row - 0.5, 'GOAL', ha='center', va='center', fontsize=12, fontweight='bold')
st.pyplot(fig)
# ----- Valeur max Q par état -----
st.subheader("Valeur maximale de Q par état")
fig2, ax2 = plt.subplots()
ax2.plot(np.max(Q_table, axis=0), marker='o')
ax2.set_xlabel("État")
ax2.set_ylabel("Valeur max Q")
ax2.grid(True)
st.pyplot(fig2)