TwinLiteNet8 / train_8class.py
WEN0256's picture
Initial release: TwinLiteNet8 (0.44M params, 7-class orchard semantic seg, edge-deployment ready)
f5cc6c0 verified
"""TwinLiteNet8 β€” single-branch 8-class semantic seg, directly comparable to Segformer.
Classes: 0 tree 1 ground 2 person 3 sky 4 road 5 mountain 6 building 7 background
"""
from __future__ import annotations
import os, sys, json, re, time, random
from pathlib import Path
import numpy as np, cv2, torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, ConcatDataset
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from model.TwinLite_8class import TwinLiteNet8
# ───────── config ─────────
ROOT = Path(r"C:/Users/room104/Desktop/AGMOtree/semantic_segmantation")
OLD_IMG = ROOT / "merged_dataset/train/images"
OLD_MSK = ROOT / "merged_dataset/train/masks_pseudo"
NEW_IMG = ROOT / "orchard_nav/train/images"
NEW_MSK = ROOT / "orchard_nav/train/masks"
OUT_DIR = Path(r"C:/Users/room104/Desktop/AGMOtree/TwinLiteNet_train/run_v2")
OUT_DIR.mkdir(parents=True, exist_ok=True)
NAMES = ["tree","ground","person","sky","road","mountain","building","background"]
NUM_CLASSES = 8
IGNORE_INDEX = 255
W_IN, H_IN = 640, 360
BATCH = 16
EPOCHS = 60
LR = 5e-4
NUM_WORKERS = 4
SEED = 42
DEVICE = "cuda"
# v2 design: background is NOT a real class. Pixels labeled 7 β†’ 255 (ignore_index)
# in the loader, so loss never trains channel 7. Weight 0 as belt-and-braces.
# At inference, channel 7 logit is set to -inf before argmax (see predict.py update).
WEIGHTS = np.array([1.5, 0.5, 1.5, 1.0, 1.0, 1.0, 1.0, 0.0], dtype=np.float32)
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
def frame_num(p):
m = re.match(r"frame_(\d+)", p.stem); return int(m.group(1)) if m else -1
class OrchardDS(Dataset):
def __init__(self, paths, mask_dir, augment=False, source="old"):
self.paths = paths
self.mask_dir = mask_dir
self.augment = augment
self.source = source
def __len__(self): return len(self.paths)
def __getitem__(self, i):
ip = self.paths[i]
img = cv2.imread(str(ip))
msk = cv2.imread(str(self.mask_dir / (ip.stem + ".png")), cv2.IMREAD_GRAYSCALE)
if img is None or msk is None:
img = np.zeros((H_IN, W_IN, 3), dtype=np.uint8)
msk = np.full((H_IN, W_IN), IGNORE_INDEX, dtype=np.uint8)
if self.augment:
if random.random() < 0.5:
img = np.ascontiguousarray(img[:, ::-1])
msk = np.ascontiguousarray(msk[:, ::-1])
if random.random() < 0.5:
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.int16)
hsv[..., 0] = (hsv[..., 0] + random.randint(-10, 10)) % 180
hsv[..., 1] = np.clip(hsv[..., 1] * random.uniform(0.7, 1.3), 0, 255)
hsv[..., 2] = np.clip(hsv[..., 2] * random.uniform(0.7, 1.3), 0, 255)
img = cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR)
img = cv2.resize(img, (W_IN, H_IN))
msk = cv2.resize(msk, (W_IN, H_IN), interpolation=cv2.INTER_NEAREST)
# v2: remap class 7 (background) -> IGNORE_INDEX so it does NOT train.
# The user's intent: "background = stuff the model can't recognize", not a real class.
if self.source == "old":
msk[msk == 7] = IGNORE_INDEX
# new-source masks already have 255 for non-tree pixels, no change needed.
img = img[:, :, ::-1].transpose(2, 0, 1).astype(np.float32) / 255.0
return (torch.from_numpy(img).float(),
torch.from_numpy(msk).long())
# ─── temporal split ───
old_all = sorted(OLD_IMG.glob("*.jpg"))
old_train = [p for p in old_all if frame_num(p) <= 4500]
old_val = [p for p in old_all if frame_num(p) > 4500]
new_all = sorted(NEW_IMG.glob("*.jpg")); random.shuffle(new_all)
n_new_val = max(20, len(new_all) // 10)
new_val = new_all[:n_new_val]
new_train = new_all[n_new_val:]
train_ds = ConcatDataset([
OrchardDS(old_train, OLD_MSK, augment=True, source="old"),
OrchardDS(new_train, NEW_MSK, augment=True, source="new"),
])
old_val_ds = OrchardDS(old_val, OLD_MSK, augment=False, source="old")
new_val_ds = OrchardDS(new_val, NEW_MSK, augment=False, source="new")
print(f"=== TwinLiteNet8 (single-branch, 8-class) ===")
print(f" old train: {len(old_train)} new train: {len(new_train)}")
print(f" old val: {len(old_val)} new val: {len(new_val)}")
# ─── eval ───
def confusion(preds, ys, n, ignore=IGNORE_INDEX):
cm = np.zeros((n, n), dtype=np.int64)
valid = ys != ignore
if not valid.any(): return cm
p = preds[valid]; t = ys[valid]
for tc in range(n):
mt = (t == tc)
if not mt.any(): continue
for pc in range(n):
cm[tc, pc] += int(((p == pc) & mt).sum())
return cm
def iou_from_cm(cm):
n = cm.shape[0]; ious = np.zeros(n)
for c in range(n):
tp = cm[c,c]; fp = cm[:,c].sum()-tp; fn = cm[c,:].sum()-tp
ious[c] = tp / (tp+fp+fn) if (tp+fp+fn) > 0 else float("nan")
return ious
# ─── train ───
log_path = OUT_DIR / "log.txt"
def log(m):
print(m, flush=True)
with log_path.open("a", encoding="utf-8") as f: f.write(m + "\n")
def main():
log_path.write_text("")
train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True,
num_workers=NUM_WORKERS, pin_memory=True, drop_last=True,
persistent_workers=True)
old_val_loader = DataLoader(old_val_ds, batch_size=BATCH, shuffle=False,
num_workers=2, pin_memory=True, persistent_workers=True)
new_val_loader = DataLoader(new_val_ds, batch_size=BATCH, shuffle=False,
num_workers=2, pin_memory=True, persistent_workers=True)
model = TwinLiteNet8(num_classes=NUM_CLASSES).to(DEVICE)
n_params = sum(p.numel() for p in model.parameters())
log(f"model: TwinLiteNet8 params: {n_params/1e6:.3f}M")
log(f"input: {W_IN}x{H_IN} batch: {BATCH} epochs: {EPOCHS} LR: {LR}")
log(f"classes: {NAMES}")
log(f"weights: {dict(zip(NAMES, [round(float(w),2) for w in WEIGHTS]))}")
log(f"train: {len(train_ds)} old_val: {len(old_val_ds)} new_val: {len(new_val_ds)}")
cw = torch.tensor(WEIGHTS, dtype=torch.float32, device=DEVICE)
loss_fn = nn.CrossEntropyLoss(weight=cw, ignore_index=IGNORE_INDEX)
optim = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=EPOCHS * len(train_loader))
best_tree = -1.0
history = []
for epoch in range(1, EPOCHS+1):
model.train()
t0 = time.time()
ep_loss = 0.0
for x, y in train_loader:
x = x.cuda(non_blocking=True); y = y.cuda(non_blocking=True)
logits = model(x)
loss = loss_fn(logits, y)
optim.zero_grad(); loss.backward(); optim.step(); sched.step()
ep_loss += loss.item()
train_loss = ep_loss / len(train_loader)
model.eval()
cm_old = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=np.int64)
tree_tp = tree_fn = 0
with torch.no_grad():
for x, y in old_val_loader:
x = x.cuda(); y = y.cuda()
logits = model(x)
logits[:, 7, :, :] = -1e9 # never predict background β€” that channel is untrained
preds = logits.argmax(1)
cm_old += confusion(preds.cpu().numpy(), y.cpu().numpy(), NUM_CLASSES)
for x, y in new_val_loader:
x = x.cuda(); y = y.cuda()
logits = model(x)
logits[:, 7, :, :] = -1e9
preds = logits.argmax(1).cpu().numpy()
ys = y.cpu().numpy()
tm = (ys == 0)
tree_tp += int(((preds == 0) & tm).sum())
tree_fn += int(((preds != 0) & tm).sum())
iou_old = iou_from_cm(cm_old)
miou_7 = float(np.nanmean(iou_old[:7]))
tree_old = float(iou_old[0])
ground_old = float(iou_old[1])
tree_recall_new = tree_tp / (tree_tp + tree_fn) if (tree_tp + tree_fn) > 0 else float("nan")
elapsed = time.time() - t0
log(f"epoch {epoch:02d}/{EPOCHS} loss={train_loss:.4f} "
f"mIoU(7)={miou_7:.3f} tree_old={tree_old:.3f} ground_old={ground_old:.3f} "
f"tree_new_recall={tree_recall_new:.3f} ({elapsed:.0f}s)")
log(f" per-class IoU: " + ", ".join(f"{n}={v:.3f}" for n, v in zip(NAMES, iou_old)))
history.append({
"epoch": epoch, "loss": float(train_loss),
"miou_7": miou_7, "tree_iou_old": tree_old, "ground_iou_old": ground_old,
"tree_recall_new": float(tree_recall_new),
"per_class_iou": {n: float(v) for n, v in zip(NAMES, iou_old)},
})
torch.save({"model": model.state_dict(), "epoch": epoch,
"tree_iou_old": tree_old, "miou_7": miou_7, "tree_recall_new": float(tree_recall_new)},
OUT_DIR / "twinlite8_last.pt")
if tree_old > best_tree:
best_tree = tree_old
torch.save({"model": model.state_dict(), "epoch": epoch,
"tree_iou_old": tree_old, "miou_7": miou_7, "tree_recall_new": float(tree_recall_new)},
OUT_DIR / "twinlite8_best.pt")
log(f" saved best (tree_old {tree_old:.3f})")
(OUT_DIR / "history.json").write_text(json.dumps(history, indent=2))
log(f"\n=== DONE === best tree_old IoU: {best_tree:.3f}")
# ─── FPS benchmark ───
log(f"\n=== FPS BENCHMARK (RTX 3080, batch=1, 640x360) ===")
model.eval()
x = torch.randn(1, 3, H_IN, W_IN, device=DEVICE)
with torch.no_grad():
for _ in range(20): model(x)
torch.cuda.synchronize()
t0 = time.time()
N = 200
for _ in range(N): model(x)
torch.cuda.synchronize()
fps = N / (time.time() - t0)
log(f" TwinLiteNet8 @ 640x360 batch=1: {fps:.1f} FPS")
log(f" Jetson Orin Nano estimate: ~{fps/4:.0f}-{fps/3:.0f} FPS")
if __name__ == "__main__":
main()