"""TwinLiteNet8 — single-branch 8-class semantic seg, directly comparable to Segformer. Classes: 0 tree 1 ground 2 person 3 sky 4 road 5 mountain 6 building 7 background """ from __future__ import annotations import os, sys, json, re, time, random from pathlib import Path import numpy as np, cv2, torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader, ConcatDataset sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from model.TwinLite_8class import TwinLiteNet8 # ───────── config ───────── ROOT = Path(r"C:/Users/room104/Desktop/AGMOtree/semantic_segmantation") OLD_IMG = ROOT / "merged_dataset/train/images" OLD_MSK = ROOT / "merged_dataset/train/masks_pseudo" NEW_IMG = ROOT / "orchard_nav/train/images" NEW_MSK = ROOT / "orchard_nav/train/masks" OUT_DIR = Path(r"C:/Users/room104/Desktop/AGMOtree/TwinLiteNet_train/run_v2") OUT_DIR.mkdir(parents=True, exist_ok=True) NAMES = ["tree","ground","person","sky","road","mountain","building","background"] NUM_CLASSES = 8 IGNORE_INDEX = 255 W_IN, H_IN = 640, 360 BATCH = 16 EPOCHS = 60 LR = 5e-4 NUM_WORKERS = 4 SEED = 42 DEVICE = "cuda" # v2 design: background is NOT a real class. Pixels labeled 7 → 255 (ignore_index) # in the loader, so loss never trains channel 7. Weight 0 as belt-and-braces. # At inference, channel 7 logit is set to -inf before argmax (see predict.py update). WEIGHTS = np.array([1.5, 0.5, 1.5, 1.0, 1.0, 1.0, 1.0, 0.0], dtype=np.float32) random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED) def frame_num(p): m = re.match(r"frame_(\d+)", p.stem); return int(m.group(1)) if m else -1 class OrchardDS(Dataset): def __init__(self, paths, mask_dir, augment=False, source="old"): self.paths = paths self.mask_dir = mask_dir self.augment = augment self.source = source def __len__(self): return len(self.paths) def __getitem__(self, i): ip = self.paths[i] img = cv2.imread(str(ip)) msk = cv2.imread(str(self.mask_dir / (ip.stem + ".png")), cv2.IMREAD_GRAYSCALE) if img is None or msk is None: img = np.zeros((H_IN, W_IN, 3), dtype=np.uint8) msk = np.full((H_IN, W_IN), IGNORE_INDEX, dtype=np.uint8) if self.augment: if random.random() < 0.5: img = np.ascontiguousarray(img[:, ::-1]) msk = np.ascontiguousarray(msk[:, ::-1]) if random.random() < 0.5: hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.int16) hsv[..., 0] = (hsv[..., 0] + random.randint(-10, 10)) % 180 hsv[..., 1] = np.clip(hsv[..., 1] * random.uniform(0.7, 1.3), 0, 255) hsv[..., 2] = np.clip(hsv[..., 2] * random.uniform(0.7, 1.3), 0, 255) img = cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR) img = cv2.resize(img, (W_IN, H_IN)) msk = cv2.resize(msk, (W_IN, H_IN), interpolation=cv2.INTER_NEAREST) # v2: remap class 7 (background) -> IGNORE_INDEX so it does NOT train. # The user's intent: "background = stuff the model can't recognize", not a real class. if self.source == "old": msk[msk == 7] = IGNORE_INDEX # new-source masks already have 255 for non-tree pixels, no change needed. img = img[:, :, ::-1].transpose(2, 0, 1).astype(np.float32) / 255.0 return (torch.from_numpy(img).float(), torch.from_numpy(msk).long()) # ─── temporal split ─── old_all = sorted(OLD_IMG.glob("*.jpg")) old_train = [p for p in old_all if frame_num(p) <= 4500] old_val = [p for p in old_all if frame_num(p) > 4500] new_all = sorted(NEW_IMG.glob("*.jpg")); random.shuffle(new_all) n_new_val = max(20, len(new_all) // 10) new_val = new_all[:n_new_val] new_train = new_all[n_new_val:] train_ds = ConcatDataset([ OrchardDS(old_train, OLD_MSK, augment=True, source="old"), OrchardDS(new_train, NEW_MSK, augment=True, source="new"), ]) old_val_ds = OrchardDS(old_val, OLD_MSK, augment=False, source="old") new_val_ds = OrchardDS(new_val, NEW_MSK, augment=False, source="new") print(f"=== TwinLiteNet8 (single-branch, 8-class) ===") print(f" old train: {len(old_train)} new train: {len(new_train)}") print(f" old val: {len(old_val)} new val: {len(new_val)}") # ─── eval ─── def confusion(preds, ys, n, ignore=IGNORE_INDEX): cm = np.zeros((n, n), dtype=np.int64) valid = ys != ignore if not valid.any(): return cm p = preds[valid]; t = ys[valid] for tc in range(n): mt = (t == tc) if not mt.any(): continue for pc in range(n): cm[tc, pc] += int(((p == pc) & mt).sum()) return cm def iou_from_cm(cm): n = cm.shape[0]; ious = np.zeros(n) for c in range(n): tp = cm[c,c]; fp = cm[:,c].sum()-tp; fn = cm[c,:].sum()-tp ious[c] = tp / (tp+fp+fn) if (tp+fp+fn) > 0 else float("nan") return ious # ─── train ─── log_path = OUT_DIR / "log.txt" def log(m): print(m, flush=True) with log_path.open("a", encoding="utf-8") as f: f.write(m + "\n") def main(): log_path.write_text("") train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True, persistent_workers=True) old_val_loader = DataLoader(old_val_ds, batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True, persistent_workers=True) new_val_loader = DataLoader(new_val_ds, batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True, persistent_workers=True) model = TwinLiteNet8(num_classes=NUM_CLASSES).to(DEVICE) n_params = sum(p.numel() for p in model.parameters()) log(f"model: TwinLiteNet8 params: {n_params/1e6:.3f}M") log(f"input: {W_IN}x{H_IN} batch: {BATCH} epochs: {EPOCHS} LR: {LR}") log(f"classes: {NAMES}") log(f"weights: {dict(zip(NAMES, [round(float(w),2) for w in WEIGHTS]))}") log(f"train: {len(train_ds)} old_val: {len(old_val_ds)} new_val: {len(new_val_ds)}") cw = torch.tensor(WEIGHTS, dtype=torch.float32, device=DEVICE) loss_fn = nn.CrossEntropyLoss(weight=cw, ignore_index=IGNORE_INDEX) optim = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4) sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=EPOCHS * len(train_loader)) best_tree = -1.0 history = [] for epoch in range(1, EPOCHS+1): model.train() t0 = time.time() ep_loss = 0.0 for x, y in train_loader: x = x.cuda(non_blocking=True); y = y.cuda(non_blocking=True) logits = model(x) loss = loss_fn(logits, y) optim.zero_grad(); loss.backward(); optim.step(); sched.step() ep_loss += loss.item() train_loss = ep_loss / len(train_loader) model.eval() cm_old = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=np.int64) tree_tp = tree_fn = 0 with torch.no_grad(): for x, y in old_val_loader: x = x.cuda(); y = y.cuda() logits = model(x) logits[:, 7, :, :] = -1e9 # never predict background — that channel is untrained preds = logits.argmax(1) cm_old += confusion(preds.cpu().numpy(), y.cpu().numpy(), NUM_CLASSES) for x, y in new_val_loader: x = x.cuda(); y = y.cuda() logits = model(x) logits[:, 7, :, :] = -1e9 preds = logits.argmax(1).cpu().numpy() ys = y.cpu().numpy() tm = (ys == 0) tree_tp += int(((preds == 0) & tm).sum()) tree_fn += int(((preds != 0) & tm).sum()) iou_old = iou_from_cm(cm_old) miou_7 = float(np.nanmean(iou_old[:7])) tree_old = float(iou_old[0]) ground_old = float(iou_old[1]) tree_recall_new = tree_tp / (tree_tp + tree_fn) if (tree_tp + tree_fn) > 0 else float("nan") elapsed = time.time() - t0 log(f"epoch {epoch:02d}/{EPOCHS} loss={train_loss:.4f} " f"mIoU(7)={miou_7:.3f} tree_old={tree_old:.3f} ground_old={ground_old:.3f} " f"tree_new_recall={tree_recall_new:.3f} ({elapsed:.0f}s)") log(f" per-class IoU: " + ", ".join(f"{n}={v:.3f}" for n, v in zip(NAMES, iou_old))) history.append({ "epoch": epoch, "loss": float(train_loss), "miou_7": miou_7, "tree_iou_old": tree_old, "ground_iou_old": ground_old, "tree_recall_new": float(tree_recall_new), "per_class_iou": {n: float(v) for n, v in zip(NAMES, iou_old)}, }) torch.save({"model": model.state_dict(), "epoch": epoch, "tree_iou_old": tree_old, "miou_7": miou_7, "tree_recall_new": float(tree_recall_new)}, OUT_DIR / "twinlite8_last.pt") if tree_old > best_tree: best_tree = tree_old torch.save({"model": model.state_dict(), "epoch": epoch, "tree_iou_old": tree_old, "miou_7": miou_7, "tree_recall_new": float(tree_recall_new)}, OUT_DIR / "twinlite8_best.pt") log(f" saved best (tree_old {tree_old:.3f})") (OUT_DIR / "history.json").write_text(json.dumps(history, indent=2)) log(f"\n=== DONE === best tree_old IoU: {best_tree:.3f}") # ─── FPS benchmark ─── log(f"\n=== FPS BENCHMARK (RTX 3080, batch=1, 640x360) ===") model.eval() x = torch.randn(1, 3, H_IN, W_IN, device=DEVICE) with torch.no_grad(): for _ in range(20): model(x) torch.cuda.synchronize() t0 = time.time() N = 200 for _ in range(N): model(x) torch.cuda.synchronize() fps = N / (time.time() - t0) log(f" TwinLiteNet8 @ 640x360 batch=1: {fps:.1f} FPS") log(f" Jetson Orin Nano estimate: ~{fps/4:.0f}-{fps/3:.0f} FPS") if __name__ == "__main__": main()