|
|
#!/bin/bash |
|
|
|
|
|
MODEL_PATH=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Qwen2.5-3B-Instruct |
|
|
SPEECH_ENCODER=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/speech_encoder/whisper-medium |
|
|
PROMPT_VERSION=qwen |
|
|
DATA_PATH=/data1/speech/anhnmt2/dataset/s2s/english/asr/dataset/train_asr_eng_5M.jsonl |
|
|
DEV_PATH=/data1/speech/anhnmt2/dataset/s2s/english/asr/dataset/dev_asr_libri_spgi.jsonl |
|
|
CACHE_DIR="../output/cached_asr_full" |
|
|
AUGMENT_PATH="/data1/speech/anhnmt2/dataset/s2s/augment/noise_list_non_speech.txt" |
|
|
|
|
|
deepspeed ../omni_speech/train/train_mem.py \ |
|
|
--deepspeed zero2.json \ |
|
|
--model_name_or_path $MODEL_PATH \ |
|
|
--version $PROMPT_VERSION \ |
|
|
--data_path $DATA_PATH \ |
|
|
--dev_path $DEV_PATH \ |
|
|
--cache_dir $CACHE_DIR \ |
|
|
--speech_encoder $SPEECH_ENCODER \ |
|
|
--mel_size 80 \ |
|
|
--speech_encoder_hidden_size 1024 \ |
|
|
--speech_encoder_type whisper \ |
|
|
--bf16 True \ |
|
|
--output_dir ../checkpoints/omni_whisper-medium_Qwen2.5-3B_pretrained-asr-5M \ |
|
|
--num_train_epochs 4 \ |
|
|
--tune_speech_projector True \ |
|
|
--per_device_train_batch_size 16 \ |
|
|
--per_device_eval_batch_size 4 \ |
|
|
--gradient_accumulation_steps 2 \ |
|
|
--evaluation_strategy "steps" \ |
|
|
--save_strategy "steps" \ |
|
|
--eval_steps 2000 \ |
|
|
--save_steps 2000 \ |
|
|
--save_total_limit 1 \ |
|
|
--learning_rate 1e-3 \ |
|
|
--weight_decay 0. \ |
|
|
--warmup_ratio 0.03 \ |
|
|
--lr_scheduler_type "cosine" \ |
|
|
--logging_steps 1 \ |
|
|
--tf32 True \ |
|
|
--model_max_length 4096 \ |
|
|
--gradient_checkpointing True \ |
|
|
--dataloader_num_workers 8 |
|
|
|
|
|
|
|
|
|