| """TwinLiteNet8 β single-branch 8-class semantic seg, directly comparable to Segformer. |
| |
| Classes: 0 tree 1 ground 2 person 3 sky 4 road 5 mountain 6 building 7 background |
| """ |
| from __future__ import annotations |
| import os, sys, json, re, time, random |
| from pathlib import Path |
| import numpy as np, cv2, torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader, ConcatDataset |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from model.TwinLite_8class import TwinLiteNet8 |
|
|
| |
| ROOT = Path(r"C:/Users/room104/Desktop/AGMOtree/semantic_segmantation") |
| OLD_IMG = ROOT / "merged_dataset/train/images" |
| OLD_MSK = ROOT / "merged_dataset/train/masks_pseudo" |
| NEW_IMG = ROOT / "orchard_nav/train/images" |
| NEW_MSK = ROOT / "orchard_nav/train/masks" |
|
|
| OUT_DIR = Path(r"C:/Users/room104/Desktop/AGMOtree/TwinLiteNet_train/run_v2") |
| OUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| NAMES = ["tree","ground","person","sky","road","mountain","building","background"] |
| NUM_CLASSES = 8 |
| IGNORE_INDEX = 255 |
|
|
| W_IN, H_IN = 640, 360 |
| BATCH = 16 |
| EPOCHS = 60 |
| LR = 5e-4 |
| NUM_WORKERS = 4 |
| SEED = 42 |
| DEVICE = "cuda" |
|
|
| |
| |
| |
| WEIGHTS = np.array([1.5, 0.5, 1.5, 1.0, 1.0, 1.0, 1.0, 0.0], dtype=np.float32) |
|
|
| random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED) |
|
|
|
|
| def frame_num(p): |
| m = re.match(r"frame_(\d+)", p.stem); return int(m.group(1)) if m else -1 |
|
|
|
|
| class OrchardDS(Dataset): |
| def __init__(self, paths, mask_dir, augment=False, source="old"): |
| self.paths = paths |
| self.mask_dir = mask_dir |
| self.augment = augment |
| self.source = source |
|
|
| def __len__(self): return len(self.paths) |
|
|
| def __getitem__(self, i): |
| ip = self.paths[i] |
| img = cv2.imread(str(ip)) |
| msk = cv2.imread(str(self.mask_dir / (ip.stem + ".png")), cv2.IMREAD_GRAYSCALE) |
| if img is None or msk is None: |
| img = np.zeros((H_IN, W_IN, 3), dtype=np.uint8) |
| msk = np.full((H_IN, W_IN), IGNORE_INDEX, dtype=np.uint8) |
|
|
| if self.augment: |
| if random.random() < 0.5: |
| img = np.ascontiguousarray(img[:, ::-1]) |
| msk = np.ascontiguousarray(msk[:, ::-1]) |
| if random.random() < 0.5: |
| hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.int16) |
| hsv[..., 0] = (hsv[..., 0] + random.randint(-10, 10)) % 180 |
| hsv[..., 1] = np.clip(hsv[..., 1] * random.uniform(0.7, 1.3), 0, 255) |
| hsv[..., 2] = np.clip(hsv[..., 2] * random.uniform(0.7, 1.3), 0, 255) |
| img = cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR) |
|
|
| img = cv2.resize(img, (W_IN, H_IN)) |
| msk = cv2.resize(msk, (W_IN, H_IN), interpolation=cv2.INTER_NEAREST) |
|
|
| |
| |
| if self.source == "old": |
| msk[msk == 7] = IGNORE_INDEX |
| |
|
|
| img = img[:, :, ::-1].transpose(2, 0, 1).astype(np.float32) / 255.0 |
| return (torch.from_numpy(img).float(), |
| torch.from_numpy(msk).long()) |
|
|
|
|
| |
| old_all = sorted(OLD_IMG.glob("*.jpg")) |
| old_train = [p for p in old_all if frame_num(p) <= 4500] |
| old_val = [p for p in old_all if frame_num(p) > 4500] |
|
|
| new_all = sorted(NEW_IMG.glob("*.jpg")); random.shuffle(new_all) |
| n_new_val = max(20, len(new_all) // 10) |
| new_val = new_all[:n_new_val] |
| new_train = new_all[n_new_val:] |
|
|
| train_ds = ConcatDataset([ |
| OrchardDS(old_train, OLD_MSK, augment=True, source="old"), |
| OrchardDS(new_train, NEW_MSK, augment=True, source="new"), |
| ]) |
| old_val_ds = OrchardDS(old_val, OLD_MSK, augment=False, source="old") |
| new_val_ds = OrchardDS(new_val, NEW_MSK, augment=False, source="new") |
|
|
| print(f"=== TwinLiteNet8 (single-branch, 8-class) ===") |
| print(f" old train: {len(old_train)} new train: {len(new_train)}") |
| print(f" old val: {len(old_val)} new val: {len(new_val)}") |
|
|
|
|
| |
| def confusion(preds, ys, n, ignore=IGNORE_INDEX): |
| cm = np.zeros((n, n), dtype=np.int64) |
| valid = ys != ignore |
| if not valid.any(): return cm |
| p = preds[valid]; t = ys[valid] |
| for tc in range(n): |
| mt = (t == tc) |
| if not mt.any(): continue |
| for pc in range(n): |
| cm[tc, pc] += int(((p == pc) & mt).sum()) |
| return cm |
|
|
| def iou_from_cm(cm): |
| n = cm.shape[0]; ious = np.zeros(n) |
| for c in range(n): |
| tp = cm[c,c]; fp = cm[:,c].sum()-tp; fn = cm[c,:].sum()-tp |
| ious[c] = tp / (tp+fp+fn) if (tp+fp+fn) > 0 else float("nan") |
| return ious |
|
|
|
|
| |
| log_path = OUT_DIR / "log.txt" |
| def log(m): |
| print(m, flush=True) |
| with log_path.open("a", encoding="utf-8") as f: f.write(m + "\n") |
|
|
|
|
| def main(): |
| log_path.write_text("") |
| train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, |
| num_workers=NUM_WORKERS, pin_memory=True, drop_last=True, |
| persistent_workers=True) |
| old_val_loader = DataLoader(old_val_ds, batch_size=BATCH, shuffle=False, |
| num_workers=2, pin_memory=True, persistent_workers=True) |
| new_val_loader = DataLoader(new_val_ds, batch_size=BATCH, shuffle=False, |
| num_workers=2, pin_memory=True, persistent_workers=True) |
|
|
| model = TwinLiteNet8(num_classes=NUM_CLASSES).to(DEVICE) |
| n_params = sum(p.numel() for p in model.parameters()) |
| log(f"model: TwinLiteNet8 params: {n_params/1e6:.3f}M") |
| log(f"input: {W_IN}x{H_IN} batch: {BATCH} epochs: {EPOCHS} LR: {LR}") |
| log(f"classes: {NAMES}") |
| log(f"weights: {dict(zip(NAMES, [round(float(w),2) for w in WEIGHTS]))}") |
| log(f"train: {len(train_ds)} old_val: {len(old_val_ds)} new_val: {len(new_val_ds)}") |
|
|
| cw = torch.tensor(WEIGHTS, dtype=torch.float32, device=DEVICE) |
| loss_fn = nn.CrossEntropyLoss(weight=cw, ignore_index=IGNORE_INDEX) |
| optim = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=EPOCHS * len(train_loader)) |
|
|
| best_tree = -1.0 |
| history = [] |
| for epoch in range(1, EPOCHS+1): |
| model.train() |
| t0 = time.time() |
| ep_loss = 0.0 |
| for x, y in train_loader: |
| x = x.cuda(non_blocking=True); y = y.cuda(non_blocking=True) |
| logits = model(x) |
| loss = loss_fn(logits, y) |
| optim.zero_grad(); loss.backward(); optim.step(); sched.step() |
| ep_loss += loss.item() |
| train_loss = ep_loss / len(train_loader) |
|
|
| model.eval() |
| cm_old = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=np.int64) |
| tree_tp = tree_fn = 0 |
| with torch.no_grad(): |
| for x, y in old_val_loader: |
| x = x.cuda(); y = y.cuda() |
| logits = model(x) |
| logits[:, 7, :, :] = -1e9 |
| preds = logits.argmax(1) |
| cm_old += confusion(preds.cpu().numpy(), y.cpu().numpy(), NUM_CLASSES) |
| for x, y in new_val_loader: |
| x = x.cuda(); y = y.cuda() |
| logits = model(x) |
| logits[:, 7, :, :] = -1e9 |
| preds = logits.argmax(1).cpu().numpy() |
| ys = y.cpu().numpy() |
| tm = (ys == 0) |
| tree_tp += int(((preds == 0) & tm).sum()) |
| tree_fn += int(((preds != 0) & tm).sum()) |
|
|
| iou_old = iou_from_cm(cm_old) |
| miou_7 = float(np.nanmean(iou_old[:7])) |
| tree_old = float(iou_old[0]) |
| ground_old = float(iou_old[1]) |
| tree_recall_new = tree_tp / (tree_tp + tree_fn) if (tree_tp + tree_fn) > 0 else float("nan") |
| elapsed = time.time() - t0 |
|
|
| log(f"epoch {epoch:02d}/{EPOCHS} loss={train_loss:.4f} " |
| f"mIoU(7)={miou_7:.3f} tree_old={tree_old:.3f} ground_old={ground_old:.3f} " |
| f"tree_new_recall={tree_recall_new:.3f} ({elapsed:.0f}s)") |
| log(f" per-class IoU: " + ", ".join(f"{n}={v:.3f}" for n, v in zip(NAMES, iou_old))) |
|
|
| history.append({ |
| "epoch": epoch, "loss": float(train_loss), |
| "miou_7": miou_7, "tree_iou_old": tree_old, "ground_iou_old": ground_old, |
| "tree_recall_new": float(tree_recall_new), |
| "per_class_iou": {n: float(v) for n, v in zip(NAMES, iou_old)}, |
| }) |
| torch.save({"model": model.state_dict(), "epoch": epoch, |
| "tree_iou_old": tree_old, "miou_7": miou_7, "tree_recall_new": float(tree_recall_new)}, |
| OUT_DIR / "twinlite8_last.pt") |
| if tree_old > best_tree: |
| best_tree = tree_old |
| torch.save({"model": model.state_dict(), "epoch": epoch, |
| "tree_iou_old": tree_old, "miou_7": miou_7, "tree_recall_new": float(tree_recall_new)}, |
| OUT_DIR / "twinlite8_best.pt") |
| log(f" saved best (tree_old {tree_old:.3f})") |
| (OUT_DIR / "history.json").write_text(json.dumps(history, indent=2)) |
|
|
| log(f"\n=== DONE === best tree_old IoU: {best_tree:.3f}") |
|
|
| |
| log(f"\n=== FPS BENCHMARK (RTX 3080, batch=1, 640x360) ===") |
| model.eval() |
| x = torch.randn(1, 3, H_IN, W_IN, device=DEVICE) |
| with torch.no_grad(): |
| for _ in range(20): model(x) |
| torch.cuda.synchronize() |
| t0 = time.time() |
| N = 200 |
| for _ in range(N): model(x) |
| torch.cuda.synchronize() |
| fps = N / (time.time() - t0) |
| log(f" TwinLiteNet8 @ 640x360 batch=1: {fps:.1f} FPS") |
| log(f" Jetson Orin Nano estimate: ~{fps/4:.0f}-{fps/3:.0f} FPS") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|