sam522 commited on
Commit
18cc826
·
verified ·
1 Parent(s): a6c2c46

Upload PPO LunarLander model

Browse files
Files changed (6) hide show
  1. README.md +104 -0
  2. config.json +18 -0
  3. model.pt +3 -0
  4. requirements.txt +4 -0
  5. results.json +866 -0
  6. test_model.py +47 -0
README.md ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - LunarLander-v2
4
+ - ppo
5
+ - deep-reinforcement-learning
6
+ - reinforcement-learning
7
+ - custom-implementation
8
+ model-index:
9
+ - name: PPO
10
+ results:
11
+ - task:
12
+ type: reinforcement-learning
13
+ name: reinforcement-learning
14
+ dataset:
15
+ name: LunarLander-v2
16
+ type: LunarLander-v2
17
+ metrics:
18
+ - type: mean_reward
19
+ value: -9.92 +/- 91.50
20
+ name: mean_reward
21
+ verified: false
22
+ ---
23
+
24
+ # **PPO** Agent playing **LunarLander-v2**
25
+
26
+ This is a trained model of a **PPO** agent playing **LunarLander-v2** using a custom implementation.
27
+
28
+ ## Usage
29
+
30
+ ```python
31
+ import torch
32
+ import gymnasium as gym
33
+ import torch.nn as nn
34
+ import torch.nn.functional as F
35
+ from torch.distributions import Categorical
36
+ import numpy as np
37
+
38
+ # Define the Actor network
39
+ class Actor(nn.Module):
40
+ def __init__(self, state_dim, action_dim, hidden_size=64):
41
+ super().__init__()
42
+ self.network = nn.Sequential(
43
+ nn.Linear(state_dim, hidden_size),
44
+ nn.Tanh(),
45
+ nn.Linear(hidden_size, hidden_size),
46
+ nn.Tanh(),
47
+ nn.Linear(hidden_size, action_dim)
48
+ )
49
+
50
+ def forward(self, x):
51
+ return self.network(x)
52
+
53
+ # Load the model
54
+ checkpoint = torch.load("model.pt", map_location='cpu')
55
+ actor = Actor(state_dim=8, action_dim=4, hidden_size=checkpoint['config']['hidden_size'])
56
+ actor.load_state_dict(checkpoint['actor_state_dict'])
57
+ actor.eval()
58
+
59
+ # Test the agent
60
+ env = gym.make("LunarLander-v2")
61
+ state, _ = env.reset()
62
+ total_reward = 0
63
+
64
+ for _ in range(1000): # Max steps
65
+ with torch.no_grad():
66
+ state_tensor = torch.FloatTensor(state).unsqueeze(0)
67
+ logits = actor(state_tensor)
68
+ action = torch.argmax(logits, dim=-1).item()
69
+
70
+ state, reward, terminated, truncated, _ = env.step(action)
71
+ total_reward += reward
72
+
73
+ if terminated or truncated:
74
+ break
75
+
76
+ print(f"Total reward: {total_reward:.2f}")
77
+ ```
78
+
79
+ ## Training Results
80
+
81
+ - **Mean reward**: -9.92 ± 91.50
82
+ - **Best reward**: 143.33
83
+ - **Success rate**: 0.0% (episodes with reward > 200)
84
+ - **Total episodes**: 353
85
+ - **Total timesteps**: 100,000
86
+
87
+ ## Algorithm Configuration
88
+
89
+ - **Algorithm**: Proximal Policy Optimization (PPO)
90
+ - **Learning rate**: 0.0003
91
+ - **Batch size**: 2048
92
+ - **Clip coefficient**: 0.2
93
+ - **Entropy coefficient**: 0.01
94
+ - **Value coefficient**: 0.5
95
+ - **Gamma**: 0.99
96
+ - **GAE Lambda**: 0.95
97
+
98
+ ## Training Environment
99
+
100
+ - **Environment**: LunarLander-v2
101
+ - **Framework**: PyTorch + Gymnasium
102
+ - **Training date**: 2025-09-04
103
+
104
+ This model was trained as part of the Hugging Face Deep RL Course.
config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "env_id": "LunarLander-v2",
3
+ "max_episode_steps": 1000,
4
+ "total_timesteps": 100000,
5
+ "batch_size": 2048,
6
+ "minibatch_size": 64,
7
+ "num_epochs": 4,
8
+ "learning_rate": 0.0003,
9
+ "gamma": 0.99,
10
+ "gae_lambda": 0.95,
11
+ "clip_coef": 0.2,
12
+ "entropy_coef": 0.01,
13
+ "value_coef": 0.5,
14
+ "max_grad_norm": 0.5,
15
+ "hidden_size": 64,
16
+ "log_frequency": 1000,
17
+ "eval_frequency": 5000
18
+ }
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2984c79fe12f2cffd8ed7a990bf3ca40532c544591abdd38b7c1fee9f74f5413
3
+ size 65115
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch>=1.9.0
2
+ gymnasium[box2d]>=0.28.0
3
+ numpy>=1.21.0
4
+ matplotlib>=3.3.0
results.json ADDED
@@ -0,0 +1,866 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "episode_rewards": [
3
+ -0.06825948606616805,
4
+ -169.11205875315596,
5
+ -215.3913619178971,
6
+ -139.8789067666994,
7
+ -509.00763208625216,
8
+ -109.10158643440585,
9
+ -166.36546227855413,
10
+ -200.2774249561457,
11
+ -177.853047471592,
12
+ -117.68830352738027,
13
+ -377.8491467246267,
14
+ -88.45872897085687,
15
+ -87.81601692565125,
16
+ -95.75884386471762,
17
+ -42.98813630908335,
18
+ -243.49287632618473,
19
+ -139.87800064187832,
20
+ -104.03428716974196,
21
+ -390.8048731990434,
22
+ -225.5430582633193,
23
+ -185.89923995942115,
24
+ -88.42383149105987,
25
+ -408.9225997545925,
26
+ -106.22383469588478,
27
+ -141.10328573561222,
28
+ -87.93442530313699,
29
+ -79.96148810811498,
30
+ -132.9337179443088,
31
+ -213.97733893091453,
32
+ -147.42151069764924,
33
+ -162.99907861266115,
34
+ -144.6632461507486,
35
+ -84.81459503872658,
36
+ -69.14377921641031,
37
+ -104.60268858246573,
38
+ -57.19420918932683,
39
+ -161.42868563828253,
40
+ -344.4091094104822,
41
+ -101.16971336992118,
42
+ -188.49448757234114,
43
+ -13.678164633553635,
44
+ -108.50645100939873,
45
+ -170.76606043027135,
46
+ -186.7262446158587,
47
+ -15.216783348278511,
48
+ -114.86771345279944,
49
+ -227.3692629653085,
50
+ -147.20910772391082,
51
+ -129.51296531717037,
52
+ -169.92751893892438,
53
+ -92.67419463241187,
54
+ -66.52312287592397,
55
+ -205.00254460937958,
56
+ -294.4399920099744,
57
+ -68.70825604829983,
58
+ -95.73002166672707,
59
+ -203.00938076883114,
60
+ -180.01476651566466,
61
+ -218.65008045854182,
62
+ -109.9569345200619,
63
+ -249.13048322194444,
64
+ -66.71251085050328,
65
+ -96.39204623549715,
66
+ -115.68446613589023,
67
+ -163.87003732668143,
68
+ -177.90045049613406,
69
+ -75.71805993225172,
70
+ -123.02066775256463,
71
+ -73.64389095234229,
72
+ -261.12391153130307,
73
+ -106.47186764493826,
74
+ -101.80564931343925,
75
+ -141.86797230927073,
76
+ -156.61153573618833,
77
+ -95.2393424830547,
78
+ -90.22571965883145,
79
+ -143.6636948877111,
80
+ -51.35073548228167,
81
+ -519.0832165573745,
82
+ -341.6084146927546,
83
+ -79.28020255548074,
84
+ -119.36101244775585,
85
+ -88.1950180493556,
86
+ -122.61044158853123,
87
+ -135.71958328905532,
88
+ -207.58081741472392,
89
+ -144.75407671414126,
90
+ -83.38352184054375,
91
+ -176.40191781750895,
92
+ -101.60612676644908,
93
+ -75.21788949050259,
94
+ -315.1279736505058,
95
+ -207.51669932911682,
96
+ -48.31907149982139,
97
+ -308.26873149443145,
98
+ -219.67469874966247,
99
+ -137.19352252669634,
100
+ -44.806359347531014,
101
+ -107.76555172357948,
102
+ -297.23318223874264,
103
+ -186.3685812608377,
104
+ -110.65739435770014,
105
+ -253.53768551432907,
106
+ -178.73459600772205,
107
+ -86.00970854249826,
108
+ -71.44504434241344,
109
+ -315.7744485503431,
110
+ -94.420414320224,
111
+ -207.2584697830588,
112
+ -71.94779060839778,
113
+ -135.1856849970423,
114
+ -55.45307758594119,
115
+ -137.02197706792916,
116
+ -173.33883858632197,
117
+ -116.13620467096102,
118
+ -100.33774882902676,
119
+ -96.21422856581302,
120
+ -306.7805674821406,
121
+ -115.20336496789074,
122
+ -91.53036352443942,
123
+ -124.2550546719635,
124
+ -44.74700688984777,
125
+ -74.32923924165786,
126
+ -106.72254156258597,
127
+ -145.1791552903415,
128
+ -325.6934198758087,
129
+ -492.27073137336924,
130
+ -86.99830724178076,
131
+ -175.6210715671713,
132
+ -331.7468607743321,
133
+ -85.70705823312146,
134
+ -105.0045904423163,
135
+ -99.45017495028988,
136
+ -78.43215638806684,
137
+ -161.86971764719186,
138
+ -99.54601422371243,
139
+ -117.09788268141342,
140
+ -247.03344075957054,
141
+ -165.6021271585296,
142
+ -134.35843308237764,
143
+ -49.73205798092371,
144
+ 23.212409800418598,
145
+ -115.15950679935591,
146
+ -80.7186329423706,
147
+ -63.94584994249098,
148
+ -94.55277414944797,
149
+ -85.40629043215664,
150
+ -50.799756615195946,
151
+ -87.72580112853257,
152
+ -135.9770036205606,
153
+ -138.92725370070528,
154
+ -76.66805661003464,
155
+ -117.18797733566853,
156
+ -128.15344308974312,
157
+ -167.25237561311252,
158
+ -180.84063140488854,
159
+ -93.59998844842957,
160
+ -191.19317336997725,
161
+ -26.29411332936955,
162
+ -96.49525715689076,
163
+ -145.29278234776177,
164
+ -119.57008506165374,
165
+ -133.13342043800822,
166
+ -134.65198237996353,
167
+ -76.58624460994515,
168
+ -70.34120337383676,
169
+ -59.63430906498967,
170
+ -357.6935432146067,
171
+ -179.39487446119162,
172
+ -337.56760565449656,
173
+ -49.05309353091758,
174
+ -152.00949803010212,
175
+ -67.1993251644055,
176
+ -6.099140936544586,
177
+ -14.201823869525427,
178
+ -81.7722795475811,
179
+ -200.1448172618197,
180
+ -184.37469314203383,
181
+ -36.92118377875647,
182
+ -90.34611709777525,
183
+ -68.6805904834934,
184
+ -53.716997900380434,
185
+ -23.113558688573406,
186
+ -8.903564346705423,
187
+ -27.55353003701832,
188
+ -55.428221828307414,
189
+ -84.96989184224451,
190
+ -115.05392030264886,
191
+ -53.77243026924064,
192
+ -125.56693025600609,
193
+ -127.89382603352178,
194
+ -298.05774179855473,
195
+ -59.40924420265998,
196
+ 14.639449861942666,
197
+ -125.90139888201305,
198
+ -28.511688106958374,
199
+ -227.56883861154222,
200
+ -226.9705435777895,
201
+ -233.8241736362481,
202
+ -145.25056610537473,
203
+ -145.1836740137075,
204
+ -380.3525086402928,
205
+ -439.9677258270466,
206
+ -211.5204267450631,
207
+ 0.11140423555376344,
208
+ -173.3141272144474,
209
+ -369.91312641847395,
210
+ -247.012128973268,
211
+ -2.5561605861201144,
212
+ -88.12799381680097,
213
+ -84.81944076821246,
214
+ -112.23794903888829,
215
+ -134.88684454749006,
216
+ -46.5822621434159,
217
+ -37.89523863193384,
218
+ -55.66770982032418,
219
+ -192.16655727616046,
220
+ -221.83124282901133,
221
+ -69.30029113382633,
222
+ -63.65819587905243,
223
+ -159.41511505490848,
224
+ -268.658235468461,
225
+ 1.5979156618273151,
226
+ -132.945999852996,
227
+ -2.5585262389231644,
228
+ -229.86666718444704,
229
+ -33.50987358508034,
230
+ -91.74437579021038,
231
+ -13.499084717157515,
232
+ -33.91904110843143,
233
+ 83.74854941453819,
234
+ 14.740936458834327,
235
+ -17.505239573516818,
236
+ 17.005275529258597,
237
+ -64.13899263345624,
238
+ -1.8494117922500948,
239
+ -13.80694293885719,
240
+ -161.59203053864863,
241
+ -5.111877413591586,
242
+ -62.267203761515994,
243
+ -203.1636986569103,
244
+ -16.33521283772599,
245
+ -43.76419289703527,
246
+ -231.79356791613571,
247
+ -19.82721510771465,
248
+ -24.94162644332951,
249
+ -140.13489191875996,
250
+ -35.33350810426582,
251
+ -25.517750232653157,
252
+ -323.82197700896745,
253
+ -33.1830087723218,
254
+ -58.84664699201487,
255
+ -53.8204562164564,
256
+ -94.41007698626146,
257
+ 20.1768324848328,
258
+ 39.568090697977766,
259
+ 28.742513615873946,
260
+ -10.876059937404719,
261
+ 36.887560542335905,
262
+ 57.72817332988835,
263
+ -121.2429640394513,
264
+ -277.64083651135974,
265
+ 36.56318243984942,
266
+ -189.75141689003476,
267
+ -102.57304618954247,
268
+ -25.353143435251624,
269
+ -81.33489379005088,
270
+ -34.462574286471536,
271
+ -13.184893660098453,
272
+ -20.9300185028616,
273
+ -43.849572287832835,
274
+ -61.405688917778264,
275
+ 4.850469385603915,
276
+ 73.42787441544748,
277
+ -14.142018323762258,
278
+ -10.218390830410101,
279
+ -42.09291614967388,
280
+ 6.446253695358948,
281
+ -89.34496052487404,
282
+ 7.730260234043357,
283
+ -3.2801650797843678,
284
+ 40.05943458622408,
285
+ 3.4305993288540044,
286
+ -181.9884743359073,
287
+ -48.15835818232749,
288
+ -62.52116058763089,
289
+ -69.72306275970666,
290
+ -36.67234925154844,
291
+ -2.576860974653428,
292
+ 74.94141121003486,
293
+ 7.110245109088272,
294
+ -121.51373599032131,
295
+ 41.62826356222841,
296
+ 32.69544163530229,
297
+ -10.915920698598285,
298
+ 38.00125649565851,
299
+ -127.1549093744554,
300
+ 45.17158197251346,
301
+ -291.45632312199194,
302
+ -347.4260170142368,
303
+ -177.93849720833748,
304
+ -151.55347718264545,
305
+ -187.41104423710306,
306
+ -24.44516627724018,
307
+ -92.84446205697854,
308
+ -68.53762724577287,
309
+ -117.39158029036854,
310
+ 6.951101977978475,
311
+ -111.45460200674995,
312
+ 37.59049691017192,
313
+ 76.11035439251663,
314
+ 60.072023854697335,
315
+ 132.17751182796914,
316
+ 95.66148791008895,
317
+ 127.01426732947682,
318
+ -28.125136920530366,
319
+ 80.49117019603592,
320
+ 99.07148572522087,
321
+ 80.90938719489925,
322
+ -12.548944087098235,
323
+ 91.4527720834068,
324
+ 10.987130469205631,
325
+ 63.18406103139541,
326
+ -35.32557212165236,
327
+ -12.361553198005272,
328
+ 85.70227644691255,
329
+ 109.80887623490239,
330
+ 60.50440531795487,
331
+ 50.73971224683154,
332
+ -92.68465262159447,
333
+ -19.7173009646866,
334
+ 49.82634045083176,
335
+ 34.626965702974054,
336
+ 9.144020124847328,
337
+ 84.83352703661873,
338
+ 30.00431252390878,
339
+ 57.90072859438916,
340
+ 85.08031680993007,
341
+ -51.65908002767094,
342
+ -39.021653890489404,
343
+ -44.40585873624194,
344
+ 60.18370016305159,
345
+ 87.8427834734688,
346
+ 85.46914135213747,
347
+ 91.95126816001013,
348
+ -71.52945051996913,
349
+ -92.6781186850232,
350
+ 51.24324737252195,
351
+ 87.4071023140425,
352
+ 143.33163700976513,
353
+ 59.59649815264238,
354
+ 67.47989524602491,
355
+ 26.09265425121593
356
+ ],
357
+ "episode_lengths": [
358
+ 99,
359
+ 109,
360
+ 113,
361
+ 120,
362
+ 87,
363
+ 107,
364
+ 74,
365
+ 108,
366
+ 88,
367
+ 82,
368
+ 90,
369
+ 158,
370
+ 74,
371
+ 66,
372
+ 99,
373
+ 113,
374
+ 91,
375
+ 79,
376
+ 118,
377
+ 85,
378
+ 89,
379
+ 77,
380
+ 103,
381
+ 96,
382
+ 109,
383
+ 77,
384
+ 105,
385
+ 73,
386
+ 141,
387
+ 128,
388
+ 66,
389
+ 122,
390
+ 95,
391
+ 75,
392
+ 82,
393
+ 86,
394
+ 118,
395
+ 121,
396
+ 110,
397
+ 89,
398
+ 89,
399
+ 94,
400
+ 70,
401
+ 100,
402
+ 74,
403
+ 106,
404
+ 111,
405
+ 71,
406
+ 116,
407
+ 105,
408
+ 66,
409
+ 68,
410
+ 93,
411
+ 125,
412
+ 70,
413
+ 89,
414
+ 102,
415
+ 89,
416
+ 125,
417
+ 78,
418
+ 113,
419
+ 97,
420
+ 80,
421
+ 71,
422
+ 154,
423
+ 93,
424
+ 66,
425
+ 113,
426
+ 93,
427
+ 106,
428
+ 94,
429
+ 129,
430
+ 70,
431
+ 74,
432
+ 83,
433
+ 75,
434
+ 83,
435
+ 68,
436
+ 113,
437
+ 133,
438
+ 68,
439
+ 69,
440
+ 66,
441
+ 82,
442
+ 92,
443
+ 144,
444
+ 101,
445
+ 72,
446
+ 99,
447
+ 129,
448
+ 72,
449
+ 133,
450
+ 127,
451
+ 75,
452
+ 136,
453
+ 113,
454
+ 77,
455
+ 106,
456
+ 106,
457
+ 148,
458
+ 111,
459
+ 71,
460
+ 131,
461
+ 124,
462
+ 72,
463
+ 139,
464
+ 120,
465
+ 132,
466
+ 142,
467
+ 72,
468
+ 88,
469
+ 153,
470
+ 134,
471
+ 108,
472
+ 91,
473
+ 192,
474
+ 81,
475
+ 146,
476
+ 84,
477
+ 79,
478
+ 97,
479
+ 159,
480
+ 103,
481
+ 101,
482
+ 133,
483
+ 174,
484
+ 177,
485
+ 134,
486
+ 101,
487
+ 137,
488
+ 77,
489
+ 87,
490
+ 79,
491
+ 141,
492
+ 122,
493
+ 97,
494
+ 85,
495
+ 158,
496
+ 117,
497
+ 110,
498
+ 173,
499
+ 108,
500
+ 77,
501
+ 100,
502
+ 72,
503
+ 96,
504
+ 77,
505
+ 73,
506
+ 72,
507
+ 109,
508
+ 178,
509
+ 107,
510
+ 116,
511
+ 94,
512
+ 112,
513
+ 135,
514
+ 151,
515
+ 171,
516
+ 92,
517
+ 131,
518
+ 103,
519
+ 157,
520
+ 163,
521
+ 94,
522
+ 71,
523
+ 116,
524
+ 116,
525
+ 234,
526
+ 100,
527
+ 130,
528
+ 162,
529
+ 202,
530
+ 142,
531
+ 142,
532
+ 102,
533
+ 99,
534
+ 206,
535
+ 132,
536
+ 175,
537
+ 159,
538
+ 98,
539
+ 170,
540
+ 98,
541
+ 97,
542
+ 82,
543
+ 172,
544
+ 157,
545
+ 130,
546
+ 129,
547
+ 222,
548
+ 245,
549
+ 107,
550
+ 116,
551
+ 153,
552
+ 95,
553
+ 76,
554
+ 178,
555
+ 158,
556
+ 186,
557
+ 185,
558
+ 254,
559
+ 313,
560
+ 221,
561
+ 147,
562
+ 177,
563
+ 125,
564
+ 227,
565
+ 203,
566
+ 124,
567
+ 275,
568
+ 111,
569
+ 182,
570
+ 281,
571
+ 215,
572
+ 154,
573
+ 109,
574
+ 233,
575
+ 192,
576
+ 127,
577
+ 211,
578
+ 324,
579
+ 206,
580
+ 112,
581
+ 139,
582
+ 203,
583
+ 226,
584
+ 131,
585
+ 126,
586
+ 159,
587
+ 115,
588
+ 1000,
589
+ 250,
590
+ 192,
591
+ 174,
592
+ 94,
593
+ 90,
594
+ 139,
595
+ 199,
596
+ 137,
597
+ 119,
598
+ 148,
599
+ 120,
600
+ 110,
601
+ 325,
602
+ 173,
603
+ 274,
604
+ 304,
605
+ 281,
606
+ 225,
607
+ 211,
608
+ 222,
609
+ 219,
610
+ 225,
611
+ 343,
612
+ 191,
613
+ 200,
614
+ 137,
615
+ 170,
616
+ 143,
617
+ 1000,
618
+ 452,
619
+ 260,
620
+ 1000,
621
+ 273,
622
+ 212,
623
+ 208,
624
+ 220,
625
+ 210,
626
+ 132,
627
+ 178,
628
+ 185,
629
+ 243,
630
+ 348,
631
+ 1000,
632
+ 205,
633
+ 322,
634
+ 136,
635
+ 1000,
636
+ 208,
637
+ 223,
638
+ 211,
639
+ 1000,
640
+ 1000,
641
+ 205,
642
+ 194,
643
+ 215,
644
+ 376,
645
+ 334,
646
+ 363,
647
+ 1000,
648
+ 1000,
649
+ 656,
650
+ 1000,
651
+ 1000,
652
+ 257,
653
+ 1000,
654
+ 702,
655
+ 1000,
656
+ 573,
657
+ 732,
658
+ 330,
659
+ 1000,
660
+ 314,
661
+ 1000,
662
+ 445,
663
+ 523,
664
+ 377,
665
+ 1000,
666
+ 690,
667
+ 1000,
668
+ 1000,
669
+ 1000,
670
+ 1000,
671
+ 1000,
672
+ 1000,
673
+ 479,
674
+ 1000,
675
+ 1000,
676
+ 1000,
677
+ 308,
678
+ 1000,
679
+ 1000,
680
+ 1000,
681
+ 289,
682
+ 1000,
683
+ 1000,
684
+ 1000,
685
+ 1000,
686
+ 1000,
687
+ 674,
688
+ 1000,
689
+ 1000,
690
+ 1000,
691
+ 1000,
692
+ 1000,
693
+ 1000,
694
+ 1000,
695
+ 1000,
696
+ 568,
697
+ 336,
698
+ 510,
699
+ 1000,
700
+ 1000,
701
+ 1000,
702
+ 1000,
703
+ 426,
704
+ 450,
705
+ 1000,
706
+ 1000,
707
+ 1000,
708
+ 1000,
709
+ 1000,
710
+ 1000
711
+ ],
712
+ "training_losses": {
713
+ "actor": [
714
+ -0.01979029180802172,
715
+ -0.01952682618139079,
716
+ -0.023040959014906548,
717
+ -0.021843695227289572,
718
+ -0.01714207920304034,
719
+ -0.014676372746180277,
720
+ -0.017646560751018114,
721
+ -0.0215543580125086,
722
+ -0.014346814270538744,
723
+ -0.021339075115974993,
724
+ -0.016831182496389374,
725
+ -0.022772795979108196,
726
+ -0.023433746107912157,
727
+ -0.018633067615155596,
728
+ -0.01597626227157889,
729
+ -0.01767163053591503,
730
+ -0.01928800484893145,
731
+ -0.02197402838646667,
732
+ -0.018846331498934887,
733
+ -0.018185104410804342,
734
+ -0.017785934993298724,
735
+ -0.017179835092974827,
736
+ -0.018634800864674617,
737
+ -0.016147430505952798,
738
+ -0.020907533988065552,
739
+ -0.014310308593849186,
740
+ -0.01923305022501154,
741
+ -0.018182744635851122,
742
+ -0.016129654854012188,
743
+ -0.018532538808358368,
744
+ -0.015569949602650013,
745
+ -0.017233232538274024,
746
+ -0.01921561570270569,
747
+ -0.014850864667096175,
748
+ -0.013190674544603098,
749
+ -0.011923887774173636,
750
+ -0.017301808853517286,
751
+ -0.01568447455065325,
752
+ -0.012880522626801394,
753
+ -0.016250163862423506,
754
+ -0.01898401835205732,
755
+ -0.016680118802469224,
756
+ -0.015046855827677064,
757
+ -0.013486761541571468,
758
+ -0.014262968943512533,
759
+ -0.012893502258521039,
760
+ -0.014304315525805578,
761
+ -0.012067950345226564
762
+ ],
763
+ "critic": [
764
+ 1680.4626302719116,
765
+ 1030.951669216156,
766
+ 1021.7877514362335,
767
+ 1272.4353795051575,
768
+ 1158.8028721809387,
769
+ 727.7568366527557,
770
+ 901.111081123352,
771
+ 543.9186744689941,
772
+ 530.4165495634079,
773
+ 489.02578616142273,
774
+ 528.2152299880981,
775
+ 495.15413546562195,
776
+ 269.65131109952927,
777
+ 315.82098948955536,
778
+ 372.054971575737,
779
+ 369.43128794431686,
780
+ 209.46067798137665,
781
+ 245.78451192378998,
782
+ 218.96768248081207,
783
+ 120.53185752034187,
784
+ 160.47752958536148,
785
+ 162.42764976620674,
786
+ 156.59158584475517,
787
+ 67.61967559158802,
788
+ 56.876462534070015,
789
+ 80.85796889662743,
790
+ 218.38137596845627,
791
+ 178.75969290733337,
792
+ 131.04233753681183,
793
+ 98.44907432794571,
794
+ 30.182169884443283,
795
+ 46.721587389707565,
796
+ 103.00322636961937,
797
+ 28.124917030334473,
798
+ 90.61402375996113,
799
+ 33.96754629909992,
800
+ 53.8245303183794,
801
+ 24.37088042497635,
802
+ 115.26408568024635,
803
+ 29.321963973343372,
804
+ 26.280027896165848,
805
+ 19.028836239129305,
806
+ 15.724595122039318,
807
+ 168.65730379521847,
808
+ 17.816705368459225,
809
+ 117.39989982545376,
810
+ 25.387532763183117,
811
+ 7.2297560181468725
812
+ ],
813
+ "entropy": [
814
+ 1.384157688356936,
815
+ 1.3705531377345324,
816
+ 1.3620696673169732,
817
+ 1.34917978849262,
818
+ 1.33616404235363,
819
+ 1.333935335278511,
820
+ 1.3117941915988922,
821
+ 1.2774329232051969,
822
+ 1.2390660690143704,
823
+ 1.2097994657233357,
824
+ 1.1650842120870948,
825
+ 1.163279932923615,
826
+ 1.1738641979172826,
827
+ 1.2158568240702152,
828
+ 1.114059129729867,
829
+ 1.124722182750702,
830
+ 1.129445598460734,
831
+ 1.2377208853140473,
832
+ 1.1542788790538907,
833
+ 1.2222738303244114,
834
+ 1.1691155284643173,
835
+ 1.165976474992931,
836
+ 1.1471045780926943,
837
+ 1.1464839773252606,
838
+ 1.1270105578005314,
839
+ 0.9089711038395762,
840
+ 1.1209891475737095,
841
+ 1.0858190059661865,
842
+ 1.0796098113059998,
843
+ 1.1417838828638196,
844
+ 1.1433162745088339,
845
+ 1.0739809516817331,
846
+ 1.085077359341085,
847
+ 1.0622714115306735,
848
+ 1.0335699897259474,
849
+ 1.0388436089269817,
850
+ 1.064063385128975,
851
+ 1.034223263617605,
852
+ 1.0418634568341076,
853
+ 0.9904597336426377,
854
+ 0.895952848251909,
855
+ 1.0142980953678489,
856
+ 1.0475381733849645,
857
+ 0.9823213117197156,
858
+ 0.9869043030776083,
859
+ 0.9921243949793279,
860
+ 0.8546506622806191,
861
+ 0.9553067521192133
862
+ ]
863
+ },
864
+ "total_timesteps": 100000,
865
+ "total_episodes": 353
866
+ }
test_model.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gymnasium as gym
3
+ import torch.nn as nn
4
+
5
+ class Actor(nn.Module):
6
+ def __init__(self, state_dim, action_dim, hidden_size=64):
7
+ super().__init__()
8
+ self.network = nn.Sequential(
9
+ nn.Linear(state_dim, hidden_size),
10
+ nn.Tanh(),
11
+ nn.Linear(hidden_size, hidden_size),
12
+ nn.Tanh(),
13
+ nn.Linear(hidden_size, action_dim)
14
+ )
15
+
16
+ def forward(self, x):
17
+ return self.network(x)
18
+
19
+ def test_model():
20
+ # Load the model
21
+ checkpoint = torch.load("model.pt", map_location='cpu')
22
+ actor = Actor(state_dim=8, action_dim=4, hidden_size=checkpoint['config']['hidden_size'])
23
+ actor.load_state_dict(checkpoint['actor_state_dict'])
24
+ actor.eval()
25
+
26
+ # Test the agent
27
+ env = gym.make("LunarLander-v2", render_mode="human")
28
+ state, _ = env.reset()
29
+ total_reward = 0
30
+
31
+ for _ in range(1000):
32
+ with torch.no_grad():
33
+ state_tensor = torch.FloatTensor(state).unsqueeze(0)
34
+ logits = actor(state_tensor)
35
+ action = torch.argmax(logits, dim=-1).item()
36
+
37
+ state, reward, terminated, truncated, _ = env.step(action)
38
+ total_reward += reward
39
+
40
+ if terminated or truncated:
41
+ break
42
+
43
+ env.close()
44
+ print(f"Total reward: {total_reward:.2f}")
45
+
46
+ if __name__ == "__main__":
47
+ test_model()