| import base64 |
| import httpx |
| import re |
| import requests |
| import torch |
| import torchaudio.functional as F |
| import torchaudio |
| import uroman as ur |
| import logging |
| import traceback |
|
|
|
|
| def convert_to_list_with_punctuation_mixed(text): |
| """处理中文文本(可能包含英文单词) - 中文按字符分割,英文单词保持完整""" |
| result = [] |
| text = text.strip() |
| |
| if not text: |
| return result |
| |
| def is_chinese(char): |
| """检查是否是汉字""" |
| return '\u4e00' <= char <= '\u9fff' |
| |
| |
| |
| pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]' |
| tokens = re.findall(pattern, text) |
| |
| for token in tokens: |
| if not token.strip(): |
| continue |
| |
| if re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): |
| result.append(token) |
| elif is_chinese(token): |
| result.append(token) |
| else: |
| |
| if result: |
| result[-1] += token |
| else: |
| |
| result.append(token) |
| |
| return result |
|
|
| def split_and_merge_punctuation(text): |
| """处理英文 - 按单词分割,保持单词完整性""" |
| |
| elements = text.split() |
| |
| |
| result = [] |
| |
| |
| for ele in elements: |
| |
| parts = re.findall(r'[a-zA-Z0-9]+|[^\w\s]+', ele) |
| |
| |
| merged_parts = [] |
| |
| for i in range(len(parts)): |
| if i % 2 == 0: |
| |
| merged_parts.append(parts[i]) |
| else: |
| |
| if merged_parts: |
| merged_parts[-1] += parts[i] |
| else: |
| merged_parts.append(parts[i]) |
| |
| |
| result.extend(merged_parts) |
| |
| return result |
|
|
|
|
| def get_aligned_result_text_with_punctuation(alignment_result, text, language): |
| """ |
| 将对齐结果转换为正确的文本tokens,英文保持单词级别,中文保持字符级别(但英文单词完整) |
| """ |
| logging.info("start change text to text_tokens") |
| |
| if language == "EN": |
| text_tokens = split_and_merge_punctuation(text) |
| elif language == "ZH": |
| text_tokens = convert_to_list_with_punctuation_mixed(text) |
| else: |
| raise ValueError(f"Unsupported language: {language}") |
|
|
| logging.info(f"Text tokens count: {len(text_tokens)}, Alignment result count: {len(alignment_result)}") |
| |
| punctuations = set(',.!?;:()[]<>\'\"…·,。;:!?()【】《》''""\、') |
|
|
| logging.info("start get align result text with punctuation") |
| updated_alignment_result = [] |
| token_idx = 0 |
| |
| for index, align_item in enumerate(alignment_result): |
| if token_idx >= len(text_tokens): |
| |
| logging.warning(f"Text tokens exhausted at index {token_idx}, but alignment has more items") |
| break |
| |
| start = align_item["start"] |
| end = align_item["end"] |
| text_token = text_tokens[token_idx] |
| |
| |
| if language == "ZH": |
| while token_idx + 1 < len(text_tokens) and text_tokens[token_idx + 1] in punctuations: |
| assert False, "???" |
| text_token += text_tokens[token_idx + 1] |
| token_idx += 1 |
| else: |
| |
| pass |
| |
| |
| updated_item = { |
| "start": start, |
| "end": end, |
| "transcript": text_token |
| } |
| updated_item.update({key: align_item[key] for key in align_item if key not in ["start", "end", "transcript"]}) |
| |
| updated_alignment_result.append(updated_item) |
| token_idx += 1 |
|
|
| logging.info("end get align result text with punctuation") |
| return updated_alignment_result |
|
|
|
|
| class AlignmentModel: |
| def __init__(self, device, model_dir='/data-mnt/data/wy/X-Codec-2.0/checkpoints'): |
| """ |
| 初始化对齐模型并加载必要的资源 |
| """ |
| self.device = torch.device(device) |
| self.bundle = torchaudio.pipelines.MMS_FA |
| model = self.bundle.get_model(with_star=False, dl_kwargs={'model_dir': model_dir}).to(self.device) |
| |
| |
| |
| |
| print("Compiling the model... This may take a moment.") |
| self.align_model = torch.compile(model, mode="reduce-overhead", fullgraph=True) |
| print("Model compiled successfully.") |
|
|
| self.uroman = ur.Uroman() |
| self.DICTIONARY = self.bundle.get_dict() |
|
|
| def align(self, emission, tokens): |
| """ |
| 执行强对齐 |
| :param emission: 模型的输出 |
| :param tokens: 目标 tokens |
| :return: 对齐的 tokens 和分数 |
| """ |
| alignments, scores = F.forced_align( |
| log_probs=emission, |
| targets=tokens, |
| blank=0 |
| ) |
| alignments, scores = alignments[0], scores[0] |
| scores = scores.exp() |
| return alignments, scores |
|
|
| def unflatten(self, list_, lengths): |
| """ |
| 将一个长列表按照长度拆分成子列表 |
| :param list_: 长列表 |
| :param lengths: 各子列表的长度 |
| :return: 拆分后的子列表 |
| """ |
| assert len(list_) == sum(lengths) |
| i = 0 |
| ret = [] |
| for l in lengths: |
| ret.append(list_[i:i + l]) |
| i += l |
| return ret |
|
|
| def preview_word(self, waveform, spans, num_frames, transcript, sample_rate): |
| """ |
| 预览每个单词的开始时间和结束时间 |
| :param waveform: 音频波形 |
| :param spans: 单词的跨度 |
| :param num_frames: 帧数 |
| :param transcript: 转录文本 |
| :param sample_rate: 采样率 |
| :return: 单词的对齐信息 |
| """ |
| end = 0 |
| alignment_result = [] |
| for span, trans in zip(spans, transcript): |
| ratio = waveform.size(1) / num_frames |
| x0 = int(ratio * span[0].start) |
| x1 = int(ratio * span[-1].end) |
| align_info = { |
| "transcript": trans, |
| "start": round(x0 / sample_rate, 3), |
| "end": round(x1 / sample_rate, 3) |
| } |
| align_info["pause"] = round(align_info["start"] - end, 3) |
| align_info["duration"] = round(align_info["end"] - align_info["start"], 3) |
| end = align_info["end"] |
| alignment_result.append(align_info) |
| return alignment_result |
|
|
| def make_wav_batch(self, wav_list): |
| """ |
| 将 wav_list 中的每个 wav 张量填充为相同的长度,返回填充后的张量和每个张量的原始长度。 |
| :param wav_list: wav 文件列表 |
| :return: 填充后的音频张量和原始长度 |
| """ |
| wav_lengths = torch.tensor([wav.size(0) for wav in wav_list], dtype=torch.long) |
| max_length = max(wav_lengths) |
| |
| wavs_tensors = torch.zeros(len(wav_list), max_length, device=self.device) |
| for i, wav in enumerate(wav_list): |
| wav = wav.to(self.device) |
| wavs_tensors[i, :wav_lengths[i]] = wav |
| return wavs_tensors, wav_lengths.to(self.device) |
|
|
| def get_target(self, transcript, language): |
| """ |
| 获取给定转录文本的目标 tokens - 修正版本,保持英文单词完整性 |
| """ |
| original_transcript = transcript |
| |
| if language == "ZH": |
| |
| |
| pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]' |
| tokens = re.findall(pattern, transcript) |
| |
| |
| processed_parts = [] |
| for token in tokens: |
| if not token.strip(): |
| continue |
| elif re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): |
| |
| processed_parts.append(token.lower()) |
| elif '\u4e00' <= token <= '\u9fff': |
| |
| romanized = self.uroman.romanize_string(token) |
| processed_parts.append(romanized) |
| else: |
| |
| processed_parts.append(token) |
| |
| |
| transcript = ' '.join(processed_parts) |
| |
| elif language == "EN": |
| |
| pass |
| else: |
| assert False, f"Unsupported language: {language}" |
|
|
| |
| transcript = re.sub(r'[^\w\s]', r' ', transcript) |
| TRANSCRIPT = transcript.lower().split() |
| |
| |
| star_token = self.DICTIONARY['*'] |
| tokenized_transcript = [] |
|
|
| |
| for word in TRANSCRIPT: |
| |
| word_tokens = [] |
| for c in word: |
| if c in self.DICTIONARY and c != '-': |
| word_tokens.append(self.DICTIONARY[c]) |
| else: |
| word_tokens.append(star_token) |
| tokenized_transcript.extend(word_tokens) |
| |
| logging.info(f"Original transcript: {original_transcript}") |
| logging.info(f"Processed transcript: {transcript}") |
| logging.info(f"Final TRANSCRIPT: {TRANSCRIPT}") |
| |
| return torch.tensor([tokenized_transcript], dtype=torch.int32, device=self.device) |
|
|
| def get_alignment_result(self, emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language): |
| """ |
| 根据给定的 emission 和对齐信息生成对齐结果 - 修正版本 |
| """ |
| original_transcript = transcript |
| |
| if language == "ZH": |
| |
| pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]' |
| tokens = re.findall(pattern, transcript) |
| |
| processed_parts = [] |
| for token in tokens: |
| if not token.strip(): |
| continue |
| elif re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): |
| processed_parts.append(token.lower()) |
| elif '\u4e00' <= token <= '\u9fff': |
| romanized = self.uroman.romanize_string(token) |
| processed_parts.append(romanized) |
| else: |
| processed_parts.append(token) |
| |
| transcript = ' '.join(processed_parts) |
| elif language == "EN": |
| pass |
| else: |
| assert False, f"Unsupported language: {language}" |
| |
| transcript = re.sub(r'[^\w\s]', r' ', transcript) |
| emission = emission_padded[:emission_length, :].unsqueeze(0) |
| TRANSCRIPT = transcript.lower().split() |
| |
| token_spans = F.merge_tokens(aligned_tokens, alignment_scores) |
| |
| |
| word_spans = self.unflatten(token_spans, [len(word) for word in TRANSCRIPT]) |
| |
| num_frames = emission.size(1) |
| |
| logging.info(f"Original transcript for alignment: {original_transcript}") |
| logging.info(f"Processed TRANSCRIPT: {TRANSCRIPT}") |
| |
| return self.preview_word(waveform.unsqueeze(0), word_spans, num_frames, TRANSCRIPT, self.bundle.sample_rate) |
|
|
| def batch_alignment(self, wav_list, transcript_list, language_list): |
| """ |
| 批量对齐 |
| :param wav_list: wav 文件列表 |
| :param transcript_list: 转录文本列表 |
| :param language_list: 语言类型列表 |
| :return: 对齐结果列表 |
| """ |
| wavs_tensors, wavs_lengths_tensor = self.make_wav_batch(wav_list) |
| logging.info("start alignment model forward") |
| with torch.inference_mode(): |
| emission, emission_lengths = self.align_model(wavs_tensors.to(self.device), wavs_lengths_tensor) |
| star_dim = torch.zeros((emission.shape[0], emission.size(1), 1), dtype=emission.dtype, device=self.device) |
| emission = torch.cat((emission, star_dim), dim=-1) |
| |
| logging.info("end alignment model forward") |
| |
| target_list = [self.get_target(transcript, language) for transcript, language in zip(transcript_list, language_list)] |
| |
| logging.info("align success") |
| align_results = [ |
| self.align(emission_padded[:emission_length, :].unsqueeze(0), target) |
| for emission_padded, emission_length, target in zip(emission, emission_lengths, target_list) |
| ] |
| |
| logging.info("get align result") |
| batch_aligned_tokens = [align_result[0] for align_result in align_results] |
| batch_alignment_scores = [align_result[1] for align_result in align_results] |
|
|
| alignment_result_list = [ |
| self.get_alignment_result(emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language) |
| for emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language |
| in zip(emission, emission_lengths, batch_aligned_tokens, batch_alignment_scores, transcript_list, wav_list, language_list) |
| ] |
| logging.info("get align result success") |
| return alignment_result_list |
|
|
|
|
| async def batch_get_alignment_result_remote(alignment_url, audio_path, transcript, language): |
| """ |
| 通过调用远程对齐服务来批量获取对齐结果。 |
| """ |
| payload = { |
| "audio_path": audio_path, |
| "transcript": transcript, |
| "language": language, |
| } |
|
|
| try: |
| async with httpx.AsyncClient() as client: |
| response = await client.post(alignment_url, json=payload, timeout=300) |
| response.raise_for_status() |
| data = response.json() |
| return data['results'] |
|
|
| except requests.exceptions.RequestException as e: |
| logging.error(f"Failed to connect to alignment service: {e}") |
| traceback.print_exc() |
| |
| except Exception as e: |
| logging.error(f"An error occurred in remote alignment: {e}") |
| traceback.print_exc() |
|
|
|
|