File size: 3,693 Bytes
c7ebaa1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | #!/usr/bin/env python3
"""
BioRLHF SFT Training Example
This script demonstrates how to fine-tune a language model using
supervised fine-tuning (SFT) on biological reasoning tasks.
Requirements:
- CUDA-compatible GPU with 16GB+ VRAM (or use CPU with reduced batch size)
- PyTorch with CUDA support
- All BioRLHF dependencies installed
Usage:
python train_sft.py [--config custom_config.json]
"""
import argparse
import json
from pathlib import Path
from biorlhf import SFTTrainingConfig, run_sft_training
from biorlhf.data.dataset import create_sft_dataset
def create_training_dataset(output_path: str = "training_dataset.json") -> str:
"""Create a training dataset if one doesn't exist."""
path = Path(output_path)
if path.exists():
print(f"Using existing dataset: {output_path}")
return output_path
print(f"Creating new dataset: {output_path}")
create_sft_dataset(
output_path=output_path,
include_calibration=True,
include_chain_of_thought=True,
)
return output_path
def main():
"""Run SFT training."""
parser = argparse.ArgumentParser(
description="Fine-tune a model for biological reasoning"
)
parser.add_argument(
"--model",
type=str,
default="mistralai/Mistral-7B-v0.3",
help="Base model to fine-tune",
)
parser.add_argument(
"--dataset",
type=str,
default=None,
help="Path to training dataset (created if not provided)",
)
parser.add_argument(
"--output",
type=str,
default="./biorlhf_model",
help="Output directory for trained model",
)
parser.add_argument(
"--epochs",
type=int,
default=3,
help="Number of training epochs",
)
parser.add_argument(
"--batch-size",
type=int,
default=4,
help="Training batch size per device",
)
parser.add_argument(
"--learning-rate",
type=float,
default=2e-4,
help="Learning rate",
)
parser.add_argument(
"--no-wandb",
action="store_true",
help="Disable Weights & Biases logging",
)
parser.add_argument(
"--wandb-project",
type=str,
default="biorlhf",
help="W&B project name",
)
parser.add_argument(
"--config",
type=str,
default=None,
help="Path to JSON config file (overrides other args)",
)
args = parser.parse_args()
# Load config from file if provided
if args.config:
with open(args.config) as f:
config_dict = json.load(f)
config = SFTTrainingConfig(**config_dict)
else:
# Create or use dataset
dataset_path = args.dataset
if dataset_path is None:
dataset_path = create_training_dataset()
# Build config from arguments
config = SFTTrainingConfig(
model_name=args.model,
dataset_path=dataset_path,
output_dir=args.output,
num_epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
use_wandb=not args.no_wandb,
wandb_project=args.wandb_project,
)
print("\nTraining Configuration:")
print("-" * 40)
for key, value in vars(config).items():
print(f" {key}: {value}")
print("-" * 40)
# Run training
output_path = run_sft_training(config)
print(f"\nModel saved to: {output_path}")
print("\nTo evaluate the model, run:")
print(f" python evaluate_model.py --model {output_path}")
if __name__ == "__main__":
main()
|