| | from transformers import ( |
| | ViTForImageClassification, |
| | ViTImageProcessor, |
| | TrainingArguments, |
| | Trainer, |
| | ) |
| | from datasets import load_dataset |
| | from .utils import MODEL_DIR |
| |
|
| |
|
| | def train(): |
| | |
| | dataset = load_dataset("mnist") |
| | dataset = dataset.rename_column("label", "labels") |
| |
|
| | |
| | small_train_size = 2000 |
| | small_test_size = 500 |
| |
|
| | dataset["train"] = dataset["train"].select(range(small_train_size)) |
| | dataset["test"] = dataset["test"].select(range(small_test_size)) |
| |
|
| | |
| | processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") |
| |
|
| | def transform(examples): |
| | |
| | images = [img.convert("RGB") for img in examples["image"]] |
| | inputs = processor(images=images, return_tensors="pt") |
| | inputs["labels"] = examples["labels"] |
| | return inputs |
| |
|
| | |
| | dataset.set_transform(transform) |
| |
|
| | |
| | model = ViTForImageClassification.from_pretrained( |
| | "google/vit-base-patch16-224", |
| | num_labels=10, |
| | id2label={str(i): str(i) for i in range(10)}, |
| | label2id={str(i): i for i in range(10)}, |
| | ignore_mismatched_sizes=True, |
| | ) |
| |
|
| | |
| | training_args = TrainingArguments( |
| | output_dir="./results", |
| | remove_unused_columns=False, |
| | per_device_train_batch_size=16, |
| | eval_strategy="steps", |
| | num_train_epochs=3, |
| | fp16=False, |
| | save_steps=500, |
| | eval_steps=500, |
| | logging_steps=100, |
| | learning_rate=2e-4, |
| | push_to_hub=False, |
| | ) |
| |
|
| | trainer = Trainer( |
| | model=model, |
| | args=training_args, |
| | train_dataset=dataset["train"], |
| | eval_dataset=dataset["test"], |
| | ) |
| |
|
| | trainer.train() |
| |
|
| | |
| | model.save_pretrained(MODEL_DIR) |
| | processor.save_pretrained(MODEL_DIR) |
| |
|