PIWM / src /csgo /web_action_processing.py
musictimer's picture
Initial Diamond CSGO AI deployment
c64c726
"""
Web-compatible action processing for CSGO actions
Converts web keyboard inputs to CSGO actions without pygame dependency
"""
from dataclasses import dataclass
from typing import Dict, List, Set, Tuple
import numpy as np
import torch
# Web key code to CSGO action mapping
WEB_KEYMAP = {
'KeyW': "up",
'KeyD': "right",
'KeyA': "left",
'KeyS': "down",
'Space': "jump",
'ControlLeft': "crouch",
'ShiftLeft': "walk",
'Digit1': "weapon1",
'Digit2': "weapon2",
'Digit3': "weapon3",
'KeyR': "reload",
'ArrowUp': "camera_up",
'ArrowRight': "camera_right",
'ArrowLeft': "camera_left",
'ArrowDown': "camera_down",
}
# Forbidden key combinations (same logic as original)
WEB_FORBIDDEN_COMBINATIONS = [
{"up", "down"},
{"left", "right"},
{"weapon1", "weapon2"},
{"weapon1", "weapon3"},
{"weapon2", "weapon3"},
{"camera_up", "camera_down"},
{"camera_left", "camera_right"},
]
@dataclass
class WebCSGOAction:
"""Web-compatible CSGO action without pygame dependencies"""
key_names: List[str] # Use string names instead of pygame key codes
mouse_x: float
mouse_y: float
l_click: bool
r_click: bool
def __post_init__(self) -> None:
self.key_names = filter_web_keys_forbidden(self.key_names)
self.process_mouse()
def process_mouse(self) -> None:
"""Process mouse movement with discretization"""
# Import mouse constants
from .action_processing import MOUSE_X_POSSIBLES, MOUSE_Y_POSSIBLES, MOUSE_X_LIM, MOUSE_Y_LIM
# Clip and match mouse to closest in list of possibles
x = np.clip(self.mouse_x, MOUSE_X_LIM[0], MOUSE_X_LIM[1])
y = np.clip(self.mouse_y, MOUSE_Y_LIM[0], MOUSE_Y_LIM[1])
self.mouse_x = min(MOUSE_X_POSSIBLES, key=lambda x_: abs(x_ - x))
self.mouse_y = min(MOUSE_Y_POSSIBLES, key=lambda x_: abs(x_ - y))
# Use arrow keys to override mouse movements
for key_name in self.key_names:
if key_name == "camera_left":
self.mouse_x = -60
elif key_name == "camera_right":
self.mouse_x = +60
elif key_name == "camera_up":
self.mouse_y = -50
elif key_name == "camera_down":
self.mouse_y = +50
def filter_web_keys_forbidden(key_names: List[str]) -> List[str]:
"""Filter out forbidden key combinations"""
names = set(key_names)
filtered_names = []
for key_name in key_names:
# Check if adding this key would create a forbidden combination
test_names = set(filtered_names + [key_name])
is_forbidden = False
for forbidden in WEB_FORBIDDEN_COMBINATIONS:
if forbidden.issubset(test_names):
is_forbidden = True
break
if not is_forbidden:
filtered_names.append(key_name)
return filtered_names
def web_keys_to_csgo_action_names(pressed_web_keys: Set[str]) -> List[str]:
"""Convert set of pressed web keys to CSGO action names"""
action_names = []
for web_key in pressed_web_keys:
if web_key in WEB_KEYMAP:
action_names.append(WEB_KEYMAP[web_key])
return action_names
def encode_web_csgo_action(web_action: WebCSGOAction, device: torch.device) -> torch.Tensor:
"""Encode web CSGO action to tensor format (compatible with original encoding)"""
from .action_processing import MOUSE_X_POSSIBLES, MOUSE_Y_POSSIBLES, N_KEYS, N_CLICKS, N_MOUSE_X, N_MOUSE_Y
keys_pressed_onehot = np.zeros(N_KEYS)
mouse_x_onehot = np.zeros(N_MOUSE_X)
mouse_y_onehot = np.zeros(N_MOUSE_Y)
l_click_onehot = np.zeros(1)
r_click_onehot = np.zeros(1)
# Map action names to one-hot encoding
for action_name in web_action.key_names:
if action_name == "up": # w key
keys_pressed_onehot[0] = 1
elif action_name == "left": # a key
keys_pressed_onehot[1] = 1
elif action_name == "down": # s key
keys_pressed_onehot[2] = 1
elif action_name == "right": # d key
keys_pressed_onehot[3] = 1
elif action_name == "jump": # space
keys_pressed_onehot[4] = 1
elif action_name == "crouch": # ctrl
keys_pressed_onehot[5] = 1
elif action_name == "walk": # shift
keys_pressed_onehot[6] = 1
elif action_name == "weapon1": # 1
keys_pressed_onehot[7] = 1
elif action_name == "weapon2": # 2
keys_pressed_onehot[8] = 1
elif action_name == "weapon3": # 3
keys_pressed_onehot[9] = 1
elif action_name == "reload": # r
keys_pressed_onehot[10] = 1
l_click_onehot[0] = int(web_action.l_click)
r_click_onehot[0] = int(web_action.r_click)
mouse_x_onehot[MOUSE_X_POSSIBLES.index(web_action.mouse_x)] = 1
mouse_y_onehot[MOUSE_Y_POSSIBLES.index(web_action.mouse_y)] = 1
assert mouse_x_onehot.sum() == 1
assert mouse_y_onehot.sum() == 1
return torch.tensor(
np.concatenate((
keys_pressed_onehot,
l_click_onehot,
r_click_onehot,
mouse_x_onehot,
mouse_y_onehot,
)),
device=device,
dtype=torch.float32,
)
def print_web_csgo_action(action: WebCSGOAction) -> Tuple[str, str, str]:
"""Print web CSGO action in readable format"""
action_names = [name for name in action.key_names if not name.startswith("camera_")]
keys = " + ".join(action_names)
mouse = str((action.mouse_x, action.mouse_y)) * (action.mouse_x != 0 or action.mouse_y != 0)
clicks = "L" * action.l_click + " + " * (action.l_click and action.r_click) + "R" * action.r_click
return keys, mouse, clicks