File size: 5,442 Bytes
c64c726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import argparse
from pathlib import Path

from hydra import compose, initialize
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
import torch

from agent import Agent
from envs import WorldModelEnv
from game import Game, PlayEnv
from utils import get_path_agent_ckpt


OmegaConf.register_new_resolver("eval", eval)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint-dir", type=str, help="Path to training output directory")
    parser.add_argument("--epoch", type=int, default=-1, help="Epoch to load, -1 for latest checkpoint")
    parser.add_argument("-r", "--record", action="store_true", help="Record episodes in PlayEnv.")
    parser.add_argument("--store-denoising-trajectory", action="store_true", help="Save denoising steps in info.")
    parser.add_argument("--store-original-obs", action="store_true", help="Save original obs (pre resizing) in info.")
    parser.add_argument("--mouse-multiplier", type=int, default=10, help="Multiplication factor for the mouse movement.")
    parser.add_argument("--size-multiplier", type=int, default=2, help="Multiplication factor for the screen size.")
    parser.add_argument("--compile", action="store_true", help="Turn on model compilation.")
    parser.add_argument("--fps", type=int, default=15, help="Frame rate.")
    parser.add_argument("--no-header", action="store_true")
    return parser.parse_args()


def check_args(args: argparse.Namespace) -> None:
    if not args.record and (args.store_denoising_trajectory or args.store_original_obs):
        print("Warning: not in recording mode, ignoring --store* options")
    return True


def prepare_play_mode(cfg: DictConfig, args: argparse.Namespace) -> PlayEnv:

    #checkpoint_dir = Path(args.checkpoint_dir)
    
    # Load training config
    config_path = Path("/home/alienware3/Documents/diamond/config/trainer.yaml")
    if not config_path.exists():
        raise FileNotFoundError(f"Training config not found: {config_path}")
    
    training_cfg = OmegaConf.load(config_path)
    
    # Override config
    cfg.agent = training_cfg.defaults[2].agent
    cfg.env = training_cfg.defaults[1].env
    cfg.world_model_env = training_cfg.defaults[3].world_model_env

    if torch.cuda.is_available():
        device = torch.device("cuda:0")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    print("----------------------------------------------------------------------")
    print(f"Using {device} for rendering.")
    if not torch.cuda.is_available():
        print("If you have a CUDA GPU available and it is not being used, please follow the instructions at https://pytorch.org/get-started/locally/ to reinstall torch with CUDA support and try again.")
    print("----------------------------------------------------------------------")

    # Get model checkpoint path
    ckpt_dir = "checkpoints"

    path_ckpt = Path("/home/alienware3/Documents/diamond/agent_epoch_00206.pt") # get_path_agent_ckpt(ckpt_dir, args.epoch)
    
    if not path_ckpt.exists():
        agent_versions_dir = ckpt_dir / "agent_versions"
        if agent_versions_dir.exists():
            available_ckpts = sorted(list(agent_versions_dir.glob("*.pt")))
            if available_ckpts:
                path_ckpt = available_ckpts[-1]
            else:
                raise FileNotFoundError("No agent checkpoint files found")
        else:
            raise FileNotFoundError(f"Agent checkpoints directory not found: {agent_versions_dir}")


    spawn_dir = Path("/home/alienware3/Documents/diamond/csgo/spawn")
    assert cfg.env.train.id == "csgo"
    num_actions = cfg.env.num_actions

    # Models
    agent = Agent(instantiate(cfg.agent, num_actions=num_actions)).to(device).eval()
    agent.load(path_ckpt)
    
    # World model environment
    sl = cfg.agent.denoiser.inner_model.num_steps_conditioning
    if agent.upsampler is not None:
        sl = max(sl, cfg.agent.upsampler.inner_model.num_steps_conditioning)
    wm_env_cfg = instantiate(cfg.world_model_env, num_batches_to_preload=1)
    wm_env = WorldModelEnv(agent.denoiser, agent.upsampler, agent.rew_end_model, spawn_dir, 1, sl, wm_env_cfg, return_denoising_trajectory=True)
    
    if device.type == "cuda" and args.compile:
        print("Compiling models...")
        wm_env.predict_next_obs = torch.compile(wm_env.predict_next_obs, mode="reduce-overhead")
        if agent.upsampler is not None:
            wm_env.upsample_next_obs = torch.compile(wm_env.upsample_next_obs, mode="reduce-overhead")

    play_env = PlayEnv(
        agent,
        wm_env,
        args.record,
        args.store_denoising_trajectory,
        args.store_original_obs,
    )

    return play_env


@torch.no_grad()
def main():
    args = parse_args()
    ok = check_args(args)
    if not ok:
        return

    with initialize(version_base="1.3", config_path="../config"):
        cfg = compose(config_name="trainer")

    # window size
    h, w = (cfg.env.train.size,) * 2 if isinstance(cfg.env.train.size, int) else cfg.env.train.size
    size_h, size_w = h * args.size_multiplier, w * args.size_multiplier
    env = prepare_play_mode(cfg, args)
    game = Game(env, (size_h, size_w), args.mouse_multiplier, fps=args.fps, verbose=not args.no_header)
    game.run()


if __name__ == "__main__":
    main()