Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| from torch.utils.data import DataLoader, Dataset | |
| from transformers import SegformerForSemanticSegmentation | |
| from PIL import Image | |
| import os | |
| # Config | |
| device = torch.device("cpu") | |
| target_list = ['Crack', 'ACrack', 'Wetspot', 'Efflorescence', 'Rust', 'Rockpocket', 'Hollowareas', 'Cavity', | |
| 'Spalling', 'Graffiti', 'Weathering', 'Restformwork', 'ExposedRebars', | |
| 'Bearing', 'EJoint', 'Drainage', 'PEquipment', 'JTape', 'WConccor' | |
| ] | |
| label2id = {l: i for i, l in enumerate(target_list)} | |
| id2label = {i: l for l, i in label2id.items()} | |
| # Dataset | |
| class MyDataset(Dataset): | |
| def __init__(self, image_dir, transform=None): | |
| self.image_dir = image_dir | |
| self.images =[f for f in os.listdir(image_dir) if f.endswith(".jpg")] | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.images) | |
| def __getitem__(self, idx): | |
| img_path = os.path.join(self.image_dir, self.images[idx]) | |
| img = Image.open(img_path).convert("RGB") | |
| if self.transform: | |
| img = self.transform(img) | |
| label = torch.zeros(len(target_list), img.shape[1], img.shape[2]) | |
| transform = transforms.Compose([ | |
| transforms.Resize((512, 512)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| dataset = MyDataset("data/train", transform=transform) | |
| loader = DataLoader(dataset, batch_size=2, shuffle=True) | |
| # Model | |
| segformer = SegformerForSemanticSegmentation.from_pretrained( | |
| "nvidia/mit-b1", | |
| num_labels=len(target_list), | |
| id2label=id2label, | |
| label2id=label2id | |
| ).to(device) | |
| class SegModel(nn.Module) | |
| def __init__(self, segformer): | |
| super().__init__() | |
| self.segformer = segformer | |
| self.upsample = nn.Upsample(scale_factor=4, mode='nearest') | |
| def forward(self, x): | |
| return self.upsample(self.segformer(x).logits) | |
| model = SegModel(segformer).to(device) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) | |
| criterion = nn.BCEWithLogitsLoss() | |
| # Training | |
| for epoch in range(2): # nur Demo | |
| for imgs, labels in loader: | |
| imgs, labels = imgs.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(imgs) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| print(f"Epoch {epoch+1} done, Loss: {loss.item():.4f}") | |
| # SAVE | |
| torch.save(model.state_dict(), "best_model_sate_dict.pth") | |
| print(f"✅ Model weights saved as best_model_cpu.pth") | |