from contextlib import contextmanager import os import re from time import time import torch from rb import ReplayBuffer, SocketManager @contextmanager def timer(header): time_start = time() yield print(header, (time() - time_start)) class DataLoader: def __init__(self, port, cpus, batch_size, sgf_prefix, is_selfplay): seed = torch.randint(0, 2**32, [1]).item() rb = ReplayBuffer() rb.max_iters = 20 # moves_per_iter = #games * #moves_per_game rb.moves_per_iter = 5000 * 60 rb.run(seed, cpus, batch_size) sock = SocketManager() sock.run(port) self._rb = rb self._sock = sock # total = #moves_per_game * #isom / batch_size self._total = int(rb.moves_per_iter * 1 / batch_size) self._sgf_prefix = sgf_prefix self._iter = 0 self._is_selfplay = is_selfplay if not is_selfplay: import zmq self._ctx = zmq.Context() zsock = self._ctx.socket(zmq.DEALER) zsock.setsockopt(zmq.LINGER, 0) zsock.setsockopt(zmq.ROUTING_ID, b"0") zsock.connect(f"tcp://127.0.0.1:{port}") zsock.send_multipart([b"0", b"0"]) self._zsock = zsock def load(self, sgf_prefix, epoch_ckpt): rb, sock = self._rb, self._sock for i in range(epoch_ckpt): if self._iter > 0: sock.notify() if i == 0 or i + rb.max_iters >= epoch_ckpt: print(f"[{i:3d}] Load selfplay") pattern = re.compile(rf"iter-{self._iter}-(\d+).sgf") nodes = 1 + max( int(m.group(1)) for f in os.listdir(sgf_prefix) if (m := pattern.search(f)) ) rb.add_iter(sgf_prefix, self._iter, nodes) for _ in range(self._total): rb.sample().free() self._iter += 1 def __del__(self): self._rb.terminate() self._sock.terminate() def __iter__(self): rb, sock = self._rb, self._sock if self._iter > 0: sock.notify() if self._is_selfplay: with timer("[{:3d}] Time for selfplay:".format(self._iter)): if sock.wait(): exit(0) # SIGINT else: finished = rb.moves_per_iter + 1 self._zsock.send_multipart( [bytes(str(self._iter), "utf-8"), bytes(str(finished), "utf-8")] ) with timer("[{:3d}] Time for training:".format(self._iter)): rb.add_iter(self._sgf_prefix, self._iter, sock.nodes) for _ in range(self._total): sample = rb.sample() yield sample sample.free() self._iter += 1