Spaces:
Running on Zero
Running on Zero
| import os | |
| from typing import Any | |
| os.environ["UNSLOTH_VLLM_STANDBY"] = "1" | |
| import unsloth # noqa: I001, F401 | |
| import logging | |
| import sys | |
| import transformers | |
| from datasets import DatasetDict, load_dataset | |
| from datasets.load import DownloadMode | |
| from datasets.utils.logging import set_verbosity | |
| from transformers.trainer_utils import get_last_checkpoint, set_seed | |
| from trl.scripts.utils import TrlParser | |
| from trl.trainer.sft_config import SFTConfig | |
| from trl.trainer.sft_trainer import SFTTrainer | |
| from linalg_zero.config.data import ScriptArguments, SFTModelConfig, SFTRunConfig | |
| from linalg_zero.sft.callbacks import get_callbacks | |
| from linalg_zero.sft.utils import ( | |
| ensure_tokenizer_has_defaults, | |
| get_unsloth_model, | |
| init_wandb_training, | |
| load_merged_model_for_sft, | |
| ) | |
| from linalg_zero.shared.utils import get_logger, setup_logging | |
| def main( # noqa: C901 | |
| script_args: ScriptArguments, training_args: SFTRunConfig, trl_training_args: SFTConfig, model_args: SFTModelConfig | |
| ) -> None: | |
| """Main training function.""" | |
| # Reproducibility | |
| set_seed(trl_training_args.seed) | |
| ################# | |
| # Setup logging # | |
| ################# | |
| # Log both to file and console | |
| setup_logging(level=logging.INFO, include_timestamp=True) | |
| logger = get_logger(__name__) | |
| # Adjust script logging level based on the node logging level (main process or replica) | |
| log_level = trl_training_args.get_process_log_level() | |
| logger.setLevel(log_level) | |
| set_verbosity(log_level) | |
| transformers.utils.logging.set_verbosity(log_level) | |
| transformers.utils.logging.enable_default_handler() | |
| transformers.utils.logging.enable_explicit_format() | |
| logger.info(f"Model parameters: {model_args}") | |
| logger.info(f"Script parameters: {script_args}") | |
| logger.info(f"Training parameters: {training_args}") | |
| logger.info(f"TRL training parameters: {trl_training_args}") | |
| # Check for last checkpoint | |
| last_checkpoint = None | |
| if trl_training_args.output_dir and os.path.isdir(trl_training_args.output_dir): | |
| last_checkpoint = get_last_checkpoint(trl_training_args.output_dir) | |
| if last_checkpoint is not None and trl_training_args.resume_from_checkpoint is None: | |
| logger.info(f"Checkpoint detected, resuming training at {last_checkpoint}") | |
| # Initialize wandb if requested | |
| if trl_training_args.report_to and "wandb" in trl_training_args.report_to: | |
| init_wandb_training(training_args) | |
| ###################################### | |
| # Load dataset, tokenizer, and model # | |
| ###################################### | |
| logger.info(f"Loading dataset from {script_args.dataset_name}...") | |
| dataset = load_dataset( | |
| script_args.dataset_name, script_args.dataset_config, download_mode=DownloadMode.FORCE_REDOWNLOAD | |
| ) | |
| if not isinstance(dataset, DatasetDict): | |
| raise TypeError(f"Expected dataset to be a DatasetDict, but got {type(dataset)}") | |
| # Model, tokenizer, dataset | |
| logger.info("Loading model and tokenizer...") | |
| if getattr(model_args, "use_peft", True): | |
| # Standard LoRA SFT on base model | |
| model, tokenizer = get_unsloth_model(model_args, training_args, trl_training_args, resume_path=last_checkpoint) | |
| else: | |
| # Light touch-up on a merged model: train only I/O layers if requested | |
| max_seq_len = training_args.max_seq_length or trl_training_args.max_seq_length | |
| model, tokenizer = load_merged_model_for_sft( | |
| model_path=model_args.model_name_or_path, | |
| max_seq_length=max_seq_len, | |
| dtype=None, | |
| train_io_only=True, | |
| add_special_tokens=training_args.add_special_tokens, | |
| ) | |
| # Ensure pad token and padding side are set consistently for SFT | |
| ensure_tokenizer_has_defaults(tokenizer, model) | |
| def ensure_text(x: dict[str, Any]) -> dict[str, Any]: | |
| x["text"] = tokenizer.apply_chat_template(x["messages"], tools=x["tools"], tokenize=False) | |
| return x | |
| def formatting_prompts_func(examples): | |
| convos = examples["messages"] # List of 1000 conversations | |
| tools = examples.get("tools", None) # List of 1000 tool specs | |
| texts = [] | |
| for i, convo in enumerate(convos): | |
| example_tools = tools[i] if tools and isinstance(tools, list) else tools | |
| text = tokenizer.apply_chat_template( | |
| convo, | |
| tools=example_tools, # Pass tools[i] for the i-th conversation | |
| tokenize=False, | |
| add_generation_prompt=False, | |
| ) | |
| texts.append(text) | |
| return {"text": texts} | |
| dataset = dataset.map(formatting_prompts_func, batched=True) | |
| ############################## | |
| # Initialize the SFT Trainer # | |
| ############################## | |
| trl_training_args.max_eval_samples = training_args.max_eval_samples | |
| trl_training_args.eval_max_new_tokens = training_args.eval_max_new_tokens | |
| logger.info("Initializing SFT Trainer...") | |
| trainer = SFTTrainer( | |
| model=model, | |
| processing_class=tokenizer, | |
| train_dataset=dataset[script_args.dataset_train_split], | |
| eval_dataset=(dataset[script_args.dataset_test_split] if trl_training_args.eval_strategy != "no" else None), | |
| args=trl_training_args, | |
| callbacks=get_callbacks(training_args, model_args, script_args, dataset), | |
| ) | |
| ################# | |
| # Training loop # | |
| ################# | |
| logger.info("*** Starting Training ***") | |
| checkpoint = None | |
| if trl_training_args.resume_from_checkpoint is not None: | |
| checkpoint = trl_training_args.resume_from_checkpoint | |
| elif last_checkpoint is not None: | |
| checkpoint = last_checkpoint | |
| try: | |
| train_result = trainer.train(resume_from_checkpoint=checkpoint) | |
| metrics = train_result.metrics | |
| metrics["train_samples"] = len(dataset[script_args.dataset_train_split]) | |
| trainer.log_metrics("train", metrics) | |
| trainer.save_metrics("train", metrics) | |
| trainer.save_state() | |
| logger.info("Training completed successfully!") | |
| except KeyboardInterrupt: | |
| logger.info("Training interrupted by user.") | |
| except Exception: | |
| logger.exception("Training failed with an unexpected error") | |
| raise | |
| #################################### | |
| # Save model and create model card # | |
| #################################### | |
| logger.info("*** Saving Model ***") | |
| try: | |
| # Align the model's generation config with the tokenizer's eos token | |
| # to avoid unbounded generation in the transformers `pipeline()` function | |
| if trainer.model is not None and trainer.model.generation_config is not None: | |
| trainer.model.generation_config.eos_token_id = tokenizer.eos_token_id | |
| assert trainer.model.generation_config.pad_token_id == tokenizer.pad_token_id, "Pad token ID mismatch" | |
| # Restore k,v cache for fast inference before saving | |
| if trainer.model is not None: | |
| trainer.model.config.use_cache = True | |
| trainer.save_model(trl_training_args.output_dir) | |
| logger.info(f"Model saved to {trl_training_args.output_dir}") | |
| # Save everything else on main process | |
| kwargs = { | |
| "dataset_name": script_args.dataset_name, | |
| "tags": ["linalg-zero", "sft", "tool-use", "linear-algebra"], | |
| "model_name": model_args.model_name_or_path, | |
| } | |
| if trainer.accelerator.is_main_process: | |
| trainer.create_model_card(**kwargs) | |
| except Exception: | |
| logger.exception("Failed to save model") | |
| raise | |
| ############ | |
| # Evaluate # | |
| ############ | |
| if trl_training_args.do_eval: | |
| logger.info("*** Final Evaluation on Full Dataset ***") | |
| try: | |
| # Temporarily override max_eval_samples to evaluate on full dataset | |
| original_max_eval_samples = getattr(trl_training_args, "max_eval_samples", None) | |
| trl_training_args.max_eval_samples = training_args.final_eval_max_samples | |
| metrics = trainer.evaluate() | |
| metrics["eval_samples"] = len(dataset[script_args.dataset_test_split]) | |
| trainer.log_metrics("eval", metrics) | |
| trainer.save_metrics("eval", metrics) | |
| logger.info("Evaluation completed successfully!") | |
| # Restore original value | |
| trl_training_args.max_eval_samples = original_max_eval_samples | |
| except Exception: | |
| logger.exception("Evaluation failed") | |
| ############### | |
| # Push to hub # | |
| ############### | |
| if trl_training_args.push_to_hub: | |
| logger.info("*** Pushing to Hub ***") | |
| try: | |
| trainer.push_to_hub(**kwargs) | |
| logger.info("Successfully pushed model to HuggingFace Hub!") | |
| except Exception: | |
| logger.exception("Failed to push to hub") | |
| if __name__ == "__main__": | |
| """Script entry point for SFT training.""" | |
| if "--config" not in sys.argv: | |
| sys.argv.append("--config") | |
| sys.argv.append("linalg_zero/config/sft/qwen2.5-3B/production_merged.yaml") | |
| # sys.argv.append("linalg_zero/config/sft/qwen2.5-3B/production_instruct.yaml") | |
| # sys.argv.append("linalg_zero/config/sft/qwen2.5-3B/production.yaml") | |
| parser = TrlParser([ScriptArguments, SFTRunConfig, SFTConfig, SFTModelConfig]) | |
| script_args, training_args, trl_training_args, model_args = parser.parse_args_and_config() | |
| main(script_args, training_args, trl_training_args, model_args) | |