| | import torch |
| | import torchaudio |
| | from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2ForCTC |
| |
|
| |
|
| | def max_alignment(s1, s2, skip_character="~", record=None): |
| | """ |
| | A clever function that aligns s1 to s2 as best it can. Wherever a character from s1 is not found in s2, a '~' is |
| | used to replace that character. |
| | |
| | Finally got to use my DP skills! |
| | """ |
| | if record is None: |
| | record = {} |
| | assert ( |
| | skip_character not in s1 |
| | ), f"Found the skip character {skip_character} in the provided string, {s1}" |
| | if len(s1) == 0: |
| | return "" |
| | if len(s2) == 0: |
| | return skip_character * len(s1) |
| | if s1 == s2: |
| | return s1 |
| | if s1[0] == s2[0]: |
| | return s1[0] + max_alignment(s1[1:], s2[1:], skip_character, record) |
| |
|
| | take_s1_key = (len(s1), len(s2) - 1) |
| | if take_s1_key in record: |
| | take_s1, take_s1_score = record[take_s1_key] |
| | else: |
| | take_s1 = max_alignment(s1, s2[1:], skip_character, record) |
| | take_s1_score = len(take_s1.replace(skip_character, "")) |
| | record[take_s1_key] = (take_s1, take_s1_score) |
| |
|
| | take_s2_key = (len(s1) - 1, len(s2)) |
| | if take_s2_key in record: |
| | take_s2, take_s2_score = record[take_s2_key] |
| | else: |
| | take_s2 = max_alignment(s1[1:], s2, skip_character, record) |
| | take_s2_score = len(take_s2.replace(skip_character, "")) |
| | record[take_s2_key] = (take_s2, take_s2_score) |
| |
|
| | return take_s1 if take_s1_score > take_s2_score else skip_character + take_s2 |
| |
|
| |
|
| | class Wav2VecAlignment: |
| | """ |
| | Uses wav2vec2 to perform audio<->text alignment. |
| | """ |
| |
|
| | def __init__(self, device="cuda"): |
| | self.model = Wav2Vec2ForCTC.from_pretrained( |
| | "jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli" |
| | ).cpu() |
| | self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( |
| | "facebook/wav2vec2-large-960h" |
| | ) |
| | self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained( |
| | "jbetker/tacotron-symbols" |
| | ) |
| | self.device = device |
| |
|
| | def align(self, audio, expected_text, audio_sample_rate=24000): |
| | orig_len = audio.shape[-1] |
| |
|
| | with torch.no_grad(): |
| | self.model = self.model.to(self.device) |
| | audio = audio.to(self.device) |
| | audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000) |
| | clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) |
| | logits = self.model(clip_norm).logits |
| | self.model = self.model.cpu() |
| |
|
| | logits = logits[0] |
| | pred_string = self.tokenizer.decode(logits.argmax(-1).tolist()) |
| |
|
| | fixed_expectation = max_alignment(expected_text.lower(), pred_string) |
| | w2v_compression = orig_len // logits.shape[0] |
| | expected_tokens = self.tokenizer.encode(fixed_expectation) |
| | expected_chars = list(fixed_expectation) |
| | if len(expected_tokens) == 1: |
| | return [0] |
| | expected_tokens.pop(0) |
| | expected_chars.pop(0) |
| |
|
| | alignments = [0] |
| |
|
| | def pop_till_you_win(): |
| | if len(expected_tokens) == 0: |
| | return None |
| | popped = expected_tokens.pop(0) |
| | popped_char = expected_chars.pop(0) |
| | while popped_char == "~": |
| | alignments.append(-1) |
| | if len(expected_tokens) == 0: |
| | return None |
| | popped = expected_tokens.pop(0) |
| | popped_char = expected_chars.pop(0) |
| | return popped |
| |
|
| | next_expected_token = pop_till_you_win() |
| | for i, logit in enumerate(logits): |
| | top = logit.argmax() |
| | if next_expected_token == top: |
| | alignments.append(i * w2v_compression) |
| | if len(expected_tokens) > 0: |
| | next_expected_token = pop_till_you_win() |
| | else: |
| | break |
| |
|
| | pop_till_you_win() |
| | if not (len(expected_tokens) == 0 and len(alignments) == len(expected_text)): |
| | torch.save([audio, expected_text], "alignment_debug.pth") |
| | assert False, ( |
| | "Something went wrong with the alignment algorithm. I've dumped a file, 'alignment_debug.pth' to" |
| | "your current working directory. Please report this along with the file so it can get fixed." |
| | ) |
| |
|
| | |
| | alignments.append( |
| | orig_len |
| | ) |
| | for i in range(len(alignments)): |
| | if alignments[i] == -1: |
| | for j in range(i + 1, len(alignments)): |
| | if alignments[j] != -1: |
| | next_found_token = j |
| | break |
| | for j in range(i, next_found_token): |
| | gap = alignments[next_found_token] - alignments[i - 1] |
| | alignments[j] = (j - i + 1) * gap // ( |
| | next_found_token - i + 1 |
| | ) + alignments[i - 1] |
| |
|
| | return alignments[:-1] |
| |
|
| | def redact(self, audio, expected_text, audio_sample_rate=24000): |
| | if "[" not in expected_text: |
| | return audio |
| | splitted = expected_text.split("[") |
| | fully_split = [splitted[0]] |
| | for spl in splitted[1:]: |
| | assert ( |
| | "]" in spl |
| | ), 'Every "[" character must be paired with a "]" with no nesting.' |
| | fully_split.extend(spl.split("]")) |
| |
|
| | |
| | non_redacted_intervals = [] |
| | last_point = 0 |
| | for i in range(len(fully_split)): |
| | if i % 2 == 0: |
| | end_interval = max(0, last_point + len(fully_split[i]) - 1) |
| | non_redacted_intervals.append((last_point, end_interval)) |
| | last_point += len(fully_split[i]) |
| |
|
| | bare_text = "".join(fully_split) |
| | alignments = self.align(audio, bare_text, audio_sample_rate) |
| |
|
| | output_audio = [] |
| | for nri in non_redacted_intervals: |
| | start, stop = nri |
| | output_audio.append(audio[:, alignments[start] : alignments[stop]]) |
| | return torch.cat(output_audio, dim=-1) |
| |
|