Spaces:
Runtime error
Runtime error
| import torch | |
| import yaml | |
| import os | |
| import re | |
| from vita.model.vita_tts.utils import init_encoder_llm, load_checkpoint | |
| class inferencePipeline(): | |
| def __init__(self, args): | |
| self.args = args | |
| with open(self.args.model_path + "/audiollm/train.yaml", 'r') as fin: | |
| configs = yaml.safe_load(fin) | |
| configs['cmvn_file'] = self.args.model_path + "/audiollm/global_cmvn" | |
| configs['model_conf']['llm_path'] = self.args.llm_path | |
| # Init asr model from configs | |
| self.model = init_encoder_llm(configs) | |
| load_checkpoint(self.model, self.args.model_path + "/audiollm/final.pt") | |
| device = torch.device('cuda') | |
| self.model = self.model.to(device) | |
| self.model.eval() | |
| def speech_dialogue(self, | |
| audio: tuple, | |
| role: str=None, | |
| stat: str='sl', | |
| past_key_values=None, | |
| last_id=None, | |
| past_tokens=None, | |
| adapter_cache=None, | |
| encoder_cache=None, | |
| pe_index=0): | |
| with torch.no_grad(): | |
| ## input fbank | |
| feats = audio | |
| if feats is not None: | |
| feats = feats.to('cuda') | |
| feats_lengths = torch.tensor([feats.size(1)]).to('cuda') | |
| else: | |
| feats_lengths = None | |
| extra_inputs = {} | |
| extra_inputs['top_p'] = self.args.top_p | |
| extra_inputs['top_k'] = self.args.top_k | |
| extra_inputs['temperature'] = self.args.temperature | |
| extra_inputs['past_key_values'] = past_key_values | |
| extra_inputs['stat'] = stat | |
| extra_inputs['last_id'] = last_id | |
| extra_inputs['adapter_cache'] = adapter_cache | |
| extra_inputs['encoder_cache'] = encoder_cache | |
| extra_inputs['pe_index'] = pe_index | |
| if role is not None and past_key_values is None: | |
| # add <|im_end|> in chat_prefix | |
| extra_inputs['role'] = '<|im_start|>system\n' + role # + '<|im_end|>' | |
| with torch.autocast(device_type="cuda", | |
| dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32): | |
| # preprocess system role first | |
| if stat == 'pre': | |
| past_key_values = self.model.set_system_role(extra_inputs) | |
| stat = 'sl' | |
| else: | |
| (last_id, stat, past_key_values, adapter_cache, | |
| encoder_cache, pe_index, hidden_state) = self.model.recognize( | |
| feats, | |
| feats_lengths, | |
| extra_inputs=extra_inputs) | |
| outputs = dict( | |
| past_key_values=past_key_values, | |
| stat=stat, | |
| last_id=last_id, | |
| adapter_cache=adapter_cache, | |
| encoder_cache=encoder_cache, | |
| pe_index=pe_index, | |
| ) | |
| if stat == 'cs': | |
| if past_tokens is None: | |
| past_tokens = [] | |
| past_tokens.append(last_id[0][0]) | |
| text = self.model.tokenizer.decode(past_tokens, skip_special_tokens=True) | |
| outputs['hidden_state'] = hidden_state | |
| outputs['text'] = text | |
| outputs['past_tokens'] = past_tokens | |
| return outputs | |
| def post_process(self, text): | |
| """ | |
| Post-processes the input text to standardize various characters and formatting. | |
| Parameters: | |
| - text (str): The input text string to be post-processed. | |
| Actions: | |
| 1. Replaces various Chinese and English punctuation marks with standardized ones. | |
| 2. Removes newline, tab, and other unwanted whitespace characters. | |
| 3. Removes special characters like asterisks, underscores, backticks, and tildes. | |
| 4. Condenses whitespace following periods and colons. | |
| 5. Adjusts the format of numbered lists to use appropriate separators | |
| 6. Ensures the text ends with an appropriate punctuation mark | |
| Returns: | |
| - str: The post-processed text string. | |
| """ | |
| text = text.replace('、', ',') | |
| text = text.replace('(', ',') | |
| text = text.replace(')', ',') | |
| text = text.replace('(', ',') | |
| text = text.replace(')', ',') | |
| text = re.sub(r'[\n\r\t]', '', text) | |
| text = re.sub(r'[*_`~]', '', text) | |
| text = re.sub(r'(\.|\:)\s+', r'\1', text) | |
| if re.search(r'[\u4e00-\u9fa5]', text): | |
| text = re.sub(r'(\d+)\.\s*([\u4e00-\u9fa5A-Za-z])', r'\1:\2', text) | |
| else: | |
| text = re.sub(r'(\d+)\.\s*([\w])', r'\1:\2', text) | |
| if text and text[-1] not in ["。", "?", "!", ".", "?", "!"]: | |
| if text[-1] in [",", ",", ";", ";", ":", ":", "、"]: | |
| text = text[:-1] + "。" | |
| else: | |
| text += "。" | |
| return text | |