PIWM / src /game /game.py
musictimer's picture
Initial Diamond CSGO AI deployment
c64c726
from typing import Tuple, Union
import numpy as np
import pygame
from PIL import Image
import cv2
import os
from datetime import datetime
from csgo.action_processing import CSGOAction
from .dataset_env import DatasetEnv
from .play_env import PlayEnv
class Game:
def __init__(
self,
play_env: Union[PlayEnv, DatasetEnv],
size: Tuple[int, int],
mouse_multiplier: int,
fps: int,
verbose: bool,
) -> None:
self.env = play_env
self.height, self.width = size
self.mouse_multiplier = mouse_multiplier
self.fps = fps
self.verbose = verbose
self.env.print_controls()
print("\nControls:\n")
print(" m : switch control (human/replay)") # Not for main as Game can use either PlayEnv or DatasetEnv
print(" . : pause/unpause")
print(" e : step-by-step (when paused)")
print(" ⏎ : reset env")
print("Esc : quit")
print("\n")
input("Press enter to start")
def run(self) -> None:
pygame.init()
header_height = 150 if self.verbose else 0
header_width = 540
font_size = 16
screen = pygame.display.set_mode((0, 0), pygame.FULLSCREEN)
pygame.mouse.set_visible(False)
pygame.event.set_grab(True)
clock = pygame.time.Clock()
font = pygame.font.SysFont("mono", font_size)
x_center, y_center = screen.get_rect().center
x_header = x_center - header_width // 2
y_header = y_center - self.height // 2 - header_height - 10
header_rect = pygame.Rect(x_header, y_header, header_width, header_height)
def clear_header():
pygame.draw.rect(screen, pygame.Color("black"), header_rect)
pygame.draw.rect(screen, pygame.Color("white"), header_rect, 1)
def draw_text(text, idx_line, idx_column, num_cols):
x_pos = 5 + idx_column * int(header_width // num_cols)
y_pos = 5 + idx_line * font_size
assert (0 <= x_pos <= header_width) and (0 <= y_pos <= header_height)
screen.blit(font.render(text, True, pygame.Color("white")), (x_header + x_pos, y_header + y_pos))
def draw_obs(obs, obs_low_res=None):
assert obs.ndim == 4 and obs.size(0) == 1
img = Image.fromarray(obs[0].add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy())
pygame_image = np.array(img.resize((self.width, self.height), resample=Image.BICUBIC)).transpose((1, 0, 2))
surface = pygame.surfarray.make_surface(pygame_image)
screen.blit(surface, (x_center - self.width // 2, y_center - self.height // 2))
if obs_low_res is not None:
assert obs_low_res.ndim == 4 and obs_low_res.size(0) == 1
img = Image.fromarray(obs_low_res[0].add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy())
h = self.height * obs_low_res.size(2) // obs.size(2)
w = self.width * obs_low_res.size(3) // obs.size(3)
pygame_image = np.array(img.resize((w, h), resample=Image.BICUBIC)).transpose((1, 0, 2))
surface = pygame.surfarray.make_surface(pygame_image)
screen.blit(surface, (x_header + header_width - w - 5, y_header + 5 + font_size))
# screen.blit(surface, (x_center - w // 2, y_center + self.height // 2))
# Return the main pygame_image for video recording
return pygame_image.transpose((1, 0, 2))
def reset():
nonlocal obs, info, do_reset, ep_return, ep_length, keys_pressed, l_click, r_click
obs, info = self.env.reset()
pygame.event.clear()
do_reset = False
ep_return = 0
ep_length = 0
keys_pressed = []
l_click = r_click = False
obs, info, do_reset, ep_return, ep_length, keys_pressed, l_click, r_click = (None,) * 8
reset()
do_wait = False
should_stop = False
frame_list = []
# Initialize video recording
video_dir = "/home/alienware3/Documents/diamond/videos"
os.makedirs(video_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
video_filename = os.path.join(video_dir, f"game_recording_{timestamp}.mp4")
# Video writer will be initialized when we get the first frame
video_writer = None
recording_fps = self.fps
while not should_stop:
do_one_step = False
mouse_x, mouse_y = 0, 0
pygame.event.pump()
for event in pygame.event.get():
if event.type == pygame.QUIT or (event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE):
should_stop = True
if event.type == pygame.MOUSEMOTION:
mouse_x, mouse_y = event.rel
mouse_x *= self.mouse_multiplier
mouse_y *= self.mouse_multiplier
if event.type == pygame.MOUSEBUTTONDOWN:
if event.button == 1:
l_click = True
if event.button == 3:
r_click = True
elif event.type == pygame.MOUSEBUTTONUP:
if event.button == 1:
l_click = False
if event.button == 3:
r_click = False
if event.type == pygame.KEYDOWN:
keys_pressed.append(event.key)
elif event.type == pygame.KEYUP and event.key in keys_pressed:
keys_pressed.remove(event.key)
if event.type != pygame.KEYDOWN:
continue
if event.key == pygame.K_RETURN:
do_reset = True
if event.key == pygame.K_PERIOD:
do_wait = not do_wait
print("Game paused." if do_wait else "Game resumed.")
if event.key == pygame.K_e:
do_one_step = True
if event.key == pygame.K_m:
do_reset = self.env.next_mode()
if event.key == pygame.K_UP:
do_reset = self.env.next_axis_1()
if event.key == pygame.K_DOWN:
do_reset = self.env.prev_axis_1()
if event.key == pygame.K_RIGHT:
do_reset = self.env.next_axis_2()
if event.key == pygame.K_LEFT:
do_reset = self.env.prev_axis_2()
if do_reset:
reset()
if do_wait and not do_one_step:
continue
csgo_action = CSGOAction(keys_pressed, mouse_x, mouse_y, l_click, r_click)
next_obs, rew, end, trunc, info = self.env.step(csgo_action)
ep_return += rew.item()
ep_length += 1
if self.verbose and info is not None:
clear_header()
assert isinstance(info, dict) and "header" in info
header = info["header"]
num_cols = len(header)
for j, col in enumerate(header):
for i, row in enumerate(col):
draw_text(row, idx_line=i, idx_column=j, num_cols=num_cols)
draw_low_res = self.verbose and "obs_low_res" in info and self.width == 280
if draw_low_res:
current_frame = draw_obs(obs, info["obs_low_res"])
draw_text(" Pre-upsampling:", 0, 2, 3)
else:
current_frame = draw_obs(obs, None)
# Initialize video writer with first frame dimensions
if video_writer is None and current_frame is not None:
h, w, c = current_frame.shape
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(video_filename, fourcc, recording_fps, (w, h))
print(f"Started recording video: {video_filename}")
print(f"Video dimensions: {w}x{h}, FPS: {recording_fps}")
# Save frame to video
if video_writer is not None and current_frame is not None:
# Convert RGB to BGR for OpenCV
frame_bgr = cv2.cvtColor(current_frame.astype(np.uint8), cv2.COLOR_RGB2BGR)
# save image to folder
frame_dir = "/home/alienware3/Documents/diamond/frames"
os.makedirs(frame_dir, exist_ok=True)
cv2.imwrite(os.path.join(frame_dir, f"frame_{len(frame_list)}.png"), frame_bgr)
video_writer.write(frame_bgr)
frame_list.append(current_frame)
pygame.display.flip() # update screen
clock.tick(self.fps) # ensures game maintains the given frame rate
if end or trunc:
reset()
else:
obs = next_obs
# Finalize video recording
if video_writer is not None:
video_writer.release()
print(f"Video saved successfully: {video_filename}")
print(f"Total frames recorded: {len(frame_list)}")
pygame.quit()