File size: 5,192 Bytes
08c5e28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""Audio reference conditioning item for IC-LoRA voice cloning."""

import torch

from ltx_core.components.patchifiers import AudioPatchifier
from ltx_core.conditioning.item import ConditioningItem
from ltx_core.tools import AudioLatentTools
from ltx_core.types import AudioLatentShape, LatentState


class AudioConditionByReferenceLatent(ConditioningItem):
    """Conditions audio generation on a reference audio latent for voice cloning.

    Mirrors VideoConditionByReferenceLatent but for audio:
    - Patchifies reference latent [B, C, T, F] -> [B, ref_T, 128]
    - Computes 1D temporal positions via AudioPatchifier
    - Sets denoise_mask = 1.0 - strength (strength=1.0 -> mask=0 -> frozen)
    - Builds ASYMMETRIC attention mask: target->ref=1 (attend), ref->target=0 (read-only)
    - APPENDS ref tokens to END of latent sequence (IC-LoRA pattern)
    - Uses OVERLAPPING positions (same coordinate space) so RoPE doesn't
      decay target->ref attention. The asymmetric mask provides the structural
      signal that ref tokens are conditioning, not reconstruction targets.

    Args:
        latent: Reference audio latent [B, C, T, F] (pre-VAE-encoded).
        strength: Conditioning strength. 1.0 = full (ref kept clean),
            0.0 = none (ref fully denoised). Default 1.0.
    """

    def __init__(self, latent: torch.Tensor, strength: float = 1.0):
        self.latent = latent
        self.strength = strength

    def apply_to(
        self,
        latent_state: LatentState,
        latent_tools: AudioLatentTools,
    ) -> LatentState:
        """Append reference audio tokens with positions and attention mask."""
        tokens = latent_tools.patchifier.patchify(self.latent)

        # Compute positions for the reference audio β€” small offset (0.5s) from
        # target start to avoid exact t=0 overlap (which causes ref content to
        # bleed into target start), while keeping RoPE decay minimal.
        # 0.5s / max_pos(20s) = 0.025 fractional β€” negligible RoPE decay.
        ref_shape = AudioLatentShape(
            batch=self.latent.shape[0],
            channels=self.latent.shape[1],
            frames=self.latent.shape[2],
            mel_bins=self.latent.shape[3],
        )
        positions = latent_tools.patchifier.get_patch_grid_bounds(
            output_shape=ref_shape,
            device=self.latent.device,
        )
        # Small offset to prevent t=0 position collision between target and ref
        positions = positions + 0.5

        # Denoise mask: 0 for frozen (strength=1.0), 1 for fully denoised (strength=0.0)
        denoise_mask = torch.full(
            size=(*tokens.shape[:2], 1),
            fill_value=1.0 - self.strength,
            device=self.latent.device,
            dtype=torch.float32,
        )

        # Build ASYMMETRIC attention mask manually.
        # Structure:
        #              target (N)    ref (M)
        #            β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
        #   target   β”‚    1.0     β”‚   1.0    β”‚  target attends to everything
        #    (N)     β”‚            β”‚          β”‚
        #            β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
        #    ref     β”‚    0.0     β”‚   1.0    β”‚  ref only attends to itself
        #    (M)     β”‚            β”‚          β”‚
        #            β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
        #
        # This makes reference tokens "read-only conditioning":
        # - Target tokens freely attend to ref (voice cloning signal)
        # - Ref tokens don't attend to noisy target (stays clean/stable)
        batch_size = tokens.shape[0]
        num_target = latent_state.latent.shape[1]
        num_ref = tokens.shape[1]
        total = num_target + num_ref

        # Use float32 for the [0,1] mask β€” _prepare_self_attention_mask converts
        # to log-space bias in the model's compute dtype before it reaches attention.
        mask = torch.zeros(
            (batch_size, total, total),
            device=self.latent.device,
            dtype=torch.float32,
        )

        # Incorporate existing mask if present, otherwise full attention for target
        if latent_state.attention_mask is not None:
            mask[:, :num_target, :num_target] = latent_state.attention_mask
        else:
            mask[:, :num_target, :num_target] = 1.0

        # Target -> ref: FULL attention (target can read reference voice)
        mask[:, :num_target, num_target:] = 1.0

        # Ref -> target: BLOCKED (ref is read-only, doesn't see noisy target)
        # mask[:, num_target:, :num_target] remains 0.0

        # Ref -> ref: full self-attention within reference
        mask[:, num_target:, num_target:] = 1.0

        return LatentState(
            latent=torch.cat([latent_state.latent, tokens], dim=1),
            denoise_mask=torch.cat([latent_state.denoise_mask, denoise_mask], dim=1),
            positions=torch.cat([latent_state.positions, positions], dim=2),
            clean_latent=torch.cat([latent_state.clean_latent, tokens], dim=1),
            attention_mask=mask,
        )