ledmands
commited on
Commit
·
47aa47a
1
Parent(s):
91b9fbd
Removed unecessary comments in plot_improvement.py
Browse files- plot_improvement.py +3 -9
plot_improvement.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
| 1 |
-
import argparse
|
| 2 |
import numpy as np
|
| 3 |
import os
|
| 4 |
from matplotlib import pyplot as plt
|
| 5 |
|
| 6 |
def calc_stats(filepath):
|
| 7 |
-
# load the numpy file
|
| 8 |
data = np.load(filepath)["results"]
|
| 9 |
# sort the arrays and delete the first and last elements
|
| 10 |
data = np.sort(data, axis=1)
|
|
@@ -19,11 +18,6 @@ def calc_stats(filepath):
|
|
| 19 |
# parser.add_argument("-s", "--save", help="Specify whether to save the chart.", action="store_const", const=True)
|
| 20 |
# args = parser.parse_args()
|
| 21 |
|
| 22 |
-
# Get the file paths and store in list.
|
| 23 |
-
# For each file path, I want to calculate the mean reward. This would be the mean reward for the training run over all evaluations.
|
| 24 |
-
# For each file path, append the mean reward to an averages list
|
| 25 |
-
# Plot the averages!
|
| 26 |
-
|
| 27 |
filepaths = []
|
| 28 |
for d in os.listdir("agents/"):
|
| 29 |
if "dqn_v2" in d:
|
|
@@ -40,8 +34,8 @@ for path in filepaths:
|
|
| 40 |
runs = []
|
| 41 |
for i in range(len(filepaths)):
|
| 42 |
runs.append(i + 1)
|
| 43 |
-
plt.xlabel("
|
| 44 |
-
plt.ylabel("
|
| 45 |
plt.bar(runs, means)
|
| 46 |
plt.bar(runs, stds)
|
| 47 |
plt.legend(["Mean evaluation score", "Standard deviation"])
|
|
|
|
| 1 |
+
# import argparse
|
| 2 |
import numpy as np
|
| 3 |
import os
|
| 4 |
from matplotlib import pyplot as plt
|
| 5 |
|
| 6 |
def calc_stats(filepath):
|
|
|
|
| 7 |
data = np.load(filepath)["results"]
|
| 8 |
# sort the arrays and delete the first and last elements
|
| 9 |
data = np.sort(data, axis=1)
|
|
|
|
| 18 |
# parser.add_argument("-s", "--save", help="Specify whether to save the chart.", action="store_const", const=True)
|
| 19 |
# args = parser.parse_args()
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
filepaths = []
|
| 22 |
for d in os.listdir("agents/"):
|
| 23 |
if "dqn_v2" in d:
|
|
|
|
| 34 |
runs = []
|
| 35 |
for i in range(len(filepaths)):
|
| 36 |
runs.append(i + 1)
|
| 37 |
+
plt.xlabel("Training Run")
|
| 38 |
+
plt.ylabel("Score")
|
| 39 |
plt.bar(runs, means)
|
| 40 |
plt.bar(runs, stds)
|
| 41 |
plt.legend(["Mean evaluation score", "Standard deviation"])
|