DeMemWM / scripts /dememwm_full_eval.slurm
BonanDing's picture
Clean DeMemWM deterministic memory slot handling
93d7b0a
#!/usr/bin/env bash
#SBATCH --job-name=dememwm_full_eval
#SBATCH --partition=gpu
#SBATCH --time=1-00:00:00
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=16
#SBATCH --mem=256G
#SBATCH --gres=gpu:1
#SBATCH --chdir=/share_1/users/bonan_ding/DeMemWM
#SBATCH --output=/share_1/users/bonan_ding/DeMemWM/slurm_logs/%x_%j.out
#SBATCH --error=/share_1/users/bonan_ding/DeMemWM/slurm_logs/%x_%j.err
# Full DeMemWM evaluation script for DeMemWM/H200.
# Submit from the remote repo after training has produced a checkpoint:
# sbatch --export=ALL,CHECKPOINT=/share_1/users/bonan_ding/DeMemWM/outputs/<run>/train/checkpoints/last.ckpt scripts/dememwm_full_eval.slurm
# or:
# CHECKPOINT=/path/to/last.ckpt sbatch --export=ALL scripts/dememwm_full_eval.slurm
set -euo pipefail
CHECKPOINT=${CHECKPOINT:-${1:-}}
if [[ -z "${CHECKPOINT}" ]]; then
echo "ERROR: set CHECKPOINT=/path/to/dememwm.ckpt, e.g." >&2
echo " sbatch --export=ALL,CHECKPOINT=/share_1/users/bonan_ding/DeMemWM/outputs/<run>/train/checkpoints/last.ckpt scripts/dememwm_full_eval.slurm" >&2
exit 2
fi
if [[ ! -s "${CHECKPOINT}" ]]; then
echo "ERROR: checkpoint does not exist or is empty: ${CHECKPOINT}" >&2
exit 2
fi
REPO=${REPO:-/share_1/users/bonan_ding/DeMemWM}
DATA_DIR=${DATA_DIR:-/share_1/users/bonan_ding/worldmem_data/minecraft}
FEATURE_DIR=${FEATURE_DIR:-/share_1/users/bonan_ding/worldmem_data/minecraft/vae_features}
RUN_TAG=${RUN_TAG:-dememwm_full_eval_${SLURM_JOB_ID:-manual_$(date +%Y%m%d_%H%M%S)}}
RUN_ROOT=${RUN_ROOT:-${REPO}/outputs/${RUN_TAG}}
EVAL_OUT=${EVAL_OUT:-${RUN_ROOT}/eval}
LOG_DIR=${LOG_DIR:-${REPO}/slurm_logs/${RUN_TAG}}
mkdir -p "${EVAL_OUT}" "${LOG_DIR}" "${REPO}/slurm_logs"
DATASET_N_FRAMES=${DATASET_N_FRAMES:-300}
N_FRAMES_VALID=${N_FRAMES_VALID:-216}
CONTEXT_FRAMES=${CONTEXT_FRAMES:-116}
N_TOKENS=${N_TOKENS:-8}
SAMPLING_TIMESTEPS=${SAMPLING_TIMESTEPS:-20}
VAL_BATCH_SIZE=${VAL_BATCH_SIZE:-1}
VAL_LIMIT=${VAL_LIMIT:-16}
LOG_VIDEO=${LOG_VIDEO:-true}
SEED=${SEED:-42}
ABLATION_BRANCH=${ABLATION_BRANCH:-A_plus_D_plus_R_normal}
# Consumed DeMemWM memory-shape knobs for current latent setup.
# Anchor: ratio 6 over 18x32 -> 4 prefixes * 3x6 pooled slots = 72 tokens.
# Revisit: ratio 3 over 18x32 -> 2 frames * 6x11 pooled slots = 132 tokens.
ANCHOR_DOWNSAMPLE_RATIO=${ANCHOR_DOWNSAMPLE_RATIO:-6}
REVISIT_MAX_FRAMES=${REVISIT_MAX_FRAMES:-2}
REVISIT_DOWNSAMPLE_RATIO=${REVISIT_DOWNSAMPLE_RATIO:-3}
cd "${REPO}"
source ~/.bashrc >/dev/null 2>&1 || true
if command -v conda >/dev/null 2>&1; then
eval "$(conda shell.bash hook)"
elif [[ -f "${HOME}/.conda/etc/profile.d/conda.sh" ]]; then
source "${HOME}/.conda/etc/profile.d/conda.sh"
elif [[ -f /share_0/conda/etc/profile.d/conda.sh ]]; then
source /share_0/conda/etc/profile.d/conda.sh
fi
conda activate worldmem
PY=$(which python)
export PYTHONPATH="./:${PYTHONPATH:-}"
export HYDRA_FULL_ERROR=1
export PYTHONWARNINGS=ignore
export OMP_NUM_THREADS="${SLURM_CPUS_PER_TASK:-16}"
export WANDB_MODE=offline
export NCCL_P2P_DISABLE=1
wandb offline >/dev/null 2>&1 || true
echo "JOB_ID=${SLURM_JOB_ID:-manual}"
echo "RUN_TAG=${RUN_TAG}"
echo "RUN_ROOT=${RUN_ROOT}"
echo "CHECKPOINT=${CHECKPOINT}"
echo "ABLATION_BRANCH=${ABLATION_BRANCH}"
echo "HOST=$(hostname)"
echo "START=$(date --iso-8601=seconds)"
echo "PWD=$PWD"
echo "PY=${PY}"
"${PY}" --version
nvidia-smi || true
nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits > "${LOG_DIR}/gpu_memory_before_mb.txt" || true
git branch --show-current || true
git rev-parse HEAD || true
EVAL_ARGS=(
"+name=eval_${RUN_TAG}"
"+output_dir=${EVAL_OUT}/"
"experiment.tasks=[validation]"
"wandb.mode=offline"
"dataset.validation_multiplier=1"
"+dataset.seed=${SEED}"
"+customized_load=true"
"+seperate_load=false"
"algorithm=dememwm_memory_dit"
"load=${CHECKPOINT}"
"dataset=video_minecraft_latent"
"dataset.save_dir=${DATA_DIR}"
"dataset.precomputed_feature_dir=${FEATURE_DIR}"
"dataset.n_frames=${DATASET_N_FRAMES}"
"+dataset.n_frames_valid=${N_FRAMES_VALID}"
"+dataset.customized_validation=true"
"+dataset.memory_condition_length=0"
"++dataset.angle_range=180"
"++dataset.pos_range=1000000000"
"++algorithm.n_tokens=${N_TOKENS}"
"algorithm.x_shape=[16,18,32]"
"++algorithm.context_frames=${CONTEXT_FRAMES}"
"++algorithm.log_video=${LOG_VIDEO}"
"++algorithm.diffusion.sampling_timesteps=${SAMPLING_TIMESTEPS}"
"++algorithm.dememwm.debug_force_all_streams=false"
"++algorithm.dememwm.training_stage=stage_2"
"++algorithm.dememwm.anchor.enabled=true"
"++algorithm.dememwm.anchor.anchor_indices=[0,1,2,3]"
"++algorithm.dememwm.anchor.diverse_selection=true"
"++algorithm.dememwm.anchor.compress.downsample_ratio=${ANCHOR_DOWNSAMPLE_RATIO}"
"++algorithm.dememwm.anchor.allow_generated_as_anchor=false"
"++algorithm.dememwm.dynamic.enabled=true"
"++algorithm.dememwm.dynamic.exclude_latest_local_frames=4"
"++algorithm.dememwm.dynamic.recent_frames=8"
"++algorithm.dememwm.revisit.enabled=true"
"++algorithm.dememwm.revisit.deterministic_pose_retrieval=true"
"++algorithm.dememwm.revisit.fov_overlap_threshold=0.30"
"++algorithm.dememwm.revisit.high_quality_fov_threshold=0.70"
"++algorithm.dememwm.revisit.pose_preselect_topk=64"
"++algorithm.dememwm.revisit.fov_yaw_samples=25"
"++algorithm.dememwm.revisit.fov_pitch_samples=20"
"++algorithm.dememwm.revisit.fov_depth_samples=20"
"++algorithm.dememwm.revisit.plucker_weight=0.10"
"++algorithm.dememwm.revisit.max_frames=${REVISIT_MAX_FRAMES}"
"++algorithm.dememwm.revisit.compress.downsample_ratio=${REVISIT_DOWNSAMPLE_RATIO}"
"++algorithm.dememwm.stage_policy.noise_bucket_logging=true"
"++algorithm.dememwm.eval_ablation.enabled=true"
"++algorithm.dememwm.eval_ablation.branch=${ABLATION_BRANCH}"
"++algorithm.dememwm.cache.enabled=true"
"++algorithm.dememwm.cache.device=cpu"
"++algorithm.dememwm.cache.keep_raw_latents=all"
"++algorithm.dememwm.cache.keep_compressed_records=true"
"++algorithm.dememwm.cache.eviction_policy=none"
"++algorithm.dememwm.cache.no_evict=true"
"++algorithm.dememwm.cache.clear_between_videos=true"
"++algorithm.dememwm.cache.max_records=null"
"++algorithm.dememwm.cache.on_capacity_exceeded=warn"
"experiment.validation.batch_size=${VAL_BATCH_SIZE}"
"experiment.validation.limit_batch=${VAL_LIMIT}"
)
printf '%s\n' "${EVAL_ARGS[@]}" > "${LOG_DIR}/eval_args.txt"
echo "Launching evaluation..."
SECONDS=0
srun "${PY}" -m main "${EVAL_ARGS[@]}"
EVAL_DURATION_SECONDS=${SECONDS}
nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits > "${LOG_DIR}/gpu_memory_after_mb.txt" || true
cat > "${RUN_ROOT}/eval_manifest.txt" <<MANIFEST
RUN_TAG=${RUN_TAG}
RUN_ROOT=${RUN_ROOT}
EVAL_OUT=${EVAL_OUT}
CHECKPOINT=${CHECKPOINT}
ABLATION_BRANCH=${ABLATION_BRANCH}
EVAL_DURATION_SECONDS=${EVAL_DURATION_SECONDS}
GPU_MEMORY_BEFORE_MB_FILE=${LOG_DIR}/gpu_memory_before_mb.txt
GPU_MEMORY_AFTER_MB_FILE=${LOG_DIR}/gpu_memory_after_mb.txt
JOB_ID=${SLURM_JOB_ID:-manual}
FINISHED=$(date --iso-8601=seconds)
MANIFEST
echo "DEMEMWM_FULL_EVAL_DONE $(date --iso-8601=seconds)"