| |
| """ |
| End-to-end training script for FSD-Level5-CoT on SADC driving data. |
| |
| This script: |
| 1. Downloads a subset of the SADC dataset (streaming β disk) |
| 2. Builds the FSD model from fsd_model/ |
| 3. Trains end-to-end with gradient accumulation, warmup, eval, logging |
| 4. Pushes the trained model to Hugging Face Hub |
| |
| Dataset: jHaselberger/SADC-Situation-Awareness-for-Driver-Centric-Driving-Style-Adaptation |
| Model: Reality123b/FSD-Level5-CoT |
| |
| Usage: |
| # Default (5000 train, 1000 val, 5 epochs) |
| python train_sadc_e2e.py |
| |
| # Custom |
| python train_sadc_e2e.py --train_samples 10000 --val_samples 2000 --epochs 10 --batch_size 4 |
| |
| # Quick test run |
| python train_sadc_e2e.py --train_samples 100 --val_samples 50 --epochs 1 |
| """ |
|
|
| import os |
| import sys |
| import time |
| import json |
| import math |
| import argparse |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| import numpy as np |
|
|
|
|
| |
| |
| |
|
|
| DATASET_NAME = "jHaselberger/SADC-Situation-Awareness-for-Driver-Centric-Driving-Style-Adaptation" |
| HUB_MODEL_ID = "Reality123b/FSD-Level5-CoT" |
|
|
| |
| BEV_SIZE = 100 |
| BEV_FEATURE_DIM = 128 |
| PLANNING_D_MODEL = 128 |
| IMG_H, IMG_W = 120, 160 |
| NUM_WAYPOINTS = 20 |
| COT_ACTOR_QUERIES = 32 |
| COT_ROAD_QUERIES = 16 |
|
|
| |
| MAX_SPEED_MS = 20.0 * 0.44704 |
|
|
| ROAD_TYPE_MAP = { |
| "misc": 0, "rural": 1, "federal": 2, "highway": 3, |
| "city": 4, "parking": 5, "intersection": 6, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def download_sadc_subset(train_samples, val_samples, output_dir, train_split, val_split): |
| """Download a manageable subset of SADC via streaming.""" |
| from datasets import load_dataset, Dataset as HFDataset |
|
|
| os.makedirs(output_dir, exist_ok=True) |
| train_path = os.path.join(output_dir, "train") |
| val_path = os.path.join(output_dir, "val") |
|
|
| |
| if os.path.exists(train_path) and os.path.exists(val_path): |
| print(f"[Download] Found existing subset at {output_dir}, skipping download.") |
| from datasets import load_from_disk |
| return load_from_disk(train_path), load_from_disk(val_path) |
|
|
| |
| print(f"[Download] Streaming {train_samples} train samples from '{train_split}'...") |
| ds_stream = load_dataset(DATASET_NAME, split=train_split, streaming=True) |
| train_rows = [] |
| for i, row in enumerate(ds_stream): |
| if i >= train_samples: |
| break |
| train_rows.append(row) |
| if (i + 1) % 1000 == 0: |
| print(f" ... {i + 1}/{train_samples}") |
| train_ds = HFDataset.from_list(train_rows) |
| train_ds.save_to_disk(train_path) |
| print(f" Saved {len(train_ds)} train samples.") |
|
|
| |
| print(f"[Download] Streaming {val_samples} val samples from '{val_split}'...") |
| ds_stream = load_dataset(DATASET_NAME, split=val_split, streaming=True) |
| val_rows = [] |
| for i, row in enumerate(ds_stream): |
| if i >= val_samples: |
| break |
| val_rows.append(row) |
| if (i + 1) % 500 == 0: |
| print(f" ... {i + 1}/{val_samples}") |
| val_ds = HFDataset.from_list(val_rows) |
| val_ds.save_to_disk(val_path) |
| print(f" Saved {len(val_ds)} val samples.") |
|
|
| return train_ds, val_ds |
|
|
|
|
| |
| |
| |
|
|
| class SADCDrivingDataset(Dataset): |
| """Wraps SADC HF dataset β FSD model inputs + targets.""" |
|
|
| def __init__(self, hf_dataset, img_size=(IMG_H, IMG_W)): |
| self.ds = hf_dataset |
| self.img_h, self.img_w = img_size |
|
|
| def __len__(self): |
| return len(self.ds) |
|
|
| def __getitem__(self, idx): |
| row = self.ds[idx] |
|
|
| |
| img = row.get("frame", None) |
| if img is None: |
| img_tensor = torch.zeros(3, self.img_h, self.img_w) |
| else: |
| from torchvision import transforms |
| transform = transforms.Compose([ |
| transforms.Resize((self.img_h, self.img_w)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| ]) |
| try: |
| if hasattr(img, "convert"): |
| img = img.convert("RGB") |
| img_tensor = transform(img) |
| except Exception: |
| img_tensor = torch.zeros(3, self.img_h, self.img_w) |
|
|
| |
| camera_images = img_tensor.unsqueeze(0).expand(6, -1, -1, -1).clone() |
| for i in range(1, 6): |
| camera_images[i] += torch.randn_like(camera_images[i]) * 0.01 |
|
|
| |
| speed_ms = float(row.get("v_kmph", 0.0)) / 3.6 |
| ax = float(row.get("ax_mpss", 0.0)) |
| steering = float(row.get("steering_rack_pos_m", 0.0)) |
| yaw_rate = float(row.get("yaw_rate_radps", 0.0)) |
| lane_center = float(row.get("d_lanecenter_m", 0.0)) |
| curvature = float(row.get("lane_curvature_radpm", 0.0)) |
|
|
| ego_state = torch.tensor([ |
| speed_ms, ax, steering, yaw_rate, 0.0, lane_center, |
| ], dtype=torch.float32) |
|
|
| |
| road_type = str(row.get("road_type", "misc")) |
| nav_cmd = ROAD_TYPE_MAP.get(road_type, 0) |
|
|
| |
| K = torch.zeros(6, 3, 3) |
| K[:, 0, 0] = 200.0 |
| K[:, 1, 1] = 200.0 |
| K[:, 0, 2] = self.img_w / 2 |
| K[:, 1, 2] = self.img_h / 2 |
| K[:, 2, 2] = 1.0 |
|
|
| E = torch.eye(4).unsqueeze(0).expand(6, -1, -1).clone() |
| yaw_offsets = [-45, 45, -135, 135, -90, 90] |
| for i, yaw_deg in enumerate(yaw_offsets): |
| yaw_r = math.radians(yaw_deg) |
| E[i, 0, 0] = math.cos(yaw_r) |
| E[i, 0, 1] = -math.sin(yaw_r) |
| E[i, 1, 0] = math.sin(yaw_r) |
| E[i, 1, 1] = math.cos(yaw_r) |
|
|
| |
| base_dist = max(0.5, abs(lane_center)) |
| us_distances = torch.ones(20, 1) * base_dist |
| us_distances[:7] = torch.clamp(torch.randn(7, 1) * 0.5 + 3.0, 0.3, 5.0) |
| us_distances[7:14] = torch.clamp(torch.randn(7, 1) * 0.5 + 3.5, 0.3, 5.0) |
| us_distances[14:17] = torch.clamp(torch.tensor([[base_dist]] * 3) + torch.randn(3, 1) * 0.2, 0.3, 5.0) |
| us_distances[17:20] = torch.clamp(torch.tensor([[base_dist]] * 3) + torch.randn(3, 1) * 0.2, 0.3, 5.0) |
|
|
| us_placements = torch.zeros(20, 6) |
| for i in range(7): |
| us_placements[i] = torch.tensor([2.25, (i - 3) * 0.3, 0.4, (i - 3) * 10, 0, 0]) |
| for i in range(7): |
| us_placements[7 + i] = torch.tensor([-2.25, (i - 3) * 0.3, 0.4, 180 + (i - 3) * 10, 0, 0]) |
| for i in range(3): |
| us_placements[14 + i] = torch.tensor([(1 - i) * 1.0, 0.9, 0.6, -90, 0, 0]) |
| us_placements[17 + i] = torch.tensor([(1 - i) * 1.0, -0.9, 0.6, 90, 0, 0]) |
|
|
| |
| gt_steering = torch.tensor(steering * 20.0) |
| gt_throttle = torch.tensor(max(0.0, ax / 3.0)).clamp(0, 1) |
| gt_brake = torch.tensor(max(0.0, -ax / 8.0)).clamp(0, 1) |
|
|
| gt_waypoints = torch.zeros(NUM_WAYPOINTS, 4) |
| for t in range(NUM_WAYPOINTS): |
| dt = (t + 1) * 0.5 |
| gt_waypoints[t, 0] = speed_ms * dt |
| gt_waypoints[t, 1] = -lane_center * min(1.0, dt / 3.0) |
| gt_waypoints[t, 2] = curvature * speed_ms * dt |
| gt_waypoints[t, 3] = min(speed_ms, MAX_SPEED_MS) |
|
|
| if abs(steering) > 0.3: |
| gt_behavior = 1 if steering > 0 else 2 |
| elif abs(ax) < 0.1 and speed_ms < 0.5: |
| gt_behavior = 5 |
| else: |
| gt_behavior = 0 |
|
|
| bev = BEV_SIZE |
| gt_seg = torch.zeros(bev, bev, dtype=torch.long) |
| gt_seg[bev // 4 : 3 * bev // 4, :] = 1 |
|
|
| gt_heatmap = torch.zeros(10, bev, bev) |
|
|
| gt_occ = torch.zeros(1, bev, bev) |
| gt_occ[:, : bev // 4, :] = 1.0 |
| gt_occ[:, 3 * bev // 4 :, :] = 1.0 |
|
|
| inputs = { |
| "camera_images": camera_images, |
| "camera_intrinsics": K, |
| "camera_extrinsics": E, |
| "ultrasonic_distances": us_distances, |
| "ultrasonic_placements": us_placements, |
| "ego_state": ego_state, |
| "nav_command": torch.tensor(nav_cmd, dtype=torch.long), |
| } |
|
|
| targets = { |
| "gt_steering": gt_steering, |
| "gt_throttle": gt_throttle, |
| "gt_brake": gt_brake, |
| "gt_waypoints": gt_waypoints, |
| "gt_behavior": torch.tensor(gt_behavior, dtype=torch.long), |
| "gt_segmentation": gt_seg, |
| "gt_heatmap": gt_heatmap, |
| "gt_occupancy": gt_occ, |
| } |
|
|
| return inputs, targets |
|
|
|
|
| def collate_fn(batch): |
| inputs_list, targets_list = zip(*batch) |
| inputs = {k: torch.stack([d[k] for d in inputs_list]) for k in inputs_list[0]} |
| targets = {k: torch.stack([d[k] for d in targets_list]) for k in targets_list[0]} |
| return inputs, targets |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def evaluate(model, loss_fn, val_loader, device, max_batches=50): |
| model.eval() |
| losses = [] |
| for i, (inputs, targets) in enumerate(val_loader): |
| if i >= max_batches: |
| break |
| inputs = {k: v.to(device, non_blocking=True) for k, v in inputs.items()} |
| targets = {k: v.to(device, non_blocking=True) for k, v in targets.items()} |
| try: |
| output = model(**inputs) |
| l = loss_fn(output, targets) |
| losses.append(l["total"].item()) |
| except RuntimeError: |
| continue |
| return np.mean(losses) if losses else float("inf") |
|
|
|
|
| def train(args, train_ds, val_ds): |
| """Build model and run training loop.""" |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"\n[Train] Device: {device}") |
| if device.type == "cuda": |
| print(f" GPU: {torch.cuda.get_device_name()}") |
| print(f" VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB") |
|
|
| |
| HAS_TRACKIO = False |
| try: |
| import trackio |
| trackio.init(project="fsd-level5-cot", name="sadc-e2e-training") |
| HAS_TRACKIO = True |
| print(" Trackio initialized β") |
| except Exception as e: |
| print(f" Trackio not available: {e}") |
|
|
| |
| train_dataset = SADCDrivingDataset(train_ds) |
| val_dataset = SADCDrivingDataset(val_ds) |
|
|
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=args.batch_size, |
| shuffle=True, |
| num_workers=args.num_workers, |
| collate_fn=collate_fn, |
| pin_memory=True, |
| drop_last=True, |
| ) |
| val_loader = DataLoader( |
| val_dataset, |
| batch_size=args.batch_size, |
| shuffle=False, |
| num_workers=args.num_workers, |
| collate_fn=collate_fn, |
| pin_memory=True, |
| drop_last=True, |
| ) |
| print(f" Train batches/epoch: {len(train_loader)}") |
| print(f" Val batches: {len(val_loader)}") |
|
|
| |
| print("\n[Train] Building FSD model...") |
| script_dir = os.path.dirname(os.path.abspath(__file__)) |
| if script_dir not in sys.path: |
| sys.path.insert(0, script_dir) |
|
|
| from fsd_model.config import VehicleConfig |
| from fsd_model.model import FullSelfDrivingModel, FSDLoss |
|
|
| config = VehicleConfig() |
| model = FullSelfDrivingModel( |
| vehicle_config=config, |
| bev_size=BEV_SIZE, |
| bev_resolution=0.5, |
| bev_feature_dim=BEV_FEATURE_DIM, |
| num_object_classes=10, |
| num_seg_classes=7, |
| num_waypoints=NUM_WAYPOINTS, |
| planning_d_model=PLANNING_D_MODEL, |
| future_steps=6, |
| num_forecast_modes=6, |
| forecast_steps=12, |
| num_behaviors=10, |
| enable_cot=True, |
| cot_num_actor_queries=COT_ACTOR_QUERIES, |
| cot_num_road_queries=COT_ROAD_QUERIES, |
| ).to(device) |
|
|
| param_info = model.count_parameters() |
| total_params = param_info["total"] |
| print(f" Total parameters: {total_params:,}") |
|
|
| |
| loss_fn = FSDLoss( |
| learnable_weights=True, |
| w_detection=0.5, |
| w_segmentation=1.0, |
| w_occupancy=1.0, |
| w_motion=0.5, |
| w_behavior=1.0, |
| w_trajectory=3.0, |
| w_control=3.0, |
| w_safety=2.0, |
| ).to(device) |
|
|
| |
| all_params = list(model.parameters()) + list(loss_fn.parameters()) |
| optimizer = torch.optim.AdamW(all_params, lr=args.lr, weight_decay=args.weight_decay) |
|
|
| total_steps = len(train_loader) * args.epochs // args.grad_accum |
| scheduler = torch.optim.lr_scheduler.OneCycleLR( |
| optimizer, |
| max_lr=args.lr, |
| total_steps=total_steps + 10, |
| pct_start=0.1, |
| anneal_strategy="cos", |
| ) |
|
|
| if hasattr(model, "gradient_checkpointing_enable"): |
| model.gradient_checkpointing_enable() |
|
|
| |
| effective_batch = args.batch_size * args.grad_accum |
| print(f"\n[Train] Starting: {args.epochs} epochs, effective batch={effective_batch}") |
| print(f" Total optimizer steps: ~{total_steps}") |
|
|
| global_step = 0 |
| best_val_loss = float("inf") |
| t0 = time.time() |
|
|
| for epoch in range(args.epochs): |
| model.train() |
| epoch_losses = [] |
| optimizer.zero_grad() |
|
|
| for batch_idx, (inputs, targets) in enumerate(train_loader): |
| inputs = {k: v.to(device, non_blocking=True) for k, v in inputs.items()} |
| targets = {k: v.to(device, non_blocking=True) for k, v in targets.items()} |
|
|
| try: |
| output = model(**inputs) |
| losses = loss_fn(output, targets) |
| loss = losses["total"] / args.grad_accum |
| except RuntimeError as e: |
| if "out of memory" in str(e): |
| torch.cuda.empty_cache() |
| print(f" OOM at batch {batch_idx}, skipping") |
| continue |
| raise |
|
|
| loss.backward() |
|
|
| if (batch_idx + 1) % args.grad_accum == 0: |
| torch.nn.utils.clip_grad_norm_(all_params, args.max_grad_norm) |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| global_step += 1 |
|
|
| total_loss_val = losses["total"].item() |
| epoch_losses.append(total_loss_val) |
|
|
| |
| if (batch_idx + 1) % args.log_every == 0: |
| elapsed = time.time() - t0 |
| lr = scheduler.get_last_lr()[0] |
| avg_loss = np.mean(epoch_losses[-args.log_every :]) |
| ctrl = losses.get("control", torch.tensor(0.0)).item() |
| traj = losses.get("trajectory", torch.tensor(0.0)).item() |
| seg = losses.get("segmentation", torch.tensor(0.0)).item() |
| safety = losses.get("safety", torch.tensor(0.0)).item() |
|
|
| print( |
| f" [E{epoch+1}/{args.epochs}][{batch_idx+1}/{len(train_loader)}] " |
| f"loss={avg_loss:.4f} ctrl={ctrl:.4f} traj={traj:.4f} " |
| f"seg={seg:.4f} safety={safety:.4f} lr={lr:.2e} t={elapsed:.0f}s" |
| ) |
|
|
| if HAS_TRACKIO: |
| trackio.log({ |
| "train/loss": avg_loss, |
| "train/control_loss": ctrl, |
| "train/trajectory_loss": traj, |
| "train/segmentation_loss": seg, |
| "train/safety_loss": safety, |
| "train/lr": lr, |
| "train/epoch": epoch + batch_idx / len(train_loader), |
| }) |
|
|
| |
| if global_step > 0 and global_step % args.eval_every == 0: |
| val_loss = evaluate(model, loss_fn, val_loader, device) |
| print(f" ββ EVAL step {global_step}: val_loss={val_loss:.4f} (best={best_val_loss:.4f})") |
| if HAS_TRACKIO: |
| trackio.log({"val/loss": val_loss, "val/step": global_step}) |
| if val_loss < best_val_loss: |
| best_val_loss = val_loss |
| save_checkpoint(model, args.save_dir, "best") |
| print(f" ββ Saved best model (val_loss={val_loss:.4f})") |
| model.train() |
|
|
| |
| val_loss = evaluate(model, loss_fn, val_loader, device) |
| avg_epoch_loss = np.mean(epoch_losses) if epoch_losses else float("inf") |
| print( |
| f"\n Epoch {epoch+1}/{args.epochs}: " |
| f"train_loss={avg_epoch_loss:.4f} val_loss={val_loss:.4f}" |
| ) |
| if val_loss < best_val_loss: |
| best_val_loss = val_loss |
| save_checkpoint(model, args.save_dir, "best") |
| print(f" ββ New best model (val_loss={val_loss:.4f})") |
|
|
| |
| total_time = time.time() - t0 |
| print(f"\n{'='*60}") |
| print(f"Training complete in {total_time/60:.1f} min") |
| print(f"Best val loss: {best_val_loss:.4f}") |
| save_checkpoint(model, args.save_dir, "final") |
|
|
| |
| if args.push_to_hub: |
| print(f"\n[Hub] Pushing model to {args.hub_model_id}...") |
| try: |
| from huggingface_hub import HfApi |
| api = HfApi() |
| api.upload_folder( |
| folder_path=os.path.join(args.save_dir, "best"), |
| repo_id=args.hub_model_id, |
| path_in_repo="trained_model", |
| commit_message=f"Trained model (best val_loss={best_val_loss:.4f})", |
| ) |
| print(f" β Pushed to {args.hub_model_id}/trained_model") |
| except Exception as e: |
| print(f" Push failed: {e}") |
|
|
| |
| meta = { |
| "dataset": DATASET_NAME, |
| "train_samples": len(train_ds), |
| "val_samples": len(val_ds), |
| "epochs": args.epochs, |
| "batch_size": args.batch_size, |
| "grad_accum": args.grad_accum, |
| "lr": args.lr, |
| "best_val_loss": best_val_loss, |
| "total_params": total_params, |
| "training_time_min": total_time / 60, |
| "device": str(device), |
| } |
| meta_path = os.path.join(args.save_dir, "training_meta.json") |
| with open(meta_path, "w") as f: |
| json.dump(meta, f, indent=2) |
| print(f" Metadata saved to {meta_path}") |
|
|
| if args.push_to_hub: |
| try: |
| api.upload_file( |
| path_or_fileobj=meta_path, |
| path_in_repo="trained_model/training_meta.json", |
| repo_id=args.hub_model_id, |
| ) |
| except Exception: |
| pass |
|
|
| print("\nDone! β") |
| return best_val_loss |
|
|
|
|
| def save_checkpoint(model, save_dir, tag): |
| path = os.path.join(save_dir, tag) |
| os.makedirs(path, exist_ok=True) |
| if hasattr(model, "save_pretrained"): |
| model.save_pretrained(path) |
| else: |
| torch.save(model.state_dict(), os.path.join(path, "model.pt")) |
|
|
|
|
| |
| |
| |
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="End-to-end FSD-Level5-CoT training on SADC") |
|
|
| |
| p.add_argument("--train_samples", type=int, default=5000) |
| p.add_argument("--val_samples", type=int, default=1000) |
| p.add_argument("--train_split", type=str, default="pretrain_train") |
| p.add_argument("--val_split", type=str, default="pretrain_val") |
| p.add_argument("--data_dir", type=str, default="./sadc_subset") |
|
|
| |
| p.add_argument("--epochs", type=int, default=5) |
| p.add_argument("--batch_size", type=int, default=8) |
| p.add_argument("--grad_accum", type=int, default=4) |
| p.add_argument("--lr", type=float, default=3e-4) |
| p.add_argument("--weight_decay", type=float, default=1e-4) |
| p.add_argument("--max_grad_norm", type=float, default=5.0) |
| p.add_argument("--num_workers", type=int, default=4) |
|
|
| |
| p.add_argument("--log_every", type=int, default=10) |
| p.add_argument("--eval_every", type=int, default=500) |
|
|
| |
| p.add_argument("--save_dir", type=str, default="./checkpoints") |
| p.add_argument("--push_to_hub", action="store_true", default=True) |
| p.add_argument("--no_push_to_hub", action="store_false", dest="push_to_hub") |
| p.add_argument("--hub_model_id", type=str, default=HUB_MODEL_ID) |
|
|
| return p.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| print("=" * 60) |
| print(" FSD-Level5-CoT Β· End-to-End Training on SADC") |
| print("=" * 60) |
| print(f" Train samples: {args.train_samples}") |
| print(f" Val samples: {args.val_samples}") |
| print(f" Epochs: {args.epochs}") |
| print(f" Batch size: {args.batch_size} Γ {args.grad_accum} accum = {args.batch_size * args.grad_accum}") |
| print(f" LR: {args.lr}") |
| print(f" Push to Hub: {args.push_to_hub} β {args.hub_model_id}") |
| print("=" * 60) |
|
|
| |
| train_ds, val_ds = download_sadc_subset( |
| train_samples=args.train_samples, |
| val_samples=args.val_samples, |
| output_dir=args.data_dir, |
| train_split=args.train_split, |
| val_split=args.val_split, |
| ) |
|
|
| |
| best_val = train(args, train_ds, val_ds) |
|
|
| return best_val |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|