|
|
#!/bin/bash |
|
|
|
|
|
export DEBUG_MODE=true |
|
|
export LOG_PATH="./debug_log_2b.txt" |
|
|
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 |
|
|
export MAIN_PROCESS_PORT=29507 |
|
|
export NCCL_DEBUG=WARN |
|
|
export NCCL_IB_DISABLE=1 |
|
|
export NCCL_P2P_DISABLE=0 |
|
|
export NCCL_ASYNC_DISABLE=1 |
|
|
export TORCH_DISTRIBUTED_DEBUG=OFF |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
REASONER_MODEL="Qwen/Qwen2.5-1.5B-Instruct" |
|
|
WEAVER_MODEL="Qwen/Qwen2.5-7B-Instruct" |
|
|
TRIGGER_MODEL=null |
|
|
|
|
|
|
|
|
DATASET_NAME="gsm8k" |
|
|
DATASET_MODE="grpo" |
|
|
|
|
|
|
|
|
TRAIN_METHOD="grpo" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MAX_PROMPT_AUG_NUM=1 |
|
|
MAX_INFERENCE_AUG_NUM=5 |
|
|
PROMPT_LATENTS_LEN=8 |
|
|
INFERENCE_LATENTS_LEN=8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LOAD_WEAVER_PATH=null |
|
|
|
|
|
NUM_PROCS=8 |
|
|
|
|
|
|
|
|
uv run python -m accelerate.commands.launch \ |
|
|
--num_processes=${NUM_PROCS} \ |
|
|
--main_process_port=${MAIN_PROCESS_PORT} \ |
|
|
--config_file=configs/zero2.yaml \ |
|
|
main.py \ |
|
|
--cfg-path configs/latent_memory/${DATASET_NAME}.yaml \ |
|
|
--options \ |
|
|
model.reasoner_model_name ${REASONER_MODEL} \ |
|
|
model.weaver.weaver_model_name ${WEAVER_MODEL} \ |
|
|
model.trigger.trigger_model_name ${TRIGGER_MODEL} \ |
|
|
model.weaver.prompt_latents_len ${PROMPT_LATENTS_LEN} \ |
|
|
model.weaver.inference_latents_len ${INFERENCE_LATENTS_LEN} \ |
|
|
model.max_prompt_aug_num ${MAX_PROMPT_AUG_NUM} \ |
|
|
model.max_inference_aug_num ${MAX_INFERENCE_AUG_NUM} \ |
|
|
model.load_model_path ${LOAD_WEAVER_PATH} \ |
|
|
datasets.${DATASET_NAME}.mode ${DATASET_MODE} \ |
|
|
run.mode train \ |
|
|
run.train_weaver True \ |
|
|
run.train_trigger False \ |
|
|
run.train_weaver_method ${TRAIN_METHOD} \ |
|
|
run.generation.do_sample True \ |
|
|
run.generation.temperature 1.0 \ |
|
|
run.generation.max_response_length 512 \ |