import random from functools import lru_cache @lru_cache(maxsize=256) def make_arithmetic_codes(group_size: int, seed: int) -> tuple[float, ...]: if group_size < 1: raise ValueError(f"group_size must be positive, got {group_size}") shift = random.Random(seed).random() return tuple(((i + 0.5) / group_size + shift) % 1.0 for i in range(group_size)) def get_arithmetic_code(group_size: int, seed: int, rollout_n: int) -> float: codes = make_arithmetic_codes(group_size=group_size, seed=seed) return codes[rollout_n % len(codes)]