connect-4-API / main2.py
Gruhit Patel
init-backend
1fab54b verified
from model import Model
from config import Config
from arena import get_move_for_bot
from game import Connect4
import pygame
from view_board import draw_board, draw_winning_line
import sys
import torch
def play_game(model: Model):
board = Connect4(
row = Config.row,
col = Config.col
)
pygame.init()
screen = pygame.display.set_mode((Config.col*100, (Config.row+1)*100))
ai_turn = True
game_end = False
while True:
draw_board(screen, board.board)
draw_winning_line(screen, board.winning_start, board.winning_end)
# render(board.board)
if ai_turn and not game_end:
# print("Getting move from AI...")
act = get_move_for_bot(board, model, Config.tree_iter)
# print(f"AI moved in column {act}")
board, win = board.drop_piece(act)
if win is not None:
print("AI has WON")
print("Board \n")
print(board)
print("Winner is...", win)
game_end = True
ai_turn = False
draw_board(screen, board.board)
pygame.display.update()
for event in pygame.event.get():
if event.type == pygame.QUIT:
sys.exit()
if event.type == pygame.MOUSEBUTTONDOWN and not game_end:
posx = event.pos[0]
act = posx//100
board, win = board.drop_piece(act)
ai_turn = True
if win is not None:
print("Human has Won")
print("Board \n")
print(board)
game_end = True
if event.type == pygame.MOUSEMOTION and not game_end:
pygame.draw.rect(screen, (0, 0, 0), (0, 0, 700, 100))
posx = event.pos[0]
# If ai is turn 1 then player's turn is second
if board.player_1 == -1:
pygame.draw.circle(screen, (230,230,20), (posx, int(100//2)), 50)
else:
pygame.draw.circle(screen, (52, 186, 235), (posx, int(100//2)), 50)
pygame.display.update()
if __name__ == "__main__":
model = Model(
n_action = Config.n_action,
num_hidden = Config.num_hidden,
num_resblock = Config.num_res_block,
rate = Config.rate,
row = Config.row,
col = Config.col,
device = Config.device
)
# This is LR = .01 model
# model_path = './Models/C4GruhitSPatel/FullBuffer5x5V1/TargetModel_500.pt'
# This is LR = .001 model
model_path = "./Models/C4GruhitML/C4CyclicLRV3/TargetModel_500.pt"
model.load_state_dict(torch.load(model_path))
model.eval()
play_game(model)
# print("Model Loaded")