|
|
import argparse |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import shutil |
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torchvision.transforms as T |
|
|
from datasets import load_dataset |
|
|
from huggingface_hub import HfApi |
|
|
from peft import LoraConfig, PeftModel, get_peft_model |
|
|
from PIL import Image |
|
|
from safetensors import safe_open |
|
|
from transformers import ( |
|
|
AutoImageProcessor, |
|
|
Dinov2ForImageClassification, |
|
|
Trainer, |
|
|
TrainingArguments, |
|
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
MODEL = "facebook/dinov2-base-imagenet1k-1-layer" |
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description='Train a DINOv2 model for font classification') |
|
|
parser.add_argument('--data_dir', type=str, default=None, |
|
|
help='Directory containing the font dataset') |
|
|
parser.add_argument('--output_dir', type=str, default=None, |
|
|
help='Directory to save the model') |
|
|
parser.add_argument('--checkpoint', type=str, default=None, |
|
|
help='Path to checkpoint to resume training from') |
|
|
parser.add_argument('--batch_size', type=int, default=32, |
|
|
help='Training and evaluation batch size') |
|
|
parser.add_argument('--epochs', type=int, default=1, |
|
|
help='Number of training epochs') |
|
|
parser.add_argument('--learning_rate', type=float, default=1e-4, |
|
|
help='Learning rate for training') |
|
|
parser.add_argument('--lora_rank', type=int, default=8, |
|
|
help='LoRA rank for parameter-efficient fine-tuning') |
|
|
parser.add_argument('--lora_alpha', type=int, default=16, |
|
|
help='LoRA alpha parameter') |
|
|
parser.add_argument('--lora_dropout', type=float, default=0.1, |
|
|
help='LoRA dropout rate') |
|
|
parser.add_argument('--test_size', type=float, default=0.1, |
|
|
help='Proportion of data to use for validation') |
|
|
parser.add_argument('--seed', type=int, default=42, |
|
|
help='Random seed for reproducibility') |
|
|
parser.add_argument('--log_level', type=str, default='INFO', |
|
|
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], |
|
|
help='Logging level') |
|
|
parser.add_argument('--huggingface_model_name', type=str, default=None, |
|
|
help='Name of the model to push to the Hub') |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def load_checkpoint_with_size_mismatch_handling(base_model, checkpoint_path, peft_config): |
|
|
""" |
|
|
Load PEFT checkpoint with automatic handling of size mismatches. |
|
|
This uses PEFT's built-in loading but with strict=False to handle size mismatches gracefully. |
|
|
|
|
|
Basically, if we have a different number of labels than in the checkpoint, we re-initialize the classifier head to relearn it. |
|
|
|
|
|
Args: |
|
|
base_model: The base model with the new classifier size |
|
|
checkpoint_path: Path to the checkpoint |
|
|
peft_config: LoraConfig object with the desired configuration |
|
|
|
|
|
Returns: |
|
|
PeftModel with loaded weights (mismatched layers will be skipped) |
|
|
""" |
|
|
logger.info(f"Loading checkpoint with automatic size mismatch handling: {checkpoint_path}") |
|
|
|
|
|
try: |
|
|
|
|
|
model = PeftModel.from_pretrained( |
|
|
base_model, |
|
|
checkpoint_path, |
|
|
is_trainable=True |
|
|
) |
|
|
logger.info("Successfully loaded checkpoint without size mismatches") |
|
|
return model |
|
|
except Exception as e: |
|
|
logger.info(f"Standard loading failed ({str(e)}), using fallback loading method") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fresh_base = Dinov2ForImageClassification.from_pretrained( |
|
|
MODEL, |
|
|
num_labels=base_model.config.num_labels, |
|
|
ignore_mismatched_sizes=True, |
|
|
) |
|
|
|
|
|
model = get_peft_model(fresh_base, peft_config) |
|
|
|
|
|
|
|
|
checkpoint_file = os.path.join(checkpoint_path, "adapter_model.safetensors") |
|
|
|
|
|
if not os.path.exists(checkpoint_file): |
|
|
raise ValueError(f"Checkpoint file {checkpoint_file} does not exist") |
|
|
|
|
|
checkpoint_state_dict = {} |
|
|
with safe_open(checkpoint_file, framework="pt", device="cpu") as f: |
|
|
for key in f.keys(): |
|
|
checkpoint_state_dict[key] = f.get_tensor(key) |
|
|
|
|
|
|
|
|
missing_keys, unexpected_keys = model.load_state_dict(checkpoint_state_dict, strict=False) |
|
|
|
|
|
logger.info(f"Loaded checkpoint with {len(missing_keys)} missing keys and {len(unexpected_keys)} unexpected keys") |
|
|
logger.info(f"The following keys were in the checkpoint but are now missing: {missing_keys}") |
|
|
logger.info(f"The following keys are new i.e. unexpected: {unexpected_keys}") |
|
|
logger.info("Missing keys (likely new classifier parameters): will be randomly initialized") |
|
|
|
|
|
return model |
|
|
|
|
|
def get_inference_transform(processor: AutoImageProcessor, size: int): |
|
|
"""Get the raw validation transform for direct inference on PIL images.""" |
|
|
normalize = T.Normalize(mean=processor.image_mean, std=processor.image_std) |
|
|
|
|
|
to_rgb = T.Lambda(lambda img: img.convert('RGB')) |
|
|
|
|
|
def pad_to_square(img): |
|
|
w, h = img.size |
|
|
max_size = max(w, h) |
|
|
pad_w = (max_size - w) // 2 |
|
|
pad_h = (max_size - h) // 2 |
|
|
padding = (pad_w, pad_h, max_size - w - pad_w, max_size - h - pad_h) |
|
|
return T.Pad(padding, fill=0)(img) |
|
|
|
|
|
aug = T.Compose([ |
|
|
to_rgb, |
|
|
pad_to_square, |
|
|
T.Resize(size), |
|
|
T.ToTensor(), |
|
|
normalize |
|
|
]) |
|
|
|
|
|
return aug |
|
|
|
|
|
|
|
|
|
|
|
def get_transform(processor: AutoImageProcessor, size: int): |
|
|
aug = get_inference_transform(processor, size) |
|
|
|
|
|
def transform(example, train=True): |
|
|
|
|
|
|
|
|
inputs = processor(images=example["image"], return_tensors="pt") |
|
|
example["pixel_values"] = inputs["pixel_values"].squeeze(0) |
|
|
return example |
|
|
|
|
|
return transform |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = parse_args() |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=args.log_level, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s - %(filename)s:%(lineno)d', |
|
|
datefmt='%Y-%m-%d %H:%M:%S' |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Loading dataset from {args.data_dir}") |
|
|
|
|
|
|
|
|
label_names = sorted(os.listdir(f"{args.data_dir}/train")) |
|
|
logger.info(f"Found {len(label_names)} labels") |
|
|
|
|
|
if len(label_names) <= 1: |
|
|
raise ValueError(f"Expected at least 2 labels, got {label_names=}, imagefolder will not label the dataset if there are less than 2 labels.") |
|
|
|
|
|
|
|
|
train_dataset = None |
|
|
test_dataset = None |
|
|
|
|
|
|
|
|
logger.info("Setting up image processor and augmentations") |
|
|
processor = AutoImageProcessor.from_pretrained(MODEL) |
|
|
size = processor.size["shortest_edge"] |
|
|
|
|
|
if args.epochs > 0: |
|
|
dataset = load_dataset( |
|
|
"imagefolder", |
|
|
data_dir=args.data_dir, |
|
|
) |
|
|
|
|
|
logger.info(f"Train size: {len(dataset['train'])}, Validation size: {len(dataset['test'])}") |
|
|
|
|
|
transform = get_transform(processor, size) |
|
|
|
|
|
logger.info("Applying data transformations") |
|
|
train_dataset = dataset["train"].map( |
|
|
lambda x: transform(x, train=True), |
|
|
remove_columns=["image"], |
|
|
desc="Transforming training data" |
|
|
) |
|
|
test_dataset = dataset["test"].map( |
|
|
lambda x: transform(x, train=False), |
|
|
remove_columns=["image"], |
|
|
desc="Transforming test data" |
|
|
) |
|
|
|
|
|
|
|
|
train_dataset.set_format(type="torch", columns=["pixel_values", "label"]) |
|
|
test_dataset.set_format(type="torch", columns=["pixel_values", "label"]) |
|
|
|
|
|
logger.info("Data preprocessing complete") |
|
|
|
|
|
logger.info("Loading DINOv2 model") |
|
|
|
|
|
base = Dinov2ForImageClassification.from_pretrained( |
|
|
MODEL, |
|
|
num_labels=len(label_names), |
|
|
ignore_mismatched_sizes=True, |
|
|
) |
|
|
|
|
|
logger.info("Configuring LoRA adapters") |
|
|
peft_cfg = LoraConfig( |
|
|
r = args.lora_rank, |
|
|
lora_alpha = args.lora_alpha, |
|
|
target_modules = ["query", "value"], |
|
|
lora_dropout = args.lora_dropout, |
|
|
bias = "none", |
|
|
modules_to_save = ["classifier"], |
|
|
) |
|
|
|
|
|
if args.checkpoint: |
|
|
model = load_checkpoint_with_size_mismatch_handling(base, args.checkpoint, peft_cfg) |
|
|
else: |
|
|
model = get_peft_model(base, peft_cfg) |
|
|
|
|
|
model.print_trainable_parameters() |
|
|
|
|
|
def collate(batch): |
|
|
|
|
|
pixel_values = torch.stack([item["pixel_values"] for item in batch]) |
|
|
labels = torch.tensor([item["label"] for item in batch]) |
|
|
return {"pixel_values": pixel_values, "labels": labels} |
|
|
|
|
|
|
|
|
def compute_metrics(eval_pred): |
|
|
predictions, labels = eval_pred |
|
|
predictions = predictions.argmax(axis=-1) |
|
|
accuracy = (predictions == labels).mean() |
|
|
return {"accuracy": accuracy} |
|
|
|
|
|
logger.info("Setting up training arguments") |
|
|
|
|
|
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu") |
|
|
logger.info(f"Using device: {device}") |
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir = args.output_dir, |
|
|
per_device_train_batch_size = args.batch_size, |
|
|
per_device_eval_batch_size = args.batch_size, |
|
|
|
|
|
|
|
|
|
|
|
label_names=["labels"], |
|
|
eval_strategy = "steps" if args.epochs > 0 else "no", |
|
|
eval_steps = 500, |
|
|
save_strategy = "steps" if args.epochs > 0 else "no", |
|
|
save_steps = 500, |
|
|
num_train_epochs = args.epochs, |
|
|
learning_rate = args.learning_rate, |
|
|
weight_decay = 0.05, |
|
|
fp16 = device.type == "cuda", |
|
|
save_total_limit = 3, |
|
|
logging_dir = os.path.join(args.output_dir, "logs") if args.output_dir else None, |
|
|
logging_steps = 10, |
|
|
report_to = "tensorboard", |
|
|
load_best_model_at_end = True, |
|
|
metric_for_best_model = "eval_accuracy", |
|
|
greater_is_better = True, |
|
|
|
|
|
resume_from_checkpoint = args.checkpoint if args.checkpoint else None, |
|
|
) |
|
|
|
|
|
trainer = Trainer( |
|
|
model = model, |
|
|
args = training_args, |
|
|
train_dataset = train_dataset, |
|
|
eval_dataset = test_dataset, |
|
|
data_collator = collate, |
|
|
compute_metrics = compute_metrics, |
|
|
) |
|
|
|
|
|
logger.info("Starting training") |
|
|
if args.checkpoint: |
|
|
logger.info(f"Resuming training from checkpoint: {args.checkpoint}") |
|
|
|
|
|
if args.epochs > 0: |
|
|
trainer.train() |
|
|
logger.info("Training complete") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.output_dir: |
|
|
logger.info("Saving result model to the output directory") |
|
|
trainer.save_model(f"{args.output_dir}/result_model") |
|
|
|
|
|
if args.huggingface_model_name: |
|
|
logger.info(f"Pushing model to the Hub: {args.huggingface_model_name}") |
|
|
|
|
|
trainer.hub_model_id = args.huggingface_model_name |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmp: |
|
|
|
|
|
merged = trainer.model.merge_and_unload() |
|
|
id2label = {i: name for i, name in enumerate(label_names)} |
|
|
label2id = {name: i for i, name in enumerate(label_names)} |
|
|
|
|
|
merged.config.id2label = id2label |
|
|
merged.config.label2id = label2id |
|
|
merged.config.pipeline_tag = "image-classification" |
|
|
merged.save_pretrained(tmp, safe_serialization=True) |
|
|
processor.save_pretrained(tmp) |
|
|
|
|
|
|
|
|
shutil.copy("train_model.py", tmp) |
|
|
shutil.copy("handler.py", tmp) |
|
|
Path(tmp, "requirements.txt").write_text("\n".join([ |
|
|
"torchvision>=0.19", |
|
|
"Pillow>=10", |
|
|
"datasets>=2.19", |
|
|
"peft>=0.10", |
|
|
"safetensors>=0.4", |
|
|
])) |
|
|
|
|
|
HfApi().upload_folder( |
|
|
repo_id=args.huggingface_model_name, |
|
|
folder_path=tmp, |
|
|
commit_message="Add merged model + processor", |
|
|
token=os.environ["HUGGINGFACE_API_KEY"], |
|
|
) |
|
|
|