trioskosmos's picture
Upload folder using huggingface_hub
463f868 verified
import numpy as np
import torch
from stable_baselines3.common.vec_env import SubprocVecEnv
from sb3_contrib import MaskablePPO
from sb3_contrib import MaskablePPO
import os
import time
from ai.student_model import StudentActor # Safe import now
class BatchedSubprocVecEnv(SubprocVecEnv):
"""
A specialized SubprocVecEnv that handles batched opponent inference
in the main process (GPU) instead of individual workers (CPU).
Uses standard worker commands (env_method) for better Windows stability.
"""
def __init__(self, env_fns, opponent_model_path=None):
super().__init__(env_fns)
self.opponent_model = None
self.opponent_path = opponent_model_path
self.last_load_time = 0
self._load_opponent()
def _load_opponent(self):
use_distilled = os.getenv("USE_DISTILLED_OPPONENT", "0") == "1"
if use_distilled:
# Path to student model
distilled_path = os.path.join("checkpoints", "student_model.pth")
if os.path.exists(distilled_path):
try:
print(f" [BatchedEnv] Loading DISTILLED opponent from {distilled_path}...", flush=True)
self.opponent_model = torch.load(distilled_path, weights_only=False)
self.opponent_model.eval() # Inference mode
# Move to CUDA if available
if torch.cuda.is_available():
self.opponent_model.to("cuda")
print(f" [BatchedEnv] Loaded Distilled Student (Fast Mode).", flush=True)
return
except Exception as e:
print(f" [BatchedEnv] Failed to load distilled model: {e}. Falling back...", flush=True)
if self.opponent_path and os.path.exists(self.opponent_path):
try:
mtime = os.path.getmtime(self.opponent_path)
if mtime > self.last_load_time:
print(f" [BatchedEnv] Loading opponent model from {self.opponent_path}...", flush=True)
# For opponent, we can use CPU for inference if GPU is packed,
# but GPU is preferred. Let's stick to cuda for now.
self.opponent_model = MaskablePPO.load(self.opponent_path, device="cuda")
# Apply Dynamic Quantization if enabled
if os.getenv("ENABLE_QUANTIZATION", "0") == "1":
print(" [BatchedEnv] Applying dynamic quantization to opponent...", flush=True)
try:
# SB3 models wrap the policy in .policy
# We can only quantize CPU models easily in PyTorch, but for GPU inference
# we usually use half-precision (FP16) or INT8 via TensorRT.
# For simplicity/safety on standard PyTorch, we'll try FP16 (Half)
self.opponent_model.policy.to(torch.float16) # Use FP16 weights
print(" [BatchedEnv] Converted opponent to FP16.", flush=True)
except Exception as qe:
print(f" [BatchedEnv] Quantization failed: {qe}", flush=True)
self.last_load_time = mtime
print(f" [BatchedEnv] Successfully loaded opponent model.", flush=True)
except Exception as e:
print(f" [BatchedEnv] Warning: Failed to load opponent: {e}", flush=True)
# We don't crash the whole thing, we just fallback to random
def reset(self):
obs = super().reset()
# Pull infos to check for immediate opponent moves after reset
infos = self.env_method("get_current_info")
needs_opp = [i for i, info in enumerate(infos) if info.get("needs_opponent", False)]
if needs_opp:
# Convert to mutable lists/arrays for assignment in _handle_opponent_moves
obs_list = list(obs)
rews_list = [0.0] * self.num_envs
terms_list = [False] * self.num_envs
truncs_list = [False] * self.num_envs
infos_list = list(infos)
obs_list, _, _, _, infos_list = self._handle_opponent_moves(
obs_list, rews_list, terms_list, truncs_list, infos_list, needs_opp
)
obs = np.stack(obs_list)
return obs
def step_wait(self):
start_wait = time.perf_counter()
# 1. Main Agent Step (Distributed)
# super().step_wait() returns (obs, rews, terms, truncs, infos) for Gym-like envs
# and (obs, rews, terms, infos) for old SB3 envs.
# We assume it returns 5 elements for simplicity, as SB3 is moving towards Gym API.
# If it returns 4, we'll need to adapt. For now, let's assume 5.
res = super().step_wait()
if len(res) == 5:
obs, rews, terms, truncs, infos = res
else: # Handle old SB3 API if it returns 4 elements
obs, rews, terms, infos = res
truncs = np.zeros_like(terms, dtype=bool)
# Convert to mutable structures
obs_list = list(obs)
rews_list = list(rews)
terms_list = list(terms)
truncs_list = list(truncs)
infos_list = list(infos)
# 2. Opponent Moves
needs_opp = [i for i, info in enumerate(infos_list) if info.get("needs_opponent", False)]
start_opp = time.perf_counter()
if needs_opp:
obs_list, rews_list, terms_list, truncs_list, infos_list = self._handle_opponent_moves(
obs_list, rews_list, terms_list, truncs_list, infos_list, needs_opp
)
obs = np.stack(obs_list)
rews = np.array(rews_list)
terms = np.array(terms_list)
truncs = np.array(truncs_list)
infos = tuple(infos_list)
opp_time = time.perf_counter() - start_opp
total_time = time.perf_counter() - start_wait
# Inject Debug Timing
if len(infos) > 0:
# Ensure infos[0] is mutable (dict) before modifying
if not isinstance(infos[0], dict):
infos_list[0] = dict(infos[0]) # Convert to dict if it's a tuple/other immutable
infos = tuple(infos_list) # Update the infos tuple
infos[0]["t_opp_batch"] = opp_time
infos[0]["t_total_batch"] = total_time
# Aggregate worker times if available (from gym_env instrumentation)
eng_times = [info.get("time_engine", 0) for info in infos]
obs_times = [info.get("time_obs", 0) for info in infos]
infos[0]["t_worker_eng_avg"] = sum(eng_times) / (len(eng_times) + 1e-6)
infos[0]["t_worker_obs_avg"] = sum(obs_times) / (len(obs_times) + 1e-6)
if len(res) == 5:
return obs, rews, terms, truncs, infos
else:
return obs, rews, terms, infos
def _handle_opponent_moves(self, obs, rews, terms, truncs, infos, needs_opp):
while needs_opp:
self._load_opponent()
# 1. Batch Inference
batch_obs = np.stack([infos[i]["opp_obs"] for i in needs_opp])
batch_masks = np.stack([infos[i]["opp_masks"] for i in needs_opp])
if self.opponent_model:
with torch.no_grad():
# Check if model is SB3 or Custom Torch Student
is_student = isinstance(self.opponent_model, torch.nn.Module)
device = next(self.opponent_model.parameters()).device if is_student else self.opponent_model.device
obs_tensor = torch.as_tensor(batch_obs, device=device)
masks_tensor = torch.as_tensor(batch_masks, device=device)
# FP16 check (only if not student, student is usually concise enough or we can quantize it too)
if not is_student and os.getenv("ENABLE_QUANTIZATION", "0") == "1":
obs_tensor = obs_tensor.half()
if is_student:
# Direct predict call on StudentActor
actions, _ = self.opponent_model.predict(
obs_tensor, action_masks=masks_tensor, deterministic=False
)
else:
# SB3 API
actions, _ = self.opponent_model.predict(
obs_tensor, action_masks=masks_tensor, deterministic=False
)
else:
actions = [np.random.choice(np.where(m)[0]) for m in batch_masks]
# 2. Step Opponents in parallel
for i, idx in enumerate(needs_opp):
self.remotes[idx].send(("env_method", ("step_opponent", (actions[i],), {})))
for i, idx in enumerate(needs_opp):
res = self.remotes[idx].recv()
# res is (obs, reward, term, trunc, info)
obs[idx], rews[idx], terms[idx], truncs[idx], infos[idx] = res
# 3. Repeat if multi-turn
needs_opp = [i for i, info in enumerate(infos) if info.get("needs_opponent", False)]
return obs, rews, terms, truncs, infos