streaming-speech / scripts /test_llama.sh
NMCxyz's picture
Add files using upload-large-folder tool
9942354 verified
#!/bin/bash
MODEL_PATH=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/llm/Llama-3.1-8B-Instruct
SPEECH_ENCODER=/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/models/speech_encoder/whisper-medium
PROMPT_VERSION=llama_3
DATA_PATH=/data1/speech/anhnmt2/dataset/s2s/new/train_asr_eng_50000.jsonl
DEV_PATH=/data1/speech/anhnmt2/dataset/s2s/new/dev_asr_eng_5000.jsonl
CACHE_DIR="../output/cached_asr"
deepspeed ../omni_speech/train/train.py \
--deepspeed zero2.json \
--model_name_or_path $MODEL_PATH \
--version $PROMPT_VERSION \
--data_path $DATA_PATH \
--dev_path $DEV_PATH \
--cache_dir $CACHE_DIR \
--speech_encoder $SPEECH_ENCODER \
--mel_size 80 \
--speech_encoder_hidden_size 1024 \
--speech_encoder_type whisper \
--bf16 True \
--output_dir ../checkpoints/llama-omni-pretrained-asr-test \
--num_train_epochs 10 \
--tune_speech_projector True \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 4 \
--evaluation_strategy "steps" \
--save_strategy "steps" \
--eval_steps 2000 \
--save_steps 2000 \
--save_total_limit 1 \
--learning_rate 1e-3 \
--optim adamw_torch \
--weight_decay 0. \
--warmup_ratio 0.03 \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 8