Spaces:
Sleeping
Sleeping
| from stable_baselines3 import PPO | |
| from stable_baselines3.common.callbacks import BaseCallback | |
| import os | |
| from PIL import Image | |
| import logging | |
| import json | |
| import numpy as np | |
| import csv | |
| import gymnasium | |
| from vizdoom import gymnasium_wrapper # This import is needed to register the env | |
| DATASET_DIR = "gamelogs" | |
| FRAMES_DIR = os.path.join(DATASET_DIR, "frames") | |
| os.makedirs(FRAMES_DIR, exist_ok=True) | |
| class NpEncoder(json.JSONEncoder): | |
| def default(self, obj): | |
| if isinstance(obj, np.integer): | |
| return int(obj) | |
| if isinstance(obj, np.floating): | |
| return float(obj) | |
| if isinstance(obj, np.ndarray): | |
| return obj.tolist() | |
| return super(NpEncoder, self).default(obj) | |
| class GameNGenCallback(BaseCallback): | |
| def __init__(self, verbose: bool, save_path: str): | |
| super(GameNGenCallback, self).__init__(verbose) | |
| self.save_path = save_path | |
| self.frame_log = open(os.path.join(self.save_path, "metadata.csv"), mode="w", newline="") | |
| self.csv_writer = csv.writer(self.frame_log) | |
| # CSV Header | |
| self.csv_writer.writerow(["frame_id", "action"]) | |
| def _on_step(self) -> bool: | |
| frame_id = self.n_calls | |
| key = f"{frame_id:09d}" | |
| try: | |
| obs_dict = self.locals["new_obs"] | |
| # The observation from the callback is in Channels-First format (C, H, W) | |
| frame_data = obs_dict['screen'][0] | |
| action = self.locals["actions"][0] | |
| # --- DEFINITIVE FIX --- | |
| # Check if the frame is in the expected Channels-First format (C, H, W). | |
| # A valid RGB image will have 3 channels in its first dimension. | |
| if frame_data.ndim == 3 and frame_data.shape[0] == 3: | |
| # Pillow's fromarray function needs the image in Channels-Last format (H, W, C). | |
| # We must transpose the axes from (C, H, W) to (H, W, C). | |
| transposed_frame = np.transpose(frame_data, (1, 2, 0)) | |
| image = Image.fromarray(transposed_frame) | |
| image.save(os.path.join(FRAMES_DIR, f"frame_{key}.png")) | |
| json_action = json.dumps(action, cls=NpEncoder) | |
| self.csv_writer.writerow([key, json_action]) | |
| else: | |
| # This will now correctly catch the junk frames from terminal states. | |
| logging.warning(f"Skipping corrupted frame {key} with invalid shape: {frame_data.shape}") | |
| except Exception as e: | |
| # This will now only catch truly unexpected errors. | |
| logging.error(f"Could not process or save frame {key} due to an unexpected error: {e}") | |
| return True | |
| def _on_training_end(self) -> None: | |
| self.frame_log.close() | |
| # --- Main script --- | |
| logging.basicConfig(level=logging.INFO) | |
| # Create the VizDoom environment. No wrappers are needed. | |
| env = gymnasium.make("VizdoomHealthGatheringSupreme-v0") | |
| callback = GameNGenCallback(verbose=True, save_path=DATASET_DIR) | |
| model = PPO( | |
| "MultiInputPolicy", | |
| env, | |
| verbose=1, | |
| ) | |
| model.learn(total_timesteps=2_000_000, callback=callback) | |
| env.close() | |