PIWM / src /game /web_play_env.py
musictimer's picture
Fix initial bugs
1d96a61
"""
Web-compatible PlayEnv that handles web input and AI inference
"""
from typing import Any, Dict, List, Set, Tuple
import torch
from torch import Tensor
from torch.distributions.categorical import Categorical
import torch.nn as nn
import torch.nn.functional as F
from ..agent import Agent
from ..envs import WorldModelEnv
from ..csgo.web_action_processing import WebCSGOAction, web_keys_to_csgo_action_names, encode_web_csgo_action
from .play_env import PlayEnv
class WebPlayEnv(PlayEnv):
"""Web-compatible version of PlayEnv that handles web input and AI inference"""
def __init__(
self,
agent: Agent,
wm_env: WorldModelEnv,
recording_mode: bool,
store_denoising_trajectory: bool,
store_original_obs: bool,
) -> None:
super().__init__(agent, wm_env, recording_mode, store_denoising_trajectory, store_original_obs)
# For web demo, we want AI control by default
self.is_human_player = False # AI controls the actions
self.human_input_override = False # Can be set to True to allow human input
# Initialize LSTM hidden states for actor-critic (only if actor_critic exists)
if agent.actor_critic is not None:
self.hx = torch.zeros(1, agent.actor_critic.lstm_dim, device=agent.device)
self.cx = torch.zeros(1, agent.actor_critic.lstm_dim, device=agent.device)
else:
self.hx = None
self.cx = None
def switch_controller(self) -> None:
"""Switch between AI and human control"""
self.is_human_player = not self.is_human_player
print(f"Switched to {'human' if self.is_human_player else 'AI'} control")
def str_control(self) -> str:
"""Return control mode string"""
if self.human_input_override:
return "Human Override"
return "Human" if self.is_human_player else "AI"
@torch.no_grad()
def step_from_web_input(
self,
pressed_keys: Set[str],
mouse_x: float,
mouse_y: float,
l_click: bool,
r_click: bool,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Dict[str, Any]]:
"""
Step the environment with web input.
If AI mode is enabled, use AI inference. If human mode or override, use human input.
"""
# Convert web keys to action names
action_names = web_keys_to_csgo_action_names(pressed_keys)
# Create web CSGO action from input
web_action = WebCSGOAction(
key_names=action_names,
mouse_x=mouse_x,
mouse_y=mouse_y,
l_click=l_click,
r_click=r_click
)
# Ensure we have a valid observation; if not, reset the environment
if self.obs is None:
try:
self.obs, _ = self.reset()
except Exception:
# If reset fails, fall back to human input below
pass
# If we have human input override or in human mode, use human input
if self.human_input_override or self.is_human_player:
# Encode the web action to tensor format
action = encode_web_csgo_action(web_action, device=self.agent.device)
else:
# AI mode - use the agent's actor-critic to predict the action
try:
# Get current observation (ensure it has batch dimension)
obs = self.obs
if obs.ndim == 3: # CHW -> BCHW
obs = obs.unsqueeze(0)
# Ensure obs is on the same device as the models
if obs.device != self.agent.device:
obs = obs.to(self.agent.device, non_blocking=True)
# Detach hidden states to prevent gradient tracking (only if they exist)
if self.hx is not None:
self.hx = self.hx.detach()
if self.cx is not None:
self.cx = self.cx.detach()
# Resize observation to match actor-critic expected encoder/LSTM input
# Count how many MaxPool2d layers are in the encoder to infer downsampling factor
if hasattr(self.agent, "actor_critic") and self.agent.actor_critic is not None:
try:
n_pools = sum(
1 for m in self.agent.actor_critic.encoder.encoder if isinstance(m, nn.MaxPool2d)
)
# We want the spatial size after the encoder to be 1x1 so that
# flattening matches the LSTM input size configured at init time.
# With n_pools halvings, input size must be 2**n_pools.
target_hw = 2 ** n_pools if n_pools > 0 else min(int(obs.size(-2)), int(obs.size(-1)))
if obs.size(-2) != target_hw or obs.size(-1) != target_hw:
obs = F.interpolate(
obs, size=(target_hw, target_hw), mode="bilinear", align_corners=False
)
except Exception:
# If anything goes wrong in the shape logic, fall back without resizing
pass
# Get action logits and value from actor-critic
logits_act, value, (self.hx, self.cx) = self.agent.actor_critic.predict_act_value(obs, (self.hx, self.cx))
# Sample action from logits
action_dist = Categorical(logits=logits_act)
action = action_dist.sample()
# Convert to proper shape (remove batch dimension if needed)
if action.ndim > 0 and action.size(0) == 1:
action = action.squeeze(0)
except Exception as e:
print(f"AI inference failed: {e}")
import traceback
traceback.print_exc()
# Fallback to human input if AI fails
action = encode_web_csgo_action(web_action, device=self.agent.device)
# Step the environment with the chosen action
next_obs, rew, end, trunc, env_info = self.env.step(action)
# Update internal state
self.obs = next_obs
self.t += 1
# Reset hidden states on episode end (only if they exist)
if end.any() or trunc.any():
if self.hx is not None:
self.hx.zero_()
if self.cx is not None:
self.cx.zero_()
# Return the step results
return next_obs, rew, end, trunc, env_info