| import os
|
| import queue
|
| import threading
|
| import time
|
| from contextlib import nullcontext
|
| from dataclasses import dataclass
|
| from pathlib import Path
|
| from typing import Literal, Optional, Tuple, Union
|
|
|
| import click
|
| import hydra
|
| import numpy as np
|
| import torch
|
| import torch._dynamo.config
|
| import torch._inductor.config
|
| from loguru import logger
|
| from tqdm import tqdm
|
| from transformers import AutoTokenizer
|
|
|
| from fish_speech.conversation import (
|
| CODEBOOK_PAD_TOKEN_ID,
|
| Conversation,
|
| Message,
|
| TextPart,
|
| VQPart,
|
| )
|
| from fish_speech.models.text2semantic.llama import BaseModelArgs
|
| from fish_speech.text import clean_text, split_text
|
| from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
|
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| torch._inductor.config.coordinate_descent_tuning = True
|
| torch._inductor.config.triton.unique_kernel_names = True
|
|
|
| if hasattr(torch._inductor.config, "fx_graph_cache"):
|
|
|
| torch._inductor.config.fx_graph_cache = True
|
|
|
|
|
| from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
|
| from fish_speech.models.text2semantic.llama import (
|
| BaseTransformer,
|
| DualARTransformer,
|
| NaiveTransformer,
|
| )
|
|
|
|
|
| def multinomial_sample_one_no_sync(
|
| probs_sort,
|
| ):
|
| q = torch.empty_like(probs_sort).exponential_(1)
|
| return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
|
|
|
|
| def logits_to_probs(
|
| logits,
|
| previous_tokens: Optional[torch.Tensor] = None,
|
| temperature: torch.Tensor = 1.0,
|
| top_p: torch.Tensor = 1.0,
|
| repetition_penalty: torch.Tensor = 1.0,
|
| ) -> torch.Tensor:
|
|
|
| if previous_tokens is not None:
|
| previous_tokens = previous_tokens.long()
|
| score = torch.gather(logits, dim=0, index=previous_tokens)
|
| score = torch.where(
|
| score < 0, score * repetition_penalty, score / repetition_penalty
|
| )
|
| logits.scatter_(dim=0, index=previous_tokens, src=score)
|
|
|
|
|
| sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
| sorted_indices_to_remove = cum_probs > top_p
|
| sorted_indices_to_remove[0] = False
|
| indices_to_remove = sorted_indices_to_remove.scatter(
|
| dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
| )
|
| logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
|
|
| logits = logits / max(temperature, 1e-5)
|
|
|
| probs = torch.nn.functional.softmax(logits, dim=-1)
|
| return probs
|
|
|
|
|
| def multinomial_sample_one_no_sync_agent(
|
| probs_sort,
|
| ):
|
| q = torch.empty_like(probs_sort).exponential_(1)
|
| return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
|
|
|
|
| def logits_to_probs_agent(
|
| logits,
|
| previous_tokens: Optional[torch.Tensor] = None,
|
| temperature: torch.Tensor = 1.0,
|
| top_p: torch.Tensor = 1.0,
|
| repetition_penalty: torch.Tensor = 1.0,
|
| ) -> torch.Tensor:
|
|
|
| if previous_tokens is not None:
|
| previous_tokens = previous_tokens.long()
|
| score = torch.gather(logits, dim=-1, index=previous_tokens)
|
| score = torch.where(
|
| score < 0, score * repetition_penalty, score / repetition_penalty
|
| )
|
| logits.scatter_(dim=-1, index=previous_tokens, src=score)
|
|
|
|
|
| sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
| sorted_indices_to_remove = cum_probs > top_p
|
| sorted_indices_to_remove[..., 0] = False
|
| indices_to_remove = sorted_indices_to_remove.scatter(
|
| dim=-1, index=sorted_indices, src=sorted_indices_to_remove
|
| )
|
| logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
|
|
| logits = logits / max(temperature, 1e-5)
|
|
|
| probs = torch.nn.functional.softmax(logits, dim=-1)
|
| return probs
|
|
|
|
|
| def sample(
|
| logits,
|
| previous_tokens: Optional[torch.Tensor] = None,
|
| **sampling_kwargs,
|
| ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| probs = logits_to_probs(
|
| logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
|
| )
|
| idx_next = multinomial_sample_one_no_sync(probs)
|
| return idx_next, probs
|
|
|
|
|
| def sample_agent(
|
| logits,
|
| previous_tokens: Optional[torch.Tensor] = None,
|
| **sampling_kwargs,
|
| ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| probs = logits_to_probs_agent(
|
| logits=logits[:, -1], previous_tokens=previous_tokens, **sampling_kwargs
|
| )
|
| idx_next = multinomial_sample_one_no_sync_agent(probs)
|
| return idx_next, probs
|
|
|
|
|
| def decode_one_token_ar_agent(
|
| model: DualARTransformer,
|
| x: torch.Tensor,
|
| input_pos: torch.Tensor,
|
| semantic_ids: list,
|
| previous_tokens: torch.Tensor = None,
|
| **sampling_kwargs,
|
| ) -> torch.Tensor:
|
|
|
| x = model.forward_generate(x, input_pos)
|
| logits = x.logits
|
| hidden_states = x.hidden_states
|
|
|
| sampling_kwargs_main = sampling_kwargs.copy()
|
| sampling_kwargs_main["temperature"] = 0.1
|
| sampling_kwargs_main["top_p"] = 0.1
|
| sampling_kwargs_main["repetition_penalty"] = 1.0
|
|
|
| codebooks = [
|
| sample_agent(
|
| logits,
|
| previous_tokens=None,
|
| **sampling_kwargs_main,
|
| )[0]
|
| ]
|
|
|
|
|
| for layer in model.fast_layers:
|
| layer.attention.kv_cache.k_cache.fill_(0)
|
| layer.attention.kv_cache.v_cache.fill_(0)
|
|
|
| for codebook_idx in range(model.config.num_codebooks):
|
| input_pos = torch.tensor(
|
| [codebook_idx], device=hidden_states.device, dtype=torch.long
|
| )
|
| logits = model.forward_generate_fast(hidden_states, input_pos)
|
| a = sample_agent(
|
| logits,
|
| previous_tokens=(
|
| previous_tokens[:, codebook_idx + 1]
|
| if previous_tokens is not None
|
| else None
|
| ),
|
| **sampling_kwargs,
|
| )[0]
|
| hidden_states = model.fast_embeddings(a)
|
| codebooks.append(a)
|
|
|
| codebooks = torch.stack(codebooks, dim=1)
|
| semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
|
| codebooks[:, 1:, :] = torch.masked_fill(
|
| codebooks[:, 1:, :],
|
| ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
|
| CODEBOOK_PAD_TOKEN_ID,
|
| )
|
|
|
| return codebooks
|
|
|
|
|
| def decode_one_token_naive_agent(
|
| model: NaiveTransformer,
|
| x: torch.Tensor,
|
| input_pos: torch.Tensor,
|
| semantic_ids: list,
|
| previous_tokens: torch.Tensor = None,
|
| **sampling_kwargs,
|
| ) -> torch.Tensor:
|
| x = model.forward_generate(x, input_pos)
|
|
|
| codebooks = [
|
| sample(
|
| x.token_logits,
|
| previous_tokens=None,
|
| **sampling_kwargs,
|
| )[0]
|
| ]
|
|
|
| for i in range(model.config.num_codebooks):
|
| codebooks.append(
|
| sample_agent(
|
| x.codebook_logits[:, :, i],
|
| previous_tokens=(
|
| previous_tokens[:, i + 1] if previous_tokens is not None else None
|
| ),
|
| **sampling_kwargs,
|
| )[0]
|
| )
|
|
|
| codebooks = torch.stack(codebooks, dim=1)
|
| semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
|
| codebooks[:, 1:, :] = torch.masked_fill(
|
| codebooks[:, 1:, :],
|
| ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
|
| CODEBOOK_PAD_TOKEN_ID,
|
| )
|
|
|
| return codebooks
|
|
|
|
|
| def decode_one_token_ar(
|
| model: DualARTransformer,
|
| x: torch.Tensor,
|
| input_pos: torch.Tensor,
|
| semantic_ids: list,
|
| previous_tokens: torch.Tensor = None,
|
| **sampling_kwargs,
|
| ) -> torch.Tensor:
|
| x = model.forward_generate(x, input_pos)
|
|
|
| sampling_kwargs_main = sampling_kwargs.copy()
|
|
|
|
|
|
|
|
|
| codebooks = [
|
| sample(
|
| x.logits,
|
| previous_tokens=(
|
| previous_tokens[0] if previous_tokens is not None else None
|
| ),
|
| **sampling_kwargs_main,
|
| )[0]
|
| ]
|
|
|
| hidden_states = x.hidden_states
|
|
|
|
|
| for layer in model.fast_layers:
|
| layer.attention.kv_cache.k_cache.fill_(0)
|
| layer.attention.kv_cache.v_cache.fill_(0)
|
|
|
| input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
|
| model.forward_generate_fast(hidden_states, input_pos)
|
| a = codebooks[0] - model.tokenizer.semantic_begin_id
|
| a[a < 0] = 0
|
| hidden_states = model.fast_embeddings(a)
|
| codebooks.append(a)
|
|
|
| for codebook_idx in range(1, model.config.num_codebooks):
|
| input_pos = torch.tensor(
|
| [codebook_idx], device=hidden_states.device, dtype=torch.long
|
| )
|
| logits = model.forward_generate_fast(hidden_states, input_pos)
|
| a = sample(
|
| logits,
|
| previous_tokens=(
|
| previous_tokens[codebook_idx + 1]
|
| if previous_tokens is not None
|
| else None
|
| ),
|
| **sampling_kwargs,
|
| )[0]
|
| hidden_states = model.fast_embeddings(a)
|
| codebooks.append(a)
|
|
|
| codebooks = torch.stack(codebooks, dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
| return codebooks
|
|
|
|
|
| def decode_one_token_naive(
|
| model: NaiveTransformer,
|
| x: torch.Tensor,
|
| input_pos: torch.Tensor,
|
| previous_tokens: torch.Tensor = None,
|
| **sampling_kwargs,
|
| ) -> torch.Tensor:
|
| x = model.forward_generate(x, input_pos)
|
|
|
| sampling_kwargs_main = sampling_kwargs.copy()
|
| sampling_kwargs_main["temperature"] = 0.1
|
| sampling_kwargs_main["top_p"] = 0.1
|
| sampling_kwargs_main["repetition_penalty"] = 1.0
|
|
|
| codebooks = [
|
| sample(
|
| x.logits,
|
| previous_tokens=None,
|
| **sampling_kwargs_main,
|
| )[0]
|
| ]
|
|
|
| for i in range(model.config.num_codebooks):
|
| codebooks.append(
|
| sample(
|
| x.codebook_logits[:, :, i],
|
| previous_tokens=(
|
| previous_tokens[i + 1] if previous_tokens is not None else None
|
| ),
|
| **sampling_kwargs,
|
| )[0]
|
| )
|
|
|
| return torch.stack(codebooks, dim=0)
|
|
|
|
|
| def decode_n_tokens(
|
| model: NaiveTransformer,
|
| cur_token: torch.Tensor,
|
| input_pos: torch.Tensor,
|
| num_new_tokens: int,
|
| semantic_ids: list,
|
| decode_one_token=decode_one_token_naive,
|
| **sampling_kwargs,
|
| ):
|
| previous_tokens = torch.zeros(
|
| (model.config.num_codebooks + 1, model.config.max_seq_len),
|
| dtype=torch.int,
|
| device=cur_token.device,
|
| )
|
|
|
| for i in tqdm(range(num_new_tokens)):
|
|
|
| win_size = 16
|
| if i < win_size:
|
| window = previous_tokens[:, :win_size]
|
| else:
|
| window = previous_tokens[:, i - win_size : i]
|
|
|
| with (
|
| torch.backends.cuda.sdp_kernel(
|
| enable_flash=False, enable_mem_efficient=False, enable_math=True
|
| )
|
| if torch.cuda.is_available()
|
| else nullcontext()
|
| ):
|
| next_token = decode_one_token(
|
| model=model,
|
| x=cur_token,
|
| input_pos=input_pos,
|
| previous_tokens=window,
|
| semantic_ids=semantic_ids,
|
| **sampling_kwargs,
|
| )
|
|
|
| input_pos += 1
|
| cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
|
| previous_tokens[:, i : i + 1] = next_token.view(
|
| model.config.num_codebooks + 1, -1
|
| )
|
|
|
| if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
|
| break
|
|
|
| return previous_tokens[:, : i + 1]
|
|
|
|
|
| @torch.no_grad()
|
| @torch.inference_mode()
|
| def generate(
|
| *,
|
| model: NaiveTransformer,
|
| prompt: torch.Tensor,
|
| max_new_tokens: int,
|
| decode_one_token=decode_one_token_naive,
|
| **sampling_kwargs,
|
| ) -> torch.Tensor:
|
| """
|
| Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
|
| """
|
|
|
|
|
| T = prompt.size(1)
|
|
|
| semantic_ids = [
|
| model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024)
|
| ]
|
|
|
| if max_new_tokens:
|
| if T + max_new_tokens > model.config.max_seq_len:
|
| max_new_tokens = model.config.max_seq_len - T
|
| logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
|
|
|
| T_new = T + max_new_tokens
|
| else:
|
| T_new = model.config.max_seq_len
|
| max_new_tokens = T_new - T
|
|
|
| device, dtype = prompt.device, prompt.dtype
|
|
|
| codebook_dim = 1 + model.config.num_codebooks
|
|
|
| empty = torch.empty(
|
| (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
|
| )
|
| empty[:, :T] = prompt
|
| seq = empty
|
| input_pos = torch.arange(0, T, device=device)
|
|
|
|
|
| prefill_decode = (
|
| decode_one_token_naive
|
| if isinstance(model, NaiveTransformer)
|
| else decode_one_token_ar
|
| )
|
|
|
| next_token = prefill_decode(
|
| model,
|
| prompt.view(1, codebook_dim, -1),
|
| input_pos,
|
| semantic_ids=semantic_ids,
|
| **sampling_kwargs,
|
| )
|
| seq[:, T : T + 1] = next_token
|
|
|
| input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
| x = decode_n_tokens(
|
| model,
|
| next_token.view(1, codebook_dim, -1),
|
| input_pos,
|
| max_new_tokens - 1,
|
| decode_one_token=decode_one_token,
|
| semantic_ids=semantic_ids,
|
| **sampling_kwargs,
|
| )
|
|
|
| seq = seq[:, : T + 1 + x.size(1)]
|
| seq[:, T + 1 :] = x
|
|
|
| return seq
|
|
|
|
|
| def decode_n_tokens_agent(
|
| model: NaiveTransformer,
|
| cur_token: torch.Tensor,
|
| input_pos: torch.Tensor,
|
| num_new_tokens: int,
|
| semantic_ids: list,
|
| im_end_id: int = 4,
|
| decode_one_token=decode_one_token_naive_agent,
|
| early_stop_threshold: float = 0.6,
|
| **sampling_kwargs,
|
| ):
|
| batch_size = cur_token.size(0)
|
| previous_tokens = torch.zeros(
|
| (batch_size, model.config.num_codebooks + 1, model.config.max_seq_len),
|
| dtype=torch.int,
|
| device=cur_token.device,
|
| )
|
| finished = torch.zeros(batch_size, dtype=torch.bool, device=cur_token.device)
|
| finished = finished | (cur_token[:, 0, -1] == im_end_id)
|
| start_time = time.time()
|
|
|
| for i in tqdm(range(num_new_tokens), desc="Decoding: ", total=num_new_tokens):
|
|
|
| win_size = 16
|
| if i < win_size:
|
| window = previous_tokens[:, :, :win_size]
|
| else:
|
| window = previous_tokens[:, :, i - win_size : i]
|
|
|
| with sdpa_kernel(
|
| SDPBackend.MATH
|
| ):
|
| next_token = decode_one_token(
|
| model=model,
|
| x=cur_token,
|
| input_pos=input_pos,
|
| previous_tokens=window,
|
| semantic_ids=semantic_ids,
|
| **sampling_kwargs,
|
| )
|
|
|
| input_pos += 1
|
| cur_token = next_token.view(batch_size, model.config.num_codebooks + 1, -1)
|
| previous_tokens[:, :, i : i + 1] = next_token.view(
|
| batch_size, model.config.num_codebooks + 1, -1
|
| )
|
|
|
| yield cur_token.cpu()
|
|
|
| finished = finished | (cur_token[:, 0, -1] == im_end_id)
|
| if finished.all() or (
|
| 0 < early_stop_threshold < 1
|
| and finished.sum() >= round(batch_size * early_stop_threshold)
|
| ):
|
| break
|
|
|
| total_time = time.time() - start_time
|
| generated_tokens = i + 1
|
| tokens_per_second = (generated_tokens / total_time) * batch_size
|
| logger.info(
|
| f"Decoded {generated_tokens} x {batch_size} tokens in {total_time:.2f}s ({tokens_per_second:.2f} tokens/s)"
|
| )
|
|
|
|
|
| @torch.no_grad()
|
| @torch.inference_mode()
|
| def generate_agent(
|
| *,
|
| model: BaseTransformer,
|
| prompt: torch.Tensor,
|
| max_new_tokens: int,
|
| semantic_ids: list,
|
| im_end_id: int = 4,
|
| decode_one_token=decode_one_token_naive_agent,
|
| num_samples: int = 1,
|
| early_stop_threshold: float = 0.6,
|
| **sampling_kwargs,
|
| ):
|
| """
|
| Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
|
| """
|
|
|
|
|
| T = prompt.size(1)
|
| prompt = prompt[None].repeat(num_samples, 1, 1)
|
|
|
| if T >= model.config.max_seq_len:
|
| raise ValueError(
|
| f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
|
| )
|
|
|
| if max_new_tokens:
|
| if T + max_new_tokens > model.config.max_seq_len:
|
| max_new_tokens = model.config.max_seq_len - T
|
| logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
|
|
|
| T_new = T + max_new_tokens
|
| else:
|
| T_new = model.config.max_seq_len
|
| max_new_tokens = T_new - T
|
|
|
| device, dtype = prompt.device, prompt.dtype
|
|
|
| codebook_dim = 1 + model.config.num_codebooks
|
| input_pos = torch.arange(0, T, device=device)
|
|
|
|
|
| prefill_decode = (
|
| decode_one_token_naive_agent
|
| if isinstance(model, NaiveTransformer)
|
| else decode_one_token_ar_agent
|
| )
|
| next_token = prefill_decode(
|
| model,
|
| prompt,
|
| input_pos,
|
| semantic_ids=semantic_ids,
|
| **sampling_kwargs,
|
| ).view(num_samples, codebook_dim, -1)
|
| yield next_token.cpu()
|
|
|
| input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
|
|
| yield from decode_n_tokens_agent(
|
| model,
|
| next_token,
|
| input_pos,
|
| max_new_tokens - 1,
|
| im_end_id=im_end_id,
|
| semantic_ids=semantic_ids,
|
| decode_one_token=decode_one_token,
|
| early_stop_threshold=early_stop_threshold,
|
| **sampling_kwargs,
|
| )
|
|
|
|
|
| def encode_tokens(
|
| tokenizer,
|
| string,
|
| device="cuda",
|
| prompt_tokens=None,
|
| num_codebooks=4,
|
| ):
|
| string = clean_text(string)
|
|
|
| messages = []
|
| messages.append(
|
| Message(
|
| role="user",
|
| parts=[TextPart(text=string)],
|
| cal_loss=False,
|
| )
|
| )
|
|
|
| if prompt_tokens is not None:
|
| if prompt_tokens.ndim == 3:
|
| assert (
|
| prompt_tokens.shape[0] == 1
|
| ), "3D prompt tokens should have shape (1, num_codebooks, seq_len)"
|
| prompt_tokens = prompt_tokens[0]
|
|
|
| assert prompt_tokens.ndim == 2, "Prompt tokens should be 2D tensor"
|
|
|
| if prompt_tokens.shape[0] > num_codebooks:
|
| logger.warning(
|
| f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
|
| )
|
| prompt_tokens = prompt_tokens[:num_codebooks]
|
|
|
| vq_part = VQPart(codes=prompt_tokens.to(device))
|
|
|
| messages.append(
|
| Message(
|
| role="assistant",
|
| parts=[TextPart(text="<|voice|>"), vq_part],
|
| cal_loss=False,
|
| )
|
| )
|
| else:
|
| messages.append(
|
| Message(
|
| role="assistant",
|
| parts=[TextPart(text="<|voice|>")],
|
| cal_loss=False,
|
| add_im_end=False,
|
| )
|
| )
|
|
|
| conversation = Conversation(messages=messages)
|
|
|
| encoded = conversation.encode_for_inference(
|
| tokenizer=tokenizer,
|
| num_codebooks=num_codebooks,
|
| )
|
|
|
| return encoded.to(device)
|
|
|
|
|
| def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
|
| model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
|
| checkpoint_path, load_weights=True, is_agent=is_agent
|
| )
|
|
|
| model = model.to(device=device, dtype=precision)
|
| logger.info(f"Restored model from checkpoint")
|
|
|
| if isinstance(model, DualARTransformer):
|
| decode_one_token = (
|
| decode_one_token_ar_agent if is_agent else decode_one_token_ar
|
| )
|
| logger.info("Using DualARTransformer")
|
| else:
|
| decode_one_token = (
|
| decode_one_token_naive_agent if is_agent else decode_one_token_naive
|
| )
|
| logger.info("Using NaiveTransformer")
|
|
|
| if compile:
|
| logger.info("Compiling function...")
|
| decode_one_token = torch.compile(
|
| decode_one_token,
|
| fullgraph=True,
|
| backend="inductor" if torch.cuda.is_available() else "aot_eager",
|
| mode="reduce-overhead" if torch.cuda.is_available() else None,
|
| )
|
|
|
| return model.eval(), decode_one_token
|
|
|
|
|
| @dataclass
|
| class GenerateResponse:
|
| action: Literal["sample", "next"]
|
| codes: Optional[torch.Tensor] = None
|
| text: Optional[str] = None
|
|
|
|
|
| def generate_long(
|
| *,
|
| model,
|
| device: str | torch.device,
|
| decode_one_token: callable,
|
| text: str,
|
| num_samples: int = 1,
|
| max_new_tokens: int = 0,
|
| top_p: int = 0.7,
|
| repetition_penalty: float = 1.5,
|
| temperature: float = 0.7,
|
| compile: bool = False,
|
| iterative_prompt: bool = True,
|
| max_length: int = 2048,
|
| chunk_length: int = 150,
|
| prompt_text: Optional[str | list[str]] = None,
|
| prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
|
| ):
|
| assert 0 < top_p <= 1, "top_p must be in (0, 1]"
|
| assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
|
| assert 0 < temperature < 2, "temperature must be in (0, 2)"
|
|
|
| use_prompt = prompt_text is not None and prompt_tokens is not None
|
| if use_prompt and isinstance(prompt_text, str):
|
| prompt_text = [prompt_text]
|
| prompt_tokens = [prompt_tokens]
|
|
|
| assert use_prompt is False or len(prompt_text) == len(
|
| prompt_tokens
|
| ), "Prompt text and tokens must have the same length"
|
|
|
| model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| tokenizer = model.tokenizer
|
| im_end_id = tokenizer.get_token_id("<|im_end|>")
|
|
|
| encoded = []
|
| texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
| encoded_prompts = [
|
| Conversation(
|
| messages=[
|
| Message(
|
| role="system",
|
| parts=[TextPart(text="Speak out the provided text.")],
|
| cal_loss=False,
|
| )
|
| ]
|
| )
|
| .encode_for_inference(
|
| tokenizer=tokenizer,
|
| num_codebooks=model.config.num_codebooks,
|
| )
|
| .to(device)
|
| ]
|
|
|
| if use_prompt:
|
| for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
|
| encoded_prompts.append(
|
| encode_tokens(
|
| tokenizer,
|
| string=t,
|
| device=device,
|
| prompt_tokens=c,
|
| num_codebooks=model.config.num_codebooks,
|
| )
|
| )
|
|
|
| for idx, text in enumerate(texts):
|
| encoded.append(
|
| encode_tokens(
|
| tokenizer,
|
| string=text,
|
| device=device,
|
| num_codebooks=model.config.num_codebooks,
|
| )
|
| )
|
| logger.info(f"Encoded text: {text}")
|
|
|
|
|
|
|
| temperature = torch.tensor(temperature, device=device, dtype=torch.float)
|
| top_p = torch.tensor(top_p, device=device, dtype=torch.float)
|
| repetition_penalty = torch.tensor(
|
| repetition_penalty, device=device, dtype=torch.float
|
| )
|
|
|
| for sample_idx in range(num_samples):
|
| if torch.cuda.is_available():
|
| torch.cuda.synchronize()
|
|
|
| global_encoded = []
|
| seg_idx = 0
|
|
|
| while seg_idx < len(encoded):
|
| logger.info(
|
| f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
|
| )
|
|
|
| seg = encoded[seg_idx]
|
| global_encoded.append(seg)
|
|
|
| lengths = reversed([seg.size(1) for seg in global_encoded])
|
|
|
|
|
| count = 0
|
| for i, length in enumerate(lengths):
|
| count += length
|
| if count + length > max_length - 1024 - sum(
|
| t.shape[1] for t in encoded_prompts
|
| ):
|
| break
|
|
|
| if i != 0 and i % 2 == 0:
|
| i -= 1
|
|
|
|
|
| if i < len(global_encoded) - 2:
|
| partial_encoded = global_encoded[:2] + global_encoded[-i:]
|
| else:
|
| partial_encoded = global_encoded
|
|
|
| if use_prompt:
|
| partial_encoded = encoded_prompts + partial_encoded
|
|
|
| cat_encoded = torch.cat(partial_encoded, dim=1)
|
| prompt_length = cat_encoded.size(1)
|
|
|
| t0 = time.perf_counter()
|
| y = generate(
|
| model=model,
|
| prompt=cat_encoded,
|
| max_new_tokens=max_new_tokens,
|
| decode_one_token=decode_one_token,
|
| temperature=temperature,
|
| top_p=top_p,
|
| repetition_penalty=repetition_penalty,
|
| )
|
|
|
| if sample_idx == 0 and seg_idx == 0 and compile:
|
| logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
|
|
|
| if torch.cuda.is_available():
|
| torch.cuda.synchronize()
|
|
|
| t = time.perf_counter() - t0
|
|
|
| tokens_generated = y.size(1) - prompt_length
|
| tokens_sec = tokens_generated / t
|
| logger.info(
|
| f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
|
| )
|
| logger.info(
|
| f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
|
| )
|
|
|
| if torch.cuda.is_available():
|
| logger.info(
|
| f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
|
| )
|
|
|
|
|
|
|
| codes = y[1:, prompt_length + 1 :].clone()
|
| assert (codes >= 0).all(), f"Negative code found"
|
|
|
| decoded = y[:, prompt_length:].clone()
|
|
|
|
|
| global_encoded.append(decoded)
|
| assert (codes >= 0).all(), f"Negative code found: {codes}"
|
| yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
|
| seg_idx += 1
|
|
|
|
|
| yield GenerateResponse(action="next")
|
|
|
|
|
| @dataclass
|
| class WrappedGenerateResponse:
|
| status: Literal["success", "error"]
|
| response: Optional[GenerateResponse | Exception] = None
|
|
|
|
|
| @dataclass
|
| class GenerateRequest:
|
| request: dict
|
| response_queue: queue.Queue
|
|
|
|
|
| def launch_thread_safe_queue(
|
| checkpoint_path,
|
| device,
|
| precision,
|
| compile: bool = False,
|
| ):
|
| input_queue = queue.Queue()
|
| init_event = threading.Event()
|
|
|
| def worker():
|
| model, decode_one_token = load_model(
|
| checkpoint_path, device, precision, compile=compile
|
| )
|
| with torch.device(device):
|
| model.setup_caches(
|
| max_batch_size=1,
|
| max_seq_len=model.config.max_seq_len,
|
| dtype=next(model.parameters()).dtype,
|
| )
|
| init_event.set()
|
|
|
| while True:
|
| item: GenerateRequest | None = input_queue.get()
|
| if item is None:
|
| break
|
|
|
| kwargs = item.request
|
| response_queue = item.response_queue
|
|
|
| try:
|
| for chunk in generate_long(
|
| model=model, decode_one_token=decode_one_token, **kwargs
|
| ):
|
| response_queue.put(
|
| WrappedGenerateResponse(status="success", response=chunk)
|
| )
|
| except Exception as e:
|
| response_queue.put(WrappedGenerateResponse(status="error", response=e))
|
|
|
| threading.Thread(target=worker, daemon=True).start()
|
| init_event.wait()
|
|
|
| return input_queue
|
|
|
|
|
| def launch_thread_safe_queue_agent(
|
| checkpoint_path,
|
| device,
|
| precision,
|
| compile: bool = False,
|
| ):
|
| input_queue = queue.Queue()
|
| init_event = threading.Event()
|
|
|
| tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
|
| config = BaseModelArgs.from_pretrained(checkpoint_path)
|
|
|
| def worker():
|
| model, decode_one_token = load_model(
|
| checkpoint_path, device, precision, compile=compile, is_agent=True
|
| )
|
|
|
| with torch.device(device):
|
| model.setup_caches(
|
| max_batch_size=1,
|
| max_seq_len=model.config.max_seq_len,
|
| dtype=next(model.parameters()).dtype,
|
| )
|
| init_event.set()
|
|
|
| while True:
|
| item: GenerateRequest | None = input_queue.get()
|
| if item is None:
|
| break
|
|
|
| kwargs = item.request
|
| response_queue = item.response_queue
|
|
|
| try:
|
| for token in generate_agent(
|
| model=model,
|
| decode_one_token=decode_one_token,
|
| **kwargs,
|
| ):
|
| response_queue.put(token)
|
|
|
| response_queue.put("stop")
|
| except Exception as e:
|
| import traceback
|
|
|
| logger.exception(f"Error in worker: {traceback.format_exc()}")
|
| response_queue.put("error")
|
|
|
| threading.Thread(target=worker, daemon=True).start()
|
| init_event.wait()
|
|
|
| return input_queue, tokenizer, config
|
|
|
|
|
| @click.command()
|
| @click.option(
|
| "--text",
|
| type=str,
|
| default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
|
| )
|
| @click.option("--prompt-text", type=str, default=None, multiple=True)
|
| @click.option(
|
| "--prompt-tokens",
|
| type=click.Path(path_type=Path, exists=True),
|
| default=None,
|
| multiple=True,
|
| )
|
| @click.option("--num-samples", type=int, default=1)
|
| @click.option("--max-new-tokens", type=int, default=0)
|
| @click.option("--top-p", type=float, default=0.7)
|
| @click.option("--repetition-penalty", type=float, default=1.2)
|
| @click.option("--temperature", type=float, default=0.7)
|
| @click.option(
|
| "--checkpoint-path",
|
| type=click.Path(path_type=Path, exists=True),
|
| default="checkpoints/fish-speech-1.4",
|
| )
|
| @click.option("--device", type=str, default="cuda")
|
| @click.option("--compile/--no-compile", default=False)
|
| @click.option("--seed", type=int, default=42)
|
| @click.option("--half/--no-half", default=False)
|
| @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
|
| @click.option("--chunk-length", type=int, default=100)
|
| def main(
|
| text: str,
|
| prompt_text: Optional[list[str]],
|
| prompt_tokens: Optional[list[Path]],
|
| num_samples: int,
|
| max_new_tokens: int,
|
| top_p: int,
|
| repetition_penalty: float,
|
| temperature: float,
|
| checkpoint_path: Path,
|
| device: str,
|
| compile: bool,
|
| seed: int,
|
| half: bool,
|
| iterative_prompt: bool,
|
| chunk_length: int,
|
| ) -> None:
|
|
|
| precision = torch.half if half else torch.bfloat16
|
|
|
| if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
|
| raise ValueError(
|
| f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
|
| )
|
|
|
| logger.info("Loading model ...")
|
| t0 = time.time()
|
| model, decode_one_token = load_model(
|
| checkpoint_path, device, precision, compile=compile
|
| )
|
| with torch.device(device):
|
| model.setup_caches(
|
| max_batch_size=1,
|
| max_seq_len=model.config.max_seq_len,
|
| dtype=next(model.parameters()).dtype,
|
| )
|
| if torch.cuda.is_available():
|
| torch.cuda.synchronize()
|
|
|
| logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
|
|
|
| if prompt_tokens is not None:
|
| prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
|
|
|
| torch.manual_seed(seed)
|
|
|
| if torch.cuda.is_available():
|
| torch.cuda.manual_seed(seed)
|
|
|
| generator = generate_long(
|
| model=model,
|
| device=device,
|
| decode_one_token=decode_one_token,
|
| text=text,
|
| num_samples=num_samples,
|
| max_new_tokens=max_new_tokens,
|
| top_p=top_p,
|
| repetition_penalty=repetition_penalty,
|
| temperature=temperature,
|
| compile=compile,
|
| iterative_prompt=iterative_prompt,
|
| chunk_length=chunk_length,
|
| prompt_text=prompt_text,
|
| prompt_tokens=prompt_tokens,
|
| )
|
|
|
| idx = 0
|
| codes = []
|
|
|
| for response in generator:
|
| if response.action == "sample":
|
| codes.append(response.codes)
|
| logger.info(f"Sampled text: {response.text}")
|
| elif response.action == "next":
|
| if codes:
|
| np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
|
| logger.info(f"Saved codes to codes_{idx}.npy")
|
| logger.info(f"Next sample")
|
| codes = []
|
| idx += 1
|
| else:
|
| logger.error(f"Error: {response}")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|