Spaces:
Sleeping
Sleeping
| """ | |
| Web-compatible action processing for CSGO actions | |
| Converts web keyboard inputs to CSGO actions without pygame dependency | |
| """ | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Set, Tuple | |
| import numpy as np | |
| import torch | |
| # Web key code to CSGO action mapping | |
| WEB_KEYMAP = { | |
| 'KeyW': "up", | |
| 'KeyD': "right", | |
| 'KeyA': "left", | |
| 'KeyS': "down", | |
| 'Space': "jump", | |
| 'ControlLeft': "crouch", | |
| 'ShiftLeft': "walk", | |
| 'Digit1': "weapon1", | |
| 'Digit2': "weapon2", | |
| 'Digit3': "weapon3", | |
| 'KeyR': "reload", | |
| 'ArrowUp': "camera_up", | |
| 'ArrowRight': "camera_right", | |
| 'ArrowLeft': "camera_left", | |
| 'ArrowDown': "camera_down", | |
| } | |
| # Forbidden key combinations (same logic as original) | |
| WEB_FORBIDDEN_COMBINATIONS = [ | |
| {"up", "down"}, | |
| {"left", "right"}, | |
| {"weapon1", "weapon2"}, | |
| {"weapon1", "weapon3"}, | |
| {"weapon2", "weapon3"}, | |
| {"camera_up", "camera_down"}, | |
| {"camera_left", "camera_right"}, | |
| ] | |
| class WebCSGOAction: | |
| """Web-compatible CSGO action without pygame dependencies""" | |
| key_names: List[str] # Use string names instead of pygame key codes | |
| mouse_x: float | |
| mouse_y: float | |
| l_click: bool | |
| r_click: bool | |
| def __post_init__(self) -> None: | |
| self.key_names = filter_web_keys_forbidden(self.key_names) | |
| self.process_mouse() | |
| def process_mouse(self) -> None: | |
| """Process mouse movement with discretization""" | |
| # Import mouse constants | |
| from .action_processing import MOUSE_X_POSSIBLES, MOUSE_Y_POSSIBLES, MOUSE_X_LIM, MOUSE_Y_LIM | |
| # Clip and match mouse to closest in list of possibles | |
| x = np.clip(self.mouse_x, MOUSE_X_LIM[0], MOUSE_X_LIM[1]) | |
| y = np.clip(self.mouse_y, MOUSE_Y_LIM[0], MOUSE_Y_LIM[1]) | |
| self.mouse_x = min(MOUSE_X_POSSIBLES, key=lambda x_: abs(x_ - x)) | |
| self.mouse_y = min(MOUSE_Y_POSSIBLES, key=lambda x_: abs(x_ - y)) | |
| # Use arrow keys to override mouse movements | |
| for key_name in self.key_names: | |
| if key_name == "camera_left": | |
| self.mouse_x = -60 | |
| elif key_name == "camera_right": | |
| self.mouse_x = +60 | |
| elif key_name == "camera_up": | |
| self.mouse_y = -50 | |
| elif key_name == "camera_down": | |
| self.mouse_y = +50 | |
| def filter_web_keys_forbidden(key_names: List[str]) -> List[str]: | |
| """Filter out forbidden key combinations""" | |
| names = set(key_names) | |
| filtered_names = [] | |
| for key_name in key_names: | |
| # Check if adding this key would create a forbidden combination | |
| test_names = set(filtered_names + [key_name]) | |
| is_forbidden = False | |
| for forbidden in WEB_FORBIDDEN_COMBINATIONS: | |
| if forbidden.issubset(test_names): | |
| is_forbidden = True | |
| break | |
| if not is_forbidden: | |
| filtered_names.append(key_name) | |
| return filtered_names | |
| def web_keys_to_csgo_action_names(pressed_web_keys: Set[str]) -> List[str]: | |
| """Convert set of pressed web keys to CSGO action names""" | |
| action_names = [] | |
| for web_key in pressed_web_keys: | |
| if web_key in WEB_KEYMAP: | |
| action_names.append(WEB_KEYMAP[web_key]) | |
| return action_names | |
| def encode_web_csgo_action(web_action: WebCSGOAction, device: torch.device) -> torch.Tensor: | |
| """Encode web CSGO action to tensor format (compatible with original encoding)""" | |
| from .action_processing import MOUSE_X_POSSIBLES, MOUSE_Y_POSSIBLES, N_KEYS, N_CLICKS, N_MOUSE_X, N_MOUSE_Y | |
| keys_pressed_onehot = np.zeros(N_KEYS) | |
| mouse_x_onehot = np.zeros(N_MOUSE_X) | |
| mouse_y_onehot = np.zeros(N_MOUSE_Y) | |
| l_click_onehot = np.zeros(1) | |
| r_click_onehot = np.zeros(1) | |
| # Map action names to one-hot encoding | |
| for action_name in web_action.key_names: | |
| if action_name == "up": # w key | |
| keys_pressed_onehot[0] = 1 | |
| elif action_name == "left": # a key | |
| keys_pressed_onehot[1] = 1 | |
| elif action_name == "down": # s key | |
| keys_pressed_onehot[2] = 1 | |
| elif action_name == "right": # d key | |
| keys_pressed_onehot[3] = 1 | |
| elif action_name == "jump": # space | |
| keys_pressed_onehot[4] = 1 | |
| elif action_name == "crouch": # ctrl | |
| keys_pressed_onehot[5] = 1 | |
| elif action_name == "walk": # shift | |
| keys_pressed_onehot[6] = 1 | |
| elif action_name == "weapon1": # 1 | |
| keys_pressed_onehot[7] = 1 | |
| elif action_name == "weapon2": # 2 | |
| keys_pressed_onehot[8] = 1 | |
| elif action_name == "weapon3": # 3 | |
| keys_pressed_onehot[9] = 1 | |
| elif action_name == "reload": # r | |
| keys_pressed_onehot[10] = 1 | |
| l_click_onehot[0] = int(web_action.l_click) | |
| r_click_onehot[0] = int(web_action.r_click) | |
| mouse_x_onehot[MOUSE_X_POSSIBLES.index(web_action.mouse_x)] = 1 | |
| mouse_y_onehot[MOUSE_Y_POSSIBLES.index(web_action.mouse_y)] = 1 | |
| assert mouse_x_onehot.sum() == 1 | |
| assert mouse_y_onehot.sum() == 1 | |
| return torch.tensor( | |
| np.concatenate(( | |
| keys_pressed_onehot, | |
| l_click_onehot, | |
| r_click_onehot, | |
| mouse_x_onehot, | |
| mouse_y_onehot, | |
| )), | |
| device=device, | |
| dtype=torch.float32, | |
| ) | |
| def print_web_csgo_action(action: WebCSGOAction) -> Tuple[str, str, str]: | |
| """Print web CSGO action in readable format""" | |
| action_names = [name for name in action.key_names if not name.startswith("camera_")] | |
| keys = " + ".join(action_names) | |
| mouse = str((action.mouse_x, action.mouse_y)) * (action.mouse_x != 0 or action.mouse_y != 0) | |
| clicks = "L" * action.l_click + " + " * (action.l_click and action.r_click) + "R" * action.r_click | |
| return keys, mouse, clicks | |