File size: 5,774 Bytes
8deda9d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding: utf-8
from typing import Any, Dict, List, Literal, NamedTuple
import numpy as np
class FrameSamplerOutput(NamedTuple):
indices: List[int]
additional_info: Dict[str, Any]
class MultiClipsFrameSampler:
"""
Deterministic sampler used by Lance inference for image/video inputs.
The inference dataset always builds a single clip covering the full video.
This sampler keeps the public behavior that matters for inference: sample
at a target FPS, optionally clamp to max_duration, and return a frame count
compatible with the VAE temporal downsample factor.
"""
def __init__(
self,
temporal: int = 4,
sample_fps: int = 12,
truncate: bool = False,
max_duration: int = 12,
length_type: Literal["kn", "kn+1"] = "kn+1",
assert_seconds: bool = True,
):
self.temporal = temporal
self.sample_fps = sample_fps
self.truncate = truncate
self.max_duration = max_duration
self.length_type = length_type
self.assert_seconds = assert_seconds
def __call__(self, frames_info: Dict[str, Any]) -> FrameSamplerOutput:
clip_indices = frames_info["clip_indices"]
origin_fps = frames_info["fps"]
if self.truncate:
clip_indices = self.truncate_to_bucket(clip_indices, origin_fps)
if self.assert_seconds:
duration_sec = int(round(sum((end - start) / origin_fps for start, end in clip_indices)))
if not self.truncate:
duration_sec = min(duration_sec, self.max_duration)
n_frames = duration_sec * self.sample_fps
if self.length_type == "kn+1":
n_frames += 1
else:
duration = sum((end - start) / origin_fps for start, end in clip_indices)
if not self.truncate:
duration = min(duration, self.max_duration)
n_frames = int(round(duration * self.sample_fps))
if self.length_type == "kn+1":
if n_frames % self.temporal != 0:
n_frames = n_frames // self.temporal * self.temporal + 1
else:
n_frames = n_frames // self.temporal * self.temporal + 1 - self.temporal
clip_n_frames = self.split_n_frames_by_clip(n_frames, clip_indices)
sample_indices = self.sample_frame_indices(clip_indices, clip_n_frames)
clip_n_latent_frames = [(n + self.temporal - 1) // self.temporal for n in clip_n_frames]
return FrameSamplerOutput(
indices=sample_indices,
additional_info={
"clip_n_frames": clip_n_frames,
"clip_n_latent_frames": clip_n_latent_frames,
},
)
def truncate_to_bucket(self, clip_indices, fps):
clip_indices = [tuple(index) for index in clip_indices]
durations = [(end - start) / fps for start, end in clip_indices]
duration = sum(durations)
max_duration = min(int(duration), self.max_duration)
cutoff = duration - max_duration
if cutoff <= 0:
return clip_indices
if durations[-1] - cutoff > durations[0] - cutoff:
start, end = clip_indices[-1]
end = min(round((durations[-1] - cutoff) * fps), end) + start
clip_indices[-1] = (start, end)
else:
start, end = clip_indices[0]
start = max(end - round((durations[0] - cutoff) * fps), start)
clip_indices[0] = (start, end)
return clip_indices
def split_n_frames_by_clip(self, n_frames, clip_indices):
n_latent_frames = n_frames // self.temporal
clip_lengths = [end - start for start, end in clip_indices]
total_length = sum(clip_lengths)
clip_n_latent_frames = [int(length / total_length * n_latent_frames) for length in clip_lengths]
n_remains = n_latent_frames - sum(clip_n_latent_frames)
for i in range(n_remains):
clip_n_latent_frames[i] += 1
clip_n_frames = [n * self.temporal for n in clip_n_latent_frames]
if self.length_type == "kn+1":
clip_n_frames[0] += 1
return clip_n_frames
@staticmethod
def sample_frame_indices(clip_indices, clip_n_frames):
shift_clip_indices = []
accum_n_frames = 0
for start, end in clip_indices:
shift_start, shift_end = accum_n_frames, accum_n_frames + (end - start)
shift_clip_indices.append((shift_start, shift_end))
accum_n_frames += end - start
all_sample_indices = []
for i, ((start, end), (shift_start, shift_end), n_frames) in enumerate(
zip(clip_indices, shift_clip_indices, clip_n_frames)
):
indices = np.arange(start, end)
next_shift_start = shift_clip_indices[i + 1][0] if i < len(clip_indices) - 1 else shift_end
shift_sample_indices = (
np.linspace(shift_start, next_shift_start - 1, n_frames, dtype=int) - shift_start
)
all_sample_indices.extend(indices[shift_sample_indices].tolist())
return all_sample_indices
|