"""TwinLiteNet8 inference — single image or directory. Same interface as Segformer's predict.py for easy swap. Trained at 640x360; this script auto-resizes any input down to 640x360 for inference, then upsamples the prediction back to original resolution. Usage: python predict.py input.jpg --weights run_8class/twinlite8_best.pt python predict.py --dir frames/ --out out/ --weights run_8class/twinlite8_best.pt """ from __future__ import annotations import argparse, sys, os from pathlib import Path import cv2 import numpy as np import torch import torch.nn.functional as F sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from model.TwinLite_8class import TwinLiteNet8 NAMES = ["tree", "ground", "person", "sky", "road", "mountain", "building", "background"] PALETTE = np.array([ [60, 220, 60], # tree [40, 100, 160], # ground [40, 40, 230], # person [230, 200, 60], # sky [140, 140, 140], # road [180, 60, 180], # mountain [50, 220, 220], # building [100, 100, 100], # background ], dtype=np.uint8) TRAIN_W, TRAIN_H = 640, 360 def load_model(weights, device="cuda"): model = TwinLiteNet8(num_classes=8).to(device).eval() ckpt = torch.load(weights, map_location=device, weights_only=False) model.load_state_dict(ckpt["model"] if "model" in ckpt else ckpt) return model def predict(model, bgr_img, device="cuda"): """BGR uint8 → (H,W) class id mask 0..7 at original resolution.""" H, W = bgr_img.shape[:2] inp_bgr = cv2.resize(bgr_img, (TRAIN_W, TRAIN_H)) rgb = cv2.cvtColor(inp_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 x = torch.from_numpy(rgb.transpose(2, 0, 1)).unsqueeze(0).float().to(device) with torch.no_grad(): logits = model(x) # Upsample logits to original resolution before argmax (cleaner boundaries) logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False) # v2: channel 7 (background) was never trained -> mask it out so it can't win argmax logits[:, 7, :, :] = -1e9 return logits.argmax(1)[0].cpu().numpy().astype(np.uint8) def colorize(mask): return PALETTE[mask] def overlay(bgr, mask, alpha=0.45): return cv2.addWeighted(bgr, 1 - alpha, colorize(mask), alpha, 0) def main(): ap = argparse.ArgumentParser() ap.add_argument("input", nargs="?") ap.add_argument("--dir") ap.add_argument("--out", default=".") ap.add_argument("--weights", default="run_8class/twinlite8_best.pt") ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") args = ap.parse_args() if not args.input and not args.dir: ap.print_help(); return print(f"loading model from {args.weights} on {args.device} ...") model = load_model(args.weights, device=args.device) out_dir = Path(args.out); out_dir.mkdir(parents=True, exist_ok=True) paths = [] if args.dir: paths = sorted(p for p in Path(args.dir).iterdir() if p.suffix.lower() in {".jpg",".jpeg",".png",".bmp"}) if args.input: paths.append(Path(args.input)) for p in paths: img = cv2.imread(str(p)) if img is None: print(f" skip: {p}"); continue mask = predict(model, img, device=args.device) cv2.imwrite(str(out_dir / f"{p.stem}_pred.png"), mask) cv2.imwrite(str(out_dir / f"{p.stem}_overlay.jpg"), overlay(img, mask)) counts = np.bincount(mask.flatten(), minlength=8) top = counts.argmax() print(f" {p.name:<50} top: {NAMES[top]} ({100*counts[top]/counts.sum():.1f}%)") print(f"\noutputs -> {out_dir.resolve()}") if __name__ == "__main__": main()