File size: 7,151 Bytes
3770c94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
#!/usr/bin/env bash
#SBATCH --job-name=tactile_vae
#SBATCH --partition=ct
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:h100:1
#SBATCH --requeue
#SBATCH --output=/group2/ct/weihanx/tactile_world_model/slurm-logs/tactile_vae.%j.log
#SBATCH --error=/group2/ct/weihanx/tactile_world_model/slurm-logs/tactile_vae.%j.log

# Train tactile_vae.model.TactileVAE on the fota_unlabeled parquet dataset.
#
# Each run lives at <RUNS_DIR>/<RUN_ID>/. Re-launching with the same RUN_ID
# auto-resumes from ckpt_last.pt; wandb keeps the same run id, so metrics
# append to the same dashboard.
#
# Usage (sbatch):  sbatch tactile_vae/script/train_vae.sh <run_id> [config.yaml]
# Usage (local):   ./tactile_vae/script/train_vae.sh        <run_id> [config.yaml]
#
# Diagnostics: set DEBUG=1 to enable `set -x` command tracing.
#   sbatch --export=ALL,DEBUG=1 tactile_vae/script/train_vae.sh <run_id>

# Force unbuffered stdout/stderr so the slurm log shows progress live, not in
# one giant flush at the end. (Without this, NFS-backed log files can look
# completely empty for minutes while bash + python buffer output.)
exec 1> >(stdbuf -oL -eL cat) 2>&1

set -euo pipefail
[[ "${DEBUG:-0}" == "1" ]] && set -x

# ============================================================
#  Inputs (positional)
# ============================================================
if [[ $# -lt 1 ]]; then
  echo "Usage: $0 <run_id> [config.yaml]" >&2
  echo "  run_id   : required. Both the output subdir name and the wandb run id." >&2
  echo "  config   : optional. Defaults to tactile_vae/config/train_vae.yaml." >&2
  exit 2
fi
RUN_ID="$1"
CONFIG="${2:-tactile_vae/config/train_vae.yaml}"

# ============================================================
#  Paths
# ============================================================
WORKDIR="/group2/ct/weihanx/tactile_world_model"
# Keep this aligned with train_vae.yaml default `runs_root: runs`.
RUNS_DIR="$WORKDIR/tactile_world_model/runs"
RUN_DIR="$RUNS_DIR/$RUN_ID"
DATA_DIR="$WORKDIR/tactile_vae/data"
SPLITS_PATH="$WORKDIR/tactile_vae/dataset/splits.json"

# Conda env with all required deps installed. Override via env var if you
# prefer a different env (e.g. CONDA_ENV=samaudio311 sbatch ...).
#   torch torchvision timm numpy pyarrow PIL pyyaml wandb
# `twm` is the project's standard env (matches tactile_jepa training).
CONDA_ENV="${CONDA_ENV:-twm}"

mkdir -p "$WORKDIR/slurm-logs"
mkdir -p "$RUNS_DIR"
umask 027

# ============================================================
#  Print startup info IMMEDIATELY — before any heavy operation
#  (conda activate / python imports) so the slurm log is never
#  silent for more than a fraction of a second.
# ============================================================
echo "=== Tactile VAE training ==="
echo "Host:       $(hostname)"
echo "Job ID:     ${SLURM_JOB_ID:-N/A}"
echo "Start time: $(date)"
echo "Run ID:     $RUN_ID"
echo "Workdir:    $WORKDIR"
echo "Config:     $CONFIG"
echo "Run dir:    $RUN_DIR"
echo "Conda env:  $CONDA_ENV"
echo

# ============================================================
#  Environment knobs
# ============================================================
export OMP_NUM_THREADS=8
export MKL_NUM_THREADS=8
export TOKENIZERS_PARALLELISM="false"
export PYTHONFAULTHANDLER=1
export PYTHONUNBUFFERED=1   # ensures `print()` in Python flushes per line

# ============================================================
#  Weights & Biases (mirrors jepa_training.sh — same account)
# ============================================================
export WANDB_API_KEY="76cdc4261bf436617e661171fd41d80403e69e9b"
export WANDB_ENTITY="weihanx-university-of-michigan"
export WANDB_USERNAME="weihanx@umich.edu"
export WANDB_PROJECT="tactile_vae"
export WANDB_MODE="online"
export WANDB_RUN_ID="$RUN_ID"
export WANDB_NAME="$RUN_ID"
export WANDB_SERVICE_WAIT=300
export WANDB_INIT_TIMEOUT=300
export WANDB_START_METHOD="thread"
export WANDB_CONSOLE="wrap"
# Keep wandb metadata/cache off network storage to speed init/resume.
export WANDB_DIR="${WANDB_DIR:-/tmp/$USER/wandb/$RUN_ID}"
export WANDB_CACHE_DIR="${WANDB_CACHE_DIR:-/tmp/$USER/wandb-cache}"
export WANDB_DATA_DIR="${WANDB_DATA_DIR:-/tmp/$USER/wandb-data}"
mkdir -p "$WANDB_DIR" "$WANDB_CACHE_DIR" "$WANDB_DATA_DIR"

# Debug knob: disable wandb entirely to isolate startup stalls.
# Default is enabled; set DISABLE_WANDB=1 to disable.
DISABLE_WANDB="${DISABLE_WANDB:-0}"
if [[ "$DISABLE_WANDB" == "1" ]]; then
  unset WANDB_PROJECT WANDB_ENTITY WANDB_API_KEY WANDB_USERNAME
  export WANDB_MODE="disabled"
fi

echo "--- Wandb ---"
echo "  project=${WANDB_PROJECT:-<disabled>}  entity=${WANDB_ENTITY:-<disabled>}"
echo "  run_id=$WANDB_RUN_ID  name=$WANDB_NAME  mode=$WANDB_MODE"
echo "  dir=$WANDB_DIR"
echo "  cache_dir=$WANDB_CACHE_DIR"
echo "  data_dir=$WANDB_DATA_DIR"
if [[ -n "${WANDB_API_KEY:-}" ]]; then
  echo "  api_key=${WANDB_API_KEY:0:10}...${WANDB_API_KEY: -4}"
else
  echo "  api_key=<disabled>"
fi
echo

# ============================================================
#  Sanity checks (cheap; before conda activate)
# ============================================================
if [[ ! -f "$WORKDIR/$CONFIG" ]] && [[ ! -f "$CONFIG" ]]; then
  echo "ERROR: config not found: $CONFIG (or $WORKDIR/$CONFIG)" >&2
  exit 2
fi
if [[ ! -d "$DATA_DIR" ]]; then
  echo "ERROR: data dir does not exist: $DATA_DIR" >&2
  exit 2
fi
if [[ ! -f "$SPLITS_PATH" ]]; then
  echo "ERROR: splits manifest not found: $SPLITS_PATH" >&2
  echo "  Generate it with: python tactile_vae/dataset/make_splits.py" >&2
  exit 2
fi

if [[ -f "$RUN_DIR/ckpt_last.pt" ]]; then
  echo "Resume:     auto-resume from $RUN_DIR/ckpt_last.pt"
else
  echo "Resume:     fresh run (no $RUN_DIR/ckpt_last.pt)"
fi
echo

# ============================================================
#  Resolve Python interpreter
# ============================================================
# Fast path: call env python directly to avoid expensive `conda activate`
# startup on busy shared filesystems. Fallback to full activation if needed.
PYTHON_BIN="${PYTHON_BIN:-$HOME/miniconda3/envs/$CONDA_ENV/bin/python}"
if [[ -x "$PYTHON_BIN" ]]; then
  echo "[$(date +%H:%M:%S)] using env python directly: $PYTHON_BIN"
else
  echo "[$(date +%H:%M:%S)] env python not found; falling back to conda activate..."
  source ~/miniconda3/etc/profile.d/conda.sh
  echo "[$(date +%H:%M:%S)] activating $CONDA_ENV..."
  conda activate "$CONDA_ENV"
  PYTHON_BIN="$(which python)"
  echo "[$(date +%H:%M:%S)] env activated. python = $PYTHON_BIN"
fi
echo "[$(date +%H:%M:%S)] GPU(s): ${CUDA_VISIBLE_DEVICES:-$(nvidia-smi -L 2>/dev/null | head -1 || echo none)}"
echo

# ============================================================
#  Launch training (`-u` is also forced by PYTHONUNBUFFERED above)
# ============================================================
cd "$WORKDIR"
echo "[$(date +%H:%M:%S)] launching trainer..."
"$PYTHON_BIN" -u tactile_vae/script/train_vae.py \
    --config "$CONFIG" \
    --run-id "$RUN_ID"

echo
echo "[$(date +%H:%M:%S)] Finished."
echo "Run dir contents:"
ls -lh "$RUN_DIR" 2>/dev/null || echo "  (empty)"