Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| from stable_baselines3.common.vec_env import SubprocVecEnv | |
| from sb3_contrib import MaskablePPO | |
| from sb3_contrib import MaskablePPO | |
| import os | |
| import time | |
| from ai.student_model import StudentActor # Safe import now | |
| class BatchedSubprocVecEnv(SubprocVecEnv): | |
| """ | |
| A specialized SubprocVecEnv that handles batched opponent inference | |
| in the main process (GPU) instead of individual workers (CPU). | |
| Uses standard worker commands (env_method) for better Windows stability. | |
| """ | |
| def __init__(self, env_fns, opponent_model_path=None): | |
| super().__init__(env_fns) | |
| self.opponent_model = None | |
| self.opponent_path = opponent_model_path | |
| self.last_load_time = 0 | |
| self._load_opponent() | |
| def _load_opponent(self): | |
| use_distilled = os.getenv("USE_DISTILLED_OPPONENT", "0") == "1" | |
| if use_distilled: | |
| # Path to student model | |
| distilled_path = os.path.join("checkpoints", "student_model.pth") | |
| if os.path.exists(distilled_path): | |
| try: | |
| print(f" [BatchedEnv] Loading DISTILLED opponent from {distilled_path}...", flush=True) | |
| self.opponent_model = torch.load(distilled_path, weights_only=False) | |
| self.opponent_model.eval() # Inference mode | |
| # Move to CUDA if available | |
| if torch.cuda.is_available(): | |
| self.opponent_model.to("cuda") | |
| print(f" [BatchedEnv] Loaded Distilled Student (Fast Mode).", flush=True) | |
| return | |
| except Exception as e: | |
| print(f" [BatchedEnv] Failed to load distilled model: {e}. Falling back...", flush=True) | |
| if self.opponent_path and os.path.exists(self.opponent_path): | |
| try: | |
| mtime = os.path.getmtime(self.opponent_path) | |
| if mtime > self.last_load_time: | |
| print(f" [BatchedEnv] Loading opponent model from {self.opponent_path}...", flush=True) | |
| # For opponent, we can use CPU for inference if GPU is packed, | |
| # but GPU is preferred. Let's stick to cuda for now. | |
| self.opponent_model = MaskablePPO.load(self.opponent_path, device="cuda") | |
| # Apply Dynamic Quantization if enabled | |
| if os.getenv("ENABLE_QUANTIZATION", "0") == "1": | |
| print(" [BatchedEnv] Applying dynamic quantization to opponent...", flush=True) | |
| try: | |
| # SB3 models wrap the policy in .policy | |
| # We can only quantize CPU models easily in PyTorch, but for GPU inference | |
| # we usually use half-precision (FP16) or INT8 via TensorRT. | |
| # For simplicity/safety on standard PyTorch, we'll try FP16 (Half) | |
| self.opponent_model.policy.to(torch.float16) # Use FP16 weights | |
| print(" [BatchedEnv] Converted opponent to FP16.", flush=True) | |
| except Exception as qe: | |
| print(f" [BatchedEnv] Quantization failed: {qe}", flush=True) | |
| self.last_load_time = mtime | |
| print(f" [BatchedEnv] Successfully loaded opponent model.", flush=True) | |
| except Exception as e: | |
| print(f" [BatchedEnv] Warning: Failed to load opponent: {e}", flush=True) | |
| # We don't crash the whole thing, we just fallback to random | |
| def reset(self): | |
| obs = super().reset() | |
| # Pull infos to check for immediate opponent moves after reset | |
| infos = self.env_method("get_current_info") | |
| needs_opp = [i for i, info in enumerate(infos) if info.get("needs_opponent", False)] | |
| if needs_opp: | |
| # Convert to mutable lists/arrays for assignment in _handle_opponent_moves | |
| obs_list = list(obs) | |
| rews_list = [0.0] * self.num_envs | |
| terms_list = [False] * self.num_envs | |
| truncs_list = [False] * self.num_envs | |
| infos_list = list(infos) | |
| obs_list, _, _, _, infos_list = self._handle_opponent_moves( | |
| obs_list, rews_list, terms_list, truncs_list, infos_list, needs_opp | |
| ) | |
| obs = np.stack(obs_list) | |
| return obs | |
| def step_wait(self): | |
| start_wait = time.perf_counter() | |
| # 1. Main Agent Step (Distributed) | |
| # super().step_wait() returns (obs, rews, terms, truncs, infos) for Gym-like envs | |
| # and (obs, rews, terms, infos) for old SB3 envs. | |
| # We assume it returns 5 elements for simplicity, as SB3 is moving towards Gym API. | |
| # If it returns 4, we'll need to adapt. For now, let's assume 5. | |
| res = super().step_wait() | |
| if len(res) == 5: | |
| obs, rews, terms, truncs, infos = res | |
| else: # Handle old SB3 API if it returns 4 elements | |
| obs, rews, terms, infos = res | |
| truncs = np.zeros_like(terms, dtype=bool) | |
| # Convert to mutable structures | |
| obs_list = list(obs) | |
| rews_list = list(rews) | |
| terms_list = list(terms) | |
| truncs_list = list(truncs) | |
| infos_list = list(infos) | |
| # 2. Opponent Moves | |
| needs_opp = [i for i, info in enumerate(infos_list) if info.get("needs_opponent", False)] | |
| start_opp = time.perf_counter() | |
| if needs_opp: | |
| obs_list, rews_list, terms_list, truncs_list, infos_list = self._handle_opponent_moves( | |
| obs_list, rews_list, terms_list, truncs_list, infos_list, needs_opp | |
| ) | |
| obs = np.stack(obs_list) | |
| rews = np.array(rews_list) | |
| terms = np.array(terms_list) | |
| truncs = np.array(truncs_list) | |
| infos = tuple(infos_list) | |
| opp_time = time.perf_counter() - start_opp | |
| total_time = time.perf_counter() - start_wait | |
| # Inject Debug Timing | |
| if len(infos) > 0: | |
| # Ensure infos[0] is mutable (dict) before modifying | |
| if not isinstance(infos[0], dict): | |
| infos_list[0] = dict(infos[0]) # Convert to dict if it's a tuple/other immutable | |
| infos = tuple(infos_list) # Update the infos tuple | |
| infos[0]["t_opp_batch"] = opp_time | |
| infos[0]["t_total_batch"] = total_time | |
| # Aggregate worker times if available (from gym_env instrumentation) | |
| eng_times = [info.get("time_engine", 0) for info in infos] | |
| obs_times = [info.get("time_obs", 0) for info in infos] | |
| infos[0]["t_worker_eng_avg"] = sum(eng_times) / (len(eng_times) + 1e-6) | |
| infos[0]["t_worker_obs_avg"] = sum(obs_times) / (len(obs_times) + 1e-6) | |
| if len(res) == 5: | |
| return obs, rews, terms, truncs, infos | |
| else: | |
| return obs, rews, terms, infos | |
| def _handle_opponent_moves(self, obs, rews, terms, truncs, infos, needs_opp): | |
| while needs_opp: | |
| self._load_opponent() | |
| # 1. Batch Inference | |
| batch_obs = np.stack([infos[i]["opp_obs"] for i in needs_opp]) | |
| batch_masks = np.stack([infos[i]["opp_masks"] for i in needs_opp]) | |
| if self.opponent_model: | |
| with torch.no_grad(): | |
| # Check if model is SB3 or Custom Torch Student | |
| is_student = isinstance(self.opponent_model, torch.nn.Module) | |
| device = next(self.opponent_model.parameters()).device if is_student else self.opponent_model.device | |
| obs_tensor = torch.as_tensor(batch_obs, device=device) | |
| masks_tensor = torch.as_tensor(batch_masks, device=device) | |
| # FP16 check (only if not student, student is usually concise enough or we can quantize it too) | |
| if not is_student and os.getenv("ENABLE_QUANTIZATION", "0") == "1": | |
| obs_tensor = obs_tensor.half() | |
| if is_student: | |
| # Direct predict call on StudentActor | |
| actions, _ = self.opponent_model.predict( | |
| obs_tensor, action_masks=masks_tensor, deterministic=False | |
| ) | |
| else: | |
| # SB3 API | |
| actions, _ = self.opponent_model.predict( | |
| obs_tensor, action_masks=masks_tensor, deterministic=False | |
| ) | |
| else: | |
| actions = [np.random.choice(np.where(m)[0]) for m in batch_masks] | |
| # 2. Step Opponents in parallel | |
| for i, idx in enumerate(needs_opp): | |
| self.remotes[idx].send(("env_method", ("step_opponent", (actions[i],), {}))) | |
| for i, idx in enumerate(needs_opp): | |
| res = self.remotes[idx].recv() | |
| # res is (obs, reward, term, trunc, info) | |
| obs[idx], rews[idx], terms[idx], truncs[idx], infos[idx] = res | |
| # 3. Repeat if multi-turn | |
| needs_opp = [i for i, info in enumerate(infos) if info.get("needs_opponent", False)] | |
| return obs, rews, terms, truncs, infos | |