Spaces:
Sleeping
Sleeping
File size: 2,562 Bytes
befde4f 5b5749d c0b2962 5b5749d c0b2962 befde4f 5b5749d 009d8e1 c0b2962 545058a 5b5749d befde4f eee069a 5b5749d befde4f 009d8e1 c0b2962 5b5749d c0b2962 5b5749d c0b2962 009d8e1 c0b2962 749b7d2 c0b2962 5b5749d c0b2962 5b5749d c0b2962 009d8e1 c0b2962 5b5749d c0b2962 5b5749d 009d8e1 e46b751 c0b2962 5b5749d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import SegformerForSemanticSegmentation
from PIL import Image
import os
# Config
device = torch.device("cpu")
target_list = ['Crack', 'ACrack', 'Wetspot', 'Efflorescence', 'Rust', 'Rockpocket', 'Hollowareas', 'Cavity',
'Spalling', 'Graffiti', 'Weathering', 'Restformwork', 'ExposedRebars',
'Bearing', 'EJoint', 'Drainage', 'PEquipment', 'JTape', 'WConccor'
]
label2id = {l: i for i, l in enumerate(target_list)}
id2label = {i: l for l, i in label2id.items()}
# Dataset
class MyDataset(Dataset):
def __init__(self, image_dir, transform=None):
self.image_dir = image_dir
self.images =[f for f in os.listdir(image_dir) if f.endswith(".jpg")]
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = os.path.join(self.image_dir, self.images[idx])
img = Image.open(img_path).convert("RGB")
if self.transform:
img = self.transform(img)
label = torch.zeros(len(target_list), img.shape[1], img.shape[2])
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
dataset = MyDataset("data/train", transform=transform)
loader = DataLoader(dataset, batch_size=2, shuffle=True)
# Model
segformer = SegformerForSemanticSegmentation.from_pretrained(
"nvidia/mit-b1",
num_labels=len(target_list),
id2label=id2label,
label2id=label2id
).to(device)
class SegModel(nn.Module)
def __init__(self, segformer):
super().__init__()
self.segformer = segformer
self.upsample = nn.Upsample(scale_factor=4, mode='nearest')
def forward(self, x):
return self.upsample(self.segformer(x).logits)
model = SegModel(segformer).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()
# Training
for epoch in range(2): # nur Demo
for imgs, labels in loader:
imgs, labels = imgs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(imgs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1} done, Loss: {loss.item():.4f}")
# SAVE
torch.save(model.state_dict(), "best_model_sate_dict.pth")
print(f"✅ Model weights saved as best_model_cpu.pth")
|