| echo "Starting GRPO training..." | |
| #!/bin/bash | |
| # run_blip2.sh | |
| # 用于启动 BLIP2 + GRPO 训练的脚本 | |
| # ===== 基本路径配置 ===== | |
| DATA_FILE=/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/DeepLocBinary/test.csv | |
| DATASET_NAME=deeplocbinary | |
| OUTPUT_DIR=./output | |
| CACHE_DIR=./cache | |
| # ===== 模型配置 ===== | |
| BERT_PATH=/nas/shared/kilab/wangyujia/ProtT3/plm_model/microsoft | |
| PLM_MODEL=/nas/shared/kilab/wangyujia/ProtT3/plm_model/esm2-150m | |
| LLM_MODEL=/oss/wangyujia/BIO/construction_finetuning/alpaca/v1-20250609-141541/checkpoint-50-merged | |
| SFT_CHECKPOINT=/nas/shared/kilab/wangyujia/ProtT3/all_checkpoints/stage2_07301646_2datasets_construct/epoch=09.ckpt/converted.ckpt | |
| # ===== 训练参数 ===== | |
| BATCH_SIZE=4 | |
| EPOCHS=3 | |
| LR=1e-5 | |
| # ===== 奖励函数权重 ===== | |
| FORMAT_WEIGHT=0.2 | |
| ACCURACY_WEIGHT=0.6 | |
| REPETITION_WEIGHT=0.2 | |
| # ===== 运行训练 ===== | |
| python blips_reason.py \ | |
| --data_file_paths ${DATA_FILE} \ | |
| --dataset_name ${DATASET_NAME} \ | |
| --reward_funcs combined \ | |
| --format_weight ${FORMAT_WEIGHT} \ | |
| --accuracy_weight ${ACCURACY_WEIGHT} \ | |
| --repetition_weight ${REPETITION_WEIGHT} \ | |
| --use_custom_prompts \ | |
| --template_name classification \ | |
| --max_seq_length 1000 \ | |
| --output_dir ${OUTPUT_DIR} \ | |
| --per_device_train_batch_size ${BATCH_SIZE} \ | |
| --num_train_epochs ${EPOCHS} \ | |
| --learning_rate ${LR} \ | |
| --bert_name ${BERT_PATH} \ | |
| --plm_model ${PLM_MODEL} \ | |
| --llm_name ${LLM_MODEL} \ | |
| --sft_checkpoint ${SFT_CHECKPOINT} \ | |
| --plm_tune freeze \ | |
| --llm_tune lora \ | |
| --qformer_tune train \ | |
| --lora_r 8 \ | |
| --lora_alpha 16 \ | |
| --lora_dropout 0.1 \ | |
| --enable_flash \ | |
| --cache_dir ${CACHE_DIR} | |
| # python protein_reason.py \ | |
| # --output_dir "./grpo_outputs" \ | |
| # --model_name_or_path "Qwen/Qwen3-0.6B" \ | |
| # --protein_model_name_or_path "facebook/esm2_t6_8M_UR50D" \ | |
| # --qformer_model_name_or_path "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext" \ | |
| # --dataset_name "wanglab/protein_function" \ | |
| # --sft_checkpoint "./checkpoints/best_model" \ | |
| # --per_device_train_batch_size 4 \ | |
| # --gradient_accumulation_steps 4 \ | |
| # --num_train_epochs 3 \ | |
| # --learning_rate 1e-6 \ | |
| # --beta 0.04 \ | |
| # --temperature 0.6 \ | |
| # --top_p 0.95 \ | |
| # --top_k 20 \ | |
| # --max_completion_length 800 \ | |
| # --num_generations 8 \ | |
| # --reward_funcs "xmlcount" "soft_format" "strict_format" "correctness" \ | |
| # --lora_r 32 \ | |
| # --lora_alpha 64 \ | |
| # --lora_dropout 0.05 \ | |
| # --freeze_protein_modules \ | |
| # --logging_steps 2 \ | |
| # --eval_strategy "steps" \ | |
| # --eval_steps 100 \ | |
| # --save_steps 200 \ | |
| # --report_to "wandb" \ | |
| # --log_completions | |
| # python blip2_reason.py \ | |
| # --data_file_paths /oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/DeepLocBinary/test.csv \ | |
| # --reward_funcs combined \ | |
| # --format_weight 0.2 \ | |
| # --accuracy_weight 0.6 \ | |
| # --repetition_weight 0.2 \ | |
| # --use_custom_prompts \ | |
| # --template_name classification \ | |
| # --max_seq_length 1000 \ | |
| # --output_dir ./output \ | |
| # --per_device_train_batch_size 4 \ | |
| # --num_train_epochs 3 \ | |
| # --learning_rate 1e-5 | |
| echo "GRPO training completed!" | |
| echo "All training stages completed successfully!" |