File size: 5,366 Bytes
bff2f94 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | #!/bin/bash
#
# BioRLHF Training Script - Ecosystem Improved Model
# ====================================================
#
# This script trains a model on the combined dataset including:
# - Original KMP study data (363 examples)
# - Ecosystem failure-based examples (15 examples)
# - Calibration training
# - Adversarial resistance
# - Protocol completeness
# - Fact drilling
#
# Requirements:
# - CUDA-capable GPU (recommended: A100, V100, or 4090)
# - 24GB+ VRAM for Mistral-7B with 4-bit quantization
# - Python environment with: torch, transformers, peft, trl, bitsandbytes
#
# Usage:
# ./scripts/train_ecosystem_improved.sh
#
# Or on HPC with SLURM:
# sbatch scripts/train_ecosystem_improved.sh
#
# ==============================================================================
# SLURM Configuration (for HPC clusters - uncomment if using SLURM)
# ==============================================================================
#SBATCH --job-name=biorlhf_ecosystem
#SBATCH --output=logs/biorlhf_ecosystem_%j.out
#SBATCH --error=logs/biorlhf_ecosystem_%j.err
#SBATCH --time=4:00:00
#SBATCH --gres=gpu:1
#SBATCH --mem=48G
#SBATCH --cpus-per-task=8
# ==============================================================================
# Environment Setup
# ==============================================================================
echo "============================================================"
echo "BioRLHF Ecosystem Training"
echo "============================================================"
echo "Start time: $(date)"
echo "Host: $(hostname)"
echo ""
# Activate conda environment (adjust path as needed)
# source /path/to/conda/etc/profile.d/conda.sh
# conda activate biorlhf
# Set working directory
cd "$(dirname "$0")/.." || exit 1
echo "Working directory: $(pwd)"
# Check GPU
echo ""
echo "GPU Information:"
nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv 2>/dev/null || echo "No GPU detected"
echo ""
# ==============================================================================
# Training Configuration
# ==============================================================================
# Model settings
MODEL="mistralai/Mistral-7B-v0.3"
DATASET="data/combined_training.json"
OUTPUT_DIR="./ecosystem_improved_model"
# Training hyperparameters (optimized based on prior BioRLHF experiments)
EPOCHS=10 # More epochs for better fact memorization
BATCH_SIZE=4 # Adjust based on GPU memory
GRAD_ACCUM=4 # Effective batch size = 16
LEARNING_RATE=2e-4 # Standard for LoRA fine-tuning
MAX_LENGTH=1024 # Sufficient for most examples
# LoRA configuration (higher rank for domain knowledge)
LORA_R=64 # Higher rank for better capacity
LORA_ALPHA=128 # Alpha = 2 * r
# Logging
WANDB_PROJECT="biorlhf"
WANDB_RUN="ecosystem_improved_$(date +%Y%m%d_%H%M%S)"
# ==============================================================================
# Pre-training Checks
# ==============================================================================
echo "============================================================"
echo "Configuration:"
echo "============================================================"
echo "Model: $MODEL"
echo "Dataset: $DATASET"
echo "Output: $OUTPUT_DIR"
echo "Epochs: $EPOCHS"
echo "Batch size: $BATCH_SIZE (effective: $((BATCH_SIZE * GRAD_ACCUM)))"
echo "LoRA r/α: $LORA_R / $LORA_ALPHA"
echo "Max length: $MAX_LENGTH"
echo ""
# Check if dataset exists
if [ ! -f "$DATASET" ]; then
echo "ERROR: Dataset not found at $DATASET"
echo "Run: python scripts/merge_training_data.py"
exit 1
fi
# Count examples
EXAMPLE_COUNT=$(python3 -c "import json; print(len(json.load(open('$DATASET'))))")
echo "Dataset contains $EXAMPLE_COUNT examples"
echo ""
# ==============================================================================
# Run Training
# ==============================================================================
echo "============================================================"
echo "Starting Training..."
echo "============================================================"
python3 sft_train_v2.py \
--model "$MODEL" \
--dataset "$DATASET" \
--output_dir "$OUTPUT_DIR" \
--epochs $EPOCHS \
--batch_size $BATCH_SIZE \
--grad_accum $GRAD_ACCUM \
--lr $LEARNING_RATE \
--max_length $MAX_LENGTH \
--lora_r $LORA_R \
--lora_alpha $LORA_ALPHA \
--use_4bit \
--wandb_project "$WANDB_PROJECT" \
--wandb_run "$WANDB_RUN"
# Check exit status
if [ $? -eq 0 ]; then
echo ""
echo "============================================================"
echo "✅ Training Complete!"
echo "============================================================"
echo "Model saved to: $OUTPUT_DIR"
echo "End time: $(date)"
echo ""
echo "Next steps:"
echo "1. Evaluate on SpaceOmicsBench: python evaluate_model.py --model $OUTPUT_DIR"
echo "2. Evaluate on CAMELOT: python evaluate_model.py --model $OUTPUT_DIR --benchmark camelot"
echo "3. Compare with baseline: python compare_models.py"
else
echo ""
echo "============================================================"
echo "❌ Training Failed!"
echo "============================================================"
echo "Check the error messages above."
exit 1
fi
|