File size: 3,141 Bytes
2ad4d00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback
import os
from PIL import Image
import logging
import json
import numpy as np
import csv
import gymnasium
from vizdoom import gymnasium_wrapper # This import is needed to register the env

DATASET_DIR = "gamelogs"
FRAMES_DIR = os.path.join(DATASET_DIR, "frames")
os.makedirs(FRAMES_DIR, exist_ok=True)


class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)

class GameNGenCallback(BaseCallback):
    def __init__(self, verbose: bool, save_path: str):
        super(GameNGenCallback, self).__init__(verbose)
        self.save_path = save_path
        self.frame_log = open(os.path.join(self.save_path, "metadata.csv"), mode="w", newline="")
        self.csv_writer = csv.writer(self.frame_log)
        # CSV Header
        self.csv_writer.writerow(["frame_id", "action"])
    
    def _on_step(self) -> bool:
        frame_id = self.n_calls
        key = f"{frame_id:09d}"

        try:
            obs_dict = self.locals["new_obs"]
            # The observation from the callback is in Channels-First format (C, H, W)
            frame_data = obs_dict['screen'][0]
            action = self.locals["actions"][0]

            # --- DEFINITIVE FIX ---
            # Check if the frame is in the expected Channels-First format (C, H, W).
            # A valid RGB image will have 3 channels in its first dimension.
            if frame_data.ndim == 3 and frame_data.shape[0] == 3:
                # Pillow's fromarray function needs the image in Channels-Last format (H, W, C).
                # We must transpose the axes from (C, H, W) to (H, W, C).
                transposed_frame = np.transpose(frame_data, (1, 2, 0))
                image = Image.fromarray(transposed_frame)
                image.save(os.path.join(FRAMES_DIR, f"frame_{key}.png"))
                
                json_action = json.dumps(action, cls=NpEncoder)
                self.csv_writer.writerow([key, json_action])
            else:
                # This will now correctly catch the junk frames from terminal states.
                logging.warning(f"Skipping corrupted frame {key} with invalid shape: {frame_data.shape}")

        except Exception as e:
            # This will now only catch truly unexpected errors.
            logging.error(f"Could not process or save frame {key} due to an unexpected error: {e}")

        return True

    def _on_training_end(self) -> None:
        self.frame_log.close()

# --- Main script ---
logging.basicConfig(level=logging.INFO)

# Create the VizDoom environment. No wrappers are needed.
env = gymnasium.make("VizdoomHealthGatheringSupreme-v0")

callback = GameNGenCallback(verbose=True, save_path=DATASET_DIR)

model = PPO(
    "MultiInputPolicy",
    env,  
    verbose=1,
)

model.learn(total_timesteps=2_000_000, callback=callback)

env.close()