File size: 3,902 Bytes
b7c075a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/bin/bash
#SBATCH --account=nvr_lpr_rvp
#SBATCH --qos=normal
#SBATCH --partition=batch_long
#SBATCH --nodes=1
#SBATCH --gpus-per-node=4
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=56
#SBATCH --time=24:00:00
#SBATCH --job-name=openvla_oft_spatial_object
#SBATCH --output=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_spatial_object_%j.out
#SBATCH --error=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_spatial_object_%j.err
#SBATCH --comment=fact_off

set -euo pipefail

WORKSPACE=/lustre/fsw/portfolios/nvr/users/jtremblay
OPENVLA_DIR=$WORKSPACE/yu/openvla-oft
CONDA_ENV=$WORKSPACE/conda_envs/openvla-oft
PYTHON=$CONDA_ENV/bin/python
ACCELERATE=$CONDA_ENV/bin/accelerate

DATA_ROOT=$WORKSPACE/yu/conflict_maniskill/demo_conflict/spatial_object/300/huggingface_data/spatial_object/conflict
RLDS_OUTPUT=$WORKSPACE/yu/rlds_spatial_object
RUN_DIR=$WORKSPACE/yu/openvla-oft/runs/spatial_object
MAX_STEPS=40000

export HF_HOME=$WORKSPACE/hugging_face
export HF_TOKEN="${HF_TOKEN:-}"
export TOKENIZERS_PARALLELISM=false
export WANDB_MODE=disabled
export PYTHONPATH=$OPENVLA_DIR:${PYTHONPATH:-}
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

mkdir -p "$RLDS_OUTPUT" "$RUN_DIR" "$WORKSPACE/yu/logs"

cd "$OPENVLA_DIR"

# ── Step 1: Build RLDS TFRecords (once) ─────────────────────────────────
if [ ! -d "$RLDS_OUTPUT/conflict_maniskill" ]; then
    echo "============================================================"
    echo " Building RLDS dataset from parquet for spatial_object..."
    echo "============================================================"
    $PYTHON prismatic/vla/datasets/rlds/conflict_maniskill_dataset_builder.py \
        --data_root "$DATA_ROOT" \
        --output_dir "$RLDS_OUTPUT"
else
    echo "RLDS dataset already exists, skipping build."
fi

# ── Step 2: Detect latest checkpoint ─────────────────────────────────────
LATEST_CKPT=""
LATEST_STEP=0

for d in "$RUN_DIR"/*_chkpt; do
    [ -d "$d" ] || continue
    step=$(basename "$d" | grep -oP '\d+(?=_chkpt)')
    if [ -n "$step" ] && [ "$step" -gt "$LATEST_STEP" ]; then
        LATEST_STEP=$step
        LATEST_CKPT=$d
    fi
done

# ── Step 3: Fine-tune (fresh or resumed) ─────────────────────────────────
if [ "$LATEST_STEP" -ge "$MAX_STEPS" ]; then
    echo "Already reached max_steps=$MAX_STEPS (latest checkpoint: step $LATEST_STEP). Nothing to do."
    exit 0
fi

if [ -n "$LATEST_CKPT" ]; then
    echo "============================================================"
    echo " Resuming from step $LATEST_STEP: $LATEST_CKPT"
    echo "============================================================"
    RESUME_ARGS="--resume true --resume_step $LATEST_STEP --vla_path $LATEST_CKPT"
else
    echo "============================================================"
    echo " Starting fresh fine-tune from openvla/openvla-7b"
    echo "============================================================"
    RESUME_ARGS="--vla_path openvla/openvla-7b"
fi

$ACCELERATE launch \
    --mixed_precision bf16 \
    --num_processes 4 \
    --num_machines 1 \
    vla-scripts/finetune.py \
    $RESUME_ARGS \
    --data_root_dir "$RLDS_OUTPUT" \
    --dataset_name conflict_maniskill \
    --run_root_dir "$RUN_DIR" \
    --use_l1_regression true \
    --use_film false \
    --num_images_in_input 2 \
    --use_proprio true \
    --batch_size 2 \
    --grad_accumulation_steps 4 \
    --learning_rate 5e-4 \
    --max_steps $MAX_STEPS \
    --save_freq 5000 \
    --save_latest_checkpoint_only false \
    --image_aug true \
    --use_lora true \
    --lora_rank 32 \
    --merge_lora_during_training true \
    --wandb_entity disabled \
    --wandb_project disabled