Tetris-RL / callbacks.py
BaljinderH's picture
Update callbacks.py
129e0fe verified
import os
from stable_baselines3.common.callbacks import BaseCallback
import imageio
import numpy as np
class SaveFramesCallback(BaseCallback):
def __init__(self, save_freq, save_path, verbose=0):
super(SaveFramesCallback, self).__init__(verbose)
self.save_freq = save_freq
self.save_path = save_path
self.frames = []
os.makedirs(self.save_path, exist_ok=True)
def _on_step(self) -> bool:
if self.num_timesteps % self.save_freq == 0:
frame = self.training_env.render(mode='rgb_array')
self.frames.append(frame)
if self.verbose > 0:
print(f"Saved frame at timestep {self.num_timesteps}")
return True
def _on_training_end(self) -> None:
if self.frames:
gif_path = os.path.join(self.save_path, "training.gif")
imageio.mimsave(gif_path, self.frames, fps=10)
if self.verbose > 0:
print(f"Saved training GIF to {gif_path}")