import logging import os from pathlib import Path import torch from datasets import DatasetDict from datasets import load_dataset as hf_load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel from transformers.tokenization_utils import PreTrainedTokenizer from trl.trainer.model_config import ModelConfig from trl.trainer.utils import get_kbit_device_map, get_quantization_config from unsloth import FastLanguageModel from unsloth.tokenizer_utils import SFTConfig from linalg_zero.config.data import ScriptArguments, SFTModelConfig, SFTRunConfig from linalg_zero.shared.system_prompts import ( ANSWER_CLOSE, ANSWER_OPEN, THINK_CLOSE, THINK_OPEN, TOOL_CALL_CLOSE, TOOL_CALL_OPEN, ) logger = logging.getLogger(__name__) def is_using_deepspeed() -> bool: """Check if DeepSpeed is being used via environment variables""" return ( os.environ.get("LOCAL_RANK") is not None or os.environ.get("ACCELERATE_USE_DEEPSPEED", "false").lower() == "true" or "deepspeed" in os.environ.get("ACCELERATE_CONFIG_FILE", "").lower() ) def ensure_tokenizer_has_defaults(tokenizer: PreTrainedTokenizer, model: PreTrainedModel) -> None: if getattr(tokenizer, "pad_token_id", None) is None: tokenizer.pad_token_id = tokenizer.eos_token_id if tokenizer.padding_side != "right": tokenizer.padding_side = "right" if getattr(model, "config", None) is not None: model.config.pad_token_id = tokenizer.pad_token_id model.config.eos_token_id = tokenizer.eos_token_id if getattr(model, "generation_config", None) is not None: assert model.generation_config is not None, "Generation config is not set" model.generation_config.pad_token_id = tokenizer.pad_token_id model.generation_config.eos_token_id = tokenizer.eos_token_id def init_wandb_training(training_args: SFTRunConfig) -> None: """Initialize Weights & Biases for training logging.""" try: # Set environment variables for wandb if training_args.wandb_entity is not None: os.environ["WANDB_ENTITY"] = training_args.wandb_entity if training_args.wandb_project is not None: os.environ["WANDB_PROJECT"] = training_args.wandb_project if training_args.wandb_run_group is not None: os.environ["WANDB_RUN_GROUP"] = training_args.wandb_run_group if training_args.wandb_run_id is not None: os.environ["WANDB_RUN_ID"] = training_args.wandb_run_id os.environ["WANDB_RESUME"] = "allow" logger.info("Set wandb environment variables from training args") except Exception: logger.exception("Failed to initialize wandb environment") def get_tokenizer(model_args: ModelConfig, training_args: SFTRunConfig) -> PreTrainedTokenizer: """Get the tokenizer for the model.""" tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, ) if training_args.chat_template is not None: tokenizer.chat_template = training_args.chat_template return tokenizer def load_model_for_evaluation( model_path: str, max_seq_length: int = 2048, dtype: torch.dtype | None = None, ) -> tuple[PreTrainedModel, PreTrainedTokenizer]: """ Load a trained model for evaluation/inference. """ model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_path, max_seq_length=max_seq_length, dtype=dtype, load_in_4bit=False, ) FastLanguageModel.for_inference(model) return model, tokenizer def add_special_tokens_and_resize( model: PreTrainedModel, tokenizer: PreTrainedTokenizer, ) -> bool: """ Add special reasoning/tool-calling tokens to tokenizer and resize model embeddings if needed. Returns True if any new tokens were added (regardless of whether a resize was needed), False if no new tokens were added. """ special_tags = [THINK_OPEN, THINK_CLOSE, TOOL_CALL_OPEN, TOOL_CALL_CLOSE, ANSWER_OPEN, ANSWER_CLOSE] num_added = tokenizer.add_special_tokens({"additional_special_tokens": special_tags}) if num_added and num_added > 0: tok_vocab = len(tokenizer) model_vocab = model.get_input_embeddings().weight.size(0) # Mark embeddings as trainable so new token rows can be updated. model._need_to_train_embeddings = True if tok_vocab > model_vocab: pad_to_multiple_of = 128 logger.info( "Added %s special tokens; resizing embeddings %s -> %s (padded to multiple of %s).", num_added, model_vocab, tok_vocab, pad_to_multiple_of, ) model.resize_token_embeddings(tok_vocab, pad_to_multiple_of=pad_to_multiple_of) return True else: logger.info( "Added %s special tokens but model vocab (%s) already >= tokenizer vocab (%s); " "skipping embedding resize.", num_added, model_vocab, tok_vocab, ) return True else: logger.info("No new special tokens added (tokens likely already present). Skipping resize.") return False def load_merged_model_for_sft( model_path: str, max_seq_length: int = 2048, dtype: torch.dtype | None = None, train_io_only: bool = False, add_special_tokens: bool = False, ) -> tuple[PreTrainedModel, PreTrainedTokenizer]: """ Load a merged (non-LoRA) model for a light SFT touch-up. - `model_path` should point to the merged checkpoint directory (e.g. \"results/LinalgZero-SFT-merged\"). - If `train_io_only` is True, all parameters are frozen except: * input embeddings (`embed_tokens`) * output head (`lm_head` / output embeddings) - If `add_special_tokens` is True, adds reasoning/tool-calling tokens and resizes embeddings """ # Load with Unsloth wrapper for consistent config handling model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_path, max_seq_length=max_seq_length, dtype=dtype, load_in_4bit=False, load_in_8bit=False, ) # Make sure pad / eos are wired correctly before training ensure_tokenizer_has_defaults(tokenizer, model) # Optionally add special tokens and resize embeddings if add_special_tokens: add_special_tokens_and_resize(model, tokenizer) if train_io_only: # Freeze everything for param in model.parameters(): param.requires_grad = False # Unfreeze embeddings for param in model.get_input_embeddings().parameters(): param.requires_grad = True # Unfreeze LM head / output embeddings output_layer = getattr(model, "lm_head", None) if output_layer is None: output_layer = model.get_output_embeddings() for param in output_layer.parameters(): param.requires_grad = True return model, tokenizer def get_unsloth_model( model_args: SFTModelConfig, training_args: SFTRunConfig, trl_training_args: SFTConfig, resume_path: str | None = None, use_vllm: bool = False, ) -> tuple[FastLanguageModel, PreTrainedTokenizer]: """Fetch the model and optimizer for training.""" # Checkpoint loading is handled by the Trainer via `resume_from_checkpoint`. # We keep `resume_path` for API compatibility but do not use it here. if resume_path is not None: logger.info( "Received resume_path=%s in get_unsloth_model, but checkpoint loading is " "handled by the Trainer. Ignoring this argument.", resume_path, ) model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_args.model_name_or_path, max_seq_length=training_args.max_seq_length, load_in_4bit=model_args.load_in_4bit, load_in_8bit=model_args.load_in_8bit, max_lora_rank=model_args.lora_r, # enforce_eager=model_args.enforce_eager, fast_inference=use_vllm, gpu_memory_utilization=training_args.gpu_memory_utilization, ) # Add special tokens and resize embeddings has_added_tokens = False if training_args.add_special_tokens: has_added_tokens = add_special_tokens_and_resize(model, tokenizer) model = FastLanguageModel.get_peft_model( model, r=model_args.lora_r, modules_to_save=["embed_tokens", "lm_head"] if has_added_tokens else None, target_modules=model_args.lora_target_modules, lora_alpha=model_args.lora_alpha, use_gradient_checkpointing="unsloth", random_state=3407, ensure_weight_tying=True, ) if trl_training_args.chat_template_path is not None: template_path = Path(trl_training_args.chat_template_path) tokenizer.chat_template = template_path.read_text() if training_args.chat_template is not None: tokenizer.chat_template = training_args.chat_template has_user_template = training_args.chat_template is not None has_config_template = trl_training_args.chat_template_path is not None assert has_user_template ^ has_config_template, ( "Exactly one of tokenizer.chat_template or chat_template_path must be set, not both or neither" ) return model, tokenizer def get_model(model_args: ModelConfig, training_args: SFTRunConfig) -> AutoModelForCausalLM: """Get the model""" torch_dtype = model_args.torch_dtype if torch_dtype not in (None, "auto"): assert torch_dtype is not None torch_dtype = getattr(torch, torch_dtype) quantization_config = get_quantization_config(model_args) using_deepspeed = is_using_deepspeed() device_map = None if quantization_config is not None and not using_deepspeed: device_map = get_kbit_device_map() logger.info(f"Setting device_map: {device_map}") else: # Device map is not compatible with quantization and deepspeed ZeRO-3`` logger.info("Not setting device_map (DeepSpeed detected or no quantization)") model_kwargs = { "revision": model_args.model_revision, "trust_remote_code": model_args.trust_remote_code, "attn_implementation": model_args.attn_implementation, "torch_dtype": torch_dtype, "use_cache": not training_args.gradient_checkpointing, "device_map": device_map, "quantization_config": quantization_config, } if model_args.model_name_or_path is None: raise ValueError("model_name_or_path must be set for loading the model") model: AutoModelForCausalLM = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) return model def load_dataset(args: ScriptArguments) -> DatasetDict: """Load the dataset produced during the distillation step, removing unnecessary columns for SFT.""" def remove_redundant_columns(dataset: DatasetDict) -> DatasetDict: """Remove columns from a dataset.""" if dataset.column_names: splits = dict(dataset.column_names.items()) # Remove any redundant columns not using during SFT training. Only 'tools' and 'messages' are relevant. dataset = dataset.remove_columns([ col for split in splits.values() if split is not None for col in split if col not in ["tools", "messages"] ]) return dataset dataset = hf_load_dataset(args.dataset_name, args.dataset_config) if args.take_n is not None: dataset = dataset.select(range(args.take_n)) # Only the ["messages", "tools"] columns are relevant for SFT return remove_redundant_columns(dataset)