Spaces:
Sleeping
Sleeping
| """ | |
| Credits: some parts are taken and modified from the file `config.py` from https://github.com/TeaPearce/Counter-Strike_Behavioural_Cloning/ | |
| """ | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Set, Tuple | |
| import numpy as np | |
| try: | |
| import pygame # type: ignore | |
| except Exception: | |
| pygame = None # type: ignore | |
| import torch | |
| from .keymap import CSGO_FORBIDDEN_COMBINATIONS, CSGO_KEYMAP | |
| class CSGOAction: | |
| keys: List[int] | |
| mouse_x: float | |
| mouse_y: float | |
| l_click: bool | |
| r_click: bool | |
| def __post_init__(self) -> None: | |
| self.keys = filter_keys_pressed_forbidden(self.keys) | |
| self.process_mouse() | |
| def key_names(self) -> List[str]: | |
| # Use pygame to convert key codes when available; else assume keys already strings | |
| names: List[str] = [] | |
| for key in self.keys: | |
| if pygame is not None: | |
| try: | |
| names.append(pygame.key.name(key)) | |
| continue | |
| except Exception: | |
| pass | |
| # Fallback for headless: keep as-is or cast to string | |
| names.append(str(key)) | |
| return names | |
| def process_mouse(self) -> None: | |
| # Clip and match mouse to closest in list of possibles | |
| x = np.clip(self.mouse_x, MOUSE_X_LIM[0], MOUSE_X_LIM[1]) | |
| y = np.clip(self.mouse_y, MOUSE_Y_LIM[0], MOUSE_Y_LIM[1]) | |
| self.mouse_x = min(MOUSE_X_POSSIBLES, key=lambda x_: abs(x_ - x)) | |
| self.mouse_y = min(MOUSE_Y_POSSIBLES, key=lambda x_: abs(x_ - y)) | |
| # Use arrows to override mouse movements | |
| for key in self.key_names: | |
| if key == "left": | |
| self.mouse_x = -60 | |
| elif key == "right": | |
| self.mouse_x = +60 | |
| elif key == "up": | |
| self.mouse_y = -50 | |
| elif key == "down": | |
| self.mouse_y = +50 | |
| def print_csgo_action(action: CSGOAction) -> Tuple[str]: | |
| action_names = [CSGO_KEYMAP[k] for k in action.keys] if len(action.keys) > 0 else [] | |
| action_names = [x for x in action_names if not x.startswith("camera_")] | |
| keys = " + ".join(action_names) | |
| mouse = str((action.mouse_x, action.mouse_y)) * (action.mouse_x != 0 or action.mouse_y != 0) | |
| clicks = "L" * action.l_click + " + " * (action.l_click and action.r_click) + "R" * action.r_click | |
| return keys, mouse, clicks | |
| MOUSE_X_POSSIBLES = [ | |
| -1000, | |
| -500, | |
| -300, | |
| -200, | |
| -100, | |
| -60, | |
| -30, | |
| -20, | |
| -10, | |
| -4, | |
| -2, | |
| 0, | |
| 2, | |
| 4, | |
| 10, | |
| 20, | |
| 30, | |
| 60, | |
| 100, | |
| 200, | |
| 300, | |
| 500, | |
| 1000, | |
| ] | |
| MOUSE_Y_POSSIBLES = [ | |
| -200, | |
| -100, | |
| -50, | |
| -20, | |
| -10, | |
| -4, | |
| -2, | |
| 0, | |
| 2, | |
| 4, | |
| 10, | |
| 20, | |
| 50, | |
| 100, | |
| 200, | |
| ] | |
| MOUSE_X_LIM = (MOUSE_X_POSSIBLES[0], MOUSE_X_POSSIBLES[-1]) | |
| MOUSE_Y_LIM = (MOUSE_Y_POSSIBLES[0], MOUSE_Y_POSSIBLES[-1]) | |
| N_KEYS = 11 # number of keyboard outputs, w,s,a,d,space,ctrl,shift,1,2,3,r | |
| N_CLICKS = 2 # number of mouse buttons, left, right | |
| N_MOUSE_X = len(MOUSE_X_POSSIBLES) # number of outputs on mouse x axis | |
| N_MOUSE_Y = len(MOUSE_Y_POSSIBLES) # number of outputs on mouse y axis | |
| def encode_csgo_action(csgo_action: CSGOAction, device: torch.device) -> torch.Tensor: | |
| # mouse_x = csgo_action.mouse_x | |
| # mouse_y = csgo_action.mouse_y | |
| keys_pressed_onehot = np.zeros(N_KEYS) | |
| mouse_x_onehot = np.zeros(N_MOUSE_X) | |
| mouse_y_onehot = np.zeros(N_MOUSE_Y) | |
| l_click_onehot = np.zeros(1) | |
| r_click_onehot = np.zeros(1) | |
| for key in csgo_action.key_names: | |
| if key == "w": | |
| keys_pressed_onehot[0] = 1 | |
| elif key == "a": | |
| keys_pressed_onehot[1] = 1 | |
| elif key == "s": | |
| keys_pressed_onehot[2] = 1 | |
| elif key == "d": | |
| keys_pressed_onehot[3] = 1 | |
| elif key == "space": | |
| keys_pressed_onehot[4] = 1 | |
| elif key == "left ctrl": | |
| keys_pressed_onehot[5] = 1 | |
| elif key == "left shift": | |
| keys_pressed_onehot[6] = 1 | |
| elif key == "1": | |
| keys_pressed_onehot[7] = 1 | |
| elif key == "2": | |
| keys_pressed_onehot[8] = 1 | |
| elif key == "3": | |
| keys_pressed_onehot[9] = 1 | |
| elif key == "r": | |
| keys_pressed_onehot[10] = 1 | |
| l_click_onehot[0] = int(csgo_action.l_click) | |
| r_click_onehot[0] = int(csgo_action.r_click) | |
| mouse_x_onehot[MOUSE_X_POSSIBLES.index(csgo_action.mouse_x)] = 1 | |
| mouse_y_onehot[MOUSE_Y_POSSIBLES.index(csgo_action.mouse_y)] = 1 | |
| assert mouse_x_onehot.sum() == 1 | |
| assert mouse_y_onehot.sum() == 1 | |
| return torch.tensor( | |
| np.concatenate(( | |
| keys_pressed_onehot, | |
| l_click_onehot, | |
| r_click_onehot, | |
| mouse_x_onehot, | |
| mouse_y_onehot, | |
| )), | |
| device=device, | |
| dtype=torch.float32, | |
| ) | |
| def decode_csgo_action(y_preds: torch.Tensor) -> CSGOAction: | |
| y_preds = y_preds.squeeze() | |
| keys_pred = y_preds[0:N_KEYS] | |
| l_click_pred = y_preds[N_KEYS : N_KEYS + 1] | |
| r_click_pred = y_preds[N_KEYS + 1 : N_KEYS + N_CLICKS] | |
| mouse_x_pred = y_preds[N_KEYS + N_CLICKS : N_KEYS + N_CLICKS + N_MOUSE_X] | |
| mouse_y_pred = y_preds[ | |
| N_KEYS + N_CLICKS + N_MOUSE_X : N_KEYS + N_CLICKS + N_MOUSE_X + N_MOUSE_Y | |
| ] | |
| keys_pressed = [] | |
| keys_pressed_onehot = np.round(keys_pred) | |
| if keys_pressed_onehot[0] == 1: | |
| keys_pressed.append("w") | |
| if keys_pressed_onehot[1] == 1: | |
| keys_pressed.append("a") | |
| if keys_pressed_onehot[2] == 1: | |
| keys_pressed.append("s") | |
| if keys_pressed_onehot[3] == 1: | |
| keys_pressed.append("d") | |
| if keys_pressed_onehot[4] == 1: | |
| keys_pressed.append("space") | |
| if keys_pressed_onehot[5] == 1: | |
| keys_pressed.append("left ctrl") | |
| if keys_pressed_onehot[6] == 1: | |
| keys_pressed.append("left shift") | |
| if keys_pressed_onehot[7] == 1: | |
| keys_pressed.append("1") | |
| if keys_pressed_onehot[8] == 1: | |
| keys_pressed.append("2") | |
| if keys_pressed_onehot[9] == 1: | |
| keys_pressed.append("3") | |
| if keys_pressed_onehot[10] == 1: | |
| keys_pressed.append("r") | |
| l_click = int(np.round(l_click_pred)) | |
| r_click = int(np.round(r_click_pred)) | |
| id = np.argmax(mouse_x_pred) | |
| mouse_x = MOUSE_X_POSSIBLES[id] | |
| id = np.argmax(mouse_y_pred) | |
| mouse_y = MOUSE_Y_POSSIBLES[id] | |
| # Map string names back to pygame key codes when pygame is available; otherwise keep as strings | |
| if pygame is not None: | |
| try: | |
| keys_pressed = [pygame.key.key_code(x) for x in keys_pressed] | |
| except Exception: | |
| pass | |
| return CSGOAction(keys_pressed, mouse_x, mouse_y, bool(l_click), bool(r_click)) | |
| def filter_keys_pressed_forbidden(keys_pressed: List[int], keymap: Dict[int, str] = CSGO_KEYMAP, forbidden_combinations: List[Set[str]] = CSGO_FORBIDDEN_COMBINATIONS) -> List[int]: | |
| keys = set() | |
| names = set() | |
| for key in keys_pressed: | |
| if key not in keymap: | |
| continue | |
| name = keymap[key] | |
| keys.add(key) | |
| names.add(name) | |
| for forbidden in forbidden_combinations: | |
| if forbidden.issubset(names): | |
| keys.remove(key) | |
| names.remove(name) | |
| break | |
| return list(filter(lambda key: key in keys, keys_pressed)) | |