File size: 1,864 Bytes
1fab54b
 
 
 
 
 
 
 
 
 
e153fcc
1fab54b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e153fcc
 
 
 
 
 
 
 
 
1fab54b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from game import Connect4
from model import Model
from config import Config
from pydantic import BaseModel
from typing import List, Union
import numpy as np
from arena import get_move_for_bot
import torch
import time

class Request(BaseModel):
    board: List[List[int]]
    currentPlayer: str
    randomMoves: Union[None, bool]
    mctsIterations: Union[None, int]

# Create an application instance
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["GET", "POST"],
    allow_headers=["*"]
)

# Create the model
model = Model(
    n_action = Config.n_action,
    num_hidden = Config.num_hidden,
    num_resblock = Config.num_res_block,
    rate = Config.rate,
    row = Config.row,
    col = Config.col,
    device = Config.device
)
model.load_state_dict(torch.load(Config.checkpoint_path))
model.eval()

# Create a middleware to record time
@app.middleware("http")
async def add_process_time_header(req, call_next):
    start_time = time.time()
    response = await call_next(req)
    process_time = time.time() - start_time
    print(f'Time taken for response = {process_time:.2f} seconds')
    return response

@app.get("/")
def root():
    return {"message": "This is a temporary response"}

@app.post("/get_move")
def get_move(req: Request):
    global model
    board_arr = np.array(req.board)
    board = Connect4()
    board.board = board_arr

    if req.currentPlayer == "yellow":
        (board.player_1, board.player_2) = (board.player_2, board.player_1)

    # TODO: change the tree_iter to req.parameters
    act = get_move_for_bot(
        state = board,
        model = model,
        tree_iters = req.mctsIterations,
        random_move = req.randomMoves
    )

    return {'move': int(act)}