"""Stage 1 학습 — projector만 학습하여 시각 특징을 LLM 임베딩 공간으로 정렬. 사용 예: python -m src.train \\ --data-path data/coco_subset/manifest.json \\ --output-dir checkpoints/stage1 \\ --batch-size 8 --epochs 1 --lr 1e-3 """ from __future__ import annotations import argparse import math import os import random import torch from torch.optim import AdamW from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader from tqdm import tqdm from .config import TrainConfig from .dataset import VQACollator, VQADataset from .model import MiniLLaVA def set_seed(seed: int): random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def cosine_lr_lambda(total_steps: int, warmup_steps: int): def fn(step: int): if step < warmup_steps: return step / max(1, warmup_steps) progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) return 0.5 * (1.0 + math.cos(math.pi * progress)) return fn def maybe_apply_lora(model: MiniLLaVA, cfg: TrainConfig): """Stage 2: 기존 projector는 그대로 학습 가능 + LLM에 LoRA 어댑터 추가.""" if not cfg.use_lora: return model from peft import LoraConfig, get_peft_model lora_cfg = LoraConfig( r=cfg.lora_r, lora_alpha=cfg.lora_alpha, lora_dropout=cfg.lora_dropout, target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM", ) model.llm = get_peft_model(model.llm, lora_cfg) # PEFT 가 base LLM을 자동 freeze. projector는 외부라 trainable 유지. return model def parse_args() -> TrainConfig: p = argparse.ArgumentParser() p.add_argument("--data-path", type=str, required=True) p.add_argument("--output-dir", type=str, default="checkpoints/stage1") p.add_argument("--batch-size", type=int, default=8) p.add_argument("--grad-accum-steps", type=int, default=1) p.add_argument("--epochs", type=int, default=1) p.add_argument("--lr", type=float, default=1e-3) p.add_argument("--weight-decay", type=float, default=0.0) p.add_argument("--warmup-ratio", type=float, default=0.03) p.add_argument("--max-text-length", type=int, default=512) p.add_argument("--log-every", type=int, default=20) p.add_argument("--save-every", type=int, default=500) p.add_argument("--seed", type=int, default=42) p.add_argument("--use-lora", action="store_true", help="Stage 2: LoRA adapter on LLM + projector 동시 학습") p.add_argument("--lora-r", type=int, default=16) p.add_argument("--lora-alpha", type=int, default=32) p.add_argument("--lora-dropout", type=float, default=0.05) p.add_argument("--init-projector", type=str, default=None, help="기존 projector ckpt에서 시작 (Stage 1 → Stage 2 이어 학습)") args = p.parse_args() return TrainConfig(**vars(args)) def main(): cfg = parse_args() set_seed(cfg.seed) os.makedirs(cfg.output_dir, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[device] {device}") print("[init] loading MiniLLaVA ...") model = MiniLLaVA(freeze_vision=True, freeze_llm=not cfg.use_lora) if cfg.init_projector and os.path.exists(cfg.init_projector): print(f"[init] loading existing projector → {cfg.init_projector}") model.load_projector(cfg.init_projector, map_location="cpu") model = maybe_apply_lora(model, cfg) model.to(device) print(f"[init] trainable params: {model.num_trainable():,}") print(f"[data] loading {cfg.data_path}") dataset = VQADataset( cfg.data_path, model.tokenizer, model.image_processor, cfg.max_text_length ) collator = VQACollator(pad_token_id=model.tokenizer.pad_token_id) loader = DataLoader( dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=2, pin_memory=True, collate_fn=collator, ) print(f"[data] {len(dataset)} samples, {len(loader)} batches/epoch") optimizer = AdamW( model.trainable_parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay ) total_steps = (len(loader) // cfg.grad_accum_steps) * cfg.epochs warmup_steps = int(total_steps * cfg.warmup_ratio) scheduler = LambdaLR(optimizer, cosine_lr_lambda(total_steps, warmup_steps)) global_step = 0 model.train() if hasattr(model, "vision"): model.vision.eval() for epoch in range(cfg.epochs): pbar = tqdm(loader, desc=f"epoch {epoch + 1}/{cfg.epochs}") running_loss = 0.0 for step, batch in enumerate(pbar): batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} outputs = model(**batch) loss = outputs.loss / cfg.grad_accum_steps loss.backward() running_loss += loss.item() * cfg.grad_accum_steps if (step + 1) % cfg.grad_accum_steps == 0: torch.nn.utils.clip_grad_norm_(model.trainable_parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad(set_to_none=True) global_step += 1 if global_step % cfg.log_every == 0: avg = running_loss / (cfg.log_every * cfg.grad_accum_steps) pbar.set_postfix( loss=f"{avg:.4f}", lr=f"{scheduler.get_last_lr()[0]:.2e}" ) running_loss = 0.0 if global_step % cfg.save_every == 0: ckpt = os.path.join( cfg.output_dir, f"projector_step{global_step}.pt" ) model.save_projector(ckpt) final_path = os.path.join(cfg.output_dir, "projector.pt") model.save_projector(final_path) print(f"[done] saved → {final_path}") if cfg.use_lora: lora_dir = os.path.join(cfg.output_dir, "lora_adapter") model.llm.save_pretrained(lora_dir) print(f"[done] saved LoRA → {lora_dir}") if __name__ == "__main__": main()