tactile-vae / script /train_vae.sh
WitneyWW's picture
Initial upload of tactile_vae (code, model, config, inference)
3770c94 verified
Raw
History Blame Contribute Delete
7.15 kB
#!/usr/bin/env bash
#SBATCH --job-name=tactile_vae
#SBATCH --partition=ct
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:h100:1
#SBATCH --requeue
#SBATCH --output=/group2/ct/weihanx/tactile_world_model/slurm-logs/tactile_vae.%j.log
#SBATCH --error=/group2/ct/weihanx/tactile_world_model/slurm-logs/tactile_vae.%j.log
# Train tactile_vae.model.TactileVAE on the fota_unlabeled parquet dataset.
#
# Each run lives at <RUNS_DIR>/<RUN_ID>/. Re-launching with the same RUN_ID
# auto-resumes from ckpt_last.pt; wandb keeps the same run id, so metrics
# append to the same dashboard.
#
# Usage (sbatch): sbatch tactile_vae/script/train_vae.sh <run_id> [config.yaml]
# Usage (local): ./tactile_vae/script/train_vae.sh <run_id> [config.yaml]
#
# Diagnostics: set DEBUG=1 to enable `set -x` command tracing.
# sbatch --export=ALL,DEBUG=1 tactile_vae/script/train_vae.sh <run_id>
# Force unbuffered stdout/stderr so the slurm log shows progress live, not in
# one giant flush at the end. (Without this, NFS-backed log files can look
# completely empty for minutes while bash + python buffer output.)
exec 1> >(stdbuf -oL -eL cat) 2>&1
set -euo pipefail
[[ "${DEBUG:-0}" == "1" ]] && set -x
# ============================================================
# Inputs (positional)
# ============================================================
if [[ $# -lt 1 ]]; then
echo "Usage: $0 <run_id> [config.yaml]" >&2
echo " run_id : required. Both the output subdir name and the wandb run id." >&2
echo " config : optional. Defaults to tactile_vae/config/train_vae.yaml." >&2
exit 2
fi
RUN_ID="$1"
CONFIG="${2:-tactile_vae/config/train_vae.yaml}"
# ============================================================
# Paths
# ============================================================
WORKDIR="/group2/ct/weihanx/tactile_world_model"
# Keep this aligned with train_vae.yaml default `runs_root: runs`.
RUNS_DIR="$WORKDIR/tactile_world_model/runs"
RUN_DIR="$RUNS_DIR/$RUN_ID"
DATA_DIR="$WORKDIR/tactile_vae/data"
SPLITS_PATH="$WORKDIR/tactile_vae/dataset/splits.json"
# Conda env with all required deps installed. Override via env var if you
# prefer a different env (e.g. CONDA_ENV=samaudio311 sbatch ...).
# torch torchvision timm numpy pyarrow PIL pyyaml wandb
# `twm` is the project's standard env (matches tactile_jepa training).
CONDA_ENV="${CONDA_ENV:-twm}"
mkdir -p "$WORKDIR/slurm-logs"
mkdir -p "$RUNS_DIR"
umask 027
# ============================================================
# Print startup info IMMEDIATELY — before any heavy operation
# (conda activate / python imports) so the slurm log is never
# silent for more than a fraction of a second.
# ============================================================
echo "=== Tactile VAE training ==="
echo "Host: $(hostname)"
echo "Job ID: ${SLURM_JOB_ID:-N/A}"
echo "Start time: $(date)"
echo "Run ID: $RUN_ID"
echo "Workdir: $WORKDIR"
echo "Config: $CONFIG"
echo "Run dir: $RUN_DIR"
echo "Conda env: $CONDA_ENV"
echo
# ============================================================
# Environment knobs
# ============================================================
export OMP_NUM_THREADS=8
export MKL_NUM_THREADS=8
export TOKENIZERS_PARALLELISM="false"
export PYTHONFAULTHANDLER=1
export PYTHONUNBUFFERED=1 # ensures `print()` in Python flushes per line
# ============================================================
# Weights & Biases (mirrors jepa_training.sh — same account)
# ============================================================
export WANDB_API_KEY="76cdc4261bf436617e661171fd41d80403e69e9b"
export WANDB_ENTITY="weihanx-university-of-michigan"
export WANDB_USERNAME="weihanx@umich.edu"
export WANDB_PROJECT="tactile_vae"
export WANDB_MODE="online"
export WANDB_RUN_ID="$RUN_ID"
export WANDB_NAME="$RUN_ID"
export WANDB_SERVICE_WAIT=300
export WANDB_INIT_TIMEOUT=300
export WANDB_START_METHOD="thread"
export WANDB_CONSOLE="wrap"
# Keep wandb metadata/cache off network storage to speed init/resume.
export WANDB_DIR="${WANDB_DIR:-/tmp/$USER/wandb/$RUN_ID}"
export WANDB_CACHE_DIR="${WANDB_CACHE_DIR:-/tmp/$USER/wandb-cache}"
export WANDB_DATA_DIR="${WANDB_DATA_DIR:-/tmp/$USER/wandb-data}"
mkdir -p "$WANDB_DIR" "$WANDB_CACHE_DIR" "$WANDB_DATA_DIR"
# Debug knob: disable wandb entirely to isolate startup stalls.
# Default is enabled; set DISABLE_WANDB=1 to disable.
DISABLE_WANDB="${DISABLE_WANDB:-0}"
if [[ "$DISABLE_WANDB" == "1" ]]; then
unset WANDB_PROJECT WANDB_ENTITY WANDB_API_KEY WANDB_USERNAME
export WANDB_MODE="disabled"
fi
echo "--- Wandb ---"
echo " project=${WANDB_PROJECT:-<disabled>} entity=${WANDB_ENTITY:-<disabled>}"
echo " run_id=$WANDB_RUN_ID name=$WANDB_NAME mode=$WANDB_MODE"
echo " dir=$WANDB_DIR"
echo " cache_dir=$WANDB_CACHE_DIR"
echo " data_dir=$WANDB_DATA_DIR"
if [[ -n "${WANDB_API_KEY:-}" ]]; then
echo " api_key=${WANDB_API_KEY:0:10}...${WANDB_API_KEY: -4}"
else
echo " api_key=<disabled>"
fi
echo
# ============================================================
# Sanity checks (cheap; before conda activate)
# ============================================================
if [[ ! -f "$WORKDIR/$CONFIG" ]] && [[ ! -f "$CONFIG" ]]; then
echo "ERROR: config not found: $CONFIG (or $WORKDIR/$CONFIG)" >&2
exit 2
fi
if [[ ! -d "$DATA_DIR" ]]; then
echo "ERROR: data dir does not exist: $DATA_DIR" >&2
exit 2
fi
if [[ ! -f "$SPLITS_PATH" ]]; then
echo "ERROR: splits manifest not found: $SPLITS_PATH" >&2
echo " Generate it with: python tactile_vae/dataset/make_splits.py" >&2
exit 2
fi
if [[ -f "$RUN_DIR/ckpt_last.pt" ]]; then
echo "Resume: auto-resume from $RUN_DIR/ckpt_last.pt"
else
echo "Resume: fresh run (no $RUN_DIR/ckpt_last.pt)"
fi
echo
# ============================================================
# Resolve Python interpreter
# ============================================================
# Fast path: call env python directly to avoid expensive `conda activate`
# startup on busy shared filesystems. Fallback to full activation if needed.
PYTHON_BIN="${PYTHON_BIN:-$HOME/miniconda3/envs/$CONDA_ENV/bin/python}"
if [[ -x "$PYTHON_BIN" ]]; then
echo "[$(date +%H:%M:%S)] using env python directly: $PYTHON_BIN"
else
echo "[$(date +%H:%M:%S)] env python not found; falling back to conda activate..."
source ~/miniconda3/etc/profile.d/conda.sh
echo "[$(date +%H:%M:%S)] activating $CONDA_ENV..."
conda activate "$CONDA_ENV"
PYTHON_BIN="$(which python)"
echo "[$(date +%H:%M:%S)] env activated. python = $PYTHON_BIN"
fi
echo "[$(date +%H:%M:%S)] GPU(s): ${CUDA_VISIBLE_DEVICES:-$(nvidia-smi -L 2>/dev/null | head -1 || echo none)}"
echo
# ============================================================
# Launch training (`-u` is also forced by PYTHONUNBUFFERED above)
# ============================================================
cd "$WORKDIR"
echo "[$(date +%H:%M:%S)] launching trainer..."
"$PYTHON_BIN" -u tactile_vae/script/train_vae.py \
--config "$CONFIG" \
--run-id "$RUN_ID"
echo
echo "[$(date +%H:%M:%S)] Finished."
echo "Run dir contents:"
ls -lh "$RUN_DIR" 2>/dev/null || echo " (empty)"