Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import random | |
| from contextlib import nullcontext | |
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from peft import LoraConfig, TaskType, get_peft_model | |
| from sklearn.model_selection import train_test_split | |
| from torch.utils.data import DataLoader, Dataset | |
| from transformers import SiglipModel, SiglipProcessor | |
| class CustomDataset(Dataset): | |
| def __init__(self, records: list[dict[str, Any]], processor: SiglipProcessor, max_length: int) -> None: | |
| self.records = records | |
| self.image_processor = processor.image_processor | |
| self.tokenizer = processor.tokenizer | |
| self.max_length = max_length | |
| def __len__(self) -> int: | |
| return len(self.records) | |
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: | |
| record = self.records[idx] | |
| image_path = record["image_path"] | |
| label = record["label_ko"] | |
| with Image.open(image_path) as img: | |
| image = img.convert("RGB") | |
| image_inputs = self.image_processor(images=image, return_tensors="pt") | |
| text_inputs = self.tokenizer( | |
| label, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_attention_mask=True, | |
| ) | |
| input_ids = text_inputs["input_ids"][0] | |
| if "attention_mask" in text_inputs: | |
| attention_mask = text_inputs["attention_mask"][0] | |
| else: | |
| pad_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 | |
| attention_mask = (input_ids != pad_id).long() | |
| return { | |
| "pixel_values": image_inputs["pixel_values"][0], | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| } | |
| def load_records(data_file: Path, data_root: Path) -> list[dict[str, Any]]: | |
| text = data_file.read_text(encoding="utf-8").strip() | |
| if not text: | |
| raise ValueError(f"Empty data file: {data_file}") | |
| if text.lstrip().startswith("["): | |
| raw_records = json.loads(text) | |
| else: | |
| raw_records = [] | |
| for line in text.splitlines(): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| raw_records.append(json.loads(line)) | |
| records: list[dict[str, Any]] = [] | |
| missing = 0 | |
| for rec in raw_records: | |
| image_path = rec.get("image_path") | |
| label = rec.get("label_ko") | |
| if not image_path or not label: | |
| continue | |
| path = Path(image_path) | |
| if not path.is_absolute(): | |
| path = (data_root / path).resolve() | |
| if not path.exists(): | |
| missing += 1 | |
| continue | |
| label_text = str(label).strip() | |
| if not label_text: | |
| continue | |
| records.append({"image_path": path, "label_ko": label_text}) | |
| if missing: | |
| print(f"Skipped {missing} records with missing images.") | |
| if not records: | |
| raise ValueError("No valid records found after filtering.") | |
| return records | |
| def prepare_model_and_processor( | |
| model_id: str, | |
| lora_r: int, | |
| lora_alpha: int, | |
| lora_dropout: float, | |
| ) -> tuple[SiglipModel, SiglipProcessor]: | |
| processor = SiglipProcessor.from_pretrained(model_id) | |
| base_model = SiglipModel.from_pretrained(model_id) | |
| for param in base_model.parameters(): | |
| param.requires_grad = False | |
| lora_config = LoraConfig( | |
| r=lora_r, | |
| lora_alpha=lora_alpha, | |
| lora_dropout=lora_dropout, | |
| bias="none", | |
| task_type=TaskType.FEATURE_EXTRACTION, | |
| target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], | |
| ) | |
| model = get_peft_model(base_model, lora_config) | |
| model.print_trainable_parameters() | |
| return model, processor | |
| def clip_contrastive_loss( | |
| model: SiglipModel, | |
| image_embeds: torch.Tensor, | |
| text_embeds: torch.Tensor, | |
| ) -> torch.Tensor: | |
| image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) | |
| text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) | |
| logit_scale = model.logit_scale.exp().clamp(max=100) | |
| logits_per_text = logit_scale * text_embeds @ image_embeds.t() | |
| logits_per_image = logits_per_text.t() | |
| labels = torch.arange(logits_per_text.size(0), device=logits_per_text.device) | |
| loss_t = F.cross_entropy(logits_per_text, labels) | |
| loss_i = F.cross_entropy(logits_per_image, labels) | |
| return (loss_t + loss_i) / 2 | |
| def evaluate( | |
| model: SiglipModel, | |
| data_loader: DataLoader, | |
| device: torch.device, | |
| autocast_context, | |
| ) -> float: | |
| model.eval() | |
| total_loss = 0.0 | |
| steps = 0 | |
| for batch in data_loader: | |
| batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} | |
| with autocast_context: | |
| image_embeds = model.get_image_features(pixel_values=batch["pixel_values"]) | |
| text_embeds = model.get_text_features( | |
| input_ids=batch["input_ids"], | |
| attention_mask=batch["attention_mask"], | |
| ) | |
| loss = clip_contrastive_loss(model, image_embeds, text_embeds) | |
| total_loss += loss.item() | |
| steps += 1 | |
| return total_loss / max(steps, 1) | |
| def run_similarity_test( | |
| model: SiglipModel, | |
| processor: SiglipProcessor, | |
| sample: dict[str, Any], | |
| device: torch.device, | |
| autocast_context, | |
| ) -> None: | |
| model.eval() | |
| image_path = sample["image_path"] | |
| label = sample["label_ko"] | |
| queries = [label, "unrelated item icon"] | |
| with Image.open(image_path) as img: | |
| image = img.convert("RGB") | |
| image_inputs = processor.image_processor(images=image, return_tensors="pt") | |
| text_inputs = processor.tokenizer( | |
| queries, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=processor.tokenizer.model_max_length, | |
| ) | |
| image_inputs = {k: v.to(device) for k, v in image_inputs.items()} | |
| text_inputs = {k: v.to(device) for k, v in text_inputs.items()} | |
| with autocast_context: | |
| image_features = model.get_image_features(**image_inputs) | |
| text_features = model.get_text_features(**text_inputs) | |
| image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| scores = (text_features @ image_features.T).squeeze(-1).cpu().tolist() | |
| print("Similarity test (higher is better):") | |
| for query, score in zip(queries, scores): | |
| print(f"- {query}: {score:.4f}") | |
| def collate_fn(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]: | |
| return { | |
| "pixel_values": torch.stack([item["pixel_values"] for item in batch]), | |
| "input_ids": torch.stack([item["input_ids"] for item in batch]), | |
| "attention_mask": torch.stack([item["attention_mask"] for item in batch]), | |
| } | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="LoRA fine-tuning for KoCLIP image-text retrieval.") | |
| parser.add_argument("--data-file", type=Path, required=True, help="Path to JSONL or JSON data file.") | |
| parser.add_argument( | |
| "--data-root", | |
| type=Path, | |
| default=Path.cwd(), | |
| help="Root directory for relative image paths.", | |
| ) | |
| parser.add_argument("--output-dir", type=Path, default=Path("outputs/ko-clip-lora")) | |
| parser.add_argument( | |
| "--model-id", | |
| type=str, | |
| default="google/siglip-base-patch16-256-multilingual", | |
| ) | |
| parser.add_argument("--epochs", type=int, default=10) | |
| parser.add_argument( | |
| "--batch-size", | |
| type=int, | |
| default=64, | |
| help="Per-device batch size (64-128 typical on 24GB; reduce if OOM).", | |
| ) | |
| parser.add_argument("--grad-accum-steps", type=int, default=1) | |
| parser.add_argument("--learning-rate", type=float, default=1e-4) | |
| parser.add_argument("--weight-decay", type=float, default=0.01) | |
| parser.add_argument("--val-ratio", type=float, default=0.1) | |
| parser.add_argument("--num-workers", type=int, default=4) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--lora-r", type=int, default=8) | |
| parser.add_argument("--lora-alpha", type=int, default=16) | |
| parser.add_argument("--lora-dropout", type=float, default=0.1) | |
| return parser.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| if not 0.1 <= args.val_ratio <= 0.15: | |
| raise ValueError("--val-ratio must be between 0.10 and 0.15") | |
| random.seed(args.seed) | |
| torch.manual_seed(args.seed) | |
| torch.cuda.manual_seed_all(args.seed) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if device.type != "cuda": | |
| print("WARNING: CUDA not available; using fp32. bf16 requires GPU.") | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| records = load_records(args.data_file, args.data_root) | |
| train_records, val_records = train_test_split( | |
| records, | |
| test_size=args.val_ratio, | |
| random_state=args.seed, | |
| shuffle=True, | |
| ) | |
| print(f"Loaded {len(records)} samples (train={len(train_records)}, val={len(val_records)}).") | |
| model, processor = prepare_model_and_processor( | |
| args.model_id, | |
| lora_r=args.lora_r, | |
| lora_alpha=args.lora_alpha, | |
| lora_dropout=args.lora_dropout, | |
| ) | |
| model.to(device) | |
| max_length = processor.tokenizer.model_max_length | |
| train_dataset = CustomDataset(train_records, processor, max_length) | |
| val_dataset = CustomDataset(val_records, processor, max_length) | |
| pin_memory = device.type == "cuda" | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| num_workers=args.num_workers, | |
| pin_memory=pin_memory, | |
| collate_fn=collate_fn, | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| num_workers=args.num_workers, | |
| pin_memory=pin_memory, | |
| collate_fn=collate_fn, | |
| ) | |
| trainable_params = [p for p in model.parameters() if p.requires_grad] | |
| if not trainable_params: | |
| raise RuntimeError("No trainable parameters found. Check LoRA target_modules.") | |
| optimizer = torch.optim.AdamW(trainable_params, lr=args.learning_rate, weight_decay=args.weight_decay) | |
| if device.type == "cuda": | |
| autocast_context = torch.cuda.amp.autocast(dtype=torch.bfloat16) | |
| else: | |
| autocast_context = nullcontext() | |
| best_val = float("inf") | |
| output_dir = args.output_dir | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| for epoch in range(1, args.epochs + 1): | |
| model.train() | |
| total_loss = 0.0 | |
| steps = 0 | |
| for step, batch in enumerate(train_loader, start=1): | |
| batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} | |
| with autocast_context: | |
| image_embeds = model.get_image_features(pixel_values=batch["pixel_values"]) | |
| text_embeds = model.get_text_features( | |
| input_ids=batch["input_ids"], | |
| attention_mask=batch["attention_mask"], | |
| ) | |
| loss = clip_contrastive_loss(model, image_embeds, text_embeds) | |
| total_loss += loss.item() | |
| loss = loss / args.grad_accum_steps | |
| loss.backward() | |
| if step % args.grad_accum_steps == 0 or step == len(train_loader): | |
| torch.nn.utils.clip_grad_norm_(trainable_params, 1.0) | |
| optimizer.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| steps += 1 | |
| train_loss = total_loss / max(steps, 1) | |
| val_loss = evaluate(model, val_loader, device, autocast_context) | |
| print(f"Epoch {epoch:02d} | train loss: {train_loss:.4f} | val loss: {val_loss:.4f}") | |
| if val_loss < best_val: | |
| best_val = val_loss | |
| best_dir = output_dir / "best_model" | |
| best_dir.mkdir(parents=True, exist_ok=True) | |
| model.save_pretrained(best_dir) | |
| if val_records: | |
| run_similarity_test(model, processor, val_records[0], device, autocast_context) | |
| if __name__ == "__main__": | |
| main() | |