WireSegHR / README.md
MRiabov's picture
Update model card
e3db6f5 verified
---
license: apache-2.0
language:
- en
pipeline_tag: image-segmentation
library_name: transformers
tags:
- wire
- drone
- computer_vision
---
# WireSegHR (Segmentation Only)
This repository contains the segmentation-only implementation of the two-stage [WireSegHR model](https://arxiv.org/abs/2304.00221), training on the WireSegHR dataset plus the [TTPLA dataset](https://github.com/R3ab/ttpla_dataset).
## Quick Start
1) Get secrets necessary for fetching of the dataset:
You'll need a GDrive service account to fetch the WireSegHR dataset using scripts in this repo. Get a GDrive key as described [in this short README](scripts/drive-viewer-key-readme.md), and put it in `/secrets/drive-json.json`
2) Run:
```bash
scripts/setup.sh
```
This installs dependencies and merges the TTPLA dataset into the WireSegHR dataset format.
3) Train and run a quick inference check:
```bash
python3 train.py --config configs/default.yaml
python3 infer.py --config configs/default.yaml --image /path/to/image.jpg
```
The default config `default.yaml` is suitable for a 24GB VRAM GPU with support for bf16 (e.g., RTX 3090/4090).
<!-- For a quick RTX GPU setup, I recommend [vast.ai](https://cloud.vast.ai/?ref_id=162850) -->
## Project Overview
- Two-stage, global-to-local segmentation with a shared encoder and a fine decoder conditioned on the coarse stage.
- Full training loop with AMP (optional), Poly LR, periodic evaluation, checkpointing, and test visualizations (`train.py`).
- Dataset utilities under `src/wireseghr/data/` and model components under `src/wireseghr/model/`.
- Paper text and figures live in `paper-tex/` (`paper-tex/sections/` contains the Method, Results, etc.).
## Notes
- This is a segmentation-only codebase. Inpainting is out of scope here.
- Defaults locked: SegFormer MiT-B3 encoder, patch size 768, MinMax 6×6, global+binary mask conditioning with patch-cropped global map.
### Backbone Source
- HuggingFace Transformers SegFormer (e.g., `nvidia/mit-b3`). We set `num_channels` to match input channels.
- Alternative: TorchVision ResNet-50 (`backbone: resnet50`). The stem is adapted to the requested `in_channels`, and we expose features from `layer1`..`layer4` at strides 1/4, 1/8, 1/16, 1/32 with channels [256, 512, 1024, 2048].
## Dataset Convention
- Flat directories with numeric filenames; images are `.jpg`/`.jpeg`, masks are `.png`.
- Example (after split 85/5/10):
- `dataset/train/images/1.jpg, 2.jpg, ...` and `dataset/train/gts/1.png, 2.png, ...`
- `dataset/val/images/...` and `dataset/val/gts/...`
- `dataset/test/images/...` and `dataset/test/gts/...`
- Masks are binary: foreground = white (255), background = black (0).
- The loader strictly enforces numeric stems and 1:1 pairing of naming and will raise on file name mismatches.
Update `configs/default.yaml` with your paths under `data.train_images`, `data.train_masks`, etc. Defaults point to `dataset/train/images`, `dataset/train/gts`, and validation to `dataset/val/...`.
## Inference
- Single image (optionally save outputs to a directory):
```bash
python3 infer.py \
--config configs/default.yaml \
--ckpt ckpt_5000.pt \
--image dataset/test/images/123.jpg \
--out outputs/infer
```
- Compute metrics for a single image (requires a GT mask):
```bash
python3 infer.py \
--config configs/default.yaml \
--ckpt ckpt_5000.pt \
--image dataset/test/images/123.jpg \
--out outputs/infer \
--metrics \
--mask dataset/test/gts/123.png
```
- Run inference over the entire directory with metrics (images_dir sets the image directory, masks_dir sets the ground truth mask directory):
```bash
python3 infer.py \
--config configs/default.yaml \
--ckpt ckpt_5000.pt \
--images_dir dataset/test/images \
--out outputs/infer \
--metrics \
--masks_dir dataset/test/gts
```
Notes:
- Predictions are saved as 0/255 PNGs. For metrics, predictions are binarized with `> 0` to match training logic.
- Masks are matched by filename stem: `images/123.jpg``gts/123.png`.
## Benchmarking and Metrics
Benchmark mode times the model on a directory of images and reports coarse/fine/total latency statistics. When `--metrics` is provided, it also computes IoU/F1/Precision/Recall over the benchmark set (both fine and coarse outputs).
Example (uses `data.test_images` and `data.test_masks` from the config by default):
```bash
python3 infer.py \
--config configs/default.yaml \
--benchmark \
--ckpt ckpt_5000.pt \
--bench_warmup 2 \
--bench_limit 0 \
--bench_report_json outputs/bench_report.json \
--metrics
```
If your ground truth directory is different from `data.test_masks`, please override it with `--bench_masks_dir`:
```bash
python3 infer.py \
--config configs/default.yaml \
--benchmark \
--ckpt ckpt_5000.pt \
--bench_warmup 2 \
--bench_limit 0 \
--bench_report_json outputs/bench_report.json \
--metrics \
--bench_masks_dir /path/to/gts
```
You will see output like:
```
[WireSegHR][bench] Results (ms):
Coarse avg=50.16 p50=44.48 p95=76.78
Fine avg=534.38 p50=419.52 p95=1187.66
Total avg=584.54 p50=464.73 p95=1300.07
Target < 1000 ms per 3000x4000 image: YES
[WireSegHR][bench][Fine] IoU=0.6098 F1=0.7576 P=0.6418 R=0.9244
[WireSegHR][bench][Coarse] IoU=0.5315 F1=0.6941 P=0.5467 R=0.9502
```
**These metrics were obtained after 5000 iterations*
Optional: you can save a JSON timing report with `--bench_report_json`. Schema:
- `summary`
- `avg_ms`, `p50_ms`, `p95_ms`
- `avg_coarse_ms`, `avg_fine_ms`
- `images`
- `per_image`: list of objects with
- `path`, `H`, `W`, `t_coarse_ms`, `t_fine_ms`, `t_total_ms`
### Utils:
- Export your model to inference-only weights by scripts/strip_checkpoint.py
### Example:
Input:
<img src="images/026_input.png" alt="Input" width="50%"/>
Fine prediction:
<img src="images/026_fine_pred.png" alt="Fine prediction" width="50%"/>