ResNet34‑UNet for Steel Defect Segmentation

This repository hosts the trained model weights for a U‑Net with ResNet34 backbone, fine‑tuned for semantic segmentation of surface defects on steel sheets. The model classifies and localises four defect types and outputs pixel‑wise probability maps.

🔗 Full training and inference code is available on GitHub:
https://github.com/zyxdtt/cv-course-project/tree/main/Semantic%20Segmentation


🧠 Model Description

Property Details
Architecture U‑Net with encoder: ResNet34 (pretrained on ImageNet)
Input size 256 × 1600 pixels (single channel grayscale converted to RGB for pretrained backbone)
Output 4 probability maps (height 256, width 1600) with sigmoid activation
Loss function L = BinaryCrossEntropy + (1 - Dice)
Optimiser AdamW (initial LR = 1e-4)
Training epochs 10
Data augmentation Random horizontal flip (p=0.5)

📊 Performance on Validation Set

The validation set consists of 1,333 images (20% of the Severstal dataset). Evaluation metric is the Dice coefficient.

Overall Metrics

Metric Value
Best overall Dice 0.6296
Optimal probability threshold 0.45
Best validation loss 0.4358 (epoch 10)

Per‑Class Dice (threshold = 0.45)

Class Dice
Defect class 1 0.651
Defect class 2 0.624
Defect class 3 0.637
Defect class 4 0.606

Threshold Robustness

The Dice score remains between 0.6293 and 0.6296 for thresholds from 0.3 to 0.7, with a maximum at 0.4, 0.45, and 0.5. This indicates that the model produces highly confident predictions (probabilities near 0 or 1).


🚀 Usage Example (PyTorch)

Load the model and weights

import torch
from torchvision import transforms
from PIL import Image

# Assume you have the model definition from the GitHub repo
from model import UNetWithResNet34

# Instantiate model
model = UNetWithResNet34(num_classes=4, pretrained=False)
model.load_state_dict(torch.load("best.pth", map_location="cpu"))
model.eval()

# Preprocessing
transform = transforms.Compose([
    transforms.Resize((256, 1600)),
    transforms.ToTensor(),
])

# Inference
image = Image.open("steel_sheet.png").convert("RGB")
input_tensor = transform(image).unsqueeze(0)  # shape: (1, 3, 256, 1600)

with torch.no_grad():
    logits = model(input_tensor)
    probs = torch.sigmoid(logits)            # shape: (1, 4, 256, 1600)

# Binarize at optimal threshold
masks = (probs > 0.45).float()               # shape: (1, 4, 256, 1600)
Visualise the masks
python
import matplotlib.pyplot as plt

# Show class 1 mask
plt.imshow(masks[0, 0], cmap='gray')
plt.title("Defect Class 1 Prediction")
plt.axis('off')
plt.show()
📁 Files in this repository
File	Description
best.pth	Model weights achieving lowest validation loss (0.4358)
config.json	(Optional) Training hyperparameters
README.md	This file
📝 Notes from the Test Report
The model successfully learns to detect major defect regions but struggles with small or subtle defects.

Defect sizes vary significantly (small spots to large continuous streaks).

Multiple defect classes can appear on the same image.

The loss curves show no overfitting; further training with stronger augmentation or pseudo‑labeling could improve the Dice score above 0.85.

🔗 Related Resources
Source code, training scripts, and design documents: GitHub repository

Dataset: Severstal Steel Defect Detection

U‑Net paper: Ronneberger et al., MICCAI 2015

ResNet paper: He et al., CVPR 2016

📄 License
MIT
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support