File size: 561 Bytes
1faccd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import random
from functools import lru_cache


@lru_cache(maxsize=256)
def make_arithmetic_codes(group_size: int, seed: int) -> tuple[float, ...]:
    if group_size < 1:
        raise ValueError(f"group_size must be positive, got {group_size}")

    shift = random.Random(seed).random()
    return tuple(((i + 0.5) / group_size + shift) % 1.0 for i in range(group_size))


def get_arithmetic_code(group_size: int, seed: int, rollout_n: int) -> float:
    codes = make_arithmetic_codes(group_size=group_size, seed=seed)
    return codes[rollout_n % len(codes)]