privateboss commited on
Commit
6797b6d
·
verified ·
1 Parent(s): f9355c8

Create plot_utility_Trainer.py

Browse files
Files changed (1) hide show
  1. plot_utility_Trainer.py +86 -0
plot_utility_Trainer.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import os
4
+ import time
5
+
6
+ def smooth_curve(points, factor=0.9):
7
+
8
+     smoothed_points = []
9
+     if points:
10
+         smoothed_points.append(points[0])
11
+         for i in range(1, len(points)):
12
+             smoothed_points.append(smoothed_points[-1] * factor + points[i] * (1 - factor))
13
+     return smoothed_points
14
+
15
+ def plot_rewards(rewards_history, log_interval, save_dir, filename="rewards_plot.png", show_plot=True):
16
+    
17
+     os.makedirs(save_dir, exist_ok=True)
18
+    
19
+     plt.figure(figsize=(12, 6))
20
+     episodes = [i * log_interval for i in range(1, len(rewards_history) + 1)]
21
+     plt.plot(episodes, rewards_history, label='Average Reward')
22
+     plt.xlabel('Episodes')
23
+     plt.ylabel('Average Reward')
24
+     plt.title('PPO Training Progress (Average Reward per Episode)')
25
+     plt.grid(True)
26
+     plt.legend()
27
+     plt.tight_layout()
28
+    
29
+     save_path = os.path.join(save_dir, filename)
30
+     plt.savefig(save_path)
31
+     print(f"Plot saved to: {os.path.abspath(save_path)}")
32
+    
33
+     if show_plot:
34
+         plt.show()
35
+
36
+ def init_live_plot(save_dir, filename="live_rewards_plot.png"):
37
+
38
+     plt.ion() # Turn on interactive mode
39
+     fig, ax = plt.subplots(figsize=(12, 6))
40
+     line, = ax.plot([], [], label='Smoothed Average Reward')
41
+     ax.set_xlabel('Episodes')
42
+     ax.set_ylabel('Average Reward')
43
+     ax.set_title('Live PPO Training Progress')
44
+     ax.grid(True)
45
+     ax.legend()
46
+     plt.tight_layout()
47
+    
48
+     ax._save_path_final = os.path.join(save_dir, filename)
49
+    
50
+     return fig, ax, line
51
+
52
+ def update_live_plot(fig, ax, line, episodes, smoothed_rewards, current_timestep=None, total_timesteps=None):
53
+     """
54
+     Updates the live plot with new data.
55
+     """
56
+     if not episodes or not smoothed_rewards:
57
+         return
58
+
59
+     line.set_data(episodes, smoothed_rewards)
60
+    
61
+     ax.set_xlim(0, max(episodes) * 1.05 if episodes else 1)
62
+    
63
+     min_y = min(smoothed_rewards) * 0.9 if smoothed_rewards else -1
64
+     max_y = max(smoothed_rewards) * 1.1 if smoothed_rewards else 1
65
+
66
+     if abs(max_y - min_y) < 0.1:
67
+         min_y -= 0.05
68
+         max_y += 0.05
69
+     ax.set_ylim(min_y, max_y)
70
+
71
+     if current_timestep is not None and total_timesteps is not None:
72
+         ax.set_title(f'Live PPO Training Progress (Timestep: {current_timestep:,}/{total_timesteps:,})')
73
+
74
+     fig.canvas.draw()
75
+     fig.canvas.flush_events()
76
+     time.sleep(0.01)
77
+
78
+ def save_live_plot_final(fig, ax):
79
+    
80
+     plt.ioff()
81
+     save_path = getattr(ax, '_save_path_final', None)
82
+     if save_path:
83
+         plt.savefig(save_path)
84
+         print(f"Final live plot saved to: {os.path.abspath(save_path)}")
85
+     plt.close(fig)
86
+     plt.show()