Marcel0123's picture
Update app.py
0c6bd2f verified
# Create a clean app.py WITHOUT any file-writing code (fixing the runtime error)
app_py_clean = """import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
# =============================
# Gridworld RL demo (visual + step-by-step)
# =============================
ACTIONS = ["↑", "→", "↓", "←"]
DELTAS = [(-1, 0), (0, 1), (1, 0), (0, -1)]
def clamp(x, lo, hi):
return max(lo, min(hi, x))
# -----------------------------
# Environment
# -----------------------------
class Gridworld:
def __init__(self, n=6, step_penalty=-0.01):
self.n = n
self.goal = (n - 1, n - 1)
self.traps = {(n // 2, n // 2)}
self.step_penalty = float(step_penalty)
self.reset()
def reset(self):
self.pos = (0, 0)
return self.state()
def state(self):
r, c = self.pos
return r * self.n + c
def step(self, a):
dr, dc = DELTAS[a]
r, c = self.pos
nr = clamp(r + dr, 0, self.n - 1)
nc = clamp(c + dc, 0, self.n - 1)
self.pos = (nr, nc)
if self.pos == self.goal:
return self.state(), 1.0, True
if self.pos in self.traps:
return self.state(), -1.0, True
return self.state(), self.step_penalty, False
# -----------------------------
# RL helpers
# -----------------------------
def epsilon_greedy(Q, s, eps):
if np.random.rand() < eps:
return int(np.random.randint(Q.shape[1]))
return int(np.argmax(Q[s]))
# -----------------------------
# Rendering (HTML + plots)
# -----------------------------
def render_grid_html(env):
n = env.n
sr, sc = (0, 0)
gr_, gc_ = env.goal
ar, ac = env.pos
def cell(bg, txt, bold=False):
w = "font-weight:700;" if bold else ""
return (
f\"<td style='background:{bg};{w}border:1px solid #ddd;"
"width:42px;height:42px;text-align:center;font-size:18px'>"
f\"{txt}</td>\"
)
html = ["<table style='border-collapse:collapse'>"]
for r in range(n):
html.append("<tr>")
for c in range(n):
pos = (r, c)
if pos == (sr, sc):
html.append(cell("#dbeafe", "S", True))
elif pos == (gr_, gc_):
html.append(cell("#dcfce7", "G", True))
elif pos in env.traps:
html.append(cell("#fee2e2", "X", True))
elif pos == (ar, ac):
html.append(cell("#fef9c3", "A", True))
else:
html.append(cell("#ffffff", "·"))
html.append("</tr>")
html.append("</table>")
return "".join(html)
def render_policy_html(Q, env):
n = env.n
sr, sc = (0, 0)
gr_, gc_ = env.goal
html = ["<table style='border-collapse:collapse'>"]
for r in range(n):
html.append("<tr>")
for c in range(n):
pos = (r, c)
s = r * n + c
if pos == (sr, sc):
html.append("<td>S</td>")
elif pos == (gr_, gc_):
html.append("<td>G</td>")
elif pos in env.traps:
html.append("<td>X</td>")
else:
html.append(f"<td>{ACTIONS[int(np.argmax(Q[s]))]}</td>")
html.append("</tr>")
html.append("</table>")
return "".join(html)
def reward_plot(rewards, current=None):
fig = plt.figure()
ys = list(rewards)
if current is not None:
ys.append(current)
if ys:
plt.plot(ys)
plt.scatter(len(ys) - 1, ys[-1])
plt.xlabel("Episode")
plt.ylabel("Total reward")
plt.tight_layout()
return fig
# -----------------------------
# State + step-by-step learning
# -----------------------------
def init_state(n=6):
env = Gridworld(n=n)
return {
"env": env,
"Q": np.zeros((n * n, 4)),
"epsilon": 0.6,
"alpha": 0.3,
"gamma": 0.95,
"eps_decay": 0.98,
"episode_reward": 0.0,
"rewards": [],
"steps": 0,
"max_steps": 50,
"last_info": "Klik op ‘Next step’ om te starten."
}
def next_step(state):
env = state["env"]
Q = state["Q"]
s = env.state()
a = epsilon_greedy(Q, s, state["epsilon"])
s2, r, done = env.step(a)
td_target = r + (0 if done else state["gamma"] * np.max(Q[s2]))
td_error = td_target - Q[s, a]
Q[s, a] += state["alpha"] * td_error
state["episode_reward"] += r
state["steps"] += 1
state["last_info"] = (
f"State s = {s}\\n"
f"Action a = {ACTIONS[a]}\\n"
f"Reward r = {r}\\n"
f"Next state s' = {s2}\\n\\n"
f"TD target = {td_target:.3f}\\n"
f"TD error = {td_error:.3f}\\n\\n"
f"Q(s,a) = {Q[s, a]:.3f}"
)
if done or state["steps"] >= state["max_steps"]:
state["rewards"].append(state["episode_reward"])
state["episode_reward"] = 0.0
state["steps"] = 0
state["epsilon"] *= state["eps_decay"]
env.reset()
return (
state,
render_grid_html(env),
render_policy_html(Q, env),
reward_plot(state["rewards"], state["episode_reward"]),
state["last_info"],
)
# -----------------------------
# UI
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown(
\"\"\"
# 🎮 Gridworld Reinforcement Learning (Q-learning)
Klik **Next step** om **één echte reinforcement learning update** te zien.
Je ziet de agent bewegen, de reward oplopen en de Q-waarden veranderen.
\"\"\"
)
state = gr.State(init_state())
grid = gr.HTML(label="Gridworld")
policy = gr.HTML(label="Policy")
plot = gr.Plot(label="Reward per episode")
info = gr.Textbox(label="Wat gebeurt er nu?", lines=10)
btn = gr.Button("Next step")
btn.click(
next_step,
inputs=state,
outputs=[state, grid, policy, plot, info],
)
demo.load(
lambda st: (
st,
render_grid_html(st["env"]),
render_policy_html(st["Q"], st["env"]),
reward_plot(st["rewards"], st["episode_reward"]),
st["last_info"],
),
inputs=state,
outputs=[state, grid, policy, plot, info],
)
demo.launch()
"""
req = "gradio\nnumpy\nmatplotlib\n"
with open("/mnt/data/app.py", "w", encoding="utf-8") as f:
f.write(app_py_clean)
with open("/mnt/data/requirements.txt", "w", encoding="utf-8") as f:
f.write(req)
("/mnt/data/app.py", "/mnt/data/requirements.txt")