E-GRPO / scoure_code /hope /finetune_tempflow_multi.sh
zhangsj0722's picture
Upload folder using huggingface_hub
58a7e24 verified
# cluster_spec='{"am":["psx2s7cxrbvmlcvk-am-0.psx2s7cxrbvmlcvk.hadoop-aipnlp.svc.cluster.local"],"index":"0","role":"worker","worker":["psx2s7cxrbvmlcvk-worker-0.psx2s7cxrbvmlcvk.hadoop-aipnlp.svc.cluster.local:3400","psx2s7cxrbvmlcvk-worker-1.psx2s7cxrbvmlcvk.hadoop-aipnlp.svc.cluster.local:3400"]}'
# echo "cluster spec is $cluster_spec"
WORK_DIR=$1
PYTHON_BIN=$2
SCRIPT=$3
NNODES=$4
NPROC_PER_NODE=$5
echo "WORK_DIR is $WORK_DIR"
echo "PYTHON_BIN is $PYTHON_BIN"
echo "SCRIPT is $SCRIPT"
echo "NNODES is $NNODES"
echo "NPROC_PER_NODE is $NPROC_PER_NODE"
PORT=${PORT:-29509}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
cluster_spec=${AFO_ENV_CLUSTER_SPEC//\"/\\\"}
echo "cluster spec is $cluster_spec"
# Assuming worker_list contains the JSON string (it's already been parsed)
worker_list_command="import json_parser; data = json_parser.parse('$cluster_spec'); print(data['worker'])"
worker_list=$($PYTHON_BIN -c "$worker_list_command")
# Remove the square brackets and quotes from worker_list
worker_list_cleaned=$(echo $worker_list | tr -d '[]' | tr -d "'")
# Convert the cleaned worker list into an array by splitting by commas
worker_strs=($(echo $worker_list_cleaned | tr ',' '\n'))
# Extract the master (first worker)
master=${worker_strs[0]}
# Extract master address and port
master_addr=$(echo $master | cut -d ':' -f1)
master_port=$(echo $master | cut -d ':' -f2)
# Output the master information without brackets and quotes
echo "worker list is $worker_list_cleaned"
echo "master is $master"
echo "master address is $master_addr"
echo "master port is $master_port"
worker_list_command="import json_parser; data = json_parser.parse('$cluster_spec'); print(data['index'])"
node_rank=$($PYTHON_BIN -c "$worker_list_command")
echo "node rank is $node_rank"
dist_url="tcp://$master_addr:$master_port"
echo "dist url is $dist_url"
export TOKENIZERS_PARALLELISM=false
export OMP_NUM_THREADS=1
export NCCL_DEBUG=INFO
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1
### launch with DDP (multi-machines-multi-gpus)
source scl_source enable devtoolset-7
ifconfig
cd $WORK_DIR=/mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/Granular-GRPO/hope/json_parse_test.sh
$PYTHON_BIN -m torch.distributed.run \
--nnodes=$NNODES --nproc_per_node=$NPROC_PER_NODE --node_rank=$node_rank --master_addr=$master_addr --master_port=$PORT \
$SCRIPT \
--seed 42 \
--pretrained_model_name_or_path /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/checkpoints/G2RPO/ckpt/flux \
--hps_path /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/checkpoints/G2RPO/ckpt/hps/HPS_v2.1_compressed.pt \
--hps_clip_path /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/checkpoints/G2RPO/ckpt/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin \
--clip_score_path /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/checkpoints/G2RPO/ckpt/clip_score \
--data_json_path /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/checkpoints/G2RPO/rl_embeddings/videos2caption.json \
--gradient_checkpointing \
--train_batch_size 1 \
--num_latent_t 1 \
--sp_size 1 \
--train_sp_batch_size 1 \
--dataloader_num_workers 4 \
--max_train_steps 301 \
--learning_rate 2e-6 \
--mixed_precision bf16 \
--checkpointing_steps 50 \
--cfg 0.0 \
--output_dir /mnt/dolphinfs/ssd_pool/docker/user/hadoop-videogen-hl/hadoop-camera3d/zhangshengjun/checkpoints/G2RPO/save_exp/tempflow_hps_clip \
--h 1024 \
--w 1024 \
--t 1 \
--sampling_steps 16 \
--eta 0.7 \
--lr_warmup_steps 0 \
--sampler_seed 1223627 \
--max_grad_norm 1.0 \
--weight_decay 0.0001 \
--num_generations 12 \
--shift 3 \
--init_same_noise \
--clip_range 1e-4 \
--adv_clip_max 5.0 \
--eta_step_list 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 \
--granular_list 1 \