Spaces:
Sleeping
Sleeping
File size: 6,798 Bytes
c64c726 ded2bd6 c64c726 ded2bd6 c64c726 ded2bd6 b8159f9 c64c726 ded2bd6 c64c726 ded2bd6 c64c726 ded2bd6 c64c726 ded2bd6 c64c726 ded2bd6 c64c726 1d96a61 ded2bd6 c64c726 ded2bd6 c64c726 ded2bd6 c64c726 b8159f9 ded2bd6 c64c726 ded2bd6 b8159f9 ded2bd6 1d96a61 ded2bd6 b8159f9 ded2bd6 c64c726 ded2bd6 c64c726 ded2bd6 c64c726 1d96a61 ded2bd6 1d96a61 c64c726 ded2bd6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
"""
Web-compatible PlayEnv that handles web input and AI inference
"""
from typing import Any, Dict, List, Set, Tuple
import torch
from torch import Tensor
from torch.distributions.categorical import Categorical
import torch.nn as nn
import torch.nn.functional as F
from ..agent import Agent
from ..envs import WorldModelEnv
from ..csgo.web_action_processing import WebCSGOAction, web_keys_to_csgo_action_names, encode_web_csgo_action
from .play_env import PlayEnv
class WebPlayEnv(PlayEnv):
"""Web-compatible version of PlayEnv that handles web input and AI inference"""
def __init__(
self,
agent: Agent,
wm_env: WorldModelEnv,
recording_mode: bool,
store_denoising_trajectory: bool,
store_original_obs: bool,
) -> None:
super().__init__(agent, wm_env, recording_mode, store_denoising_trajectory, store_original_obs)
# For web demo, we want AI control by default
self.is_human_player = False # AI controls the actions
self.human_input_override = False # Can be set to True to allow human input
# Initialize LSTM hidden states for actor-critic (only if actor_critic exists)
if agent.actor_critic is not None:
self.hx = torch.zeros(1, agent.actor_critic.lstm_dim, device=agent.device)
self.cx = torch.zeros(1, agent.actor_critic.lstm_dim, device=agent.device)
else:
self.hx = None
self.cx = None
def switch_controller(self) -> None:
"""Switch between AI and human control"""
self.is_human_player = not self.is_human_player
print(f"Switched to {'human' if self.is_human_player else 'AI'} control")
def str_control(self) -> str:
"""Return control mode string"""
if self.human_input_override:
return "Human Override"
return "Human" if self.is_human_player else "AI"
@torch.no_grad()
def step_from_web_input(
self,
pressed_keys: Set[str],
mouse_x: float,
mouse_y: float,
l_click: bool,
r_click: bool,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Dict[str, Any]]:
"""
Step the environment with web input.
If AI mode is enabled, use AI inference. If human mode or override, use human input.
"""
# Convert web keys to action names
action_names = web_keys_to_csgo_action_names(pressed_keys)
# Create web CSGO action from input
web_action = WebCSGOAction(
key_names=action_names,
mouse_x=mouse_x,
mouse_y=mouse_y,
l_click=l_click,
r_click=r_click
)
# Ensure we have a valid observation; if not, reset the environment
if self.obs is None:
try:
self.obs, _ = self.reset()
except Exception:
# If reset fails, fall back to human input below
pass
# If we have human input override or in human mode, use human input
if self.human_input_override or self.is_human_player:
# Encode the web action to tensor format
action = encode_web_csgo_action(web_action, device=self.agent.device)
else:
# AI mode - use the agent's actor-critic to predict the action
try:
# Get current observation (ensure it has batch dimension)
obs = self.obs
if obs.ndim == 3: # CHW -> BCHW
obs = obs.unsqueeze(0)
# Ensure obs is on the same device as the models
if obs.device != self.agent.device:
obs = obs.to(self.agent.device, non_blocking=True)
# Detach hidden states to prevent gradient tracking (only if they exist)
if self.hx is not None:
self.hx = self.hx.detach()
if self.cx is not None:
self.cx = self.cx.detach()
# Resize observation to match actor-critic expected encoder/LSTM input
# Count how many MaxPool2d layers are in the encoder to infer downsampling factor
if hasattr(self.agent, "actor_critic") and self.agent.actor_critic is not None:
try:
n_pools = sum(
1 for m in self.agent.actor_critic.encoder.encoder if isinstance(m, nn.MaxPool2d)
)
# We want the spatial size after the encoder to be 1x1 so that
# flattening matches the LSTM input size configured at init time.
# With n_pools halvings, input size must be 2**n_pools.
target_hw = 2 ** n_pools if n_pools > 0 else min(int(obs.size(-2)), int(obs.size(-1)))
if obs.size(-2) != target_hw or obs.size(-1) != target_hw:
obs = F.interpolate(
obs, size=(target_hw, target_hw), mode="bilinear", align_corners=False
)
except Exception:
# If anything goes wrong in the shape logic, fall back without resizing
pass
# Get action logits and value from actor-critic
logits_act, value, (self.hx, self.cx) = self.agent.actor_critic.predict_act_value(obs, (self.hx, self.cx))
# Sample action from logits
action_dist = Categorical(logits=logits_act)
action = action_dist.sample()
# Convert to proper shape (remove batch dimension if needed)
if action.ndim > 0 and action.size(0) == 1:
action = action.squeeze(0)
except Exception as e:
print(f"AI inference failed: {e}")
import traceback
traceback.print_exc()
# Fallback to human input if AI fails
action = encode_web_csgo_action(web_action, device=self.agent.device)
# Step the environment with the chosen action
next_obs, rew, end, trunc, env_info = self.env.step(action)
# Update internal state
self.obs = next_obs
self.t += 1
# Reset hidden states on episode end (only if they exist)
if end.any() or trunc.any():
if self.hx is not None:
self.hx.zero_()
if self.cx is not None:
self.cx.zero_()
# Return the step results
return next_obs, rew, end, trunc, env_info
|