Spaces:
Configuration error
Configuration error
| # /// script | |
| # requires-python = ">=3.12" | |
| # dependencies = [ | |
| # "julius", | |
| # "librosa", | |
| # "soundfile", | |
| # "moshi==0.2.11", | |
| # ] | |
| # /// | |
| """An example script that illustrates how one can get per-word timestamps from | |
| Kyutai STT models. | |
| """ | |
| import argparse | |
| import dataclasses | |
| import itertools | |
| import math | |
| import julius | |
| import moshi.models | |
| import sphn | |
| import time | |
| import torch | |
| class TimestampedText: | |
| text: str | |
| timestamp: tuple[float, float] | |
| def __str__(self): | |
| return f"{self.text} ({self.timestamp[0]:.2f}:{self.timestamp[1]:.2f})" | |
| def tokens_to_timestamped_text( | |
| text_tokens, | |
| tokenizer, | |
| frame_rate, | |
| end_of_padding_id, | |
| padding_token_id, | |
| offset_seconds, | |
| ) -> list[TimestampedText]: | |
| text_tokens = text_tokens.cpu().view(-1) | |
| # Normally `end_of_padding` tokens indicate word boundaries. | |
| # Everything between them should be a single word; | |
| # the time offset of the those tokens correspond to word start and | |
| # end timestamps (minus silence prefix and audio delay). | |
| # | |
| # However, in rare cases some complexities could arise. Firstly, | |
| # for words that are said quickly but are represented with | |
| # multiple tokens, the boundary might be omitted. Secondly, | |
| # for the very last word the end boundary might not happen. | |
| # Below is a code snippet that handles those situations a bit | |
| # more carefully. | |
| sequence_timestamps = [] | |
| def _tstmp(start_position, end_position): | |
| return ( | |
| max(0, start_position / frame_rate - offset_seconds), | |
| max(0, end_position / frame_rate - offset_seconds), | |
| ) | |
| def _decode(t): | |
| t = t[t > padding_token_id] | |
| return tokenizer.decode(t.numpy().tolist()) | |
| def _decode_segment(start, end): | |
| nonlocal text_tokens | |
| nonlocal sequence_timestamps | |
| text = _decode(text_tokens[start:end]) | |
| words_inside_segment = text.split() | |
| if len(words_inside_segment) == 0: | |
| return | |
| if len(words_inside_segment) == 1: | |
| # Single word within the boundaries, the general case | |
| sequence_timestamps.append( | |
| TimestampedText(text=text, timestamp=_tstmp(start, end)) | |
| ) | |
| else: | |
| # We're in a rare situation where multiple words are so close they are not separated by `end_of_padding`. | |
| # We tokenize words one-by-one; each word is assigned with as many frames as much tokens it has. | |
| for adjacent_word in words_inside_segment[:-1]: | |
| n_tokens = len(tokenizer.encode(adjacent_word)) | |
| sequence_timestamps.append( | |
| TimestampedText( | |
| text=adjacent_word, timestamp=_tstmp(start, start + n_tokens) | |
| ) | |
| ) | |
| start += n_tokens | |
| # The last word takes everything until the boundary | |
| adjacent_word = words_inside_segment[-1] | |
| sequence_timestamps.append( | |
| TimestampedText(text=adjacent_word, timestamp=_tstmp(start, end)) | |
| ) | |
| (segment_boundaries,) = torch.where(text_tokens == end_of_padding_id) | |
| if not segment_boundaries.numel(): | |
| return [] | |
| for i in range(len(segment_boundaries) - 1): | |
| segment_start = int(segment_boundaries[i]) + 1 | |
| segment_end = int(segment_boundaries[i + 1]) | |
| _decode_segment(segment_start, segment_end) | |
| last_segment_start = segment_boundaries[-1] + 1 | |
| boundary_token = torch.tensor([tokenizer.eos_id()]) | |
| (end_of_last_segment,) = torch.where( | |
| torch.isin(text_tokens[last_segment_start:], boundary_token) | |
| ) | |
| if not end_of_last_segment.numel(): | |
| # upper-bound either end of the audio or 1 second duration, whicher is smaller | |
| last_segment_end = min(text_tokens.shape[-1], last_segment_start + frame_rate) | |
| else: | |
| last_segment_end = last_segment_start + end_of_last_segment[0] | |
| _decode_segment(last_segment_start, last_segment_end) | |
| return sequence_timestamps | |
| def main(args): | |
| if args.vad and args.hf_repo is None: | |
| args.hf_repo = "kyutai/stt-1b-en_fr-candle" | |
| info = moshi.models.loaders.CheckpointInfo.from_hf_repo( | |
| args.hf_repo, | |
| moshi_weights=args.moshi_weight, | |
| mimi_weights=args.mimi_weight, | |
| tokenizer=args.tokenizer, | |
| config_path=args.config_path, | |
| ) | |
| mimi = info.get_mimi(device=args.device) | |
| tokenizer = info.get_text_tokenizer() | |
| lm = info.get_moshi( | |
| device=args.device, | |
| dtype=torch.bfloat16, | |
| ) | |
| lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0) | |
| audio_silence_prefix_seconds = info.stt_config.get( | |
| "audio_silence_prefix_seconds", 1.0 | |
| ) | |
| audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0) | |
| padding_token_id = info.raw_config.get("text_padding_token_id", 3) | |
| audio, input_sample_rate = sphn.read(args.in_file) | |
| audio = torch.from_numpy(audio).to(args.device) | |
| audio = julius.resample_frac(audio, input_sample_rate, mimi.sample_rate) | |
| if audio.shape[-1] % mimi.frame_size != 0: | |
| to_pad = mimi.frame_size - audio.shape[-1] % mimi.frame_size | |
| audio = torch.nn.functional.pad(audio, (0, to_pad)) | |
| text_tokens_accum = [] | |
| n_prefix_chunks = math.ceil(audio_silence_prefix_seconds * mimi.frame_rate) | |
| n_suffix_chunks = math.ceil(audio_delay_seconds * mimi.frame_rate) | |
| silence_chunk = torch.zeros( | |
| (1, 1, mimi.frame_size), dtype=torch.float32, device=args.device | |
| ) | |
| chunks = itertools.chain( | |
| itertools.repeat(silence_chunk, n_prefix_chunks), | |
| torch.split(audio[:, None], mimi.frame_size, dim=-1), | |
| itertools.repeat(silence_chunk, n_suffix_chunks), | |
| ) | |
| start_time = time.time() | |
| nchunks = 0 | |
| last_print_was_vad = False | |
| with mimi.streaming(1), lm_gen.streaming(1): | |
| for audio_chunk in chunks: | |
| nchunks += 1 | |
| audio_tokens = mimi.encode(audio_chunk) | |
| if args.vad: | |
| text_tokens, vad_heads = lm_gen.step_with_extra_heads(audio_tokens) | |
| if vad_heads: | |
| pr_vad = vad_heads[2][0, 0, 0].cpu().item() | |
| if pr_vad > 0.5 and not last_print_was_vad: | |
| print(" [end of turn detected]") | |
| last_print_was_vad = True | |
| else: | |
| text_tokens = lm_gen.step(audio_tokens) | |
| text_token = text_tokens[0, 0, 0].cpu().item() | |
| if text_token not in (0, 3): | |
| _text = tokenizer.id_to_piece(text_tokens[0, 0, 0].cpu().item()) # type: ignore | |
| _text = _text.replace("▁", " ") | |
| print(_text, end="", flush=True) | |
| last_print_was_vad = False | |
| text_tokens_accum.append(text_tokens) | |
| utterance_tokens = torch.concat(text_tokens_accum, dim=-1) | |
| dt = time.time() - start_time | |
| print( | |
| f"\nprocessed {nchunks} chunks in {dt:.2f} seconds, steps per second: {nchunks / dt:.2f}" | |
| ) | |
| timed_text = tokens_to_timestamped_text( | |
| utterance_tokens, | |
| tokenizer, | |
| mimi.frame_rate, | |
| end_of_padding_id=0, | |
| padding_token_id=padding_token_id, | |
| offset_seconds=int(n_prefix_chunks / mimi.frame_rate) + audio_delay_seconds, | |
| ) | |
| decoded = " ".join([str(t) for t in timed_text]) | |
| print(decoded) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Example streaming STT w/ timestamps.") | |
| parser.add_argument("in_file", help="The file to transcribe.") | |
| parser.add_argument( | |
| "--hf-repo", type=str, help="HF repo to load the STT model from. " | |
| ) | |
| parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.") | |
| parser.add_argument( | |
| "--moshi-weight", type=str, help="Path to a local checkpoint file." | |
| ) | |
| parser.add_argument( | |
| "--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi." | |
| ) | |
| parser.add_argument( | |
| "--config-path", type=str, help="Path to a local config file.", default=None | |
| ) | |
| parser.add_argument( | |
| "--vad", action="store_true", help="Enable VAD (Voice Activity Detection)." | |
| ) | |
| parser.add_argument( | |
| "--device", | |
| type=str, | |
| default="cuda", | |
| help="Device on which to run, defaults to 'cuda'.", | |
| ) | |
| args = parser.parse_args() | |
| main(args) | |