File size: 3,727 Bytes
f5cc6c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""TwinLiteNet8 inference — single image or directory.

Same interface as Segformer's predict.py for easy swap.
Trained at 640x360; this script auto-resizes any input down to 640x360 for
inference, then upsamples the prediction back to original resolution.

Usage:
    python predict.py input.jpg                  --weights run_8class/twinlite8_best.pt
    python predict.py --dir frames/ --out out/   --weights run_8class/twinlite8_best.pt
"""
from __future__ import annotations
import argparse, sys, os
from pathlib import Path

import cv2
import numpy as np
import torch
import torch.nn.functional as F

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from model.TwinLite_8class import TwinLiteNet8

NAMES = ["tree", "ground", "person", "sky", "road", "mountain", "building", "background"]
PALETTE = np.array([
    [60, 220, 60],    # tree
    [40, 100, 160],   # ground
    [40,  40, 230],   # person
    [230, 200, 60],   # sky
    [140, 140, 140],  # road
    [180,  60, 180],  # mountain
    [50, 220, 220],   # building
    [100, 100, 100],  # background
], dtype=np.uint8)
TRAIN_W, TRAIN_H = 640, 360


def load_model(weights, device="cuda"):
    model = TwinLiteNet8(num_classes=8).to(device).eval()
    ckpt = torch.load(weights, map_location=device, weights_only=False)
    model.load_state_dict(ckpt["model"] if "model" in ckpt else ckpt)
    return model


def predict(model, bgr_img, device="cuda"):
    """BGR uint8 → (H,W) class id mask 0..7 at original resolution."""
    H, W = bgr_img.shape[:2]
    inp_bgr = cv2.resize(bgr_img, (TRAIN_W, TRAIN_H))
    rgb = cv2.cvtColor(inp_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    x = torch.from_numpy(rgb.transpose(2, 0, 1)).unsqueeze(0).float().to(device)
    with torch.no_grad():
        logits = model(x)
        # Upsample logits to original resolution before argmax (cleaner boundaries)
        logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False)
        # v2: channel 7 (background) was never trained -> mask it out so it can't win argmax
        logits[:, 7, :, :] = -1e9
    return logits.argmax(1)[0].cpu().numpy().astype(np.uint8)


def colorize(mask):
    return PALETTE[mask]


def overlay(bgr, mask, alpha=0.45):
    return cv2.addWeighted(bgr, 1 - alpha, colorize(mask), alpha, 0)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("input", nargs="?")
    ap.add_argument("--dir")
    ap.add_argument("--out", default=".")
    ap.add_argument("--weights", default="run_8class/twinlite8_best.pt")
    ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    args = ap.parse_args()

    if not args.input and not args.dir:
        ap.print_help(); return

    print(f"loading model from {args.weights} on {args.device} ...")
    model = load_model(args.weights, device=args.device)
    out_dir = Path(args.out); out_dir.mkdir(parents=True, exist_ok=True)

    paths = []
    if args.dir:
        paths = sorted(p for p in Path(args.dir).iterdir() if p.suffix.lower() in {".jpg",".jpeg",".png",".bmp"})
    if args.input:
        paths.append(Path(args.input))

    for p in paths:
        img = cv2.imread(str(p))
        if img is None:
            print(f"  skip: {p}"); continue
        mask = predict(model, img, device=args.device)
        cv2.imwrite(str(out_dir / f"{p.stem}_pred.png"), mask)
        cv2.imwrite(str(out_dir / f"{p.stem}_overlay.jpg"), overlay(img, mask))
        counts = np.bincount(mask.flatten(), minlength=8)
        top = counts.argmax()
        print(f"  {p.name:<50}  top: {NAMES[top]} ({100*counts[top]/counts.sum():.1f}%)")
    print(f"\noutputs -> {out_dir.resolve()}")


if __name__ == "__main__":
    main()