File size: 1,005 Bytes
b100cf9
 
 
 
 
 
129e0fe
b100cf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import os
from stable_baselines3.common.callbacks import BaseCallback
import imageio
import numpy as np

class SaveFramesCallback(BaseCallback):
   
    def __init__(self, save_freq, save_path, verbose=0):
        super(SaveFramesCallback, self).__init__(verbose)
        self.save_freq = save_freq
        self.save_path = save_path
        self.frames = []
        os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self) -> bool:
        if self.num_timesteps % self.save_freq == 0:
            frame = self.training_env.render(mode='rgb_array')
            self.frames.append(frame)
            if self.verbose > 0:
                print(f"Saved frame at timestep {self.num_timesteps}")
        return True

    def _on_training_end(self) -> None:
        if self.frames:
            gif_path = os.path.join(self.save_path, "training.gif")
            imageio.mimsave(gif_path, self.frames, fps=10)
            if self.verbose > 0:
                print(f"Saved training GIF to {gif_path}")