Spaces:
Configuration error
Configuration error
| """An example script that illustrates how one can prompt Kyutai STT models.""" | |
| import argparse | |
| import itertools | |
| import math | |
| from collections import deque | |
| import julius | |
| import moshi.models | |
| import sphn | |
| import torch | |
| import tqdm | |
| class PromptHook: | |
| def __init__(self, tokenizer, prefix, padding_tokens=(0, 3)): | |
| self.tokenizer = tokenizer | |
| self.prefix_enforce = deque(self.tokenizer.encode(prefix)) | |
| self.padding_tokens = padding_tokens | |
| def on_token(self, token): | |
| if not self.prefix_enforce: | |
| return | |
| token = token.item() | |
| if token in self.padding_tokens: | |
| pass | |
| elif token == self.prefix_enforce[0]: | |
| self.prefix_enforce.popleft() | |
| else: | |
| assert False | |
| def on_logits(self, logits): | |
| if not self.prefix_enforce: | |
| return | |
| mask = torch.zeros_like(logits, dtype=torch.bool) | |
| for t in self.padding_tokens: | |
| mask[..., t] = True | |
| mask[..., self.prefix_enforce[0]] = True | |
| logits[:] = torch.where(mask, logits, float("-inf")) | |
| def main(args): | |
| info = moshi.models.loaders.CheckpointInfo.from_hf_repo( | |
| args.hf_repo, | |
| moshi_weights=args.moshi_weight, | |
| mimi_weights=args.mimi_weight, | |
| tokenizer=args.tokenizer, | |
| config_path=args.config_path, | |
| ) | |
| mimi = info.get_mimi(device=args.device) | |
| tokenizer = info.get_text_tokenizer() | |
| lm = info.get_moshi( | |
| device=args.device, | |
| dtype=torch.bfloat16, | |
| ) | |
| if args.prompt_text: | |
| prompt_hook = PromptHook(tokenizer, args.prompt_text) | |
| lm_gen = moshi.models.LMGen( | |
| lm, | |
| temp=0, | |
| temp_text=0.0, | |
| on_text_hook=prompt_hook.on_token, | |
| on_text_logits_hook=prompt_hook.on_logits, | |
| ) | |
| else: | |
| lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0) | |
| audio_silence_prefix_seconds = info.stt_config.get( | |
| "audio_silence_prefix_seconds", 1.0 | |
| ) | |
| audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0) | |
| padding_token_id = info.raw_config.get("text_padding_token_id", 3) | |
| def _load_and_process(path): | |
| audio, input_sample_rate = sphn.read(path) | |
| audio = torch.from_numpy(audio).to(args.device).mean(axis=0, keepdim=True) | |
| audio = julius.resample_frac(audio, input_sample_rate, mimi.sample_rate) | |
| if audio.shape[-1] % mimi.frame_size != 0: | |
| to_pad = mimi.frame_size - audio.shape[-1] % mimi.frame_size | |
| audio = torch.nn.functional.pad(audio, (0, to_pad)) | |
| return audio | |
| n_prefix_chunks = math.ceil(audio_silence_prefix_seconds * mimi.frame_rate) | |
| n_suffix_chunks = math.ceil(audio_delay_seconds * mimi.frame_rate) | |
| silence_chunk = torch.zeros( | |
| (1, 1, mimi.frame_size), dtype=torch.float32, device=args.device | |
| ) | |
| audio = _load_and_process(args.file) | |
| if args.prompt_file: | |
| audio_prompt = _load_and_process(args.prompt_file) | |
| else: | |
| audio_prompt = None | |
| chain = [itertools.repeat(silence_chunk, n_prefix_chunks)] | |
| if audio_prompt is not None: | |
| chain.append(torch.split(audio_prompt[:, None, :], mimi.frame_size, dim=-1)) | |
| # adding a bit (0.8s) of silence to separate prompt and the actual audio | |
| chain.append(itertools.repeat(silence_chunk, 10)) | |
| chain += [ | |
| torch.split(audio[:, None, :], mimi.frame_size, dim=-1), | |
| itertools.repeat(silence_chunk, n_suffix_chunks), | |
| ] | |
| chunks = itertools.chain(*chain) | |
| text_tokens_accum = [] | |
| with mimi.streaming(1), lm_gen.streaming(1): | |
| for audio_chunk in tqdm.tqdm(chunks): | |
| audio_tokens = mimi.encode(audio_chunk) | |
| text_tokens = lm_gen.step(audio_tokens) | |
| if text_tokens is not None: | |
| text_tokens_accum.append(text_tokens) | |
| utterance_tokens = torch.concat(text_tokens_accum, dim=-1) | |
| text_tokens = utterance_tokens.cpu().view(-1) | |
| # if we have an audio prompt and we don't want to have it in the transcript, | |
| # we should cut the corresponding number of frames from the output tokens. | |
| # However, there is also some amount of padding that happens before it | |
| # due to silence_prefix and audio_delay. Normally it is ignored in detokenization, | |
| # but now we should account for it to find the position of the prompt transcript. | |
| if args.cut_prompt_transcript and audio_prompt is not None: | |
| prompt_frames = audio_prompt.shape[1] // mimi.frame_size | |
| no_prompt_offset_seconds = audio_delay_seconds + audio_silence_prefix_seconds | |
| no_prompt_offset = int(no_prompt_offset_seconds * mimi.frame_rate) | |
| text_tokens = text_tokens[prompt_frames + no_prompt_offset :] | |
| text = tokenizer.decode( | |
| text_tokens[text_tokens > padding_token_id].numpy().tolist() | |
| ) | |
| print(text) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Example streaming STT w/ a prompt.") | |
| parser.add_argument( | |
| "--file", | |
| required=True, | |
| help="File to transcribe.", | |
| ) | |
| parser.add_argument( | |
| "--prompt_file", | |
| required=False, | |
| help="Audio of the prompt.", | |
| ) | |
| parser.add_argument( | |
| "--prompt_text", | |
| required=False, | |
| help="Text of the prompt.", | |
| ) | |
| parser.add_argument( | |
| "--cut-prompt-transcript", | |
| action="store_true", | |
| help="Cut the prompt from the output transcript", | |
| ) | |
| parser.add_argument( | |
| "--hf-repo", type=str, help="HF repo to load the STT model from. " | |
| ) | |
| parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.") | |
| parser.add_argument( | |
| "--moshi-weight", type=str, help="Path to a local checkpoint file." | |
| ) | |
| parser.add_argument( | |
| "--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi." | |
| ) | |
| parser.add_argument( | |
| "--config-path", type=str, help="Path to a local config file.", default=None | |
| ) | |
| parser.add_argument( | |
| "--device", | |
| type=str, | |
| default="cuda", | |
| help="Device on which to run, defaults to 'cuda'.", | |
| ) | |
| args = parser.parse_args() | |
| main(args) | |