{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "299923dd", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import torch\n", "import sys\n", "import os\n", "from huggingface_hub import hf_hub_download\n", "from huggingface_hub import snapshot_download" ] }, { "cell_type": "markdown", "id": "88f1cc80", "metadata": {}, "source": [ "### Download Inference Code, Model Checkpoint and Example Data" ] }, { "cell_type": "code", "execution_count": null, "id": "da55e6c8", "metadata": {}, "outputs": [], "source": [ "inference_path = hf_hub_download(repo_id=\"longpollehn/tfv6_navsim\", filename=\"ltfv6.py\")\n", "config_path = hf_hub_download(repo_id=\"longpollehn/tfv6_navsim\", filename=\"config.json\")\n", "model_path = hf_hub_download(repo_id=\"longpollehn/tfv6_navsim\", filename=\"model_0060.pth\")\n", "data_path = snapshot_download(repo_id=\"longpollehn/tfv6_navsim\", allow_patterns=\"data/*\")\n", "\n", "sys.path.insert(0, os.path.dirname(inference_path))\n", "\n", "from ltfv6 import NavsimData, load_tf" ] }, { "cell_type": "markdown", "id": "e940c04b", "metadata": {}, "source": [ "### Load Model" ] }, { "cell_type": "code", "execution_count": null, "id": "afa5d128", "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model = load_tf(model_path, device)" ] }, { "cell_type": "markdown", "id": "1441cf9b", "metadata": {}, "source": [ "### Example data" ] }, { "cell_type": "code", "execution_count": null, "id": "691ad60f", "metadata": {}, "outputs": [], "source": [ "dataset = NavsimData(root=data_path, config=model.config)\n", "\n", "sample_index = 5\n", "data = dataset[sample_index]\n", "data = torch.utils.data._utils.collate.default_collate([data])\n", "data = {k: v.to(device) for k, v in data.items()}\n", "\n", "plt.imshow(np.transpose(data[\"rgb\"][0].cpu().numpy(), (1, 2, 0)).astype(np.uint8))\n", "plt.show()\n", "\n", "print(\n", " f\"Inputs: speed={data['speed'][0].item():.2f} m/s, acceleration={data['acceleration'][0].item():.2f} m/s², command={data['command'][0].cpu().numpy()}\"\n", ")" ] }, { "cell_type": "markdown", "id": "347c8ac1", "metadata": {}, "source": [ "### Inference" ] }, { "cell_type": "code", "execution_count": null, "id": "388fd167", "metadata": {}, "outputs": [], "source": [ "prediction = model(data)\n", "waypoints = prediction.pred_future_waypoints\n", "headings = prediction.pred_headings\n", "\n", "# Model was trained in CARLA coordinate system, convert to NavSim/NuPlan coordinate system\n", "waypoints[:, :, 1] *= -1 # Invert Y axis\n", "headings *= -1\n", "\n", "plt.plot(waypoints[0, :, 0].cpu().detach().numpy(), waypoints[0, :, 1].cpu().detach().numpy(), marker=\"o\")\n", "plt.xlim(-32, 32)\n", "plt.ylim(-32, 32)\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "lead", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.15" } }, "nbformat": 4, "nbformat_minor": 5 }