| import io |
| import re |
|
|
| import librosa |
| import torch |
| import torchaudio |
| from cachetools import LRUCache, cached |
|
|
| CACHE_MAXSIZE = 10000 |
| MICRO_BATCH_SIZE = 8 |
| ASR_SAMPLE_RATE = 16000 |
| HUGE_GAP_THRESHOLD = 4000 |
|
|
|
|
| @torch.no_grad() |
| @torch.autocast(device_type="cuda", dtype=torch.half) |
| def batch_encode(model, audios_list: list[bytes]): |
| audios: list[torch.Tensor] = [ |
| ( |
| torch.from_numpy( |
| librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0] |
| )[None] |
| if isinstance(audio, bytes) |
| else audio |
| ) |
| for audio in audios_list |
| ] |
|
|
| lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device) |
| max_length = lengths.max().item() |
|
|
| print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s") |
|
|
| padded = torch.stack( |
| [ |
| torch.nn.functional.pad(audio, (0, int(max_length - audio.shape[-1]))) |
| for audio in audios |
| ] |
| ).to(model.device) |
|
|
| features, feature_lengths = model.encode(padded, audio_lengths=lengths) |
| features, feature_lengths = features.cpu(), feature_lengths.cpu() |
|
|
| return [feature[..., :length] for feature, length in zip(features, feature_lengths)] |
|
|
|
|
| @cached( |
| cache=LRUCache(maxsize=CACHE_MAXSIZE), |
| key=lambda model, audios: (model.device, tuple(audios)), |
| ) |
| def cached_vqgan_batch_encode(model, audios: list[bytes]): |
| return batch_encode(model, audios) |
|
|
|
|
| @torch.no_grad() |
| @torch.autocast(device_type="cuda", dtype=torch.half) |
| def batch_vqgan_decode(model, features): |
| lengths = torch.tensor( |
| [feature.shape[-1] for feature in features], device=model.device |
| ) |
| max_length = lengths.max().item() |
| padded = torch.stack( |
| [ |
| torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1])) |
| for feature in features |
| ] |
| ).to(model.device) |
|
|
| |
| audios, audio_lengths = [], [] |
| for i in range(0, padded.shape[0], MICRO_BATCH_SIZE): |
| audio, audio_length = model.decode( |
| padded[i : i + MICRO_BATCH_SIZE], |
| feature_lengths=lengths[i : i + MICRO_BATCH_SIZE], |
| ) |
| audios.append(audio) |
| audio_lengths.append(audio_length) |
| audios = torch.cat(audios, dim=0) |
| audio_lengths = torch.cat(audio_lengths, dim=0) |
| audios, audio_lengths = audios.cpu(), audio_lengths.cpu() |
|
|
| return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)] |
|
|
|
|
| @torch.no_grad() |
| def batch_asr(model, lock, audios, sr, language="auto"): |
| resampled_audios = [] |
| for audio in audios: |
| audio = torchaudio.functional.resample(audio, sr, ASR_SAMPLE_RATE) |
| assert audio.ndim == 1 |
| resampled_audios.append(audio) |
|
|
| with lock: |
| res = model.generate( |
| input=resampled_audios, |
| batch_size=len(resampled_audios), |
| language=language, |
| use_itn=True, |
| ) |
|
|
| results = [] |
| for r, audio in zip(res, audios): |
| text = r["text"] |
| text = re.sub(r"<\|.*?\|>", "", text) |
| duration = len(audio) / sr * 1000 |
| huge_gap = False |
|
|
| if "timestamp" in r and len(r["timestamp"]) > 2: |
| for timestamp_a, timestamp_b in zip( |
| r["timestamp"][:-1], r["timestamp"][1:] |
| ): |
| |
| if timestamp_b[0] - timestamp_a[1] > HUGE_GAP_THRESHOLD: |
| huge_gap = True |
| break |
|
|
| |
| if duration - r["timestamp"][-1][1] > HUGE_GAP_THRESHOLD: |
| huge_gap = True |
|
|
| results.append( |
| { |
| "text": text, |
| "duration": duration, |
| "huge_gap": huge_gap, |
| } |
| ) |
|
|
| return results |
|
|