Alic22 commited on
Commit
befde4f
·
verified ·
1 Parent(s): 8bd1f13

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +105 -0
train.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torchvision import transforms
4
+ from datasets import load_dataset
5
+ from transformers import (
6
+ SegformerForSemanticSegmentation,
7
+ SegformerFeatureExtractor,
8
+ Trainer,
9
+ TrainingArguments
10
+ )
11
+ import evaluate
12
+
13
+ # ------------------------------
14
+ # 1️⃣ Parameter
15
+ # ------------------------------
16
+ DATA_DIR = "path_to_dataset" # Pfad zu deinen Bild- und Maskenordnern
17
+ NUM_CLASSES = 3 # z.B. 3 Klassen: Hintergrund, Schaden, Rand
18
+ IMAGE_SIZE = 256 # Bildgröße für Training
19
+
20
+ # ------------------------------
21
+ # 2️⃣ Dataset laden
22
+ # ------------------------------
23
+ # Annahme: Dataset im ImageFolder Format mit Unterordnern 'train' und 'validation'
24
+ dataset = load_dataset("imagefolder", data_dir=DATA_DIR)
25
+
26
+ # Transformationen für Bilder
27
+ train_transforms = transforms.Compose([
28
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
29
+ transforms.ToTensor(),
30
+ ])
31
+
32
+ mask_transforms = transforms.Compose([
33
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
34
+ transforms.PILToTensor(), # Masken als Tensor
35
+ ])
36
+
37
+ # Preprocessing-Funktion
38
+ def preprocess(batch):
39
+ batch["pixel_values"] = [train_transforms(x) for x in batch["image"]]
40
+ # Masken als LongTensor für CrossEntropyLoss
41
+ batch["labels"] = [mask_transforms(x).long().squeeze(0) for x in batch["label"]]
42
+ return batch
43
+
44
+ dataset = dataset.map(preprocess)
45
+
46
+ # ------------------------------
47
+ # 3️⃣ Feature Extractor & Modell
48
+ # ------------------------------
49
+ feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/mit-b1")
50
+
51
+ model = SegformerForSemanticSegmentation.from_pretrained(
52
+ "nvidia/mit-b1",
53
+ num_labels=NUM_CLASSES,
54
+ )
55
+
56
+ # ------------------------------
57
+ # 4️⃣ Metrics
58
+ # ------------------------------
59
+ metric = evaluate.load("mean_iou")
60
+
61
+ def compute_metrics(p):
62
+ preds = np.argmax(p.predictions, axis=1)
63
+ return metric.compute(predictions=preds, references=p.label_ids, num_labels=NUM_CLASSES)
64
+
65
+ # ------------------------------
66
+ # 5️⃣ TrainingArguments
67
+ # ------------------------------
68
+ training_args = TrainingArguments(
69
+ output_dir="./results",
70
+ per_device_train_batch_size=4,
71
+ per_device_eval_batch_size=4,
72
+ num_train_epochs=10,
73
+ learning_rate=5e-5,
74
+ evaluation_strategy="steps",
75
+ save_strategy="steps",
76
+ save_steps=200,
77
+ eval_steps=200,
78
+ logging_steps=50,
79
+ fp16=True, # Mixed Precision, falls GPU verfügbar
80
+ remove_unused_columns=False, # wichtig für Segmentation
81
+ )
82
+
83
+ # ------------------------------
84
+ # 6️⃣ Trainer
85
+ # ------------------------------
86
+ trainer = Trainer(
87
+ model=model,
88
+ args=training_args,
89
+ train_dataset=dataset["train"],
90
+ eval_dataset=dataset["validation"],
91
+ compute_metrics=compute_metrics,
92
+ )
93
+
94
+ # ------------------------------
95
+ # 7️⃣ Training starten
96
+ # ------------------------------
97
+ trainer.train()
98
+
99
+ # ------------------------------
100
+ # 8️⃣ Modell speichern
101
+ # ------------------------------
102
+ trainer.save_model("my_segformer_model")
103
+ feature_extractor.save_pretrained("my_segformer_model")
104
+
105
+ print("✅ Training abgeschlossen und Modell gespeichert!")