maple-data / train.py
gyubin02's picture
.
fed41a9
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import random
from contextlib import nullcontext
from pathlib import Path
from typing import Any
import torch
import torch.nn.functional as F
from PIL import Image
from peft import LoraConfig, TaskType, get_peft_model
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from transformers import SiglipModel, SiglipProcessor
class CustomDataset(Dataset):
def __init__(self, records: list[dict[str, Any]], processor: SiglipProcessor, max_length: int) -> None:
self.records = records
self.image_processor = processor.image_processor
self.tokenizer = processor.tokenizer
self.max_length = max_length
def __len__(self) -> int:
return len(self.records)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
record = self.records[idx]
image_path = record["image_path"]
label = record["label_ko"]
with Image.open(image_path) as img:
image = img.convert("RGB")
image_inputs = self.image_processor(images=image, return_tensors="pt")
text_inputs = self.tokenizer(
label,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_length,
return_attention_mask=True,
)
input_ids = text_inputs["input_ids"][0]
if "attention_mask" in text_inputs:
attention_mask = text_inputs["attention_mask"][0]
else:
pad_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0
attention_mask = (input_ids != pad_id).long()
return {
"pixel_values": image_inputs["pixel_values"][0],
"input_ids": input_ids,
"attention_mask": attention_mask,
}
def load_records(data_file: Path, data_root: Path) -> list[dict[str, Any]]:
text = data_file.read_text(encoding="utf-8").strip()
if not text:
raise ValueError(f"Empty data file: {data_file}")
if text.lstrip().startswith("["):
raw_records = json.loads(text)
else:
raw_records = []
for line in text.splitlines():
line = line.strip()
if not line:
continue
raw_records.append(json.loads(line))
records: list[dict[str, Any]] = []
missing = 0
for rec in raw_records:
image_path = rec.get("image_path")
label = rec.get("label_ko")
if not image_path or not label:
continue
path = Path(image_path)
if not path.is_absolute():
path = (data_root / path).resolve()
if not path.exists():
missing += 1
continue
label_text = str(label).strip()
if not label_text:
continue
records.append({"image_path": path, "label_ko": label_text})
if missing:
print(f"Skipped {missing} records with missing images.")
if not records:
raise ValueError("No valid records found after filtering.")
return records
def prepare_model_and_processor(
model_id: str,
lora_r: int,
lora_alpha: int,
lora_dropout: float,
) -> tuple[SiglipModel, SiglipProcessor]:
processor = SiglipProcessor.from_pretrained(model_id)
base_model = SiglipModel.from_pretrained(model_id)
for param in base_model.parameters():
param.requires_grad = False
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
bias="none",
task_type=TaskType.FEATURE_EXTRACTION,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()
return model, processor
def clip_contrastive_loss(
model: SiglipModel,
image_embeds: torch.Tensor,
text_embeds: torch.Tensor,
) -> torch.Tensor:
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
logit_scale = model.logit_scale.exp().clamp(max=100)
logits_per_text = logit_scale * text_embeds @ image_embeds.t()
logits_per_image = logits_per_text.t()
labels = torch.arange(logits_per_text.size(0), device=logits_per_text.device)
loss_t = F.cross_entropy(logits_per_text, labels)
loss_i = F.cross_entropy(logits_per_image, labels)
return (loss_t + loss_i) / 2
@torch.no_grad()
def evaluate(
model: SiglipModel,
data_loader: DataLoader,
device: torch.device,
autocast_context,
) -> float:
model.eval()
total_loss = 0.0
steps = 0
for batch in data_loader:
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
with autocast_context:
image_embeds = model.get_image_features(pixel_values=batch["pixel_values"])
text_embeds = model.get_text_features(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
)
loss = clip_contrastive_loss(model, image_embeds, text_embeds)
total_loss += loss.item()
steps += 1
return total_loss / max(steps, 1)
@torch.no_grad()
def run_similarity_test(
model: SiglipModel,
processor: SiglipProcessor,
sample: dict[str, Any],
device: torch.device,
autocast_context,
) -> None:
model.eval()
image_path = sample["image_path"]
label = sample["label_ko"]
queries = [label, "unrelated item icon"]
with Image.open(image_path) as img:
image = img.convert("RGB")
image_inputs = processor.image_processor(images=image, return_tensors="pt")
text_inputs = processor.tokenizer(
queries,
return_tensors="pt",
padding=True,
truncation=True,
max_length=processor.tokenizer.model_max_length,
)
image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
with autocast_context:
image_features = model.get_image_features(**image_inputs)
text_features = model.get_text_features(**text_inputs)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
scores = (text_features @ image_features.T).squeeze(-1).cpu().tolist()
print("Similarity test (higher is better):")
for query, score in zip(queries, scores):
print(f"- {query}: {score:.4f}")
def collate_fn(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
return {
"pixel_values": torch.stack([item["pixel_values"] for item in batch]),
"input_ids": torch.stack([item["input_ids"] for item in batch]),
"attention_mask": torch.stack([item["attention_mask"] for item in batch]),
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="LoRA fine-tuning for KoCLIP image-text retrieval.")
parser.add_argument("--data-file", type=Path, required=True, help="Path to JSONL or JSON data file.")
parser.add_argument(
"--data-root",
type=Path,
default=Path.cwd(),
help="Root directory for relative image paths.",
)
parser.add_argument("--output-dir", type=Path, default=Path("outputs/ko-clip-lora"))
parser.add_argument(
"--model-id",
type=str,
default="google/siglip-base-patch16-256-multilingual",
)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument(
"--batch-size",
type=int,
default=64,
help="Per-device batch size (64-128 typical on 24GB; reduce if OOM).",
)
parser.add_argument("--grad-accum-steps", type=int, default=1)
parser.add_argument("--learning-rate", type=float, default=1e-4)
parser.add_argument("--weight-decay", type=float, default=0.01)
parser.add_argument("--val-ratio", type=float, default=0.1)
parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--lora-r", type=int, default=8)
parser.add_argument("--lora-alpha", type=int, default=16)
parser.add_argument("--lora-dropout", type=float, default=0.1)
return parser.parse_args()
def main() -> None:
args = parse_args()
if not 0.1 <= args.val_ratio <= 0.15:
raise ValueError("--val-ratio must be between 0.10 and 0.15")
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type != "cuda":
print("WARNING: CUDA not available; using fp32. bf16 requires GPU.")
torch.backends.cuda.matmul.allow_tf32 = True
records = load_records(args.data_file, args.data_root)
train_records, val_records = train_test_split(
records,
test_size=args.val_ratio,
random_state=args.seed,
shuffle=True,
)
print(f"Loaded {len(records)} samples (train={len(train_records)}, val={len(val_records)}).")
model, processor = prepare_model_and_processor(
args.model_id,
lora_r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
)
model.to(device)
max_length = processor.tokenizer.model_max_length
train_dataset = CustomDataset(train_records, processor, max_length)
val_dataset = CustomDataset(val_records, processor, max_length)
pin_memory = device.type == "cuda"
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=pin_memory,
collate_fn=collate_fn,
)
val_loader = DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=pin_memory,
collate_fn=collate_fn,
)
trainable_params = [p for p in model.parameters() if p.requires_grad]
if not trainable_params:
raise RuntimeError("No trainable parameters found. Check LoRA target_modules.")
optimizer = torch.optim.AdamW(trainable_params, lr=args.learning_rate, weight_decay=args.weight_decay)
if device.type == "cuda":
autocast_context = torch.cuda.amp.autocast(dtype=torch.bfloat16)
else:
autocast_context = nullcontext()
best_val = float("inf")
output_dir = args.output_dir
output_dir.mkdir(parents=True, exist_ok=True)
for epoch in range(1, args.epochs + 1):
model.train()
total_loss = 0.0
steps = 0
for step, batch in enumerate(train_loader, start=1):
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
with autocast_context:
image_embeds = model.get_image_features(pixel_values=batch["pixel_values"])
text_embeds = model.get_text_features(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
)
loss = clip_contrastive_loss(model, image_embeds, text_embeds)
total_loss += loss.item()
loss = loss / args.grad_accum_steps
loss.backward()
if step % args.grad_accum_steps == 0 or step == len(train_loader):
torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
optimizer.step()
optimizer.zero_grad(set_to_none=True)
steps += 1
train_loss = total_loss / max(steps, 1)
val_loss = evaluate(model, val_loader, device, autocast_context)
print(f"Epoch {epoch:02d} | train loss: {train_loss:.4f} | val loss: {val_loss:.4f}")
if val_loss < best_val:
best_val = val_loss
best_dir = output_dir / "best_model"
best_dir.mkdir(parents=True, exist_ok=True)
model.save_pretrained(best_dir)
if val_records:
run_similarity_test(model, processor, val_records[0], device, autocast_context)
if __name__ == "__main__":
main()