Masterarbeit / train.py
Alic22's picture
Update train.py
545058a verified
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import SegformerForSemanticSegmentation
from PIL import Image
import os
# Config
device = torch.device("cpu")
target_list = ['Crack', 'ACrack', 'Wetspot', 'Efflorescence', 'Rust', 'Rockpocket', 'Hollowareas', 'Cavity',
'Spalling', 'Graffiti', 'Weathering', 'Restformwork', 'ExposedRebars',
'Bearing', 'EJoint', 'Drainage', 'PEquipment', 'JTape', 'WConccor'
]
label2id = {l: i for i, l in enumerate(target_list)}
id2label = {i: l for l, i in label2id.items()}
# Dataset
class MyDataset(Dataset):
def __init__(self, image_dir, transform=None):
self.image_dir = image_dir
self.images =[f for f in os.listdir(image_dir) if f.endswith(".jpg")]
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = os.path.join(self.image_dir, self.images[idx])
img = Image.open(img_path).convert("RGB")
if self.transform:
img = self.transform(img)
label = torch.zeros(len(target_list), img.shape[1], img.shape[2])
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
dataset = MyDataset("data/train", transform=transform)
loader = DataLoader(dataset, batch_size=2, shuffle=True)
# Model
segformer = SegformerForSemanticSegmentation.from_pretrained(
"nvidia/mit-b1",
num_labels=len(target_list),
id2label=id2label,
label2id=label2id
).to(device)
class SegModel(nn.Module)
def __init__(self, segformer):
super().__init__()
self.segformer = segformer
self.upsample = nn.Upsample(scale_factor=4, mode='nearest')
def forward(self, x):
return self.upsample(self.segformer(x).logits)
model = SegModel(segformer).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()
# Training
for epoch in range(2): # nur Demo
for imgs, labels in loader:
imgs, labels = imgs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(imgs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1} done, Loss: {loss.item():.4f}")
# SAVE
torch.save(model.state_dict(), "best_model_sate_dict.pth")
print(f"✅ Model weights saved as best_model_cpu.pth")