| |
| from typing import Callable, Optional |
|
|
| import numpy as np |
|
|
| from .motion_utils import NoiseType |
|
|
|
|
| class ContextType: |
| UNIFORM_WINDOW = "uniform window" |
|
|
|
|
| class ContextOptions: |
| CONTEXT_TYPE = None |
|
|
| def __init__(self): |
| pass |
|
|
|
|
| class UniformContextOptions(ContextOptions): |
| CONTEXT_TYPE = ContextType.UNIFORM_WINDOW |
|
|
| def __init__(self, context_length: int, context_stride: int, context_overlap: int, context_schedule: int, closed_loop: bool): |
| self.context_length = context_length |
| self.context_stride = context_stride |
| self.context_overlap = context_overlap |
| self.context_schedule = context_schedule |
| self.closed_loop = closed_loop |
| self.sync_context_to_pe = False |
| self.noise_type = NoiseType.DEFAULT |
| |
| def set_sync_context_to_pe(self, sync_context_to_pe: bool): |
| self.sync_context_to_pe = sync_context_to_pe |
| |
| def set_noise_type(self, noise_type: str): |
| self.noise_type = noise_type |
|
|
|
|
| class ContextSchedules: |
| UNIFORM = "uniform" |
| UNIFORM_CONSTANT = "uniform_constant" |
| UNIFORM_V2 = "uniform v2" |
|
|
| CONTEXT_SCHEDULE_LIST = [UNIFORM] |
|
|
|
|
| |
| def ordered_halving(val, print_final=False): |
| |
| bin_str = f"{val:064b}" |
| |
| bin_flip = bin_str[::-1] |
| |
| as_int = int(bin_flip, 2) |
| |
| |
| final = as_int / (1 << 64) |
| if print_final: |
| print(f"$$$$ final: {final}") |
| return final |
|
|
|
|
| |
| def uniform( |
| step: int = ..., |
| num_steps: Optional[int] = None, |
| num_frames: int = ..., |
| context_size: Optional[int] = None, |
| context_stride: int = 3, |
| context_overlap: int = 4, |
| closed_loop: bool = True, |
| print_final: bool = False, |
| ): |
| if num_frames <= context_size: |
| yield list(range(num_frames)) |
| return |
|
|
| context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) |
|
|
| for context_step in 1 << np.arange(context_stride): |
| pad = int(round(num_frames * ordered_halving(step, print_final))) |
| for j in range( |
| int(ordered_halving(step) * context_step) + pad, |
| num_frames + pad + (0 if closed_loop else -context_overlap), |
| (context_size * context_step - context_overlap), |
| ): |
| yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)] |
|
|
| def uniform_v2( |
| step: int = ..., |
| num_steps: Optional[int] = None, |
| num_frames: int = ..., |
| context_size: Optional[int] = None, |
| context_stride: int = 3, |
| context_overlap: int = 4, |
| closed_loop: bool = True, |
| print_final: bool = False, |
| ): |
| if num_frames <= context_size: |
| yield list(range(num_frames)) |
| return |
|
|
| context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) |
|
|
| pad = int(round(num_frames * ordered_halving(step, print_final))) |
| for context_step in 1 << np.arange(context_stride): |
| j_initial = int(ordered_halving(step) * context_step) + pad |
| for j in range( |
| j_initial, |
| num_frames + pad - context_overlap, |
| (context_size * context_step - context_overlap), |
| ): |
| if context_size * context_step > num_frames: |
| |
| |
| yield [e % num_frames for e in range(j, j + num_frames, context_step)] |
| continue |
| j = j % num_frames |
| if j > (j + context_size * context_step) % num_frames and not closed_loop: |
| yield [e for e in range(j, num_frames, context_step)] |
| j_stop = (j + context_size * context_step) % num_frames |
| |
| |
| |
| |
| |
| |
| |
| |
| yield [e for e in range(0, j_stop, context_step)] |
| continue |
| yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)] |
|
|
|
|
| def uniform_constant( |
| step: int = ..., |
| num_steps: Optional[int] = None, |
| num_frames: int = ..., |
| context_size: Optional[int] = None, |
| context_stride: int = 3, |
| context_overlap: int = 4, |
| closed_loop: bool = True, |
| print_final: bool = False, |
| ): |
| if num_frames <= context_size: |
| yield list(range(num_frames)) |
| return |
|
|
| context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) |
|
|
| |
|
|
| for context_step in 1 << np.arange(context_stride): |
| pad = int(round(num_frames * ordered_halving(step, print_final))) |
| for j in range( |
| int(ordered_halving(step) * context_step) + pad, |
| num_frames + pad + (0 if closed_loop else -context_overlap), |
| (context_size * context_step - context_overlap), |
| ): |
| skip_this_window = False |
| prev_val = -1 |
| to_yield = [] |
| for e in range(j, j + context_size * context_step, context_step): |
| e = e % num_frames |
| |
| if not closed_loop and e < prev_val: |
| skip_this_window = True |
| break |
| to_yield.append(e) |
| prev_val = e |
| if skip_this_window: |
| continue |
| |
| yield to_yield |
|
|
|
|
| |
| UNIFORM_CONTEXT_MAPPING = { |
| ContextSchedules.UNIFORM: uniform, |
| ContextSchedules.UNIFORM_CONSTANT: uniform_constant, |
| ContextSchedules.UNIFORM_V2: uniform_v2, |
| } |
|
|
|
|
| |
| def get_context_scheduler(name: str) -> Callable: |
| context_func = UNIFORM_CONTEXT_MAPPING.get(name, None) |
| if not context_func: |
| raise ValueError(f"Unknown context_overlap policy {name}") |
| return context_func |
|
|
|
|
| def get_total_steps( |
| scheduler, |
| timesteps: list[int], |
| num_steps: Optional[int] = None, |
| num_frames: int = ..., |
| context_size: Optional[int] = None, |
| context_stride: int = 3, |
| context_overlap: int = 4, |
| closed_loop: bool = True, |
| ): |
| return sum( |
| len( |
| list( |
| scheduler( |
| i, |
| num_steps, |
| num_frames, |
| context_size, |
| context_stride, |
| context_overlap, |
| ) |
| ) |
| ) |
| for i in range(len(timesteps)) |
| ) |
|
|
|
|
| def get_total_steps_fixed( |
| scheduler, |
| timesteps: list[int], |
| num_steps: Optional[int] = None, |
| num_frames: int = ..., |
| context_size: Optional[int] = None, |
| context_stride: int = 3, |
| context_overlap: int = 4, |
| closed_loop: bool = True, |
| ): |
| total_loops = 0 |
| for i, t in enumerate(timesteps): |
| for context in scheduler(i, num_steps, num_frames, context_size, context_stride, context_overlap, closed_loop=closed_loop): |
| total_loops += 1 |
| return total_loops |
|
|