| """ |
| Command-line interface for BioRLHF. |
| |
| This module provides CLI entry points for training and evaluating models. |
| """ |
|
|
| import argparse |
| import json |
| import sys |
| from pathlib import Path |
|
|
| from biorlhf.training.sft import SFTTrainingConfig, run_sft_training |
| from biorlhf.evaluation.evaluate import evaluate_model as _evaluate_model |
| from biorlhf.training.grpo import BioGRPOConfig, run_grpo_training |
|
|
|
|
| def train(): |
| """CLI entry point for training models.""" |
| parser = argparse.ArgumentParser( |
| description="Train a BioRLHF model using supervised fine-tuning", |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| ) |
|
|
| |
| parser.add_argument( |
| "--model", |
| type=str, |
| default="mistralai/Mistral-7B-v0.3", |
| help="Base model to fine-tune", |
| ) |
| parser.add_argument( |
| "--dataset", |
| type=str, |
| required=True, |
| help="Path to training dataset JSON file", |
| ) |
| parser.add_argument( |
| "--output", |
| type=str, |
| default="./biorlhf_model", |
| help="Output directory for trained model", |
| ) |
|
|
| |
| parser.add_argument( |
| "--epochs", |
| type=int, |
| default=3, |
| help="Number of training epochs", |
| ) |
| parser.add_argument( |
| "--batch-size", |
| type=int, |
| default=4, |
| help="Training batch size per device", |
| ) |
| parser.add_argument( |
| "--learning-rate", |
| type=float, |
| default=2e-4, |
| help="Learning rate", |
| ) |
| parser.add_argument( |
| "--max-length", |
| type=int, |
| default=1024, |
| help="Maximum sequence length", |
| ) |
|
|
| |
| parser.add_argument( |
| "--lora-r", |
| type=int, |
| default=64, |
| help="LoRA rank", |
| ) |
| parser.add_argument( |
| "--lora-alpha", |
| type=int, |
| default=128, |
| help="LoRA alpha", |
| ) |
|
|
| |
| parser.add_argument( |
| "--no-quantization", |
| action="store_true", |
| help="Disable 4-bit quantization", |
| ) |
| parser.add_argument( |
| "--no-wandb", |
| action="store_true", |
| help="Disable Weights & Biases logging", |
| ) |
| parser.add_argument( |
| "--wandb-project", |
| type=str, |
| default="biorlhf", |
| help="W&B project name", |
| ) |
| parser.add_argument( |
| "--wandb-run-name", |
| type=str, |
| default="sft_training", |
| help="W&B run name", |
| ) |
| parser.add_argument( |
| "--config", |
| type=str, |
| default=None, |
| help="Path to JSON config file (overrides other args)", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| |
| if not Path(args.dataset).exists(): |
| print(f"Error: Dataset not found at {args.dataset}", file=sys.stderr) |
| sys.exit(1) |
|
|
| |
| if args.config: |
| with open(args.config) as f: |
| config_dict = json.load(f) |
| config = SFTTrainingConfig(**config_dict) |
| else: |
| config = SFTTrainingConfig( |
| model_name=args.model, |
| dataset_path=args.dataset, |
| output_dir=args.output, |
| num_epochs=args.epochs, |
| batch_size=args.batch_size, |
| learning_rate=args.learning_rate, |
| max_length=args.max_length, |
| lora_r=args.lora_r, |
| lora_alpha=args.lora_alpha, |
| use_4bit=not args.no_quantization, |
| use_wandb=not args.no_wandb, |
| wandb_project=args.wandb_project, |
| wandb_run_name=args.wandb_run_name, |
| ) |
|
|
| print("BioRLHF Training") |
| print("=" * 50) |
| print(f"Model: {config.model_name}") |
| print(f"Dataset: {config.dataset_path}") |
| print(f"Output: {config.output_dir}") |
| print("=" * 50) |
|
|
| try: |
| output_path = run_sft_training(config) |
| print(f"\nModel saved to: {output_path}") |
| except Exception as e: |
| print(f"Error during training: {e}", file=sys.stderr) |
| sys.exit(1) |
|
|
|
|
| def evaluate(): |
| """CLI entry point for evaluating models.""" |
| parser = argparse.ArgumentParser( |
| description="Evaluate a BioRLHF model on a test set", |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| ) |
|
|
| parser.add_argument( |
| "--model", |
| type=str, |
| required=True, |
| help="Path to the fine-tuned model directory", |
| ) |
| parser.add_argument( |
| "--test-set", |
| type=str, |
| required=True, |
| help="Path to test questions JSON file", |
| ) |
| parser.add_argument( |
| "--base-model", |
| type=str, |
| default="mistralai/Mistral-7B-v0.3", |
| help="Base model name", |
| ) |
| parser.add_argument( |
| "--output", |
| type=str, |
| default=None, |
| help="Output path for detailed results JSON", |
| ) |
| parser.add_argument( |
| "--no-quantization", |
| action="store_true", |
| help="Disable 4-bit quantization", |
| ) |
| parser.add_argument( |
| "--temperature", |
| type=float, |
| default=0.1, |
| help="Generation temperature (0 for greedy)", |
| ) |
| parser.add_argument( |
| "--max-tokens", |
| type=int, |
| default=512, |
| help="Maximum tokens to generate", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| |
| if not Path(args.model).exists(): |
| print(f"Error: Model not found at {args.model}", file=sys.stderr) |
| sys.exit(1) |
|
|
| if not Path(args.test_set).exists(): |
| print(f"Error: Test set not found at {args.test_set}", file=sys.stderr) |
| sys.exit(1) |
|
|
| print("BioRLHF Evaluation") |
| print("=" * 50) |
| print(f"Model: {args.model}") |
| print(f"Test Set: {args.test_set}") |
| print("=" * 50) |
|
|
| try: |
| results = _evaluate_model( |
| model_path=args.model, |
| test_questions_path=args.test_set, |
| base_model=args.base_model, |
| use_4bit=not args.no_quantization, |
| max_new_tokens=args.max_tokens, |
| temperature=args.temperature, |
| ) |
|
|
| print("\nResults:") |
| print("-" * 30) |
| print(f"Overall Accuracy: {results.overall_accuracy:.1%}") |
| print(f"Factual Accuracy: {results.factual_accuracy:.1%}") |
| print(f"Reasoning Accuracy: {results.reasoning_accuracy:.1%}") |
| print(f"Calibration Accuracy: {results.calibration_accuracy:.1%}") |
| print(f"Total: {results.correct_answers}/{results.total_questions}") |
|
|
| |
| if args.output: |
| output_data = { |
| "model_path": args.model, |
| "test_set": args.test_set, |
| "metrics": { |
| "overall_accuracy": results.overall_accuracy, |
| "factual_accuracy": results.factual_accuracy, |
| "reasoning_accuracy": results.reasoning_accuracy, |
| "calibration_accuracy": results.calibration_accuracy, |
| "total_questions": results.total_questions, |
| "correct_answers": results.correct_answers, |
| }, |
| "detailed_results": results.detailed_results, |
| } |
|
|
| with open(args.output, "w") as f: |
| json.dump(output_data, f, indent=2) |
|
|
| print(f"\nDetailed results saved to: {args.output}") |
|
|
| except Exception as e: |
| print(f"Error during evaluation: {e}", file=sys.stderr) |
| sys.exit(1) |
|
|
|
|
| def grpo_train(): |
| """CLI entry point for GRPO training with biological verifiers.""" |
| parser = argparse.ArgumentParser( |
| description="Train a BioGRPO model with composable biological verifiers", |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| ) |
|
|
| parser.add_argument( |
| "--model", |
| type=str, |
| default="mistralai/Mistral-7B-v0.3", |
| help="Base model to fine-tune", |
| ) |
| parser.add_argument( |
| "--sft-model", |
| type=str, |
| default=None, |
| help="Path to SFT checkpoint (recommended)", |
| ) |
| parser.add_argument( |
| "--output", |
| type=str, |
| default="./biogrpo_model", |
| help="Output directory", |
| ) |
| parser.add_argument( |
| "--num-generations", |
| type=int, |
| default=8, |
| help="G value: number of completions per prompt", |
| ) |
| parser.add_argument( |
| "--beta", |
| type=float, |
| default=0.04, |
| help="KL penalty coefficient", |
| ) |
| parser.add_argument( |
| "--learning-rate", |
| type=float, |
| default=1e-6, |
| help="Learning rate", |
| ) |
| parser.add_argument( |
| "--lora-r", |
| type=int, |
| default=32, |
| help="LoRA rank", |
| ) |
| parser.add_argument( |
| "--lora-alpha", |
| type=int, |
| default=64, |
| help="LoRA alpha", |
| ) |
| parser.add_argument( |
| "--verifiers", |
| type=str, |
| nargs="+", |
| default=None, |
| help="Active verifiers (e.g., V1 V2 V3 V4). Default: all", |
| ) |
| parser.add_argument( |
| "--pathway-db", |
| type=str, |
| default="hallmark", |
| choices=["hallmark", "kegg", "reactome", "mitocarta"], |
| help="Pathway database for GeneLab questions", |
| ) |
| parser.add_argument( |
| "--no-wandb", |
| action="store_true", |
| help="Disable W&B logging", |
| ) |
| parser.add_argument( |
| "--wandb-project", |
| type=str, |
| default="biogrpo", |
| help="W&B project name", |
| ) |
| parser.add_argument( |
| "--wandb-run-name", |
| type=str, |
| default="grpo_v1", |
| help="W&B run name", |
| ) |
| parser.add_argument( |
| "--config", |
| type=str, |
| default=None, |
| help="Path to JSON config file (overrides other args)", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| if args.config: |
| with open(args.config) as f: |
| config_dict = json.load(f) |
| config = BioGRPOConfig(**config_dict) |
| else: |
| config = BioGRPOConfig( |
| model_name=args.model, |
| sft_model_path=args.sft_model, |
| output_dir=args.output, |
| num_generations=args.num_generations, |
| beta=args.beta, |
| learning_rate=args.learning_rate, |
| lora_r=args.lora_r, |
| lora_alpha=args.lora_alpha, |
| active_verifiers=args.verifiers, |
| pathway_db=args.pathway_db, |
| use_wandb=not args.no_wandb, |
| wandb_project=args.wandb_project, |
| wandb_run_name=args.wandb_run_name, |
| ) |
|
|
| try: |
| output_path = run_grpo_training(config) |
| print(f"\nModel saved to: {output_path}") |
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| print(f"Error during GRPO training: {e}", file=sys.stderr) |
| sys.exit(1) |
|
|
|
|
| if __name__ == "__main__": |
| print("Use 'biorlhf-train', 'biorlhf-evaluate', or 'biorlhf-grpo' commands after installation.") |
|
|