| from transformers import ViTImageProcessor, ViTForImageClassification |
|
|
| model_name = "google/vit-base-patch16-224" |
|
|
| processor = ViTImageProcessor.from_pretrained(model_name) |
| model = ViTForImageClassification.from_pretrained( |
| model_name, |
| num_labels=8, |
| id2label={ |
| 0: "Mercury", |
| 1: "Venus", |
| 2: "Earth", |
| 3: "Mars", |
| 4: "Jupiter", |
| 5: "Saturn", |
| 6: "Uranus", |
| 7: "Neptune" |
| }, |
| label2id={ |
| "Mercury": 0, |
| "Venus": 1, |
| "Earth": 2, |
| "Mars": 3, |
| "Jupiter": 4, |
| "Saturn": 5, |
| "Uranus": 6, |
| "Neptune": 7 |
| } |
| ) |
|
|
| model.save_pretrained("planet-image-classifier") |
| processor.save_pretrained("planet-image-classifier") |