File size: 2,842 Bytes
026a224 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from contextlib import contextmanager
import os
import re
from time import time
import torch
from rb import ReplayBuffer, SocketManager
@contextmanager
def timer(header):
time_start = time()
yield
print(header, (time() - time_start))
class DataLoader:
def __init__(self, port, cpus, batch_size, sgf_prefix, is_selfplay):
seed = torch.randint(0, 2**32, [1]).item()
rb = ReplayBuffer()
rb.max_iters = 20
# moves_per_iter = #games * #moves_per_game
rb.moves_per_iter = 5000 * 60
rb.run(seed, cpus, batch_size)
sock = SocketManager()
sock.run(port)
self._rb = rb
self._sock = sock
# total = #moves_per_game * #isom / batch_size
self._total = int(rb.moves_per_iter * 1 / batch_size)
self._sgf_prefix = sgf_prefix
self._iter = 0
self._is_selfplay = is_selfplay
if not is_selfplay:
import zmq
self._ctx = zmq.Context()
zsock = self._ctx.socket(zmq.DEALER)
zsock.setsockopt(zmq.LINGER, 0)
zsock.setsockopt(zmq.ROUTING_ID, b"0")
zsock.connect(f"tcp://127.0.0.1:{port}")
zsock.send_multipart([b"0", b"0"])
self._zsock = zsock
def load(self, sgf_prefix, epoch_ckpt):
rb, sock = self._rb, self._sock
for i in range(epoch_ckpt):
if self._iter > 0:
sock.notify()
if i == 0 or i + rb.max_iters >= epoch_ckpt:
print(f"[{i:3d}] Load selfplay")
pattern = re.compile(rf"iter-{self._iter}-(\d+).sgf")
nodes = 1 + max(
int(m.group(1))
for f in os.listdir(sgf_prefix)
if (m := pattern.search(f))
)
rb.add_iter(sgf_prefix, self._iter, nodes)
for _ in range(self._total):
rb.sample().free()
self._iter += 1
def __del__(self):
self._rb.terminate()
self._sock.terminate()
def __iter__(self):
rb, sock = self._rb, self._sock
if self._iter > 0:
sock.notify()
if self._is_selfplay:
with timer("[{:3d}] Time for selfplay:".format(self._iter)):
if sock.wait():
exit(0) # SIGINT
else:
finished = rb.moves_per_iter + 1
self._zsock.send_multipart(
[bytes(str(self._iter), "utf-8"), bytes(str(finished), "utf-8")]
)
with timer("[{:3d}] Time for training:".format(self._iter)):
rb.add_iter(self._sgf_prefix, self._iter, sock.nodes)
for _ in range(self._total):
sample = rb.sample()
yield sample
sample.free()
self._iter += 1
|