| import os |
| import io |
| import glob |
| import math |
| import tarfile |
| import torch |
| import torchaudio |
| import safetensors |
| from .configuration_whisper import WhisperVQConfig |
| from .modeling_whisper import WhisperVQEncoder, WhisperVQForConditionalGeneration |
| from transformers import WhisperFeatureExtractor, WhisperTokenizerFast |
|
|
|
|
| def load_quantize_encoder(model_path): |
| config = WhisperVQConfig.from_pretrained(model_path) |
| config.quantize_encoder_only = True |
| model = WhisperVQEncoder(config) |
| state_dict = {} |
| for path in glob.glob(os.path.join(model_path, "model*.safetensors")): |
| with safetensors.safe_open(path, framework="pt", device="cpu") as f: |
| for key in f.keys(): |
| if key.startswith("model.encoder."): |
| new_key = key[len("model.encoder."):] |
| if new_key.startswith("layer_norm"): |
| continue |
| if new_key.startswith("layers"): |
| layer_id = int(new_key.split(".")[1]) |
| if layer_id >= config.quantize_position: |
| continue |
| state_dict[new_key] = f.get_tensor(key) |
| model.load_state_dict(state_dict) |
| model.eval() |
| model.cuda() |
| return model |
|
|
|
|
| _resample_buffer: dict[int, torchaudio.transforms.Resample] = {} |
|
|
|
|
| def extract_speech_token(model: WhisperVQEncoder, feature_extractor: WhisperFeatureExtractor, utts): |
| with torch.no_grad(): |
| audios, indices = [], [] |
| for idx, utt in enumerate(utts): |
| if isinstance(utt, tuple): |
| audio, sample_rate = utt |
| else: |
| audio, sample_rate = torchaudio.load(utt) |
| audio = audio.cuda() |
| if sample_rate != 16000: |
| if sample_rate not in _resample_buffer: |
| _resample_buffer[sample_rate] = torchaudio.transforms.Resample( |
| orig_freq=sample_rate, |
| new_freq=16000 |
| ).to('cuda') |
| audio = _resample_buffer[sample_rate](audio) |
| |
| |
| audio = audio[0] |
| audio = audio.cpu().numpy() |
| time_step = 0 |
| while time_step * 16000 < audio.shape[0]: |
| audio_segment = audio[time_step * 16000: (time_step + 30) * 16000] |
| audios.append(audio_segment) |
| indices.append(idx) |
| time_step += 30 |
| pooling_kernel_size = model.config.pooling_kernel_size or 1 |
| stride = model.conv1.stride[0] * model.conv2.stride[0] * pooling_kernel_size * feature_extractor.hop_length |
| all_speech_tokens = [[] for _ in range(len(utts))] |
| batch_size = 128 |
| for start in range(0, len(audios), batch_size): |
| features = feature_extractor(audios[start: start + batch_size], sampling_rate=16000, |
| return_attention_mask=True, return_tensors="pt", device='cuda', |
| padding="longest", pad_to_multiple_of=stride) |
| features = features.to(device="cuda") |
| outputs = model(**features) |
| speech_tokens = outputs.quantized_token_ids |
| attention_mask = features.attention_mask[:, ::model.conv1.stride[0] * model.conv2.stride[0]] |
| attention_mask = attention_mask[:, ::model.config.pooling_kernel_size] |
| assert attention_mask.shape == speech_tokens.shape |
| for i in range(len(speech_tokens)): |
| idx = indices[start + i] |
| speech_token = speech_tokens[i][attention_mask[i].bool()].tolist() |
| all_speech_tokens[idx].extend(speech_token) |
| return all_speech_tokens |
|
|