File size: 2,562 Bytes
befde4f
5b5749d
c0b2962
5b5749d
 
c0b2962
befde4f
5b5749d
009d8e1
c0b2962
545058a
 
 
 
5b5749d
befde4f
eee069a
5b5749d
 
befde4f
009d8e1
 
c0b2962
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b5749d
 
 
 
 
 
 
 
 
 
c0b2962
 
5b5749d
c0b2962
009d8e1
c0b2962
 
749b7d2
c0b2962
 
 
5b5749d
c0b2962
 
 
 
 
 
 
 
 
 
 
5b5749d
c0b2962
009d8e1
c0b2962
 
 
 
 
 
5b5749d
 
c0b2962
5b5749d
009d8e1
 
e46b751
c0b2962
5b5749d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import SegformerForSemanticSegmentation
from PIL import Image
import os

# Config

device = torch.device("cpu")
target_list = ['Crack', 'ACrack', 'Wetspot', 'Efflorescence', 'Rust', 'Rockpocket', 'Hollowareas', 'Cavity',
               'Spalling', 'Graffiti', 'Weathering', 'Restformwork', 'ExposedRebars', 
               'Bearing', 'EJoint', 'Drainage', 'PEquipment', 'JTape', 'WConccor'
    
]


label2id = {l: i for i, l in enumerate(target_list)}
id2label = {i: l for l, i in label2id.items()}

# Dataset 

class MyDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.images =[f for f in os.listdir(image_dir) if f.endswith(".jpg")]
        self.transform = transform 

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = torch.zeros(len(target_list), img.shape[1], img.shape[2])

transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

dataset = MyDataset("data/train", transform=transform)
loader = DataLoader(dataset, batch_size=2, shuffle=True)

# Model

segformer = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/mit-b1",
    num_labels=len(target_list),
    id2label=id2label,
    label2id=label2id
).to(device)

class SegModel(nn.Module)
    def __init__(self, segformer):
        super().__init__()
        self.segformer = segformer
        self.upsample = nn.Upsample(scale_factor=4, mode='nearest')
    def forward(self, x):
        return self.upsample(self.segformer(x).logits)

model = SegModel(segformer).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()

# Training

for epoch in range(2):  # nur Demo
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1} done, Loss: {loss.item():.4f}")

# SAVE 

torch.save(model.state_dict(), "best_model_sate_dict.pth")
print(f"✅ Model weights saved as best_model_cpu.pth")