Spaces:
Running on Zero
Running on Zero
File size: 3,159 Bytes
e340a84 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import random
import torch
from typing import Optional, Tuple
class KeyframeSelector:
def __init__(
self,
min_interval: int = 8,
max_interval: int = 8,
force_first: bool = True,
motion_threshold: Optional[float] = None,
mode: str = "fixed",
):
self.min_interval = int(min_interval)
self.max_interval = int(max_interval)
self.force_first = bool(force_first)
self.motion_threshold = motion_threshold
self.mode = mode
def select_keyframes(
self,
sequence_length: int,
batch_size: int = 1,
device: Optional[torch.device] = None,
poses: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
device = device or torch.device("cpu")
is_keyframe = torch.zeros(
batch_size, sequence_length, dtype=torch.bool, device=device
)
keyframe_indices = torch.zeros(
batch_size, sequence_length, dtype=torch.long, device=device
)
for b in range(batch_size):
last_keyframe_idx = 0
next_keyframe_target = None
if self.force_first or sequence_length == 1:
is_keyframe[b, 0] = True
keyframe_indices[b, 0] = 0
if self.mode == "random":
interval = random.randint(self.min_interval, self.max_interval)
next_keyframe_target = interval
for s in range(1, sequence_length):
keyframe_indices[b, s] = last_keyframe_idx
frames_since_last = s - last_keyframe_idx
if self.mode == "random" and next_keyframe_target is not None:
if s >= next_keyframe_target:
is_keyframe[b, s] = True
last_keyframe_idx = s
interval = random.randint(self.min_interval, self.max_interval)
next_keyframe_target = s + interval
elif frames_since_last >= self.max_interval:
is_keyframe[b, s] = True
last_keyframe_idx = s
if self.mode == "random":
interval = random.randint(self.min_interval, self.max_interval)
next_keyframe_target = s + interval
elif (
frames_since_last >= self.min_interval
and poses is not None
and self.motion_threshold is not None
):
motion = torch.norm(
poses[b, s, :3] - poses[b, last_keyframe_idx, :3]
).item()
if motion > self.motion_threshold:
is_keyframe[b, s] = True
last_keyframe_idx = s
if self.mode == "random":
interval = random.randint(
self.min_interval, self.max_interval
)
next_keyframe_target = s + interval
return is_keyframe, keyframe_indices
|