| import argparse |
| import ctypes |
| import math |
| import os |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import faulthandler |
| import numpy as np |
| import soundfile as sf |
| from scipy.signal import resample_poly |
| from transformers import WhisperFeatureExtractor |
|
|
| faulthandler.enable() |
| os.environ.setdefault("RKLLM_LOG_LEVEL", "1") |
|
|
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| from rkllm_binding import ( |
| LLMCallState, |
| RKLLMInferMode, |
| RKLLMInferParam, |
| RKLLMInput, |
| RKLLMInputType, |
| RKLLMResult, |
| RKLLMRuntime, |
| ) |
|
|
|
|
| import ztu_somemodelruntime_ez_rknn_async as ort |
|
|
|
|
|
|
| DEFAULT_ENCODER_PATH = "rknn/audio_encoder.rknn" |
| DEFAULT_LLM_PATH = "rknn/language_model.rkllm" |
|
|
|
|
| def now() -> float: |
| return time.perf_counter() |
|
|
|
|
| class StreamingTextCollector: |
| def __init__(self, stream_output: bool = True): |
| self.stream_output = stream_output |
| self.parts: list[str] = [] |
| self.error = False |
|
|
| def __call__(self, result_ptr, userdata_ptr, state_enum): |
| state = LLMCallState(state_enum) |
| result: RKLLMResult = result_ptr.contents |
|
|
| if state == LLMCallState.RKLLM_RUN_NORMAL and result.text: |
| chunk = result.text.decode("utf-8", errors="ignore") |
| self.parts.append(chunk) |
| if self.stream_output: |
| print(chunk, end="", flush=True) |
| elif state == LLMCallState.RKLLM_RUN_FINISH and self.stream_output: |
| print("(finish)", flush=True) |
| elif state == LLMCallState.RKLLM_RUN_ERROR: |
| self.error = True |
| if self.stream_output: |
| print("\nrun error", flush=True) |
| return 0 |
|
|
| @property |
| def text(self) -> str: |
| return "".join(self.parts) |
|
|
|
|
| def load_waveform(audio_path: str, target_sr: int = 16000) -> np.ndarray: |
| audio, sr = sf.read(audio_path, dtype="float32", always_2d=False) |
| audio = np.asarray(audio, dtype=np.float32) |
| if audio.ndim == 2: |
| audio = audio.mean(axis=-1) |
| if sr != target_sr: |
| divisor = math.gcd(int(sr), int(target_sr)) |
| up = int(target_sr // divisor) |
| down = int(sr // divisor) |
| audio = resample_poly(audio, up=up, down=down).astype(np.float32) |
| return audio |
|
|
|
|
| def configure_feature_extractor_for_audio(feature_extractor: WhisperFeatureExtractor, waveform: np.ndarray) -> None: |
| required_seconds = max(1, math.ceil(waveform.shape[0] / float(feature_extractor.sampling_rate))) |
| if required_seconds <= feature_extractor.chunk_length: |
| return |
| feature_extractor.chunk_length = required_seconds |
| feature_extractor.n_samples = int(required_seconds * feature_extractor.sampling_rate) |
| feature_extractor.nb_max_frames = feature_extractor.n_samples // feature_extractor.hop_length |
|
|
|
|
| def extract_mel_features(model_path: str, audio_path: str) -> tuple[np.ndarray, int]: |
| feature_extractor = WhisperFeatureExtractor.from_pretrained(model_path) |
| waveform = load_waveform(audio_path) |
| configure_feature_extractor_for_audio(feature_extractor, waveform) |
| outputs = feature_extractor( |
| waveform, |
| sampling_rate=16000, |
| return_attention_mask=True, |
| return_tensors="np", |
| ) |
| input_features = outputs["input_features"][0].astype(np.float32) |
| feature_len = int(outputs["attention_mask"][0].sum()) |
| return input_features, feature_len |
|
|
|
|
| def split_mel_features(input_features: np.ndarray, feature_len: int, chunk_frames: int) -> list[tuple[np.ndarray, int]]: |
| chunks = [] |
| start = 0 |
| while start < feature_len: |
| cur_len = min(chunk_frames, feature_len - start) |
| chunk = np.zeros((input_features.shape[0], chunk_frames), dtype=np.float32) |
| chunk[:, :cur_len] = input_features[:, start : start + cur_len] |
| chunks.append((chunk, cur_len)) |
| start += cur_len |
| return chunks |
|
|
|
|
| def get_chunk_output_length_value(length: int) -> int: |
| value = int(length) |
| value = (value + 1) // 2 |
| value = (value + 1) // 2 |
| value = (value + 1) // 2 |
| return value |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Run end-to-end Qwen3-ASR with RKNN audio encoder and RKLLM decoder.") |
| parser.add_argument("--model-path", type=str, default=".", help="Path to the original Qwen3-ASR model directory.") |
| parser.add_argument("--audio-path", type=str, required=True, help="Path to the input audio file.") |
| parser.add_argument( |
| "--encoder-model-path", |
| type=str, |
| default=DEFAULT_ENCODER_PATH, |
| help="Path to the audio encoder model (.rknn).", |
| ) |
| parser.add_argument( |
| "--llm-model-path", |
| type=str, |
| default=DEFAULT_LLM_PATH, |
| help="Path to the exported .rkllm text model.", |
| ) |
| parser.add_argument("--chunk-frames", type=int, default=100, help="Fixed mel chunk length.") |
| parser.add_argument("--max-new-tokens", type=int, default=1024, help="Maximum number of new tokens to generate.") |
| parser.add_argument("--max-context-len", type=int, default=4096, help="Maximum context length for RKLLM.") |
| parser.add_argument("--top-k", type=int, default=5, help="Top-k used by RKLLM decoding.") |
| parser.add_argument("--system-prompt", type=str, default="", help="Optional system prompt.") |
| parser.add_argument( |
| "--force-language", |
| type=str, |
| default=None, |
| help="Optional language suffix, for example 'Chinese'. Appends 'language X<asr_text>' after the assistant prompt.", |
| ) |
| parser.add_argument("--save-audio-features", type=str, default=None, help="Optional path to save concatenated audio features.") |
| parser.add_argument("--save-text", type=str, default=None, help="Optional path to save the final decoded text.") |
| parser.add_argument("--no-stream", action="store_true", help="Disable streaming stdout from the RKLLM callback.") |
| return parser.parse_args() |
|
|
|
|
| def build_chat_template(system_prompt: str, force_language) -> tuple[str, str, str]: |
| assistant_prefix = "" |
| if force_language: |
| assistant_prefix = f"language {force_language}<asr_text>" |
| return ( |
| f"<|im_start|>system\n{system_prompt or ''}<|im_end|>\n", |
| "<|im_start|>user\n", |
| f"<|im_end|>\n<|im_start|>assistant\n{assistant_prefix}", |
| ) |
|
|
|
|
| def run_audio_encoder( |
| session, |
| input_features: np.ndarray, |
| feature_len: int, |
| chunk_frames: int, |
| ) -> np.ndarray: |
| chunks = split_mel_features(input_features, feature_len, chunk_frames) |
| if not chunks: |
| return np.zeros((0, 2048), dtype=np.float32) |
|
|
| batch_tensor = np.ascontiguousarray( |
| np.stack([chunk for chunk, _ in chunks], axis=0), |
| dtype=np.float32, |
| ) |
| session_outputs = session.run( |
| None, |
| { |
| "input_features": batch_tensor, |
| |
| }, run_options={"ztu_modelrt_dispatch_batch": True} |
| ) |
| audio_features = np.asarray(session_outputs[0], dtype=np.float32) |
| if len(session_outputs) >= 2: |
| valid_lens = np.asarray(session_outputs[1]).reshape(-1) |
| if valid_lens.size == 1 and len(chunks) > 1: |
| valid_lens = np.repeat(valid_lens, len(chunks)) |
| else: |
| valid_lens = np.asarray( |
| [get_chunk_output_length_value(chunk_len) for _, chunk_len in chunks], |
| dtype=np.int32, |
| ) |
|
|
| if audio_features.shape[0] != len(chunks): |
| raise RuntimeError( |
| f"Audio encoder batch mismatch: got {audio_features.shape[0]} outputs for {len(chunks)} inputs." |
| ) |
| if valid_lens.size != len(chunks): |
| raise RuntimeError( |
| f"Audio encoder valid length mismatch: got {valid_lens.size} lengths for {len(chunks)} inputs." |
| ) |
|
|
| outputs = [audio_features[idx, : int(valid_len)] for idx, valid_len in enumerate(valid_lens)] |
| return np.concatenate(outputs, axis=0) |
|
|
| def load_rkllm( |
| llm_model_path: str, |
| max_new_tokens: int, |
| max_context_len: int, |
| top_k: int, |
| system_prompt: str, |
| force_language, |
| stream_output: bool, |
| ): |
| collector = StreamingTextCollector(stream_output=stream_output) |
| rk_llm = RKLLMRuntime() |
| param = rk_llm.create_default_param() |
|
|
| param.model_path = llm_model_path.encode("utf-8") |
| param.top_k = top_k |
| param.max_new_tokens = max_new_tokens |
| param.max_context_len = max_context_len |
| param.skip_special_token = True |
| param.img_start = b"<|audio_start|>" |
| param.img_end = b"<|audio_end|>" |
| param.img_content = b"<|audio_pad|>" |
| param.extend_param.base_domain_id = 1 |
|
|
| rk_llm.init(param, collector) |
|
|
| system_text, prompt_prefix, prompt_postfix = build_chat_template( |
| system_prompt=system_prompt, |
| force_language=force_language, |
| ) |
| rk_llm.set_chat_template( |
| system_prompt=system_text, |
| prompt_prefix=prompt_prefix, |
| prompt_postfix=prompt_postfix, |
| ) |
| return rk_llm, collector |
|
|
|
|
| def run_rkllm( |
| rk_llm: RKLLMRuntime, |
| audio_features: np.ndarray, |
| ) -> None: |
| rkllm_input = RKLLMInput() |
| rkllm_input.role = b"user" |
| rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_MULTIMODAL |
|
|
| |
| flattened = np.ascontiguousarray(audio_features.reshape(-1), dtype=np.float32) |
| rkllm_input.multimodal_input.prompt = b"<image>" |
| rkllm_input.multimodal_input.image_embed = flattened.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) |
| rkllm_input.multimodal_input.n_image_tokens = audio_features.shape[0] |
| rkllm_input.multimodal_input.n_image = 1 |
| rkllm_input.multimodal_input.image_height = 1 |
| rkllm_input.multimodal_input.image_width = max(audio_features.shape[0], 1) |
|
|
| infer_param = RKLLMInferParam() |
| infer_param.mode = RKLLMInferMode.RKLLM_INFER_GENERATE |
| infer_param.keep_history = 0 |
|
|
| rk_llm.run(rkllm_input, infer_param) |
|
|
|
|
| def main(): |
| args = parse_args() |
| total_t0 = now() |
| encoder_session = None |
| rk_llm = None |
| collector = None |
|
|
| load_t0 = now() |
| mel_t0 = now() |
| input_features, feature_len = extract_mel_features(args.model_path, args.audio_path) |
| mel_elapsed = now() - mel_t0 |
|
|
| encoder_session = ort.InferenceSession(args.encoder_model_path, provider_options=[{"schedule": [0,1,2]}]) |
| rkllm_init_t0 = now() |
| rk_llm, collector = load_rkllm( |
| llm_model_path=args.llm_model_path, |
| max_new_tokens=args.max_new_tokens, |
| max_context_len=args.max_context_len, |
| top_k=args.top_k, |
| system_prompt=args.system_prompt, |
| force_language=args.force_language, |
| stream_output=not args.no_stream, |
| ) |
| rkllm_init_elapsed = now() - rkllm_init_t0 |
| load_elapsed = now() - load_t0 |
|
|
| infer_t0 = now() |
| encoder_t0 = now() |
| audio_features = run_audio_encoder( |
| session=encoder_session, |
| input_features=input_features, |
| feature_len=feature_len, |
| chunk_frames=args.chunk_frames, |
| ) |
| encoder_elapsed = now() - encoder_t0 |
|
|
| print(f"input_feature_len: {feature_len}") |
| print(f"audio_features: {audio_features.shape}") |
| print(f"time_mel_sec: {mel_elapsed:.3f}") |
| print(f"time_rkllm_init_sec: {rkllm_init_elapsed:.3f}") |
| print(f"time_load_total_sec: {load_elapsed:.3f}") |
| print(f"time_audio_encoder_sec: {encoder_elapsed:.3f}") |
|
|
| if args.save_audio_features: |
| savepath = Path(args.save_audio_features) |
| savepath.parent.mkdir(parents=True, exist_ok=True) |
| np.save(savepath, audio_features) |
| print(f"saved_audio_features: {savepath}") |
|
|
| generate_t0 = now() |
| run_rkllm(rk_llm=rk_llm, audio_features=audio_features) |
| generate_elapsed = now() - generate_t0 |
| infer_elapsed = now() - infer_t0 |
| total_elapsed = now() - total_t0 |
|
|
| if collector and collector.error: |
| raise RuntimeError("RKLLM generation failed.") |
| text = collector.text if collector else "" |
|
|
| print(f"time_generate_sec: {generate_elapsed:.3f}") |
| print(f"time_infer_total_sec: {infer_elapsed:.3f}") |
| print(f"time_total_sec: {total_elapsed:.3f}") |
|
|
| if args.save_text: |
| savepath = Path(args.save_text) |
| savepath.parent.mkdir(parents=True, exist_ok=True) |
| savepath.write_text(text, encoding="utf-8") |
| print(f"saved_text: {savepath}") |
|
|
| if args.no_stream: |
| print(text) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|