#!/bin/sh [ -n "${BASH_VERSION:-}" ] || exec bash "$0" "$@" set -euo pipefail # Minimal cached-R3M semantic LAT/LDP grid. # # Defaults match the current experiment convention: # LAT batch_size=256 # LDP batch_size=128 # num_workers=8 # semantic objective=soft_topk SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" if [[ -z "${PYTHON_BIN:-}" ]]; then if [[ -x "${REPO_ROOT}/.venv/bin/python" ]]; then PYTHON_BIN="${REPO_ROOT}/.venv/bin/python" elif command -v python >/dev/null 2>&1; then PYTHON_BIN="$(command -v python)" else PYTHON_BIN="$(command -v python3)" fi fi DATASET_ROOT="${DATASET_ROOT:-/home/bagel/bridge_processed_partial/bridge_processed}" CACHED_R3M_FEATURES_DIR="${CACHED_R3M_FEATURES_DIR:-/home/bagel/bridge_processed_partial/r3m_resnet34_features_bs128}" OUTPUT_ROOT="${OUTPUT_ROOT:-${REPO_ROOT}/rold_semantic_ablation_runs}" DEVICE="${DEVICE:-cuda}" NUM_WORKERS="${NUM_WORKERS:-8}" SEED="${SEED:-0}" LAT_BATCH_SIZE="${LAT_BATCH_SIZE:-256}" LAT_EPOCHS="${LAT_EPOCHS:-80}" LAT_LR="${LAT_LR:-2e-4}" LATENT_DIM="${LATENT_DIM:-32}" LAT_KL_WEIGHT="${LAT_KL_WEIGHT:-1e-3}" LAT_SEMANTIC_WEIGHT="${LAT_SEMANTIC_WEIGHT:-0.05}" LAT_SEMANTIC_TOPK="${LAT_SEMANTIC_TOPK:-8}" LAT_SEMANTIC_TARGET_TEMPERATURE="${LAT_SEMANTIC_TARGET_TEMPERATURE:-0.07}" MULTIDELTA_INDICES="${MULTIDELTA_INDICES:-4,8,12,16}" LDP_BATCH_SIZE="${LDP_BATCH_SIZE:-128}" LDP_EPOCHS="${LDP_EPOCHS:-80}" LDP_LR="${LDP_LR:-2e-4}" LDP_FLOW_STEPS="${LDP_FLOW_STEPS:-50}" EVAL_BATCH_SIZE="${EVAL_BATCH_SIZE:-64}" EVAL_MAX_WINDOWS="${EVAL_MAX_WINDOWS:-2048}" EVAL_TOP_K="${EVAL_TOP_K:-5 10}" RUN_ROLD_LAT="${RUN_ROLD_LAT:-0}" RUN_ENDPOINT_LAT="${RUN_ENDPOINT_LAT:-0}" RUN_MULTIDELTA_LAT="${RUN_MULTIDELTA_LAT:-0}" RUN_ALIGNMENT_EVAL="${RUN_ALIGNMENT_EVAL:-0}" RUN_ROLD_LDP="${RUN_ROLD_LDP:-0}" RUN_ENDPOINT_LDP="${RUN_ENDPOINT_LDP:-0}" RUN_MULTIDELTA_LDP="${RUN_MULTIDELTA_LDP:-0}" ROLD_LAT_DIR="${ROLD_LAT_DIR:-${OUTPUT_ROOT}/lat_r3m_rold_baseline_cached}" ENDPOINT_LAT_DIR="${ENDPOINT_LAT_DIR:-${OUTPUT_ROOT}/lat_r3m_softk_w0.05}" MULTIDELTA_LAT_DIR="${MULTIDELTA_LAT_DIR:-${OUTPUT_ROOT}/lat_r3m_softk_multidelta_w0.05}" ROLD_LDP_DIR="${ROLD_LDP_DIR:-${OUTPUT_ROOT}/ldp_rold_baseline_cached}" ENDPOINT_LDP_DIR="${ENDPOINT_LDP_DIR:-${OUTPUT_ROOT}/ldp_monolithic_delta_w0.05}" MULTIDELTA_LDP_DIR="${MULTIDELTA_LDP_DIR:-${OUTPUT_ROOT}/ldp_monolithic_multidelta_w0.05}" ALIGNMENT_EVAL_DIR="${ALIGNMENT_EVAL_DIR:-${OUTPUT_ROOT}/alignment_eval_cached}" log() { echo "[semantic-hypothesis-cached] $*"; } mkdir -p "${OUTPUT_ROOT}" common_lat_flags=( --dataset_root "${DATASET_ROOT}" --train_split train --val_split val --device "${DEVICE}" --num_workers "${NUM_WORKERS}" --seed "${SEED}" --epochs "${LAT_EPOCHS}" --batch_size "${LAT_BATCH_SIZE}" --learning_rate "${LAT_LR}" --latent_dim "${LATENT_DIM}" --kl_loss_weight "${LAT_KL_WEIGHT}" --obs_encoder_source r3m --r3m_model_id resnet34 --cached_r3m_features_dir "${CACHED_R3M_FEATURES_DIR}" --resume_auto ) if [[ "${RUN_ROLD_LAT}" == "1" ]]; then log "training cached RoLD LAT -> ${ROLD_LAT_DIR}" mkdir -p "${ROLD_LAT_DIR}" "${PYTHON_BIN}" -u -m ddm_actions.cli.train_lat_autoencoder \ --output_dir "${ROLD_LAT_DIR}" \ "${common_lat_flags[@]}" fi if [[ "${RUN_ENDPOINT_LAT}" == "1" ]]; then log "training cached endpoint soft-topk LAT -> ${ENDPOINT_LAT_DIR}" mkdir -p "${ENDPOINT_LAT_DIR}" "${PYTHON_BIN}" -u -m ddm_actions.cli.train_lat_autoencoder \ --output_dir "${ENDPOINT_LAT_DIR}" \ "${common_lat_flags[@]}" \ --use_semantic_alignment \ --semantic_alignment_loss_type soft_topk \ --semantic_loss_weight "${LAT_SEMANTIC_WEIGHT}" \ --semantic_soft_topk "${LAT_SEMANTIC_TOPK}" \ --semantic_target_temperature "${LAT_SEMANTIC_TARGET_TEMPERATURE}" \ --semantic_target_type future_delta_feature fi if [[ "${RUN_MULTIDELTA_LAT}" == "1" ]]; then log "training cached multi-delta soft-topk LAT -> ${MULTIDELTA_LAT_DIR}" mkdir -p "${MULTIDELTA_LAT_DIR}" "${PYTHON_BIN}" -u -m ddm_actions.cli.train_lat_autoencoder \ --output_dir "${MULTIDELTA_LAT_DIR}" \ "${common_lat_flags[@]}" \ --use_semantic_alignment \ --semantic_alignment_loss_type soft_topk \ --semantic_loss_weight "${LAT_SEMANTIC_WEIGHT}" \ --semantic_soft_topk "${LAT_SEMANTIC_TOPK}" \ --semantic_target_temperature "${LAT_SEMANTIC_TARGET_TEMPERATURE}" \ --semantic_target_type multi_delta_feature \ --semantic_multidelta_indices "${MULTIDELTA_INDICES}" fi if [[ "${RUN_ALIGNMENT_EVAL}" == "1" ]]; then log "running LAT alignment diagnostics -> ${ALIGNMENT_EVAL_DIR}" mkdir -p "${ALIGNMENT_EVAL_DIR}" read -r -a top_k_values <<< "${EVAL_TOP_K}" ckpts=() names=() if [[ -f "${ROLD_LAT_DIR}/best_recon.pt" ]]; then ckpts+=("${ROLD_LAT_DIR}/best_recon.pt") names+=(rold_cached) fi if [[ -f "${ENDPOINT_LAT_DIR}/best_recon.pt" ]]; then ckpts+=("${ENDPOINT_LAT_DIR}/best_recon.pt") names+=(endpoint_softtopk) fi if [[ -f "${MULTIDELTA_LAT_DIR}/best_recon.pt" ]]; then ckpts+=("${MULTIDELTA_LAT_DIR}/best_recon.pt") names+=(multidelta_softtopk) fi if [[ "${#ckpts[@]}" -lt 2 ]]; then echo "Need at least two LAT checkpoints for alignment eval." >&2 exit 1 fi "${PYTHON_BIN}" -u -m ddm_actions.cli.eval_latent_semantic_alignment \ --dataset_root "${DATASET_ROOT}" \ --output_dir "${ALIGNMENT_EVAL_DIR}" \ --autoencoder_ckpt "${ckpts[@]}" \ --run_name "${names[@]}" \ --split val \ --device "${DEVICE}" \ --batch_size "${EVAL_BATCH_SIZE}" \ --num_workers "${NUM_WORKERS}" \ --max_windows "${EVAL_MAX_WINDOWS}" \ --top_k "${top_k_values[@]}" \ --cached_r3m_features_dir "${CACHED_R3M_FEATURES_DIR}" fi common_ldp_flags=( --dataset_root "${DATASET_ROOT}" --train_split train --val_split val --test_split test --device "${DEVICE}" --num_workers "${NUM_WORKERS}" --epochs "${LDP_EPOCHS}" --batch_size "${LDP_BATCH_SIZE}" --learning_rate "${LDP_LR}" --hidden_dim 256 --flow_steps "${LDP_FLOW_STEPS}" --seed "${SEED}" --decoded_action_loss_weight 0.0 --cached_r3m_features_dir "${CACHED_R3M_FEATURES_DIR}" --resume_auto ) if [[ "${RUN_ROLD_LDP}" == "1" ]]; then log "training cached RoLD LDP -> ${ROLD_LDP_DIR}" mkdir -p "${ROLD_LDP_DIR}" "${PYTHON_BIN}" -u -m ddm_actions.cli.train_latent_policy \ --output_dir "${ROLD_LDP_DIR}" \ --autoencoder_ckpt "${ROLD_LAT_DIR}/best_recon.pt" \ "${common_ldp_flags[@]}" fi if [[ "${RUN_ENDPOINT_LDP}" == "1" ]]; then log "training endpoint soft-topk LDP -> ${ENDPOINT_LDP_DIR}" mkdir -p "${ENDPOINT_LDP_DIR}" "${PYTHON_BIN}" -u -m ddm_actions.cli.train_latent_policy \ --output_dir "${ENDPOINT_LDP_DIR}" \ --autoencoder_ckpt "${ENDPOINT_LAT_DIR}/best_recon.pt" \ "${common_ldp_flags[@]}" fi if [[ "${RUN_MULTIDELTA_LDP}" == "1" ]]; then log "training multi-delta soft-topk LDP -> ${MULTIDELTA_LDP_DIR}" mkdir -p "${MULTIDELTA_LDP_DIR}" "${PYTHON_BIN}" -u -m ddm_actions.cli.train_latent_policy \ --output_dir "${MULTIDELTA_LDP_DIR}" \ --autoencoder_ckpt "${MULTIDELTA_LAT_DIR}/best_recon.pt" \ "${common_ldp_flags[@]}" fi log "done"