--- license: mit tags: - semantic-segmentation - unet - resnet34 - steel-defect-detection - severstal - pytorch datasets: - severstal-steel-defect-detection metrics: - dice --- # ResNet34‑UNet for Steel Defect Segmentation This repository hosts the **trained model weights** for a U‑Net with ResNet34 backbone, fine‑tuned for semantic segmentation of surface defects on steel sheets. The model classifies and localises four defect types and outputs pixel‑wise probability maps. > 🔗 **Full training and inference code is available on GitHub:** > [https://github.com/zyxdtt/cv-course-project/tree/main/Semantic%20Segmentation](https://github.com/zyxdtt/cv-course-project/tree/main/Semantic%20Segmentation) --- ## 🧠 Model Description | Property | Details | |----------|---------| | Architecture | U‑Net with encoder: ResNet34 (pretrained on ImageNet) | | Input size | 256 × 1600 pixels (single channel grayscale converted to RGB for pretrained backbone) | | Output | 4 probability maps (height 256, width 1600) with sigmoid activation | | Loss function | `L = BinaryCrossEntropy + (1 - Dice)` | | Optimiser | AdamW (initial LR = 1e-4) | | Training epochs | 10 | | Data augmentation | Random horizontal flip (p=0.5) | --- ## 📊 Performance on Validation Set The validation set consists of **1,333 images** (20% of the Severstal dataset). Evaluation metric is the **Dice coefficient**. ### Overall Metrics | Metric | Value | |--------|-------| | Best overall Dice | **0.6296** | | Optimal probability threshold | **0.45** | | Best validation loss | 0.4358 (epoch 10) | ### Per‑Class Dice (threshold = 0.45) | Class | Dice | |-------|------| | Defect class 1 | 0.651 | | Defect class 2 | 0.624 | | Defect class 3 | 0.637 | | Defect class 4 | 0.606 | ### Threshold Robustness The Dice score remains between 0.6293 and 0.6296 for thresholds from 0.3 to 0.7, with a maximum at 0.4, 0.45, and 0.5. This indicates that the model produces highly confident predictions (probabilities near 0 or 1). --- ## 🚀 Usage Example (PyTorch) ### Load the model and weights ```python import torch from torchvision import transforms from PIL import Image # Assume you have the model definition from the GitHub repo from model import UNetWithResNet34 # Instantiate model model = UNetWithResNet34(num_classes=4, pretrained=False) model.load_state_dict(torch.load("best.pth", map_location="cpu")) model.eval() # Preprocessing transform = transforms.Compose([ transforms.Resize((256, 1600)), transforms.ToTensor(), ]) # Inference image = Image.open("steel_sheet.png").convert("RGB") input_tensor = transform(image).unsqueeze(0) # shape: (1, 3, 256, 1600) with torch.no_grad(): logits = model(input_tensor) probs = torch.sigmoid(logits) # shape: (1, 4, 256, 1600) # Binarize at optimal threshold masks = (probs > 0.45).float() # shape: (1, 4, 256, 1600) Visualise the masks python import matplotlib.pyplot as plt # Show class 1 mask plt.imshow(masks[0, 0], cmap='gray') plt.title("Defect Class 1 Prediction") plt.axis('off') plt.show() 📁 Files in this repository File Description best.pth Model weights achieving lowest validation loss (0.4358) config.json (Optional) Training hyperparameters README.md This file 📝 Notes from the Test Report The model successfully learns to detect major defect regions but struggles with small or subtle defects. Defect sizes vary significantly (small spots to large continuous streaks). Multiple defect classes can appear on the same image. The loss curves show no overfitting; further training with stronger augmentation or pseudo‑labeling could improve the Dice score above 0.85. 🔗 Related Resources Source code, training scripts, and design documents: GitHub repository Dataset: Severstal Steel Defect Detection U‑Net paper: Ronneberger et al., MICCAI 2015 ResNet paper: He et al., CVPR 2016 📄 License MIT