Spaces:
Running
Running
| import re | |
| from dataclasses import dataclass | |
| from typing import Any, List, Dict, Union | |
| import torch | |
| from zhconv import convert | |
| # 删除标点符号 | |
| def remove_punctuation(text: str or List[str]): | |
| punctuation = '!,.;:?、!,。;:?' | |
| if isinstance(text, str): | |
| text = re.sub(r'[{}]+'.format(punctuation), '', text).strip() | |
| return text | |
| elif isinstance(text, list): | |
| result_text = [] | |
| for t in text: | |
| t = re.sub(r'[{}]+'.format(punctuation), '', t).strip() | |
| result_text.append(t) | |
| return result_text | |
| else: | |
| raise Exception(f'不支持该类型{type(text)}') | |
| # 将繁体中文总成简体中文 | |
| def to_simple(text: str or List[str]): | |
| if isinstance(text, str): | |
| text = convert(text, 'zh-cn') | |
| return text | |
| elif isinstance(text, list): | |
| result_text = [] | |
| for t in text: | |
| t = convert(t, 'zh-cn') | |
| result_text.append(t) | |
| return result_text | |
| else: | |
| raise Exception(f'不支持该类型{type(text)}') | |
| class DataCollatorSpeechSeq2SeqWithPadding: | |
| processor: Any | |
| def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: | |
| # split inputs and labels since they have to be of different lengths and need different padding methods | |
| # first treat the audio inputs by simply returning torch tensors | |
| input_features = [{"input_features": feature["input_features"][0]} for feature in features] | |
| batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") | |
| # get the tokenized label sequences | |
| label_features = [{"input_ids": feature["labels"]} for feature in features] | |
| # pad the labels to max length | |
| labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") | |
| # replace padding with -100 to ignore loss correctly | |
| labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) | |
| # if bos token is appended in previous tokenization step, | |
| # cut bos token here as it's append later anyways | |
| if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item(): | |
| labels = labels[:, 1:] | |
| batch["labels"] = labels | |
| return batch | |