TaoNet-mini-T2 / code /TaoTrain /scripts /remote /run_pre_200m_stability_gate.sh
StarMist0012's picture
Add files using upload-large-folder tool
e2bfccc verified
#!/usr/bin/env bash
set -euo pipefail
DATA_PATH="${DATA_PATH:-/home/student/Data/TaoData/pretrain.jsonl}"
SFT_DATA_PATH="${SFT_DATA_PATH:-/home/student/Data/TaoData/sft.jsonl}"
TOKENIZER_PATH="${TOKENIZER_PATH:-/home/student/YouZheng/tokenizers/taodata_pilot_8k/tokenizer.model}"
SSM_REPO_PATH="${SSM_REPO_PATH:-/home/student/YouZheng/gamma_ssm_repo}"
PYTHON_BIN="${PYTHON_BIN:-/home/student/.venv/bin/python}"
REMOTE_REPO="${REMOTE_REPO:-$(pwd)}"
OUTPUT_BASE="${REPOBRIDGE_OUTPUT_DIR:-$REMOTE_REPO/results/pre-200m-stability-gate}"
CHECKPOINT_BASE="${TAOTERN_CHECKPOINT_DIR:-$OUTPUT_BASE/checkpoints}"
SEQ_LEN="${SEQ_LEN:-512}"
BATCH_SIZE="${BATCH_SIZE:-8}"
TARGET_TOKENS="${TARGET_TOKENS:-20000000}"
MAX_TOKENS="${MAX_TOKENS:-50000000}"
MAX_RECORDS="${MAX_RECORDS:-120000}"
EVAL_BATCHES="${EVAL_BATCHES:-64}"
LEARNING_RATE="${LEARNING_RATE:-0.0008}"
WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}"
TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-250}"
SFT_SANITY_SAMPLES="${SFT_SANITY_SAMPLES:-4}"
SFT_SANITY_STEPS="${SFT_SANITY_STEPS:-120}"
SFT_SANITY_LR="${SFT_SANITY_LR:-0.00005}"
ceil_div() {
local numerator="$1"
local denominator="$2"
echo $(( (numerator + denominator - 1) / denominator ))
}
TRAIN_STEPS="${TRAIN_STEPS:-$(ceil_div "$TARGET_TOKENS" $((BATCH_SIZE * SEQ_LEN)))}"
export PYTHONPATH="$REMOTE_REPO/src:$SSM_REPO_PATH"
mkdir -p "$OUTPUT_BASE" "$CHECKPOINT_BASE" "$OUTPUT_BASE/configs" "$OUTPUT_BASE/diagnostics"
cat > "$OUTPUT_BASE/run_plan.json" <<JSON
{
"purpose": "pre_200m_stability_gate_before_4b_sft_chatbot_run",
"candidate": "pure_ssm_196m_stabilized_m256_h32",
"target_tokens": $TARGET_TOKENS,
"train_steps": $TRAIN_STEPS,
"batch_size": $BATCH_SIZE,
"seq_len": $SEQ_LEN,
"learning_rate": $LEARNING_RATE,
"weight_decay": $WEIGHT_DECAY,
"checks": [
"bounded pretrain loss/eval/grad telemetry",
"activation scale probe",
"sample generation",
"tiny SFT overfit probe"
]
}
JSON
printf '\n============================================================\n'
printf 'Pre-200M stability gate: pure SSM stabilized candidate\n'
printf 'target_tokens=%s batch=%s seq_len=%s train_steps=%s eval_batches=%s\n' \
"$TARGET_TOKENS" "$BATCH_SIZE" "$SEQ_LEN" "$TRAIN_STEPS" "$EVAL_BATCHES"
printf '============================================================\n'
"$PYTHON_BIN" scripts/benchmark_taonet_real_tokens.py \
--data-path "$DATA_PATH" \
--text-field text \
--tokenizer-type sentencepiece \
--tokenizer-path "$TOKENIZER_PATH" \
--max-records "$MAX_RECORDS" \
--max-tokens "$MAX_TOKENS" \
--eval-fraction 0.1 \
--architectures taonet_ssm \
--batch-sizes "$BATCH_SIZE" \
--seq-len "$SEQ_LEN" \
--hidden-dim 1024 \
--num-layers 18 \
--num-heads 8 \
--d-latent-kv 768 \
--d-rope 128 \
--hidden-dim-ff 3072 \
--dropout 0.0 \
--ssm-core dplr \
--ssm-hidden-dims 32 \
--ssm-mixer-dims 256 \
--ssm-num-lanes-list 2 \
--ssm-lane-combine channel \
--ssm-lane-modes split \
--ssm-split-mixes none \
--ssm-rank 1 \
--ssm-kernel-mode conv \
--no-ssm-finite-tail-correction \
--ssm-gate-types channel \
--dtype bf16 \
--device cuda \
--warmup 1 \
--repeats 2 \
--backward \
--train-steps "$TRAIN_STEPS" \
--train-log-every "$TRAIN_LOG_EVERY" \
--learning-rate "$LEARNING_RATE" \
--weight-decay "$WEIGHT_DECAY" \
--max-grad-norm 1.0 \
--eval-batches "$EVAL_BATCHES" \
--ssm-local-shift \
--ssm-local-shift-per-channel \
--ssm-local-shift-init 0.1 \
--ssm-branch-rms-norm \
--ssm-branch-clip-value 1.0 \
--block-residual-rms-norm \
--block-residual-rms-target 1.0 \
--output-dir "$OUTPUT_BASE/pretrain" \
--resume-completed \
--incremental-output \
--save-case-checkpoints \
--checkpoint-dir "$CHECKPOINT_BASE/pretrain"
PRETRAIN_CKPT="$CHECKPOINT_BASE/pretrain/latest.pt"
if [[ ! -f "$PRETRAIN_CKPT" ]]; then
echo "Expected pretrain checkpoint missing: $PRETRAIN_CKPT" >&2
exit 2
fi
"$PYTHON_BIN" scripts/diagnostics/activation_probe.py \
--checkpoint "$PRETRAIN_CKPT" \
--tokenizer-path "$TOKENIZER_PATH" \
--data-path "$DATA_PATH" \
--text-field text \
--output "$OUTPUT_BASE/diagnostics/activation_probe_pretrain_latest.json" \
--batch-size 2 \
--seq-len "$SEQ_LEN" \
--device cuda \
--dtype bfloat16
"$PYTHON_BIN" scripts/diagnostics/generate_checkpoint_samples.py \
--checkpoint "$PRETRAIN_CKPT" \
--tokenizer-path "$TOKENIZER_PATH" \
--output "$OUTPUT_BASE/diagnostics/generation_samples_pretrain_latest.json" \
--max-new-tokens 80 \
--temperature 0.8 \
--top-p 0.9 \
--prompt "The purpose of artificial intelligence is" \
--prompt "In a small village," \
--prompt "<user>Hello, who are you?<assistant>"
SFT_CONFIG="$OUTPUT_BASE/configs/sft_sanity.yaml"
cat > "$SFT_CONFIG" <<YAML
model:
architecture_type: taonet_ssm
vocab_size: 8192
hidden_dim: 1024
num_layers: 18
num_heads: 8
max_seq_length: $SEQ_LEN
d_latent_kv: 768
d_rope: 128
hidden_dim_ff: 3072
dropout: 0.0
gqa_groups: 1
use_factorized_embedding: false
d_embed_rank: 96
init_std: 0.02
ssm_core: dplr
ssm_hidden_dim: 32
ssm_mixer_dim: 256
ssm_num_lanes: 2
ssm_lane_combine: channel
ssm_lane_mode: split
ssm_split_mix: none
ssm_rank: 1
ssm_max_low_rank_scale: 0.1
ssm_finite_tail_correction: false
ssm_discretization: bilinear
ssm_kernel_mode: conv
ssm_kernel_threshold: 64
ssm_dt_min: 1e-3
ssm_dt_max: 1e-1
ssm_dt_init: 1e-2
ssm_use_d: true
ssm_activation: gelu
ssm_gate: true
ssm_input_gate: true
ssm_gate_type: channel
ssm_use_padding_mask: false
ssm_layer_scale_init: 0.1
ssm_branch_rms_norm: true
ssm_branch_clip_value: 1.0
block_residual_rms_norm: true
block_residual_rms_target: 1.0
ssm_local_shift: true
ssm_local_shift_init: 0.1
ssm_local_shift_per_channel: true
dataset:
split: train
instruction_column: input
response_column: output
local: true
jsonl_path: $SFT_DATA_PATH
samples_per_chunk: 2000
tokenizer_type: sentencepiece
tokenizer_path: $TOKENIZER_PATH
tokenizer_threads: 8
checkpoint_path: $PRETRAIN_CKPT
user_token: "<user>"
assistant_token: "<assistant>"
response_loss_only: true
batch_size: $BATCH_SIZE
num_epochs: 100000
max_steps: $SFT_SANITY_STEPS
gradient_accumulation_steps: 1
max_grad_norm: 1.0
optimizer:
optimizer_type: adamw
learning_rate: $SFT_SANITY_LR
weight_decay: 0.0
betas: [0.9, 0.999]
eps: 1e-8
scheduler:
scheduler_type: linearWarmup
warmup_steps: 0
dtype: bfloat16
device: cuda
checkpoint_dir: $CHECKPOINT_BASE/sft_sanity
save_every_steps: 5000
save_best_model: false
keep_last_n_checkpoints: 1
eval_every_steps: 5000
eval_samples: 32
log_every_steps: 10
aim_repo: $OUTPUT_BASE/.aim-sft-sanity
seed: 43
num_workers: 0
pin_memory: true
YAML
"$PYTHON_BIN" scripts/diagnostics/sft_sanity_check.py \
--config "$SFT_CONFIG" \
--checkpoint "$PRETRAIN_CKPT" \
--output "$OUTPUT_BASE/diagnostics/sft_sanity_pretrain_latest.json" \
--samples "$SFT_SANITY_SAMPLES" \
--steps "$SFT_SANITY_STEPS" \
--lr "$SFT_SANITY_LR" \
--log-every 20 \
--device cuda \
--dtype bfloat16 \
--ssm-branch-rms-norm \
--ssm-branch-clip-value 1.0 \
--block-residual-rms-norm \
--block-residual-rms-target 1.0
cat > "$OUTPUT_BASE/gate_summary.json" <<JSON
{
"state": "completed",
"pretrain_checkpoint": "$PRETRAIN_CKPT",
"pretrain_results": "$OUTPUT_BASE/pretrain/taonet_real_token_benchmark.json",
"activation_probe": "$OUTPUT_BASE/diagnostics/activation_probe_pretrain_latest.json",
"generation_samples": "$OUTPUT_BASE/diagnostics/generation_samples_pretrain_latest.json",
"sft_sanity": "$OUTPUT_BASE/diagnostics/sft_sanity_pretrain_latest.json"
}
JSON
echo "Pre-200M stability gate completed."
cat "$OUTPUT_BASE/gate_summary.json"