File size: 2,729 Bytes
e2bfccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#!/usr/bin/env bash
set -euo pipefail

ROOT_OUTPUT_DIR="${REPOBRIDGE_OUTPUT_DIR:-$(pwd)/results/200m-until-selection}"
ROOT_CHECKPOINT_DIR="${TAOTERN_CHECKPOINT_DIR:-$ROOT_OUTPUT_DIR/checkpoints}"
SEQ_LEN="${SEQ_LEN:-512}"
BATCH_SIZE="${BATCH_SIZE:-8}"
PILOT_TOKENS="${PILOT_TOKENS:-300000000}"
SERIOUS_TOKENS="${SERIOUS_TOKENS:-1000000000}"
PILOT_EVAL_BATCHES="${PILOT_EVAL_BATCHES:-64}"
SERIOUS_EVAL_BATCHES="${SERIOUS_EVAL_BATCHES:-128}"
LEARNING_RATE="${LEARNING_RATE:-0.0003}"
WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}"

ceil_div() {
  local numerator="$1"
  local denominator="$2"
  echo $(( (numerator + denominator - 1) / denominator ))
}

run_phase() {
  local phase="$1"
  local target_tokens="$2"
  local eval_batches="$3"
  local tokens_per_step=$((BATCH_SIZE * SEQ_LEN))
  local train_steps
  train_steps="$(ceil_div "$target_tokens" "$tokens_per_step")"

  printf '\n============================================================\n'
  printf '200M until-selection phase: %s\n' "$phase"
  printf 'target_tokens=%s batch=%s seq_len=%s train_steps=%s eval_batches=%s\n' \
    "$target_tokens" "$BATCH_SIZE" "$SEQ_LEN" "$train_steps" "$eval_batches"
  printf '============================================================\n'

  REPOBRIDGE_OUTPUT_DIR="$ROOT_OUTPUT_DIR/$phase" \
  TAOTERN_CHECKPOINT_DIR="$ROOT_CHECKPOINT_DIR/$phase" \
  BATCH_SIZES="$BATCH_SIZE" \
  SEQ_LEN="$SEQ_LEN" \
  TRAIN_STEPS="$train_steps" \
  EVAL_BATCHES="$eval_batches" \
  LEARNING_RATE="$LEARNING_RATE" \
  WEIGHT_DECAY="$WEIGHT_DECAY" \
  bash scripts/remote/run_200m_base_suite.sh
}

mkdir -p "$ROOT_OUTPUT_DIR" "$ROOT_CHECKPOINT_DIR"

cat > "$ROOT_OUTPUT_DIR/run_plan.json" <<JSON

{

  "stopping_point": "selection_after_1b_all_four_variants",

  "batch_size": $BATCH_SIZE,

  "seq_len": $SEQ_LEN,

  "learning_rate": $LEARNING_RATE,

  "weight_decay": $WEIGHT_DECAY,

  "phases": [

    {

      "name": "pilot_300m",

      "target_tokens_per_variant": $PILOT_TOKENS,

      "train_steps": $(ceil_div "$PILOT_TOKENS" $((BATCH_SIZE * SEQ_LEN))),

      "eval_batches": $PILOT_EVAL_BATCHES

    },

    {

      "name": "serious_1b",

      "target_tokens_per_variant": $SERIOUS_TOKENS,

      "train_steps": $(ceil_div "$SERIOUS_TOKENS" $((BATCH_SIZE * SEQ_LEN))),

      "eval_batches": $SERIOUS_EVAL_BATCHES

    }

  ],

  "variants": [

    "attention_196m",

    "pure_ssm_196m_hadamard",

    "pure_ssm_196m_nomix",

    "hybrid_ssm_first_199m"

  ]

}

JSON

run_phase "pilot_300m" "$PILOT_TOKENS" "$PILOT_EVAL_BATCHES"
run_phase "serious_1b" "$SERIOUS_TOKENS" "$SERIOUS_EVAL_BATCHES"

echo "Selection gate reached after pilot_300m and serious_1b completed for all four variants."