import argparse import json import inspect # Added for Transformers version compatibility import math import time from pathlib import Path from typing import Any, Dict, Optional, Tuple, List import torch import yaml from datasets import load_dataset, DatasetDict from huggingface_hub import snapshot_download from transformers import ( AutoTokenizer, AutoModelForCausalLM, AutoModel, AutoConfig, BitsAndBytesConfig, TrainingArguments, Trainer, TrainerCallback, EarlyStoppingCallback, default_data_collator, set_seed, ) from transformers.trainer_utils import get_last_checkpoint from peft import ( LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel, ) try: import wandb WANDB_AVAILABLE = True except ImportError: WANDB_AVAILABLE = False wandb = None # -------------------------- # Helpers # -------------------------- def _dtype_from_str(s: str) -> torch.dtype: s = (s or "").lower() if s in ("float16", "fp16"): return torch.float16 if s in ("bfloat16", "bf16"): return torch.bfloat16 if s in ("float32", "fp32"): return torch.float32 raise ValueError(f"Unknown torch_dtype: {s}") def _now_iso() -> str: return time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime()) def _safe_exp(x: float) -> float: x = min(float(x), 50.0) return float(math.exp(x)) def _ensure_dir(p: Path) -> Path: p.mkdir(parents=True, exist_ok=True) return p def _looks_like_model_dir(p: Path) -> bool: if not p.exists() or not p.is_dir(): return False if (p / "config.json").exists(): return True if any(p.glob("*.safetensors")) or any(p.glob("pytorch_model*.bin")): return True return False def _infer_target_modules(model) -> List[str]: names = set() for n, _ in model.named_modules(): names.add(n.split(".")[-1]) for group in [ ["q_proj", "k_proj", "v_proj", "o_proj"], ["Wqkv", "out_proj"], ["query_key_value", "dense"], ["c_attn", "c_proj"], ]: if all(x in names for x in group): return group fallback = [ x for x in [ "q_proj", "k_proj", "v_proj", "o_proj", "c_attn", "c_proj", "out_proj", "dense", ] if x in names ] if fallback: return fallback raise ValueError( "Could not auto-infer target_modules. Set peft.target_modules explicitly." ) def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]: return cfg.get("model", {}).get("attn_implementation", None) # -------------------------- # Wandb Integration # -------------------------- def setup_wandb(cfg: Dict[str, Any], run_dir: Path): """Initialize Wandb if enabled in configuration.""" wandb_cfg = cfg.get("wandb", {}) if not wandb_cfg.get("enabled", False): print("Wandb logging disabled") return None if not WANDB_AVAILABLE: print("Wandb not available. Install with: pip install wandb") return None # Extract wandb configuration project = wandb_cfg.get("project", "sft-training") entity = wandb_cfg.get("entity", None) name = wandb_cfg.get("name", None) tags = wandb_cfg.get("tags", []) notes = wandb_cfg.get("notes", None) # Initialize wandb try: wandb.init( project=project, entity=entity, name=name, tags=tags, notes=notes, dir=str(run_dir), config={ "model": cfg.get("model", {}), "data": cfg.get("data", {}), "peft": cfg.get("peft", {}), "train": cfg.get("train", {}), "run_dir": str(run_dir), } ) print(f"Wandb initialized: project='{project}', name='{name or 'auto-generated'}'") return wandb except Exception as e: print(f"Failed to initialize Wandb: {e}") return None def finish_wandb(): """Finish Wandb run if active.""" if WANDB_AVAILABLE and wandb.run is not None: wandb.finish() print("Wandb run finished") # -------------------------- # JSONL Logger Callback # -------------------------- class JsonlLoggerCallback(TrainerCallback): def __init__(self, run_dir: Path): self.run_dir = run_dir self.train_log_path = _ensure_dir(run_dir / "logs") / "train.jsonl" self.eval_log_path = _ensure_dir(run_dir / "logs") / "eval.jsonl" self.start_time = None def _eta(self, global_step: int, max_steps: int) -> Optional[str]: if self.start_time is None or global_step <= 0 or max_steps <= 0: return None elapsed = time.time() - self.start_time sec_per_step = elapsed / global_step remaining = max(0, max_steps - global_step) * sec_per_step h = int(remaining // 3600) m = int((remaining % 3600) // 60) s = int(remaining % 60) return f"{h:02d}:{m:02d}:{s:02d}" def on_train_begin(self, args, state, control, **kwargs): self.start_time = time.time() def on_log(self, args, state, control, logs=None, **kwargs): if not logs: return max_steps = int(state.max_steps) if getattr(state, "max_steps", None) else 0 progress_pct = ( (100.0 * state.global_step / max_steps) if max_steps > 0 else None ) epoch_pct = None if ( state.epoch is not None and args.num_train_epochs and args.num_train_epochs > 0 ): epoch_pct = 100.0 * (float(state.epoch) / float(args.num_train_epochs)) payload = { "ts": _now_iso(), "event": "train_log", "step": int(state.global_step), "epoch": round(float(state.epoch), 4) if state.epoch is not None else None, "progress_pct": ( round(progress_pct, 2) if progress_pct is not None else None ), "epoch_pct": round(epoch_pct, 2) if epoch_pct is not None else None, "eta": self._eta(int(state.global_step), max_steps), "max_grad_norm": getattr(args, "max_grad_norm", None), **logs, } with self.train_log_path.open("a", encoding="utf-8") as f: f.write(json.dumps(payload, ensure_ascii=False) + "\n") def on_evaluate(self, args, state, control, metrics=None, **kwargs): if not metrics: return eval_loss = metrics.get("eval_loss", None) ppl = _safe_exp(eval_loss) if eval_loss is not None else None payload = { "ts": _now_iso(), "event": "eval", "step": int(state.global_step), "epoch": float(state.epoch) if state.epoch is not None else None, **metrics, "perplexity": ppl, } with self.eval_log_path.open("a", encoding="utf-8") as f: f.write(json.dumps(payload, ensure_ascii=False) + "\n") # -------------------------- # Data Pipeline (Instruction Formatting) # -------------------------- def format_instruction( example: Dict[str, Any], cfg: Dict[str, Any], tokenizer ) -> Dict[str, Any]: """ Format instruction data for training. Supports multiple formats: chatml, alpaca, custom templates. Returns both formatted text and the response start position for loss masking. """ data_cfg = cfg["data"] format_type = data_cfg.get("format_type", "chatml") # Get field names from config input_field = data_cfg.get("input_field", "input") output_field = data_cfg.get("output_field", "output") instruction_field = data_cfg.get("instruction_field", "instruction") # Extract text from example instruction = example.get(instruction_field, "") input_text = example.get(input_field, "") output_text = example.get(output_field, "") if format_type == "chatml": # ChatML format with special tokens system_prompt = data_cfg.get("system_prompt", "You are a helpful assistant.") messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) user_content = instruction if input_text: user_content = f"{instruction}\n\n{input_text}" messages.append({"role": "user", "content": user_content}) messages.append({"role": "assistant", "content": output_text}) # Apply chat template formatted_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False ) # Add EOS token if not present if tokenizer.eos_token and not formatted_text.endswith(tokenizer.eos_token): formatted_text += tokenizer.eos_token # Find where the assistant response starts for loss masking # Try multiple possible markers for robustness markers = ["<|im_start|>assistant", "<|assistant|>", "Assistant:", "assistant\n"] response_start_pos = -1 for marker in markers: idx = formatted_text.find(marker) if idx != -1: # Find the newline after the marker newline_idx = formatted_text.find("\n", idx) if newline_idx != -1: response_start_pos = newline_idx + 1 break # Fallback: find where the actual output starts if response_start_pos == -1: output_idx = formatted_text.find(output_text) if output_idx != -1: response_start_pos = output_idx else: # Last resort: split at last occurrence of newline before end response_start_pos = formatted_text.rfind("\n", 0, len(formatted_text) - len(output_text)) + 1 elif format_type == "alpaca": # Alpaca format if input_text: prefix = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n" else: prefix = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n" formatted_text = prefix + output_text # Add EOS token if tokenizer.eos_token: formatted_text += tokenizer.eos_token # Response starts after the prefix response_start_pos = len(prefix) elif format_type == "custom": # Custom template from config template = data_cfg.get("custom_template", "{instruction}\n{input}\n{output}") # For custom format, use system_prompt as instruction if instruction field is empty if not instruction: instruction = data_cfg.get("system_prompt", "") # For custom templates, we need to find where {output} starts template_parts = template.split("{output}") prefix = template_parts[0].format(instruction=instruction, input=input_text) formatted_text = prefix + output_text # Add EOS token if not already in template if tokenizer.eos_token and not formatted_text.endswith(tokenizer.eos_token): formatted_text += tokenizer.eos_token # Response starts after the prefix response_start_pos = len(prefix) else: raise ValueError(f"Unsupported format_type: {format_type}") return {"text": formatted_text, "response_start_pos": response_start_pos} def build_datasets(cfg: Dict[str, Any], tokenizer) -> Tuple[Any, Any]: """ Build datasets for instruction fine-tuning. """ data_cfg = cfg["data"] train_path = data_cfg["train_jsonl"] eval_path = data_cfg.get("eval_jsonl", None) split_ratio = float(data_cfg.get("eval_split_ratio", 0.0)) max_length = int(data_cfg.get("max_length", 2048)) shuffle = bool(data_cfg.get("shuffle", True)) num_proc = int(data_cfg.get("num_proc", 4)) # Ensure tokenizer has pad token if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load datasets ds = load_dataset("json", data_files={"train": train_path}) if eval_path: ds_eval = load_dataset("json", data_files={"eval": eval_path}) dsd = DatasetDict({"train": ds["train"], "eval": ds_eval["eval"]}) else: if 0.0 < split_ratio < 1.0: split = ds["train"].train_test_split( test_size=split_ratio, seed=int(cfg["run"].get("seed", 42)) ) dsd = DatasetDict({"train": split["train"], "eval": split["test"]}) else: dsd = DatasetDict({"train": ds["train"], "eval": None}) # Format instructions and track response start positions def format_fn(examples): formatted_examples = [] response_start_positions = [] for i in range(len(examples[list(examples.keys())[0]])): example = {k: examples[k][i] for k in examples.keys()} formatted = format_instruction(example, cfg, tokenizer) formatted_examples.append(formatted["text"]) response_start_positions.append(formatted["response_start_pos"]) return { "text": formatted_examples, "response_start_pos": response_start_positions } formatted_train = dsd["train"].map( format_fn, batched=True, num_proc=num_proc, remove_columns=dsd["train"].column_names, desc="Formatting train instructions", ) formatted_eval = None if dsd["eval"] is not None: formatted_eval = dsd["eval"].map( format_fn, batched=True, num_proc=num_proc, remove_columns=dsd["eval"].column_names, desc="Formatting eval instructions", ) # Tokenize and apply loss masking def tokenize_and_mask_fn(examples): tokenized = tokenizer( examples["text"], truncation=True, padding=False, max_length=max_length, return_overflowing_tokens=False, ) # Apply loss masking - CRITICAL for SFT labels = [] attention_masks = [] for i in range(len(tokenized["input_ids"])): input_ids = tokenized["input_ids"][i] response_start_pos = examples["response_start_pos"][i] # Get the instruction part (before response) full_text = examples["text"][i] instruction_text = full_text[:response_start_pos] # Create labels masked by default label_ids = [-100] * len(input_ids) # Find where response starts using character-based ratio # This is more reliable than tokenizing prefix separately # because separate tokenization can add different special tokens char_ratio = response_start_pos / max(len(full_text), 1) response_start_idx = int(len(input_ids) * char_ratio) # Ensure we have valid bounds (at least position 1, at most len-1) response_start_idx = max(1, min(response_start_idx, len(input_ids) - 1)) # Unmask response tokens (including EOS) for j in range(response_start_idx, len(input_ids)): label_ids[j] = input_ids[j] # Create attention mask (1 for real tokens, 0 for padding) attention_mask = [1] * len(input_ids) labels.append(label_ids) attention_masks.append(attention_mask) tokenized["labels"] = labels tokenized["attention_mask"] = attention_masks return tokenized tokenized_train = formatted_train.map( tokenize_and_mask_fn, batched=True, num_proc=num_proc, desc="Tokenizing and masking train", ) tokenized_eval = None if formatted_eval is not None: tokenized_eval = formatted_eval.map( tokenize_and_mask_fn, batched=True, num_proc=num_proc, desc="Tokenizing and masking eval", ) if shuffle: tokenized_train = tokenized_train.shuffle(seed=int(cfg["run"].get("seed", 42))) return tokenized_train, tokenized_eval # -------------------------- # Model Loading + PEFT # -------------------------- def load_base_model_and_tokenizer(cfg: Dict[str, Any], base_dir: Path): model_cfg = cfg["model"] trust_remote_code = bool(model_cfg.get("trust_remote_code", True)) use_fast = bool(model_cfg.get("tokenizer_use_fast", True)) device_map = model_cfg.get("device_map", "auto") tokenizer = AutoTokenizer.from_pretrained( str(base_dir), use_fast=use_fast, trust_remote_code=trust_remote_code, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token torch_dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16")) use_4bit = bool(model_cfg.get("use_4bit", False)) quant_cfg = None if use_4bit: quant_cfg = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type=str(model_cfg.get("bnb_4bit_quant_type", "nf4")), bnb_4bit_use_double_quant=bool( model_cfg.get("bnb_4bit_use_double_quant", True) ), bnb_4bit_compute_dtype=_dtype_from_str( model_cfg.get("bnb_4bit_compute_dtype", "bfloat16") ), ) attn_impl = _choose_attn_impl(cfg) # First check the model type to determine loading strategy try: config = AutoConfig.from_pretrained(str(base_dir), trust_remote_code=True) model_type = config.model_type architectures = getattr(config, 'architectures', []) # Handle Mistral3 (multimodal) models if model_type == "mistral3" or (architectures and "Mistral3" in architectures[0]): print(f"[info] Detected Mistral3 model architecture, loading with specific class") from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration try: model = Mistral3ForConditionalGeneration.from_pretrained( str(base_dir), config=config, device_map=device_map, low_cpu_mem_usage=True, torch_dtype=(torch_dtype if not use_4bit else None), quantization_config=quant_cfg, attn_implementation=attn_impl, ) except Exception as e: if attn_impl is not None: print(f"[warn] attn_implementation='{attn_impl}' failed: {e}") print("[warn] Falling back to default attention implementation.") model = Mistral3ForConditionalGeneration.from_pretrained( str(base_dir), config=config, device_map=device_map, low_cpu_mem_usage=True, torch_dtype=(torch_dtype if not use_4bit else None), quantization_config=quant_cfg, ) else: raise e else: # Standard AutoModelForCausalLM loading for other models try: model = AutoModelForCausalLM.from_pretrained( str(base_dir), device_map=device_map, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=(torch_dtype if not use_4bit else None), quantization_config=quant_cfg, attn_implementation=attn_impl, ) except Exception as e: if attn_impl is not None: print(f"[warn] attn_implementation='{attn_impl}' failed: {e}") print("[warn] Falling back to default attention implementation.") model = AutoModelForCausalLM.from_pretrained( str(base_dir), device_map=device_map, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=(torch_dtype if not use_4bit else None), quantization_config=quant_cfg, ) else: raise e except Exception as e: print(f"[error] Failed to load model: {e}") raise e # Ensure all parameters are off meta device print("[info] Ensuring all parameters are materialized...") meta_params = [] for name, param in model.named_parameters(): if param.device.type == 'meta': meta_params.append(name) if meta_params: print(f"[warn] Found {len(meta_params)} parameters on meta device") # For multimodal models, freeze vision components if doing text-only training if hasattr(model, 'vision_tower'): print("[info] Freezing vision tower for text-only training") for param in model.vision_tower.parameters(): param.requires_grad = False return model, tokenizer def apply_peft(cfg: Dict[str, Any], model): peft_cfg = cfg["peft"] model_cfg = cfg["model"] tr_cfg = cfg["train"] if not bool(peft_cfg.get("enabled", True)): return model, None use_4bit = bool(model_cfg.get("use_4bit", False)) gradient_checkpointing = bool(tr_cfg.get("gradient_checkpointing", True)) # For multimodal models, ensure vision tower doesn't use gradient checkpointing if gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"): if hasattr(model, 'vision_tower'): print("[info] Disabling gradient checkpointing for vision tower") # Only enable gradient checkpointing on language model if hasattr(model, 'language_model'): model.language_model.gradient_checkpointing_enable() elif hasattr(model, 'lm_head'): model.gradient_checkpointing_enable() else: model.gradient_checkpointing_enable() if hasattr(model, "config"): model.config.use_cache = False if use_4bit: model = prepare_model_for_kbit_training( model, use_gradient_checkpointing=gradient_checkpointing, ) target_modules = peft_cfg.get("target_modules", "auto") if target_modules == "auto": target_modules = _infer_target_modules(model) # For multimodal models, ensure we only target language model modules if hasattr(model, 'vision_tower') and isinstance(target_modules, list): print(f"[info] Filtering target modules to exclude vision tower") # Filter out any vision tower modules target_modules = [m for m in target_modules if 'vision' not in m.lower()] print(f"[info] LoRA target modules: {target_modules}") lora_config = LoraConfig( r=int(peft_cfg.get("r", 16)), lora_alpha=int(peft_cfg.get("lora_alpha", 32)), lora_dropout=float(peft_cfg.get("lora_dropout", 0.05)), bias=str(peft_cfg.get("bias", "none")), task_type="CAUSAL_LM", target_modules=target_modules, modules_to_save=None, # Don't update any additional modules ) model = get_peft_model(model, lora_config) return model, lora_config # -------------------------- # Merge Logic # -------------------------- def merge_adapter( cfg: Dict[str, Any], base_dir: Path, adapter_dir: Path, final_dir: Path ): print(f"--- Merge: {adapter_dir} + {base_dir} -> {final_dir} ---") model_cfg = cfg["model"] merge_cfg = cfg.get("merge", {}) trust_remote_code = bool(model_cfg.get("trust_remote_code", True)) merged_dtype = _dtype_from_str(merge_cfg.get("merged_dtype", "float16")) max_shard_size = str(merge_cfg.get("max_shard_size", "2GB")) base = AutoModelForCausalLM.from_pretrained( str(base_dir), torch_dtype=merged_dtype, device_map="cpu", low_cpu_mem_usage=True, trust_remote_code=trust_remote_code, ) merged = PeftModel.from_pretrained(base, str(adapter_dir)) merged = merged.merge_and_unload() _ensure_dir(final_dir) merged.save_pretrained( str(final_dir), safe_serialization=True, max_shard_size=max_shard_size ) tok = AutoTokenizer.from_pretrained( str(base_dir), trust_remote_code=trust_remote_code ) if tok.pad_token is None: tok.pad_token = tok.eos_token tok.save_pretrained(str(final_dir)) print("--- Merge complete ---") # -------------------------- # Main # -------------------------- def main(): ap = argparse.ArgumentParser() ap.add_argument("--config", required=True, help="Path to YAML config") ap.add_argument( "--merge-only", action="store_true", help="Skip training, just merge adapter" ) args = ap.parse_args() with open(args.config, "r", encoding="utf-8") as f: cfg = yaml.safe_load(f) run_dir = _ensure_dir(Path(cfg["run"]["run_dir"])) _ensure_dir(run_dir / "logs") with (run_dir / "config_resolved.yaml").open("w", encoding="utf-8") as f: yaml.safe_dump(cfg, f, sort_keys=False) model_cfg = cfg["model"] repo_id = str(model_cfg["repo_id"]).strip() repo_path = Path(repo_id) # ✅ Local model path -> load directly; no download if repo_path.exists() and repo_path.is_dir() and _looks_like_model_dir(repo_path): base_dir = repo_path print(f"Using local model at: {base_dir}") elif repo_path.exists() and repo_path.is_dir(): raise ValueError( f"model.repo_id points to a directory, but it doesn't look like a HF model dir: {base_dir}" ) else: # HF repo_id -> download into run_dir/base_local_dir base_dir = _ensure_dir(run_dir / model_cfg.get("base_local_dir", "base_model")) if not _looks_like_model_dir(base_dir): print(f"Base model not found at {base_dir}, downloading from {repo_id} ...") snapshot_download( repo_id=repo_id, revision=model_cfg.get("revision", None), local_dir=str(base_dir), local_dir_use_symlinks=False, ) ckpt_dir = _ensure_dir(run_dir / "checkpoints") best_adapter_dir = _ensure_dir(run_dir / "best_adapter") merge_cfg = cfg.get("merge", {}) or {} if merge_cfg.get("output_dir"): od = Path(str(merge_cfg["output_dir"])) final_dir = od if od.is_absolute() else (run_dir / od) else: final_dir = run_dir / "final_model" # Merge-only if args.merge_only: if not _looks_like_model_dir(best_adapter_dir): raise FileNotFoundError(f"Adapter not found at {best_adapter_dir}") merge_adapter(cfg, base_dir, best_adapter_dir, final_dir) return # Initialize Wandb wandb_run = setup_wandb(cfg, run_dir) # Training set_seed(int(cfg["run"].get("seed", 42))) model, tokenizer = load_base_model_and_tokenizer(cfg, base_dir) model, _ = apply_peft(cfg, model) train_ds, eval_ds = build_datasets(cfg, tokenizer) tr_cfg = cfg["train"] dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16")) use_fp16 = dtype == torch.float16 use_bf16 = dtype == torch.bfloat16 max_steps = int(tr_cfg.get("max_steps", 0)) num_train_epochs = float(tr_cfg.get("num_train_epochs", 1)) # --- Dynamic evaluation strategy parameter handling --- ta_params = inspect.signature(TrainingArguments.__init__).parameters eval_key = ( "eval_strategy" if "eval_strategy" in ta_params else "evaluation_strategy" ) # Setup reporting based on wandb availability report_to = [] if wandb_run is not None: report_to.append("wandb") ta_kwargs = dict( output_dir=str(ckpt_dir), max_steps=max_steps if max_steps > 0 else -1, num_train_epochs=num_train_epochs, per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 1)), per_device_eval_batch_size=int( tr_cfg.get( "per_device_eval_batch_size", tr_cfg.get("per_device_train_batch_size", 1), ) ), gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 1)), learning_rate=float(tr_cfg.get("learning_rate", 2e-5)), weight_decay=float(tr_cfg.get("weight_decay", 0.0)), warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)), lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "cosine")), optim=str( tr_cfg.get( "optim", ( "paged_adamw_8bit" if bool(model_cfg.get("use_4bit", False)) else "adamw_torch" ), ) ), max_grad_norm=float(tr_cfg.get("max_grad_norm", 1.0)), logging_steps=int(tr_cfg.get("logging_steps", 10)), save_strategy=str(tr_cfg.get("save_strategy", "steps")), save_steps=int(tr_cfg.get("save_steps", 200)), save_total_limit=int(tr_cfg.get("save_total_limit", 3)), eval_steps=int(tr_cfg.get("eval_steps", 200)), load_best_model_at_end=( bool(tr_cfg.get("load_best_model_at_end", True)) if eval_ds is not None else False ), metric_for_best_model="eval_loss", greater_is_better=False, fp16=use_fp16, bf16=use_bf16, report_to=report_to, remove_unused_columns=False, ) # Set the correct argument name for this transformers version ta_kwargs[eval_key] = str( tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no") ) training_args = TrainingArguments(**ta_kwargs) # Setup callbacks callbacks = [JsonlLoggerCallback(run_dir)] # Add early stopping callback if enabled early_stopping_cfg = tr_cfg.get("early_stopping", {}) if early_stopping_cfg.get("enabled", False) and eval_ds is not None: early_stopping_callback = EarlyStoppingCallback( early_stopping_patience=int(early_stopping_cfg.get("patience", 3)), early_stopping_threshold=float(early_stopping_cfg.get("min_delta", 0.001)), ) callbacks.append(early_stopping_callback) print(f"Early stopping enabled: patience={early_stopping_cfg.get('patience', 3)}, " f"min_delta={early_stopping_cfg.get('min_delta', 0.001)}") trainer = Trainer( model=model, args=training_args, train_dataset=train_ds, eval_dataset=eval_ds, data_collator=default_data_collator, callbacks=callbacks, ) # Resume resume_from = tr_cfg.get("resume_from_checkpoint", None) if resume_from == "auto": last = get_last_checkpoint(str(ckpt_dir)) resume_from = last if last else None if resume_from: print(f"Resuming from {resume_from}") print("Starting instruction fine-tuning...") trainer.train(resume_from_checkpoint=resume_from) trainer.save_model(str(best_adapter_dir)) print(f"Saved best adapter -> {best_adapter_dir}") if eval_ds is not None: metrics = trainer.evaluate() eval_loss = metrics.get("eval_loss", None) metrics["perplexity"] = _safe_exp(eval_loss) if eval_loss is not None else None with (run_dir / "eval_final.json").open("w", encoding="utf-8") as f: json.dump(metrics, f, indent=2) print(f"Final eval_loss={eval_loss}, ppl={metrics['perplexity']}") if bool(cfg.get("merge", {}).get("enabled", False)): del trainer, model torch.cuda.empty_cache() merge_adapter(cfg, base_dir, best_adapter_dir, final_dir) else: print("Merge disabled. Run with --merge-only later if needed.") # Finish Wandb run finish_wandb() if __name__ == "__main__": main()