mini-llava-demo / src /train.py
AD-Styles's picture
Upload folder using huggingface_hub
95e4119 verified
"""Stage 1 학습 — projector만 학습하여 시각 특징을 LLM 임베딩 공간으로 정렬.
사용 예:
python -m src.train \\
--data-path data/coco_subset/manifest.json \\
--output-dir checkpoints/stage1 \\
--batch-size 8 --epochs 1 --lr 1e-3
"""
from __future__ import annotations
import argparse
import math
import os
import random
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from tqdm import tqdm
from .config import TrainConfig
from .dataset import VQACollator, VQADataset
from .model import MiniLLaVA
def set_seed(seed: int):
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def cosine_lr_lambda(total_steps: int, warmup_steps: int):
def fn(step: int):
if step < warmup_steps:
return step / max(1, warmup_steps)
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
return 0.5 * (1.0 + math.cos(math.pi * progress))
return fn
def maybe_apply_lora(model: MiniLLaVA, cfg: TrainConfig):
"""Stage 2: 기존 projector는 그대로 학습 가능 + LLM에 LoRA 어댑터 추가."""
if not cfg.use_lora:
return model
from peft import LoraConfig, get_peft_model
lora_cfg = LoraConfig(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
lora_dropout=cfg.lora_dropout,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
task_type="CAUSAL_LM",
)
model.llm = get_peft_model(model.llm, lora_cfg)
# PEFT 가 base LLM을 자동 freeze. projector는 외부라 trainable 유지.
return model
def parse_args() -> TrainConfig:
p = argparse.ArgumentParser()
p.add_argument("--data-path", type=str, required=True)
p.add_argument("--output-dir", type=str, default="checkpoints/stage1")
p.add_argument("--batch-size", type=int, default=8)
p.add_argument("--grad-accum-steps", type=int, default=1)
p.add_argument("--epochs", type=int, default=1)
p.add_argument("--lr", type=float, default=1e-3)
p.add_argument("--weight-decay", type=float, default=0.0)
p.add_argument("--warmup-ratio", type=float, default=0.03)
p.add_argument("--max-text-length", type=int, default=512)
p.add_argument("--log-every", type=int, default=20)
p.add_argument("--save-every", type=int, default=500)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--use-lora", action="store_true",
help="Stage 2: LoRA adapter on LLM + projector 동시 학습")
p.add_argument("--lora-r", type=int, default=16)
p.add_argument("--lora-alpha", type=int, default=32)
p.add_argument("--lora-dropout", type=float, default=0.05)
p.add_argument("--init-projector", type=str, default=None,
help="기존 projector ckpt에서 시작 (Stage 1 → Stage 2 이어 학습)")
args = p.parse_args()
return TrainConfig(**vars(args))
def main():
cfg = parse_args()
set_seed(cfg.seed)
os.makedirs(cfg.output_dir, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[device] {device}")
print("[init] loading MiniLLaVA ...")
model = MiniLLaVA(freeze_vision=True, freeze_llm=not cfg.use_lora)
if cfg.init_projector and os.path.exists(cfg.init_projector):
print(f"[init] loading existing projector → {cfg.init_projector}")
model.load_projector(cfg.init_projector, map_location="cpu")
model = maybe_apply_lora(model, cfg)
model.to(device)
print(f"[init] trainable params: {model.num_trainable():,}")
print(f"[data] loading {cfg.data_path}")
dataset = VQADataset(
cfg.data_path, model.tokenizer, model.image_processor, cfg.max_text_length
)
collator = VQACollator(pad_token_id=model.tokenizer.pad_token_id)
loader = DataLoader(
dataset,
batch_size=cfg.batch_size,
shuffle=True,
num_workers=2,
pin_memory=True,
collate_fn=collator,
)
print(f"[data] {len(dataset)} samples, {len(loader)} batches/epoch")
optimizer = AdamW(
model.trainable_parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay
)
total_steps = (len(loader) // cfg.grad_accum_steps) * cfg.epochs
warmup_steps = int(total_steps * cfg.warmup_ratio)
scheduler = LambdaLR(optimizer, cosine_lr_lambda(total_steps, warmup_steps))
global_step = 0
model.train()
if hasattr(model, "vision"):
model.vision.eval()
for epoch in range(cfg.epochs):
pbar = tqdm(loader, desc=f"epoch {epoch + 1}/{cfg.epochs}")
running_loss = 0.0
for step, batch in enumerate(pbar):
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss / cfg.grad_accum_steps
loss.backward()
running_loss += loss.item() * cfg.grad_accum_steps
if (step + 1) % cfg.grad_accum_steps == 0:
torch.nn.utils.clip_grad_norm_(model.trainable_parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad(set_to_none=True)
global_step += 1
if global_step % cfg.log_every == 0:
avg = running_loss / (cfg.log_every * cfg.grad_accum_steps)
pbar.set_postfix(
loss=f"{avg:.4f}", lr=f"{scheduler.get_last_lr()[0]:.2e}"
)
running_loss = 0.0
if global_step % cfg.save_every == 0:
ckpt = os.path.join(
cfg.output_dir, f"projector_step{global_step}.pt"
)
model.save_projector(ckpt)
final_path = os.path.join(cfg.output_dir, "projector.pt")
model.save_projector(final_path)
print(f"[done] saved → {final_path}")
if cfg.use_lora:
lora_dir = os.path.join(cfg.output_dir, "lora_adapter")
model.llm.save_pretrained(lora_dir)
print(f"[done] saved LoRA → {lora_dir}")
if __name__ == "__main__":
main()