Spaces:
Sleeping
Sleeping
File size: 5,786 Bytes
c64c726 f1594be c64c726 ded2bd6 c64c726 f1594be c64c726 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
from collections import defaultdict, namedtuple
import math
from pathlib import Path
from typing import Any, Dict, List, Tuple
try:
import pygame # type: ignore
except Exception: # Handle ImportError and headless environments gracefully
pygame = None # type: ignore
import torch
from torch import Tensor
from ..agent import Agent
from ..csgo.action_processing import CSGOAction, decode_csgo_action, encode_csgo_action, print_csgo_action
from ..csgo.keymap import CSGO_KEYMAP
from ..data import Dataset, Episode
from ..envs import WorldModelEnv
NamedEnv = namedtuple("NamedEnv", "name env")
OneStepData = namedtuple("OneStepData", "obs act rew end trunc")
class PlayEnv:
def __init__(
self,
agent: Agent,
wm_env: WorldModelEnv,
recording_mode: bool,
store_denoising_trajectory: bool,
store_original_obs: bool,
) -> None:
self.agent = agent
self.keymap = CSGO_KEYMAP
self.recording_mode = recording_mode
self.store_denoising_trajectory = store_denoising_trajectory
self.store_original_obs = store_original_obs
self.is_human_player = True
self.env_id = 0
self.env_name = "world model"
self.env = wm_env
self.obs, self.t, self.buffer, self.rec_dataset = (None,) * 4
def print_controls(self) -> None:
print("\nEnvironment actions:\n")
for key, action_name in self.keymap.items():
if key is not None:
# Use pygame key names when available, otherwise fallback to string representation
key_name = str(key)
if pygame is not None:
try:
key_name = pygame.key.name(key)
except Exception:
pass
key_name = "⎵" if key_name == "space" else key_name
print(f"{key_name} : {action_name}")
def next_mode(self) -> bool:
self.switch_controller()
return True
def next_axis_1(self) -> bool:
return False
def prev_axis_1(self) -> bool:
return False
def next_axis_2(self) -> bool:
return False
def prev_axis_2(self) -> bool:
return False
def print_env(self) -> None:
print(f"> Environment: {self.env_name}")
def str_control(self) -> str:
return "human" if self.is_human_player else "replay actions (test dataset)"
def print_control(self) -> None:
print(f"> Control: {self.str_control()}")
def switch_controller(self) -> None:
self.is_human_player = not self.is_human_player
self.print_control()
def update_wm_horizon(self, incr: int) -> None:
self.env.horizon = max(1, self.env.horizon + incr)
def reset_recording(self) -> None:
self.buffer = defaultdict(list)
self.buffer["info"] = defaultdict(list)
dir = Path("dataset") / f"rec_{self.env_name}_{'H' if self.is_human_player else 'R'}"
self.rec_dataset = Dataset(dir, None)
self.rec_dataset.load_from_default_path()
def reset(self) -> Tuple[Tensor, None]:
self.obs, _ = self.env.reset()
self.t = 0
if self.recording_mode:
self.reset_recording()
return self.obs, None
@torch.no_grad()
def step(self, csgo_action: CSGOAction) -> Tuple[Tensor, Tensor, Tensor, Tensor, Dict[str, Any]]:
if self.is_human_player:
action = encode_csgo_action(csgo_action, device=self.agent.device)
else:
action = self.env.next_act[self.t - 1] if self.t > 0 else self.env.act_buffer[0, -1].clone()
csgo_action = decode_csgo_action(action.cpu())
next_obs, rew, end, trunc, env_info = self.env.step(action)
if not self.is_human_player and self.t == self.env.next_act.size(0):
trunc[0] = 1
data = OneStepData(self.obs, action, rew, end, trunc)
keys, mouse, clicks = print_csgo_action(csgo_action)
horizon = self.env.horizon if self.is_human_player else min(self.env.horizon, self.env.next_act.size(0))
header = [
[
f"Env : {self.env_name}",
f"Control : {self.str_control()}",
f"Timestep: {self.t + 1}",
f"Horizon : {horizon}",
"",
f"Keys : {keys}",
f"Mouse : {mouse}",
f"Clicks: {clicks}",
],
]
info = {"header": header}
if "obs_low_res" in env_info:
info["obs_low_res"] = env_info["obs_low_res"]
if self.recording_mode:
for k, v in data._asdict().items():
self.buffer[k].append(v)
if "obs_low_res" in env_info:
self.buffer["info"]["obs_low_res"].append(env_info["obs_low_res"])
if self.store_denoising_trajectory and "denoising_trajectory" in env_info:
self.buffer["info"]["denoising_trajectory"].append(env_info["denoising_trajectory"])
if self.store_original_obs and "original_obs" in env_info:
original_obs = (torch.tensor(env_info["original_obs"][0]).permute(2, 0, 1).unsqueeze(0).contiguous())
self.buffer["info"]["original_obs"].append(original_obs)
if end or trunc:
ep_dict = {k: torch.cat(v, dim=0) for k, v in self.buffer.items() if k != "info"}
ep_info = {k: torch.cat(v, dim=0) for k, v in self.buffer["info"].items()}
ep = Episode(**ep_dict, info=ep_info).to("cpu")
self.rec_dataset.add_episode(ep)
self.rec_dataset.save_to_default_path()
self.obs = next_obs
self.t += 1
return next_obs, rew, end, trunc, info |