"""Segformer85Mv1 — apple-orchard semantic segmentation inference. Usage: python predict.py input.jpg # writes input_pred.png + input_overlay.jpg python predict.py --dir frames/ --out out/ # batch process a folder Classes (id → name): 0 tree 1 ground 2 person 3 sky 4 road 5 mountain 6 building 7 background """ from __future__ import annotations import argparse import os from pathlib import Path import cv2 import numpy as np import torch import torch.nn.functional as F from transformers import SegformerForSemanticSegmentation # ─── config ─── BASE_MODEL = "nvidia/segformer-b5-finetuned-ade-640-640" WEIGHTS_PATH = os.environ.get("SEGFORMER85M_WEIGHTS", "Segformer85Mv2.pt") # default v2 (best generalization); set env var or --weights to use v1 NAMES = ["tree", "ground", "person", "sky", "road", "mountain", "building", "background"] PALETTE = np.array([ [60, 220, 60], # tree - green [40, 100, 160], # ground - brown [40, 40, 230], # person - red [230, 200, 60], # sky - cyan [140, 140, 140], # road - gray [180, 60, 180], # mountain - purple [50, 220, 220], # building - yellow [100, 100, 100], # background - mid-gray ], dtype=np.uint8) IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) def load_model(weights_path: str | Path = WEIGHTS_PATH, device: str = "cuda"): """Load Segformer85Mv1. Returns model in eval mode on the target device.""" model = SegformerForSemanticSegmentation.from_pretrained( BASE_MODEL, num_labels=len(NAMES), id2label={i: n for i, n in enumerate(NAMES)}, label2id={n: i for i, n in enumerate(NAMES)}, ignore_mismatched_sizes=True, ).to(device) ckpt = torch.load(weights_path, map_location=device, weights_only=False) state = ckpt["model"] if "model" in ckpt else ckpt model.load_state_dict(state) model.eval() return model def preprocess(bgr_img: np.ndarray) -> tuple[torch.Tensor, tuple[int, int]]: """BGR uint8 image → normalized tensor sized to 32 multiples; returns (tensor, original (H,W)).""" H, W = bgr_img.shape[:2] H32, W32 = (H // 32) * 32, (W // 32) * 32 if H32 == 0 or W32 == 0: raise ValueError(f"Image too small: {W}x{H}") rgb = cv2.cvtColor(cv2.resize(bgr_img, (W32, H32)), cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 rgb = (rgb - IMAGENET_MEAN) / IMAGENET_STD x = torch.from_numpy(rgb.transpose(2, 0, 1)).unsqueeze(0).float() return x, (H, W) def predict(model, bgr_img: np.ndarray, device: str = "cuda") -> np.ndarray: """Run inference on one BGR image. Returns (H,W) uint8 mask with class ids 0..7.""" x, (H, W) = preprocess(bgr_img) x = x.to(device) with torch.no_grad(): logits = model(pixel_values=x).logits logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False) return logits.argmax(1)[0].cpu().numpy().astype(np.uint8) def colorize(mask: np.ndarray) -> np.ndarray: """class-id mask (H,W) → BGR color visualization (H,W,3).""" return PALETTE[mask] def overlay(bgr_img: np.ndarray, mask: np.ndarray, alpha: float = 0.45) -> np.ndarray: """Blend prediction over original image.""" return cv2.addWeighted(bgr_img, 1 - alpha, colorize(mask), alpha, 0) def main(): ap = argparse.ArgumentParser(description="Segformer85Mv1 inference (8-class outdoor segmentation).") ap.add_argument("input", nargs="?", help="Single image path") ap.add_argument("--dir", help="Directory of images to process") ap.add_argument("--out", default=".", help="Output directory") ap.add_argument("--weights", default=WEIGHTS_PATH, help="Path to Segformer85Mv1.pt") ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") args = ap.parse_args() if not args.input and not args.dir: ap.print_help() return print(f"loading model from {args.weights} on {args.device} ...") model = load_model(args.weights, device=args.device) out_dir = Path(args.out); out_dir.mkdir(parents=True, exist_ok=True) paths = [] if args.dir: paths = sorted(p for p in Path(args.dir).iterdir() if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".bmp"}) if args.input: paths.append(Path(args.input)) for p in paths: img = cv2.imread(str(p)) if img is None: print(f" skip (unreadable): {p}") continue mask = predict(model, img, device=args.device) cv2.imwrite(str(out_dir / f"{p.stem}_pred.png"), mask) # raw class-id mask cv2.imwrite(str(out_dir / f"{p.stem}_overlay.jpg"), overlay(img, mask)) # visualization # quick stats counts = np.bincount(mask.flatten(), minlength=len(NAMES)) top = counts.argmax() print(f" {p.name:<40} top class: {NAMES[top]} ({100*counts[top]/counts.sum():.1f}%)") print(f"\noutputs -> {out_dir.resolve()}") if __name__ == "__main__": main()