yuccaaa's picture
Add files using upload-large-folder tool
acbfbc3 verified
echo "Starting GRPO training..."
#!/bin/bash
# run_blip2.sh
# 用于启动 BLIP2 + GRPO 训练的脚本
# ===== 基本路径配置 =====
DATA_FILE=/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/DeepLocBinary/test.csv
DATASET_NAME=deeplocbinary
OUTPUT_DIR=./output
CACHE_DIR=./cache
# ===== 模型配置 =====
BERT_PATH=/nas/shared/kilab/wangyujia/ProtT3/plm_model/microsoft
PLM_MODEL=/nas/shared/kilab/wangyujia/ProtT3/plm_model/esm2-150m
LLM_MODEL=/oss/wangyujia/BIO/construction_finetuning/alpaca/v1-20250609-141541/checkpoint-50-merged
SFT_CHECKPOINT=/nas/shared/kilab/wangyujia/ProtT3/all_checkpoints/stage2_07301646_2datasets_construct/epoch=09.ckpt/converted.ckpt
# ===== 训练参数 =====
BATCH_SIZE=4
EPOCHS=3
LR=1e-5
# ===== 奖励函数权重 =====
FORMAT_WEIGHT=0.2
ACCURACY_WEIGHT=0.6
REPETITION_WEIGHT=0.2
# ===== 运行训练 =====
python blips_reason.py \
--data_file_paths ${DATA_FILE} \
--dataset_name ${DATASET_NAME} \
--reward_funcs combined \
--format_weight ${FORMAT_WEIGHT} \
--accuracy_weight ${ACCURACY_WEIGHT} \
--repetition_weight ${REPETITION_WEIGHT} \
--use_custom_prompts \
--template_name classification \
--max_seq_length 1000 \
--output_dir ${OUTPUT_DIR} \
--per_device_train_batch_size ${BATCH_SIZE} \
--num_train_epochs ${EPOCHS} \
--learning_rate ${LR} \
--bert_name ${BERT_PATH} \
--plm_model ${PLM_MODEL} \
--llm_name ${LLM_MODEL} \
--sft_checkpoint ${SFT_CHECKPOINT} \
--plm_tune freeze \
--llm_tune lora \
--qformer_tune train \
--lora_r 8 \
--lora_alpha 16 \
--lora_dropout 0.1 \
--enable_flash \
--cache_dir ${CACHE_DIR}
# python protein_reason.py \
# --output_dir "./grpo_outputs" \
# --model_name_or_path "Qwen/Qwen3-0.6B" \
# --protein_model_name_or_path "facebook/esm2_t6_8M_UR50D" \
# --qformer_model_name_or_path "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext" \
# --dataset_name "wanglab/protein_function" \
# --sft_checkpoint "./checkpoints/best_model" \
# --per_device_train_batch_size 4 \
# --gradient_accumulation_steps 4 \
# --num_train_epochs 3 \
# --learning_rate 1e-6 \
# --beta 0.04 \
# --temperature 0.6 \
# --top_p 0.95 \
# --top_k 20 \
# --max_completion_length 800 \
# --num_generations 8 \
# --reward_funcs "xmlcount" "soft_format" "strict_format" "correctness" \
# --lora_r 32 \
# --lora_alpha 64 \
# --lora_dropout 0.05 \
# --freeze_protein_modules \
# --logging_steps 2 \
# --eval_strategy "steps" \
# --eval_steps 100 \
# --save_steps 200 \
# --report_to "wandb" \
# --log_completions
# python blip2_reason.py \
# --data_file_paths /oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/DeepLocBinary/test.csv \
# --reward_funcs combined \
# --format_weight 0.2 \
# --accuracy_weight 0.6 \
# --repetition_weight 0.2 \
# --use_custom_prompts \
# --template_name classification \
# --max_seq_length 1000 \
# --output_dir ./output \
# --per_device_train_batch_size 4 \
# --num_train_epochs 3 \
# --learning_rate 1e-5
echo "GRPO training completed!"
echo "All training stages completed successfully!"