import argparse from pathlib import Path import os import subprocess import torch from peft import LoraConfig, TaskType, get_peft_model from transformers import ( AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, TrainerCallback, set_seed, ) from config import PATHS, TRAINING_CONFIG from dataset import LocalJsonlInstructionDataset from utils import ensure_dirs, setup_logger # ============================== # šŸ”„ BACKUP CALLBACK # ============================== class BackupCallback(TrainerCallback): def on_save(self, args, state, control, **kwargs): try: checkpoint_dir = os.path.join( args.output_dir, f"checkpoint-{state.global_step}" ) if not os.path.exists(checkpoint_dir): return os.makedirs("backups", exist_ok=True) backup_name = f"backup_step{state.global_step}.tar.gz" backup_path = os.path.join("backups", backup_name) print(f"\n[BACKUP] Creating backup for step {state.global_step}...") subprocess.run([ "tar", "-czf", backup_path, checkpoint_dir ], check=True) print(f"[BACKUP] Saved: {backup_path}") except Exception as e: print(f"[BACKUP ERROR] {e}") # ============================== # MODEL PATH RESOLUTION # ============================== def _is_valid_hf_model_dir(path: Path) -> bool: return path.exists() and (path / "config.json").exists() def _resolve_model_path(logger) -> Path: primary = PATHS.model_dir fallback = Path("./hf_release/MINDI-1.0-420M") if _is_valid_hf_model_dir(primary): return primary if _is_valid_hf_model_dir(fallback): logger.warning( "Primary model missing → using fallback %s", fallback.resolve(), ) return fallback raise FileNotFoundError("No valid model directory found.") # ============================== # BUILD MODEL (FIXED) # ============================== def _build_model_and_tokenizer(model_path: Path): tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True, local_files_only=True, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # šŸ”„ FIXED MODEL LOADING model = AutoModelForCausalLM.from_pretrained( model_path, trust_remote_code=True, use_safetensors=True, # IMPORTANT ) # LoRA lora_cfg = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type=TaskType.CAUSAL_LM, target_modules="all-linear", ) model = get_peft_model(model, lora_cfg) return model, tokenizer # ============================== # CHECKPOINT RESUME (SAFE) # ============================== def get_latest_checkpoint(checkpoint_dir): if not os.path.exists(checkpoint_dir): return None checkpoints = [ d for d in os.listdir(checkpoint_dir) if d.startswith("checkpoint-") ] if not checkpoints: return None checkpoints = sorted( checkpoints, key=lambda x: int(x.split("-")[-1]) ) return os.path.join(checkpoint_dir, checkpoints[-1]) def safe_train(trainer, checkpoint_dir, logger): latest_checkpoint = get_latest_checkpoint(checkpoint_dir) if latest_checkpoint: logger.info(f"Trying resume from: {latest_checkpoint}") try: trainer.train(resume_from_checkpoint=latest_checkpoint) return except Exception as e: logger.warning(f"Resume failed → starting fresh: {e}") trainer.train() # ============================== # MAIN TRAIN # ============================== def train(resume: bool): ensure_dirs([ PATHS.data_dir, PATHS.output_dir, PATHS.logs_dir, PATHS.checkpoint_dir, PATHS.lora_output_dir, PATHS.tokenizer_output_dir, ]) logger = setup_logger("train", PATHS.logs_dir / "train.log") set_seed(42) model_path = _resolve_model_path(logger) logger.info("Loading model from %s", model_path) model, tokenizer = _build_model_and_tokenizer(model_path) model.print_trainable_parameters() train_dataset = LocalJsonlInstructionDataset( tokenizer, max_length=TRAINING_CONFIG.max_length ) training_args = TrainingArguments( output_dir=str(PATHS.checkpoint_dir), num_train_epochs=TRAINING_CONFIG.num_train_epochs, per_device_train_batch_size=TRAINING_CONFIG.per_device_train_batch_size, gradient_accumulation_steps=TRAINING_CONFIG.gradient_accumulation_steps, learning_rate=TRAINING_CONFIG.learning_rate, fp16=torch.cuda.is_available(), logging_steps=50, save_steps=250, save_total_limit=3, report_to="none", remove_unused_columns=False, ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, callbacks=[BackupCallback()], ) logger.info("Starting training...") safe_train(trainer, str(PATHS.checkpoint_dir), logger) trainer.model.save_pretrained(str(PATHS.lora_output_dir)) tokenizer.save_pretrained(str(PATHS.tokenizer_output_dir)) print("\nāœ… Training complete.") # ============================== # ENTRY # ============================== if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--no-resume", action="store_true") args = parser.parse_args() train(resume=not args.no_resume)