Long Nguyen commited on
Commit
d954131
·
verified ·
1 Parent(s): 5c324e4

Upload example.ipynb

Browse files
Files changed (1) hide show
  1. example.ipynb +144 -0
example.ipynb ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "299923dd",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import matplotlib.pyplot as plt\n",
11
+ "import numpy as np\n",
12
+ "import torch\n",
13
+ "import sys\n",
14
+ "import os\n",
15
+ "from huggingface_hub import hf_hub_download\n",
16
+ "from huggingface_hub import snapshot_download"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "markdown",
21
+ "id": "88f1cc80",
22
+ "metadata": {},
23
+ "source": [
24
+ "### Download Inference Code, Model Checkpoint and Example Data"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "id": "da55e6c8",
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "inference_path = hf_hub_download(repo_id=\"longpollehn/tfv6_navsim\", filename=\"ltfv6.py\")\n",
35
+ "config_path = hf_hub_download(repo_id=\"longpollehn/tfv6_navsim\", filename=\"config.json\")\n",
36
+ "model_path = hf_hub_download(repo_id=\"longpollehn/tfv6_navsim\", filename=\"model_0060.pth\")\n",
37
+ "data_path = snapshot_download(repo_id=\"longpollehn/tfv6_navsim\", allow_patterns=\"data/*\")\n",
38
+ "\n",
39
+ "sys.path.insert(0, os.path.dirname(inference_path))\n",
40
+ "\n",
41
+ "from ltfv6 import NavsimData, load_tf"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "markdown",
46
+ "id": "e940c04b",
47
+ "metadata": {},
48
+ "source": [
49
+ "### Load Model"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "id": "afa5d128",
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
60
+ "model = load_tf(model_path, device)"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "markdown",
65
+ "id": "1441cf9b",
66
+ "metadata": {},
67
+ "source": [
68
+ "### Example data"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": null,
74
+ "id": "691ad60f",
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "dataset = NavsimData(root=data_path, config=model.config)\n",
79
+ "\n",
80
+ "sample_index = 5\n",
81
+ "data = dataset[sample_index]\n",
82
+ "data = torch.utils.data._utils.collate.default_collate([data])\n",
83
+ "data = {k: v.to(device) for k, v in data.items()}\n",
84
+ "\n",
85
+ "plt.imshow(np.transpose(data[\"rgb\"][0].cpu().numpy(), (1, 2, 0)).astype(np.uint8))\n",
86
+ "plt.show()\n",
87
+ "\n",
88
+ "print(\n",
89
+ " f\"Inputs: speed={data['speed'][0].item():.2f} m/s, acceleration={data['acceleration'][0].item():.2f} m/s², command={data['command'][0].cpu().numpy()}\"\n",
90
+ ")"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "markdown",
95
+ "id": "347c8ac1",
96
+ "metadata": {},
97
+ "source": [
98
+ "### Inference"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": null,
104
+ "id": "388fd167",
105
+ "metadata": {},
106
+ "outputs": [],
107
+ "source": [
108
+ "prediction = model(data)\n",
109
+ "waypoints = prediction.pred_future_waypoints\n",
110
+ "headings = prediction.pred_headings\n",
111
+ "\n",
112
+ "# Model was trained in CARLA coordinate system, convert to NavSim/NuPlan coordinate system\n",
113
+ "waypoints[:, :, 1] *= -1 # Invert Y axis\n",
114
+ "headings *= -1\n",
115
+ "\n",
116
+ "plt.plot(waypoints[0, :, 0].cpu().detach().numpy(), waypoints[0, :, 1].cpu().detach().numpy(), marker=\"o\")\n",
117
+ "plt.xlim(-32, 32)\n",
118
+ "plt.ylim(-32, 32)\n",
119
+ "plt.show()"
120
+ ]
121
+ }
122
+ ],
123
+ "metadata": {
124
+ "kernelspec": {
125
+ "display_name": "lead",
126
+ "language": "python",
127
+ "name": "python3"
128
+ },
129
+ "language_info": {
130
+ "codemirror_mode": {
131
+ "name": "ipython",
132
+ "version": 3
133
+ },
134
+ "file_extension": ".py",
135
+ "mimetype": "text/x-python",
136
+ "name": "python",
137
+ "nbconvert_exporter": "python",
138
+ "pygments_lexer": "ipython3",
139
+ "version": "3.10.15"
140
+ }
141
+ },
142
+ "nbformat": 4,
143
+ "nbformat_minor": 5
144
+ }