| | from dataclasses import dataclass, field |
| | from typing import List, Literal, Union |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | from fish_speech.tokenizer import ( |
| | IM_END_TOKEN, |
| | MODALITY_TOKENS, |
| | FishTokenizer, |
| | ) |
| |
|
| |
|
| | def restore_ndarray(obj, to_tensor: bool = False): |
| | if isinstance(obj, dict) and "__ndarray__" in obj: |
| | obj = np.frombuffer(obj["data"], dtype=obj["dtype"]).reshape(obj["shape"]) |
| |
|
| | if to_tensor and isinstance(obj, np.ndarray): |
| | obj = torch.from_numpy(obj.copy()) |
| |
|
| | return obj |
| |
|
| |
|
| | @dataclass |
| | class BasePart: |
| | type: Literal["text", "vq", "audio"] | None = None |
| | cal_loss: bool = False |
| |
|
| |
|
| | @dataclass(kw_only=True) |
| | class VQPart(BasePart): |
| | type = "vq" |
| | codes: torch.Tensor |
| |
|
| | def __post_init__(self: "VQPart"): |
| | self.type = "vq" |
| | self.codes = restore_ndarray(self.codes, to_tensor=True) |
| |
|
| |
|
| | @dataclass(kw_only=True) |
| | class TextPart(BasePart): |
| | type = "text" |
| | text: str | None = None |
| | tokens: list[int] | None = None |
| |
|
| | def __post_init__(self: "TextPart"): |
| | self.type = "text" |
| | if self.text is None and self.tokens is None: |
| | raise ValueError("Either text or tokens must be provided") |
| |
|
| |
|
| | @dataclass(kw_only=True) |
| | class AudioPart(BasePart): |
| | type = "audio" |
| | features: torch.Tensor |
| |
|
| | def __post_init__(self: "AudioPart"): |
| | self.type = "audio" |
| | self.features = restore_ndarray(self.features, to_tensor=True) |
| |
|
| |
|
| | @dataclass(kw_only=True) |
| | class EncodedMessage: |
| | tokens: torch.Tensor |
| | labels: torch.Tensor |
| | vq_mask_tokens: torch.Tensor | None = None |
| | vq_mask_labels: torch.Tensor | None = None |
| | vq_parts: list[torch.Tensor] |
| | vq_require_losses: torch.Tensor | None = None |
| | audio_parts: list[torch.Tensor] |
| | audio_masks: torch.Tensor | None = None |
| | metadata: dict | None = None |
| |
|
| |
|
| | @dataclass |
| | class ContentSequence: |
| | """ |
| | Flexible sequence of content parts that supports interleaved multimodal format. |
| | Example format: <|interleave|><|speaker:1|> TEXT AUDIO <|im_end|><|speaker:2|> TEXT AUDIO <|im_end|> |
| | """ |
| |
|
| | parts: list[BasePart] = field(default_factory=list) |
| | modality: Literal["text", "voice", "interleave"] | None = None |
| | metadata: dict | None = None |
| |
|
| | def __init__( |
| | self: "ContentSequence", |
| | parts: list[BasePart | dict] | None = None, |
| | modality: Literal["text", "voice", "interleave"] | None = None, |
| | metadata: dict | None = None, |
| | ): |
| | self.modality = modality |
| | self.metadata = metadata or {} |
| |
|
| | fixed_parts = [] |
| | for part in parts or []: |
| | if isinstance(part, dict): |
| | if part["type"] == "vq": |
| | part = VQPart(**part) |
| | elif part["type"] == "audio": |
| | part = AudioPart(**part) |
| | elif part["type"] == "text": |
| | part = TextPart(**part) |
| | else: |
| | raise ValueError(f"Unsupported part type: {part['type']}") |
| | fixed_parts.append(part) |
| |
|
| | self.parts = fixed_parts |
| |
|
| | |
| | if self.modality and not ( |
| | len(self.parts) > 0 |
| | and isinstance(self.parts[0], dict) is False |
| | and isinstance(self.parts[0], TextPart) |
| | and self.parts[0].text is not None |
| | and self.parts[0].text.startswith(MODALITY_TOKENS[self.modality]) |
| | ): |
| | modality_token = MODALITY_TOKENS[self.modality] |
| | self.parts.insert(0, TextPart(text=modality_token)) |
| |
|
| | def append( |
| | self: "ContentSequence", |
| | part_or_parts: Union[BasePart, List[BasePart]], |
| | add_end: bool = False, |
| | speaker: Union[str, int] | None = None, |
| | ): |
| | """ |
| | Append a part or list of parts to the sequence. |
| | |
| | Args: |
| | part_or_parts: A single part or list of parts to add |
| | add_end: Whether to add the IM_END_TOKEN after these parts |
| | speaker: Optional speaker identifier (name or ID) to add before the parts |
| | """ |
| | |
| | parts_to_add = ( |
| | [part_or_parts] if not isinstance(part_or_parts, list) else part_or_parts |
| | ) |
| |
|
| | |
| | if speaker is not None: |
| | speaker_token = f"<|speaker:{speaker}|>" |
| | self.parts.append(TextPart(text=speaker_token)) |
| |
|
| | |
| | self.parts.extend(parts_to_add) |
| |
|
| | |
| | if add_end: |
| | self.parts.append( |
| | TextPart(text=IM_END_TOKEN, cal_loss=self.parts[-1].cal_loss) |
| | ) |
| |
|
| | def encode( |
| | self: "ContentSequence", |
| | tokenizer: FishTokenizer, |
| | add_shift: bool = True, |
| | ignore_loss_tokens: list[str] = [], |
| | ) -> EncodedMessage: |
| | """ |
| | Encode the sequence parts into tokens for the model. |
| | |
| | Args: |
| | tokenizer: The tokenizer to use |
| | add_shift: Whether to shift tokens for next-token prediction |
| | ignore_loss_tokens: List of token strings to ignore when calculating loss |
| | |
| | Returns: |
| | EncodedMessage with tensors ready for the model |
| | """ |
| | all_tokens = [] |
| | all_labels = [] |
| |
|
| | |
| | vq_parts = [] |
| | vq_masks = [] |
| | vq_require_losses = [] |
| |
|
| | audio_parts = [] |
| | audio_masks = [] |
| |
|
| | ignore_loss_token_ids = [tokenizer.get_token_id(i) for i in ignore_loss_tokens] |
| |
|
| | for part in self.parts: |
| | if isinstance(part, TextPart): |
| | if part.tokens is None: |
| | assert part.text is not None |
| | tokens = tokenizer.encode(part.text) |
| | else: |
| | tokens = part.tokens |
| |
|
| | tokens = torch.tensor(tokens, dtype=torch.int) |
| | elif isinstance(part, VQPart): |
| | curr_codes = part.codes.clone().to(torch.int) |
| | tokens = torch.tensor( |
| | [ |
| | tokenizer.semantic_id_to_token_id[int(i.item())] |
| | for i in curr_codes[0].int() |
| | ], |
| | dtype=torch.int, |
| | ) |
| | vq_parts.append(curr_codes) |
| | vq_require_losses.append(part.cal_loss) |
| | else: |
| | raise ValueError(f"Unsupported part type: {type(part)}") |
| |
|
| | all_tokens.append(tokens) |
| |
|
| | |
| | if isinstance(part, VQPart): |
| | vq_masks.append(torch.ones_like(tokens, dtype=torch.bool)) |
| | audio_masks.append(torch.zeros_like(tokens, dtype=torch.bool)) |
| | elif isinstance(part, AudioPart): |
| | vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool)) |
| | audio_mask = torch.ones_like(tokens, dtype=torch.bool) |
| | audio_mask[0] = False |
| | audio_mask[-1] = False |
| | audio_masks.append(audio_mask) |
| | else: |
| | vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool)) |
| | audio_masks.append(torch.zeros_like(tokens, dtype=torch.bool)) |
| |
|
| | |
| | if part.cal_loss and not isinstance(part, AudioPart): |
| | all_labels.append(tokens.clone()) |
| | else: |
| | all_labels.append(torch.full_like(tokens, -100)) |
| |
|
| | |
| | tokens = torch.cat(all_tokens, dim=0) |
| | labels = torch.cat(all_labels, dim=0) |
| | vq_masks = torch.cat(vq_masks, dim=0) |
| | audio_masks = torch.cat(audio_masks, dim=0) |
| | vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool) |
| |
|
| | |
| | vq_mask_tokens = vq_masks |
| | vq_mask_labels = vq_masks |
| |
|
| | if add_shift: |
| | tokens = tokens[:-1] |
| | labels = labels[1:] |
| | vq_masks = vq_masks[:-1] |
| | vq_mask_tokens = vq_mask_tokens[:-1] |
| | vq_mask_labels = vq_mask_labels[1:] |
| | audio_masks = audio_masks[:-1] |
| |
|
| | |
| | for i in ignore_loss_token_ids: |
| | assert i != -100 and i is not None |
| | labels[labels == i] = -100 |
| |
|
| | assert tokens.dtype in [ |
| | torch.int, |
| | torch.long, |
| | ], f"Invalid dtype: {tokens.dtype}" |
| |
|
| | return EncodedMessage( |
| | tokens=tokens, |
| | labels=labels, |
| | vq_parts=vq_parts, |
| | vq_mask_tokens=vq_mask_tokens, |
| | vq_mask_labels=vq_mask_labels, |
| | vq_require_losses=vq_require_losses, |
| | audio_parts=audio_parts, |
| | audio_masks=audio_masks, |
| | metadata=self.metadata, |
| | ) |
| |
|
| | def encode_for_inference( |
| | self: "ContentSequence", |
| | tokenizer: FishTokenizer, |
| | num_codebooks: int, |
| | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | encoded = self.encode(tokenizer, add_shift=False) |
| | tokens = encoded.tokens |
| | values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int) |
| | values[0] = tokens |
| |
|
| | if (encoded.vq_parts is None or len(encoded.vq_parts) == 0) and ( |
| | encoded.audio_parts is None or len(encoded.audio_parts) == 0 |
| | ): |
| | return values, None, None |
| |
|
| | audio_parts = audio_masks = None |
| | if encoded.vq_parts is not None and len(encoded.vq_parts) > 0: |
| | vq_parts = encoded.vq_parts |
| | vq_parts = torch.cat(vq_parts, dim=1) |
| | values[0, encoded.vq_mask_tokens] = ( |
| | vq_parts[0] + tokenizer.semantic_begin_id |
| | ) |
| | values[1:, encoded.vq_mask_tokens] = vq_parts |
| |
|
| | if encoded.audio_parts is not None and len(encoded.audio_parts) > 0: |
| | audio_parts = torch.cat(encoded.audio_parts, dim=0) |
| | audio_masks = encoded.audio_masks[None, :] |
| |
|
| | return values, audio_masks, audio_parts |
| |
|
| | def visualize( |
| | self: "ContentSequence", |
| | tokenizer: FishTokenizer, |
| | ignore_loss_tokens: list[str] = [], |
| | merge_semantic_tokens: bool = False, |
| | ): |
| | """ |
| | Visualize the encoded sequence with color-coded tokens. |
| | Blue/cyan tokens contribute to loss, green tokens do not. |
| | """ |
| | encoded = self.encode( |
| | tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens |
| | ) |
| |
|
| | |
| | colors = { |
| | "blue": "\033[94m", |
| | "cyan": "\033[96m", |
| | "green": "\033[92m", |
| | "dark_green": "\033[32m", |
| | } |
| | blue_idx = 0 |
| | green_idx = 0 |
| |
|
| | def print_in_blue(x): |
| | nonlocal blue_idx |
| | color = colors["blue"] if blue_idx % 2 == 0 else colors["cyan"] |
| | print(f"{color}{x}\033[0m", end="") |
| | blue_idx += 1 |
| |
|
| | def print_in_green(x): |
| | nonlocal green_idx |
| | color = colors["green"] if green_idx % 2 == 0 else colors["dark_green"] |
| | print(f"{color}{x}\033[0m", end="") |
| | green_idx += 1 |
| |
|
| | def print_semantic_token(x, count): |
| | val = f"[<|semantic|>x{count}]" |
| | if x == -100: |
| | print_in_green(val) |
| | else: |
| | print_in_blue(val) |
| |
|
| | count_semantic_tokens = 0 |
| | semantic_label = None |
| |
|
| | for tok, lab in zip(encoded.tokens, encoded.labels): |
| | token_id = int(tok.item()) |
| |
|
| | if merge_semantic_tokens: |
| | if ( |
| | tokenizer.semantic_begin_id <= token_id <= tokenizer.semantic_end_id |
| | and (semantic_label is None or semantic_label == lab) |
| | ): |
| | count_semantic_tokens += 1 |
| | semantic_label = lab |
| | continue |
| | elif count_semantic_tokens > 0: |
| | print_semantic_token(semantic_label, count_semantic_tokens) |
| | count_semantic_tokens = 0 |
| | semantic_label = None |
| |
|
| | val = tokenizer.decode([int(tok.item())]) |
| |
|
| | if lab == -100: |
| | print_in_green(val) |
| | else: |
| | print_in_blue(val) |
| |
|
| | if merge_semantic_tokens and count_semantic_tokens > 0: |
| | print_semantic_token(semantic_label, count_semantic_tokens) |
| |
|
| | print() |
| |
|