| | --- |
| | license: mit |
| | --- |
| | # Model Card for Model ViT fine tuning on CiFAR10 |
| |
|
| | <!-- Provide a quick summary of what the model is/does. --> |
| |
|
| | It's a toy experiemnt of fine tuning ViT by using huggingface transformers. |
| |
|
| | ## Model Details |
| |
|
| | It's fine tuned on CiFAR10 for 1000 steps, and achieved accuracy of 98.7% on test split. |
| |
|
| | ### Model Description |
| |
|
| | <!-- Provide a longer summary of what this model is. --> |
| |
|
| |
|
| |
|
| | - **Developed by:** verypro |
| | - **Model type:** Vision Transformer |
| | - **License:** MIT |
| | - **Finetuned from model [optional]:** google/vit-base-patch16-224 |
| |
|
| | ## Uses |
| |
|
| | <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. --> |
| |
|
| | ```python |
| | from transformers import ViTImageProcessor, ViTForImageClassification |
| | from torchvision import datasets |
| | |
| | # # 初始化模型和特征提取器 |
| | image_processor = ViTImageProcessor.from_pretrained('verypro/vit-base-patch16-224-cifar10') |
| | model = ViTForImageClassification.from_pretrained('verypro/vit-base-patch16-224-cifar10') |
| | |
| | |
| | # 加载 CIFAR10 数据集 |
| | test_dataset = datasets.CIFAR10(root='./data', train=False, download=True) |
| | |
| | sample = test_dataset[0] |
| | image = sample[0] |
| | gt_label = sample[1] |
| | |
| | # 保存原始图像,并打印其标签 |
| | image.save("original.png") |
| | print(f"Ground truth class: '{test_dataset.classes[gt_label]}'") |
| | |
| | inputs = image_processor(image, return_tensors="pt") |
| | outputs = model(**inputs) |
| | |
| | logits = outputs.logits |
| | print(logits) |
| | |
| | predicted_class_idx = logits.argmax(-1).item() |
| | predicted_class_label = test_dataset.classes[predicted_class_idx] |
| | print(f"Predicted class: '{predicted_class_label}', confidence: {logits[0, predicted_class_idx]:.2f}") |
| | ``` |
| |
|
| | The output of above code snippets should be like: |
| |
|
| | ```bash |
| | Ground truth class: 'cat' |
| | tensor([[-1.1497, -0.1080, -0.7349, 9.2517, -1.3094, 0.5403, -0.9521, -1.0223, |
| | -1.4102, -1.5389]], grad_fn=<AddmmBackward0>) |
| | Predicted class: 'cat', confidence: 9.25 |
| | ``` |
| |
|