Segformer85Mv1 / predict.py
WEN0256's picture
Add Segformer85Mv2 (fine-tuned on Orchard Navigation, autumn+different camera). v1 unchanged.
2fd5923 verified
"""Segformer85Mv1 β€” apple-orchard semantic segmentation inference.
Usage:
python predict.py input.jpg # writes input_pred.png + input_overlay.jpg
python predict.py --dir frames/ --out out/ # batch process a folder
Classes (id β†’ name):
0 tree 1 ground 2 person 3 sky
4 road 5 mountain 6 building 7 background
"""
from __future__ import annotations
import argparse
import os
from pathlib import Path
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from transformers import SegformerForSemanticSegmentation
# ─── config ───
BASE_MODEL = "nvidia/segformer-b5-finetuned-ade-640-640"
WEIGHTS_PATH = os.environ.get("SEGFORMER85M_WEIGHTS", "Segformer85Mv2.pt") # default v2 (best generalization); set env var or --weights to use v1
NAMES = ["tree", "ground", "person", "sky", "road", "mountain", "building", "background"]
PALETTE = np.array([
[60, 220, 60], # tree - green
[40, 100, 160], # ground - brown
[40, 40, 230], # person - red
[230, 200, 60], # sky - cyan
[140, 140, 140], # road - gray
[180, 60, 180], # mountain - purple
[50, 220, 220], # building - yellow
[100, 100, 100], # background - mid-gray
], dtype=np.uint8)
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
def load_model(weights_path: str | Path = WEIGHTS_PATH, device: str = "cuda"):
"""Load Segformer85Mv1. Returns model in eval mode on the target device."""
model = SegformerForSemanticSegmentation.from_pretrained(
BASE_MODEL,
num_labels=len(NAMES),
id2label={i: n for i, n in enumerate(NAMES)},
label2id={n: i for i, n in enumerate(NAMES)},
ignore_mismatched_sizes=True,
).to(device)
ckpt = torch.load(weights_path, map_location=device, weights_only=False)
state = ckpt["model"] if "model" in ckpt else ckpt
model.load_state_dict(state)
model.eval()
return model
def preprocess(bgr_img: np.ndarray) -> tuple[torch.Tensor, tuple[int, int]]:
"""BGR uint8 image β†’ normalized tensor sized to 32 multiples; returns (tensor, original (H,W))."""
H, W = bgr_img.shape[:2]
H32, W32 = (H // 32) * 32, (W // 32) * 32
if H32 == 0 or W32 == 0:
raise ValueError(f"Image too small: {W}x{H}")
rgb = cv2.cvtColor(cv2.resize(bgr_img, (W32, H32)), cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
rgb = (rgb - IMAGENET_MEAN) / IMAGENET_STD
x = torch.from_numpy(rgb.transpose(2, 0, 1)).unsqueeze(0).float()
return x, (H, W)
def predict(model, bgr_img: np.ndarray, device: str = "cuda") -> np.ndarray:
"""Run inference on one BGR image. Returns (H,W) uint8 mask with class ids 0..7."""
x, (H, W) = preprocess(bgr_img)
x = x.to(device)
with torch.no_grad():
logits = model(pixel_values=x).logits
logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False)
return logits.argmax(1)[0].cpu().numpy().astype(np.uint8)
def colorize(mask: np.ndarray) -> np.ndarray:
"""class-id mask (H,W) β†’ BGR color visualization (H,W,3)."""
return PALETTE[mask]
def overlay(bgr_img: np.ndarray, mask: np.ndarray, alpha: float = 0.45) -> np.ndarray:
"""Blend prediction over original image."""
return cv2.addWeighted(bgr_img, 1 - alpha, colorize(mask), alpha, 0)
def main():
ap = argparse.ArgumentParser(description="Segformer85Mv1 inference (8-class outdoor segmentation).")
ap.add_argument("input", nargs="?", help="Single image path")
ap.add_argument("--dir", help="Directory of images to process")
ap.add_argument("--out", default=".", help="Output directory")
ap.add_argument("--weights", default=WEIGHTS_PATH, help="Path to Segformer85Mv1.pt")
ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
args = ap.parse_args()
if not args.input and not args.dir:
ap.print_help()
return
print(f"loading model from {args.weights} on {args.device} ...")
model = load_model(args.weights, device=args.device)
out_dir = Path(args.out); out_dir.mkdir(parents=True, exist_ok=True)
paths = []
if args.dir:
paths = sorted(p for p in Path(args.dir).iterdir() if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".bmp"})
if args.input:
paths.append(Path(args.input))
for p in paths:
img = cv2.imread(str(p))
if img is None:
print(f" skip (unreadable): {p}")
continue
mask = predict(model, img, device=args.device)
cv2.imwrite(str(out_dir / f"{p.stem}_pred.png"), mask) # raw class-id mask
cv2.imwrite(str(out_dir / f"{p.stem}_overlay.jpg"), overlay(img, mask)) # visualization
# quick stats
counts = np.bincount(mask.flatten(), minlength=len(NAMES))
top = counts.argmax()
print(f" {p.name:<40} top class: {NAMES[top]} ({100*counts[top]/counts.sum():.1f}%)")
print(f"\noutputs -> {out_dir.resolve()}")
if __name__ == "__main__":
main()