yongqiang
Update: save summary res to .md
2271615
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import os
import torch
import argparse
from model import SinusoidalPositionEncoder
from utils.ax_model_bin import AX_SenseVoiceSmall
from utils.ax_vad_bin import AX_Fsmn_vad
from utils.vad_utils import merge_vad
from utils.ax_cam_bin import AX_SpeakerEmbeddingInference, do_clustering, distribute_spk, get_trans_sentence_sensevoice
from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
from utils.ax_cam_bin import chunk
import time
import librosa
import soundfile as sf
import numpy as np
from utils.infer_func import InferManager
from ml_dtypes import bfloat16
from transformers import AutoConfig, AutoTokenizer
from loguru import logger
llm_axmodel_path = "./ax_model/Qwen3-4B-Instruct-2507-GPTQ-Int4_8k_axmodel"
llm_hf_tokenizer_path = "./tokenizer_qwen3_int4"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embeds = np.load(os.path.join(llm_axmodel_path, "model.embed_tokens.weight.npy"))
def parse_args():
parser = argparse.ArgumentParser(description="SenseVoice inference script")
parser.add_argument("--output_dir", type=str, default="./output_dir", help="Output directory")
parser.add_argument("--seq_len", type=int, default=132, help="Sequence length for model") #68 ,132
parser.add_argument("--wav_file", type=str, default="wav/vad_example.wav",help="Input wav file")
return parser.parse_args()
def run_model(prompt, max_seq_len=8191, slice_len=256, max_prefill_len=4095):
messages = [
{
"role": "system",
"content": "你叫小惠, 你是一个专业的会议记录分析助手, 善于从会议记录(按照时间先后记录不同人物的发言)中提取关键信息并生成合适的总结. \n 请你基于以下会议记录, 在深度思考后, 总结这段会议记录的参会人员以及内容摘要. ",
},
{
"role": "user",
"content": prompt,
},
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors="pt").to(device)
input_ids = model_inputs.input_ids
######################################################################
token_ids = input_ids[0].cpu().numpy().tolist()
token_len = len(token_ids)
assert token_len <= max_prefill_len, f"Input token length {token_len} exceeds max prefill length {max_prefill_len}"
# import pdb; pdb.set_trace()
prefill_data = np.take(embeds, token_ids, axis=0)
prefill_data = prefill_data.astype(bfloat16)
imer = InferManager(cfg, llm_axmodel_path, max_seq_len=max_seq_len, max_prefill_len=max_prefill_len) # prefill + decode max length
token_ids = imer.prefill(tokenizer, token_ids, prefill_data, slice_len=slice_len)
imer.decode(tokenizer, token_ids, embeds, slice_len=slice_len, eos_token_id=eos_token_id)
# 分别打印输入和输出 token 数, kv cache 总长度
print(f"\n输入 token 数: {token_len}")
print(f"输出 token 数: {len(token_ids) - token_len}")
print(f"kv cache 总长度: {max_seq_len}")
# 保存输出 token
output_text = tokenizer.decode(token_ids[token_len:], skip_special_tokens=True)
return output_text
if __name__ == "__main__":
args = parse_args()
seq_len = args.seq_len
model_path = args.output_dir
os.makedirs(model_path, exist_ok=True)
print(f"Initializing model...")
ax_model_dir = "ax_model"
total_inference_start = time.time()
model_vad = AX_Fsmn_vad(ax_model_dir)
speaker_model = AX_SpeakerEmbeddingInference(model_dir=ax_model_dir)
embed = SinusoidalPositionEncoder()
position_encoding = embed.get_position_encoding(torch.randn(1, seq_len, 560)).numpy()
model_bin = AX_SenseVoiceSmall(ax_model_dir, seq_len=seq_len)
# build tokenizer
print(f"Loading tokenizer...")
tokenizer = None
tokenizer_path = os.path.join(ax_model_dir, "chn_jpn_yue_eng_ko_spectok.bpe.model")
tokenizer = SentencepiecesTokenizer(bpemodel=tokenizer_path)
# Set up audio file for processing
wav_file = args.wav_file #S_R004S03C01
print(f"Running inference on example file...")
withitn = True
norm_type = "withitn" if withitn else "woitn"
print(f"\nProcessing with text normalization: {norm_type}")
language = "auto"
print(f"\n--- Processing language: {language} ---")
print(f"Processing file: {wav_file}")
inference_start = time.time()
# 加载音频数据
speech, fs = librosa.load(wav_file, sr=None)
# 检查采样率,如果不是16kHz则进行重采样
if fs != 16000:
print(f"Resampling audio from {fs}Hz to 16000Hz")
speech = librosa.resample(y=speech, orig_sr=fs, target_sr=16000)
fs = 16000
audio_duration = librosa.get_duration(y=speech, sr=fs)
speech_lengths = len(speech)
try:
#增加vad model 推理及处理
vad_start_time = time.time()
res_vad = model_vad(speech)[0]
vad_segments = merge_vad(res_vad, 15 * 1000) #短语音段合并 # vad_segments: [[0, 6480], [6480, 23670], [23670, 38210], [38210, 49910], [49910, 59820], [59820, 70550]]
vad_time_cost = time.time() - vad_start_time
print(f"VAD processing time: {vad_time_cost:.2f} seconds")
# emb_extraction
vad_time = [[vad_t[0]/1000, vad_t[1]/1000] for vad_t in res_vad]
chunks = [c for (st, ed) in vad_time for c in chunk(st, ed)]
# Extract speaker embeddings for each chunk
print("Extracting speaker embeddings...")
speaker_start_time = time.time()
# import pdb
# pdb.set_trace()
embeddings = speaker_model(speech, fs, chunks=chunks)
speaker_time_cost = time.time() - speaker_start_time
print(f"Speaker embedding extraction time: {speaker_time_cost:.2f} seconds")
print(f"Generated embeddings shape: {embeddings.shape}")
clustering_start_time = time.time()
speaker_num, diar_results = do_clustering(chunks, embeddings, speaker_num=None)
clustering_time_cost = time.time() - clustering_start_time
print(f"Speaker clustering time: {clustering_time_cost:.2f} seconds")
# print(f"VAD segments detected: {len(vad_segments)}")
# 存储所有分片结果
all_results = []
all_metadata = {}
# 遍历每个VAD片段并处理
asr_start_time = time.time()
for i, segment in enumerate(vad_segments):
segment_start, segment_end = segment
# 从原始音频中提取该片段
start_sample = int(segment_start / 1000 * fs)
end_sample = min(int(segment_end / 1000 * fs), speech_lengths)
segment_speech = speech[start_sample:end_sample]
# 计算时间偏移量(毫秒转秒)
time_offset_sec = segment_start / 1000.0
# 为当前片段创建临时文件
segment_filename = f"temp_segment_{i}.wav"
sf.write(segment_filename, segment_speech, fs)
# 对当前片段进行识别
try:
segment_res, segment_meta = model_bin(
segment_filename,
language,
withitn,
position_encoding,
tokenizer=tokenizer,
output_timestamp=True,
ban_emo_unk=False,
output_dir=model_path,
key=[f"{os.path.basename(wav_file)}_segment_{i}"]
)
if "merged_words" in segment_meta:
if "merged_words" not in all_metadata:
all_metadata["merged_words"] = []
all_metadata["merged_words"].extend(segment_meta["merged_words"])
if "merged_timestamps" in segment_meta:
if "merged_timestamps" not in all_metadata:
all_metadata["merged_timestamps"] = []
adjusted_timestamps = [[min(ts[0] + time_offset_sec, audio_duration), min(ts[1] + time_offset_sec, audio_duration)]
for ts in segment_meta["merged_timestamps"]] # 确保不超过音频总时长
all_metadata["merged_timestamps"].extend(adjusted_timestamps)
if os.path.exists(segment_filename):
os.remove(segment_filename)
except Exception as e:
if os.path.exists(segment_filename):
os.remove(segment_filename)
raise
output_asr = {
"merged_words": all_metadata.get("merged_words", []),
"merged_timestamps": all_metadata.get("merged_timestamps", [])
}
except Exception as e:
raise
asr_time_cost = time.time() - asr_start_time
print(f"ASR processing time: {asr_time_cost:.2f} seconds")
asr_timestamps = get_trans_sentence_sensevoice(output_asr)
sentence_info_with_spk = distribute_spk(asr_timestamps, diar_results)
inference_time_cost = time.time() - inference_start
inference_time_cost_all = time.time() - total_inference_start
rtf = inference_time_cost / audio_duration
print(f"Inference time for {wav_file}: {inference_time_cost:.2f} seconds")
print(f"load model + Inference time for {wav_file}: {inference_time_cost_all:.2f} seconds")
print(f"Audio duration: {audio_duration:.2f} seconds")
print(f"RTF: {rtf:.2f}")
# Save Results
output_trans_path = os.path.join(args.output_dir, f"{wav_file.split('/')[-1]}.txt")
try:
with open(output_trans_path, 'w', encoding='utf-8') as f:
for text_string, timeinterval, spk in sentence_info_with_spk:
f.write(f'Speaker_{spk}: [{timeinterval[0]:.3f} {timeinterval[1]:.3f}] {text_string}\n')
except Exception as e:
pass
########################### LLM Inference #############################
logger.info("Starting LLM Inference for meeting summary...")
# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(llm_hf_tokenizer_path)
cfg = AutoConfig.from_pretrained(llm_hf_tokenizer_path, trust_remote_code=True)
eos_token_id = None
if isinstance(cfg.eos_token_id, list) and len(cfg.eos_token_id) > 1:
eos_token_id = cfg.eos_token_id
with open(output_trans_path, 'r', encoding='utf-8') as file:
prompt = file.read()
prompt_list = prompt.split('\n')
output_text_l = []
cur_prompt = ""
for idx, prompt in enumerate(prompt_list):
cur_prompt += prompt + "\n"
if len(cur_prompt) < 5000 and idx != len(prompt_list) -1:
continue
output_text = run_model(cur_prompt)
output_text_l.append(output_text)
print("\n")
cur_prompt = ""
# TODO: 如果输入文本很长, 可以考虑再次调用模型对分段总结进行全面总结, 但是此时也需要保证总输入长度小于模型规定的上限
# print("\n\n开始进行最终总结...")
# prompt = "\n".join(output_text_l)
# final_output_text = run_model(prompt)
# print("最终总结结果:\n", final_output_text)
# 将 output_text_l 保存为 md 格式的文件, 按照原始 LLM 的输出格式进行保存
output_summary_path = os.path.join(args.output_dir, f"{wav_file.split('/')[-1]}_summary.md")
output_text_combined = "\n\n".join(output_text_l)
with open(output_summary_path, 'w', encoding='utf-8') as f:
f.write(output_text_combined)
logger.info(f"LLM Inference completed. Summary saved to {output_summary_path}")