ResNet34-UNet / README.md
zyxdtt's picture
Create README.md
7d213f2 verified
---
license: mit
tags:
- semantic-segmentation
- unet
- resnet34
- steel-defect-detection
- severstal
- pytorch
datasets:
- severstal-steel-defect-detection
metrics:
- dice
---
# ResNet34‑UNet for Steel Defect Segmentation
This repository hosts the **trained model weights** for a U‑Net with ResNet34 backbone, fine‑tuned for semantic segmentation of surface defects on steel sheets. The model classifies and localises four defect types and outputs pixel‑wise probability maps.
> 🔗 **Full training and inference code is available on GitHub:**
> [https://github.com/zyxdtt/cv-course-project/tree/main/Semantic%20Segmentation](https://github.com/zyxdtt/cv-course-project/tree/main/Semantic%20Segmentation)
---
## 🧠 Model Description
| Property | Details |
|----------|---------|
| Architecture | U‑Net with encoder: ResNet34 (pretrained on ImageNet) |
| Input size | 256 × 1600 pixels (single channel grayscale converted to RGB for pretrained backbone) |
| Output | 4 probability maps (height 256, width 1600) with sigmoid activation |
| Loss function | `L = BinaryCrossEntropy + (1 - Dice)` |
| Optimiser | AdamW (initial LR = 1e-4) |
| Training epochs | 10 |
| Data augmentation | Random horizontal flip (p=0.5) |
---
## 📊 Performance on Validation Set
The validation set consists of **1,333 images** (20% of the Severstal dataset). Evaluation metric is the **Dice coefficient**.
### Overall Metrics
| Metric | Value |
|--------|-------|
| Best overall Dice | **0.6296** |
| Optimal probability threshold | **0.45** |
| Best validation loss | 0.4358 (epoch 10) |
### Per‑Class Dice (threshold = 0.45)
| Class | Dice |
|-------|------|
| Defect class 1 | 0.651 |
| Defect class 2 | 0.624 |
| Defect class 3 | 0.637 |
| Defect class 4 | 0.606 |
### Threshold Robustness
The Dice score remains between 0.6293 and 0.6296 for thresholds from 0.3 to 0.7, with a maximum at 0.4, 0.45, and 0.5. This indicates that the model produces highly confident predictions (probabilities near 0 or 1).
---
## 🚀 Usage Example (PyTorch)
### Load the model and weights
```python
import torch
from torchvision import transforms
from PIL import Image
# Assume you have the model definition from the GitHub repo
from model import UNetWithResNet34
# Instantiate model
model = UNetWithResNet34(num_classes=4, pretrained=False)
model.load_state_dict(torch.load("best.pth", map_location="cpu"))
model.eval()
# Preprocessing
transform = transforms.Compose([
transforms.Resize((256, 1600)),
transforms.ToTensor(),
])
# Inference
image = Image.open("steel_sheet.png").convert("RGB")
input_tensor = transform(image).unsqueeze(0) # shape: (1, 3, 256, 1600)
with torch.no_grad():
logits = model(input_tensor)
probs = torch.sigmoid(logits) # shape: (1, 4, 256, 1600)
# Binarize at optimal threshold
masks = (probs > 0.45).float() # shape: (1, 4, 256, 1600)
Visualise the masks
python
import matplotlib.pyplot as plt
# Show class 1 mask
plt.imshow(masks[0, 0], cmap='gray')
plt.title("Defect Class 1 Prediction")
plt.axis('off')
plt.show()
📁 Files in this repository
File Description
best.pth Model weights achieving lowest validation loss (0.4358)
config.json (Optional) Training hyperparameters
README.md This file
📝 Notes from the Test Report
The model successfully learns to detect major defect regions but struggles with small or subtle defects.
Defect sizes vary significantly (small spots to large continuous streaks).
Multiple defect classes can appear on the same image.
The loss curves show no overfitting; further training with stronger augmentation or pseudo‑labeling could improve the Dice score above 0.85.
🔗 Related Resources
Source code, training scripts, and design documents: GitHub repository
Dataset: Severstal Steel Defect Detection
U‑Net paper: Ronneberger et al., MICCAI 2015
ResNet paper: He et al., CVPR 2016
📄 License
MIT