semantic-latent-diffusion-policy / scripts /run_semantic_hypothesis_cached.sh
GnanaPraveen's picture
Upload semantic latent diffusion policy checkpoints
f45eb20 verified
#!/bin/sh
[ -n "${BASH_VERSION:-}" ] || exec bash "$0" "$@"
set -euo pipefail
# Minimal cached-R3M semantic LAT/LDP grid.
#
# Defaults match the current experiment convention:
# LAT batch_size=256
# LDP batch_size=128
# num_workers=8
# semantic objective=soft_topk
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
if [[ -z "${PYTHON_BIN:-}" ]]; then
if [[ -x "${REPO_ROOT}/.venv/bin/python" ]]; then
PYTHON_BIN="${REPO_ROOT}/.venv/bin/python"
elif command -v python >/dev/null 2>&1; then
PYTHON_BIN="$(command -v python)"
else
PYTHON_BIN="$(command -v python3)"
fi
fi
DATASET_ROOT="${DATASET_ROOT:-/home/bagel/bridge_processed_partial/bridge_processed}"
CACHED_R3M_FEATURES_DIR="${CACHED_R3M_FEATURES_DIR:-/home/bagel/bridge_processed_partial/r3m_resnet34_features_bs128}"
OUTPUT_ROOT="${OUTPUT_ROOT:-${REPO_ROOT}/rold_semantic_ablation_runs}"
DEVICE="${DEVICE:-cuda}"
NUM_WORKERS="${NUM_WORKERS:-8}"
SEED="${SEED:-0}"
LAT_BATCH_SIZE="${LAT_BATCH_SIZE:-256}"
LAT_EPOCHS="${LAT_EPOCHS:-80}"
LAT_LR="${LAT_LR:-2e-4}"
LATENT_DIM="${LATENT_DIM:-32}"
LAT_KL_WEIGHT="${LAT_KL_WEIGHT:-1e-3}"
LAT_SEMANTIC_WEIGHT="${LAT_SEMANTIC_WEIGHT:-0.05}"
LAT_SEMANTIC_TOPK="${LAT_SEMANTIC_TOPK:-8}"
LAT_SEMANTIC_TARGET_TEMPERATURE="${LAT_SEMANTIC_TARGET_TEMPERATURE:-0.07}"
MULTIDELTA_INDICES="${MULTIDELTA_INDICES:-4,8,12,16}"
LDP_BATCH_SIZE="${LDP_BATCH_SIZE:-128}"
LDP_EPOCHS="${LDP_EPOCHS:-80}"
LDP_LR="${LDP_LR:-2e-4}"
LDP_FLOW_STEPS="${LDP_FLOW_STEPS:-50}"
EVAL_BATCH_SIZE="${EVAL_BATCH_SIZE:-64}"
EVAL_MAX_WINDOWS="${EVAL_MAX_WINDOWS:-2048}"
EVAL_TOP_K="${EVAL_TOP_K:-5 10}"
RUN_ROLD_LAT="${RUN_ROLD_LAT:-0}"
RUN_ENDPOINT_LAT="${RUN_ENDPOINT_LAT:-0}"
RUN_MULTIDELTA_LAT="${RUN_MULTIDELTA_LAT:-0}"
RUN_ALIGNMENT_EVAL="${RUN_ALIGNMENT_EVAL:-0}"
RUN_ROLD_LDP="${RUN_ROLD_LDP:-0}"
RUN_ENDPOINT_LDP="${RUN_ENDPOINT_LDP:-0}"
RUN_MULTIDELTA_LDP="${RUN_MULTIDELTA_LDP:-0}"
ROLD_LAT_DIR="${ROLD_LAT_DIR:-${OUTPUT_ROOT}/lat_r3m_rold_baseline_cached}"
ENDPOINT_LAT_DIR="${ENDPOINT_LAT_DIR:-${OUTPUT_ROOT}/lat_r3m_softk_w0.05}"
MULTIDELTA_LAT_DIR="${MULTIDELTA_LAT_DIR:-${OUTPUT_ROOT}/lat_r3m_softk_multidelta_w0.05}"
ROLD_LDP_DIR="${ROLD_LDP_DIR:-${OUTPUT_ROOT}/ldp_rold_baseline_cached}"
ENDPOINT_LDP_DIR="${ENDPOINT_LDP_DIR:-${OUTPUT_ROOT}/ldp_monolithic_delta_w0.05}"
MULTIDELTA_LDP_DIR="${MULTIDELTA_LDP_DIR:-${OUTPUT_ROOT}/ldp_monolithic_multidelta_w0.05}"
ALIGNMENT_EVAL_DIR="${ALIGNMENT_EVAL_DIR:-${OUTPUT_ROOT}/alignment_eval_cached}"
log() { echo "[semantic-hypothesis-cached] $*"; }
mkdir -p "${OUTPUT_ROOT}"
common_lat_flags=(
--dataset_root "${DATASET_ROOT}"
--train_split train
--val_split val
--device "${DEVICE}"
--num_workers "${NUM_WORKERS}"
--seed "${SEED}"
--epochs "${LAT_EPOCHS}"
--batch_size "${LAT_BATCH_SIZE}"
--learning_rate "${LAT_LR}"
--latent_dim "${LATENT_DIM}"
--kl_loss_weight "${LAT_KL_WEIGHT}"
--obs_encoder_source r3m
--r3m_model_id resnet34
--cached_r3m_features_dir "${CACHED_R3M_FEATURES_DIR}"
--resume_auto
)
if [[ "${RUN_ROLD_LAT}" == "1" ]]; then
log "training cached RoLD LAT -> ${ROLD_LAT_DIR}"
mkdir -p "${ROLD_LAT_DIR}"
"${PYTHON_BIN}" -u -m ddm_actions.cli.train_lat_autoencoder \
--output_dir "${ROLD_LAT_DIR}" \
"${common_lat_flags[@]}"
fi
if [[ "${RUN_ENDPOINT_LAT}" == "1" ]]; then
log "training cached endpoint soft-topk LAT -> ${ENDPOINT_LAT_DIR}"
mkdir -p "${ENDPOINT_LAT_DIR}"
"${PYTHON_BIN}" -u -m ddm_actions.cli.train_lat_autoencoder \
--output_dir "${ENDPOINT_LAT_DIR}" \
"${common_lat_flags[@]}" \
--use_semantic_alignment \
--semantic_alignment_loss_type soft_topk \
--semantic_loss_weight "${LAT_SEMANTIC_WEIGHT}" \
--semantic_soft_topk "${LAT_SEMANTIC_TOPK}" \
--semantic_target_temperature "${LAT_SEMANTIC_TARGET_TEMPERATURE}" \
--semantic_target_type future_delta_feature
fi
if [[ "${RUN_MULTIDELTA_LAT}" == "1" ]]; then
log "training cached multi-delta soft-topk LAT -> ${MULTIDELTA_LAT_DIR}"
mkdir -p "${MULTIDELTA_LAT_DIR}"
"${PYTHON_BIN}" -u -m ddm_actions.cli.train_lat_autoencoder \
--output_dir "${MULTIDELTA_LAT_DIR}" \
"${common_lat_flags[@]}" \
--use_semantic_alignment \
--semantic_alignment_loss_type soft_topk \
--semantic_loss_weight "${LAT_SEMANTIC_WEIGHT}" \
--semantic_soft_topk "${LAT_SEMANTIC_TOPK}" \
--semantic_target_temperature "${LAT_SEMANTIC_TARGET_TEMPERATURE}" \
--semantic_target_type multi_delta_feature \
--semantic_multidelta_indices "${MULTIDELTA_INDICES}"
fi
if [[ "${RUN_ALIGNMENT_EVAL}" == "1" ]]; then
log "running LAT alignment diagnostics -> ${ALIGNMENT_EVAL_DIR}"
mkdir -p "${ALIGNMENT_EVAL_DIR}"
read -r -a top_k_values <<< "${EVAL_TOP_K}"
ckpts=()
names=()
if [[ -f "${ROLD_LAT_DIR}/best_recon.pt" ]]; then
ckpts+=("${ROLD_LAT_DIR}/best_recon.pt")
names+=(rold_cached)
fi
if [[ -f "${ENDPOINT_LAT_DIR}/best_recon.pt" ]]; then
ckpts+=("${ENDPOINT_LAT_DIR}/best_recon.pt")
names+=(endpoint_softtopk)
fi
if [[ -f "${MULTIDELTA_LAT_DIR}/best_recon.pt" ]]; then
ckpts+=("${MULTIDELTA_LAT_DIR}/best_recon.pt")
names+=(multidelta_softtopk)
fi
if [[ "${#ckpts[@]}" -lt 2 ]]; then
echo "Need at least two LAT checkpoints for alignment eval." >&2
exit 1
fi
"${PYTHON_BIN}" -u -m ddm_actions.cli.eval_latent_semantic_alignment \
--dataset_root "${DATASET_ROOT}" \
--output_dir "${ALIGNMENT_EVAL_DIR}" \
--autoencoder_ckpt "${ckpts[@]}" \
--run_name "${names[@]}" \
--split val \
--device "${DEVICE}" \
--batch_size "${EVAL_BATCH_SIZE}" \
--num_workers "${NUM_WORKERS}" \
--max_windows "${EVAL_MAX_WINDOWS}" \
--top_k "${top_k_values[@]}" \
--cached_r3m_features_dir "${CACHED_R3M_FEATURES_DIR}"
fi
common_ldp_flags=(
--dataset_root "${DATASET_ROOT}"
--train_split train
--val_split val
--test_split test
--device "${DEVICE}"
--num_workers "${NUM_WORKERS}"
--epochs "${LDP_EPOCHS}"
--batch_size "${LDP_BATCH_SIZE}"
--learning_rate "${LDP_LR}"
--hidden_dim 256
--flow_steps "${LDP_FLOW_STEPS}"
--seed "${SEED}"
--decoded_action_loss_weight 0.0
--cached_r3m_features_dir "${CACHED_R3M_FEATURES_DIR}"
--resume_auto
)
if [[ "${RUN_ROLD_LDP}" == "1" ]]; then
log "training cached RoLD LDP -> ${ROLD_LDP_DIR}"
mkdir -p "${ROLD_LDP_DIR}"
"${PYTHON_BIN}" -u -m ddm_actions.cli.train_latent_policy \
--output_dir "${ROLD_LDP_DIR}" \
--autoencoder_ckpt "${ROLD_LAT_DIR}/best_recon.pt" \
"${common_ldp_flags[@]}"
fi
if [[ "${RUN_ENDPOINT_LDP}" == "1" ]]; then
log "training endpoint soft-topk LDP -> ${ENDPOINT_LDP_DIR}"
mkdir -p "${ENDPOINT_LDP_DIR}"
"${PYTHON_BIN}" -u -m ddm_actions.cli.train_latent_policy \
--output_dir "${ENDPOINT_LDP_DIR}" \
--autoencoder_ckpt "${ENDPOINT_LAT_DIR}/best_recon.pt" \
"${common_ldp_flags[@]}"
fi
if [[ "${RUN_MULTIDELTA_LDP}" == "1" ]]; then
log "training multi-delta soft-topk LDP -> ${MULTIDELTA_LDP_DIR}"
mkdir -p "${MULTIDELTA_LDP_DIR}"
"${PYTHON_BIN}" -u -m ddm_actions.cli.train_latent_policy \
--output_dir "${MULTIDELTA_LDP_DIR}" \
--autoencoder_ckpt "${MULTIDELTA_LAT_DIR}/best_recon.pt" \
"${common_ldp_flags[@]}"
fi
log "done"