| #!/bin/sh |
| [ -n "${BASH_VERSION:-}" ] || exec bash "$0" "$@" |
| set -euo pipefail |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" |
| REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" |
|
|
| if [[ -z "${PYTHON_BIN:-}" ]]; then |
| if [[ -x "${REPO_ROOT}/.venv/bin/python" ]]; then |
| PYTHON_BIN="${REPO_ROOT}/.venv/bin/python" |
| elif command -v python >/dev/null 2>&1; then |
| PYTHON_BIN="$(command -v python)" |
| else |
| PYTHON_BIN="$(command -v python3)" |
| fi |
| fi |
|
|
| DATASET_ROOT="${DATASET_ROOT:-/home/bagel/bridge_processed_partial/bridge_processed}" |
| CACHED_R3M_FEATURES_DIR="${CACHED_R3M_FEATURES_DIR:-/home/bagel/bridge_processed_partial/r3m_resnet34_features_bs128}" |
| OUTPUT_ROOT="${OUTPUT_ROOT:-${REPO_ROOT}/rold_semantic_ablation_runs}" |
| DEVICE="${DEVICE:-cuda}" |
| NUM_WORKERS="${NUM_WORKERS:-8}" |
| SEED="${SEED:-0}" |
|
|
| LAT_BATCH_SIZE="${LAT_BATCH_SIZE:-256}" |
| LAT_EPOCHS="${LAT_EPOCHS:-80}" |
| LAT_LR="${LAT_LR:-2e-4}" |
| LATENT_DIM="${LATENT_DIM:-32}" |
| LAT_KL_WEIGHT="${LAT_KL_WEIGHT:-1e-3}" |
| LAT_SEMANTIC_WEIGHT="${LAT_SEMANTIC_WEIGHT:-0.05}" |
| LAT_SEMANTIC_TOPK="${LAT_SEMANTIC_TOPK:-8}" |
| LAT_SEMANTIC_TARGET_TEMPERATURE="${LAT_SEMANTIC_TARGET_TEMPERATURE:-0.07}" |
| MULTIDELTA_INDICES="${MULTIDELTA_INDICES:-4,8,12,16}" |
|
|
| LDP_BATCH_SIZE="${LDP_BATCH_SIZE:-128}" |
| LDP_EPOCHS="${LDP_EPOCHS:-80}" |
| LDP_LR="${LDP_LR:-2e-4}" |
| LDP_FLOW_STEPS="${LDP_FLOW_STEPS:-50}" |
|
|
| EVAL_BATCH_SIZE="${EVAL_BATCH_SIZE:-64}" |
| EVAL_MAX_WINDOWS="${EVAL_MAX_WINDOWS:-2048}" |
| EVAL_TOP_K="${EVAL_TOP_K:-5 10}" |
|
|
| RUN_ROLD_LAT="${RUN_ROLD_LAT:-0}" |
| RUN_ENDPOINT_LAT="${RUN_ENDPOINT_LAT:-0}" |
| RUN_MULTIDELTA_LAT="${RUN_MULTIDELTA_LAT:-0}" |
| RUN_ALIGNMENT_EVAL="${RUN_ALIGNMENT_EVAL:-0}" |
| RUN_ROLD_LDP="${RUN_ROLD_LDP:-0}" |
| RUN_ENDPOINT_LDP="${RUN_ENDPOINT_LDP:-0}" |
| RUN_MULTIDELTA_LDP="${RUN_MULTIDELTA_LDP:-0}" |
|
|
| ROLD_LAT_DIR="${ROLD_LAT_DIR:-${OUTPUT_ROOT}/lat_r3m_rold_baseline_cached}" |
| ENDPOINT_LAT_DIR="${ENDPOINT_LAT_DIR:-${OUTPUT_ROOT}/lat_r3m_softk_w0.05}" |
| MULTIDELTA_LAT_DIR="${MULTIDELTA_LAT_DIR:-${OUTPUT_ROOT}/lat_r3m_softk_multidelta_w0.05}" |
|
|
| ROLD_LDP_DIR="${ROLD_LDP_DIR:-${OUTPUT_ROOT}/ldp_rold_baseline_cached}" |
| ENDPOINT_LDP_DIR="${ENDPOINT_LDP_DIR:-${OUTPUT_ROOT}/ldp_monolithic_delta_w0.05}" |
| MULTIDELTA_LDP_DIR="${MULTIDELTA_LDP_DIR:-${OUTPUT_ROOT}/ldp_monolithic_multidelta_w0.05}" |
| ALIGNMENT_EVAL_DIR="${ALIGNMENT_EVAL_DIR:-${OUTPUT_ROOT}/alignment_eval_cached}" |
|
|
| log() { echo "[semantic-hypothesis-cached] $*"; } |
|
|
| mkdir -p "${OUTPUT_ROOT}" |
|
|
| common_lat_flags=( |
| --dataset_root "${DATASET_ROOT}" |
| --train_split train |
| --val_split val |
| --device "${DEVICE}" |
| --num_workers "${NUM_WORKERS}" |
| --seed "${SEED}" |
| --epochs "${LAT_EPOCHS}" |
| --batch_size "${LAT_BATCH_SIZE}" |
| --learning_rate "${LAT_LR}" |
| --latent_dim "${LATENT_DIM}" |
| --kl_loss_weight "${LAT_KL_WEIGHT}" |
| --obs_encoder_source r3m |
| --r3m_model_id resnet34 |
| --cached_r3m_features_dir "${CACHED_R3M_FEATURES_DIR}" |
| --resume_auto |
| ) |
|
|
| if [[ "${RUN_ROLD_LAT}" == "1" ]]; then |
| log "training cached RoLD LAT -> ${ROLD_LAT_DIR}" |
| mkdir -p "${ROLD_LAT_DIR}" |
| "${PYTHON_BIN}" -u -m ddm_actions.cli.train_lat_autoencoder \ |
| --output_dir "${ROLD_LAT_DIR}" \ |
| "${common_lat_flags[@]}" |
| fi |
|
|
| if [[ "${RUN_ENDPOINT_LAT}" == "1" ]]; then |
| log "training cached endpoint soft-topk LAT -> ${ENDPOINT_LAT_DIR}" |
| mkdir -p "${ENDPOINT_LAT_DIR}" |
| "${PYTHON_BIN}" -u -m ddm_actions.cli.train_lat_autoencoder \ |
| --output_dir "${ENDPOINT_LAT_DIR}" \ |
| "${common_lat_flags[@]}" \ |
| --use_semantic_alignment \ |
| --semantic_alignment_loss_type soft_topk \ |
| --semantic_loss_weight "${LAT_SEMANTIC_WEIGHT}" \ |
| --semantic_soft_topk "${LAT_SEMANTIC_TOPK}" \ |
| --semantic_target_temperature "${LAT_SEMANTIC_TARGET_TEMPERATURE}" \ |
| --semantic_target_type future_delta_feature |
| fi |
|
|
| if [[ "${RUN_MULTIDELTA_LAT}" == "1" ]]; then |
| log "training cached multi-delta soft-topk LAT -> ${MULTIDELTA_LAT_DIR}" |
| mkdir -p "${MULTIDELTA_LAT_DIR}" |
| "${PYTHON_BIN}" -u -m ddm_actions.cli.train_lat_autoencoder \ |
| --output_dir "${MULTIDELTA_LAT_DIR}" \ |
| "${common_lat_flags[@]}" \ |
| --use_semantic_alignment \ |
| --semantic_alignment_loss_type soft_topk \ |
| --semantic_loss_weight "${LAT_SEMANTIC_WEIGHT}" \ |
| --semantic_soft_topk "${LAT_SEMANTIC_TOPK}" \ |
| --semantic_target_temperature "${LAT_SEMANTIC_TARGET_TEMPERATURE}" \ |
| --semantic_target_type multi_delta_feature \ |
| --semantic_multidelta_indices "${MULTIDELTA_INDICES}" |
| fi |
|
|
| if [[ "${RUN_ALIGNMENT_EVAL}" == "1" ]]; then |
| log "running LAT alignment diagnostics -> ${ALIGNMENT_EVAL_DIR}" |
| mkdir -p "${ALIGNMENT_EVAL_DIR}" |
| read -r -a top_k_values <<< "${EVAL_TOP_K}" |
|
|
| ckpts=() |
| names=() |
| if [[ -f "${ROLD_LAT_DIR}/best_recon.pt" ]]; then |
| ckpts+=("${ROLD_LAT_DIR}/best_recon.pt") |
| names+=(rold_cached) |
| fi |
| if [[ -f "${ENDPOINT_LAT_DIR}/best_recon.pt" ]]; then |
| ckpts+=("${ENDPOINT_LAT_DIR}/best_recon.pt") |
| names+=(endpoint_softtopk) |
| fi |
| if [[ -f "${MULTIDELTA_LAT_DIR}/best_recon.pt" ]]; then |
| ckpts+=("${MULTIDELTA_LAT_DIR}/best_recon.pt") |
| names+=(multidelta_softtopk) |
| fi |
| if [[ "${#ckpts[@]}" -lt 2 ]]; then |
| echo "Need at least two LAT checkpoints for alignment eval." >&2 |
| exit 1 |
| fi |
|
|
| "${PYTHON_BIN}" -u -m ddm_actions.cli.eval_latent_semantic_alignment \ |
| --dataset_root "${DATASET_ROOT}" \ |
| --output_dir "${ALIGNMENT_EVAL_DIR}" \ |
| --autoencoder_ckpt "${ckpts[@]}" \ |
| --run_name "${names[@]}" \ |
| --split val \ |
| --device "${DEVICE}" \ |
| --batch_size "${EVAL_BATCH_SIZE}" \ |
| --num_workers "${NUM_WORKERS}" \ |
| --max_windows "${EVAL_MAX_WINDOWS}" \ |
| --top_k "${top_k_values[@]}" \ |
| --cached_r3m_features_dir "${CACHED_R3M_FEATURES_DIR}" |
| fi |
|
|
| common_ldp_flags=( |
| --dataset_root "${DATASET_ROOT}" |
| --train_split train |
| --val_split val |
| --test_split test |
| --device "${DEVICE}" |
| --num_workers "${NUM_WORKERS}" |
| --epochs "${LDP_EPOCHS}" |
| --batch_size "${LDP_BATCH_SIZE}" |
| --learning_rate "${LDP_LR}" |
| --hidden_dim 256 |
| --flow_steps "${LDP_FLOW_STEPS}" |
| --seed "${SEED}" |
| --decoded_action_loss_weight 0.0 |
| --cached_r3m_features_dir "${CACHED_R3M_FEATURES_DIR}" |
| --resume_auto |
| ) |
|
|
| if [[ "${RUN_ROLD_LDP}" == "1" ]]; then |
| log "training cached RoLD LDP -> ${ROLD_LDP_DIR}" |
| mkdir -p "${ROLD_LDP_DIR}" |
| "${PYTHON_BIN}" -u -m ddm_actions.cli.train_latent_policy \ |
| --output_dir "${ROLD_LDP_DIR}" \ |
| --autoencoder_ckpt "${ROLD_LAT_DIR}/best_recon.pt" \ |
| "${common_ldp_flags[@]}" |
| fi |
|
|
| if [[ "${RUN_ENDPOINT_LDP}" == "1" ]]; then |
| log "training endpoint soft-topk LDP -> ${ENDPOINT_LDP_DIR}" |
| mkdir -p "${ENDPOINT_LDP_DIR}" |
| "${PYTHON_BIN}" -u -m ddm_actions.cli.train_latent_policy \ |
| --output_dir "${ENDPOINT_LDP_DIR}" \ |
| --autoencoder_ckpt "${ENDPOINT_LAT_DIR}/best_recon.pt" \ |
| "${common_ldp_flags[@]}" |
| fi |
|
|
| if [[ "${RUN_MULTIDELTA_LDP}" == "1" ]]; then |
| log "training multi-delta soft-topk LDP -> ${MULTIDELTA_LDP_DIR}" |
| mkdir -p "${MULTIDELTA_LDP_DIR}" |
| "${PYTHON_BIN}" -u -m ddm_actions.cli.train_latent_policy \ |
| --output_dir "${MULTIDELTA_LDP_DIR}" \ |
| --autoencoder_ckpt "${MULTIDELTA_LAT_DIR}/best_recon.pt" \ |
| "${common_ldp_flags[@]}" |
| fi |
|
|
| log "done" |
|
|