File size: 6,238 Bytes
3770c94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
#!/usr/bin/env bash
#SBATCH --job-name=tactile_vae_pl
#SBATCH --partition=sharedp
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=2
#SBATCH --gres=gpu:h100:4
#SBATCH --requeue
#SBATCH --exclude=mfmc10
#SBATCH --output=/group2/ct/weihanx/tactile_world_model/slurm-logs/tactile_vae_pl.%j.log
#SBATCH --error=/group2/ct/weihanx/tactile_world_model/slurm-logs/tactile_vae_pl.%j.log

# Train TactileVAE with PyTorch Lightning (train_vae_pl.py).
#
# Each run lives at <RUNS_DIR>/<RUN_ID>/. Re-launching with the same RUN_ID
# auto-resumes from checkpoints/last.ckpt (Lightning) or ckpt_last.pt (compat).
# wandb keeps the same run id so metrics append to the same dashboard.
#
# Usage (sbatch):  sbatch tactile_vae/script/train_vae_pl.sh <run_id> [config.yaml]
# Usage (local):   ./tactile_vae/script/train_vae_pl.sh        <run_id> [config.yaml]
#
# Diagnostics: set DEBUG=1 to enable `set -x` command tracing.
#   sbatch --export=ALL,DEBUG=1 tactile_vae/script/train_vae_pl.sh <run_id>

exec 1> >(stdbuf -oL -eL cat) 2>&1

set -euo pipefail
[[ "${DEBUG:-0}" == "1" ]] && set -x

# ============================================================
#  Inputs (positional)
# ============================================================
if [[ $# -lt 1 ]]; then
  echo "Usage: $0 <run_id> [config.yaml]" >&2
  echo "  run_id   : required. Both the output subdir name and the wandb run id." >&2
  echo "  config   : optional. Defaults to tactile_vae/config/train_vae.yaml." >&2
  exit 2
fi
RUN_ID="$1"
CONFIG="${2:-tactile_vae/config/train_vae.yaml}"

# ============================================================
#  Paths
# ============================================================
WORKDIR="/group2/ct/weihanx/tactile_world_model"
RUNS_DIR="$WORKDIR/runs"
RUN_DIR="$RUNS_DIR/$RUN_ID"
DATA_DIR="$WORKDIR/tactile_vae/data"
SPLITS_PATH="$WORKDIR/tactile_vae/dataset/splits_subset.json"

CONDA_ENV="${CONDA_ENV:-twm}"

mkdir -p "$WORKDIR/slurm-logs"
mkdir -p "$RUNS_DIR"
umask 027

# ============================================================
#  Print startup info
# ============================================================
echo "=== Tactile VAE (PyTorch Lightning) ==="
echo "Host:       $(hostname)"
echo "Job ID:     ${SLURM_JOB_ID:-N/A}"
echo "Start time: $(date)"
echo "Run ID:     $RUN_ID"
echo "Workdir:    $WORKDIR"
echo "Config:     $CONFIG"
echo "Run dir:    $RUN_DIR"
echo "Conda env:  $CONDA_ENV"
echo

# ============================================================
#  Environment knobs
# ============================================================
export OMP_NUM_THREADS=8
export MKL_NUM_THREADS=8
export TOKENIZERS_PARALLELISM="false"
export PYTHONFAULTHANDLER=1
export PYTHONUNBUFFERED=1

# ============================================================
#  Weights & Biases
# ============================================================
export WANDB_API_KEY="76cdc4261bf436617e661171fd41d80403e69e9b"
export WANDB_ENTITY="weihanx-university-of-michigan"
export WANDB_USERNAME="weihanx@umich.edu"
export WANDB_PROJECT="tactile_vae"
export WANDB_MODE="online"
export WANDB_RUN_ID="$RUN_ID"
export WANDB_NAME="$RUN_ID"
export WANDB_SERVICE_WAIT=300
export WANDB_INIT_TIMEOUT=300
export WANDB_START_METHOD="thread"
export WANDB_CONSOLE="wrap"
export WANDB_DIR="${WANDB_DIR:-/tmp/$USER/wandb/$RUN_ID}"
export WANDB_CACHE_DIR="${WANDB_CACHE_DIR:-/tmp/$USER/wandb-cache}"
export WANDB_DATA_DIR="${WANDB_DATA_DIR:-/tmp/$USER/wandb-data}"
mkdir -p "$WANDB_DIR" "$WANDB_CACHE_DIR" "$WANDB_DATA_DIR"

DISABLE_WANDB="${DISABLE_WANDB:-0}"
if [[ "$DISABLE_WANDB" == "1" ]]; then
  unset WANDB_PROJECT WANDB_ENTITY WANDB_API_KEY WANDB_USERNAME
  export WANDB_MODE="disabled"
fi

echo "--- Wandb ---"
echo "  project=${WANDB_PROJECT:-<disabled>}  entity=${WANDB_ENTITY:-<disabled>}"
echo "  run_id=$WANDB_RUN_ID  name=$WANDB_NAME  mode=$WANDB_MODE"
echo "  dir=$WANDB_DIR"
echo "  cache_dir=$WANDB_CACHE_DIR"
echo "  data_dir=$WANDB_DATA_DIR"
if [[ -n "${WANDB_API_KEY:-}" ]]; then
  echo "  api_key=${WANDB_API_KEY:0:10}...${WANDB_API_KEY: -4}"
else
  echo "  api_key=<disabled>"
fi
echo

# ============================================================
#  Sanity checks
# ============================================================
if [[ ! -f "$WORKDIR/$CONFIG" ]] && [[ ! -f "$CONFIG" ]]; then
  echo "ERROR: config not found: $CONFIG (or $WORKDIR/$CONFIG)" >&2
  exit 2
fi
if [[ ! -d "$DATA_DIR" ]]; then
  echo "ERROR: data dir does not exist: $DATA_DIR" >&2
  exit 2
fi
if [[ ! -f "$SPLITS_PATH" ]]; then
  echo "ERROR: splits manifest not found: $SPLITS_PATH" >&2
  echo "  Generate it with: python tactile_vae/dataset/make_splits.py" >&2
  exit 2
fi

# Report resume state (Lightning ckpt takes priority over compat ckpt)
if [[ -f "$RUN_DIR/checkpoints/last.ckpt" ]]; then
  echo "Resume:     auto-resume from $RUN_DIR/checkpoints/last.ckpt (Lightning)"
elif [[ -f "$RUN_DIR/ckpt_last.pt" ]]; then
  echo "Resume:     auto-resume from $RUN_DIR/ckpt_last.pt (compat)"
else
  echo "Resume:     fresh run (no checkpoint found)"
fi
echo

# ============================================================
#  Resolve Python interpreter
# ============================================================
PYTHON_BIN="${PYTHON_BIN:-$HOME/miniconda3/envs/$CONDA_ENV/bin/python}"
if [[ -x "$PYTHON_BIN" ]]; then
  echo "[$(date +%H:%M:%S)] using env python directly: $PYTHON_BIN"
else
  echo "[$(date +%H:%M:%S)] env python not found; falling back to conda activate..."
  source ~/miniconda3/etc/profile.d/conda.sh
  echo "[$(date +%H:%M:%S)] activating $CONDA_ENV..."
  conda activate "$CONDA_ENV"
  PYTHON_BIN="$(which python)"
  echo "[$(date +%H:%M:%S)] env activated. python = $PYTHON_BIN"
fi
echo "[$(date +%H:%M:%S)] GPU(s): ${CUDA_VISIBLE_DEVICES:-$(nvidia-smi -L 2>/dev/null | head -1 || echo none)}"
echo

# ============================================================
#  Launch training
# ============================================================
cd "$WORKDIR"
echo "[$(date +%H:%M:%S)] launching trainer (PyTorch Lightning)..."
"$PYTHON_BIN" -u tactile_vae/script/train_vae_pl.py \
    --config "$CONFIG" \
    --run-id "$RUN_ID"

echo
echo "[$(date +%H:%M:%S)] Finished."
echo "Run dir contents:"
ls -lh "$RUN_DIR" 2>/dev/null || echo "  (empty)"