Qwen3-ASR-1.7B-RKLLM / run_qwen3_asr_e2e.py
happyme531's picture
Add batch-parallel support for audio encoder
b97314a verified
import argparse
import ctypes
import math
import os
import sys
import time
from pathlib import Path
import faulthandler
import numpy as np
import soundfile as sf
from scipy.signal import resample_poly
from transformers import WhisperFeatureExtractor
faulthandler.enable()
os.environ.setdefault("RKLLM_LOG_LEVEL", "1")
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from rkllm_binding import ( # noqa: E402
LLMCallState,
RKLLMInferMode,
RKLLMInferParam,
RKLLMInput,
RKLLMInputType,
RKLLMResult,
RKLLMRuntime,
)
import ztu_somemodelruntime_ez_rknn_async as ort
DEFAULT_ENCODER_PATH = "rknn/audio_encoder.rknn"
DEFAULT_LLM_PATH = "rknn/language_model.rkllm"
def now() -> float:
return time.perf_counter()
class StreamingTextCollector:
def __init__(self, stream_output: bool = True):
self.stream_output = stream_output
self.parts: list[str] = []
self.error = False
def __call__(self, result_ptr, userdata_ptr, state_enum):
state = LLMCallState(state_enum)
result: RKLLMResult = result_ptr.contents
if state == LLMCallState.RKLLM_RUN_NORMAL and result.text:
chunk = result.text.decode("utf-8", errors="ignore")
self.parts.append(chunk)
if self.stream_output:
print(chunk, end="", flush=True)
elif state == LLMCallState.RKLLM_RUN_FINISH and self.stream_output:
print("(finish)", flush=True)
elif state == LLMCallState.RKLLM_RUN_ERROR:
self.error = True
if self.stream_output:
print("\nrun error", flush=True)
return 0
@property
def text(self) -> str:
return "".join(self.parts)
def load_waveform(audio_path: str, target_sr: int = 16000) -> np.ndarray:
audio, sr = sf.read(audio_path, dtype="float32", always_2d=False)
audio = np.asarray(audio, dtype=np.float32)
if audio.ndim == 2:
audio = audio.mean(axis=-1)
if sr != target_sr:
divisor = math.gcd(int(sr), int(target_sr))
up = int(target_sr // divisor)
down = int(sr // divisor)
audio = resample_poly(audio, up=up, down=down).astype(np.float32)
return audio
def configure_feature_extractor_for_audio(feature_extractor: WhisperFeatureExtractor, waveform: np.ndarray) -> None:
required_seconds = max(1, math.ceil(waveform.shape[0] / float(feature_extractor.sampling_rate)))
if required_seconds <= feature_extractor.chunk_length:
return
feature_extractor.chunk_length = required_seconds
feature_extractor.n_samples = int(required_seconds * feature_extractor.sampling_rate)
feature_extractor.nb_max_frames = feature_extractor.n_samples // feature_extractor.hop_length
def extract_mel_features(model_path: str, audio_path: str) -> tuple[np.ndarray, int]:
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_path)
waveform = load_waveform(audio_path)
configure_feature_extractor_for_audio(feature_extractor, waveform)
outputs = feature_extractor(
waveform,
sampling_rate=16000,
return_attention_mask=True,
return_tensors="np",
)
input_features = outputs["input_features"][0].astype(np.float32)
feature_len = int(outputs["attention_mask"][0].sum())
return input_features, feature_len
def split_mel_features(input_features: np.ndarray, feature_len: int, chunk_frames: int) -> list[tuple[np.ndarray, int]]:
chunks = []
start = 0
while start < feature_len:
cur_len = min(chunk_frames, feature_len - start)
chunk = np.zeros((input_features.shape[0], chunk_frames), dtype=np.float32)
chunk[:, :cur_len] = input_features[:, start : start + cur_len]
chunks.append((chunk, cur_len))
start += cur_len
return chunks
def get_chunk_output_length_value(length: int) -> int:
value = int(length)
value = (value + 1) // 2
value = (value + 1) // 2
value = (value + 1) // 2
return value
def parse_args():
parser = argparse.ArgumentParser(description="Run end-to-end Qwen3-ASR with RKNN audio encoder and RKLLM decoder.")
parser.add_argument("--model-path", type=str, default=".", help="Path to the original Qwen3-ASR model directory.")
parser.add_argument("--audio-path", type=str, required=True, help="Path to the input audio file.")
parser.add_argument(
"--encoder-model-path",
type=str,
default=DEFAULT_ENCODER_PATH,
help="Path to the audio encoder model (.rknn).",
)
parser.add_argument(
"--llm-model-path",
type=str,
default=DEFAULT_LLM_PATH,
help="Path to the exported .rkllm text model.",
)
parser.add_argument("--chunk-frames", type=int, default=100, help="Fixed mel chunk length.")
parser.add_argument("--max-new-tokens", type=int, default=1024, help="Maximum number of new tokens to generate.")
parser.add_argument("--max-context-len", type=int, default=4096, help="Maximum context length for RKLLM.")
parser.add_argument("--top-k", type=int, default=5, help="Top-k used by RKLLM decoding.")
parser.add_argument("--system-prompt", type=str, default="", help="Optional system prompt.")
parser.add_argument(
"--force-language",
type=str,
default=None,
help="Optional language suffix, for example 'Chinese'. Appends 'language X<asr_text>' after the assistant prompt.",
)
parser.add_argument("--save-audio-features", type=str, default=None, help="Optional path to save concatenated audio features.")
parser.add_argument("--save-text", type=str, default=None, help="Optional path to save the final decoded text.")
parser.add_argument("--no-stream", action="store_true", help="Disable streaming stdout from the RKLLM callback.")
return parser.parse_args()
def build_chat_template(system_prompt: str, force_language) -> tuple[str, str, str]:
assistant_prefix = ""
if force_language:
assistant_prefix = f"language {force_language}<asr_text>"
return (
f"<|im_start|>system\n{system_prompt or ''}<|im_end|>\n",
"<|im_start|>user\n",
f"<|im_end|>\n<|im_start|>assistant\n{assistant_prefix}",
)
def run_audio_encoder(
session,
input_features: np.ndarray,
feature_len: int,
chunk_frames: int,
) -> np.ndarray:
chunks = split_mel_features(input_features, feature_len, chunk_frames)
if not chunks:
return np.zeros((0, 2048), dtype=np.float32)
batch_tensor = np.ascontiguousarray(
np.stack([chunk for chunk, _ in chunks], axis=0),
dtype=np.float32,
)
session_outputs = session.run(
None,
{
"input_features": batch_tensor,
#"feature_len": np.asarray([chunk_len for _, chunk_len in chunks], dtype=np.int32),
}, run_options={"ztu_modelrt_dispatch_batch": True}
)
audio_features = np.asarray(session_outputs[0], dtype=np.float32)
if len(session_outputs) >= 2:
valid_lens = np.asarray(session_outputs[1]).reshape(-1)
if valid_lens.size == 1 and len(chunks) > 1:
valid_lens = np.repeat(valid_lens, len(chunks))
else:
valid_lens = np.asarray(
[get_chunk_output_length_value(chunk_len) for _, chunk_len in chunks],
dtype=np.int32,
)
if audio_features.shape[0] != len(chunks):
raise RuntimeError(
f"Audio encoder batch mismatch: got {audio_features.shape[0]} outputs for {len(chunks)} inputs."
)
if valid_lens.size != len(chunks):
raise RuntimeError(
f"Audio encoder valid length mismatch: got {valid_lens.size} lengths for {len(chunks)} inputs."
)
outputs = [audio_features[idx, : int(valid_len)] for idx, valid_len in enumerate(valid_lens)]
return np.concatenate(outputs, axis=0)
def load_rkllm(
llm_model_path: str,
max_new_tokens: int,
max_context_len: int,
top_k: int,
system_prompt: str,
force_language,
stream_output: bool,
):
collector = StreamingTextCollector(stream_output=stream_output)
rk_llm = RKLLMRuntime()
param = rk_llm.create_default_param()
param.model_path = llm_model_path.encode("utf-8")
param.top_k = top_k
param.max_new_tokens = max_new_tokens
param.max_context_len = max_context_len
param.skip_special_token = True
param.img_start = b"<|audio_start|>"
param.img_end = b"<|audio_end|>"
param.img_content = b"<|audio_pad|>"
param.extend_param.base_domain_id = 1 # 4GB is not enough
rk_llm.init(param, collector)
system_text, prompt_prefix, prompt_postfix = build_chat_template(
system_prompt=system_prompt,
force_language=force_language,
)
rk_llm.set_chat_template(
system_prompt=system_text,
prompt_prefix=prompt_prefix,
prompt_postfix=prompt_postfix,
)
return rk_llm, collector
def run_rkllm(
rk_llm: RKLLMRuntime,
audio_features: np.ndarray,
) -> None:
rkllm_input = RKLLMInput()
rkllm_input.role = b"user"
rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_MULTIMODAL
# RKLLM multimodal prompt must contain the literal "<image>" placeholder.
flattened = np.ascontiguousarray(audio_features.reshape(-1), dtype=np.float32)
rkllm_input.multimodal_input.prompt = b"<image>"
rkllm_input.multimodal_input.image_embed = flattened.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
rkllm_input.multimodal_input.n_image_tokens = audio_features.shape[0]
rkllm_input.multimodal_input.n_image = 1
rkllm_input.multimodal_input.image_height = 1
rkllm_input.multimodal_input.image_width = max(audio_features.shape[0], 1)
infer_param = RKLLMInferParam()
infer_param.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
infer_param.keep_history = 0
rk_llm.run(rkllm_input, infer_param)
def main():
args = parse_args()
total_t0 = now()
encoder_session = None
rk_llm = None
collector = None
load_t0 = now()
mel_t0 = now()
input_features, feature_len = extract_mel_features(args.model_path, args.audio_path)
mel_elapsed = now() - mel_t0
encoder_session = ort.InferenceSession(args.encoder_model_path, provider_options=[{"schedule": [0,1,2]}])
rkllm_init_t0 = now()
rk_llm, collector = load_rkllm(
llm_model_path=args.llm_model_path,
max_new_tokens=args.max_new_tokens,
max_context_len=args.max_context_len,
top_k=args.top_k,
system_prompt=args.system_prompt,
force_language=args.force_language,
stream_output=not args.no_stream,
)
rkllm_init_elapsed = now() - rkllm_init_t0
load_elapsed = now() - load_t0
infer_t0 = now()
encoder_t0 = now()
audio_features = run_audio_encoder(
session=encoder_session,
input_features=input_features,
feature_len=feature_len,
chunk_frames=args.chunk_frames,
)
encoder_elapsed = now() - encoder_t0
print(f"input_feature_len: {feature_len}")
print(f"audio_features: {audio_features.shape}")
print(f"time_mel_sec: {mel_elapsed:.3f}")
print(f"time_rkllm_init_sec: {rkllm_init_elapsed:.3f}")
print(f"time_load_total_sec: {load_elapsed:.3f}")
print(f"time_audio_encoder_sec: {encoder_elapsed:.3f}")
if args.save_audio_features:
savepath = Path(args.save_audio_features)
savepath.parent.mkdir(parents=True, exist_ok=True)
np.save(savepath, audio_features)
print(f"saved_audio_features: {savepath}")
generate_t0 = now()
run_rkllm(rk_llm=rk_llm, audio_features=audio_features)
generate_elapsed = now() - generate_t0
infer_elapsed = now() - infer_t0
total_elapsed = now() - total_t0
if collector and collector.error:
raise RuntimeError("RKLLM generation failed.")
text = collector.text if collector else ""
print(f"time_generate_sec: {generate_elapsed:.3f}")
print(f"time_infer_total_sec: {infer_elapsed:.3f}")
print(f"time_total_sec: {total_elapsed:.3f}")
if args.save_text:
savepath = Path(args.save_text)
savepath.parent.mkdir(parents=True, exist_ok=True)
savepath.write_text(text, encoding="utf-8")
print(f"saved_text: {savepath}")
if args.no_stream:
print(text)
if __name__ == "__main__":
main()