File size: 3,693 Bytes
c7ebaa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#!/usr/bin/env python3
"""
BioRLHF SFT Training Example

This script demonstrates how to fine-tune a language model using
supervised fine-tuning (SFT) on biological reasoning tasks.

Requirements:
- CUDA-compatible GPU with 16GB+ VRAM (or use CPU with reduced batch size)
- PyTorch with CUDA support
- All BioRLHF dependencies installed

Usage:
    python train_sft.py [--config custom_config.json]
"""

import argparse
import json
from pathlib import Path

from biorlhf import SFTTrainingConfig, run_sft_training
from biorlhf.data.dataset import create_sft_dataset


def create_training_dataset(output_path: str = "training_dataset.json") -> str:
    """Create a training dataset if one doesn't exist."""
    path = Path(output_path)

    if path.exists():
        print(f"Using existing dataset: {output_path}")
        return output_path

    print(f"Creating new dataset: {output_path}")
    create_sft_dataset(
        output_path=output_path,
        include_calibration=True,
        include_chain_of_thought=True,
    )

    return output_path


def main():
    """Run SFT training."""
    parser = argparse.ArgumentParser(
        description="Fine-tune a model for biological reasoning"
    )
    parser.add_argument(
        "--model",
        type=str,
        default="mistralai/Mistral-7B-v0.3",
        help="Base model to fine-tune",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default=None,
        help="Path to training dataset (created if not provided)",
    )
    parser.add_argument(
        "--output",
        type=str,
        default="./biorlhf_model",
        help="Output directory for trained model",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=3,
        help="Number of training epochs",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=4,
        help="Training batch size per device",
    )
    parser.add_argument(
        "--learning-rate",
        type=float,
        default=2e-4,
        help="Learning rate",
    )
    parser.add_argument(
        "--no-wandb",
        action="store_true",
        help="Disable Weights & Biases logging",
    )
    parser.add_argument(
        "--wandb-project",
        type=str,
        default="biorlhf",
        help="W&B project name",
    )
    parser.add_argument(
        "--config",
        type=str,
        default=None,
        help="Path to JSON config file (overrides other args)",
    )

    args = parser.parse_args()

    # Load config from file if provided
    if args.config:
        with open(args.config) as f:
            config_dict = json.load(f)
        config = SFTTrainingConfig(**config_dict)
    else:
        # Create or use dataset
        dataset_path = args.dataset
        if dataset_path is None:
            dataset_path = create_training_dataset()

        # Build config from arguments
        config = SFTTrainingConfig(
            model_name=args.model,
            dataset_path=dataset_path,
            output_dir=args.output,
            num_epochs=args.epochs,
            batch_size=args.batch_size,
            learning_rate=args.learning_rate,
            use_wandb=not args.no_wandb,
            wandb_project=args.wandb_project,
        )

    print("\nTraining Configuration:")
    print("-" * 40)
    for key, value in vars(config).items():
        print(f"  {key}: {value}")
    print("-" * 40)

    # Run training
    output_path = run_sft_training(config)

    print(f"\nModel saved to: {output_path}")
    print("\nTo evaluate the model, run:")
    print(f"  python evaluate_model.py --model {output_path}")


if __name__ == "__main__":
    main()