File size: 5,366 Bytes
bff2f94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#!/bin/bash
#
# BioRLHF Training Script - Ecosystem Improved Model
# ====================================================
#
# This script trains a model on the combined dataset including:
# - Original KMP study data (363 examples)
# - Ecosystem failure-based examples (15 examples)
#   - Calibration training
#   - Adversarial resistance
#   - Protocol completeness
#   - Fact drilling
#
# Requirements:
# - CUDA-capable GPU (recommended: A100, V100, or 4090)
# - 24GB+ VRAM for Mistral-7B with 4-bit quantization
# - Python environment with: torch, transformers, peft, trl, bitsandbytes
#
# Usage:
#   ./scripts/train_ecosystem_improved.sh
#
# Or on HPC with SLURM:
#   sbatch scripts/train_ecosystem_improved.sh
#

# ==============================================================================
# SLURM Configuration (for HPC clusters - uncomment if using SLURM)
# ==============================================================================
#SBATCH --job-name=biorlhf_ecosystem
#SBATCH --output=logs/biorlhf_ecosystem_%j.out
#SBATCH --error=logs/biorlhf_ecosystem_%j.err
#SBATCH --time=4:00:00
#SBATCH --gres=gpu:1
#SBATCH --mem=48G
#SBATCH --cpus-per-task=8

# ==============================================================================
# Environment Setup
# ==============================================================================
echo "============================================================"
echo "BioRLHF Ecosystem Training"
echo "============================================================"
echo "Start time: $(date)"
echo "Host: $(hostname)"
echo ""

# Activate conda environment (adjust path as needed)
# source /path/to/conda/etc/profile.d/conda.sh
# conda activate biorlhf

# Set working directory
cd "$(dirname "$0")/.." || exit 1
echo "Working directory: $(pwd)"

# Check GPU
echo ""
echo "GPU Information:"
nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv 2>/dev/null || echo "No GPU detected"
echo ""

# ==============================================================================
# Training Configuration
# ==============================================================================

# Model settings
MODEL="mistralai/Mistral-7B-v0.3"
DATASET="data/combined_training.json"
OUTPUT_DIR="./ecosystem_improved_model"

# Training hyperparameters (optimized based on prior BioRLHF experiments)
EPOCHS=10           # More epochs for better fact memorization
BATCH_SIZE=4        # Adjust based on GPU memory
GRAD_ACCUM=4        # Effective batch size = 16
LEARNING_RATE=2e-4  # Standard for LoRA fine-tuning
MAX_LENGTH=1024     # Sufficient for most examples

# LoRA configuration (higher rank for domain knowledge)
LORA_R=64           # Higher rank for better capacity
LORA_ALPHA=128      # Alpha = 2 * r

# Logging
WANDB_PROJECT="biorlhf"
WANDB_RUN="ecosystem_improved_$(date +%Y%m%d_%H%M%S)"

# ==============================================================================
# Pre-training Checks
# ==============================================================================
echo "============================================================"
echo "Configuration:"
echo "============================================================"
echo "Model:        $MODEL"
echo "Dataset:      $DATASET"
echo "Output:       $OUTPUT_DIR"
echo "Epochs:       $EPOCHS"
echo "Batch size:   $BATCH_SIZE (effective: $((BATCH_SIZE * GRAD_ACCUM)))"
echo "LoRA r/α:     $LORA_R / $LORA_ALPHA"
echo "Max length:   $MAX_LENGTH"
echo ""

# Check if dataset exists
if [ ! -f "$DATASET" ]; then
    echo "ERROR: Dataset not found at $DATASET"
    echo "Run: python scripts/merge_training_data.py"
    exit 1
fi

# Count examples
EXAMPLE_COUNT=$(python3 -c "import json; print(len(json.load(open('$DATASET'))))")
echo "Dataset contains $EXAMPLE_COUNT examples"
echo ""

# ==============================================================================
# Run Training
# ==============================================================================
echo "============================================================"
echo "Starting Training..."
echo "============================================================"

python3 sft_train_v2.py \
    --model "$MODEL" \
    --dataset "$DATASET" \
    --output_dir "$OUTPUT_DIR" \
    --epochs $EPOCHS \
    --batch_size $BATCH_SIZE \
    --grad_accum $GRAD_ACCUM \
    --lr $LEARNING_RATE \
    --max_length $MAX_LENGTH \
    --lora_r $LORA_R \
    --lora_alpha $LORA_ALPHA \
    --use_4bit \
    --wandb_project "$WANDB_PROJECT" \
    --wandb_run "$WANDB_RUN"

# Check exit status
if [ $? -eq 0 ]; then
    echo ""
    echo "============================================================"
    echo "✅ Training Complete!"
    echo "============================================================"
    echo "Model saved to: $OUTPUT_DIR"
    echo "End time: $(date)"
    echo ""
    echo "Next steps:"
    echo "1. Evaluate on SpaceOmicsBench: python evaluate_model.py --model $OUTPUT_DIR"
    echo "2. Evaluate on CAMELOT: python evaluate_model.py --model $OUTPUT_DIR --benchmark camelot"
    echo "3. Compare with baseline: python compare_models.py"
else
    echo ""
    echo "============================================================"
    echo "❌ Training Failed!"
    echo "============================================================"
    echo "Check the error messages above."
    exit 1
fi