| """TwinLiteNet8 inference — single image or directory. |
| |
| Same interface as Segformer's predict.py for easy swap. |
| Trained at 640x360; this script auto-resizes any input down to 640x360 for |
| inference, then upsamples the prediction back to original resolution. |
| |
| Usage: |
| python predict.py input.jpg --weights run_8class/twinlite8_best.pt |
| python predict.py --dir frames/ --out out/ --weights run_8class/twinlite8_best.pt |
| """ |
| from __future__ import annotations |
| import argparse, sys, os |
| from pathlib import Path |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from model.TwinLite_8class import TwinLiteNet8 |
|
|
| NAMES = ["tree", "ground", "person", "sky", "road", "mountain", "building", "background"] |
| PALETTE = np.array([ |
| [60, 220, 60], |
| [40, 100, 160], |
| [40, 40, 230], |
| [230, 200, 60], |
| [140, 140, 140], |
| [180, 60, 180], |
| [50, 220, 220], |
| [100, 100, 100], |
| ], dtype=np.uint8) |
| TRAIN_W, TRAIN_H = 640, 360 |
|
|
|
|
| def load_model(weights, device="cuda"): |
| model = TwinLiteNet8(num_classes=8).to(device).eval() |
| ckpt = torch.load(weights, map_location=device, weights_only=False) |
| model.load_state_dict(ckpt["model"] if "model" in ckpt else ckpt) |
| return model |
|
|
|
|
| def predict(model, bgr_img, device="cuda"): |
| """BGR uint8 → (H,W) class id mask 0..7 at original resolution.""" |
| H, W = bgr_img.shape[:2] |
| inp_bgr = cv2.resize(bgr_img, (TRAIN_W, TRAIN_H)) |
| rgb = cv2.cvtColor(inp_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 |
| x = torch.from_numpy(rgb.transpose(2, 0, 1)).unsqueeze(0).float().to(device) |
| with torch.no_grad(): |
| logits = model(x) |
| |
| logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False) |
| |
| logits[:, 7, :, :] = -1e9 |
| return logits.argmax(1)[0].cpu().numpy().astype(np.uint8) |
|
|
|
|
| def colorize(mask): |
| return PALETTE[mask] |
|
|
|
|
| def overlay(bgr, mask, alpha=0.45): |
| return cv2.addWeighted(bgr, 1 - alpha, colorize(mask), alpha, 0) |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("input", nargs="?") |
| ap.add_argument("--dir") |
| ap.add_argument("--out", default=".") |
| ap.add_argument("--weights", default="run_8class/twinlite8_best.pt") |
| ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") |
| args = ap.parse_args() |
|
|
| if not args.input and not args.dir: |
| ap.print_help(); return |
|
|
| print(f"loading model from {args.weights} on {args.device} ...") |
| model = load_model(args.weights, device=args.device) |
| out_dir = Path(args.out); out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| paths = [] |
| if args.dir: |
| paths = sorted(p for p in Path(args.dir).iterdir() if p.suffix.lower() in {".jpg",".jpeg",".png",".bmp"}) |
| if args.input: |
| paths.append(Path(args.input)) |
|
|
| for p in paths: |
| img = cv2.imread(str(p)) |
| if img is None: |
| print(f" skip: {p}"); continue |
| mask = predict(model, img, device=args.device) |
| cv2.imwrite(str(out_dir / f"{p.stem}_pred.png"), mask) |
| cv2.imwrite(str(out_dir / f"{p.stem}_overlay.jpg"), overlay(img, mask)) |
| counts = np.bincount(mask.flatten(), minlength=8) |
| top = counts.argmax() |
| print(f" {p.name:<50} top: {NAMES[top]} ({100*counts[top]/counts.sum():.1f}%)") |
| print(f"\noutputs -> {out_dir.resolve()}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|