TwinLiteNet8 / predict.py
WEN0256's picture
Initial release: TwinLiteNet8 (0.44M params, 7-class orchard semantic seg, edge-deployment ready)
f5cc6c0 verified
"""TwinLiteNet8 inference — single image or directory.
Same interface as Segformer's predict.py for easy swap.
Trained at 640x360; this script auto-resizes any input down to 640x360 for
inference, then upsamples the prediction back to original resolution.
Usage:
python predict.py input.jpg --weights run_8class/twinlite8_best.pt
python predict.py --dir frames/ --out out/ --weights run_8class/twinlite8_best.pt
"""
from __future__ import annotations
import argparse, sys, os
from pathlib import Path
import cv2
import numpy as np
import torch
import torch.nn.functional as F
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from model.TwinLite_8class import TwinLiteNet8
NAMES = ["tree", "ground", "person", "sky", "road", "mountain", "building", "background"]
PALETTE = np.array([
[60, 220, 60], # tree
[40, 100, 160], # ground
[40, 40, 230], # person
[230, 200, 60], # sky
[140, 140, 140], # road
[180, 60, 180], # mountain
[50, 220, 220], # building
[100, 100, 100], # background
], dtype=np.uint8)
TRAIN_W, TRAIN_H = 640, 360
def load_model(weights, device="cuda"):
model = TwinLiteNet8(num_classes=8).to(device).eval()
ckpt = torch.load(weights, map_location=device, weights_only=False)
model.load_state_dict(ckpt["model"] if "model" in ckpt else ckpt)
return model
def predict(model, bgr_img, device="cuda"):
"""BGR uint8 → (H,W) class id mask 0..7 at original resolution."""
H, W = bgr_img.shape[:2]
inp_bgr = cv2.resize(bgr_img, (TRAIN_W, TRAIN_H))
rgb = cv2.cvtColor(inp_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
x = torch.from_numpy(rgb.transpose(2, 0, 1)).unsqueeze(0).float().to(device)
with torch.no_grad():
logits = model(x)
# Upsample logits to original resolution before argmax (cleaner boundaries)
logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False)
# v2: channel 7 (background) was never trained -> mask it out so it can't win argmax
logits[:, 7, :, :] = -1e9
return logits.argmax(1)[0].cpu().numpy().astype(np.uint8)
def colorize(mask):
return PALETTE[mask]
def overlay(bgr, mask, alpha=0.45):
return cv2.addWeighted(bgr, 1 - alpha, colorize(mask), alpha, 0)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("input", nargs="?")
ap.add_argument("--dir")
ap.add_argument("--out", default=".")
ap.add_argument("--weights", default="run_8class/twinlite8_best.pt")
ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
args = ap.parse_args()
if not args.input and not args.dir:
ap.print_help(); return
print(f"loading model from {args.weights} on {args.device} ...")
model = load_model(args.weights, device=args.device)
out_dir = Path(args.out); out_dir.mkdir(parents=True, exist_ok=True)
paths = []
if args.dir:
paths = sorted(p for p in Path(args.dir).iterdir() if p.suffix.lower() in {".jpg",".jpeg",".png",".bmp"})
if args.input:
paths.append(Path(args.input))
for p in paths:
img = cv2.imread(str(p))
if img is None:
print(f" skip: {p}"); continue
mask = predict(model, img, device=args.device)
cv2.imwrite(str(out_dir / f"{p.stem}_pred.png"), mask)
cv2.imwrite(str(out_dir / f"{p.stem}_overlay.jpg"), overlay(img, mask))
counts = np.bincount(mask.flatten(), minlength=8)
top = counts.argmax()
print(f" {p.name:<50} top: {NAMES[top]} ({100*counts[top]/counts.sum():.1f}%)")
print(f"\noutputs -> {out_dir.resolve()}")
if __name__ == "__main__":
main()