library_name: transformers tags: []

🐱🐢 ViT-Base Cats vs. Dogs Classifier

Fine-tuned google/vit-base-patch16-224-in21k on microsoft/cats_vs_dogs achieving 99.49% validation accuracy in 3 epochs.


Model Summary

Property Value
Base Model google/vit-base-patch16-224-in21k
Dataset microsoft/cats_vs_dogs
Task Binary Image Classification (Cat / Dog)
Final Accuracy 99.49%
Final Val Loss 0.02304
Total Training Steps 2,811
Training Time ~23 min (1,401s)
Throughput 32.08 samples/sec

Training Configuration

TrainingArguments(
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    warmup_steps=500,
    weight_decay=0.01,
    fp16=True,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
)

Epoch-by-Epoch Ablation

The table below captures the full training dynamics across all 3 epochs, revealing how the model transitions from rapid generalization to fine refinement.

Epoch Train Loss Val Loss Val Accuracy Ξ” Accuracy Notes
1 0.039424 0.035078 98.9589% β€” Strong generalization from ImageNet-21k pretraining
2 0.002892 0.023039 99.4928% +0.5339% Sharpest accuracy gain; val loss hits minimum
3 0.001170 0.023760 99.4928% +0.0000% Train loss continues falling; val loss slightly rises

Key Observations

1. Pretraining quality dominates early performance. The model achieves 98.96% accuracy after just one epoch, indicating that vit-base-patch16-224-in21k's ImageNet-21k representations transfer extremely well to the cats vs. dogs domain with minimal adaptation. The heavy lifting is done by the pretrained patch embeddings and attention heads, not gradient updates.

2. Epoch 2 is the inflection point. The largest accuracy jump (+0.53%) occurs in epoch 2, coinciding with the global minimum of validation loss (0.02304). This suggests that the classifier head and the final transformer blocks are being effectively fine-tuned during this window, while the warmup phase (500 steps β‰ˆ first ~18% of training) stabilizes the learning rate.

3. Epoch 3 shows classic overfitting onset. Training loss continues to fall significantly (0.00289 β†’ 0.00117, a 59.5% drop), yet validation loss increases from 0.02304 β†’ 0.02376 and accuracy plateaus. This is a textbook bias-variance divergence signal. The load_best_model_at_end=True setting correctly rolls back to the epoch 2 checkpoint, which is what is saved and pushed to this repo.

4. The train/val loss gap is diagnostic. By epoch 3, train loss (0.00117) is ~20Γ— lower than val loss (0.02376). For a dataset of this simplicity (binary, clear visual distinction), this gap reflects the model capacity (86M parameters) significantly exceeding task complexity, not a fundamental generalization failure. Regularization via weight_decay=0.01 and fp16 training contained this but did not fully prevent it.


Ablation: What Would Change These Results?

The following ablation axes were identified during analysis. These were not all experimentally tested but are grounded in the observed training dynamics.

A. Number of Epochs

Epochs Expected Outcome
1 ~98.96% β€” sufficient for most production use cases
2 ~99.49% β€” optimal (best val loss + accuracy)
3 ~99.49% β€” no gain; slight val loss increase
4+ Likely degradation without stronger regularization

Recommendation: Stop at epoch 2. The improvement from epoch 1β†’2 justifies the compute; epoch 2β†’3 does not.

B. Batch Size

Batch size of 16 was used. Larger batch sizes (32, 64) would:

  • Increase gradient stability, potentially smoothing the epoch 2 jump into a more gradual curve
  • Reduce effective regularization from stochastic noise, possibly worsening the epoch 3 overfitting
  • Improve throughput (currently 32.08 samples/sec)

C. Warmup Steps

500 warmup steps (~18% of total training) is relatively aggressive for a fine-tuning scenario. Reducing to 100–200 steps might accelerate epoch 1 convergence but risks destabilizing pretrained weights early. Given the already-strong epoch 1 performance, this is unlikely to improve final accuracy meaningfully.

D. Weight Decay

weight_decay=0.01 provided light L2 regularization. Given the epoch 3 overfitting signal, increasing to 0.05 or 0.1 may have allowed a third epoch to contribute meaningfully without val loss regression.

E. Layer Freezing

No layers were frozen β€” the full ViT backbone was fine-tuned. Freezing the first 6 transformer blocks (of 12) would:

  • Drastically reduce compute (~40% fewer gradient updates)
  • Potentially reduce overfitting
  • Risk slightly lower peak accuracy due to reduced adaptation capacity

Compute Profile

total_flos          : 3.48 Γ— 10¹⁸
train_runtime       : 1401.07s (~23.4 min)
train_samples/sec   : 32.08
train_steps/sec     : 2.006
global_step         : 2,811
avg_train_loss      : 0.04228

FLOPs of 3.48 Γ— 10¹⁸ for 3 epochs of ViT-Base fine-tuning on this dataset is consistent with expectations for full-backbone fine-tuning with fp16 on a single A100/V100-class GPU.


How to Use

from transformers import AutoModelForImageClassification, ViTImageProcessor
from PIL import Image
import torch

model = AutoModelForImageClassification.from_pretrained("AlaminI/vit-cats-vs-dogs_classifier")
processor = ViTImageProcessor.from_pretrained("AlaminI/vit-cats-vs-dogs_classifier")

image = Image.open("image.jpg").convert("RGB")
inputs = processor(images=image, return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits

predicted = logits.argmax(-1).item()
labels = {0: "cat", 1: "dog"}
print(f"Prediction: {labels[predicted]}")

Dataset

  • Source: microsoft/cats_vs_dogs
  • Split used: Built-in train/test split
  • Preprocessing: RGB conversion + ViTImageProcessor (resize to 224Γ—224, normalize with ImageNet stats)

Citation

If you use this model, please cite the base model and dataset:

@misc{vit-cats-vs-dogs,
  author    = {your_username},
  title     = {ViT-Base Fine-tuned on Cats vs. Dogs},
  year      = {2025},
  publisher = {Hugging Face},
  url       = {https://huggingface.co/AlaminI/vit-cats-vs-dogs_classifier}
}

License

This model is released under the Apache 2.0 License, consistent with the base model (google/vit-base-patch16-224-in21k).

Downloads last month
10
Safetensors
Model size
85.8M params
Tensor type
F32
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support