File size: 2,842 Bytes
026a224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from contextlib import contextmanager
import os
import re
from time import time
import torch
from rb import ReplayBuffer, SocketManager


@contextmanager
def timer(header):
    time_start = time()
    yield
    print(header, (time() - time_start))


class DataLoader:
    def __init__(self, port, cpus, batch_size, sgf_prefix, is_selfplay):
        seed = torch.randint(0, 2**32, [1]).item()
        rb = ReplayBuffer()
        rb.max_iters = 20
        # moves_per_iter = #games * #moves_per_game
        rb.moves_per_iter = 5000 * 60
        rb.run(seed, cpus, batch_size)
        sock = SocketManager()
        sock.run(port)
        self._rb = rb
        self._sock = sock
        # total = #moves_per_game * #isom / batch_size
        self._total = int(rb.moves_per_iter * 1 / batch_size)
        self._sgf_prefix = sgf_prefix
        self._iter = 0
        self._is_selfplay = is_selfplay
        if not is_selfplay:
            import zmq

            self._ctx = zmq.Context()
            zsock = self._ctx.socket(zmq.DEALER)
            zsock.setsockopt(zmq.LINGER, 0)
            zsock.setsockopt(zmq.ROUTING_ID, b"0")
            zsock.connect(f"tcp://127.0.0.1:{port}")
            zsock.send_multipart([b"0", b"0"])
            self._zsock = zsock

    def load(self, sgf_prefix, epoch_ckpt):
        rb, sock = self._rb, self._sock
        for i in range(epoch_ckpt):
            if self._iter > 0:
                sock.notify()
            if i == 0 or i + rb.max_iters >= epoch_ckpt:
                print(f"[{i:3d}] Load selfplay")
                pattern = re.compile(rf"iter-{self._iter}-(\d+).sgf")
                nodes = 1 + max(
                    int(m.group(1))
                    for f in os.listdir(sgf_prefix)
                    if (m := pattern.search(f))
                )
                rb.add_iter(sgf_prefix, self._iter, nodes)
                for _ in range(self._total):
                    rb.sample().free()
            self._iter += 1

    def __del__(self):
        self._rb.terminate()
        self._sock.terminate()

    def __iter__(self):
        rb, sock = self._rb, self._sock
        if self._iter > 0:
            sock.notify()
        if self._is_selfplay:
            with timer("[{:3d}] Time for selfplay:".format(self._iter)):
                if sock.wait():
                    exit(0)  # SIGINT
        else:
            finished = rb.moves_per_iter + 1
            self._zsock.send_multipart(
                [bytes(str(self._iter), "utf-8"), bytes(str(finished), "utf-8")]
            )
        with timer("[{:3d}] Time for training:".format(self._iter)):
            rb.add_iter(self._sgf_prefix, self._iter, sock.nodes)
            for _ in range(self._total):
                sample = rb.sample()
                yield sample
                sample.free()
        self._iter += 1