Marcel0123 commited on
Commit
0c6bd2f
·
verified ·
1 Parent(s): 8e0fabc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -81
app.py CHANGED
@@ -1,18 +1,22 @@
1
- # Fix the syntax error by correcting the f-string quotes and rewrite files
2
 
3
- app_py_fixed = """import gradio as gr
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
 
7
- # -----------------------------
8
  # Gridworld RL demo (visual + step-by-step)
9
- # -----------------------------
 
10
  ACTIONS = ["↑", "→", "↓", "←"]
11
  DELTAS = [(-1, 0), (0, 1), (1, 0), (0, -1)]
12
 
13
  def clamp(x, lo, hi):
14
  return max(lo, min(hi, x))
15
 
 
 
 
16
  class Gridworld:
17
  def __init__(self, n=6, step_penalty=-0.01):
18
  self.n = n
@@ -42,13 +46,16 @@ class Gridworld:
42
  return self.state(), -1.0, True
43
  return self.state(), self.step_penalty, False
44
 
 
 
 
45
  def epsilon_greedy(Q, s, eps):
46
  if np.random.rand() < eps:
47
  return int(np.random.randint(Q.shape[1]))
48
  return int(np.argmax(Q[s]))
49
 
50
  # -----------------------------
51
- # Rendering helpers
52
  # -----------------------------
53
  def render_grid_html(env):
54
  n = env.n
@@ -58,50 +65,53 @@ def render_grid_html(env):
58
 
59
  def cell(bg, txt, bold=False):
60
  w = "font-weight:700;" if bold else ""
61
- return f\"\"\"<td style='background:{bg};{w}
62
- border:1px solid #ddd;width:44px;height:44px;
63
- text-align:center;font-size:18px'>{txt}</td>\"\"\"
 
 
64
 
65
- html = [\"<table style='border-collapse:collapse'>\"]
66
  for r in range(n):
67
- html.append(\"<tr>\")
68
  for c in range(n):
69
  pos = (r, c)
70
  if pos == (sr, sc):
71
- html.append(cell(\"#dbeafe\", \"S\", True))
72
  elif pos == (gr_, gc_):
73
- html.append(cell(\"#dcfce7\", \"G\", True))
74
  elif pos in env.traps:
75
- html.append(cell(\"#fee2e2\", \"X\", True))
76
  elif pos == (ar, ac):
77
- html.append(cell(\"#fef9c3\", \"A\", True))
78
  else:
79
- html.append(cell(\"#ffffff\", \"·\"))
80
- html.append(\"</tr>\")
81
- html.append(\"</table>\")
82
- return \"\".join(html)
83
 
84
  def render_policy_html(Q, env):
85
  n = env.n
86
  sr, sc = (0, 0)
87
  gr_, gc_ = env.goal
88
- html = [\"<table style='border-collapse:collapse'>\"]
 
89
  for r in range(n):
90
- html.append(\"<tr>\")
91
  for c in range(n):
92
  pos = (r, c)
93
  s = r * n + c
94
  if pos == (sr, sc):
95
- html.append(f\"<td>S</td>\")
96
  elif pos == (gr_, gc_):
97
- html.append(f\"<td>G</td>\")
98
  elif pos in env.traps:
99
- html.append(f\"<td>X</td>\")
100
  else:
101
- html.append(f\"<td>{ACTIONS[int(np.argmax(Q[s]))]}</td>\")
102
- html.append(\"</tr>\")
103
- html.append(\"</table>\")
104
- return \"\".join(html)
105
 
106
  def reward_plot(rewards, current=None):
107
  fig = plt.figure()
@@ -110,105 +120,120 @@ def reward_plot(rewards, current=None):
110
  ys.append(current)
111
  if ys:
112
  plt.plot(ys)
113
- plt.scatter(len(ys)-1, ys[-1])
114
- plt.xlabel(\"Episode\")
115
- plt.ylabel(\"Total reward\")
 
116
  return fig
117
 
118
  # -----------------------------
119
- # State + RL step
120
  # -----------------------------
121
  def init_state(n=6):
122
  env = Gridworld(n=n)
123
  return {
124
- \"env\": env,
125
- \"Q\": np.zeros((n*n, 4)),
126
- \"epsilon\": 0.6,
127
- \"alpha\": 0.3,
128
- \"gamma\": 0.95,
129
- \"eps_decay\": 0.98,
130
- \"episode_reward\": 0.0,
131
- \"rewards\": [],
132
- \"steps\": 0,
133
- \"max_steps\": 50,
134
- \"last_info\": \"Klik op Next step om te starten\"
135
  }
136
 
137
  def next_step(state):
138
- env = state[\"env\"]
139
- Q = state[\"Q\"]
140
 
141
  s = env.state()
142
- a = epsilon_greedy(Q, s, state[\"epsilon\"])
143
  s2, r, done = env.step(a)
144
 
145
- td_target = r + (0 if done else state[\"gamma\"] * np.max(Q[s2]))
146
  td_error = td_target - Q[s, a]
147
- Q[s, a] += state[\"alpha\"] * td_error
148
-
149
- state[\"episode_reward\"] += r
150
- state[\"steps\"] += 1
151
-
152
- state[\"last_info\"] = (
153
- f\"State s = {s}\\n\"
154
- f\"Action a = {ACTIONS[a]}\\n\"
155
- f\"Reward r = {r}\\n\"
156
- f\"Next state s' = {s2}\\n\\n\"
157
- f\"TD target = {td_target:.3f}\\n\"
158
- f\"TD error = {td_error:.3f}\\n\\n\"
159
- f\"Q(s,a) updated to {Q[s, a]:.3f}\"
160
  )
161
 
162
- if done or state[\"steps\"] >= state[\"max_steps\"]:
163
- state[\"rewards\"].append(state[\"episode_reward\"])
164
- state[\"episode_reward\"] = 0.0
165
- state[\"steps\"] = 0
166
- state[\"epsilon\"] *= state[\"eps_decay\"]
167
  env.reset()
168
 
169
  return (
170
  state,
171
  render_grid_html(env),
172
  render_policy_html(Q, env),
173
- reward_plot(state[\"rewards\"], state[\"episode_reward\"]),
174
- state[\"last_info\"],
175
  )
176
 
177
  # -----------------------------
178
  # UI
179
  # -----------------------------
180
  with gr.Blocks() as demo:
181
- gr.Markdown(\"\"\"
 
182
  # 🎮 Gridworld Reinforcement Learning (Q-learning)
183
 
184
- Klik **Next step** en zie hoe de agent leert via TD-updates.
185
- \"\"\")
 
 
186
 
187
  state = gr.State(init_state())
188
 
189
- grid = gr.HTML()
190
- policy = gr.HTML()
191
- plot = gr.Plot()
192
- info = gr.Textbox(lines=10)
193
 
194
- btn = gr.Button(\"Next step\"")
195
 
196
- btn.click(next_step, inputs=state, outputs=[state, grid, policy, plot, info])
 
 
 
 
197
 
198
  demo.load(
199
- lambda st: (st, render_grid_html(st[\"env\"]), render_policy_html(st[\"Q\"], st[\"env\"]),
200
- reward_plot(st[\"rewards\"], st[\"episode_reward\"]), st[\"last_info\"]),
 
 
 
 
 
201
  inputs=state,
202
- outputs=[state, grid, policy, plot, info]
203
  )
204
 
205
  demo.launch()
206
  """
207
 
 
 
208
  with open("/mnt/data/app.py", "w", encoding="utf-8") as f:
209
- f.write(app_py_fixed)
210
 
211
  with open("/mnt/data/requirements.txt", "w", encoding="utf-8") as f:
212
- f.write("gradio\nnumpy\nmatplotlib\n")
213
 
214
- "/mnt/data/app.py", "/mnt/data/requirements.txt"
 
1
+ # Create a clean app.py WITHOUT any file-writing code (fixing the runtime error)
2
 
3
+ app_py_clean = """import gradio as gr
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
 
7
+ # =============================
8
  # Gridworld RL demo (visual + step-by-step)
9
+ # =============================
10
+
11
  ACTIONS = ["↑", "→", "↓", "←"]
12
  DELTAS = [(-1, 0), (0, 1), (1, 0), (0, -1)]
13
 
14
  def clamp(x, lo, hi):
15
  return max(lo, min(hi, x))
16
 
17
+ # -----------------------------
18
+ # Environment
19
+ # -----------------------------
20
  class Gridworld:
21
  def __init__(self, n=6, step_penalty=-0.01):
22
  self.n = n
 
46
  return self.state(), -1.0, True
47
  return self.state(), self.step_penalty, False
48
 
49
+ # -----------------------------
50
+ # RL helpers
51
+ # -----------------------------
52
  def epsilon_greedy(Q, s, eps):
53
  if np.random.rand() < eps:
54
  return int(np.random.randint(Q.shape[1]))
55
  return int(np.argmax(Q[s]))
56
 
57
  # -----------------------------
58
+ # Rendering (HTML + plots)
59
  # -----------------------------
60
  def render_grid_html(env):
61
  n = env.n
 
65
 
66
  def cell(bg, txt, bold=False):
67
  w = "font-weight:700;" if bold else ""
68
+ return (
69
+ f\"<td style='background:{bg};{w}border:1px solid #ddd;"
70
+ "width:42px;height:42px;text-align:center;font-size:18px'>"
71
+ f\"{txt}</td>\"
72
+ )
73
 
74
+ html = ["<table style='border-collapse:collapse'>"]
75
  for r in range(n):
76
+ html.append("<tr>")
77
  for c in range(n):
78
  pos = (r, c)
79
  if pos == (sr, sc):
80
+ html.append(cell("#dbeafe", "S", True))
81
  elif pos == (gr_, gc_):
82
+ html.append(cell("#dcfce7", "G", True))
83
  elif pos in env.traps:
84
+ html.append(cell("#fee2e2", "X", True))
85
  elif pos == (ar, ac):
86
+ html.append(cell("#fef9c3", "A", True))
87
  else:
88
+ html.append(cell("#ffffff", "·"))
89
+ html.append("</tr>")
90
+ html.append("</table>")
91
+ return "".join(html)
92
 
93
  def render_policy_html(Q, env):
94
  n = env.n
95
  sr, sc = (0, 0)
96
  gr_, gc_ = env.goal
97
+
98
+ html = ["<table style='border-collapse:collapse'>"]
99
  for r in range(n):
100
+ html.append("<tr>")
101
  for c in range(n):
102
  pos = (r, c)
103
  s = r * n + c
104
  if pos == (sr, sc):
105
+ html.append("<td>S</td>")
106
  elif pos == (gr_, gc_):
107
+ html.append("<td>G</td>")
108
  elif pos in env.traps:
109
+ html.append("<td>X</td>")
110
  else:
111
+ html.append(f"<td>{ACTIONS[int(np.argmax(Q[s]))]}</td>")
112
+ html.append("</tr>")
113
+ html.append("</table>")
114
+ return "".join(html)
115
 
116
  def reward_plot(rewards, current=None):
117
  fig = plt.figure()
 
120
  ys.append(current)
121
  if ys:
122
  plt.plot(ys)
123
+ plt.scatter(len(ys) - 1, ys[-1])
124
+ plt.xlabel("Episode")
125
+ plt.ylabel("Total reward")
126
+ plt.tight_layout()
127
  return fig
128
 
129
  # -----------------------------
130
+ # State + step-by-step learning
131
  # -----------------------------
132
  def init_state(n=6):
133
  env = Gridworld(n=n)
134
  return {
135
+ "env": env,
136
+ "Q": np.zeros((n * n, 4)),
137
+ "epsilon": 0.6,
138
+ "alpha": 0.3,
139
+ "gamma": 0.95,
140
+ "eps_decay": 0.98,
141
+ "episode_reward": 0.0,
142
+ "rewards": [],
143
+ "steps": 0,
144
+ "max_steps": 50,
145
+ "last_info": "Klik op Next step om te starten."
146
  }
147
 
148
  def next_step(state):
149
+ env = state["env"]
150
+ Q = state["Q"]
151
 
152
  s = env.state()
153
+ a = epsilon_greedy(Q, s, state["epsilon"])
154
  s2, r, done = env.step(a)
155
 
156
+ td_target = r + (0 if done else state["gamma"] * np.max(Q[s2]))
157
  td_error = td_target - Q[s, a]
158
+ Q[s, a] += state["alpha"] * td_error
159
+
160
+ state["episode_reward"] += r
161
+ state["steps"] += 1
162
+
163
+ state["last_info"] = (
164
+ f"State s = {s}\\n"
165
+ f"Action a = {ACTIONS[a]}\\n"
166
+ f"Reward r = {r}\\n"
167
+ f"Next state s' = {s2}\\n\\n"
168
+ f"TD target = {td_target:.3f}\\n"
169
+ f"TD error = {td_error:.3f}\\n\\n"
170
+ f"Q(s,a) = {Q[s, a]:.3f}"
171
  )
172
 
173
+ if done or state["steps"] >= state["max_steps"]:
174
+ state["rewards"].append(state["episode_reward"])
175
+ state["episode_reward"] = 0.0
176
+ state["steps"] = 0
177
+ state["epsilon"] *= state["eps_decay"]
178
  env.reset()
179
 
180
  return (
181
  state,
182
  render_grid_html(env),
183
  render_policy_html(Q, env),
184
+ reward_plot(state["rewards"], state["episode_reward"]),
185
+ state["last_info"],
186
  )
187
 
188
  # -----------------------------
189
  # UI
190
  # -----------------------------
191
  with gr.Blocks() as demo:
192
+ gr.Markdown(
193
+ \"\"\"
194
  # 🎮 Gridworld Reinforcement Learning (Q-learning)
195
 
196
+ Klik **Next step** om **één echte reinforcement learning update** te zien.
197
+ Je ziet de agent bewegen, de reward oplopen en de Q-waarden veranderen.
198
+ \"\"\"
199
+ )
200
 
201
  state = gr.State(init_state())
202
 
203
+ grid = gr.HTML(label="Gridworld")
204
+ policy = gr.HTML(label="Policy")
205
+ plot = gr.Plot(label="Reward per episode")
206
+ info = gr.Textbox(label="Wat gebeurt er nu?", lines=10)
207
 
208
+ btn = gr.Button("Next step")
209
 
210
+ btn.click(
211
+ next_step,
212
+ inputs=state,
213
+ outputs=[state, grid, policy, plot, info],
214
+ )
215
 
216
  demo.load(
217
+ lambda st: (
218
+ st,
219
+ render_grid_html(st["env"]),
220
+ render_policy_html(st["Q"], st["env"]),
221
+ reward_plot(st["rewards"], st["episode_reward"]),
222
+ st["last_info"],
223
+ ),
224
  inputs=state,
225
+ outputs=[state, grid, policy, plot, info],
226
  )
227
 
228
  demo.launch()
229
  """
230
 
231
+ req = "gradio\nnumpy\nmatplotlib\n"
232
+
233
  with open("/mnt/data/app.py", "w", encoding="utf-8") as f:
234
+ f.write(app_py_clean)
235
 
236
  with open("/mnt/data/requirements.txt", "w", encoding="utf-8") as f:
237
+ f.write(req)
238
 
239
+ ("/mnt/data/app.py", "/mnt/data/requirements.txt")