File size: 5,138 Bytes
32da9de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fd5923
32da9de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""Segformer85Mv1 β€” apple-orchard semantic segmentation inference.

Usage:
    python predict.py input.jpg                     # writes input_pred.png + input_overlay.jpg
    python predict.py --dir frames/ --out out/      # batch process a folder

Classes (id β†’ name):
    0 tree   1 ground   2 person   3 sky
    4 road   5 mountain 6 building 7 background
"""
from __future__ import annotations
import argparse
import os
from pathlib import Path

import cv2
import numpy as np
import torch
import torch.nn.functional as F
from transformers import SegformerForSemanticSegmentation

# ─── config ───
BASE_MODEL = "nvidia/segformer-b5-finetuned-ade-640-640"
WEIGHTS_PATH = os.environ.get("SEGFORMER85M_WEIGHTS", "Segformer85Mv2.pt")  # default v2 (best generalization); set env var or --weights to use v1
NAMES = ["tree", "ground", "person", "sky", "road", "mountain", "building", "background"]
PALETTE = np.array([
    [60, 220, 60],    # tree     - green
    [40, 100, 160],   # ground   - brown
    [40,  40, 230],   # person   - red
    [230, 200, 60],   # sky      - cyan
    [140, 140, 140],  # road     - gray
    [180,  60, 180],  # mountain - purple
    [50, 220, 220],   # building - yellow
    [100, 100, 100],  # background - mid-gray
], dtype=np.uint8)
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD  = np.array([0.229, 0.224, 0.225], dtype=np.float32)


def load_model(weights_path: str | Path = WEIGHTS_PATH, device: str = "cuda"):
    """Load Segformer85Mv1. Returns model in eval mode on the target device."""
    model = SegformerForSemanticSegmentation.from_pretrained(
        BASE_MODEL,
        num_labels=len(NAMES),
        id2label={i: n for i, n in enumerate(NAMES)},
        label2id={n: i for i, n in enumerate(NAMES)},
        ignore_mismatched_sizes=True,
    ).to(device)
    ckpt = torch.load(weights_path, map_location=device, weights_only=False)
    state = ckpt["model"] if "model" in ckpt else ckpt
    model.load_state_dict(state)
    model.eval()
    return model


def preprocess(bgr_img: np.ndarray) -> tuple[torch.Tensor, tuple[int, int]]:
    """BGR uint8 image β†’ normalized tensor sized to 32 multiples; returns (tensor, original (H,W))."""
    H, W = bgr_img.shape[:2]
    H32, W32 = (H // 32) * 32, (W // 32) * 32
    if H32 == 0 or W32 == 0:
        raise ValueError(f"Image too small: {W}x{H}")
    rgb = cv2.cvtColor(cv2.resize(bgr_img, (W32, H32)), cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    rgb = (rgb - IMAGENET_MEAN) / IMAGENET_STD
    x = torch.from_numpy(rgb.transpose(2, 0, 1)).unsqueeze(0).float()
    return x, (H, W)


def predict(model, bgr_img: np.ndarray, device: str = "cuda") -> np.ndarray:
    """Run inference on one BGR image. Returns (H,W) uint8 mask with class ids 0..7."""
    x, (H, W) = preprocess(bgr_img)
    x = x.to(device)
    with torch.no_grad():
        logits = model(pixel_values=x).logits
        logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False)
    return logits.argmax(1)[0].cpu().numpy().astype(np.uint8)


def colorize(mask: np.ndarray) -> np.ndarray:
    """class-id mask (H,W) β†’ BGR color visualization (H,W,3)."""
    return PALETTE[mask]


def overlay(bgr_img: np.ndarray, mask: np.ndarray, alpha: float = 0.45) -> np.ndarray:
    """Blend prediction over original image."""
    return cv2.addWeighted(bgr_img, 1 - alpha, colorize(mask), alpha, 0)


def main():
    ap = argparse.ArgumentParser(description="Segformer85Mv1 inference (8-class outdoor segmentation).")
    ap.add_argument("input", nargs="?", help="Single image path")
    ap.add_argument("--dir", help="Directory of images to process")
    ap.add_argument("--out", default=".", help="Output directory")
    ap.add_argument("--weights", default=WEIGHTS_PATH, help="Path to Segformer85Mv1.pt")
    ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    args = ap.parse_args()

    if not args.input and not args.dir:
        ap.print_help()
        return

    print(f"loading model from {args.weights} on {args.device} ...")
    model = load_model(args.weights, device=args.device)
    out_dir = Path(args.out); out_dir.mkdir(parents=True, exist_ok=True)

    paths = []
    if args.dir:
        paths = sorted(p for p in Path(args.dir).iterdir() if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".bmp"})
    if args.input:
        paths.append(Path(args.input))

    for p in paths:
        img = cv2.imread(str(p))
        if img is None:
            print(f"  skip (unreadable): {p}")
            continue
        mask = predict(model, img, device=args.device)
        cv2.imwrite(str(out_dir / f"{p.stem}_pred.png"), mask)             # raw class-id mask
        cv2.imwrite(str(out_dir / f"{p.stem}_overlay.jpg"), overlay(img, mask))  # visualization
        # quick stats
        counts = np.bincount(mask.flatten(), minlength=len(NAMES))
        top = counts.argmax()
        print(f"  {p.name:<40}  top class: {NAMES[top]} ({100*counts[top]/counts.sum():.1f}%)")

    print(f"\noutputs -> {out_dir.resolve()}")


if __name__ == "__main__":
    main()