Spaces:
Running on Zero
Running on Zero
File size: 12,006 Bytes
0dd6c2f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 | import logging
import os
from pathlib import Path
import torch
from datasets import DatasetDict
from datasets import load_dataset as hf_load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from trl.trainer.model_config import ModelConfig
from trl.trainer.utils import get_kbit_device_map, get_quantization_config
from unsloth import FastLanguageModel
from unsloth.tokenizer_utils import SFTConfig
from linalg_zero.config.data import ScriptArguments, SFTModelConfig, SFTRunConfig
from linalg_zero.shared.system_prompts import (
ANSWER_CLOSE,
ANSWER_OPEN,
THINK_CLOSE,
THINK_OPEN,
TOOL_CALL_CLOSE,
TOOL_CALL_OPEN,
)
logger = logging.getLogger(__name__)
def is_using_deepspeed() -> bool:
"""Check if DeepSpeed is being used via environment variables"""
return (
os.environ.get("LOCAL_RANK") is not None
or os.environ.get("ACCELERATE_USE_DEEPSPEED", "false").lower() == "true"
or "deepspeed" in os.environ.get("ACCELERATE_CONFIG_FILE", "").lower()
)
def ensure_tokenizer_has_defaults(tokenizer: PreTrainedTokenizer, model: PreTrainedModel) -> None:
if getattr(tokenizer, "pad_token_id", None) is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
if tokenizer.padding_side != "right":
tokenizer.padding_side = "right"
if getattr(model, "config", None) is not None:
model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.eos_token_id
if getattr(model, "generation_config", None) is not None:
assert model.generation_config is not None, "Generation config is not set"
model.generation_config.pad_token_id = tokenizer.pad_token_id
model.generation_config.eos_token_id = tokenizer.eos_token_id
def init_wandb_training(training_args: SFTRunConfig) -> None:
"""Initialize Weights & Biases for training logging."""
try:
# Set environment variables for wandb
if training_args.wandb_entity is not None:
os.environ["WANDB_ENTITY"] = training_args.wandb_entity
if training_args.wandb_project is not None:
os.environ["WANDB_PROJECT"] = training_args.wandb_project
if training_args.wandb_run_group is not None:
os.environ["WANDB_RUN_GROUP"] = training_args.wandb_run_group
if training_args.wandb_run_id is not None:
os.environ["WANDB_RUN_ID"] = training_args.wandb_run_id
os.environ["WANDB_RESUME"] = "allow"
logger.info("Set wandb environment variables from training args")
except Exception:
logger.exception("Failed to initialize wandb environment")
def get_tokenizer(model_args: ModelConfig, training_args: SFTRunConfig) -> PreTrainedTokenizer:
"""Get the tokenizer for the model."""
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
)
if training_args.chat_template is not None:
tokenizer.chat_template = training_args.chat_template
return tokenizer
def load_model_for_evaluation(
model_path: str,
max_seq_length: int = 2048,
dtype: torch.dtype | None = None,
) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
"""
Load a trained model for evaluation/inference.
"""
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_path,
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=False,
)
FastLanguageModel.for_inference(model)
return model, tokenizer
def add_special_tokens_and_resize(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
) -> bool:
"""
Add special reasoning/tool-calling tokens to tokenizer and resize model embeddings if needed.
Returns True if any new tokens were added (regardless of whether a resize was needed),
False if no new tokens were added.
"""
special_tags = [THINK_OPEN, THINK_CLOSE, TOOL_CALL_OPEN, TOOL_CALL_CLOSE, ANSWER_OPEN, ANSWER_CLOSE]
num_added = tokenizer.add_special_tokens({"additional_special_tokens": special_tags})
if num_added and num_added > 0:
tok_vocab = len(tokenizer)
model_vocab = model.get_input_embeddings().weight.size(0)
# Mark embeddings as trainable so new token rows can be updated.
model._need_to_train_embeddings = True
if tok_vocab > model_vocab:
pad_to_multiple_of = 128
logger.info(
"Added %s special tokens; resizing embeddings %s -> %s (padded to multiple of %s).",
num_added,
model_vocab,
tok_vocab,
pad_to_multiple_of,
)
model.resize_token_embeddings(tok_vocab, pad_to_multiple_of=pad_to_multiple_of)
return True
else:
logger.info(
"Added %s special tokens but model vocab (%s) already >= tokenizer vocab (%s); "
"skipping embedding resize.",
num_added,
model_vocab,
tok_vocab,
)
return True
else:
logger.info("No new special tokens added (tokens likely already present). Skipping resize.")
return False
def load_merged_model_for_sft(
model_path: str,
max_seq_length: int = 2048,
dtype: torch.dtype | None = None,
train_io_only: bool = False,
add_special_tokens: bool = False,
) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
"""
Load a merged (non-LoRA) model for a light SFT touch-up.
- `model_path` should point to the merged checkpoint directory
(e.g. \"results/LinalgZero-SFT-merged\").
- If `train_io_only` is True, all parameters are frozen except:
* input embeddings (`embed_tokens`)
* output head (`lm_head` / output embeddings)
- If `add_special_tokens` is True, adds reasoning/tool-calling tokens and resizes embeddings
"""
# Load with Unsloth wrapper for consistent config handling
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_path,
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=False,
load_in_8bit=False,
)
# Make sure pad / eos are wired correctly before training
ensure_tokenizer_has_defaults(tokenizer, model)
# Optionally add special tokens and resize embeddings
if add_special_tokens:
add_special_tokens_and_resize(model, tokenizer)
if train_io_only:
# Freeze everything
for param in model.parameters():
param.requires_grad = False
# Unfreeze embeddings
for param in model.get_input_embeddings().parameters():
param.requires_grad = True
# Unfreeze LM head / output embeddings
output_layer = getattr(model, "lm_head", None)
if output_layer is None:
output_layer = model.get_output_embeddings()
for param in output_layer.parameters():
param.requires_grad = True
return model, tokenizer
def get_unsloth_model(
model_args: SFTModelConfig,
training_args: SFTRunConfig,
trl_training_args: SFTConfig,
resume_path: str | None = None,
use_vllm: bool = False,
) -> tuple[FastLanguageModel, PreTrainedTokenizer]:
"""Fetch the model and optimizer for training."""
# Checkpoint loading is handled by the Trainer via `resume_from_checkpoint`.
# We keep `resume_path` for API compatibility but do not use it here.
if resume_path is not None:
logger.info(
"Received resume_path=%s in get_unsloth_model, but checkpoint loading is "
"handled by the Trainer. Ignoring this argument.",
resume_path,
)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_args.model_name_or_path,
max_seq_length=training_args.max_seq_length,
load_in_4bit=model_args.load_in_4bit,
load_in_8bit=model_args.load_in_8bit,
max_lora_rank=model_args.lora_r,
# enforce_eager=model_args.enforce_eager,
fast_inference=use_vllm,
gpu_memory_utilization=training_args.gpu_memory_utilization,
)
# Add special tokens and resize embeddings
has_added_tokens = False
if training_args.add_special_tokens:
has_added_tokens = add_special_tokens_and_resize(model, tokenizer)
model = FastLanguageModel.get_peft_model(
model,
r=model_args.lora_r,
modules_to_save=["embed_tokens", "lm_head"] if has_added_tokens else None,
target_modules=model_args.lora_target_modules,
lora_alpha=model_args.lora_alpha,
use_gradient_checkpointing="unsloth",
random_state=3407,
ensure_weight_tying=True,
)
if trl_training_args.chat_template_path is not None:
template_path = Path(trl_training_args.chat_template_path)
tokenizer.chat_template = template_path.read_text()
if training_args.chat_template is not None:
tokenizer.chat_template = training_args.chat_template
has_user_template = training_args.chat_template is not None
has_config_template = trl_training_args.chat_template_path is not None
assert has_user_template ^ has_config_template, (
"Exactly one of tokenizer.chat_template or chat_template_path must be set, not both or neither"
)
return model, tokenizer
def get_model(model_args: ModelConfig, training_args: SFTRunConfig) -> AutoModelForCausalLM:
"""Get the model"""
torch_dtype = model_args.torch_dtype
if torch_dtype not in (None, "auto"):
assert torch_dtype is not None
torch_dtype = getattr(torch, torch_dtype)
quantization_config = get_quantization_config(model_args)
using_deepspeed = is_using_deepspeed()
device_map = None
if quantization_config is not None and not using_deepspeed:
device_map = get_kbit_device_map()
logger.info(f"Setting device_map: {device_map}")
else:
# Device map is not compatible with quantization and deepspeed ZeRO-3``
logger.info("Not setting device_map (DeepSpeed detected or no quantization)")
model_kwargs = {
"revision": model_args.model_revision,
"trust_remote_code": model_args.trust_remote_code,
"attn_implementation": model_args.attn_implementation,
"torch_dtype": torch_dtype,
"use_cache": not training_args.gradient_checkpointing,
"device_map": device_map,
"quantization_config": quantization_config,
}
if model_args.model_name_or_path is None:
raise ValueError("model_name_or_path must be set for loading the model")
model: AutoModelForCausalLM = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
return model
def load_dataset(args: ScriptArguments) -> DatasetDict:
"""Load the dataset produced during the distillation step, removing unnecessary columns for SFT."""
def remove_redundant_columns(dataset: DatasetDict) -> DatasetDict:
"""Remove columns from a dataset."""
if dataset.column_names:
splits = dict(dataset.column_names.items())
# Remove any redundant columns not using during SFT training. Only 'tools' and 'messages' are relevant.
dataset = dataset.remove_columns([
col
for split in splits.values()
if split is not None
for col in split
if col not in ["tools", "messages"]
])
return dataset
dataset = hf_load_dataset(args.dataset_name, args.dataset_config)
if args.take_n is not None:
dataset = dataset.select(range(args.take_n))
# Only the ["messages", "tools"] columns are relevant for SFT
return remove_redundant_columns(dataset)
|