| | |
| | from typing import Callable, List, Optional |
| |
|
| | import numpy as np |
| |
|
| |
|
| | def ordered_halving(val): |
| | bin_str = f"{val:064b}" |
| | bin_flip = bin_str[::-1] |
| | as_int = int(bin_flip, 2) |
| |
|
| | return as_int / (1 << 64) |
| |
|
| |
|
| | def uniform( |
| | step: int = ..., |
| | num_steps: Optional[int] = None, |
| | num_frames: int = ..., |
| | context_size: Optional[int] = None, |
| | context_stride: int = 3, |
| | context_overlap: int = 4, |
| | closed_loop: bool = True, |
| | ): |
| | if num_frames <= context_size: |
| | yield list(range(num_frames)) |
| | return |
| |
|
| | context_stride = min( |
| | context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1 |
| | ) |
| |
|
| | for context_step in 1 << np.arange(context_stride): |
| | pad = int(round(num_frames * ordered_halving(step))) |
| | for j in range( |
| | int(ordered_halving(step) * context_step) + pad, |
| | num_frames + pad + (0 if closed_loop else -context_overlap), |
| | (context_size * context_step - context_overlap), |
| | ): |
| | yield [ |
| | e % num_frames |
| | for e in range(j, j + context_size * context_step, context_step) |
| | ] |
| |
|
| |
|
| | def get_context_scheduler(name: str) -> Callable: |
| | if name == "uniform": |
| | return uniform |
| | else: |
| | raise ValueError(f"Unknown context_overlap policy {name}") |
| |
|
| |
|
| | def get_total_steps( |
| | scheduler, |
| | timesteps: List[int], |
| | num_steps: Optional[int] = None, |
| | num_frames: int = ..., |
| | context_size: Optional[int] = None, |
| | context_stride: int = 3, |
| | context_overlap: int = 4, |
| | closed_loop: bool = True, |
| | ): |
| | return sum( |
| | len( |
| | list( |
| | scheduler( |
| | i, |
| | num_steps, |
| | num_frames, |
| | context_size, |
| | context_stride, |
| | context_overlap, |
| | ) |
| | ) |
| | ) |
| | for i in range(len(timesteps)) |
| | ) |
| |
|