| | #!/bin/bash |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | echo "Starting contrastive pre-training..." |
| |
|
| | python train_contrastive.py \ |
| | --text_model_name "Qwen/Qwen3-1.7B" \ |
| | --protein_model_name "facebook/esm2_t6_8M_UR50D" \ |
| | --qformer_model_name "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext" \ |
| | --dataset_name "wanglab/protein_descriptions" \ |
| | --output_dir "./contrastive_outputs" \ |
| | --num_epochs 10 \ |
| | --batch_size 32 \ |
| | --learning_rate 1e-4 \ |
| | --temperature 0.07 \ |
| | --freeze_protein_model \ |
| | --freeze_text_model \ |
| | --max_length_protein 1024 \ |
| | --max_length_text 512 \ |
| | --eval_dataset \ |
| | --use_wandb \ |
| | --wandb_project "protein-llm-contrastive" \ |
| | --logging_steps 100 \ |
| | --eval_steps 500 \ |
| | --save_steps 1000 |
| |
|
| | echo "Contrastive pre-training completed!" |
| |
|
| | |
| | |
| | |
| | |
| | echo "Starting supervised fine-tuning..." |
| |
|
| | python train_protein_qwen.py \ |
| | --model_type "protein-llm" \ |
| | --text_model_name "Qwen/Qwen3-1.7B" \ |
| | --protein_model_name "facebook/esm2_t6_8M_UR50D" \ |
| | --qformer_model_name "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext" \ |
| | --dataset_type "protein_function" \ |
| | --protein_function_data_dir_huggingface "wanglab/protein_function" \ |
| | --text_model_finetune True \ |
| | --protein_model_finetune False \ |
| | --num_query_tokens 32 \ |
| | --seed 23 \ |
| | --batch_size 4 \ |
| | --max_epochs 5 \ |
| | --learning_rate 5e-5 \ |
| | --weight_decay 0.01 \ |
| | --gradient_accumulation_steps 8 \ |
| | --max_length_protein 1024 \ |
| | --max_length_text 1024 \ |
| | --lora_rank 32 \ |
| | --lora_alpha 64 \ |
| | --lora_dropout 0.05 \ |
| | --num_gpus 1 \ |
| | --strategy "ddp" \ |
| | --wandb_project "esm2-qwen3-1.7b-finetune" \ |
| | --checkpoint_dir "./checkpoints" \ |
| | --log_dir "./logs" \ |
| | --cache_dir "/model-weights" |
| |
|
| | echo "Supervised fine-tuning completed!" |
| |
|
| | |
| | |
| | |
| | |
| | echo "Starting GRPO training..." |
| |
|
| | python protein_reason.py \ |
| | --output_dir "./grpo_outputs" \ |
| | --model_name_or_path "Qwen/Qwen3-0.6B" \ |
| | --protein_model_name_or_path "facebook/esm2_t6_8M_UR50D" \ |
| | --qformer_model_name_or_path "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext" \ |
| | --dataset_name "wanglab/protein_function" \ |
| | --sft_checkpoint "./checkpoints/best_model" \ |
| | --per_device_train_batch_size 4 \ |
| | --gradient_accumulation_steps 4 \ |
| | --num_train_epochs 3 \ |
| | --learning_rate 1e-6 \ |
| | --beta 0.04 \ |
| | --temperature 0.6 \ |
| | --top_p 0.95 \ |
| | --top_k 20 \ |
| | --max_completion_length 800 \ |
| | --num_generations 8 \ |
| | --reward_funcs "xmlcount" "soft_format" "strict_format" "correctness" \ |
| | --lora_r 32 \ |
| | --lora_alpha 64 \ |
| | --lora_dropout 0.05 \ |
| | --freeze_protein_modules \ |
| | --logging_steps 2 \ |
| | --eval_strategy "steps" \ |
| | --eval_steps 100 \ |
| | --save_steps 200 \ |
| | --report_to "wandb" \ |
| | --log_completions |
| |
|
| | echo "GRPO training completed!" |
| |
|
| | echo "All training stages completed successfully!" |