| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| set -euo pipefail |
|
|
| CHECKPOINT=${CHECKPOINT:-${1:-}} |
| if [[ -z "${CHECKPOINT}" ]]; then |
| echo "ERROR: set CHECKPOINT=/path/to/dememwm.ckpt, e.g." >&2 |
| echo " sbatch --export=ALL,CHECKPOINT=/share_1/users/bonan_ding/DeMemWM/outputs/<run>/train/checkpoints/last.ckpt scripts/dememwm_full_eval.slurm" >&2 |
| exit 2 |
| fi |
| if [[ ! -s "${CHECKPOINT}" ]]; then |
| echo "ERROR: checkpoint does not exist or is empty: ${CHECKPOINT}" >&2 |
| exit 2 |
| fi |
|
|
| REPO=${REPO:-/share_1/users/bonan_ding/DeMemWM} |
| DATA_DIR=${DATA_DIR:-/share_1/users/bonan_ding/worldmem_data/minecraft} |
| FEATURE_DIR=${FEATURE_DIR:-/share_1/users/bonan_ding/worldmem_data/minecraft/vae_features} |
|
|
| RUN_TAG=${RUN_TAG:-dememwm_full_eval_${SLURM_JOB_ID:-manual_$(date +%Y%m%d_%H%M%S)}} |
| RUN_ROOT=${RUN_ROOT:-${REPO}/outputs/${RUN_TAG}} |
| EVAL_OUT=${EVAL_OUT:-${RUN_ROOT}/eval} |
| LOG_DIR=${LOG_DIR:-${REPO}/slurm_logs/${RUN_TAG}} |
| mkdir -p "${EVAL_OUT}" "${LOG_DIR}" "${REPO}/slurm_logs" |
|
|
| DATASET_N_FRAMES=${DATASET_N_FRAMES:-300} |
| N_FRAMES_VALID=${N_FRAMES_VALID:-216} |
| CONTEXT_FRAMES=${CONTEXT_FRAMES:-116} |
| N_TOKENS=${N_TOKENS:-8} |
| SAMPLING_TIMESTEPS=${SAMPLING_TIMESTEPS:-20} |
| VAL_BATCH_SIZE=${VAL_BATCH_SIZE:-1} |
| VAL_LIMIT=${VAL_LIMIT:-16} |
| LOG_VIDEO=${LOG_VIDEO:-true} |
| SEED=${SEED:-42} |
| ABLATION_BRANCH=${ABLATION_BRANCH:-A_plus_D_plus_R_normal} |
|
|
| |
| |
| |
| ANCHOR_DOWNSAMPLE_RATIO=${ANCHOR_DOWNSAMPLE_RATIO:-6} |
| REVISIT_MAX_FRAMES=${REVISIT_MAX_FRAMES:-2} |
| REVISIT_DOWNSAMPLE_RATIO=${REVISIT_DOWNSAMPLE_RATIO:-3} |
|
|
| cd "${REPO}" |
| source ~/.bashrc >/dev/null 2>&1 || true |
| if command -v conda >/dev/null 2>&1; then |
| eval "$(conda shell.bash hook)" |
| elif [[ -f "${HOME}/.conda/etc/profile.d/conda.sh" ]]; then |
| source "${HOME}/.conda/etc/profile.d/conda.sh" |
| elif [[ -f /share_0/conda/etc/profile.d/conda.sh ]]; then |
| source /share_0/conda/etc/profile.d/conda.sh |
| fi |
| conda activate worldmem |
| PY=$(which python) |
| |
| export PYTHONPATH="./:${PYTHONPATH:-}" |
| export HYDRA_FULL_ERROR=1 |
| export PYTHONWARNINGS=ignore |
| export OMP_NUM_THREADS="${SLURM_CPUS_PER_TASK:-16}" |
| export WANDB_MODE=offline |
| export NCCL_P2P_DISABLE=1 |
| wandb offline >/dev/null 2>&1 || true |
| |
| echo "JOB_ID=${SLURM_JOB_ID:-manual}" |
| echo "RUN_TAG=${RUN_TAG}" |
| echo "RUN_ROOT=${RUN_ROOT}" |
| echo "CHECKPOINT=${CHECKPOINT}" |
| echo "ABLATION_BRANCH=${ABLATION_BRANCH}" |
| echo "HOST=$(hostname)" |
| echo "START=$(date --iso-8601=seconds)" |
| echo "PWD=$PWD" |
| echo "PY=${PY}" |
| "${PY}" --version |
| nvidia-smi || true |
| nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits > "${LOG_DIR}/gpu_memory_before_mb.txt" || true |
| git branch --show-current || true |
| git rev-parse HEAD || true |
| |
| EVAL_ARGS=( |
| "+name=eval_${RUN_TAG}" |
| "+output_dir=${EVAL_OUT}/" |
| "experiment.tasks=[validation]" |
| "wandb.mode=offline" |
| "dataset.validation_multiplier=1" |
| "+dataset.seed=${SEED}" |
| "+customized_load=true" |
| "+seperate_load=false" |
| "algorithm=dememwm_memory_dit" |
| "load=${CHECKPOINT}" |
| "dataset=video_minecraft_latent" |
| "dataset.save_dir=${DATA_DIR}" |
| "dataset.precomputed_feature_dir=${FEATURE_DIR}" |
| "dataset.n_frames=${DATASET_N_FRAMES}" |
| "+dataset.n_frames_valid=${N_FRAMES_VALID}" |
| "+dataset.customized_validation=true" |
| "+dataset.memory_condition_length=0" |
| "++dataset.angle_range=180" |
| "++dataset.pos_range=1000000000" |
| "++algorithm.n_tokens=${N_TOKENS}" |
| "algorithm.x_shape=[16,18,32]" |
| "++algorithm.context_frames=${CONTEXT_FRAMES}" |
| "++algorithm.log_video=${LOG_VIDEO}" |
| "++algorithm.diffusion.sampling_timesteps=${SAMPLING_TIMESTEPS}" |
| "++algorithm.dememwm.debug_force_all_streams=false" |
| "++algorithm.dememwm.training_stage=stage_2" |
| "++algorithm.dememwm.anchor.enabled=true" |
| "++algorithm.dememwm.anchor.anchor_indices=[0,1,2,3]" |
| "++algorithm.dememwm.anchor.diverse_selection=true" |
| "++algorithm.dememwm.anchor.compress.downsample_ratio=${ANCHOR_DOWNSAMPLE_RATIO}" |
| "++algorithm.dememwm.anchor.allow_generated_as_anchor=false" |
| "++algorithm.dememwm.dynamic.enabled=true" |
| "++algorithm.dememwm.dynamic.exclude_latest_local_frames=4" |
| "++algorithm.dememwm.dynamic.recent_frames=8" |
| "++algorithm.dememwm.revisit.enabled=true" |
| "++algorithm.dememwm.revisit.deterministic_pose_retrieval=true" |
| "++algorithm.dememwm.revisit.fov_overlap_threshold=0.30" |
| "++algorithm.dememwm.revisit.high_quality_fov_threshold=0.70" |
| "++algorithm.dememwm.revisit.pose_preselect_topk=64" |
| "++algorithm.dememwm.revisit.fov_yaw_samples=25" |
| "++algorithm.dememwm.revisit.fov_pitch_samples=20" |
| "++algorithm.dememwm.revisit.fov_depth_samples=20" |
| "++algorithm.dememwm.revisit.plucker_weight=0.10" |
| "++algorithm.dememwm.revisit.max_frames=${REVISIT_MAX_FRAMES}" |
| "++algorithm.dememwm.revisit.compress.downsample_ratio=${REVISIT_DOWNSAMPLE_RATIO}" |
| "++algorithm.dememwm.stage_policy.noise_bucket_logging=true" |
| "++algorithm.dememwm.eval_ablation.enabled=true" |
| "++algorithm.dememwm.eval_ablation.branch=${ABLATION_BRANCH}" |
| "++algorithm.dememwm.cache.enabled=true" |
| "++algorithm.dememwm.cache.device=cpu" |
| "++algorithm.dememwm.cache.keep_raw_latents=all" |
| "++algorithm.dememwm.cache.keep_compressed_records=true" |
| "++algorithm.dememwm.cache.eviction_policy=none" |
| "++algorithm.dememwm.cache.no_evict=true" |
| "++algorithm.dememwm.cache.clear_between_videos=true" |
| "++algorithm.dememwm.cache.max_records=null" |
| "++algorithm.dememwm.cache.on_capacity_exceeded=warn" |
| "experiment.validation.batch_size=${VAL_BATCH_SIZE}" |
| "experiment.validation.limit_batch=${VAL_LIMIT}" |
| ) |
| |
| printf '%s\n' "${EVAL_ARGS[@]}" > "${LOG_DIR}/eval_args.txt" |
| echo "Launching evaluation..." |
| SECONDS=0 |
| srun "${PY}" -m main "${EVAL_ARGS[@]}" |
| EVAL_DURATION_SECONDS=${SECONDS} |
| nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits > "${LOG_DIR}/gpu_memory_after_mb.txt" || true |
| |
| cat > "${RUN_ROOT}/eval_manifest.txt" <<MANIFEST |
| RUN_TAG=${RUN_TAG} |
| RUN_ROOT=${RUN_ROOT} |
| EVAL_OUT=${EVAL_OUT} |
| CHECKPOINT=${CHECKPOINT} |
| ABLATION_BRANCH=${ABLATION_BRANCH} |
| EVAL_DURATION_SECONDS=${EVAL_DURATION_SECONDS} |
| GPU_MEMORY_BEFORE_MB_FILE=${LOG_DIR}/gpu_memory_before_mb.txt |
| GPU_MEMORY_AFTER_MB_FILE=${LOG_DIR}/gpu_memory_after_mb.txt |
| JOB_ID=${SLURM_JOB_ID:-manual} |
| FINISHED=$(date --iso-8601=seconds) |
| MANIFEST |
| |
| echo "DEMEMWM_FULL_EVAL_DONE $(date --iso-8601=seconds)" |
| |