Fabrice-TIERCELIN commited on
Commit
bdde32d
·
verified ·
1 Parent(s): 67a6ee7

Upload keyframe_cond.py

Browse files
packages/ltx-core/src/ltx_core/keyframe_cond.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ltx_core.components.patchifiers import get_pixel_coords
4
+ from ltx_core.conditioning.item import ConditioningItem
5
+ from ltx_core.tools import VideoLatentTools
6
+ from ltx_core.types import LatentState, VideoLatentShape
7
+
8
+
9
+ class VideoConditionByKeyframeIndex(ConditioningItem):
10
+ """
11
+ Conditions video generation on keyframe latents at a specific frame index.
12
+ Appends keyframe tokens to the latent state with positions offset by frame_idx,
13
+ and sets denoise strength according to the strength parameter.
14
+ """
15
+
16
+ def __init__(self, keyframes: torch.Tensor, frame_idx: int, strength: float):
17
+ self.keyframes = keyframes
18
+ self.frame_idx = frame_idx
19
+ self.strength = strength
20
+
21
+ def apply_to(
22
+ self,
23
+ latent_state: LatentState,
24
+ latent_tools: VideoLatentTools,
25
+ ) -> LatentState:
26
+ tokens = latent_tools.patchifier.patchify(self.keyframes)
27
+ latent_coords = latent_tools.patchifier.get_patch_grid_bounds(
28
+ output_shape=VideoLatentShape.from_torch_shape(self.keyframes.shape),
29
+ device=self.keyframes.device,
30
+ )
31
+ positions = get_pixel_coords(
32
+ latent_coords=latent_coords,
33
+ scale_factors=latent_tools.scale_factors,
34
+ causal_fix=latent_tools.causal_fix if self.frame_idx == 0 else False,
35
+ )
36
+
37
+ positions[:, 0, ...] += self.frame_idx
38
+ positions = positions.to(dtype=torch.float32)
39
+ positions[:, 0, ...] /= latent_tools.fps
40
+
41
+ denoise_mask = torch.full(
42
+ size=(*tokens.shape[:2], 1),
43
+ fill_value=1.0 - self.strength,
44
+ device=self.keyframes.device,
45
+ dtype=self.keyframes.dtype,
46
+ )
47
+
48
+ return LatentState(
49
+ latent=torch.cat([latent_state.latent, tokens], dim=1),
50
+ denoise_mask=torch.cat([latent_state.denoise_mask, denoise_mask], dim=1),
51
+ positions=torch.cat([latent_state.positions, positions], dim=2),
52
+ clean_latent=torch.cat([latent_state.clean_latent, tokens], dim=1),
53
+ )