| import argparse |
| from pathlib import Path |
| import os |
| import subprocess |
|
|
| import torch |
| from peft import LoraConfig, TaskType, get_peft_model |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| Trainer, |
| TrainingArguments, |
| TrainerCallback, |
| set_seed, |
| ) |
|
|
| from config import PATHS, TRAINING_CONFIG |
| from dataset import LocalJsonlInstructionDataset |
| from utils import ensure_dirs, setup_logger |
|
|
|
|
| |
| |
| |
| class BackupCallback(TrainerCallback): |
| def on_save(self, args, state, control, **kwargs): |
| try: |
| checkpoint_dir = os.path.join( |
| args.output_dir, |
| f"checkpoint-{state.global_step}" |
| ) |
|
|
| if not os.path.exists(checkpoint_dir): |
| return |
|
|
| os.makedirs("backups", exist_ok=True) |
|
|
| backup_name = f"backup_step{state.global_step}.tar.gz" |
| backup_path = os.path.join("backups", backup_name) |
|
|
| print(f"\n[BACKUP] Creating backup for step {state.global_step}...") |
|
|
| subprocess.run([ |
| "tar", "-czf", backup_path, checkpoint_dir |
| ], check=True) |
|
|
| print(f"[BACKUP] Saved: {backup_path}") |
|
|
| except Exception as e: |
| print(f"[BACKUP ERROR] {e}") |
|
|
|
|
| |
| |
| |
| def _is_valid_hf_model_dir(path: Path) -> bool: |
| return path.exists() and (path / "config.json").exists() |
|
|
|
|
| def _resolve_model_path(logger) -> Path: |
| primary = PATHS.model_dir |
| fallback = Path("./hf_release/MINDI-1.0-420M") |
|
|
| if _is_valid_hf_model_dir(primary): |
| return primary |
|
|
| if _is_valid_hf_model_dir(fallback): |
| logger.warning( |
| "Primary model missing → using fallback %s", |
| fallback.resolve(), |
| ) |
| return fallback |
|
|
| raise FileNotFoundError("No valid model directory found.") |
|
|
|
|
| |
| |
| |
| def _build_model_and_tokenizer(model_path: Path): |
| tokenizer = AutoTokenizer.from_pretrained( |
| model_path, |
| trust_remote_code=True, |
| local_files_only=True, |
| ) |
|
|
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| model = AutoModelForCausalLM.from_pretrained( |
| model_path, |
| trust_remote_code=True, |
| use_safetensors=True, |
| ) |
|
|
| |
| lora_cfg = LoraConfig( |
| r=16, |
| lora_alpha=32, |
| lora_dropout=0.05, |
| bias="none", |
| task_type=TaskType.CAUSAL_LM, |
| target_modules="all-linear", |
| ) |
|
|
| model = get_peft_model(model, lora_cfg) |
| return model, tokenizer |
|
|
|
|
| |
| |
| |
| def get_latest_checkpoint(checkpoint_dir): |
| if not os.path.exists(checkpoint_dir): |
| return None |
|
|
| checkpoints = [ |
| d for d in os.listdir(checkpoint_dir) |
| if d.startswith("checkpoint-") |
| ] |
|
|
| if not checkpoints: |
| return None |
|
|
| checkpoints = sorted( |
| checkpoints, |
| key=lambda x: int(x.split("-")[-1]) |
| ) |
|
|
| return os.path.join(checkpoint_dir, checkpoints[-1]) |
|
|
|
|
| def safe_train(trainer, checkpoint_dir, logger): |
| latest_checkpoint = get_latest_checkpoint(checkpoint_dir) |
|
|
| if latest_checkpoint: |
| logger.info(f"Trying resume from: {latest_checkpoint}") |
| try: |
| trainer.train(resume_from_checkpoint=latest_checkpoint) |
| return |
| except Exception as e: |
| logger.warning(f"Resume failed → starting fresh: {e}") |
|
|
| trainer.train() |
|
|
|
|
| |
| |
| |
| def train(resume: bool): |
| ensure_dirs([ |
| PATHS.data_dir, |
| PATHS.output_dir, |
| PATHS.logs_dir, |
| PATHS.checkpoint_dir, |
| PATHS.lora_output_dir, |
| PATHS.tokenizer_output_dir, |
| ]) |
|
|
| logger = setup_logger("train", PATHS.logs_dir / "train.log") |
| set_seed(42) |
|
|
| model_path = _resolve_model_path(logger) |
| logger.info("Loading model from %s", model_path) |
|
|
| model, tokenizer = _build_model_and_tokenizer(model_path) |
| model.print_trainable_parameters() |
|
|
| train_dataset = LocalJsonlInstructionDataset( |
| tokenizer, |
| max_length=TRAINING_CONFIG.max_length |
| ) |
|
|
| training_args = TrainingArguments( |
| output_dir=str(PATHS.checkpoint_dir), |
| num_train_epochs=TRAINING_CONFIG.num_train_epochs, |
| per_device_train_batch_size=TRAINING_CONFIG.per_device_train_batch_size, |
| gradient_accumulation_steps=TRAINING_CONFIG.gradient_accumulation_steps, |
| learning_rate=TRAINING_CONFIG.learning_rate, |
| fp16=torch.cuda.is_available(), |
| logging_steps=50, |
| save_steps=250, |
| save_total_limit=3, |
| report_to="none", |
| remove_unused_columns=False, |
| ) |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| callbacks=[BackupCallback()], |
| ) |
|
|
| logger.info("Starting training...") |
|
|
| safe_train(trainer, str(PATHS.checkpoint_dir), logger) |
|
|
| trainer.model.save_pretrained(str(PATHS.lora_output_dir)) |
| tokenizer.save_pretrained(str(PATHS.tokenizer_output_dir)) |
|
|
| print("\n✅ Training complete.") |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--no-resume", action="store_true") |
| args = parser.parse_args() |
|
|
| train(resume=not args.no_resume) |