DKatheesrupan commited on
Commit
d73d1dc
Β·
verified Β·
1 Parent(s): 1336679

Upload train_vit_oxford_pets.py

Browse files
Files changed (1) hide show
  1. train_vit_oxford_pets.py +118 -0
train_vit_oxford_pets.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import numpy as np
4
+ from datasets import load_dataset
5
+ from transformers import (
6
+ ViTImageProcessor,
7
+ ViTForImageClassification,
8
+ TrainingArguments,
9
+ Trainer,
10
+ )
11
+ import evaluate
12
+ from huggingface_hub import notebook_login
13
+
14
+ # ── 1. Hugging Face Login ─────────────────────────────────
15
+ # Erstelle einen Token auf: https://huggingface.co/settings/tokens
16
+ # Typ: "write"
17
+ notebook_login() # gibt einen Login-Dialog aus
18
+
19
+ # ── 2. Dataset laden ──────────────────────────────────────
20
+ print("Lade Dataset...")
21
+ dataset = load_dataset("pcuenq/oxford-pets")
22
+ print(dataset)
23
+
24
+ # Labels extrahieren
25
+ label_names = dataset["train"].features["label"].names
26
+ id2label = {i: label for i, label in enumerate(label_names)}
27
+ label2id = {label: i for i, label in enumerate(label_names)}
28
+ num_labels = len(label_names)
29
+ print(f"Anzahl Klassen: {num_labels}")
30
+ print("Labels:", label_names)
31
+
32
+ # ── 3. Preprocessing ──────────────────────────────────────
33
+ MODEL_NAME = "google/vit-base-patch16-224-in21k"
34
+ processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
35
+
36
+ def preprocess(batch):
37
+ images = [img.convert("RGB") for img in batch["image"]]
38
+ inputs = processor(images=images, return_tensors="pt")
39
+ inputs["labels"] = batch["label"]
40
+ return inputs
41
+
42
+ dataset = dataset.map(preprocess, batched=True, batch_size=32)
43
+ dataset.set_format(type="torch", columns=["pixel_values", "labels"])
44
+
45
+ # Train/Val Split (falls kein eigener Val-Split vorhanden)
46
+ if "validation" not in dataset:
47
+ split = dataset["train"].train_test_split(test_size=0.15, seed=42)
48
+ train_ds = split["train"]
49
+ val_ds = split["test"]
50
+ else:
51
+ train_ds = dataset["train"]
52
+ val_ds = dataset["validation"]
53
+
54
+ print(f"Train: {len(train_ds)}, Val: {len(val_ds)}")
55
+
56
+ # ── 4. Modell laden ───────────────────────────────────────
57
+ model = ViTForImageClassification.from_pretrained(
58
+ MODEL_NAME,
59
+ num_labels=num_labels,
60
+ id2label=id2label,
61
+ label2id=label2id,
62
+ ignore_mismatched_sizes=True,
63
+ )
64
+
65
+ # ── 5. Metriken ───────────────────────────────────────────
66
+ accuracy_metric = evaluate.load("accuracy")
67
+
68
+ def compute_metrics(eval_pred):
69
+ logits, labels = eval_pred
70
+ predictions = np.argmax(logits, axis=-1)
71
+ return accuracy_metric.compute(predictions=predictions, references=labels)
72
+
73
+ # ── 6. Training ───────────────────────────────────────────
74
+ # WICHTIG: Ersetze "DEIN_HF_USERNAME" mit deinem Hugging Face Benutzernamen!
75
+ HF_USERNAME = "DEIN_HF_USERNAME"
76
+ MODEL_REPO = f"{HF_USERNAME}/vit-oxford-pets"
77
+
78
+ training_args = TrainingArguments(
79
+ output_dir="./vit-oxford-pets",
80
+ num_train_epochs=5,
81
+ per_device_train_batch_size=32,
82
+ per_device_eval_batch_size=32,
83
+ warmup_steps=200,
84
+ weight_decay=0.01,
85
+ logging_dir="./logs",
86
+ logging_steps=50,
87
+ evaluation_strategy="epoch",
88
+ save_strategy="epoch",
89
+ load_best_model_at_end=True,
90
+ metric_for_best_model="accuracy",
91
+ push_to_hub=True,
92
+ hub_model_id=MODEL_REPO,
93
+ report_to="none",
94
+ )
95
+
96
+ trainer = Trainer(
97
+ model=model,
98
+ args=training_args,
99
+ train_dataset=train_ds,
100
+ eval_dataset=val_ds,
101
+ compute_metrics=compute_metrics,
102
+ )
103
+
104
+ # ── 7. Training starten ───────────────────────────────────
105
+ print("Starte Training...")
106
+ train_result = trainer.train()
107
+ print("Training abgeschlossen!")
108
+
109
+ # Trainings-Log fΓΌr README speichern
110
+ log_history = trainer.state.log_history
111
+ print("\nTrainings-Log:")
112
+ for entry in log_history:
113
+ if "eval_accuracy" in entry:
114
+ print(entry)
115
+
116
+ # ── 8. Modell auf Hugging Face hochladen ──────────────────
117
+ trainer.push_to_hub()
118
+ print(f"\nModell hochgeladen: https://huggingface.co/{MODEL_REPO}")