Spaces:
Sleeping
Sleeping
| """ | |
| Web-compatible PlayEnv that handles web input and AI inference | |
| """ | |
| from typing import Any, Dict, List, Set, Tuple | |
| import torch | |
| from torch import Tensor | |
| from torch.distributions.categorical import Categorical | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from ..agent import Agent | |
| from ..envs import WorldModelEnv | |
| from ..csgo.web_action_processing import WebCSGOAction, web_keys_to_csgo_action_names, encode_web_csgo_action | |
| from .play_env import PlayEnv | |
| class WebPlayEnv(PlayEnv): | |
| """Web-compatible version of PlayEnv that handles web input and AI inference""" | |
| def __init__( | |
| self, | |
| agent: Agent, | |
| wm_env: WorldModelEnv, | |
| recording_mode: bool, | |
| store_denoising_trajectory: bool, | |
| store_original_obs: bool, | |
| ) -> None: | |
| super().__init__(agent, wm_env, recording_mode, store_denoising_trajectory, store_original_obs) | |
| # For web demo, we want AI control by default | |
| self.is_human_player = False # AI controls the actions | |
| self.human_input_override = False # Can be set to True to allow human input | |
| # Initialize LSTM hidden states for actor-critic (only if actor_critic exists) | |
| if agent.actor_critic is not None: | |
| self.hx = torch.zeros(1, agent.actor_critic.lstm_dim, device=agent.device) | |
| self.cx = torch.zeros(1, agent.actor_critic.lstm_dim, device=agent.device) | |
| else: | |
| self.hx = None | |
| self.cx = None | |
| def switch_controller(self) -> None: | |
| """Switch between AI and human control""" | |
| self.is_human_player = not self.is_human_player | |
| print(f"Switched to {'human' if self.is_human_player else 'AI'} control") | |
| def str_control(self) -> str: | |
| """Return control mode string""" | |
| if self.human_input_override: | |
| return "Human Override" | |
| return "Human" if self.is_human_player else "AI" | |
| def step_from_web_input( | |
| self, | |
| pressed_keys: Set[str], | |
| mouse_x: float, | |
| mouse_y: float, | |
| l_click: bool, | |
| r_click: bool, | |
| ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Dict[str, Any]]: | |
| """ | |
| Step the environment with web input. | |
| If AI mode is enabled, use AI inference. If human mode or override, use human input. | |
| """ | |
| # Convert web keys to action names | |
| action_names = web_keys_to_csgo_action_names(pressed_keys) | |
| # Create web CSGO action from input | |
| web_action = WebCSGOAction( | |
| key_names=action_names, | |
| mouse_x=mouse_x, | |
| mouse_y=mouse_y, | |
| l_click=l_click, | |
| r_click=r_click | |
| ) | |
| # Ensure we have a valid observation; if not, reset the environment | |
| if self.obs is None: | |
| try: | |
| self.obs, _ = self.reset() | |
| except Exception: | |
| # If reset fails, fall back to human input below | |
| pass | |
| # If we have human input override or in human mode, use human input | |
| if self.human_input_override or self.is_human_player: | |
| # Encode the web action to tensor format | |
| action = encode_web_csgo_action(web_action, device=self.agent.device) | |
| else: | |
| # AI mode - use the agent's actor-critic to predict the action | |
| try: | |
| # Get current observation (ensure it has batch dimension) | |
| obs = self.obs | |
| if obs.ndim == 3: # CHW -> BCHW | |
| obs = obs.unsqueeze(0) | |
| # Ensure obs is on the same device as the models | |
| if obs.device != self.agent.device: | |
| obs = obs.to(self.agent.device, non_blocking=True) | |
| # Detach hidden states to prevent gradient tracking (only if they exist) | |
| if self.hx is not None: | |
| self.hx = self.hx.detach() | |
| if self.cx is not None: | |
| self.cx = self.cx.detach() | |
| # Resize observation to match actor-critic expected encoder/LSTM input | |
| # Count how many MaxPool2d layers are in the encoder to infer downsampling factor | |
| if hasattr(self.agent, "actor_critic") and self.agent.actor_critic is not None: | |
| try: | |
| n_pools = sum( | |
| 1 for m in self.agent.actor_critic.encoder.encoder if isinstance(m, nn.MaxPool2d) | |
| ) | |
| # We want the spatial size after the encoder to be 1x1 so that | |
| # flattening matches the LSTM input size configured at init time. | |
| # With n_pools halvings, input size must be 2**n_pools. | |
| target_hw = 2 ** n_pools if n_pools > 0 else min(int(obs.size(-2)), int(obs.size(-1))) | |
| if obs.size(-2) != target_hw or obs.size(-1) != target_hw: | |
| obs = F.interpolate( | |
| obs, size=(target_hw, target_hw), mode="bilinear", align_corners=False | |
| ) | |
| except Exception: | |
| # If anything goes wrong in the shape logic, fall back without resizing | |
| pass | |
| # Get action logits and value from actor-critic | |
| logits_act, value, (self.hx, self.cx) = self.agent.actor_critic.predict_act_value(obs, (self.hx, self.cx)) | |
| # Sample action from logits | |
| action_dist = Categorical(logits=logits_act) | |
| action = action_dist.sample() | |
| # Convert to proper shape (remove batch dimension if needed) | |
| if action.ndim > 0 and action.size(0) == 1: | |
| action = action.squeeze(0) | |
| except Exception as e: | |
| print(f"AI inference failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # Fallback to human input if AI fails | |
| action = encode_web_csgo_action(web_action, device=self.agent.device) | |
| # Step the environment with the chosen action | |
| next_obs, rew, end, trunc, env_info = self.env.step(action) | |
| # Update internal state | |
| self.obs = next_obs | |
| self.t += 1 | |
| # Reset hidden states on episode end (only if they exist) | |
| if end.any() or trunc.any(): | |
| if self.hx is not None: | |
| self.hx.zero_() | |
| if self.cx is not None: | |
| self.cx.zero_() | |
| # Return the step results | |
| return next_obs, rew, end, trunc, env_info | |