mindi-backup / train.py
Mindigenous
Update train.py
6a1099b
import argparse
from pathlib import Path
import os
import subprocess
import torch
from peft import LoraConfig, TaskType, get_peft_model
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Trainer,
TrainingArguments,
TrainerCallback,
set_seed,
)
from config import PATHS, TRAINING_CONFIG
from dataset import LocalJsonlInstructionDataset
from utils import ensure_dirs, setup_logger
# ==============================
# 🔥 BACKUP CALLBACK
# ==============================
class BackupCallback(TrainerCallback):
def on_save(self, args, state, control, **kwargs):
try:
checkpoint_dir = os.path.join(
args.output_dir,
f"checkpoint-{state.global_step}"
)
if not os.path.exists(checkpoint_dir):
return
os.makedirs("backups", exist_ok=True)
backup_name = f"backup_step{state.global_step}.tar.gz"
backup_path = os.path.join("backups", backup_name)
print(f"\n[BACKUP] Creating backup for step {state.global_step}...")
subprocess.run([
"tar", "-czf", backup_path, checkpoint_dir
], check=True)
print(f"[BACKUP] Saved: {backup_path}")
except Exception as e:
print(f"[BACKUP ERROR] {e}")
# ==============================
# MODEL PATH RESOLUTION
# ==============================
def _is_valid_hf_model_dir(path: Path) -> bool:
return path.exists() and (path / "config.json").exists()
def _resolve_model_path(logger) -> Path:
primary = PATHS.model_dir
fallback = Path("./hf_release/MINDI-1.0-420M")
if _is_valid_hf_model_dir(primary):
return primary
if _is_valid_hf_model_dir(fallback):
logger.warning(
"Primary model missing → using fallback %s",
fallback.resolve(),
)
return fallback
raise FileNotFoundError("No valid model directory found.")
# ==============================
# BUILD MODEL (FIXED)
# ==============================
def _build_model_and_tokenizer(model_path: Path):
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
local_files_only=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 🔥 FIXED MODEL LOADING
model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,
use_safetensors=True, # IMPORTANT
)
# LoRA
lora_cfg = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type=TaskType.CAUSAL_LM,
target_modules="all-linear",
)
model = get_peft_model(model, lora_cfg)
return model, tokenizer
# ==============================
# CHECKPOINT RESUME (SAFE)
# ==============================
def get_latest_checkpoint(checkpoint_dir):
if not os.path.exists(checkpoint_dir):
return None
checkpoints = [
d for d in os.listdir(checkpoint_dir)
if d.startswith("checkpoint-")
]
if not checkpoints:
return None
checkpoints = sorted(
checkpoints,
key=lambda x: int(x.split("-")[-1])
)
return os.path.join(checkpoint_dir, checkpoints[-1])
def safe_train(trainer, checkpoint_dir, logger):
latest_checkpoint = get_latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
logger.info(f"Trying resume from: {latest_checkpoint}")
try:
trainer.train(resume_from_checkpoint=latest_checkpoint)
return
except Exception as e:
logger.warning(f"Resume failed → starting fresh: {e}")
trainer.train()
# ==============================
# MAIN TRAIN
# ==============================
def train(resume: bool):
ensure_dirs([
PATHS.data_dir,
PATHS.output_dir,
PATHS.logs_dir,
PATHS.checkpoint_dir,
PATHS.lora_output_dir,
PATHS.tokenizer_output_dir,
])
logger = setup_logger("train", PATHS.logs_dir / "train.log")
set_seed(42)
model_path = _resolve_model_path(logger)
logger.info("Loading model from %s", model_path)
model, tokenizer = _build_model_and_tokenizer(model_path)
model.print_trainable_parameters()
train_dataset = LocalJsonlInstructionDataset(
tokenizer,
max_length=TRAINING_CONFIG.max_length
)
training_args = TrainingArguments(
output_dir=str(PATHS.checkpoint_dir),
num_train_epochs=TRAINING_CONFIG.num_train_epochs,
per_device_train_batch_size=TRAINING_CONFIG.per_device_train_batch_size,
gradient_accumulation_steps=TRAINING_CONFIG.gradient_accumulation_steps,
learning_rate=TRAINING_CONFIG.learning_rate,
fp16=torch.cuda.is_available(),
logging_steps=50,
save_steps=250,
save_total_limit=3,
report_to="none",
remove_unused_columns=False,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
callbacks=[BackupCallback()],
)
logger.info("Starting training...")
safe_train(trainer, str(PATHS.checkpoint_dir), logger)
trainer.model.save_pretrained(str(PATHS.lora_output_dir))
tokenizer.save_pretrained(str(PATHS.tokenizer_output_dir))
print("\n✅ Training complete.")
# ==============================
# ENTRY
# ==============================
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--no-resume", action="store_true")
args = parser.parse_args()
train(resume=not args.no_resume)