File size: 6,798 Bytes
c64c726
ded2bd6
c64c726
 
ded2bd6
c64c726
 
ded2bd6
b8159f9
 
c64c726
 
 
ded2bd6
 
c64c726
 
ded2bd6
 
c64c726
 
 
 
 
ded2bd6
 
 
c64c726
ded2bd6
c64c726
ded2bd6
 
 
c64c726
1d96a61
 
 
 
 
 
 
ded2bd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c64c726
ded2bd6
c64c726
 
 
ded2bd6
c64c726
 
 
 
 
 
 
 
b8159f9
 
 
 
 
 
 
 
ded2bd6
 
 
 
 
c64c726
ded2bd6
 
 
 
 
 
b8159f9
 
 
ded2bd6
1d96a61
 
 
 
 
ded2bd6
b8159f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ded2bd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c64c726
ded2bd6
 
c64c726
ded2bd6
 
 
c64c726
1d96a61
ded2bd6
1d96a61
 
 
 
c64c726
ded2bd6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""
Web-compatible PlayEnv that handles web input and AI inference
"""

from typing import Any, Dict, List, Set, Tuple
import torch
from torch import Tensor
from torch.distributions.categorical import Categorical
import torch.nn as nn
import torch.nn.functional as F

from ..agent import Agent
from ..envs import WorldModelEnv
from ..csgo.web_action_processing import WebCSGOAction, web_keys_to_csgo_action_names, encode_web_csgo_action
from .play_env import PlayEnv


class WebPlayEnv(PlayEnv):
    """Web-compatible version of PlayEnv that handles web input and AI inference"""
    
    def __init__(
        self,
        agent: Agent,
        wm_env: WorldModelEnv,
        recording_mode: bool,
        store_denoising_trajectory: bool,
        store_original_obs: bool,
    ) -> None:
        super().__init__(agent, wm_env, recording_mode, store_denoising_trajectory, store_original_obs)
        
        # For web demo, we want AI control by default
        self.is_human_player = False  # AI controls the actions
        self.human_input_override = False  # Can be set to True to allow human input
        
        # Initialize LSTM hidden states for actor-critic (only if actor_critic exists)
        if agent.actor_critic is not None:
            self.hx = torch.zeros(1, agent.actor_critic.lstm_dim, device=agent.device)
            self.cx = torch.zeros(1, agent.actor_critic.lstm_dim, device=agent.device)
        else:
            self.hx = None
            self.cx = None
    
    def switch_controller(self) -> None:
        """Switch between AI and human control"""
        self.is_human_player = not self.is_human_player
        print(f"Switched to {'human' if self.is_human_player else 'AI'} control")
    
    def str_control(self) -> str:
        """Return control mode string"""
        if self.human_input_override:
            return "Human Override"
        return "Human" if self.is_human_player else "AI"
    
    @torch.no_grad()
    def step_from_web_input(
        self,
        pressed_keys: Set[str],
        mouse_x: float,
        mouse_y: float,
        l_click: bool,
        r_click: bool,
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Dict[str, Any]]:
        """
        Step the environment with web input.
        If AI mode is enabled, use AI inference. If human mode or override, use human input.
        """
        
        # Convert web keys to action names
        action_names = web_keys_to_csgo_action_names(pressed_keys)
        
        # Create web CSGO action from input
        web_action = WebCSGOAction(
            key_names=action_names,
            mouse_x=mouse_x,
            mouse_y=mouse_y,
            l_click=l_click,
            r_click=r_click
        )
        
        # Ensure we have a valid observation; if not, reset the environment
        if self.obs is None:
            try:
                self.obs, _ = self.reset()
            except Exception:
                # If reset fails, fall back to human input below
                pass

        # If we have human input override or in human mode, use human input
        if self.human_input_override or self.is_human_player:
            # Encode the web action to tensor format
            action = encode_web_csgo_action(web_action, device=self.agent.device)
            
        else:
            # AI mode - use the agent's actor-critic to predict the action
            try:
                # Get current observation (ensure it has batch dimension)
                obs = self.obs
                if obs.ndim == 3:  # CHW -> BCHW
                    obs = obs.unsqueeze(0)
                # Ensure obs is on the same device as the models
                if obs.device != self.agent.device:
                    obs = obs.to(self.agent.device, non_blocking=True)
                
                # Detach hidden states to prevent gradient tracking (only if they exist)
                if self.hx is not None:
                    self.hx = self.hx.detach()
                if self.cx is not None:
                    self.cx = self.cx.detach()
                
                # Resize observation to match actor-critic expected encoder/LSTM input
                # Count how many MaxPool2d layers are in the encoder to infer downsampling factor
                if hasattr(self.agent, "actor_critic") and self.agent.actor_critic is not None:
                    try:
                        n_pools = sum(
                            1 for m in self.agent.actor_critic.encoder.encoder if isinstance(m, nn.MaxPool2d)
                        )
                        # We want the spatial size after the encoder to be 1x1 so that
                        # flattening matches the LSTM input size configured at init time.
                        # With n_pools halvings, input size must be 2**n_pools.
                        target_hw = 2 ** n_pools if n_pools > 0 else min(int(obs.size(-2)), int(obs.size(-1)))
                        if obs.size(-2) != target_hw or obs.size(-1) != target_hw:
                            obs = F.interpolate(
                                obs, size=(target_hw, target_hw), mode="bilinear", align_corners=False
                            )
                    except Exception:
                        # If anything goes wrong in the shape logic, fall back without resizing
                        pass

                # Get action logits and value from actor-critic
                logits_act, value, (self.hx, self.cx) = self.agent.actor_critic.predict_act_value(obs, (self.hx, self.cx))
                
                # Sample action from logits
                action_dist = Categorical(logits=logits_act)
                action = action_dist.sample()
                
                # Convert to proper shape (remove batch dimension if needed)
                if action.ndim > 0 and action.size(0) == 1:
                    action = action.squeeze(0)
                
            except Exception as e:
                print(f"AI inference failed: {e}")
                import traceback
                traceback.print_exc()
                # Fallback to human input if AI fails
                action = encode_web_csgo_action(web_action, device=self.agent.device)
        
        # Step the environment with the chosen action
        next_obs, rew, end, trunc, env_info = self.env.step(action)
        
        # Update internal state
        self.obs = next_obs
        self.t += 1
        
        # Reset hidden states on episode end (only if they exist)
        if end.any() or trunc.any():
            if self.hx is not None:
                self.hx.zero_()
            if self.cx is not None:
                self.cx.zero_()
        
        # Return the step results
        return next_obs, rew, end, trunc, env_info