{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/foxconnhy/miniconda3/envs/llamafactory/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "/home/jeff/.cache/huggingface/modules/transformers_modules/gemma_v1/speech_conformer_encoder.py:2798: FutureWarning: Please specify CheckpointImpl.NO_REENTRANT as CheckpointImpl.REENTRANT will soon be removed as the default and eventually deprecated.\n", " lambda i: encoder_checkpoint_wrapper(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "######################## speech lora #############\n", "######################## text lora #############\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00, 1.80it/s]\n", "Some weights of Gemma3OmniForConditionalGeneration were not initialized from the model checkpoint at /mnt/data-2t/jeff/codes/llm/cpp/gemma_v1 and are newly initialized: ['language_model.model.base_model.model.layers.0.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.0.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.0.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.0.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.0.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.0.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.1.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.1.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.1.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.1.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.1.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.1.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.10.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.10.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.10.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.10.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.10.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.10.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.11.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.11.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.11.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.11.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.11.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.11.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.12.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.12.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.12.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.12.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.12.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.12.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.13.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.13.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.13.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.13.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.13.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.13.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.14.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.14.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.14.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.14.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.14.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.14.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.15.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.15.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.15.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.15.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.15.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.15.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.16.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.16.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.16.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.16.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.16.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.16.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.17.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.17.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.17.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.17.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.17.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.17.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.18.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.18.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.18.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.18.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.18.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.18.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.19.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.19.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.19.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.19.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.19.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.19.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.2.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.2.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.2.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.2.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.2.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.2.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.20.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.20.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.20.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.20.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.20.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.20.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.21.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.21.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.21.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.21.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.21.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.21.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.22.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.22.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.22.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.22.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.22.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.22.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.23.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.23.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.23.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.23.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.23.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.23.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.24.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.24.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.24.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.24.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.24.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.24.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.25.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.25.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.25.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.25.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.25.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.25.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.26.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.26.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.26.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.26.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.26.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.26.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.27.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.27.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.27.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.27.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.27.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.27.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.28.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.28.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.28.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.28.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.28.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.28.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.29.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.29.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.29.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.29.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.29.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.29.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.3.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.3.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.3.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.3.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.3.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.3.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.30.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.30.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.30.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.30.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.30.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.30.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.31.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.31.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.31.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.31.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.31.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.31.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.32.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.32.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.32.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.32.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.32.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.32.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.33.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.33.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.33.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.33.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.33.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.33.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.4.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.4.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.4.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.4.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.4.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.4.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.5.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.5.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.5.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.5.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.5.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.5.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.6.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.6.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.6.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.6.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.6.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.6.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.7.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.7.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.7.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.7.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.7.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.7.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.8.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.8.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.8.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.8.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.8.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.8.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.9.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.9.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.9.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.9.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.9.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.9.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.v_proj.lora_B.text.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.\n" ] } ], "source": [ "from io import BytesIO\n", "import torch\n", "import numpy as np\n", "from transformers import AutoModel, AutoProcessor, BatchFeature,Gemma3ForCausalLM,Gemma3Processor\n", "\n", "# converter = opencc.OpenCC('s2tw.json')\n", "\n", "model_id = \"/mnt/data-2t/jeff/codes/llm/cpp/gemma_v1\"\n", "revision = \"main\" #\"v1.0\"\n", "\n", "model = AutoModel.from_pretrained(\n", " model_id, device_map=\"cpu\", revision = revision, trust_remote_code=True\n", ").eval()\n", "\n", "processor = AutoProcessor.from_pretrained(\n", " model_id, revision = revision, trust_remote_code=True\n", ")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Sequential(\n", " (0): Linear(in_features=1024, out_features=2560, bias=True)\n", " (1): GELU(approximate='none')\n", " (2): Linear(in_features=2560, out_features=2560, bias=True)\n", ")" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.audio_tower\n", "model.audio_projector" ] }, { "cell_type": "code", "execution_count": 179, "metadata": {}, "outputs": [], "source": [ "from ASRDataset import *\n", "pickup_dataset = MultiturnAudioDataset(split='train',processor=processor,json_path='/mnt/data-2t/jeff/codes/llm/cpp/sample_data/pickup_processed.json')" ] }, { "cell_type": "code", "execution_count": 180, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 256, 80])\n", "torch.Size([1, 217, 80])\n", "torch.Size([1, 77, 80])\n", "torch.Size([1, 580, 80])\n" ] } ], "source": [ "for i in range(len(pickup_dataset)):\n", " inp = pickup_dataset.__getitem__(i)\n", " print(inp['input_audio_embeds'].shape)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 100, 2560])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inp = pickup_dataset.__getitem__(3)\n", "fea,mask = model.audio_tower(inp['input_audio_embeds'],torch.ones(inp['input_audio_embeds'].shape[:2]))\n", "model.audio_projector(fea).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch \n", "import torch.nn as nn\n", "from speech_conformer_encoder import ConformerEncoder\n", "class Gemma3AudioEncoder(nn.Module):\n", " def __init__(self,):\n", " super().__init__()\n", " audio_config = model.config.audio_config.to_diff_dict()\n", " for item in ['transformers_version', 'model_type', 'torch_dtype']:\n", " if item in audio_config:\n", " audio_config.pop(item)\n", " # self.audio_tower = model.audio_tower\n", " # self.audio_projector = model.audio_projector\n", " self.audio_tower = ConformerEncoder(**audio_config)#model.audio_tower\n", " self.audio_projector = nn.Sequential(\n", " nn.Linear(in_features=1024, out_features=2560, bias=True),\n", " nn.GELU(approximate='none'),\n", " nn.Linear(in_features=2560, out_features=2560, bias=True))#model.audio_projector\n", " def forward(self,x,mask):\n", " # mask = torch.ones(x.shape[:2])\n", " x,_ = self.audio_tower(x,mask)\n", " x = self.audio_projector(x)\n", " return x\n", "audio_encoder = Gemma3AudioEncoder()\n", "import copy\n", "audio_encoder.audio_tower.encoder_embedding=copy.deepcopy(model.audio_tower.encoder_embedding)\n", "audio_encoder.audio_projector.load_state_dict(model.audio_projector.state_dict())\n", "audio_encoder.audio_tower.load_state_dict(model.audio_tower.state_dict())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import onnx\n", "import onnxruntime as ort\n", "import onnxscript\n", "import os\n", "import requests\n", "import shutil\n", "import soundfile\n", "import subprocess\n", "import sys\n", "import torch\n", "\n", "from onnx import helper, numpy_helper, TensorProto\n", "from onnxruntime_genai.models.builder import create_model\n", "from onnxruntime.transformers.dynamo_onnx_helper import DynamoOnnxHelper\n", "from onnxscript import ir\n", "from torch.export import Dim, export\n", "def build_speech(outputdir='./onnx_files'):\n", " # TorchScript export\n", " dummy_inputs = (\n", " torch.randn((1,97,80)),\n", " torch.ones((1,97))\n", " #inputs[\"input_audio_embeds\"], # audio_embeds: torch.FloatTensor\n", " #inputs[\"audio_attention_mask\"], # audio_attention_mask: torch.BoolTensor\n", " # inputs[\"audio_embed_sizes\"], # audio_sizes: torch.LongTensor\n", " # inputs[\"input_mode\"], # audio_projection_mode: int\n", " )\n", " filename = \"phi-4-mm-speech.onnx\"\n", "\n", " temp_folder_1 = os.path.join(outputdir, \"speech_init_export\")\n", " os.makedirs(temp_folder_1, exist_ok=True)\n", "\n", " fpath_1 = os.path.join(temp_folder_1, filename)\n", " torch._dynamo.config.capture_scalar_outputs = True\n", " onnx_program = torch.onnx.export(audio_encoder, dummy_inputs, fpath_1,\n", " input_names=[\"audio_embeds\", \"audio_attention_mask\"], \n", " output_names=[\"audio_features\"],\n", " opset_version=20,\n", " dynamic_axes={\n", " \"audio_embeds\": {0:'B',1: \"L\"},\n", " \"audio_attention_mask\": {0:'B',1: \"L\"},\n", " },\n", " )\n", "\n", "build_speech()" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "import onnxruntime as ort\n", "import numpy as np\n", "ort_sess = ort.InferenceSession(\"/mnt/data-2t/jeff/codes/llm/cpp/onnx_files/speech_init_export/phi-4-mm-speech.onnx\")" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(2, 111, 2560)" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import warnings\n", "warnings.filterwarnings('ignore')\n", "from tqdm import tqdm\n", "import torch\n", "import numpy as np\n", "a=[]\n", "# for i in tqdm(range(10000)):\n", "# try:\n", "ort_sess.run(None, {\"audio_embeds\": np.array(torch.randn(1,97,80),dtype=np.float32),\n", " # \"audio_attention_mask\":np.ones((1,97),dtype=np.float32)\n", " }\n", " )\n", " # print(i)\n", " # a.append(i)\n", " # except:\n", " # pass\n", "ort_sess.run(None, {\"audio_embeds\": np.array(torch.randn(2,888,80),dtype=np.float32),\n", " # \"audio_attention_mask\":np.ones((2,97),dtype=np.float32)\n", " }\n", " )[0].shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Python inference time check" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import time\n", "total = 0\n", "_mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=16000//2-80-230).T\n", "for i in range(100):\n", " now = time.time()\n", " inp = np.random.randn(np.random.randint(16240, 48240)).reshape(1,-1)#np.array(torch.randn(1,np.random.randint(100,300),80),dtype=np.float32)\n", " inp = _extract_features(inp,16000).reshape(1,-1,80)\n", " now = time.time()\n", " # inp = np.array(torch.randn(1,150,80),dtype=np.float32)\n", " ort_sess.run(None, {\"audio_embeds\": inp,\n", " # \"audio_attention_mask\":np.ones((1,97),dtype=np.float32)\n", " })\n", " total += time.time()-now\n", " \n", "total,total/100" ] }, { "cell_type": "code", "execution_count": 238, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "24240" ] }, "execution_count": 238, "metadata": {}, "output_type": "execute_result" } ], "source": [ "149*160+400" ] }, { "cell_type": "code", "execution_count": 233, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1, 40917)" ] }, "execution_count": 233, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.random.randn(np.random.randint(16240, 48240)).reshape(1,-1).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(10.608245611190796, 0.10608245611190796)" ] }, "execution_count": 218, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import time\n", "total = 0\n", "for i in range(100):\n", " tmp = torch.randn(1,np.random.randint(100,300),80)\n", " mask = torch.ones(tmp.shape[:2])\n", " now = time.time()\n", " audio_encoder(tmp,mask)\n", " total += time.time()-now\n", "total,total/100" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# C++ ERROR check" ] }, { "cell_type": "code", "execution_count": 167, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[ 0.3246, 0.0295, 0.1076, ..., -0.1125, -0.0894, -0.3800],\n", " [ 0.3267, -0.2442, 0.2653, ..., 0.7783, -0.6049, -1.0858],\n", " [ 0.1797, 0.0438, 0.9673, ..., 0.5126, -0.5657, -0.7050],\n", " ...,\n", " [ 0.0261, -0.0324, 0.0230, ..., -0.1303, 0.0343, 0.1486],\n", " [ 0.1655, -0.3327, 0.4232, ..., 0.0513, 0.4222, -0.3645],\n", " [ 0.1147, -0.1201, 0.4198, ..., 0.6170, 0.0838, -0.1409]]],\n", " grad_fn=)" ] }, "execution_count": 167, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inp = pickup_dataset.__getitem__(3)\n", "fea,mask = model.audio_tower(inp['input_audio_embeds'],torch.ones(inp['input_audio_embeds'].shape[:2]))\n", "model.audio_projector(fea)" ] }, { "cell_type": "code", "execution_count": 201, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[[ 0.13004433, -0.06643961, 0.01333247, ..., -0.05643693,\n", " -0.23922557, 0.569423 ],\n", " [-0.75552 , -0.05047493, -0.82725084, ..., 0.32261163,\n", " -0.14968234, -0.7078437 ],\n", " [-0.6673857 , 0.33906737, -0.6191502 , ..., 0.04259709,\n", " -0.01194861, 0.27635992],\n", " ...,\n", " [ 0.02916821, -0.03163592, 0.02736526, ..., -0.12979224,\n", " 0.03317374, 0.15346158],\n", " [-0.8559882 , -0.5196625 , 0.2549707 , ..., 0.28192428,\n", " 1.4099622 , -0.15940394],\n", " [-0.20253824, -0.30478072, -0.6786582 , ..., 0.08860758,\n", " -0.12145798, 0.525889 ]]], dtype=float32)" ] }, "execution_count": 201, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inp = pickup_dataset.__getitem__(0)\n", "res = ort_sess.run(None, {\"audio_embeds\": np.array(inp['input_audio_embeds'],dtype=np.float32),\n", " # \"audio_attention_mask\":np.ones((2,97),dtype=np.float32)\n", " }\n", " )[0]\n", "res" ] }, { "cell_type": "code", "execution_count": 208, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[[ 0.130969 , -0.0697925, 0.0150866, ..., -0.0559536,\n", " -0.239062 , 0.567436 ],\n", " [-0.753288 , -0.0582227, -0.825365 , ..., 0.320587 ,\n", " -0.153626 , -0.709664 ],\n", " [-0.656874 , 0.342632 , -0.607641 , ..., 0.0383743,\n", " -0.0218912, 0.269968 ],\n", " ...,\n", " [ 0.0291714, -0.0316175, 0.027369 , ..., -0.129825 ,\n", " 0.033166 , 0.153453 ],\n", " [-0.854555 , -0.530883 , 0.258313 , ..., 0.279057 ,\n", " 1.40658 , -0.159066 ],\n", " [-0.197598 , -0.306157 , -0.67907 , ..., 0.0915015,\n", " -0.124402 , 0.52159 ]]])" ] }, "execution_count": 208, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f = open('/mnt/data-2t/jeff/codes/llm/cpp/inference/f0.txt')\n", "content = f.readlines()\n", "f.close()\n", "audio_fea_cpp = np.array([float(i) for i in content[0].split(',')]).reshape(1,-1,2560)\n", "audio_fea_cpp" ] }, { "cell_type": "code", "execution_count": 202, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([[[0.917797, 1.33496 , 1.9894 , ..., 6.60723 , 6.95787 ,\n", " 7.20139 ],\n", " [0. , 0. , 0. , ..., 5.99914 , 6.11214 ,\n", " 6.40908 ],\n", " [0. , 0. , 0. , ..., 5.1184 , 5.36291 ,\n", " 5.14623 ],\n", " ...,\n", " [0. , 0. , 0. , ..., 6.25256 , 6.29312 ,\n", " 7.05511 ],\n", " [0. , 0. , 0. , ..., 6.49829 , 6.7198 ,\n", " 7.08144 ],\n", " [0. , 0. , 1.08376 , ..., 5.43068 , 5.97577 ,\n", " 6.35748 ]]]),\n", " tensor([[[0.8826, 1.3054, 1.9652, ..., 6.6069, 6.9578, 7.2011],\n", " [0.0000, 0.0000, 0.0000, ..., 5.9991, 6.1121, 6.4091],\n", " [0.0000, 0.0000, 0.0000, ..., 5.1147, 5.3624, 5.1428],\n", " ...,\n", " [0.0000, 0.0000, 0.0000, ..., 6.2526, 6.2931, 7.0548],\n", " [0.0000, 0.0000, 0.0000, ..., 6.4981, 6.7198, 7.0807],\n", " [0.0000, 0.0000, 1.1479, ..., 5.4311, 5.9743, 6.3568]]]))" ] }, "execution_count": 202, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f = open('/mnt/data-2t/jeff/codes/llm/cpp/inference/matrix_output.txt')\n", "txtlines = f.readlines()\n", "f.close()\n", "inp_emb_cpp = np.array([float(i) for l in txtlines for i in l.split(',')]).reshape(1,-1,80)\n", "inp_emb_cpp,pickup_dataset.__getitem__(0)['input_audio_embeds']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Python preprocessor" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(353, 80)" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# modify the code : \n", "# 1. input model and input pcm from args. \n", "# 2. add model input preprocessor by following python code. The wav input of _extract_features which is an audio array\n", "# 3. the onnx model input is [batch,frames,feature size] = [-1,-1,80]\n", "\n", "def _extract_spectrogram(wav, fs):\n", " \"\"\"Extract spectrogram features from waveform.\n", " Args:\n", " wav (1D array): waveform of the input\n", " fs (int): sampling rate of the waveform, 16000.\n", " Output:\n", " log_fbank (2D array): a TxD matrix of log Mel filterbank features.\n", " D=80, and T is the number of frames.\n", " \"\"\"\n", " if wav.ndim > 1:\n", " wav = np.squeeze(wav)\n", "\n", " # by default, we extract the mean if stereo\n", " if len(wav.shape) == 2:\n", " wav = wav.mean(1)\n", "\n", " preemphasis = 0.97\n", " n_fft = 512\n", " win_length = 400\n", " hop_length = 160\n", " fft_window = np.hamming(400)\n", "\n", " # Spec 1: SpeechLib cut remaining sample insufficient for a hop\n", " n_batch = (wav.shape[0] - win_length) // hop_length + 1\n", " # Here we don't use stride_tricks since the input array may not satisfy\n", " # memory layout requirement and we need writeable output\n", " # Here we only use list of views before copy to desination\n", " # so it is more efficient than broadcasting\n", " y_frames = np.array(\n", " [wav[_stride : _stride + win_length] for _stride in range(0, hop_length * n_batch, hop_length)],\n", " dtype=np.float32,\n", " )\n", "\n", " # Spec 2: SpeechLib applies preemphasis within each batch\n", " y_frames_prev = np.roll(y_frames, 1, axis=1)\n", " y_frames_prev[:, 0] = y_frames_prev[:, 1]\n", " y_frames = (y_frames - preemphasis * y_frames_prev) * 32768\n", "\n", " S = np.fft.rfft(fft_window * y_frames, n=n_fft, axis=1).astype(np.complex64)\n", " spec = np.abs(S).astype(np.float32)\n", " return spec\n", "def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):\n", " \"\"\"Create a Mel filter-bank the same as SpeechLib FbankFC.\n", "\n", " Args:\n", " sample_rate (int): Sample rate in Hz. number > 0 [scalar]\n", " n_fft (int): FFT size. int > 0 [scalar]\n", " n_mel (int): Mel filter size. int > 0 [scalar]\n", " fmin (float): lowest frequency (in Hz). If None use 0.0.\n", " float >= 0 [scalar]\n", " fmax: highest frequency (in Hz). If None use sample_rate / 2.\n", " float >= 0 [scalar]\n", "\n", " Returns\n", " out (numpy.ndarray): Mel transform matrix\n", " [shape=(n_mels, 1 + n_fft/2)]\n", " \"\"\"\n", "\n", " bank_width = int(n_fft // 2 + 1)\n", " if fmax is None:\n", " fmax = sample_rate / 2\n", " if fmin is None:\n", " fmin = 0\n", " assert fmin >= 0, \"fmin cannot be negtive\"\n", " assert fmin < fmax <= sample_rate / 2, \"fmax must be between (fmin, samplerate / 2]\"\n", "\n", " def mel(f):\n", " return 1127.0 * np.log(1.0 + f / 700.0)\n", "\n", " def bin2mel(fft_bin):\n", " return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))\n", "\n", " def f2bin(f):\n", " return int((f * n_fft / sample_rate) + 0.5)\n", "\n", " # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1]\n", " klo = f2bin(fmin) + 1\n", " khi = f2bin(fmax)\n", "\n", " khi = max(khi, klo)\n", "\n", " # Spec 2: SpeechLib uses trianges in Mel space\n", " mlo = mel(fmin)\n", " mhi = mel(fmax)\n", " m_centers = np.linspace(mlo, mhi, n_mels + 2)\n", " ms = (mhi - mlo) / (n_mels + 1)\n", "\n", " matrix = np.zeros((n_mels, bank_width), dtype=np.float32)\n", " for m in range(0, n_mels):\n", " left = m_centers[m]\n", " center = m_centers[m + 1]\n", " right = m_centers[m + 2]\n", " for fft_bin in range(klo, khi):\n", " mbin = bin2mel(fft_bin)\n", " if left < mbin < right:\n", " matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms\n", "\n", " return matrix\n", "\n", "def _extract_features(wav, fs):\n", " \"\"\"Extract log filterbank features from waveform.\n", " Args:\n", " wav (1D array): waveform of the input\n", " fs (int): sampling rate of the waveform, 16000 or 8000.\n", " If fs=8000, the waveform will be resampled to 16000Hz.\n", " Output:\n", " log_fbank (2D array): a TxD matrix of log Mel filterbank features.\n", " D=80, and T is the number of frames.\n", " \"\"\"\n", " spec = _extract_spectrogram(wav, fs)\n", " spec_power = spec**2\n", "\n", " fbank_power = np.clip(spec_power.dot(_mel), 1.0, None)\n", " log_fbank = np.log(fbank_power).astype(np.float32)\n", "\n", " return log_fbank\n", "\n", "## example \n", "## input shape of arr is [1, 56832], output shape will be (353,80)\n", "_mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=16000//2-80-230).T\n", "output = _extract_features(arr,16000)" ] }, { "cell_type": "code", "execution_count": 227, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(256, 80)" ] }, "execution_count": 227, "metadata": {}, "output_type": "execute_result" } ], "source": [ "_mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=16000//2-80-230).T\n", "output = _extract_features(arr,16000)\n", "output.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "256" ] }, "execution_count": 228, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(41239-400)//160+1 100~300" ] }, { "cell_type": "code", "execution_count": 229, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(16240, 48240)" ] }, "execution_count": 229, "metadata": {}, "output_type": "execute_result" } ], "source": [ "99*160+400,299*160+400" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "llamafactory", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 2 }