Spaces:
Build error
Build error
| import os | |
| from stable_baselines3.common.callbacks import BaseCallback | |
| import imageio | |
| import numpy as np | |
| class SaveFramesCallback(BaseCallback): | |
| def __init__(self, save_freq, save_path, verbose=0): | |
| super(SaveFramesCallback, self).__init__(verbose) | |
| self.save_freq = save_freq | |
| self.save_path = save_path | |
| self.frames = [] | |
| os.makedirs(self.save_path, exist_ok=True) | |
| def _on_step(self) -> bool: | |
| if self.num_timesteps % self.save_freq == 0: | |
| frame = self.training_env.render(mode='rgb_array') | |
| self.frames.append(frame) | |
| if self.verbose > 0: | |
| print(f"Saved frame at timestep {self.num_timesteps}") | |
| return True | |
| def _on_training_end(self) -> None: | |
| if self.frames: | |
| gif_path = os.path.join(self.save_path, "training.gif") | |
| imageio.mimsave(gif_path, self.frames, fps=10) | |
| if self.verbose > 0: | |
| print(f"Saved training GIF to {gif_path}") | |