| """Training modules for BioRLHF.""" | |
| __all__ = [ | |
| "SFTTrainingConfig", | |
| "run_sft_training", | |
| "DPOTrainingConfig", | |
| "run_dpo_training", | |
| "BioGRPOConfig", | |
| "run_grpo_training", | |
| ] | |
| def __getattr__(name): | |
| """Lazy imports for torch-dependent modules.""" | |
| if name in ("SFTTrainingConfig", "run_sft_training"): | |
| from biorlhf.training.sft import SFTTrainingConfig, run_sft_training | |
| return {"SFTTrainingConfig": SFTTrainingConfig, "run_sft_training": run_sft_training}[name] | |
| elif name in ("DPOTrainingConfig", "run_dpo_training"): | |
| from biorlhf.training.dpo import DPOTrainingConfig, run_dpo_training | |
| return {"DPOTrainingConfig": DPOTrainingConfig, "run_dpo_training": run_dpo_training}[name] | |
| elif name in ("BioGRPOConfig", "run_grpo_training"): | |
| from biorlhf.training.grpo import BioGRPOConfig, run_grpo_training | |
| return {"BioGRPOConfig": BioGRPOConfig, "run_grpo_training": run_grpo_training}[name] | |
| raise AttributeError(f"module 'biorlhf.training' has no attribute {name!r}") | |