streaming-speech / scripts /finetune_llm_speech_decoder.sh
NMCxyz's picture
Add files using upload-large-folder tool
9942354 verified
#!/bin/bash
# it currently supports for batch = 1 only.
MODEL_PATH=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-3B-Instruct
SPEECH_ENCODER=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/speech_encoder/whisper-medium
SPEECH_ADAPTER=/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-asr/speech_projector.bin
PROMPT_VERSION=qwen
DATA_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/train_20250106_fc_mixed_tgt_units.jsonl
DEV_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/dev_20250106_fc_mixed_tgt_units.jsonl
CACHE_DIR="../output/cached_sft_speech_decoder_all_20250103"
deepspeed ../omni_speech/train/train_mem.py \
--deepspeed zero2.json \
--model_name_or_path $MODEL_PATH \
--version $PROMPT_VERSION \
--data_path $DATA_PATH \
--dev_path $DEV_PATH \
--cache_dir $CACHE_DIR \
--speech_encoder $SPEECH_ENCODER \
--mel_size 80 \
--speech_encoder_hidden_size 1024 \
--speech_encoder_type whisper \
--pretrain_speech_projector $SPEECH_ADAPTER \
--bf16 True \
--output_dir ../checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-sft-fc_speech_decoder_fixed_all \
--num_train_epochs 3 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 4 \
--evaluation_strategy "steps" \
--save_strategy "steps" \
--eval_steps 2000 \
--save_steps 2000 \
--save_total_limit 1 \
--learning_rate 1e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 1024 \
--gradient_checkpointing True \
--dataloader_num_workers 8 \
--has_tgt_units True \
--ctc_loss_weight 2.0
# MODEL_PATH=/data1/speech/anhnmt2/Speech2Speech/half-streaming-speech-nlp/checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-sft-fc
# SPEECH_ENCODER=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/speech_encoder/whisper-medium
# PROMPT_VERSION=qwen
# DATA_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/moss/moss_100K_phase3_tgt_units_processed.jsonl
# # DEV_PATH=/data1/speech/anhnmt2/dataset/s2s/english/qna/dev_20250106_fc_mixed_tgt_units.jsonl
# CACHE_DIR="../output/cached_sft_speech_decoder_all_20250103"
# deepspeed ../omni_speech/train/train_mem.py \
# --deepspeed zero2.json \
# --model_name_or_path $MODEL_PATH \
# --version $PROMPT_VERSION \
# --data_path $DATA_PATH \
# --cache_dir $CACHE_DIR \
# --speech_encoder $SPEECH_ENCODER \
# --mel_size 80 \
# --speech_encoder_hidden_size 1024 \
# --speech_encoder_type whisper \
# --bf16 True \
# --output_dir ../checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-sft-fc_speech_decoder_all \
# --num_train_epochs 5 \
# --per_device_train_batch_size 1 \
# --per_device_eval_batch_size 1 \
# --gradient_accumulation_steps 4 \
# --evaluation_strategy "no" \
# --save_strategy "steps" \
# --save_steps 10000 \
# --save_total_limit 1 \
# --learning_rate 1e-4 \
# --weight_decay 0. \
# --warmup_ratio 0.03 \
# --logging_steps 1 \
# --tf32 True \
# --model_max_length 2048 \
# --gradient_checkpointing True \
# --dataloader_num_workers 8 \
# --has_tgt_units True \
# --ctc_loss_weight 10.0