|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import torch |
|
|
import argparse |
|
|
from model import SinusoidalPositionEncoder |
|
|
from utils.ax_model_bin import AX_SenseVoiceSmall |
|
|
from utils.ax_vad_bin import AX_Fsmn_vad |
|
|
from utils.vad_utils import merge_vad |
|
|
from utils.ax_cam_bin import AX_SpeakerEmbeddingInference, do_clustering, distribute_spk, get_trans_sentence_sensevoice |
|
|
from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer |
|
|
from utils.ax_cam_bin import chunk |
|
|
import time |
|
|
import librosa |
|
|
import soundfile as sf |
|
|
import numpy as np |
|
|
from utils.infer_func import InferManager |
|
|
from ml_dtypes import bfloat16 |
|
|
from transformers import AutoConfig, AutoTokenizer |
|
|
from loguru import logger |
|
|
|
|
|
|
|
|
llm_axmodel_path = "./ax_model/Qwen3-4B-Instruct-2507-GPTQ-Int4_8k_axmodel" |
|
|
llm_hf_tokenizer_path = "./tokenizer_qwen3_int4" |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
embeds = np.load(os.path.join(llm_axmodel_path, "model.embed_tokens.weight.npy")) |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description="SenseVoice inference script") |
|
|
parser.add_argument("--output_dir", type=str, default="./output_dir", help="Output directory") |
|
|
parser.add_argument("--seq_len", type=int, default=132, help="Sequence length for model") |
|
|
parser.add_argument("--wav_file", type=str, default="wav/vad_example.wav",help="Input wav file") |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def run_model(prompt, max_seq_len=8191, slice_len=256, max_prefill_len=4095): |
|
|
messages = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": "你叫小惠, 你是一个专业的会议记录分析助手, 善于从会议记录(按照时间先后记录不同人物的发言)中提取关键信息并生成合适的总结. \n 请你基于以下会议记录, 在深度思考后, 总结这段会议记录的参会人员以及内容摘要. ", |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": prompt, |
|
|
}, |
|
|
] |
|
|
|
|
|
text = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
) |
|
|
model_inputs = tokenizer([text], return_tensors="pt").to(device) |
|
|
input_ids = model_inputs.input_ids |
|
|
|
|
|
|
|
|
|
|
|
token_ids = input_ids[0].cpu().numpy().tolist() |
|
|
token_len = len(token_ids) |
|
|
|
|
|
assert token_len <= max_prefill_len, f"Input token length {token_len} exceeds max prefill length {max_prefill_len}" |
|
|
|
|
|
prefill_data = np.take(embeds, token_ids, axis=0) |
|
|
prefill_data = prefill_data.astype(bfloat16) |
|
|
|
|
|
imer = InferManager(cfg, llm_axmodel_path, max_seq_len=max_seq_len, max_prefill_len=max_prefill_len) |
|
|
token_ids = imer.prefill(tokenizer, token_ids, prefill_data, slice_len=slice_len) |
|
|
imer.decode(tokenizer, token_ids, embeds, slice_len=slice_len, eos_token_id=eos_token_id) |
|
|
|
|
|
|
|
|
print(f"\n输入 token 数: {token_len}") |
|
|
print(f"输出 token 数: {len(token_ids) - token_len}") |
|
|
print(f"kv cache 总长度: {max_seq_len}") |
|
|
|
|
|
|
|
|
output_text = tokenizer.decode(token_ids[token_len:], skip_special_tokens=True) |
|
|
return output_text |
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = parse_args() |
|
|
seq_len = args.seq_len |
|
|
model_path = args.output_dir |
|
|
os.makedirs(model_path, exist_ok=True) |
|
|
|
|
|
print(f"Initializing model...") |
|
|
|
|
|
ax_model_dir = "ax_model" |
|
|
total_inference_start = time.time() |
|
|
|
|
|
model_vad = AX_Fsmn_vad(ax_model_dir) |
|
|
speaker_model = AX_SpeakerEmbeddingInference(model_dir=ax_model_dir) |
|
|
embed = SinusoidalPositionEncoder() |
|
|
position_encoding = embed.get_position_encoding(torch.randn(1, seq_len, 560)).numpy() |
|
|
model_bin = AX_SenseVoiceSmall(ax_model_dir, seq_len=seq_len) |
|
|
|
|
|
|
|
|
print(f"Loading tokenizer...") |
|
|
tokenizer = None |
|
|
tokenizer_path = os.path.join(ax_model_dir, "chn_jpn_yue_eng_ko_spectok.bpe.model") |
|
|
tokenizer = SentencepiecesTokenizer(bpemodel=tokenizer_path) |
|
|
|
|
|
|
|
|
wav_file = args.wav_file |
|
|
|
|
|
print(f"Running inference on example file...") |
|
|
|
|
|
withitn = True |
|
|
norm_type = "withitn" if withitn else "woitn" |
|
|
print(f"\nProcessing with text normalization: {norm_type}") |
|
|
|
|
|
language = "auto" |
|
|
print(f"\n--- Processing language: {language} ---") |
|
|
print(f"Processing file: {wav_file}") |
|
|
inference_start = time.time() |
|
|
|
|
|
|
|
|
speech, fs = librosa.load(wav_file, sr=None) |
|
|
|
|
|
if fs != 16000: |
|
|
print(f"Resampling audio from {fs}Hz to 16000Hz") |
|
|
speech = librosa.resample(y=speech, orig_sr=fs, target_sr=16000) |
|
|
fs = 16000 |
|
|
audio_duration = librosa.get_duration(y=speech, sr=fs) |
|
|
speech_lengths = len(speech) |
|
|
|
|
|
try: |
|
|
|
|
|
vad_start_time = time.time() |
|
|
res_vad = model_vad(speech)[0] |
|
|
vad_segments = merge_vad(res_vad, 15 * 1000) |
|
|
vad_time_cost = time.time() - vad_start_time |
|
|
print(f"VAD processing time: {vad_time_cost:.2f} seconds") |
|
|
|
|
|
|
|
|
vad_time = [[vad_t[0]/1000, vad_t[1]/1000] for vad_t in res_vad] |
|
|
chunks = [c for (st, ed) in vad_time for c in chunk(st, ed)] |
|
|
|
|
|
|
|
|
print("Extracting speaker embeddings...") |
|
|
speaker_start_time = time.time() |
|
|
|
|
|
|
|
|
embeddings = speaker_model(speech, fs, chunks=chunks) |
|
|
speaker_time_cost = time.time() - speaker_start_time |
|
|
print(f"Speaker embedding extraction time: {speaker_time_cost:.2f} seconds") |
|
|
print(f"Generated embeddings shape: {embeddings.shape}") |
|
|
|
|
|
clustering_start_time = time.time() |
|
|
speaker_num, diar_results = do_clustering(chunks, embeddings, speaker_num=None) |
|
|
clustering_time_cost = time.time() - clustering_start_time |
|
|
print(f"Speaker clustering time: {clustering_time_cost:.2f} seconds") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
all_results = [] |
|
|
all_metadata = {} |
|
|
|
|
|
|
|
|
asr_start_time = time.time() |
|
|
for i, segment in enumerate(vad_segments): |
|
|
segment_start, segment_end = segment |
|
|
|
|
|
start_sample = int(segment_start / 1000 * fs) |
|
|
end_sample = min(int(segment_end / 1000 * fs), speech_lengths) |
|
|
segment_speech = speech[start_sample:end_sample] |
|
|
|
|
|
|
|
|
time_offset_sec = segment_start / 1000.0 |
|
|
|
|
|
|
|
|
segment_filename = f"temp_segment_{i}.wav" |
|
|
sf.write(segment_filename, segment_speech, fs) |
|
|
|
|
|
|
|
|
try: |
|
|
segment_res, segment_meta = model_bin( |
|
|
segment_filename, |
|
|
language, |
|
|
withitn, |
|
|
position_encoding, |
|
|
tokenizer=tokenizer, |
|
|
output_timestamp=True, |
|
|
ban_emo_unk=False, |
|
|
output_dir=model_path, |
|
|
key=[f"{os.path.basename(wav_file)}_segment_{i}"] |
|
|
) |
|
|
|
|
|
if "merged_words" in segment_meta: |
|
|
if "merged_words" not in all_metadata: |
|
|
all_metadata["merged_words"] = [] |
|
|
all_metadata["merged_words"].extend(segment_meta["merged_words"]) |
|
|
|
|
|
if "merged_timestamps" in segment_meta: |
|
|
if "merged_timestamps" not in all_metadata: |
|
|
all_metadata["merged_timestamps"] = [] |
|
|
adjusted_timestamps = [[min(ts[0] + time_offset_sec, audio_duration), min(ts[1] + time_offset_sec, audio_duration)] |
|
|
for ts in segment_meta["merged_timestamps"]] |
|
|
all_metadata["merged_timestamps"].extend(adjusted_timestamps) |
|
|
|
|
|
if os.path.exists(segment_filename): |
|
|
os.remove(segment_filename) |
|
|
|
|
|
except Exception as e: |
|
|
if os.path.exists(segment_filename): |
|
|
os.remove(segment_filename) |
|
|
raise |
|
|
|
|
|
output_asr = { |
|
|
"merged_words": all_metadata.get("merged_words", []), |
|
|
"merged_timestamps": all_metadata.get("merged_timestamps", []) |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
raise |
|
|
|
|
|
asr_time_cost = time.time() - asr_start_time |
|
|
print(f"ASR processing time: {asr_time_cost:.2f} seconds") |
|
|
|
|
|
asr_timestamps = get_trans_sentence_sensevoice(output_asr) |
|
|
sentence_info_with_spk = distribute_spk(asr_timestamps, diar_results) |
|
|
inference_time_cost = time.time() - inference_start |
|
|
inference_time_cost_all = time.time() - total_inference_start |
|
|
rtf = inference_time_cost / audio_duration |
|
|
print(f"Inference time for {wav_file}: {inference_time_cost:.2f} seconds") |
|
|
print(f"load model + Inference time for {wav_file}: {inference_time_cost_all:.2f} seconds") |
|
|
print(f"Audio duration: {audio_duration:.2f} seconds") |
|
|
print(f"RTF: {rtf:.2f}") |
|
|
|
|
|
|
|
|
output_trans_path = os.path.join(args.output_dir, f"{wav_file.split('/')[-1]}.txt") |
|
|
|
|
|
try: |
|
|
with open(output_trans_path, 'w', encoding='utf-8') as f: |
|
|
for text_string, timeinterval, spk in sentence_info_with_spk: |
|
|
f.write(f'Speaker_{spk}: [{timeinterval[0]:.3f} {timeinterval[1]:.3f}] {text_string}\n') |
|
|
except Exception as e: |
|
|
pass |
|
|
|
|
|
logger.info("Starting LLM Inference for meeting summary...") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(llm_hf_tokenizer_path) |
|
|
cfg = AutoConfig.from_pretrained(llm_hf_tokenizer_path, trust_remote_code=True) |
|
|
|
|
|
eos_token_id = None |
|
|
if isinstance(cfg.eos_token_id, list) and len(cfg.eos_token_id) > 1: |
|
|
eos_token_id = cfg.eos_token_id |
|
|
|
|
|
with open(output_trans_path, 'r', encoding='utf-8') as file: |
|
|
prompt = file.read() |
|
|
|
|
|
prompt_list = prompt.split('\n') |
|
|
|
|
|
output_text_l = [] |
|
|
cur_prompt = "" |
|
|
for idx, prompt in enumerate(prompt_list): |
|
|
|
|
|
cur_prompt += prompt + "\n" |
|
|
|
|
|
if len(cur_prompt) < 5000 and idx != len(prompt_list) -1: |
|
|
continue |
|
|
|
|
|
output_text = run_model(cur_prompt) |
|
|
output_text_l.append(output_text) |
|
|
print("\n") |
|
|
|
|
|
cur_prompt = "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_summary_path = os.path.join(args.output_dir, f"{wav_file.split('/')[-1]}_summary.md") |
|
|
output_text_combined = "\n\n".join(output_text_l) |
|
|
with open(output_summary_path, 'w', encoding='utf-8') as f: |
|
|
f.write(output_text_combined) |
|
|
|
|
|
logger.info(f"LLM Inference completed. Summary saved to {output_summary_path}") |