|
|
#!/bin/bash |
|
|
|
|
|
export DEBUG_MODE=true |
|
|
export LOG_PATH="./debug_log_2b.txt" |
|
|
export CUDA_VISIBLE_DEVICES=0 |
|
|
export MAIN_PROCESS_PORT=29507 |
|
|
export NCCL_DEBUG=INFO |
|
|
export NCCL_IB_DISABLE=1 |
|
|
export NCCL_P2P_DISABLE=1 |
|
|
export NCCL_ASYNC_DISABLE=1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
REASONER_MODEL="Qwen/Qwen2.5-1.5B-Instruct" |
|
|
WEAVER_MODEL="Qwen/Qwen2.5-1.5B-Instruct" |
|
|
TRIGGER_MODEL=null |
|
|
|
|
|
|
|
|
DATASET_NAME="gsm8k" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MAX_PROMPT_AUG_NUM=1 |
|
|
MAX_INFERENCE_AUG_NUM=5 |
|
|
PROMPT_LATENTS_LEN=8 |
|
|
INFERENCE_LATENTS_LEN=8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LOAD_MODEL_PATH="<output_dir>/model.safetensors" |
|
|
|
|
|
|
|
|
python -m accelerate.commands.launch \ |
|
|
--config_file=configs/zero2.yaml \ |
|
|
main.py \ |
|
|
--cfg-path configs/latent_memory/${DATASET_NAME}.yaml \ |
|
|
--options \ |
|
|
model.reasoner_model_name ${REASONER_MODEL} \ |
|
|
model.weaver.weaver_model_name ${WEAVER_MODEL} \ |
|
|
model.trigger.trigger_model_name ${TRIGGER_MODEL} \ |
|
|
model.weaver.prompt_latents_len ${PROMPT_LATENTS_LEN} \ |
|
|
model.weaver.inference_latents_len ${INFERENCE_LATENTS_LEN} \ |
|
|
model.max_prompt_aug_num ${MAX_PROMPT_AUG_NUM} \ |
|
|
model.max_inference_aug_num ${MAX_INFERENCE_AUG_NUM} \ |
|
|
model.load_model_path ${LOAD_MODEL_PATH} \ |
|
|
run.mode evaluate \ |
|
|
run.generation.eval_batch_size 4 \ |
|
|
run.generation.do_sample False \ |
|
|
run.generation.temperature ${TEMPERATURE} \ |
|
|
run.generation.max_response_length 1024 \ |