Upload folder using huggingface_hub
Browse files- .gitattributes +3 -0
- LICENSE +21 -0
- README.md +342 -3
- demo/1.eng4th_semantic_navigation.gif +3 -0
- demo/1.lobby_semantic_navigation.gif +3 -0
- demo/3.cyc4th_semantic_navigation.gif +3 -0
- ros_deployment_ws/src/semantic_cnn_nav/cnn_msgs/CMakeLists.txt +28 -0
- ros_deployment_ws/src/semantic_cnn_nav/cnn_msgs/msg/CNN_data.msg +5 -0
- ros_deployment_ws/src/semantic_cnn_nav/cnn_msgs/package.xml +17 -0
- ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/CMakeLists.txt +199 -0
- ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/launch/cnn_data_pub.launch +11 -0
- ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/launch/semantic_cnn_inference.launch +11 -0
- ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/launch/semantic_cnn_nav_gazebo.launch +50 -0
- ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/package.xml +64 -0
- ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/src/__pycache__/cnn_model.cpython-38.pyc +0 -0
- ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/src/__pycache__/custom_cnn_full.cpython-38.pyc +0 -0
- ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/src/__pycache__/model.cpython-38.pyc +0 -0
- ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/src/__pycache__/pfeiffer_model.cpython-38.pyc +0 -0
- ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/src/cnn_data_pub.py +108 -0
- ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/src/cnn_model.py +783 -0
- ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/src/goal_visualize.py +40 -0
- ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/src/model/s3_net_model.pth +3 -0
- ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/src/model/semantic_cnn_model.pth +3 -0
- ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/src/pure_pursuit.py +254 -0
- ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/src/semantic_cnn_nav_inference.py +256 -0
- training/model/semantic_cnn_model.pth +3 -0
- training/run_eval.sh +63 -0
- training/run_train.sh +73 -0
- training/scripts/__pycache__/model.cpython-38.pyc +0 -0
- training/scripts/decode_demo.py +201 -0
- training/scripts/model.py +475 -0
- training/scripts/train.py +385 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
demo/1.eng4th_semantic_navigation.gif filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
demo/1.lobby_semantic_navigation.gif filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
demo/3.cyc4th_semantic_navigation.gif filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Temple Robotics and Artificial Intelligence Lab
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,3 +1,342 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Semantic2D: Enabling Semantic Scene Understanding with 2D Lidar Alone
|
| 2 |
+
|
| 3 |
+
Semantic CNN Navigation implementation code for our paper ["Semantic2D: Enabling Semantic Scene Understanding with 2D Lidar Alone"](https://arxiv.org/pdf/2409.09899).
|
| 4 |
+
Video demos can be found at [multimedia demonstrations](https://youtu.be/P1Hsvj6WUSY).
|
| 5 |
+
The Semantic2D dataset can be found and downloaded at: https://doi.org/10.5281/zenodo.18350696.
|
| 6 |
+
|
| 7 |
+
## Related Resources
|
| 8 |
+
|
| 9 |
+
- **Dataset Download:** https://doi.org/10.5281/zenodo.18350696
|
| 10 |
+
- **SALSA (Dataset and Labeling Framework):** https://github.com/TempleRAIL/semantic2d
|
| 11 |
+
- **S³-Net (Stochastic Semantic Segmentation):** https://github.com/TempleRAIL/s3_net
|
| 12 |
+
- **Semantic CNN Navigation:** https://github.com/TempleRAIL/semantic_cnn_nav
|
| 13 |
+
|
| 14 |
+
## Overview
|
| 15 |
+
|
| 16 |
+
[](https://opensource.org/licenses/MIT)
|
| 17 |
+
|
| 18 |
+
This repository contains two main components:
|
| 19 |
+
1. **Training**: CNN-based control policy training using the Semantic2D dataset
|
| 20 |
+
2. **ROS Deployment**: Real-time semantic-aware navigation for mobile robots
|
| 21 |
+
|
| 22 |
+
The Semantic CNN Navigation system combines:
|
| 23 |
+
- **S³-Net**: Real-time semantic segmentation of 2D LiDAR scans
|
| 24 |
+
- **SemanticCNN**: ResNet-based control policy that uses semantic information for navigation
|
| 25 |
+
|
| 26 |
+
## Demo Results
|
| 27 |
+
|
| 28 |
+
**Engineering Lobby Semantic Navigation**
|
| 29 |
+

|
| 30 |
+
|
| 31 |
+
**Engineering 4th Floor Semantic Navigation**
|
| 32 |
+

|
| 33 |
+
|
| 34 |
+
**CYC 4th Floor Semantic Navigation**
|
| 35 |
+

|
| 36 |
+
|
| 37 |
+
## System Architecture
|
| 38 |
+
|
| 39 |
+
```
|
| 40 |
+
┌─────────────────────────────────────────────────────────────────────┐
|
| 41 |
+
│ Semantic CNN Navigation │
|
| 42 |
+
├─────────────────────────────────────────────────────────────────────┤
|
| 43 |
+
│ │
|
| 44 |
+
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────────┐ │
|
| 45 |
+
│ │ LiDAR Scan │───▶│ S³-Net │───▶│ Semantic Labels (10) │ │
|
| 46 |
+
│ │ + Intensity│ │ Segmentation│ │ per LiDAR point │ │
|
| 47 |
+
│ └─────────────┘ └─────────────┘ └───────────┬─────────────┘ │
|
| 48 |
+
│ │ │
|
| 49 |
+
│ ┌─────────────┐ ▼ │
|
| 50 |
+
│ │ Sub-Goal │───────────────────────▶┌─────────────────────────┐ │
|
| 51 |
+
│ │ (x, y) │ │ SemanticCNN │ │
|
| 52 |
+
│ └─────────────┘ │ (ResNet + Bottleneck) │ │
|
| 53 |
+
│ │ │ │
|
| 54 |
+
│ ┌─────────────┐ │ Input: 80x80 scan map │ │
|
| 55 |
+
│ │ Scan Map │───────────────────────▶│ + semantic map │ │
|
| 56 |
+
│ │ (history) │ │ + sub-goal │ │
|
| 57 |
+
│ └─────────────┘ └───────────┬─────────────┘ │
|
| 58 |
+
│ │ │
|
| 59 |
+
│ ▼ │
|
| 60 |
+
│ ┌─────────────────────────┐ │
|
| 61 |
+
│ │ Velocity Command │ │
|
| 62 |
+
│ │ (linear_x, angular_z) │ │
|
| 63 |
+
│ └─────────────────────────┘ │
|
| 64 |
+
└─────────────────────────────────────────────────────────────────────┘
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
## Requirements
|
| 68 |
+
|
| 69 |
+
### Training
|
| 70 |
+
- Python 3.7+
|
| 71 |
+
- PyTorch 1.7.1+
|
| 72 |
+
- TensorBoard
|
| 73 |
+
- NumPy
|
| 74 |
+
- tqdm
|
| 75 |
+
|
| 76 |
+
### ROS Deployment
|
| 77 |
+
- Ubuntu 20.04
|
| 78 |
+
- ROS Noetic
|
| 79 |
+
- Python 3.8.5
|
| 80 |
+
- PyTorch 1.7.1
|
| 81 |
+
|
| 82 |
+
Install training dependencies:
|
| 83 |
+
```bash
|
| 84 |
+
pip install torch torchvision tensorboardX numpy tqdm
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
---
|
| 88 |
+
|
| 89 |
+
# Part 1: Training
|
| 90 |
+
|
| 91 |
+
## Dataset Structure
|
| 92 |
+
|
| 93 |
+
The training expects the Semantic2D dataset organized as follows:
|
| 94 |
+
|
| 95 |
+
```
|
| 96 |
+
~/semantic2d_data/
|
| 97 |
+
├── dataset.txt # List of dataset folders
|
| 98 |
+
├── 2024-04-11-15-24-29/ # Dataset folder 1
|
| 99 |
+
│ ├── train.txt # Training sample list
|
| 100 |
+
│ ├── dev.txt # Validation sample list
|
| 101 |
+
│ ├── scans_lidar/ # Range scans (.npy)
|
| 102 |
+
│ ├── semantic_label/ # Semantic labels (.npy)
|
| 103 |
+
│ ├── sub_goals_local/ # Local sub-goals (.npy)
|
| 104 |
+
│ └── velocities/ # Ground truth velocities (.npy)
|
| 105 |
+
└── ...
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
## Model Architecture
|
| 109 |
+
|
| 110 |
+
**SemanticCNN** uses a ResNet-style architecture with Bottleneck blocks:
|
| 111 |
+
|
| 112 |
+
| Component | Details |
|
| 113 |
+
|-----------|---------|
|
| 114 |
+
| **Input** | 2 channels: scan map (80x80) + semantic map (80x80) |
|
| 115 |
+
| **Backbone** | ResNet with Bottleneck blocks [2, 1, 1] |
|
| 116 |
+
| **Goal Input** | 2D sub-goal (x, y) concatenated after pooling |
|
| 117 |
+
| **Output** | 2D velocity (linear_x, angular_z) |
|
| 118 |
+
| **Loss** | MSE Loss |
|
| 119 |
+
|
| 120 |
+
**Key Parameters:**
|
| 121 |
+
- Sequence length: 10 frames
|
| 122 |
+
- Image size: 80x80
|
| 123 |
+
- LiDAR points: 1081 → downsampled to 720 (removing ±180 points)
|
| 124 |
+
|
| 125 |
+
## Training
|
| 126 |
+
|
| 127 |
+
Train the Semantic CNN model:
|
| 128 |
+
|
| 129 |
+
```bash
|
| 130 |
+
cd training
|
| 131 |
+
sh run_train.sh ~/semantic2d_data/ ~/semantic2d_data/
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
**Arguments:**
|
| 135 |
+
- `$1` - Training data directory
|
| 136 |
+
- `$2` - Validation data directory
|
| 137 |
+
|
| 138 |
+
**Training Configuration** (in `scripts/train.py`):
|
| 139 |
+
|
| 140 |
+
| Parameter | Default | Description |
|
| 141 |
+
|-----------|---------|-------------|
|
| 142 |
+
| `NUM_EPOCHS` | 4000 | Total training epochs |
|
| 143 |
+
| `BATCH_SIZE` | 64 | Samples per batch |
|
| 144 |
+
| `LEARNING_RATE` | 0.001 | Initial learning rate |
|
| 145 |
+
|
| 146 |
+
**Learning Rate Schedule:**
|
| 147 |
+
- Epochs 0-40: `1e-3`
|
| 148 |
+
- Epochs 40-2000: `2e-4`
|
| 149 |
+
- Epochs 2000-21000: `2e-5`
|
| 150 |
+
- Epochs 21000+: `1e-5`
|
| 151 |
+
|
| 152 |
+
Model checkpoints saved every 50 epochs to `./model/`.
|
| 153 |
+
|
| 154 |
+
## Evaluation
|
| 155 |
+
|
| 156 |
+
Evaluate the trained model:
|
| 157 |
+
|
| 158 |
+
```bash
|
| 159 |
+
cd training
|
| 160 |
+
sh run_eval.sh ~/semantic2d_data/
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
**Output:** Results saved to `./output/`
|
| 164 |
+
|
| 165 |
+
## Training File Structure
|
| 166 |
+
|
| 167 |
+
```
|
| 168 |
+
training/
|
| 169 |
+
├── model/
|
| 170 |
+
│ └── semantic_cnn_model.pth # Pretrained model weights
|
| 171 |
+
├── scripts/
|
| 172 |
+
│ ├── model.py # SemanticCNN architecture + NavDataset
|
| 173 |
+
│ ├── train.py # Training script
|
| 174 |
+
│ └── decode_demo.py # Evaluation/demo script
|
| 175 |
+
├── run_train.sh # Training driver script
|
| 176 |
+
└── run_eval.sh # Evaluation driver script
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
---
|
| 180 |
+
|
| 181 |
+
## TensorBoard Monitoring
|
| 182 |
+
|
| 183 |
+
Training logs are saved to `./runs/`. View training progress:
|
| 184 |
+
|
| 185 |
+
```bash
|
| 186 |
+
cd training
|
| 187 |
+
tensorboard --logdir=runs
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
Monitored metrics:
|
| 191 |
+
- Training loss
|
| 192 |
+
- Validation loss
|
| 193 |
+
|
| 194 |
+
---
|
| 195 |
+
|
| 196 |
+
# Part 2: ROS Deployment
|
| 197 |
+
|
| 198 |
+
## Prerequisites
|
| 199 |
+
|
| 200 |
+
Install the following ROS packages:
|
| 201 |
+
|
| 202 |
+
```bash
|
| 203 |
+
# Create catkin workspace
|
| 204 |
+
mkdir -p ~/catkin_ws/src
|
| 205 |
+
cd ~/catkin_ws/src
|
| 206 |
+
|
| 207 |
+
# Clone required packages
|
| 208 |
+
git clone https://github.com/TempleRAIL/robot_gazebo.git
|
| 209 |
+
git clone https://github.com/TempleRAIL/pedsim_ros_with_gazebo.git
|
| 210 |
+
|
| 211 |
+
# Build
|
| 212 |
+
cd ~/catkin_ws
|
| 213 |
+
catkin_make
|
| 214 |
+
source devel/setup.bash
|
| 215 |
+
```
|
| 216 |
+
|
| 217 |
+
## Installation
|
| 218 |
+
|
| 219 |
+
1. Copy the ROS workspace to your catkin workspace:
|
| 220 |
+
```bash
|
| 221 |
+
cp -r ros_deployment_ws/src/semantic_cnn_nav ~/catkin_ws/src/
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
2. Build the workspace:
|
| 225 |
+
```bash
|
| 226 |
+
cd ~/catkin_ws
|
| 227 |
+
catkin_make
|
| 228 |
+
source devel/setup.bash
|
| 229 |
+
```
|
| 230 |
+
|
| 231 |
+
## Usage
|
| 232 |
+
|
| 233 |
+
### Launch Gazebo Simulation
|
| 234 |
+
|
| 235 |
+
```bash
|
| 236 |
+
roslaunch semantic_cnn_nav semantic_cnn_nav_gazebo.launch
|
| 237 |
+
```
|
| 238 |
+
|
| 239 |
+
This launch file starts:
|
| 240 |
+
- Gazebo simulator with pedestrians (pedsim)
|
| 241 |
+
- AMCL localization
|
| 242 |
+
- CNN data publisher
|
| 243 |
+
- Semantic CNN inference node
|
| 244 |
+
- RViz visualization
|
| 245 |
+
|
| 246 |
+
### Launch Configuration
|
| 247 |
+
|
| 248 |
+
Key parameters in `semantic_cnn_nav_gazebo.launch`:
|
| 249 |
+
|
| 250 |
+
| Parameter | Default | Description |
|
| 251 |
+
|-----------|---------|-------------|
|
| 252 |
+
| `s3_net_model_file` | `model/s3_net_model.pth` | S³-Net model path |
|
| 253 |
+
| `semantic_cnn_model_file` | `model/semantic_cnn_model.pth` | SemanticCNN model path |
|
| 254 |
+
| `scene_file` | `eng_hall_5.xml` | Pedsim scenario file |
|
| 255 |
+
| `world_name` | `eng_hall.world` | Gazebo world file |
|
| 256 |
+
| `map_file` | `gazebo_eng_lobby.yaml` | Navigation map |
|
| 257 |
+
| `initial_pose_x/y/a` | 1.0, 0.0, 0.13 | Robot initial pose |
|
| 258 |
+
|
| 259 |
+
### Send Navigation Goals
|
| 260 |
+
|
| 261 |
+
Use RViz "2D Nav Goal" tool to send navigation goals to the robot.
|
| 262 |
+
|
| 263 |
+
## ROS Nodes
|
| 264 |
+
|
| 265 |
+
### cnn_data_pub
|
| 266 |
+
Publishes processed LiDAR data for the CNN.
|
| 267 |
+
|
| 268 |
+
**Subscriptions:**
|
| 269 |
+
- `/scan` (sensor_msgs/LaserScan)
|
| 270 |
+
|
| 271 |
+
**Publications:**
|
| 272 |
+
- `/cnn_data` (cnn_msgs/CNN_data)
|
| 273 |
+
|
| 274 |
+
### semantic_cnn_nav_inference
|
| 275 |
+
Main inference node combining S³-Net and SemanticCNN.
|
| 276 |
+
|
| 277 |
+
**Subscriptions:**
|
| 278 |
+
- `/cnn_data` (cnn_msgs/CNN_data)
|
| 279 |
+
|
| 280 |
+
**Publications:**
|
| 281 |
+
- `/navigation_velocity_smoother/raw_cmd_vel` (geometry_msgs/Twist)
|
| 282 |
+
|
| 283 |
+
**Parameters:**
|
| 284 |
+
- `~s3_net_model_file`: Path to S³-Net model
|
| 285 |
+
- `~semantic_cnn_model_file`: Path to SemanticCNN model
|
| 286 |
+
|
| 287 |
+
## ROS Deployment File Structure
|
| 288 |
+
|
| 289 |
+
```
|
| 290 |
+
ros_deployment_ws/
|
| 291 |
+
└── src/
|
| 292 |
+
└── semantic_cnn_nav/
|
| 293 |
+
├── cnn_msgs/
|
| 294 |
+
│ └── msg/
|
| 295 |
+
│ └── CNN_data.msg # Custom message definition
|
| 296 |
+
└── semantic_cnn/
|
| 297 |
+
├── launch/
|
| 298 |
+
│ ├── cnn_data_pub.launch
|
| 299 |
+
│ ├── semantic_cnn_inference.launch
|
| 300 |
+
│ └── semantic_cnn_nav_gazebo.launch
|
| 301 |
+
└── src/
|
| 302 |
+
├── model/
|
| 303 |
+
│ ├── s3_net_model.pth # S³-Net pretrained weights
|
| 304 |
+
│ └── semantic_cnn_model.pth # SemanticCNN weights
|
| 305 |
+
├── cnn_data_pub.py # Data preprocessing node
|
| 306 |
+
├── cnn_model.py # Model definitions
|
| 307 |
+
├── pure_pursuit.py # Pure pursuit controller
|
| 308 |
+
├── goal_visualize.py # Goal visualization
|
| 309 |
+
└── semantic_cnn_nav_inference.py # Main inference node
|
| 310 |
+
```
|
| 311 |
+
---
|
| 312 |
+
|
| 313 |
+
## Pre-trained Models
|
| 314 |
+
|
| 315 |
+
Pre-trained models are included:
|
| 316 |
+
|
| 317 |
+
| Model | Location | Description |
|
| 318 |
+
|-------|----------|-------------|
|
| 319 |
+
| `s3_net_model.pth` | `ros_deployment_ws/.../model/` | S³-Net semantic segmentation |
|
| 320 |
+
| `semantic_cnn_model.pth` | `training/model/` | SemanticCNN navigation policy |
|
| 321 |
+
|
| 322 |
+
---
|
| 323 |
+
|
| 324 |
+
## Citation
|
| 325 |
+
|
| 326 |
+
```bibtex
|
| 327 |
+
@article{xie2026semantic2d,
|
| 328 |
+
title={Semantic2D: Enabling Semantic Scene Understanding with 2D Lidar Alone},
|
| 329 |
+
author={Xie, Zhanteng and Pan, Yipeng and Zhang, Yinqiang and Pan, Jia and Dames, Philip},
|
| 330 |
+
journal={arXiv preprint arXiv:2409.09899},
|
| 331 |
+
year={2026}
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
@inproceedings{xie2021towards,
|
| 335 |
+
title={Towards Safe Navigation Through Crowded Dynamic Environments},
|
| 336 |
+
author={Xie, Zhanteng and Xin, Pujie and Dames, Philip},
|
| 337 |
+
booktitle={2021 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
|
| 338 |
+
year={2021},
|
| 339 |
+
doi={10.1109/IROS51168.2021.9636102}
|
| 340 |
+
}
|
| 341 |
+
```
|
| 342 |
+
|
demo/1.eng4th_semantic_navigation.gif
ADDED
|
Git LFS Details
|
demo/1.lobby_semantic_navigation.gif
ADDED
|
Git LFS Details
|
demo/3.cyc4th_semantic_navigation.gif
ADDED
|
Git LFS Details
|
ros_deployment_ws/src/semantic_cnn_nav/cnn_msgs/CMakeLists.txt
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cmake_minimum_required(VERSION 2.8.12)
|
| 2 |
+
|
| 3 |
+
project(cnn_msgs)
|
| 4 |
+
|
| 5 |
+
set(CMAKE_CXX_FLAGS "-std=c++11 ${CMAKE_CXX_FLAGS}")
|
| 6 |
+
|
| 7 |
+
find_package(catkin REQUIRED COMPONENTS
|
| 8 |
+
roscpp
|
| 9 |
+
rospy
|
| 10 |
+
std_msgs
|
| 11 |
+
message_generation
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
add_message_files(
|
| 15 |
+
FILES
|
| 16 |
+
CNN_data.msg
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
generate_messages(
|
| 21 |
+
DEPENDENCIES
|
| 22 |
+
std_msgs
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
catkin_package(
|
| 26 |
+
CATKIN_DEPENDS
|
| 27 |
+
std_msgs
|
| 28 |
+
)
|
ros_deployment_ws/src/semantic_cnn_nav/cnn_msgs/msg/CNN_data.msg
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
float32[] scan_his # 1081 range data from the laser scan
|
| 2 |
+
float32[] intensity_his # 1081 intensity data from the laser scan
|
| 3 |
+
float32[] scan # 1081 range data from the laser scan
|
| 4 |
+
float32[] goal # current goal in robot frame
|
| 5 |
+
float32[] vel # current velocity in robot frame
|
ros_deployment_ws/src/semantic_cnn_nav/cnn_msgs/package.xml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0"?>
|
| 2 |
+
<package>
|
| 3 |
+
<name>cnn_msgs</name>
|
| 4 |
+
<version>1.1.3</version>
|
| 5 |
+
<description>CNN messages.</description>
|
| 6 |
+
<maintainer email="zzuxzt@to.do">zzuxzt</maintainer>
|
| 7 |
+
<license>BSD</license>
|
| 8 |
+
<url type="website">https://github.com/zzuxzt</url>
|
| 9 |
+
<author email="zzuxzt@to.do">zzuxzt</author>
|
| 10 |
+
|
| 11 |
+
<buildtool_depend>catkin</buildtool_depend>
|
| 12 |
+
|
| 13 |
+
<build_depend>message_generation</build_depend>
|
| 14 |
+
<build_depend>std_msgs</build_depend>
|
| 15 |
+
|
| 16 |
+
<run_depend>std_msgs</run_depend>
|
| 17 |
+
</package>
|
ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/CMakeLists.txt
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cmake_minimum_required(VERSION 2.8.3)
|
| 2 |
+
project(semantic_cnn_nav)
|
| 3 |
+
|
| 4 |
+
## Compile as C++11, supported in ROS Kinetic and newer
|
| 5 |
+
# add_compile_options(-std=c++11)
|
| 6 |
+
|
| 7 |
+
## Find catkin macros and libraries
|
| 8 |
+
## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz)
|
| 9 |
+
## is used, also find other catkin packages
|
| 10 |
+
find_package(catkin REQUIRED COMPONENTS
|
| 11 |
+
roscpp
|
| 12 |
+
rospy
|
| 13 |
+
)
|
| 14 |
+
#add_executable(start_learning scripts/start_learning.py)
|
| 15 |
+
|
| 16 |
+
## System dependencies are found with CMake's conventions
|
| 17 |
+
# find_package(Boost REQUIRED COMPONENTS system)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
## Uncomment this if the package has a setup.py. This macro ensures
|
| 21 |
+
## modules and global scripts declared therein get installed
|
| 22 |
+
## See http://ros.org/doc/api/catkin/html/user_guide/setup_dot_py.html
|
| 23 |
+
# catkin_python_setup()
|
| 24 |
+
|
| 25 |
+
################################################
|
| 26 |
+
## Declare ROS messages, services and actions ##
|
| 27 |
+
################################################
|
| 28 |
+
|
| 29 |
+
## To declare and build messages, services or actions from within this
|
| 30 |
+
## package, follow these steps:
|
| 31 |
+
## * Let MSG_DEP_SET be the set of packages whose message types you use in
|
| 32 |
+
## your messages/services/actions (e.g. std_msgs, actionlib_msgs, ...).
|
| 33 |
+
## * In the file package.xml:
|
| 34 |
+
## * add a build_depend tag for "message_generation"
|
| 35 |
+
## * add a build_depend and a exec_depend tag for each package in MSG_DEP_SET
|
| 36 |
+
## * If MSG_DEP_SET isn't empty the following dependency has been pulled in
|
| 37 |
+
## but can be declared for certainty nonetheless:
|
| 38 |
+
## * add a exec_depend tag for "message_runtime"
|
| 39 |
+
## * In this file (CMakeLists.txt):
|
| 40 |
+
## * add "message_generation" and every package in MSG_DEP_SET to
|
| 41 |
+
## find_package(catkin REQUIRED COMPONENTS ...)
|
| 42 |
+
## * add "message_runtime" and every package in MSG_DEP_SET to
|
| 43 |
+
## catkin_package(CATKIN_DEPENDS ...)
|
| 44 |
+
## * uncomment the add_*_files sections below as needed
|
| 45 |
+
## and list every .msg/.srv/.action file to be processed
|
| 46 |
+
## * uncomment the generate_messages entry below
|
| 47 |
+
## * add every package in MSG_DEP_SET to generate_messages(DEPENDENCIES ...)
|
| 48 |
+
|
| 49 |
+
## Generate messages in the 'msg' folder
|
| 50 |
+
# add_message_files(
|
| 51 |
+
# FILES
|
| 52 |
+
# Message1.msg
|
| 53 |
+
# Message2.msg
|
| 54 |
+
# )
|
| 55 |
+
|
| 56 |
+
## Generate services in the 'srv' folder
|
| 57 |
+
# add_service_files(
|
| 58 |
+
# FILES
|
| 59 |
+
# Service1.srv
|
| 60 |
+
# Service2.srv
|
| 61 |
+
# )
|
| 62 |
+
|
| 63 |
+
## Generate actions in the 'action' folder
|
| 64 |
+
# add_action_files(
|
| 65 |
+
# FILES
|
| 66 |
+
# Action1.action
|
| 67 |
+
# Action2.action
|
| 68 |
+
# )
|
| 69 |
+
|
| 70 |
+
## Generate added messages and services with any dependencies listed here
|
| 71 |
+
# generate_messages(
|
| 72 |
+
# DEPENDENCIES
|
| 73 |
+
# std_msgs # Or other packages containing msgs
|
| 74 |
+
# )
|
| 75 |
+
|
| 76 |
+
################################################
|
| 77 |
+
## Declare ROS dynamic reconfigure parameters ##
|
| 78 |
+
################################################
|
| 79 |
+
|
| 80 |
+
## To declare and build dynamic reconfigure parameters within this
|
| 81 |
+
## package, follow these steps:
|
| 82 |
+
## * In the file package.xml:
|
| 83 |
+
## * add a build_depend and a exec_depend tag for "dynamic_reconfigure"
|
| 84 |
+
## * In this file (CMakeLists.txt):
|
| 85 |
+
## * add "dynamic_reconfigure" to
|
| 86 |
+
## find_package(catkin REQUIRED COMPONENTS ...)
|
| 87 |
+
## * uncomment the "generate_dynamic_reconfigure_options" section below
|
| 88 |
+
## and list every .cfg file to be processed
|
| 89 |
+
|
| 90 |
+
## Generate dynamic reconfigure parameters in the 'cfg' folder
|
| 91 |
+
# generate_dynamic_reconfigure_options(
|
| 92 |
+
# cfg/DynReconf1.cfg
|
| 93 |
+
# cfg/DynReconf2.cfg
|
| 94 |
+
# )
|
| 95 |
+
|
| 96 |
+
###################################
|
| 97 |
+
## catkin specific configuration ##
|
| 98 |
+
###################################
|
| 99 |
+
## The catkin_package macro generates cmake config files for your package
|
| 100 |
+
## Declare things to be passed to dependent projects
|
| 101 |
+
## INCLUDE_DIRS: uncomment this if your package contains header files
|
| 102 |
+
## LIBRARIES: libraries you create in this project that dependent projects also need
|
| 103 |
+
## CATKIN_DEPENDS: catkin_packages dependent projects also need
|
| 104 |
+
## DEPENDS: system dependencies of this project that dependent projects also need
|
| 105 |
+
catkin_package(
|
| 106 |
+
# INCLUDE_DIRS include
|
| 107 |
+
# LIBRARIES cablearning
|
| 108 |
+
# CATKIN_DEPENDS other_catkin_pkg
|
| 109 |
+
# DEPENDS system_lib
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
###########
|
| 113 |
+
## Build ##
|
| 114 |
+
###########
|
| 115 |
+
|
| 116 |
+
## Specify additional locations of header files
|
| 117 |
+
## Your package locations should be listed before other locations
|
| 118 |
+
include_directories(
|
| 119 |
+
# include
|
| 120 |
+
${catkin_INCLUDE_DIRS}
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
## Declare a C++ library
|
| 124 |
+
# add_library(${PROJECT_NAME}
|
| 125 |
+
# src/${PROJECT_NAME}/cablearning.cpp
|
| 126 |
+
# )
|
| 127 |
+
|
| 128 |
+
## Add cmake target dependencies of the library
|
| 129 |
+
## as an example, code may need to be generated before libraries
|
| 130 |
+
## either from message generation or dynamic reconfigure
|
| 131 |
+
# add_dependencies(${PROJECT_NAME} ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS})
|
| 132 |
+
|
| 133 |
+
## Declare a C++ executable
|
| 134 |
+
## With catkin_make all packages are built within a single CMake context
|
| 135 |
+
## The recommended prefix ensures that target names across packages don't collide
|
| 136 |
+
# add_executable(${PROJECT_NAME}_node src/cablearning_node.cpp)
|
| 137 |
+
|
| 138 |
+
## Rename C++ executable without prefix
|
| 139 |
+
## The above recommended prefix causes long target names, the following renames the
|
| 140 |
+
## target back to the shorter version for ease of user use
|
| 141 |
+
## e.g. "rosrun someones_pkg node" instead of "rosrun someones_pkg someones_pkg_node"
|
| 142 |
+
# set_target_properties(${PROJECT_NAME}_node PROPERTIES OUTPUT_NAME node PREFIX "")
|
| 143 |
+
|
| 144 |
+
## Add cmake target dependencies of the executable
|
| 145 |
+
## same as for the library above
|
| 146 |
+
# add_dependencies(${PROJECT_NAME}_node ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS})
|
| 147 |
+
|
| 148 |
+
## Specify libraries to link a library or executable target against
|
| 149 |
+
# target_link_libraries(${PROJECT_NAME}_node
|
| 150 |
+
# ${catkin_LIBRARIES}
|
| 151 |
+
# )
|
| 152 |
+
|
| 153 |
+
#############
|
| 154 |
+
## Install ##
|
| 155 |
+
#############
|
| 156 |
+
|
| 157 |
+
# all install targets should use catkin DESTINATION variables
|
| 158 |
+
# See http://ros.org/doc/api/catkin/html/adv_user_guide/variables.html
|
| 159 |
+
|
| 160 |
+
## Mark executable scripts (Python etc.) for installation
|
| 161 |
+
## in contrast to setup.py, you can choose the destination
|
| 162 |
+
# install(PROGRAMS
|
| 163 |
+
# scripts/my_python_script
|
| 164 |
+
# DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
|
| 165 |
+
# )
|
| 166 |
+
|
| 167 |
+
## Mark executables and/or libraries for installation
|
| 168 |
+
# install(TARGETS ${PROJECT_NAME} ${PROJECT_NAME}_node
|
| 169 |
+
# ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
|
| 170 |
+
# LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
|
| 171 |
+
# RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
|
| 172 |
+
# )
|
| 173 |
+
|
| 174 |
+
## Mark cpp header files for installation
|
| 175 |
+
# install(DIRECTORY include/${PROJECT_NAME}/
|
| 176 |
+
# DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION}
|
| 177 |
+
# FILES_MATCHING PATTERN "*.h"
|
| 178 |
+
# PATTERN ".svn" EXCLUDE
|
| 179 |
+
# )
|
| 180 |
+
|
| 181 |
+
## Mark other files for installation (e.g. launch and bag files, etc.)
|
| 182 |
+
# install(FILES
|
| 183 |
+
# # myfile1
|
| 184 |
+
# # myfile2
|
| 185 |
+
# DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}
|
| 186 |
+
# )
|
| 187 |
+
|
| 188 |
+
#############
|
| 189 |
+
## Testing ##
|
| 190 |
+
#############
|
| 191 |
+
|
| 192 |
+
## Add gtest based cpp test target and link libraries
|
| 193 |
+
# catkin_add_gtest(${PROJECT_NAME}-test test/test_cablearning.cpp)
|
| 194 |
+
# if(TARGET ${PROJECT_NAME}-test)
|
| 195 |
+
# target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME})
|
| 196 |
+
# endif()
|
| 197 |
+
|
| 198 |
+
## Add folders to be run by python nosetests
|
| 199 |
+
# catkin_add_nosetests(test)
|
ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/launch/cnn_data_pub.launch
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<launch>
|
| 2 |
+
<!-- Subgoal Publisher -->
|
| 3 |
+
<node name="pure_pursuit" pkg="semantic_cnn_nav" type="pure_pursuit.py" output="screen" required="true" />
|
| 4 |
+
|
| 5 |
+
<!-- CNN Data Publisher -->
|
| 6 |
+
<node name="cnn_data_pub" pkg="semantic_cnn_nav" type="cnn_data_pub.py"/>
|
| 7 |
+
|
| 8 |
+
<!-- Goal visualization -->
|
| 9 |
+
<node name="goal_visualize" pkg="semantic_cnn_nav" type="goal_visualize.py" output="screen" />
|
| 10 |
+
|
| 11 |
+
</launch>
|
ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/launch/semantic_cnn_inference.launch
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<launch>
|
| 2 |
+
<arg name="s3_net_model_file" default="$(find semantic_cnn_nav)/src/model/s3_net_model.pth"/>
|
| 3 |
+
<arg name="semantic_cnn_model_file" default="$(find semantic_cnn_nav)/src/model/semantic_cnn_model.pth"/>
|
| 4 |
+
|
| 5 |
+
<!-- DRL-VO publisher -->
|
| 6 |
+
<node name="drl_vo_cmd" pkg="semantic_cnn_nav" type="semantic_cnn_nav_inference.py" output="screen">
|
| 7 |
+
<param name="s3_net_model_file" value="$(arg s3_net_model_file)" type="string"/>
|
| 8 |
+
<param name="semantic_cnn_model_file" value="$(arg semantic_cnn_model_file)" type="string"/>
|
| 9 |
+
</node>
|
| 10 |
+
|
| 11 |
+
</launch>
|
ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/launch/semantic_cnn_nav_gazebo.launch
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<launch>
|
| 2 |
+
<arg name="s3_net_model_file" default="$(find semantic_cnn_nav)/src/model/s3_net_model.pth"/>
|
| 3 |
+
<arg name="semantic_cnn_model_file" default="$(find semantic_cnn_nav)/src/model/semantic_cnn_model.pth"/>
|
| 4 |
+
|
| 5 |
+
<arg name="scene_file" default="$(find pedsim_simulator)/scenarios/lobby/eng_hall_5.xml"/>
|
| 6 |
+
<arg name="world_name" default="$(find pedsim_gazebo_plugin)/worlds/eng_hall.world"/>
|
| 7 |
+
<arg name="gui" default="true" doc="Bring up the Gazebo graphical interface"/>
|
| 8 |
+
<arg name="pose_initial_x" default="1.0"/>
|
| 9 |
+
<arg name="pose_initial_y" default="1.0"/>
|
| 10 |
+
<arg name="map_file" default="$(find robot_gazebo)/maps/gazebo_eng_lobby/gazebo_eng_lobby.yaml"/>
|
| 11 |
+
<arg name="initial_pose_x" default="1.0"/>
|
| 12 |
+
<arg name="initial_pose_y" default="0.0"/>
|
| 13 |
+
<arg name="initial_pose_a" default="0.13"/>
|
| 14 |
+
|
| 15 |
+
<!-- Output -->
|
| 16 |
+
<arg name="output" default="log"/>
|
| 17 |
+
<arg name="enable_opencv" default="true"/>
|
| 18 |
+
<arg name="enable_console_output" default="true"/>
|
| 19 |
+
<arg name="rviz" default="true"/>
|
| 20 |
+
|
| 21 |
+
<!-- Pedsim Gazebo -->
|
| 22 |
+
<include file="$(find pedsim_simulator)/launch/robot.launch">
|
| 23 |
+
<arg name="scene_file" value="$(arg scene_file)"/>
|
| 24 |
+
<arg name="world_name" value="$(arg world_name)"/>
|
| 25 |
+
<arg name="gui" value="$(arg gui)"/>
|
| 26 |
+
<arg name="pose_initial_x" value="$(arg pose_initial_x)"/>
|
| 27 |
+
<arg name="pose_initial_y" value="$(arg pose_initial_y)"/>
|
| 28 |
+
</include>
|
| 29 |
+
|
| 30 |
+
<!-- AMCL -->
|
| 31 |
+
<include file="$(find robot_gazebo)/launch/amcl_demo_drl.launch">
|
| 32 |
+
<arg name="map_file" value="$(arg map_file)"/>
|
| 33 |
+
<arg name="initial_pose_x" value="$(arg initial_pose_x)"/>
|
| 34 |
+
<arg name="initial_pose_y" value="$(arg initial_pose_y)"/>
|
| 35 |
+
<arg name="initial_pose_a" value="$(arg initial_pose_a)"/>
|
| 36 |
+
</include>
|
| 37 |
+
|
| 38 |
+
<!-- CNN DATA -->
|
| 39 |
+
<include file="$(find semantic_cnn_nav)/launch/cnn_data_pub.launch"/>
|
| 40 |
+
|
| 41 |
+
<!-- DRL-VO Control Policy -->
|
| 42 |
+
<include file="$(find semantic_cnn_nav)/launch/semantic_cnn_inference.launch">
|
| 43 |
+
<arg name="s3_net_model_file" value="$(arg s3_net_model_file)"/>
|
| 44 |
+
<arg name="semantic_cnn_model_file" value="$(arg semantic_cnn_model_file)"/>
|
| 45 |
+
</include>
|
| 46 |
+
|
| 47 |
+
<!-- Rviz-->
|
| 48 |
+
<include file="$(find robot_gazebo)/launch/view_navigation.launch" if="$(arg rviz)"/>
|
| 49 |
+
|
| 50 |
+
</launch>
|
ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/package.xml
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0"?>
|
| 2 |
+
<package format="2">
|
| 3 |
+
<name>semantic_cnn_nav</name>
|
| 4 |
+
<version>0.0.0</version>
|
| 5 |
+
<description>The semantic_cnn_nav package</description>
|
| 6 |
+
|
| 7 |
+
<!-- One maintainer tag required, multiple allowed, one person per tag -->
|
| 8 |
+
<!-- Example: -->
|
| 9 |
+
<!-- <maintainer email="jane.doe@example.com">Jane Doe</maintainer> -->
|
| 10 |
+
<maintainer email="xzt@todo.todo">xzt</maintainer>
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
<!-- One license tag required, multiple allowed, one license per tag -->
|
| 14 |
+
<!-- Commonly used license strings: -->
|
| 15 |
+
<!-- BSD, MIT, Boost Software License, GPLv2, GPLv3, LGPLv2.1, LGPLv3 -->
|
| 16 |
+
<license>MIT</license>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
<!-- Url tags are optional, but multiple are allowed, one per tag -->
|
| 20 |
+
<!-- Optional attribute type can be: website, bugtracker, or repository -->
|
| 21 |
+
<!-- Example: -->
|
| 22 |
+
<!-- <url type="website">http://wiki.ros.org</url> -->
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
<!-- Author tags are optional, multiple are allowed, one per tag -->
|
| 26 |
+
<!-- Authors do not have to be maintainers, but could be -->
|
| 27 |
+
<!-- Example: -->
|
| 28 |
+
<!-- <author email="jane.doe@example.com">Jane Doe</author> -->
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
<!-- The *depend tags are used to specify dependencies -->
|
| 32 |
+
<!-- Dependencies can be catkin packages or system dependencies -->
|
| 33 |
+
<!-- Examples: -->
|
| 34 |
+
<!-- Use depend as a shortcut for packages that are both build and exec dependencies -->
|
| 35 |
+
<!-- <depend>roscpp</depend> -->
|
| 36 |
+
<!-- Note that this is equivalent to the following: -->
|
| 37 |
+
<!-- <build_depend>roscpp</build_depend> -->
|
| 38 |
+
<!-- <exec_depend>roscpp</exec_depend> -->
|
| 39 |
+
<!-- Use build_depend for packages you need at compile time: -->
|
| 40 |
+
<!-- <build_depend>message_generation</build_depend> -->
|
| 41 |
+
<!-- Use build_export_depend for packages you need in order to build against this package: -->
|
| 42 |
+
<!-- <build_export_depend>message_generation</build_export_depend> -->
|
| 43 |
+
<!-- Use buildtool_depend for build tool packages: -->
|
| 44 |
+
<!-- <buildtool_depend>catkin</buildtool_depend> -->
|
| 45 |
+
<!-- Use exec_depend for packages you need at runtime: -->
|
| 46 |
+
<!-- <exec_depend>message_runtime</exec_depend> -->
|
| 47 |
+
<!-- Use test_depend for packages you need only for testing: -->
|
| 48 |
+
<!-- <test_depend>gtest</test_depend> -->
|
| 49 |
+
<!-- Use doc_depend for packages you need only for building documentation: -->
|
| 50 |
+
<!-- <doc_depend>doxygen</doc_depend> -->
|
| 51 |
+
<buildtool_depend>catkin</buildtool_depend>
|
| 52 |
+
<build_depend>roscpp</build_depend>
|
| 53 |
+
<build_depend>rospy</build_depend>
|
| 54 |
+
<build_export_depend>roscpp</build_export_depend>
|
| 55 |
+
<build_export_depend>rospy</build_export_depend>
|
| 56 |
+
<exec_depend>roscpp</exec_depend>
|
| 57 |
+
<exec_depend>rospy</exec_depend>
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
<!-- The export tag contains other, unspecified, tags -->
|
| 61 |
+
<export>
|
| 62 |
+
<!-- Other tools can request additional information be placed here -->
|
| 63 |
+
</export>
|
| 64 |
+
</package>
|
ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/src/__pycache__/cnn_model.cpython-38.pyc
ADDED
|
Binary file (18.6 kB). View file
|
|
|
ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/src/__pycache__/custom_cnn_full.cpython-38.pyc
ADDED
|
Binary file (6.02 kB). View file
|
|
|
ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/src/__pycache__/model.cpython-38.pyc
ADDED
|
Binary file (15.9 kB). View file
|
|
|
ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/src/__pycache__/pfeiffer_model.cpython-38.pyc
ADDED
|
Binary file (12.8 kB). View file
|
|
|
ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/src/cnn_data_pub.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
import numpy as np
|
| 3 |
+
import rospy
|
| 4 |
+
from geometry_msgs.msg import Point, PoseStamped, Twist, TwistStamped
|
| 5 |
+
from sensor_msgs.msg import LaserScan
|
| 6 |
+
# custom define messages:
|
| 7 |
+
from cnn_msgs.msg import CNN_data
|
| 8 |
+
|
| 9 |
+
# parameters:
|
| 10 |
+
NUM_TP = 10 # the number of timestamps
|
| 11 |
+
|
| 12 |
+
class CnnData:
|
| 13 |
+
# Constructor
|
| 14 |
+
def __init__(self):
|
| 15 |
+
# initialize data:
|
| 16 |
+
self.scan_his = []
|
| 17 |
+
self.intensity_his = []
|
| 18 |
+
self.semantic_his = []
|
| 19 |
+
|
| 20 |
+
self.scan = np.zeros(1081)
|
| 21 |
+
self.intensity = np.zeros(1081)
|
| 22 |
+
self.goal = np.zeros(2)
|
| 23 |
+
self.vel = np.zeros(2)
|
| 24 |
+
|
| 25 |
+
# timer:
|
| 26 |
+
self.timer = None
|
| 27 |
+
self.rate = 10 # 20 Hz velocity controller
|
| 28 |
+
self.ts_cnt = 0 # maximum 10 timesteps
|
| 29 |
+
|
| 30 |
+
# initialize ROS objects
|
| 31 |
+
self.scan_sub = rospy.Subscriber("scan", LaserScan, self.scan_callback)
|
| 32 |
+
self.goal_sub = rospy.Subscriber("cnn_goal", Point, self.goal_callback)
|
| 33 |
+
self.vel_sub = rospy.Subscriber("mobile_base/commands/velocity", Twist, self.vel_callback)
|
| 34 |
+
self.cnn_data_pub = rospy.Publisher('cnn_data', CNN_data, queue_size=1, latch=False)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Callback function for the scan measurement subscriber
|
| 38 |
+
def scan_callback(self, laserScan_msg):
|
| 39 |
+
# get the laser scan data:
|
| 40 |
+
scan_data = np.array(laserScan_msg.ranges, dtype=np.float32)
|
| 41 |
+
scan_data[np.isnan(scan_data)] = 30.
|
| 42 |
+
scan_data[np.isinf(scan_data)] = 30.
|
| 43 |
+
# gazebo:
|
| 44 |
+
self.scan = np.concatenate((scan_data, [0]), axis=0)
|
| 45 |
+
# real world:
|
| 46 |
+
self.scan = scan_data
|
| 47 |
+
|
| 48 |
+
intensity_data = np.array(laserScan_msg.intensities, dtype=np.float32)
|
| 49 |
+
intensity_data[np.isnan(intensity_data)] = 0.
|
| 50 |
+
intensity_data[np.isinf(intensity_data)] = 0.
|
| 51 |
+
# gazebo:
|
| 52 |
+
self.intensity = np.concatenate((intensity_data, [0]), axis=0)
|
| 53 |
+
# real world:
|
| 54 |
+
self.intensity = intensity_data
|
| 55 |
+
|
| 56 |
+
# start the timer if this is the first path received
|
| 57 |
+
if self.timer is None:
|
| 58 |
+
self.start()
|
| 59 |
+
|
| 60 |
+
# Callback function for the current goal subscriber
|
| 61 |
+
def goal_callback(self, goal_msg):
|
| 62 |
+
# Cartesian coordinate:
|
| 63 |
+
self.goal = np.zeros(2)
|
| 64 |
+
self.goal[0] = goal_msg.x
|
| 65 |
+
self.goal[1] = goal_msg.y
|
| 66 |
+
|
| 67 |
+
# Callback function for the velocity subscriber
|
| 68 |
+
def vel_callback(self, vel_msg):
|
| 69 |
+
self.vel = np.zeros(2)
|
| 70 |
+
self.vel[0] = vel_msg.linear.x
|
| 71 |
+
self.vel[1] = vel_msg.angular.z
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# Start the timer that calculates command velocities
|
| 75 |
+
def start(self):
|
| 76 |
+
# initialize timer for controller update
|
| 77 |
+
self.timer = rospy.Timer(rospy.Duration(1./self.rate), self.timer_callback)
|
| 78 |
+
|
| 79 |
+
# function that runs every time the timer finishes to ensure that velocity commands are sent regularly
|
| 80 |
+
def timer_callback(self, event):
|
| 81 |
+
# generate the trajectory of pedstrians:
|
| 82 |
+
self.scan_his.append(self.scan.tolist())
|
| 83 |
+
self.intensity_his.append(self.intensity.tolist())
|
| 84 |
+
|
| 85 |
+
self.ts_cnt = self.ts_cnt + 1
|
| 86 |
+
if(self.ts_cnt == NUM_TP):
|
| 87 |
+
# publish cnn data:
|
| 88 |
+
cnn_data = CNN_data()
|
| 89 |
+
cnn_data.scan_his = [float(val) for sublist in self.scan_his for val in sublist]
|
| 90 |
+
cnn_data.intensity_his = [float(val) for sublist in self.intensity_his for val in sublist]
|
| 91 |
+
cnn_data.scan = self.scan
|
| 92 |
+
cnn_data.goal = self.goal
|
| 93 |
+
cnn_data.vel = self.vel
|
| 94 |
+
self.cnn_data_pub.publish(cnn_data)
|
| 95 |
+
|
| 96 |
+
# reset the position data list:
|
| 97 |
+
self.ts_cnt = NUM_TP-1
|
| 98 |
+
self.scan = self.scan[1:NUM_TP]
|
| 99 |
+
|
| 100 |
+
if __name__ == '__main__':
|
| 101 |
+
try:
|
| 102 |
+
rospy.init_node('cnn_data')
|
| 103 |
+
CnnData()
|
| 104 |
+
# spin() simply keeps python from exiting until this node is stopped
|
| 105 |
+
rospy.spin()
|
| 106 |
+
except rospy.ROSInterruptException:
|
| 107 |
+
pass
|
| 108 |
+
|
ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/src/cnn_model.py
ADDED
|
@@ -0,0 +1,783 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
#
|
| 3 |
+
# file: $ISIP_EXP/tuh_dpath/exp_0074/scripts/model.py
|
| 4 |
+
#
|
| 5 |
+
# revision history:
|
| 6 |
+
# 20190925 (TE): first version
|
| 7 |
+
#
|
| 8 |
+
# usage:
|
| 9 |
+
#
|
| 10 |
+
# This script hold the model architecture
|
| 11 |
+
#------------------------------------------------------------------------------
|
| 12 |
+
|
| 13 |
+
# import pytorch modules
|
| 14 |
+
#
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import numpy as np
|
| 19 |
+
import numpy.matlib
|
| 20 |
+
|
| 21 |
+
# import modules
|
| 22 |
+
#
|
| 23 |
+
import os
|
| 24 |
+
import random
|
| 25 |
+
|
| 26 |
+
# for reproducibility, we seed the rng
|
| 27 |
+
#
|
| 28 |
+
SEED1 = 1337
|
| 29 |
+
NEW_LINE = "\n"
|
| 30 |
+
|
| 31 |
+
#-----------------------------------------------------------------------------
|
| 32 |
+
#
|
| 33 |
+
# helper functions are listed here
|
| 34 |
+
#
|
| 35 |
+
#-----------------------------------------------------------------------------
|
| 36 |
+
|
| 37 |
+
# function: set_seed
|
| 38 |
+
#
|
| 39 |
+
# arguments: seed - the seed for all the rng
|
| 40 |
+
#
|
| 41 |
+
# returns: none
|
| 42 |
+
#
|
| 43 |
+
# this method seeds all the random number generators and makes
|
| 44 |
+
# the results deterministic
|
| 45 |
+
#
|
| 46 |
+
def set_seed(seed):
|
| 47 |
+
torch.manual_seed(seed)
|
| 48 |
+
torch.cuda.manual_seed_all(seed)
|
| 49 |
+
torch.backends.cudnn.deterministic = True
|
| 50 |
+
torch.backends.cudnn.benchmark = False
|
| 51 |
+
random.seed(seed)
|
| 52 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 53 |
+
#
|
| 54 |
+
# end of method
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# calculate the angle of incidence of the lidar ray:
|
| 58 |
+
def angle_incidence_calculation(b, c, alpha, last_ray=False):
|
| 59 |
+
'''
|
| 60 |
+
# remove invalid values:
|
| 61 |
+
if(last_ray): # the last ray
|
| 62 |
+
if(np.isnan(b) or np.isinf(b)):
|
| 63 |
+
b = 60.
|
| 64 |
+
if(np.isnan(c) or np.isinf(c)):
|
| 65 |
+
c = 60.
|
| 66 |
+
else:
|
| 67 |
+
b[np.isnan(b)] = 60.
|
| 68 |
+
b[np.isinf(b)] = 60.
|
| 69 |
+
c[np.isnan(c)] = 60.
|
| 70 |
+
c[np.isinf(c)] = 60.
|
| 71 |
+
'''
|
| 72 |
+
# the law of cosines:
|
| 73 |
+
a = np.sqrt(b*b + c*c - 2*b*c*np.cos(alpha))
|
| 74 |
+
if(last_ray): # the last ray
|
| 75 |
+
beta = np.arccos([(a*a + c*c - b*b)/(2*a*c)])
|
| 76 |
+
theta = np.abs(np.pi/2 - beta)
|
| 77 |
+
else:
|
| 78 |
+
gamma = np.arccos([(a*a + b*b - c*c)/(2*a*b)])
|
| 79 |
+
theta = np.abs(np.pi/2 - gamma)
|
| 80 |
+
|
| 81 |
+
return theta
|
| 82 |
+
|
| 83 |
+
#------------------------------------------------------------------------------
|
| 84 |
+
#
|
| 85 |
+
# the S3-Net model is defined here
|
| 86 |
+
#
|
| 87 |
+
#------------------------------------------------------------------------------
|
| 88 |
+
# Constants
|
| 89 |
+
NUM_INPUT_CHANNELS = 3 #2
|
| 90 |
+
#NUM_LATENT_DIM = 512 # 16*16*2
|
| 91 |
+
NUM_OUTPUT_CHANNELS = 10 # 9 classes of semantic labels + 1 background
|
| 92 |
+
# define the PyTorch VAE model
|
| 93 |
+
#
|
| 94 |
+
# define a VAE
|
| 95 |
+
# Residual blocks:
|
| 96 |
+
class Residual(nn.Module):
|
| 97 |
+
def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
|
| 98 |
+
super(Residual, self).__init__()
|
| 99 |
+
self._block = nn.Sequential(
|
| 100 |
+
nn.ReLU(True),
|
| 101 |
+
nn.Conv1d(in_channels=in_channels,
|
| 102 |
+
out_channels=num_residual_hiddens,
|
| 103 |
+
kernel_size=3, stride=1, padding=1, bias=False),
|
| 104 |
+
nn.BatchNorm1d(num_residual_hiddens),
|
| 105 |
+
nn.ReLU(True),
|
| 106 |
+
nn.Conv1d(in_channels=num_residual_hiddens,
|
| 107 |
+
out_channels=num_hiddens,
|
| 108 |
+
kernel_size=1, stride=1, bias=False),
|
| 109 |
+
nn.BatchNorm1d(num_hiddens)
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def forward(self, x):
|
| 113 |
+
return x + self._block(x)
|
| 114 |
+
|
| 115 |
+
class ResidualStack(nn.Module):
|
| 116 |
+
def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
|
| 117 |
+
super(ResidualStack, self).__init__()
|
| 118 |
+
self._num_residual_layers = num_residual_layers
|
| 119 |
+
self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
|
| 120 |
+
for _ in range(self._num_residual_layers)])
|
| 121 |
+
|
| 122 |
+
def forward(self, x):
|
| 123 |
+
for i in range(self._num_residual_layers):
|
| 124 |
+
x = self._layers[i](x)
|
| 125 |
+
return F.relu(x)
|
| 126 |
+
|
| 127 |
+
# Encoder & Decoder Architecture:
|
| 128 |
+
# Encoder:
|
| 129 |
+
class Encoder(nn.Module):
|
| 130 |
+
def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
|
| 131 |
+
super(Encoder, self).__init__()
|
| 132 |
+
self._conv_1 = nn.Sequential(*[
|
| 133 |
+
nn.Conv1d(in_channels=in_channels,
|
| 134 |
+
out_channels=num_hiddens//2,
|
| 135 |
+
kernel_size=4,
|
| 136 |
+
stride=2,
|
| 137 |
+
padding=1),
|
| 138 |
+
nn.BatchNorm1d(num_hiddens//2),
|
| 139 |
+
nn.ReLU(True)
|
| 140 |
+
])
|
| 141 |
+
self._conv_2 = nn.Sequential(*[
|
| 142 |
+
nn.Conv1d(in_channels=num_hiddens//2,
|
| 143 |
+
out_channels=num_hiddens,
|
| 144 |
+
kernel_size=4,
|
| 145 |
+
stride=2,
|
| 146 |
+
padding=1),
|
| 147 |
+
nn.BatchNorm1d(num_hiddens)
|
| 148 |
+
#nn.ReLU(True)
|
| 149 |
+
])
|
| 150 |
+
self._residual_stack = ResidualStack(in_channels=num_hiddens,
|
| 151 |
+
num_hiddens=num_hiddens,
|
| 152 |
+
num_residual_layers=num_residual_layers,
|
| 153 |
+
num_residual_hiddens=num_residual_hiddens)
|
| 154 |
+
|
| 155 |
+
def forward(self, inputs):
|
| 156 |
+
x = self._conv_1(inputs)
|
| 157 |
+
x = self._conv_2(x)
|
| 158 |
+
x = self._residual_stack(x)
|
| 159 |
+
return x
|
| 160 |
+
|
| 161 |
+
# Decoder:
|
| 162 |
+
class Decoder(nn.Module):
|
| 163 |
+
def __init__(self, out_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
|
| 164 |
+
super(Decoder, self).__init__()
|
| 165 |
+
|
| 166 |
+
self._residual_stack = ResidualStack(in_channels=num_hiddens,
|
| 167 |
+
num_hiddens=num_hiddens,
|
| 168 |
+
num_residual_layers=num_residual_layers,
|
| 169 |
+
num_residual_hiddens=num_residual_hiddens)
|
| 170 |
+
|
| 171 |
+
self._conv_trans_2 = nn.Sequential(*[
|
| 172 |
+
nn.ReLU(True),
|
| 173 |
+
nn.ConvTranspose1d(in_channels=num_hiddens,
|
| 174 |
+
out_channels=num_hiddens//2,
|
| 175 |
+
kernel_size=4,
|
| 176 |
+
stride=2,
|
| 177 |
+
padding=1),
|
| 178 |
+
nn.BatchNorm1d(num_hiddens//2),
|
| 179 |
+
nn.ReLU(True)
|
| 180 |
+
])
|
| 181 |
+
|
| 182 |
+
self._conv_trans_1 = nn.Sequential(*[
|
| 183 |
+
nn.ConvTranspose1d(in_channels=num_hiddens//2,
|
| 184 |
+
out_channels=num_hiddens//2,
|
| 185 |
+
kernel_size=4,
|
| 186 |
+
stride=2,
|
| 187 |
+
padding=1,
|
| 188 |
+
output_padding=1),
|
| 189 |
+
nn.BatchNorm1d(num_hiddens//2),
|
| 190 |
+
nn.ReLU(True),
|
| 191 |
+
nn.Conv1d(in_channels=num_hiddens//2,
|
| 192 |
+
out_channels=out_channels,
|
| 193 |
+
kernel_size=3,
|
| 194 |
+
stride=1,
|
| 195 |
+
padding=1),
|
| 196 |
+
#nn.Sigmoid()
|
| 197 |
+
])
|
| 198 |
+
|
| 199 |
+
def forward(self, inputs):
|
| 200 |
+
x = self._residual_stack(inputs)
|
| 201 |
+
x = self._conv_trans_2(x)
|
| 202 |
+
x = self._conv_trans_1(x)
|
| 203 |
+
return x
|
| 204 |
+
|
| 205 |
+
class VAE_Encoder(nn.Module):
|
| 206 |
+
def __init__(self, input_channel, num_hiddens, num_residual_layers, num_residual_hiddens, embedding_dim):
|
| 207 |
+
super(VAE_Encoder, self).__init__()
|
| 208 |
+
# parameters:
|
| 209 |
+
self.input_channels = input_channel
|
| 210 |
+
'''
|
| 211 |
+
# Constants
|
| 212 |
+
num_hiddens = 128 #128
|
| 213 |
+
num_residual_hiddens = 64 #32
|
| 214 |
+
num_residual_layers = 2
|
| 215 |
+
embedding_dim = 2 #64
|
| 216 |
+
'''
|
| 217 |
+
|
| 218 |
+
# encoder:
|
| 219 |
+
in_channels = input_channel
|
| 220 |
+
self._encoder = Encoder(in_channels,
|
| 221 |
+
num_hiddens,
|
| 222 |
+
num_residual_layers,
|
| 223 |
+
num_residual_hiddens)
|
| 224 |
+
|
| 225 |
+
# z latent variable:
|
| 226 |
+
self._encoder_z_mu = nn.Conv1d(in_channels=num_hiddens,
|
| 227 |
+
out_channels=embedding_dim,
|
| 228 |
+
kernel_size=1,
|
| 229 |
+
stride=1)
|
| 230 |
+
self._encoder_z_log_sd = nn.Conv1d(in_channels=num_hiddens,
|
| 231 |
+
out_channels=embedding_dim,
|
| 232 |
+
kernel_size=1,
|
| 233 |
+
stride=1)
|
| 234 |
+
|
| 235 |
+
def forward(self, x):
|
| 236 |
+
# input reshape:
|
| 237 |
+
x = x.reshape(-1, self.input_channels, POINTS)
|
| 238 |
+
# Encoder:
|
| 239 |
+
encoder_out = self._encoder(x)
|
| 240 |
+
# get `mu` and `log_var`:
|
| 241 |
+
z_mu = self._encoder_z_mu(encoder_out)
|
| 242 |
+
z_log_sd = self._encoder_z_log_sd(encoder_out)
|
| 243 |
+
return z_mu, z_log_sd
|
| 244 |
+
|
| 245 |
+
# our proposed model:
|
| 246 |
+
class S3Net(nn.Module):
|
| 247 |
+
def __init__(self, input_channels, output_channels):
|
| 248 |
+
super(S3Net, self).__init__()
|
| 249 |
+
# parameters:
|
| 250 |
+
self.input_channels = input_channels
|
| 251 |
+
self.latent_dim = 270
|
| 252 |
+
self.output_channels = output_channels
|
| 253 |
+
|
| 254 |
+
# Constants
|
| 255 |
+
num_hiddens = 64 #128
|
| 256 |
+
num_residual_hiddens = 32 #64
|
| 257 |
+
num_residual_layers = 2
|
| 258 |
+
embedding_dim = 1 #2
|
| 259 |
+
|
| 260 |
+
# prediction encoder:
|
| 261 |
+
self._encoder = VAE_Encoder(self.input_channels,
|
| 262 |
+
num_hiddens,
|
| 263 |
+
num_residual_layers,
|
| 264 |
+
num_residual_hiddens,
|
| 265 |
+
embedding_dim)
|
| 266 |
+
|
| 267 |
+
# decoder:
|
| 268 |
+
self._decoder_z_mu = nn.ConvTranspose1d(in_channels=embedding_dim,
|
| 269 |
+
out_channels=num_hiddens,
|
| 270 |
+
kernel_size=1,
|
| 271 |
+
stride=1)
|
| 272 |
+
self._decoder = Decoder(self.output_channels,
|
| 273 |
+
num_hiddens,
|
| 274 |
+
num_residual_layers,
|
| 275 |
+
num_residual_hiddens)
|
| 276 |
+
|
| 277 |
+
self.softmax = nn.Softmax(dim=1)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def vae_reparameterize(self, z_mu, z_log_sd):
|
| 282 |
+
"""
|
| 283 |
+
:param mu: mean from the encoder's latent space
|
| 284 |
+
:param log_sd: log standard deviation from the encoder's latent space
|
| 285 |
+
:output: reparameterized latent variable z, Monte carlo KL divergence
|
| 286 |
+
"""
|
| 287 |
+
# reshape:
|
| 288 |
+
z_mu = z_mu.reshape(-1, self.latent_dim, 1)
|
| 289 |
+
z_log_sd = z_log_sd.reshape(-1, self.latent_dim, 1)
|
| 290 |
+
# define the z probabilities (in this case Normal for both)
|
| 291 |
+
# p(z): N(z|0,I)
|
| 292 |
+
pz = torch.distributions.Normal(loc=torch.zeros_like(z_mu), scale=torch.ones_like(z_log_sd))
|
| 293 |
+
# q(z|x,phi): N(z|mu, z_var)
|
| 294 |
+
qz_x = torch.distributions.Normal(loc=z_mu, scale=torch.exp(z_log_sd))
|
| 295 |
+
|
| 296 |
+
# repameterization trick: z = z_mu + xi (*) z_log_var, xi~N(xi|0,I)
|
| 297 |
+
z = qz_x.rsample()
|
| 298 |
+
# Monte Carlo KL divergence: MCKL(p(z)||q(z|x,phi)) = log(p(z)) - log(q(z|x,phi))
|
| 299 |
+
# sum over weight dim, leaves the batch dim
|
| 300 |
+
kl_divergence = (pz.log_prob(z) - qz_x.log_prob(z)).sum(dim=1)
|
| 301 |
+
kl_loss = -kl_divergence.mean()
|
| 302 |
+
|
| 303 |
+
return z, kl_loss
|
| 304 |
+
|
| 305 |
+
def forward(self, x_s, x_i, x_a):
|
| 306 |
+
"""
|
| 307 |
+
Forward pass `input_img` through the network
|
| 308 |
+
"""
|
| 309 |
+
# reconstruction:
|
| 310 |
+
# encode:
|
| 311 |
+
# input reshape:
|
| 312 |
+
x_s = x_s.reshape(-1, 1, POINTS)
|
| 313 |
+
x_i = x_i.reshape(-1, 1, POINTS)
|
| 314 |
+
x_a = x_a.reshape(-1, 1, POINTS)
|
| 315 |
+
# concatenate along channel axis
|
| 316 |
+
x = torch.cat([x_s, x_i, x_a], dim=1)
|
| 317 |
+
|
| 318 |
+
# encode:
|
| 319 |
+
z_mu, z_log_sd = self._encoder(x)
|
| 320 |
+
|
| 321 |
+
# get the latent vector through reparameterization:
|
| 322 |
+
z, kl_loss = self.vae_reparameterize(z_mu, z_log_sd)
|
| 323 |
+
|
| 324 |
+
# decode:
|
| 325 |
+
# reshape:
|
| 326 |
+
z = z.reshape(-1, 1, 270)
|
| 327 |
+
x_d = self._decoder_z_mu(z)
|
| 328 |
+
semantic_channels = self._decoder(x_d)
|
| 329 |
+
|
| 330 |
+
# semantic grid: 10 channels
|
| 331 |
+
semantic_scan = self.softmax(semantic_channels)
|
| 332 |
+
|
| 333 |
+
return semantic_scan, semantic_channels, kl_loss
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
#------------------------------------------------------------------------------
|
| 337 |
+
#
|
| 338 |
+
# ResNet blocks
|
| 339 |
+
#
|
| 340 |
+
#------------------------------------------------------------------------------
|
| 341 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
| 342 |
+
"""3x3 convolution with padding"""
|
| 343 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 344 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
| 345 |
+
|
| 346 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
| 347 |
+
"""1x1 convolution"""
|
| 348 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 349 |
+
|
| 350 |
+
class Bottleneck(nn.Module):
|
| 351 |
+
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
| 352 |
+
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
| 353 |
+
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
| 354 |
+
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
| 355 |
+
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
| 356 |
+
|
| 357 |
+
expansion = 2 #4
|
| 358 |
+
|
| 359 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
| 360 |
+
base_width=64, dilation=1, norm_layer=None):
|
| 361 |
+
super(Bottleneck, self).__init__()
|
| 362 |
+
if norm_layer is None:
|
| 363 |
+
norm_layer = nn.BatchNorm2d
|
| 364 |
+
width = int(planes * (base_width / 64.)) * groups
|
| 365 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
| 366 |
+
self.conv1 = conv1x1(inplanes, width)
|
| 367 |
+
self.bn1 = norm_layer(width)
|
| 368 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
| 369 |
+
self.bn2 = norm_layer(width)
|
| 370 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
| 371 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
| 372 |
+
self.relu = nn.ReLU(inplace=True)
|
| 373 |
+
self.downsample = downsample
|
| 374 |
+
self.stride = stride
|
| 375 |
+
|
| 376 |
+
def forward(self, x):
|
| 377 |
+
identity = x
|
| 378 |
+
|
| 379 |
+
out = self.conv1(x)
|
| 380 |
+
out = self.bn1(out)
|
| 381 |
+
out = self.relu(out)
|
| 382 |
+
|
| 383 |
+
out = self.conv2(out)
|
| 384 |
+
out = self.bn2(out)
|
| 385 |
+
out = self.relu(out)
|
| 386 |
+
|
| 387 |
+
out = self.conv3(out)
|
| 388 |
+
out = self.bn3(out)
|
| 389 |
+
|
| 390 |
+
if self.downsample is not None:
|
| 391 |
+
identity = self.downsample(x)
|
| 392 |
+
|
| 393 |
+
out += identity
|
| 394 |
+
out = self.relu(out)
|
| 395 |
+
|
| 396 |
+
return out
|
| 397 |
+
#
|
| 398 |
+
# end of ResNet blocks
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
#------------------------------------------------------------------------------
|
| 402 |
+
#
|
| 403 |
+
# the model is defined here
|
| 404 |
+
#
|
| 405 |
+
#------------------------------------------------------------------------------
|
| 406 |
+
|
| 407 |
+
# define the PyTorch MLP model
|
| 408 |
+
#
|
| 409 |
+
class SemanticCNN(nn.Module):
|
| 410 |
+
|
| 411 |
+
# function: init
|
| 412 |
+
#
|
| 413 |
+
# arguments: input_size - int representing size of input
|
| 414 |
+
# hidden_size - number of nodes in the hidden layer
|
| 415 |
+
# num_classes - number of classes to classify
|
| 416 |
+
#
|
| 417 |
+
# return: none
|
| 418 |
+
#
|
| 419 |
+
# This method is the main function.
|
| 420 |
+
#
|
| 421 |
+
def __init__(self, block, layers, cnn_in_channels=2, num_classes=2, zero_init_residual=True,
|
| 422 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
| 423 |
+
norm_layer=None):
|
| 424 |
+
|
| 425 |
+
# inherit the superclass properties/methods
|
| 426 |
+
#
|
| 427 |
+
super(SemanticCNN, self).__init__()
|
| 428 |
+
# define the model
|
| 429 |
+
#
|
| 430 |
+
################## ped_pos net model: ###################
|
| 431 |
+
if norm_layer is None:
|
| 432 |
+
norm_layer = nn.BatchNorm2d
|
| 433 |
+
self._norm_layer = norm_layer
|
| 434 |
+
|
| 435 |
+
self.inplanes = 64
|
| 436 |
+
self.dilation = 1
|
| 437 |
+
if replace_stride_with_dilation is None:
|
| 438 |
+
# each element in the tuple indicates if we should replace
|
| 439 |
+
# the 2x2 stride with a dilated convolution instead
|
| 440 |
+
replace_stride_with_dilation = [False, False, False]
|
| 441 |
+
if len(replace_stride_with_dilation) != 3:
|
| 442 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
| 443 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
| 444 |
+
self.groups = groups
|
| 445 |
+
self.base_width = width_per_group
|
| 446 |
+
self.conv1 = nn.Conv2d(cnn_in_channels, self.inplanes, kernel_size=3, stride=1, padding=1,
|
| 447 |
+
bias=False)
|
| 448 |
+
self.bn1 = norm_layer(self.inplanes)
|
| 449 |
+
self.relu = nn.ReLU(inplace=True)
|
| 450 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
|
| 451 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 452 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
| 453 |
+
dilate=replace_stride_with_dilation[0])
|
| 454 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
| 455 |
+
dilate=replace_stride_with_dilation[1])
|
| 456 |
+
|
| 457 |
+
self.conv2_2 = nn.Sequential(
|
| 458 |
+
nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(1, 1), stride=(1,1), padding=(0, 0)),
|
| 459 |
+
nn.BatchNorm2d(128),
|
| 460 |
+
nn.ReLU(inplace=True),
|
| 461 |
+
|
| 462 |
+
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1,1), padding=(1, 1)),
|
| 463 |
+
nn.BatchNorm2d(128),
|
| 464 |
+
nn.ReLU(inplace=True),
|
| 465 |
+
|
| 466 |
+
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(1, 1), stride=(1,1), padding=(0, 0)),
|
| 467 |
+
nn.BatchNorm2d(256)
|
| 468 |
+
)
|
| 469 |
+
self.downsample2 = nn.Sequential(
|
| 470 |
+
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(1, 1), stride=(2,2), padding=(0, 0)),
|
| 471 |
+
nn.BatchNorm2d(256)
|
| 472 |
+
)
|
| 473 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 474 |
+
|
| 475 |
+
self.conv3_2 = nn.Sequential(
|
| 476 |
+
nn.Conv2d(in_channels=512, out_channels=256, kernel_size=(1, 1), stride=(1,1), padding=(0, 0)),
|
| 477 |
+
nn.BatchNorm2d(256),
|
| 478 |
+
nn.ReLU(inplace=True),
|
| 479 |
+
|
| 480 |
+
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1,1), padding=(1, 1)),
|
| 481 |
+
nn.BatchNorm2d(256),
|
| 482 |
+
nn.ReLU(inplace=True),
|
| 483 |
+
|
| 484 |
+
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(1, 1), stride=(1,1), padding=(0, 0)),
|
| 485 |
+
nn.BatchNorm2d(512)
|
| 486 |
+
)
|
| 487 |
+
self.downsample3 = nn.Sequential(
|
| 488 |
+
nn.Conv2d(in_channels=64, out_channels=512, kernel_size=(1, 1), stride=(4,4), padding=(0, 0)),
|
| 489 |
+
nn.BatchNorm2d(512)
|
| 490 |
+
)
|
| 491 |
+
self.relu3 = nn.ReLU(inplace=True)
|
| 492 |
+
|
| 493 |
+
# self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
| 494 |
+
# dilate=replace_stride_with_dilation[2])
|
| 495 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 496 |
+
self.fc = nn.Linear(256 * block.expansion + 2, num_classes)
|
| 497 |
+
|
| 498 |
+
for m in self.modules():
|
| 499 |
+
if isinstance(m, nn.Conv2d):
|
| 500 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 501 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 502 |
+
nn.init.constant_(m.weight, 1)
|
| 503 |
+
nn.init.constant_(m.bias, 0)
|
| 504 |
+
elif isinstance(m, nn.BatchNorm1d): # add by xzt
|
| 505 |
+
nn.init.constant_(m.weight, 1)
|
| 506 |
+
nn.init.constant_(m.bias, 0)
|
| 507 |
+
elif isinstance(m, nn.Linear):
|
| 508 |
+
nn.init.xavier_normal_(m.weight)
|
| 509 |
+
|
| 510 |
+
# Zero-initialize the last BN in each residual branch,
|
| 511 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
| 512 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
| 513 |
+
if zero_init_residual:
|
| 514 |
+
for m in self.modules():
|
| 515 |
+
if isinstance(m, Bottleneck):
|
| 516 |
+
nn.init.constant_(m.bn3.weight, 0)
|
| 517 |
+
|
| 518 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
| 519 |
+
norm_layer = self._norm_layer
|
| 520 |
+
downsample = None
|
| 521 |
+
previous_dilation = self.dilation
|
| 522 |
+
if dilate:
|
| 523 |
+
self.dilation *= stride
|
| 524 |
+
stride = 1
|
| 525 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 526 |
+
downsample = nn.Sequential(
|
| 527 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 528 |
+
norm_layer(planes * block.expansion),
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
layers = []
|
| 532 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
| 533 |
+
self.base_width, previous_dilation, norm_layer))
|
| 534 |
+
self.inplanes = planes * block.expansion
|
| 535 |
+
for _ in range(1, blocks):
|
| 536 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
| 537 |
+
base_width=self.base_width, dilation=self.dilation,
|
| 538 |
+
norm_layer=norm_layer))
|
| 539 |
+
|
| 540 |
+
return nn.Sequential(*layers)
|
| 541 |
+
|
| 542 |
+
def _forward_impl(self, scan, semantics, goal):
|
| 543 |
+
###### Start of fusion net ######
|
| 544 |
+
scan_in = scan.reshape(-1,1,80,80)
|
| 545 |
+
semantics_in = semantics.reshape(-1,1,80,80)
|
| 546 |
+
fusion_in = torch.cat((scan_in, semantics_in), dim=1)
|
| 547 |
+
|
| 548 |
+
# See note [TorchScript super()]
|
| 549 |
+
x = self.conv1(fusion_in)
|
| 550 |
+
x = self.bn1(x)
|
| 551 |
+
x = self.relu(x)
|
| 552 |
+
x = self.maxpool(x)
|
| 553 |
+
|
| 554 |
+
identity3 = self.downsample3(x)
|
| 555 |
+
|
| 556 |
+
x = self.layer1(x)
|
| 557 |
+
|
| 558 |
+
identity2 = self.downsample2(x)
|
| 559 |
+
|
| 560 |
+
x = self.layer2(x)
|
| 561 |
+
|
| 562 |
+
x = self.conv2_2(x)
|
| 563 |
+
x += identity2
|
| 564 |
+
x = self.relu2(x)
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
x = self.layer3(x)
|
| 568 |
+
# x = self.layer4(x)
|
| 569 |
+
|
| 570 |
+
x = self.conv3_2(x)
|
| 571 |
+
x += identity3
|
| 572 |
+
x = self.relu3(x)
|
| 573 |
+
|
| 574 |
+
x = self.avgpool(x)
|
| 575 |
+
fusion_out = torch.flatten(x, 1)
|
| 576 |
+
###### End of fusion net ######
|
| 577 |
+
|
| 578 |
+
###### Start of goal net #######
|
| 579 |
+
goal_in = goal.reshape(-1,2)
|
| 580 |
+
goal_out = torch.flatten(goal_in, 1)
|
| 581 |
+
###### End of goal net #######
|
| 582 |
+
# Combine
|
| 583 |
+
fc_in = torch.cat((fusion_out, goal_out), dim=1)
|
| 584 |
+
x = self.fc(fc_in)
|
| 585 |
+
|
| 586 |
+
return x
|
| 587 |
+
|
| 588 |
+
def forward(self, scan, semantics, goal):
|
| 589 |
+
return self._forward_impl(scan, semantics, goal)
|
| 590 |
+
#
|
| 591 |
+
# end of method
|
| 592 |
+
#
|
| 593 |
+
# end of class
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
# define the PyTorch MLP model
|
| 597 |
+
#
|
| 598 |
+
class CNN(nn.Module):
|
| 599 |
+
|
| 600 |
+
# function: init
|
| 601 |
+
#
|
| 602 |
+
# arguments: input_size - int representing size of input
|
| 603 |
+
# hidden_size - number of nodes in the hidden layer
|
| 604 |
+
# num_classes - number of classes to classify
|
| 605 |
+
#
|
| 606 |
+
# return: none
|
| 607 |
+
#
|
| 608 |
+
# This method is the main function.
|
| 609 |
+
#
|
| 610 |
+
def __init__(self, block, layers, cnn_in_channels=2, num_classes=2, zero_init_residual=True,
|
| 611 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
| 612 |
+
norm_layer=None):
|
| 613 |
+
|
| 614 |
+
# inherit the superclass properties/methods
|
| 615 |
+
#
|
| 616 |
+
super(CNN, self).__init__()
|
| 617 |
+
# define the model
|
| 618 |
+
#
|
| 619 |
+
################## ped_pos net model: ###################
|
| 620 |
+
if norm_layer is None:
|
| 621 |
+
norm_layer = nn.BatchNorm2d
|
| 622 |
+
self._norm_layer = norm_layer
|
| 623 |
+
|
| 624 |
+
self.inplanes = 64
|
| 625 |
+
self.dilation = 1
|
| 626 |
+
if replace_stride_with_dilation is None:
|
| 627 |
+
# each element in the tuple indicates if we should replace
|
| 628 |
+
# the 2x2 stride with a dilated convolution instead
|
| 629 |
+
replace_stride_with_dilation = [False, False, False]
|
| 630 |
+
if len(replace_stride_with_dilation) != 3:
|
| 631 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
| 632 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
| 633 |
+
self.groups = groups
|
| 634 |
+
self.base_width = width_per_group
|
| 635 |
+
self.conv1 = nn.Conv2d(cnn_in_channels, self.inplanes, kernel_size=3, stride=1, padding=1,
|
| 636 |
+
bias=False)
|
| 637 |
+
self.bn1 = norm_layer(self.inplanes)
|
| 638 |
+
self.relu = nn.ReLU(inplace=True)
|
| 639 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
|
| 640 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 641 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
| 642 |
+
dilate=replace_stride_with_dilation[0])
|
| 643 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
| 644 |
+
dilate=replace_stride_with_dilation[1])
|
| 645 |
+
|
| 646 |
+
self.conv2_2 = nn.Sequential(
|
| 647 |
+
nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(1, 1), stride=(1,1), padding=(0, 0)),
|
| 648 |
+
nn.BatchNorm2d(128),
|
| 649 |
+
nn.ReLU(inplace=True),
|
| 650 |
+
|
| 651 |
+
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1,1), padding=(1, 1)),
|
| 652 |
+
nn.BatchNorm2d(128),
|
| 653 |
+
nn.ReLU(inplace=True),
|
| 654 |
+
|
| 655 |
+
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(1, 1), stride=(1,1), padding=(0, 0)),
|
| 656 |
+
nn.BatchNorm2d(256)
|
| 657 |
+
)
|
| 658 |
+
self.downsample2 = nn.Sequential(
|
| 659 |
+
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(1, 1), stride=(2,2), padding=(0, 0)),
|
| 660 |
+
nn.BatchNorm2d(256)
|
| 661 |
+
)
|
| 662 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 663 |
+
|
| 664 |
+
self.conv3_2 = nn.Sequential(
|
| 665 |
+
nn.Conv2d(in_channels=512, out_channels=256, kernel_size=(1, 1), stride=(1,1), padding=(0, 0)),
|
| 666 |
+
nn.BatchNorm2d(256),
|
| 667 |
+
nn.ReLU(inplace=True),
|
| 668 |
+
|
| 669 |
+
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1,1), padding=(1, 1)),
|
| 670 |
+
nn.BatchNorm2d(256),
|
| 671 |
+
nn.ReLU(inplace=True),
|
| 672 |
+
|
| 673 |
+
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(1, 1), stride=(1,1), padding=(0, 0)),
|
| 674 |
+
nn.BatchNorm2d(512)
|
| 675 |
+
)
|
| 676 |
+
self.downsample3 = nn.Sequential(
|
| 677 |
+
nn.Conv2d(in_channels=64, out_channels=512, kernel_size=(1, 1), stride=(4,4), padding=(0, 0)),
|
| 678 |
+
nn.BatchNorm2d(512)
|
| 679 |
+
)
|
| 680 |
+
self.relu3 = nn.ReLU(inplace=True)
|
| 681 |
+
|
| 682 |
+
# self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
| 683 |
+
# dilate=replace_stride_with_dilation[2])
|
| 684 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 685 |
+
self.fc = nn.Linear(256 * block.expansion + 2, num_classes)
|
| 686 |
+
|
| 687 |
+
for m in self.modules():
|
| 688 |
+
if isinstance(m, nn.Conv2d):
|
| 689 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 690 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 691 |
+
nn.init.constant_(m.weight, 1)
|
| 692 |
+
nn.init.constant_(m.bias, 0)
|
| 693 |
+
elif isinstance(m, nn.BatchNorm1d): # add by xzt
|
| 694 |
+
nn.init.constant_(m.weight, 1)
|
| 695 |
+
nn.init.constant_(m.bias, 0)
|
| 696 |
+
elif isinstance(m, nn.Linear):
|
| 697 |
+
nn.init.xavier_normal_(m.weight)
|
| 698 |
+
|
| 699 |
+
# Zero-initialize the last BN in each residual branch,
|
| 700 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
| 701 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
| 702 |
+
if zero_init_residual:
|
| 703 |
+
for m in self.modules():
|
| 704 |
+
if isinstance(m, Bottleneck):
|
| 705 |
+
nn.init.constant_(m.bn3.weight, 0)
|
| 706 |
+
|
| 707 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
| 708 |
+
norm_layer = self._norm_layer
|
| 709 |
+
downsample = None
|
| 710 |
+
previous_dilation = self.dilation
|
| 711 |
+
if dilate:
|
| 712 |
+
self.dilation *= stride
|
| 713 |
+
stride = 1
|
| 714 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 715 |
+
downsample = nn.Sequential(
|
| 716 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 717 |
+
norm_layer(planes * block.expansion),
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
layers = []
|
| 721 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
| 722 |
+
self.base_width, previous_dilation, norm_layer))
|
| 723 |
+
self.inplanes = planes * block.expansion
|
| 724 |
+
for _ in range(1, blocks):
|
| 725 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
| 726 |
+
base_width=self.base_width, dilation=self.dilation,
|
| 727 |
+
norm_layer=norm_layer))
|
| 728 |
+
|
| 729 |
+
return nn.Sequential(*layers)
|
| 730 |
+
|
| 731 |
+
def _forward_impl(self, scan, goal):
|
| 732 |
+
###### Start of fusion net ######
|
| 733 |
+
scan_in = scan.reshape(-1,1,80,80)
|
| 734 |
+
fusion_in = scan_in #torch.cat((scan_in, semantics_in), dim=1)
|
| 735 |
+
|
| 736 |
+
# See note [TorchScript super()]
|
| 737 |
+
x = self.conv1(fusion_in)
|
| 738 |
+
x = self.bn1(x)
|
| 739 |
+
x = self.relu(x)
|
| 740 |
+
x = self.maxpool(x)
|
| 741 |
+
|
| 742 |
+
identity3 = self.downsample3(x)
|
| 743 |
+
|
| 744 |
+
x = self.layer1(x)
|
| 745 |
+
|
| 746 |
+
identity2 = self.downsample2(x)
|
| 747 |
+
|
| 748 |
+
x = self.layer2(x)
|
| 749 |
+
|
| 750 |
+
x = self.conv2_2(x)
|
| 751 |
+
x += identity2
|
| 752 |
+
x = self.relu2(x)
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
x = self.layer3(x)
|
| 756 |
+
# x = self.layer4(x)
|
| 757 |
+
|
| 758 |
+
x = self.conv3_2(x)
|
| 759 |
+
x += identity3
|
| 760 |
+
x = self.relu3(x)
|
| 761 |
+
|
| 762 |
+
x = self.avgpool(x)
|
| 763 |
+
fusion_out = torch.flatten(x, 1)
|
| 764 |
+
###### End of fusion net ######
|
| 765 |
+
|
| 766 |
+
###### Start of goal net #######
|
| 767 |
+
goal_in = goal.reshape(-1,2)
|
| 768 |
+
goal_out = torch.flatten(goal_in, 1)
|
| 769 |
+
###### End of goal net #######
|
| 770 |
+
# Combine
|
| 771 |
+
fc_in = torch.cat((fusion_out, goal_out), dim=1)
|
| 772 |
+
x = self.fc(fc_in)
|
| 773 |
+
|
| 774 |
+
return x
|
| 775 |
+
|
| 776 |
+
def forward(self, scan, goal):
|
| 777 |
+
return self._forward_impl(scan, goal)
|
| 778 |
+
#
|
| 779 |
+
# end of method
|
| 780 |
+
#
|
| 781 |
+
# end of class
|
| 782 |
+
#
|
| 783 |
+
# end of file
|
ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/src/goal_visualize.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import rospy
|
| 4 |
+
from std_msgs.msg import Header
|
| 5 |
+
from geometry_msgs.msg import Point, PoseStamped
|
| 6 |
+
from visualization_msgs.msg import Marker
|
| 7 |
+
from visualization_msgs.msg import MarkerArray
|
| 8 |
+
|
| 9 |
+
def goal_callback(goal_msg):
|
| 10 |
+
|
| 11 |
+
# initialize header and color
|
| 12 |
+
h = Header()
|
| 13 |
+
h.frame_id = "map"
|
| 14 |
+
h.stamp = rospy.Time.now()
|
| 15 |
+
|
| 16 |
+
# initialize goal marker message
|
| 17 |
+
goal_marker = Marker()
|
| 18 |
+
goal_marker.header = h
|
| 19 |
+
goal_marker.type = Marker.SPHERE
|
| 20 |
+
goal_marker.action = Marker.ADD
|
| 21 |
+
goal_marker.pose = goal_msg.pose
|
| 22 |
+
goal_marker.scale.x = 1.8
|
| 23 |
+
goal_marker.scale.y = 1.8
|
| 24 |
+
goal_marker.scale.z = 1.8
|
| 25 |
+
goal_marker.color.r = 1.0
|
| 26 |
+
goal_marker.color.g = 0.0
|
| 27 |
+
goal_marker.color.b = 0.0
|
| 28 |
+
goal_marker.color.a = 0.5 # set transparency
|
| 29 |
+
|
| 30 |
+
goal_vis_pub.publish(goal_marker)
|
| 31 |
+
|
| 32 |
+
if __name__ == '__main__':
|
| 33 |
+
try:
|
| 34 |
+
rospy.init_node('goal_vis')
|
| 35 |
+
goal_sub = rospy.Subscriber("/move_base/current_goal", PoseStamped, goal_callback)
|
| 36 |
+
goal_vis_pub = rospy.Publisher('goal_markers', Marker, queue_size=1, latch=True)
|
| 37 |
+
# spin() simply keeps python from exiting until this node is stopped
|
| 38 |
+
rospy.spin()
|
| 39 |
+
except rospy.ROSInterruptException:
|
| 40 |
+
pass
|
ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/src/model/s3_net_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:86ffcba0092e8e20d80fc02e5e01bb675c60d0c897d8830305ecc5b8b20b6dbb
|
| 3 |
+
size 741507
|
ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/src/model/semantic_cnn_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5ed1ba15e9c9df1c7c0b5f1f45545f57ad30134d5bb7673e77ee7df106358a60
|
| 3 |
+
size 28985757
|
ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/src/pure_pursuit.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import rospy
|
| 4 |
+
from nav_msgs.msg import Path
|
| 5 |
+
from geometry_msgs.msg import Twist, PoseStamped
|
| 6 |
+
import tf
|
| 7 |
+
from scipy.optimize import linprog
|
| 8 |
+
from geometry_msgs.msg import Point
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import threading
|
| 13 |
+
|
| 14 |
+
class PurePursuit:
|
| 15 |
+
|
| 16 |
+
# Constructor
|
| 17 |
+
def __init__(self):
|
| 18 |
+
# initialize parameters
|
| 19 |
+
self.lookahead = 2 #rospy.get_param('~lookahead', 5.0)
|
| 20 |
+
self.rate = 20 #rospy.get_param('~rate', 20.)
|
| 21 |
+
self.timer = None
|
| 22 |
+
self.path = None # store the path to the goal
|
| 23 |
+
self.lock = threading.Lock() # lock to keep data thread safe
|
| 24 |
+
|
| 25 |
+
# Initialize ROS objects
|
| 26 |
+
#self.goal_sub = rospy.Subscriber("/move_base/current_goal", PoseStamped, self.goal_callback)
|
| 27 |
+
self.path_sub = rospy.Subscriber('move_base/NavfnROS/plan', Path, self.path_callback)
|
| 28 |
+
self.tf_listener = tf.TransformListener()
|
| 29 |
+
#self.cmd_vel_pub = rospy.Publisher('cmd_vel', Twist, queue_size=10)
|
| 30 |
+
self.cnn_goal_pub = rospy.Publisher('cnn_goal', Point, queue_size=1)#, latch=True)
|
| 31 |
+
self.final_goal_pub = rospy.Publisher('final_goal', Point, queue_size=1)#, latch=True)
|
| 32 |
+
|
| 33 |
+
# Callback function for the path subscriber
|
| 34 |
+
def path_callback(self, msg):
|
| 35 |
+
rospy.logdebug('PurePursuit: Got path')
|
| 36 |
+
# lock this data to ensure that it is not changed while other processes are using it
|
| 37 |
+
self.lock.acquire()
|
| 38 |
+
self.path = msg # store the path in the class member
|
| 39 |
+
self.lock.release()
|
| 40 |
+
# start the timer if this is the first path received
|
| 41 |
+
if self.timer is None:
|
| 42 |
+
self.start()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# Start the timer that calculates command velocities
|
| 46 |
+
def start(self):
|
| 47 |
+
# initialize timer for controller update
|
| 48 |
+
self.timer = rospy.Timer(rospy.Duration(1./self.rate), self.timer_callback)
|
| 49 |
+
|
| 50 |
+
# Get the current pose of the robot from the tf tree
|
| 51 |
+
def get_current_pose(self):
|
| 52 |
+
trans = rot = None
|
| 53 |
+
# look up the current pose of the base_link using the tf tree
|
| 54 |
+
try:
|
| 55 |
+
(trans,rot) = self.tf_listener.lookupTransform('/map', '/base_link', rospy.Time(0))
|
| 56 |
+
except (tf.LookupException, tf.ConnectivityException, tf.ExtrapolationException):
|
| 57 |
+
rospy.logwarn('Could not get robot pose')
|
| 58 |
+
return (np.array([np.nan, np.nan]), np.nan)
|
| 59 |
+
x = np.array([trans[0], trans[1]])
|
| 60 |
+
(roll, pitch, theta) = tf.transformations.euler_from_quaternion(rot)
|
| 61 |
+
rospy.logdebug("x = {}, y = {}, theta = {}".format(x[0], x[1], theta))
|
| 62 |
+
|
| 63 |
+
return (x, theta)
|
| 64 |
+
|
| 65 |
+
# Find the closest point on the current path to the point x
|
| 66 |
+
# Inputs:
|
| 67 |
+
# x = numpy array with 2 elements (x and y position of robot)
|
| 68 |
+
# seg = optional argument that selects which segment of the path to compute the closest point on
|
| 69 |
+
# Outputs:
|
| 70 |
+
# pt_min = closest point on the path to x
|
| 71 |
+
# dist_min = distance from the closest point to x
|
| 72 |
+
# seg_min = index of closest segment to x
|
| 73 |
+
def find_closest_point(self, x, seg=-1):
|
| 74 |
+
# initialize return values
|
| 75 |
+
pt_min = np.array([np.nan, np.nan])
|
| 76 |
+
dist_min = np.inf
|
| 77 |
+
seg_min = -1
|
| 78 |
+
|
| 79 |
+
# check if path has been received yet
|
| 80 |
+
if self.path is None:
|
| 81 |
+
rospy.logwarn('Pure Pursuit: No path received yet')
|
| 82 |
+
return (pt_min, dist_min, seg_min)
|
| 83 |
+
|
| 84 |
+
##### YOUR CODE STARTS HERE #####
|
| 85 |
+
if seg == -1:
|
| 86 |
+
# find closest point on entire pathd
|
| 87 |
+
for i in range(len(self.path.poses) - 1): # gets total number of segments and iterates over them all
|
| 88 |
+
(pt, dist, s) = self.find_closest_point(x, i) # find the closest point to the robot on segment i
|
| 89 |
+
if dist < dist_min: # if new point is closer than the previous best, keep it as the new best point
|
| 90 |
+
pt_min = pt
|
| 91 |
+
dist_min = dist
|
| 92 |
+
seg_min = s
|
| 93 |
+
else:
|
| 94 |
+
# find closest point on segment seg
|
| 95 |
+
# extract the start and end of segment seg from the path
|
| 96 |
+
p_start = np.array([self.path.poses[seg].pose.position.x, self.path.poses[seg].pose.position.y])
|
| 97 |
+
p_end = np.array([self.path.poses[seg+1].pose.position.x, self.path.poses[seg+1].pose.position.y])
|
| 98 |
+
|
| 99 |
+
# calculate the unit direction vector and segment length
|
| 100 |
+
v = p_end - p_start
|
| 101 |
+
length_seg = np.linalg.norm(v)
|
| 102 |
+
v = v / length_seg
|
| 103 |
+
|
| 104 |
+
# calculate projected distance
|
| 105 |
+
dist_projected = np.dot(x - p_start, v)
|
| 106 |
+
|
| 107 |
+
# find closest point on the line segment to x
|
| 108 |
+
if dist_projected < 0.:
|
| 109 |
+
pt_min = p_start
|
| 110 |
+
elif dist_projected > length_seg:
|
| 111 |
+
pt_min = p_end
|
| 112 |
+
else:
|
| 113 |
+
pt_min = p_start + dist_projected * v
|
| 114 |
+
|
| 115 |
+
# calculate other outputs
|
| 116 |
+
dist_min = np.linalg.norm(pt_min - x)
|
| 117 |
+
seg_min = seg
|
| 118 |
+
|
| 119 |
+
##### YOUR CODE ENDS HERE #####
|
| 120 |
+
return (pt_min, dist_min, seg_min)
|
| 121 |
+
|
| 122 |
+
# Find the goal point to drive the robot towards
|
| 123 |
+
# Inputs:
|
| 124 |
+
# x = numpy array with 2 elements (x and y position of robot)
|
| 125 |
+
# pt, dist, seg = outputs of find_closest_point
|
| 126 |
+
# Outputs:
|
| 127 |
+
# goal = numpy array with 2 elements (x and y position of goal)
|
| 128 |
+
def find_goal(self, x, pt, dist, seg):
|
| 129 |
+
goal = None
|
| 130 |
+
end_goal_pos = None
|
| 131 |
+
end_goal_rot = None
|
| 132 |
+
if dist > self.lookahead:
|
| 133 |
+
# if further than lookahead from the path, drive towards the path
|
| 134 |
+
goal = pt
|
| 135 |
+
else:
|
| 136 |
+
##### YOUR CODE STARTS HERE #####
|
| 137 |
+
seg_max = len(self.path.poses) - 2
|
| 138 |
+
# extract the end of segment seg from the path
|
| 139 |
+
p_end = np.array([self.path.poses[seg+1].pose.position.x, self.path.poses[seg+1].pose.position.y])
|
| 140 |
+
# calculate the distance from x to p_end:
|
| 141 |
+
dist_end = np.linalg.norm(x - p_end)
|
| 142 |
+
|
| 143 |
+
# start from the nearest segment and iterate forward until you find either the last segment or a segment that leaves the lookahead circle
|
| 144 |
+
while(dist_end < self.lookahead and seg < seg_max):
|
| 145 |
+
seg = seg + 1
|
| 146 |
+
# extract the end of segment seg from the path
|
| 147 |
+
p_end = np.array([self.path.poses[seg+1].pose.position.x, self.path.poses[seg+1].pose.position.y])
|
| 148 |
+
# calculate the distance from x to p_end:
|
| 149 |
+
dist_end = np.linalg.norm(x - p_end)
|
| 150 |
+
|
| 151 |
+
# if searched the whole path, set the goal as the end of the path
|
| 152 |
+
if(dist_end < self.lookahead):
|
| 153 |
+
pt = np.array([self.path.poses[seg_max+1].pose.position.x, self.path.poses[seg_max+1].pose.position.y])
|
| 154 |
+
# if found a segment that leaves the circle, find the intersection with the circle
|
| 155 |
+
else:
|
| 156 |
+
# find the closest point:
|
| 157 |
+
(pt, dist, seg) = self.find_closest_point(x, seg)
|
| 158 |
+
# extract the start and end of segment seg from the path
|
| 159 |
+
p_start = np.array([self.path.poses[seg].pose.position.x, self.path.poses[seg].pose.position.y])
|
| 160 |
+
p_end = np.array([self.path.poses[seg+1].pose.position.x, self.path.poses[seg+1].pose.position.y])
|
| 161 |
+
# calculate the unit direction vector and segment length
|
| 162 |
+
v = p_end - p_start
|
| 163 |
+
length_seg = np.linalg.norm(v)
|
| 164 |
+
v = v / length_seg
|
| 165 |
+
# calculate projected distance:
|
| 166 |
+
dist_projected_x = np.dot(x - pt, v)
|
| 167 |
+
dist_projected_y = np.linalg.norm(np.cross(x - pt, v))
|
| 168 |
+
pt = pt + (np.sqrt(self.lookahead**2 - dist_projected_y**2) + dist_projected_x)*v
|
| 169 |
+
|
| 170 |
+
goal = pt
|
| 171 |
+
##### YOUR CODE ENDS HERE #####
|
| 172 |
+
end_goal_pos = [self.path.poses[-1].pose.position.x, self.path.poses[-1].pose.position.y]
|
| 173 |
+
end_goal_rot = [self.path.poses[-1].pose.orientation.x, self.path.poses[-1].pose.orientation.y, \
|
| 174 |
+
self.path.poses[-1].pose.orientation.z, self.path.poses[-1].pose.orientation.w,]
|
| 175 |
+
|
| 176 |
+
return (goal, end_goal_pos, end_goal_rot)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# function that runs every time the timer finishes to ensure that velocity commands are sent regularly
|
| 180 |
+
def timer_callback(self, event):
|
| 181 |
+
# lock the path to ensure it is not updated during processing
|
| 182 |
+
self.lock.acquire()
|
| 183 |
+
try:
|
| 184 |
+
# get current pose
|
| 185 |
+
# (x, theta) = self.get_current_pose()
|
| 186 |
+
trans = rot = None
|
| 187 |
+
# look up the current pose of the base_link using the tf tree
|
| 188 |
+
try:
|
| 189 |
+
(trans,rot) = self.tf_listener.lookupTransform('map', 'base_link', rospy.Time(0))
|
| 190 |
+
except (tf.LookupException, tf.ConnectivityException, tf.ExtrapolationException):
|
| 191 |
+
rospy.logwarn('Could not get robot pose')
|
| 192 |
+
return (np.array([np.nan, np.nan]), np.nan)
|
| 193 |
+
x = np.array([trans[0], trans[1]])
|
| 194 |
+
(roll, pitch, theta) = tf.transformations.euler_from_quaternion(rot)
|
| 195 |
+
rospy.logdebug("x = {}, y = {}, theta = {}".format(x[0], x[1], theta))
|
| 196 |
+
if np.isnan(x[0]): # ensure data is valid
|
| 197 |
+
return
|
| 198 |
+
|
| 199 |
+
# find the closest point
|
| 200 |
+
(pt, dist, seg) = self.find_closest_point(x)
|
| 201 |
+
if np.isnan(pt).any(): # ensure data is valid
|
| 202 |
+
return
|
| 203 |
+
|
| 204 |
+
# find the goal point
|
| 205 |
+
(goal, end_goal_pos, end_goal_rot) = self.find_goal(x, pt, dist, seg)
|
| 206 |
+
if goal is None or end_goal_pos is None: # ensure data is valid
|
| 207 |
+
return
|
| 208 |
+
finally:
|
| 209 |
+
# ensure the lock is released
|
| 210 |
+
self.lock.release()
|
| 211 |
+
|
| 212 |
+
# transform goal to local coordinates
|
| 213 |
+
##### YOUR CODE STARTS HERE #####
|
| 214 |
+
# homogeneous transformation matrix:
|
| 215 |
+
map_T_robot = np.array([[np.cos(theta), -np.sin(theta), x[0]],
|
| 216 |
+
[np.sin(theta), np.cos(theta), x[1]],
|
| 217 |
+
[0, 0, 1]])
|
| 218 |
+
|
| 219 |
+
goal = np.matmul(np.linalg.inv(map_T_robot), np.array([[goal[0]],[goal[1]],[1]])) #np.dot(np.linalg.inv(map_T_robot), np.array([goal[0], goal[1],1])) #
|
| 220 |
+
goal = goal[0:2]
|
| 221 |
+
##### YOUR CODE ENDS HERE #####
|
| 222 |
+
|
| 223 |
+
# final relative goal:
|
| 224 |
+
relative_goal = np.matmul(np.linalg.inv(map_T_robot), np.array([[end_goal_pos[0]],[end_goal_pos[1]],[1]]))
|
| 225 |
+
# Compute the difference to the goal orientation
|
| 226 |
+
orientation_to_target = tf.transformations.quaternion_multiply(end_goal_rot, \
|
| 227 |
+
tf.transformations.quaternion_inverse(rot))
|
| 228 |
+
yaw = tf.transformations.euler_from_quaternion(orientation_to_target)[2]
|
| 229 |
+
|
| 230 |
+
# publish the cnn goal:
|
| 231 |
+
cnn_goal = Point()
|
| 232 |
+
cnn_goal.x = goal[0]
|
| 233 |
+
cnn_goal.y = goal[1]
|
| 234 |
+
cnn_goal.z = 0
|
| 235 |
+
if not np.isnan(cnn_goal.x) and not np.isnan(cnn_goal.y): # ensure data is valid
|
| 236 |
+
self.cnn_goal_pub.publish(cnn_goal)
|
| 237 |
+
|
| 238 |
+
# publish the final goal:
|
| 239 |
+
final_goal = Point()
|
| 240 |
+
final_goal.x = relative_goal[0]
|
| 241 |
+
final_goal.y = relative_goal[1]
|
| 242 |
+
final_goal.z = yaw
|
| 243 |
+
if not np.isnan(final_goal.x) and not np.isnan(final_goal.y): # ensure data is valid
|
| 244 |
+
self.final_goal_pub.publish(final_goal)
|
| 245 |
+
|
| 246 |
+
if __name__ == '__main__':
|
| 247 |
+
try:
|
| 248 |
+
rospy.init_node('pure_pursuit')
|
| 249 |
+
PurePursuit()
|
| 250 |
+
# spin() simply keeps python from exiting until this node is stopped
|
| 251 |
+
rospy.spin()
|
| 252 |
+
except rospy.ROSInterruptException:
|
| 253 |
+
pass
|
| 254 |
+
|
ros_deployment_ws/src/semantic_cnn_nav/semantic_cnn/src/semantic_cnn_nav_inference.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
#
|
| 3 |
+
# file: $ISIP_EXP/tuh_dpath/exp_0074/scripts/decode.py
|
| 4 |
+
#
|
| 5 |
+
# revision history:
|
| 6 |
+
# 20190925 (TE): first version
|
| 7 |
+
#
|
| 8 |
+
# usage:
|
| 9 |
+
# python decode.py odir mfile data
|
| 10 |
+
#
|
| 11 |
+
# arguments:
|
| 12 |
+
# odir: the directory where the hypotheses will be stored
|
| 13 |
+
# mfile: input model file
|
| 14 |
+
# data: the input data list to be decoded
|
| 15 |
+
#
|
| 16 |
+
# This script decodes data using a simple MLP model.
|
| 17 |
+
#------------------------------------------------------------------------------
|
| 18 |
+
|
| 19 |
+
# import modules
|
| 20 |
+
#
|
| 21 |
+
import sys
|
| 22 |
+
import os
|
| 23 |
+
|
| 24 |
+
# ros:
|
| 25 |
+
import rospy
|
| 26 |
+
#import tf
|
| 27 |
+
import numpy as np
|
| 28 |
+
import torch
|
| 29 |
+
|
| 30 |
+
# custom define messages:
|
| 31 |
+
from sensor_msgs.msg import LaserScan
|
| 32 |
+
from cnn_msgs.msg import CNN_data
|
| 33 |
+
from geometry_msgs.msg import Twist
|
| 34 |
+
from geometry_msgs.msg import Point
|
| 35 |
+
from cnn_model import *
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
#-----------------------------------------------------------------------------
|
| 39 |
+
#
|
| 40 |
+
# global variables are listed here
|
| 41 |
+
#
|
| 42 |
+
#-----------------------------------------------------------------------------
|
| 43 |
+
POINTS = 1080 #1081
|
| 44 |
+
SEQ_LEN = 10
|
| 45 |
+
IMG_SIZE = 80
|
| 46 |
+
#------------------------------------------------------------------------------
|
| 47 |
+
#
|
| 48 |
+
# the main program starts here
|
| 49 |
+
#
|
| 50 |
+
#------------------------------------------------------------------------------
|
| 51 |
+
class SemanticCnnInference:
|
| 52 |
+
# Constructor
|
| 53 |
+
def __init__(self):
|
| 54 |
+
# initialize data:
|
| 55 |
+
self.scan_his = []
|
| 56 |
+
self.intensity_his = []
|
| 57 |
+
self.scan = []
|
| 58 |
+
self.goal = []
|
| 59 |
+
self.vx = 0
|
| 60 |
+
self.wz = 0
|
| 61 |
+
|
| 62 |
+
# parameters: data mean std: scan, sub_goal, intensity, angle of incidence:
|
| 63 |
+
# [[4.518406, 8.2914915], [0.30655652, 0.5378557], [3081.8167, 1529.4413], [0.5959513, 0.4783924]]
|
| 64 |
+
self.s_mu = 4.518406
|
| 65 |
+
self.s_std = 8.2914915
|
| 66 |
+
self.g_mu = 0.30655652
|
| 67 |
+
self.g_std = 0.5378557
|
| 68 |
+
self.i_mu = 3081.8167
|
| 69 |
+
self.i_std = 1529.4413
|
| 70 |
+
self.a_mu = 0.5959513
|
| 71 |
+
self.a_std = 0.4783924
|
| 72 |
+
|
| 73 |
+
# s3-net:
|
| 74 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 75 |
+
self.alpha = np.ones(POINTS - 1)*((270*np.pi / 180) / (POINTS - 1))
|
| 76 |
+
# instantiate the S3-Net model:
|
| 77 |
+
self.s3_net_model = S3Net(input_channels=NUM_INPUT_CHANNELS,
|
| 78 |
+
output_channels=NUM_OUTPUT_CHANNELS)
|
| 79 |
+
# moves the model to device (cpu in our case so no change):
|
| 80 |
+
self.s3_net_model.to(self.device)
|
| 81 |
+
self.s3_net_model.eval()
|
| 82 |
+
# load the weights
|
| 83 |
+
#
|
| 84 |
+
s3_net_mdl_path = rospy.get_param('~s3_net_model_file', "./model/s3_net_model.pth") #'./model/s3_net_model.pth'
|
| 85 |
+
s3_net_checkpoint = torch.load(s3_net_mdl_path, map_location=self.device)
|
| 86 |
+
self.s3_net_model.load_state_dict(s3_net_checkpoint['model'])
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# semantic cnn:
|
| 90 |
+
# instantiate a model:
|
| 91 |
+
self.cnn_model = SemanticCNN(Bottleneck, [2, 1, 1], cnn_in_channels=2)
|
| 92 |
+
# moves the model to device (cpu in our case so no change):
|
| 93 |
+
self.cnn_model.to(self.device)
|
| 94 |
+
self.cnn_model.eval()
|
| 95 |
+
# load the weights
|
| 96 |
+
#
|
| 97 |
+
cnn_mdl_path = rospy.get_param('~semantic_cnn_model_file', "./model/semantic_cnn_model.pth")
|
| 98 |
+
checkpoint = torch.load(cnn_mdl_path, map_location=self.device)
|
| 99 |
+
self.cnn_model.load_state_dict(checkpoint['model'])
|
| 100 |
+
|
| 101 |
+
print("Finish loading model.")
|
| 102 |
+
|
| 103 |
+
# initialize ROS objects
|
| 104 |
+
self.cnn_data_sub = rospy.Subscriber("/cnn_data", CNN_data, self.cnn_data_callback, queue_size=1, buff_size=2**24)
|
| 105 |
+
self.cmd_vel_pub = rospy.Publisher('/navigation_velocity_smoother/raw_cmd_vel', Twist, queue_size=1, latch=False)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# Callback function for the cnn_data subscriber
|
| 109 |
+
def cnn_data_callback(self, cnn_data_msg):
|
| 110 |
+
# self.ped_pos = cnn_data_msg.ped_pos_map
|
| 111 |
+
self.scan_his = np.array(cnn_data_msg.scan_his)
|
| 112 |
+
self.intensity_his = np.array(cnn_data_msg.intensity_his)
|
| 113 |
+
self.scan = np.array(cnn_data_msg.scan)
|
| 114 |
+
self.goal = np.array(cnn_data_msg.goal)
|
| 115 |
+
|
| 116 |
+
# minimum distance:
|
| 117 |
+
scan_cur = np.array(self.scan[360-40:-360+40])
|
| 118 |
+
#print(scan)
|
| 119 |
+
scan_cur = scan_cur[scan_cur!=0]
|
| 120 |
+
if(scan_cur.size!=0):
|
| 121 |
+
min_scan_dist = np.amin(scan_cur)
|
| 122 |
+
else:
|
| 123 |
+
min_scan_dist = 10
|
| 124 |
+
#print(min_scan_dist)
|
| 125 |
+
|
| 126 |
+
cmd_vel = Twist()
|
| 127 |
+
# if the goal is close to the robot:
|
| 128 |
+
if np.linalg.norm(self.goal) <= 0.9: #or min_scan_dist <= 0.4:
|
| 129 |
+
cmd_vel.linear.x = 0
|
| 130 |
+
cmd_vel.angular.z = 0
|
| 131 |
+
#print(min_scan_dist)
|
| 132 |
+
elif min_scan_dist <= 0.6:
|
| 133 |
+
cmd_vel.linear.x = 0
|
| 134 |
+
cmd_vel.angular.z = 0.7
|
| 135 |
+
else:
|
| 136 |
+
# create lidar historical map:
|
| 137 |
+
scan_avg = np.zeros((20, IMG_SIZE))
|
| 138 |
+
semantic_avg = np.zeros((20, IMG_SIZE))
|
| 139 |
+
for n in range(SEQ_LEN):
|
| 140 |
+
# get the scan and intensity data:
|
| 141 |
+
scan = self.scan_his[n*POINTS:(n+1)*POINTS]
|
| 142 |
+
intensity = self.intensity_his[n*POINTS:(n+1)*POINTS]
|
| 143 |
+
|
| 144 |
+
# get the angle of incidence of the ray:
|
| 145 |
+
b = self.scan[:-1]
|
| 146 |
+
c = self.scan[1:]
|
| 147 |
+
# alpha = np.ones(POINTS - 1)*((270*np.pi / 180) / (POINTS - 1))
|
| 148 |
+
theta = angle_incidence_calculation(b, c, self.alpha)
|
| 149 |
+
# last ray:
|
| 150 |
+
b_last = self.scan[-2]
|
| 151 |
+
c_last = self.scan[-1]
|
| 152 |
+
alpha_last = (270*np.pi / 180) / (POINTS - 1)
|
| 153 |
+
theta_last = angle_incidence_calculation(b_last, c_last, alpha_last, last_ray=True)
|
| 154 |
+
angle_incidence = np.concatenate((theta[0], theta_last), axis=0)
|
| 155 |
+
|
| 156 |
+
# initialize:
|
| 157 |
+
scan[np.isnan(scan)] = 0.
|
| 158 |
+
scan[np.isinf(scan)] = 0.
|
| 159 |
+
intensity[np.isnan(intensity)] = 0.
|
| 160 |
+
intensity[np.isinf(intensity)] = 0.
|
| 161 |
+
angle_incidence[np.isnan(angle_incidence)] = 0.
|
| 162 |
+
angle_incidence[np.isinf(angle_incidence)] = 0.
|
| 163 |
+
|
| 164 |
+
# get valid range data:
|
| 165 |
+
scan_tmp = scan[180:-180]
|
| 166 |
+
|
| 167 |
+
# get semantic lidar data:
|
| 168 |
+
with torch.no_grad():
|
| 169 |
+
# S3-Net inference:
|
| 170 |
+
# data normalization:
|
| 171 |
+
# standardization: scan
|
| 172 |
+
# mu: 4.518406, std: 8.2914915
|
| 173 |
+
scan = (scan - self.s_mu) / self.s_std
|
| 174 |
+
|
| 175 |
+
# standardization: intensity
|
| 176 |
+
# mu: 3081.8167, std: 1529.4413
|
| 177 |
+
intensity = (intensity - self.i_mu) / self.i_std
|
| 178 |
+
|
| 179 |
+
# standardization: angle_incidence
|
| 180 |
+
# mu: 0.5959513, std: 0.4783924
|
| 181 |
+
angle_incidence = (angle_incidence - self.a_mu) / self.a_std
|
| 182 |
+
|
| 183 |
+
# feed the batch to the network:
|
| 184 |
+
semantic_scans, _, _ = self.s3_net_model(torch.FloatTensor(scan).to(self.device),
|
| 185 |
+
torch.FloatTensor(intensity).to(self.device),
|
| 186 |
+
torch.FloatTensor(angle_incidence).to(self.device)
|
| 187 |
+
)
|
| 188 |
+
semantic_scans_mx = semantic_scans.argmax(dim=1)
|
| 189 |
+
# # majority vote:
|
| 190 |
+
# semantic_scans_mx_mean = semantic_scans_mx.mode(dim=0).values
|
| 191 |
+
|
| 192 |
+
semantic_label = semantic_scans_mx.squeeze() #.data.cpu().numpy()
|
| 193 |
+
|
| 194 |
+
# get valid range data:
|
| 195 |
+
semantic_label_tmp = semantic_label[180:-180]
|
| 196 |
+
|
| 197 |
+
# get scan map:
|
| 198 |
+
for i in range(IMG_SIZE):
|
| 199 |
+
tmp = scan_tmp[i*9:(i+1)*9]
|
| 200 |
+
semantic_tmp = semantic_label_tmp[i*9:(i+1)*9]
|
| 201 |
+
# min-map:
|
| 202 |
+
i_min = np.argmin(tmp)
|
| 203 |
+
scan_avg[2*n, i] = tmp[i_min]
|
| 204 |
+
semantic_avg[2*n, i] = semantic_tmp[i_min].cpu().numpy()
|
| 205 |
+
# avg-map:
|
| 206 |
+
scan_avg[2*n+1, i] = np.mean(tmp)
|
| 207 |
+
# majority vote:
|
| 208 |
+
semantic_avg[2*n, i] = semantic_tmp.mode(dim=0).values.cpu().numpy()
|
| 209 |
+
|
| 210 |
+
scan_avg = scan_avg.reshape(1600)
|
| 211 |
+
scan_avg_map = np.matlib.repmat(scan_avg,1,4)
|
| 212 |
+
scan_map = scan_avg_map.reshape(6400)
|
| 213 |
+
|
| 214 |
+
semantic_avg = semantic_avg.reshape(1600)
|
| 215 |
+
semantic_avg_map = np.matlib.repmat(semantic_avg,1,4)
|
| 216 |
+
semantic_map = semantic_avg_map.reshape(6400)
|
| 217 |
+
|
| 218 |
+
# data normalization:
|
| 219 |
+
# standardization: scan
|
| 220 |
+
# mu: 4.518406, std: 8.2914915
|
| 221 |
+
scan_map = (scan_map - self.s_mu) / self.s_std
|
| 222 |
+
|
| 223 |
+
# goal:
|
| 224 |
+
sub_goal = np.array(self.goal, dtype=np.float32)
|
| 225 |
+
# standardization: sub goal
|
| 226 |
+
# mu: 4.518406, std: 8.2914915
|
| 227 |
+
sub_goal = (sub_goal - self.g_mu) / self.g_std
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
#self.inference()
|
| 231 |
+
action = self.cnn_model(torch.FloatTensor(scan_map).to(self.device),
|
| 232 |
+
torch.FloatTensor(semantic_map).to(self.device),
|
| 233 |
+
torch.FloatTensor(sub_goal).to(self.device)
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# calculate the goal velocity of the robot and send the command
|
| 237 |
+
# velocities:
|
| 238 |
+
cmd_vel.linear.x = action.squeeze().data.cpu().numpy()[0]
|
| 239 |
+
cmd_vel.angular.z = action.squeeze().data.cpu().numpy()[1]
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
if not np.isnan(cmd_vel.linear.x) and not np.isnan(cmd_vel.angular.z): # ensure data is valid
|
| 243 |
+
self.cmd_vel_pub.publish(cmd_vel)
|
| 244 |
+
|
| 245 |
+
# end of function
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# begin gracefully
|
| 249 |
+
#
|
| 250 |
+
|
| 251 |
+
if __name__ == '__main__':
|
| 252 |
+
rospy.init_node('semantic_cnn_inference')
|
| 253 |
+
drl_infe = SemanticCnnInference()
|
| 254 |
+
rospy.spin()
|
| 255 |
+
|
| 256 |
+
# end of file
|
training/model/semantic_cnn_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5ed1ba15e9c9df1c7c0b5f1f45545f57ad30134d5bb7673e77ee7df106358a60
|
| 3 |
+
size 28985757
|
training/run_eval.sh
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
#
|
| 3 |
+
# file: run.sh
|
| 4 |
+
#
|
| 5 |
+
# This is a simple driver script that runs training and then decoding
|
| 6 |
+
# on the training set, the dev test set and the eval set.
|
| 7 |
+
#
|
| 8 |
+
# To run this script, execute the following line:
|
| 9 |
+
#
|
| 10 |
+
# run.sh train.dat test.dat eval.dat
|
| 11 |
+
#
|
| 12 |
+
# The first argument ($1) is the training data. The last two arguments,
|
| 13 |
+
# test data ($2) and evaluation data ($3) are optional.
|
| 14 |
+
#
|
| 15 |
+
# An example of how to run this is as follows:
|
| 16 |
+
#
|
| 17 |
+
# nedc_000_[1]: echo $PWD
|
| 18 |
+
# /data/isip/exp/tuh_dpath/exp_0074/v1.0
|
| 19 |
+
# nedc_000_[1]: ./run.sh data/train_set.txt data/dev_set.txt data/eval_set.txt
|
| 20 |
+
#
|
| 21 |
+
# This script will take you through the sequence of steps required to
|
| 22 |
+
# train a simple MLP network and evaluate it on some data.
|
| 23 |
+
#
|
| 24 |
+
# The script will then take the trained models and do an evaluation
|
| 25 |
+
# on the data in "test.dat". It will output the results to output/results.txt.
|
| 26 |
+
#
|
| 27 |
+
# If an eval set is specified, it will do the same for the eval set.
|
| 28 |
+
#
|
| 29 |
+
|
| 30 |
+
# decode the number of command line arguments
|
| 31 |
+
#
|
| 32 |
+
NARGS=$#
|
| 33 |
+
|
| 34 |
+
if (test "$NARGS" -eq "0") then
|
| 35 |
+
echo "usage: run.sh train.dat [test.dat] [eval.dat]"
|
| 36 |
+
exit 1
|
| 37 |
+
fi
|
| 38 |
+
|
| 39 |
+
# define a base directory for the experiment
|
| 40 |
+
#
|
| 41 |
+
DL_EXP=`pwd`;
|
| 42 |
+
DL_SCRIPTS="$DL_EXP/scripts";
|
| 43 |
+
DL_OUT="$DL_EXP/output";
|
| 44 |
+
DL_DECODE_ODIR="$DL_OUT";
|
| 45 |
+
|
| 46 |
+
# define the output directories for training/decoding/scoring
|
| 47 |
+
#
|
| 48 |
+
#DL_TRAIN_ODIR="$DL_OUT/00_train";
|
| 49 |
+
DL_TRAIN_ODIR="$DL_EXP/model";
|
| 50 |
+
DL_MDL_PATH="$DL_TRAIN_ODIR/semantic_cnn_model.pth";
|
| 51 |
+
|
| 52 |
+
# evaluate each data set that was specified
|
| 53 |
+
#
|
| 54 |
+
echo "... starting evaluation of $1 ..."
|
| 55 |
+
$DL_SCRIPTS/decode_demo.py $DL_DECODE_ODIR $DL_MDL_PATH $1 | \
|
| 56 |
+
tee $DL_OUT/01_decode_train.log | grep "Average"
|
| 57 |
+
echo "... finished evaluation of $1 ..."
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
echo "======= end of results ======="
|
| 61 |
+
|
| 62 |
+
#
|
| 63 |
+
# exit gracefully
|
training/run_train.sh
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
#
|
| 3 |
+
# file: run.sh
|
| 4 |
+
#
|
| 5 |
+
# This is a simple driver script that runs training and then decoding
|
| 6 |
+
# on the training set, the dev test set and the eval set.
|
| 7 |
+
#
|
| 8 |
+
# To run this script, execute the following line:
|
| 9 |
+
#
|
| 10 |
+
# run.sh train.dat test.dat eval.dat
|
| 11 |
+
#
|
| 12 |
+
# The first argument ($1) is the training data. The last two arguments,
|
| 13 |
+
# test data ($2) and evaluation data ($3) are optional.
|
| 14 |
+
#
|
| 15 |
+
# An example of how to run this is as follows:
|
| 16 |
+
#
|
| 17 |
+
# nedc_000_[1]: echo $PWD
|
| 18 |
+
# /data/isip/exp/tuh_dpath/exp_0074/v1.0
|
| 19 |
+
# nedc_000_[1]: ./run.sh data/train_set.txt data/dev_set.txt data/eval_set.txt
|
| 20 |
+
#
|
| 21 |
+
# This script will take you through the sequence of steps required to
|
| 22 |
+
# train a simple MLP network and evaluate it on some data.
|
| 23 |
+
#
|
| 24 |
+
# The script will then take the trained models and do an evaluation
|
| 25 |
+
# on the data in "test.dat". It will output the results to output/results.txt.
|
| 26 |
+
#
|
| 27 |
+
# If an eval set is specified, it will do the same for the eval set.
|
| 28 |
+
#
|
| 29 |
+
|
| 30 |
+
# decode the number of command line arguments
|
| 31 |
+
#
|
| 32 |
+
NARGS=$#
|
| 33 |
+
|
| 34 |
+
if (test "$NARGS" -eq "0") then
|
| 35 |
+
echo "usage: run.sh train.dat [test.dat] [eval.dat]"
|
| 36 |
+
exit 1
|
| 37 |
+
fi
|
| 38 |
+
|
| 39 |
+
# define a base directory for the experiment
|
| 40 |
+
#
|
| 41 |
+
DL_EXP=`pwd`;
|
| 42 |
+
DL_SCRIPTS="$DL_EXP/scripts";
|
| 43 |
+
DL_OUT="$DL_EXP/output";
|
| 44 |
+
DL_LABELS="$DL_EXP/labels";
|
| 45 |
+
|
| 46 |
+
# define the number of feats environment variable
|
| 47 |
+
#
|
| 48 |
+
export DL_NUM_FEATS=5 #26
|
| 49 |
+
|
| 50 |
+
# define the output directories for training/decoding/scoring
|
| 51 |
+
#
|
| 52 |
+
#DL_TRAIN_ODIR="$DL_OUT/00_train";
|
| 53 |
+
DL_TRAIN_ODIR="$DL_EXP/model";
|
| 54 |
+
DL_MDL_PATH="$DL_TRAIN_ODIR/model.pth";
|
| 55 |
+
|
| 56 |
+
DL_DECODE_ODIR="$DL_OUT/01_hyp";
|
| 57 |
+
DL_HYP_TRAIN="$DL_DECODE_ODIR/train_set.hyp";
|
| 58 |
+
DL_HYP_DEV="$DL_DECODE_ODIR/dev_set.hyp";
|
| 59 |
+
DL_HYP_EVAL="$DL_DECODE_ODIR/eval_set.hyp";
|
| 60 |
+
|
| 61 |
+
# create the output directory
|
| 62 |
+
#
|
| 63 |
+
#rm -fr $DL_OUT
|
| 64 |
+
#mkdir -p $DL_OUT
|
| 65 |
+
|
| 66 |
+
# execute training: training must always be run
|
| 67 |
+
#
|
| 68 |
+
echo "... starting training on $1 ..."
|
| 69 |
+
$DL_SCRIPTS/train.py $DL_MDL_PATH $1 $2 | tee $DL_OUT/00_train.log | \
|
| 70 |
+
grep "reading\|Step\|Average\|Warning\|Error"
|
| 71 |
+
echo "... finished training on $1 ..."
|
| 72 |
+
|
| 73 |
+
#
|
training/scripts/__pycache__/model.cpython-38.pyc
ADDED
|
Binary file (8.39 kB). View file
|
|
|
training/scripts/decode_demo.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
#
|
| 3 |
+
# file: $ISIP_EXP/tuh_dpath/exp_0074/scripts/decode.py
|
| 4 |
+
#
|
| 5 |
+
# revision history:
|
| 6 |
+
# 20190925 (TE): first version
|
| 7 |
+
#
|
| 8 |
+
# usage:
|
| 9 |
+
# python decode.py odir mfile data
|
| 10 |
+
#
|
| 11 |
+
# arguments:
|
| 12 |
+
# odir: the directory where the hypotheses will be stored
|
| 13 |
+
# mfile: input model file
|
| 14 |
+
# data: the input data list to be decoded
|
| 15 |
+
#
|
| 16 |
+
# This script decodes data using a simple MLP model.
|
| 17 |
+
#------------------------------------------------------------------------------
|
| 18 |
+
|
| 19 |
+
# import pytorch modules
|
| 20 |
+
#
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
|
| 25 |
+
# import the model and all of its variables/functions
|
| 26 |
+
#
|
| 27 |
+
from model import *
|
| 28 |
+
|
| 29 |
+
# visualize:
|
| 30 |
+
import numpy as np
|
| 31 |
+
# import modules
|
| 32 |
+
#
|
| 33 |
+
import sys
|
| 34 |
+
import os
|
| 35 |
+
from sklearn.metrics import explained_variance_score, mean_squared_error
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
#-----------------------------------------------------------------------------
|
| 39 |
+
#
|
| 40 |
+
# global variables are listed here
|
| 41 |
+
#
|
| 42 |
+
#-----------------------------------------------------------------------------
|
| 43 |
+
model_dir = './model/semantic_SemanticCNN_model.pth' # the path of model storage: 1400 is the best one
|
| 44 |
+
NUM_ARGS = 3
|
| 45 |
+
HYP_EXT = ".hyp"
|
| 46 |
+
GRT_EXT = ".grt"
|
| 47 |
+
|
| 48 |
+
# general global values
|
| 49 |
+
#
|
| 50 |
+
SPACE = " "
|
| 51 |
+
NEW_LINE = "\n"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
#------------------------------------------------------------------------------
|
| 56 |
+
#
|
| 57 |
+
# the main program starts here
|
| 58 |
+
#
|
| 59 |
+
#------------------------------------------------------------------------------
|
| 60 |
+
|
| 61 |
+
### explained variance:
|
| 62 |
+
def explained_variance(input, target):
|
| 63 |
+
ev = 1 - np.var(target - input) / np.var(input)
|
| 64 |
+
return ev
|
| 65 |
+
|
| 66 |
+
# function: main
|
| 67 |
+
#
|
| 68 |
+
# arguments: none
|
| 69 |
+
#
|
| 70 |
+
# return: none
|
| 71 |
+
#
|
| 72 |
+
# This method is the main function.
|
| 73 |
+
#
|
| 74 |
+
|
| 75 |
+
def main(argv):
|
| 76 |
+
# ensure we have the correct amount of arguments
|
| 77 |
+
#
|
| 78 |
+
#global cur_batch_win
|
| 79 |
+
if(len(argv) != NUM_ARGS):
|
| 80 |
+
print("usage: python nedc_train_mdl.py [MDL_PATH] [TRAIN_PATH] [DEV_PATH]")
|
| 81 |
+
exit(-1)
|
| 82 |
+
|
| 83 |
+
# define local variables
|
| 84 |
+
#
|
| 85 |
+
# define local variables:
|
| 86 |
+
odir = argv[0]
|
| 87 |
+
mdl_path = argv[1]
|
| 88 |
+
pTest = argv[2]
|
| 89 |
+
|
| 90 |
+
# if the odir doesn't exist, we make it
|
| 91 |
+
#
|
| 92 |
+
if not os.path.exists(odir):
|
| 93 |
+
os.makedirs(odir)
|
| 94 |
+
|
| 95 |
+
# get the hyp file name
|
| 96 |
+
#
|
| 97 |
+
hyp_name = os.path.splitext(os.path.basename(pTest))[0] + HYP_EXT
|
| 98 |
+
grt_name = os.path.splitext(os.path.basename(pTest))[0] + GRT_EXT
|
| 99 |
+
|
| 100 |
+
# set the device to use GPU if available
|
| 101 |
+
#
|
| 102 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 103 |
+
|
| 104 |
+
# get array of the data
|
| 105 |
+
# data: [[0, 1, ... 26], [27, 28, ...] ...]
|
| 106 |
+
# labels: [0, 0, 1, ...]
|
| 107 |
+
#
|
| 108 |
+
eval_dataset = NavDataset(pTest,'test')
|
| 109 |
+
eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=1, \
|
| 110 |
+
shuffle=False, drop_last=True) #, pin_memory=True)
|
| 111 |
+
|
| 112 |
+
# instantiate a model:
|
| 113 |
+
model = SemanticCNN(Bottleneck, [2, 1, 1])
|
| 114 |
+
# moves the model to device (cpu in our case so no change):
|
| 115 |
+
model.to(device)
|
| 116 |
+
|
| 117 |
+
# set the model to evaluate
|
| 118 |
+
#
|
| 119 |
+
model.eval()
|
| 120 |
+
|
| 121 |
+
# set the loss criterion:
|
| 122 |
+
criterion = nn.MSELoss(reduction='sum')
|
| 123 |
+
criterion.to(device)
|
| 124 |
+
|
| 125 |
+
# load the weights
|
| 126 |
+
#
|
| 127 |
+
checkpoint = torch.load(mdl_path, map_location=device)
|
| 128 |
+
model.load_state_dict(checkpoint['model'])
|
| 129 |
+
|
| 130 |
+
# the output file
|
| 131 |
+
#
|
| 132 |
+
try:
|
| 133 |
+
ofile = open(os.path.join(odir, hyp_name), 'w+')
|
| 134 |
+
vel_file = open(os.path.join(odir, grt_name), 'w+')
|
| 135 |
+
except IOError as e:
|
| 136 |
+
print(os.path.join(odir, hyp_name))
|
| 137 |
+
print("[%s]: %s" % (hyp_name, e.strerror))
|
| 138 |
+
exit(-1)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# for each batch in increments of batch size:
|
| 142 |
+
counter = 0
|
| 143 |
+
running_loss = 0
|
| 144 |
+
# get the number of batches (ceiling of train_data/batch_size):
|
| 145 |
+
num_batches = int(len(eval_dataset)/eval_dataloader.batch_size)
|
| 146 |
+
with torch.no_grad():
|
| 147 |
+
for i, batch in tqdm(enumerate(eval_dataloader), total=num_batches):
|
| 148 |
+
#for i, batch in enumerate(dataloader, 0):
|
| 149 |
+
counter += 1
|
| 150 |
+
# collect the samples as a batch:
|
| 151 |
+
scan_maps = batch['scan_map']
|
| 152 |
+
scan_maps = scan_maps.to(device)
|
| 153 |
+
semantic_maps = batch['semantic_map']
|
| 154 |
+
semantic_maps = semantic_maps.to(device)
|
| 155 |
+
sub_goals = batch['sub_goal']
|
| 156 |
+
sub_goals = sub_goals.to(device)
|
| 157 |
+
velocities = batch['velocity']
|
| 158 |
+
velocities = velocities.to(device)
|
| 159 |
+
|
| 160 |
+
# feed the network the batch
|
| 161 |
+
#
|
| 162 |
+
output = model(scan_maps, semantic_maps, sub_goals)
|
| 163 |
+
#writer.add_graph(model,[batch_ped_pos_t, batch_scan_t, batch_goal_t])
|
| 164 |
+
# get the loss
|
| 165 |
+
#
|
| 166 |
+
loss = criterion(output, velocities)
|
| 167 |
+
# get the loss:
|
| 168 |
+
# multiple GPUs:
|
| 169 |
+
if torch.cuda.device_count() > 1:
|
| 170 |
+
loss = loss.mean()
|
| 171 |
+
|
| 172 |
+
running_loss += loss.item()
|
| 173 |
+
# write the highest probablity to the file
|
| 174 |
+
#
|
| 175 |
+
ofile.write(str(float(output.data.cpu().numpy()[0,0])) + \
|
| 176 |
+
SPACE + str(float(output.data.cpu().numpy()[0,1])) + NEW_LINE)
|
| 177 |
+
vel_file.write(str(float(velocities[0,0])) + \
|
| 178 |
+
SPACE + str(float(velocities[0,1])) + NEW_LINE)
|
| 179 |
+
|
| 180 |
+
# loss:
|
| 181 |
+
val_loss = running_loss / counter
|
| 182 |
+
print('Validation set: Average loss: {:.4f}'.format(val_loss))
|
| 183 |
+
# close the file
|
| 184 |
+
#
|
| 185 |
+
ofile.close()
|
| 186 |
+
vel_file.close()
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# exit gracefully
|
| 190 |
+
#
|
| 191 |
+
return True
|
| 192 |
+
#
|
| 193 |
+
# end of function
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# begin gracefully
|
| 197 |
+
#
|
| 198 |
+
if __name__ == '__main__':
|
| 199 |
+
main(sys.argv[1:])
|
| 200 |
+
#
|
| 201 |
+
# end of file
|
training/scripts/model.py
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
#
|
| 3 |
+
# file: $ISIP_EXP/SOGMP/scripts/model.py
|
| 4 |
+
#
|
| 5 |
+
# revision history: xzt
|
| 6 |
+
# 20220824 (TE): first version
|
| 7 |
+
#
|
| 8 |
+
# usage:
|
| 9 |
+
#
|
| 10 |
+
# This script hold the model architecture
|
| 11 |
+
#------------------------------------------------------------------------------
|
| 12 |
+
|
| 13 |
+
# import pytorch modules
|
| 14 |
+
#
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
# import modules
|
| 21 |
+
#
|
| 22 |
+
import os
|
| 23 |
+
import random
|
| 24 |
+
|
| 25 |
+
# for reproducibility, we seed the rng
|
| 26 |
+
#
|
| 27 |
+
SEED1 = 1337
|
| 28 |
+
NEW_LINE = "\n"
|
| 29 |
+
|
| 30 |
+
#-----------------------------------------------------------------------------
|
| 31 |
+
#
|
| 32 |
+
# helper functions are listed here
|
| 33 |
+
#
|
| 34 |
+
#-----------------------------------------------------------------------------
|
| 35 |
+
|
| 36 |
+
# function: set_seed
|
| 37 |
+
#
|
| 38 |
+
# arguments: seed - the seed for all the rng
|
| 39 |
+
#
|
| 40 |
+
# returns: none
|
| 41 |
+
#
|
| 42 |
+
# this method seeds all the random number generators and makes
|
| 43 |
+
# the results deterministic
|
| 44 |
+
#
|
| 45 |
+
def set_seed(seed):
|
| 46 |
+
torch.manual_seed(seed)
|
| 47 |
+
torch.cuda.manual_seed_all(seed)
|
| 48 |
+
torch.backends.cudnn.deterministic = True
|
| 49 |
+
torch.backends.cudnn.benchmark = False
|
| 50 |
+
random.seed(seed)
|
| 51 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 52 |
+
#
|
| 53 |
+
# end of method
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# function: get_data
|
| 57 |
+
#
|
| 58 |
+
# arguments: fp - file pointer
|
| 59 |
+
# num_feats - the number of features in a sample
|
| 60 |
+
#
|
| 61 |
+
# returns: data - the signals/features
|
| 62 |
+
# labels - the correct labels for them
|
| 63 |
+
#
|
| 64 |
+
# this method takes in a fp and returns the data and labels
|
| 65 |
+
POINTS = 1081
|
| 66 |
+
IMG_SIZE = 80
|
| 67 |
+
SEQ_LEN = 10
|
| 68 |
+
class NavDataset(torch.utils.data.Dataset):
|
| 69 |
+
def __init__(self, img_path, file_name):
|
| 70 |
+
# initialize the data and labels
|
| 71 |
+
self.npy_names = []
|
| 72 |
+
self.lengths = []
|
| 73 |
+
# parameters: data mean std: scan, sub_goal, intensity, angle of incidence:
|
| 74 |
+
# [[4.518406, 8.2914915], [0.30655652, 0.5378557], [3081.8167, 1529.4413], [0.5959513, 0.4783924]]
|
| 75 |
+
self.s_mu = 4.518406
|
| 76 |
+
self.s_std = 8.2914915
|
| 77 |
+
self.g_mu = 0.30655652
|
| 78 |
+
self.g_std = 0.5378557
|
| 79 |
+
self.i_mu = 3081.8167
|
| 80 |
+
self.i_std = 1529.4413
|
| 81 |
+
self.a_mu = 0.5959513
|
| 82 |
+
self.a_std = 0.4783924
|
| 83 |
+
|
| 84 |
+
# open train.txt or dev.txt:
|
| 85 |
+
fp_folder = open(img_path+'dataset.txt','r')
|
| 86 |
+
|
| 87 |
+
# for each line of the file:
|
| 88 |
+
for folder_line in fp_folder.read().split(NEW_LINE):
|
| 89 |
+
if('-' in folder_line):
|
| 90 |
+
npy_name = []
|
| 91 |
+
folder_path = folder_line
|
| 92 |
+
fp_file = open(img_path+folder_path+'/'+file_name+'.txt', 'r')
|
| 93 |
+
for line in fp_file.read().split(NEW_LINE):
|
| 94 |
+
if('.npy' in line):
|
| 95 |
+
npy_name.append(img_path+folder_path+line)
|
| 96 |
+
|
| 97 |
+
self.lengths.append(len(npy_name))
|
| 98 |
+
self.npy_names.append(npy_name)
|
| 99 |
+
# close txt file:
|
| 100 |
+
fp_file.close()
|
| 101 |
+
|
| 102 |
+
# close txt file:
|
| 103 |
+
fp_folder.close()
|
| 104 |
+
|
| 105 |
+
self.length = np.sum(self.lengths)
|
| 106 |
+
self.cumsum_lengths = np.cumsum(self.lengths).tolist()
|
| 107 |
+
|
| 108 |
+
print("dataset length: ", self.length)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def __len__(self):
|
| 112 |
+
return self.length
|
| 113 |
+
|
| 114 |
+
def __getitem__(self, idx):
|
| 115 |
+
|
| 116 |
+
# ---------- FAST FOLDER LOCATE ----------
|
| 117 |
+
folder_id = np.searchsorted(self.cumsum_lengths, idx, side='right')
|
| 118 |
+
start = 0 if folder_id == 0 else self.cumsum_lengths[folder_id - 1]
|
| 119 |
+
data_len = self.lengths[folder_id]
|
| 120 |
+
npy_list = self.npy_names[folder_id]
|
| 121 |
+
|
| 122 |
+
# ---------- FAST FILE PARSE ----------
|
| 123 |
+
npy_path_name = npy_list[idx - start]
|
| 124 |
+
npy_path = npy_path_name[:-11]
|
| 125 |
+
idx_num = int(npy_path_name[-11:-4])
|
| 126 |
+
|
| 127 |
+
if idx_num + SEQ_LEN < data_len:
|
| 128 |
+
idx_s = idx_num
|
| 129 |
+
elif idx_num - SEQ_LEN > 0:
|
| 130 |
+
idx_s = idx_num - SEQ_LEN
|
| 131 |
+
else:
|
| 132 |
+
idx_s = data_len // 2
|
| 133 |
+
|
| 134 |
+
# Build ending frame filename once
|
| 135 |
+
end_str = f"{idx_s + SEQ_LEN - 1:07d}.npy"
|
| 136 |
+
|
| 137 |
+
# ---------- LOAD SUBGOAL / VELOCITY ----------
|
| 138 |
+
sub_goal = np.load(f"{npy_path}/sub_goals_local/{end_str}")
|
| 139 |
+
velocity = np.load(f"{npy_path}/velocities/{end_str}")
|
| 140 |
+
|
| 141 |
+
# ---------- CREATE LIDAR MAP (VECTORIZED) ----------
|
| 142 |
+
# scan_avg, semantic_avg shape = (SEQ_LEN*2, IMG_SIZE)
|
| 143 |
+
scan_avg = np.zeros((SEQ_LEN * 2, IMG_SIZE), dtype=np.float32)
|
| 144 |
+
semantic_avg = np.zeros((SEQ_LEN * 2, IMG_SIZE), dtype=np.float32)
|
| 145 |
+
|
| 146 |
+
# Precompute slicing
|
| 147 |
+
slice_idx = np.arange(0, IMG_SIZE * 9, 9).reshape(-1, 1) + np.arange(9)
|
| 148 |
+
|
| 149 |
+
for n in range(SEQ_LEN):
|
| 150 |
+
frame_idx = f"{idx_s + n:07d}.npy"
|
| 151 |
+
|
| 152 |
+
scan = np.load(f"{npy_path}/scans_lidar/{frame_idx}")[180:-180]
|
| 153 |
+
semantic = np.load(f"{npy_path}/semantic_label/{frame_idx}")[180:-180]
|
| 154 |
+
|
| 155 |
+
# Shape after slicing = (IMG_SIZE, 9)
|
| 156 |
+
bins_scan = scan[slice_idx]
|
| 157 |
+
bins_sem = semantic[slice_idx]
|
| 158 |
+
|
| 159 |
+
# ---- min map ----
|
| 160 |
+
mins = bins_scan.min(axis=1)
|
| 161 |
+
min_idx = bins_scan.argmin(axis=1)
|
| 162 |
+
sem_min = bins_sem[np.arange(IMG_SIZE), min_idx]
|
| 163 |
+
|
| 164 |
+
scan_avg[2 * n] = mins
|
| 165 |
+
semantic_avg[2 * n] = sem_min
|
| 166 |
+
|
| 167 |
+
# ---- avg map ----
|
| 168 |
+
scan_avg[2 * n + 1] = bins_scan.mean(axis=1)
|
| 169 |
+
|
| 170 |
+
# ---- majority vote (FAST) ----
|
| 171 |
+
# bincount on axis=1
|
| 172 |
+
# bins_sem is small (size 9), so bincount(256 classes) is OK
|
| 173 |
+
counts = np.apply_along_axis(np.bincount, 1, bins_sem.astype(int), minlength=256)
|
| 174 |
+
semantic_avg[2 * n + 1] = counts.argmax(axis=1)
|
| 175 |
+
|
| 176 |
+
# ---------- FINAL MAP EXPANSION ----------
|
| 177 |
+
scan_map = np.repeat(scan_avg.reshape(-1), 4)
|
| 178 |
+
semantic_map = np.repeat(semantic_avg.reshape(-1), 4)
|
| 179 |
+
|
| 180 |
+
# initialize:
|
| 181 |
+
sub_goal[np.isnan(sub_goal)] = 0.
|
| 182 |
+
sub_goal[np.isinf(sub_goal)] = 0.
|
| 183 |
+
|
| 184 |
+
velocity[np.isnan(velocity)] = 0.
|
| 185 |
+
velocity[np.isinf(velocity)] = 0.
|
| 186 |
+
|
| 187 |
+
# data normalization:
|
| 188 |
+
# standardization: scan
|
| 189 |
+
# mu: 4.518406, std: 8.2914915
|
| 190 |
+
scan_map = (scan_map - self.s_mu) / self.s_std
|
| 191 |
+
|
| 192 |
+
# standardization: sub goal
|
| 193 |
+
# mu: 4.518406, std: 8.2914915
|
| 194 |
+
sub_goal = (sub_goal - self.g_mu) / self.g_std
|
| 195 |
+
|
| 196 |
+
# transfer to pytorch tensor:
|
| 197 |
+
scan_tensor = torch.FloatTensor(scan_map)
|
| 198 |
+
semantic_tensor = torch.FloatTensor(semantic_map)
|
| 199 |
+
sub_goal_tensor = torch.FloatTensor(sub_goal)
|
| 200 |
+
velocity_tensor = torch.FloatTensor(velocity)
|
| 201 |
+
|
| 202 |
+
data = {
|
| 203 |
+
'scan_map': scan_tensor,
|
| 204 |
+
'semantic_map': semantic_tensor,
|
| 205 |
+
'sub_goal': sub_goal_tensor,
|
| 206 |
+
'velocity': velocity_tensor,
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
return data
|
| 210 |
+
|
| 211 |
+
#
|
| 212 |
+
# end of function
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
#------------------------------------------------------------------------------
|
| 216 |
+
#
|
| 217 |
+
# ResNet blocks
|
| 218 |
+
#
|
| 219 |
+
#------------------------------------------------------------------------------
|
| 220 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
| 221 |
+
"""3x3 convolution with padding"""
|
| 222 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 223 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
| 224 |
+
|
| 225 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
| 226 |
+
"""1x1 convolution"""
|
| 227 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 228 |
+
|
| 229 |
+
class Bottleneck(nn.Module):
|
| 230 |
+
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
| 231 |
+
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
| 232 |
+
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
| 233 |
+
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
| 234 |
+
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
| 235 |
+
|
| 236 |
+
expansion = 2 #4
|
| 237 |
+
|
| 238 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
| 239 |
+
base_width=64, dilation=1, norm_layer=None):
|
| 240 |
+
super(Bottleneck, self).__init__()
|
| 241 |
+
if norm_layer is None:
|
| 242 |
+
norm_layer = nn.BatchNorm2d
|
| 243 |
+
width = int(planes * (base_width / 64.)) * groups
|
| 244 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
| 245 |
+
self.conv1 = conv1x1(inplanes, width)
|
| 246 |
+
self.bn1 = norm_layer(width)
|
| 247 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
| 248 |
+
self.bn2 = norm_layer(width)
|
| 249 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
| 250 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
| 251 |
+
self.relu = nn.ReLU(inplace=True)
|
| 252 |
+
self.downsample = downsample
|
| 253 |
+
self.stride = stride
|
| 254 |
+
|
| 255 |
+
def forward(self, x):
|
| 256 |
+
identity = x
|
| 257 |
+
|
| 258 |
+
out = self.conv1(x)
|
| 259 |
+
out = self.bn1(out)
|
| 260 |
+
out = self.relu(out)
|
| 261 |
+
|
| 262 |
+
out = self.conv2(out)
|
| 263 |
+
out = self.bn2(out)
|
| 264 |
+
out = self.relu(out)
|
| 265 |
+
|
| 266 |
+
out = self.conv3(out)
|
| 267 |
+
out = self.bn3(out)
|
| 268 |
+
|
| 269 |
+
if self.downsample is not None:
|
| 270 |
+
identity = self.downsample(x)
|
| 271 |
+
|
| 272 |
+
out += identity
|
| 273 |
+
out = self.relu(out)
|
| 274 |
+
|
| 275 |
+
return out
|
| 276 |
+
#
|
| 277 |
+
# end of ResNet blocks
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
#------------------------------------------------------------------------------
|
| 281 |
+
#
|
| 282 |
+
# the model is defined here
|
| 283 |
+
#
|
| 284 |
+
#------------------------------------------------------------------------------
|
| 285 |
+
|
| 286 |
+
# define the PyTorch MLP model
|
| 287 |
+
#
|
| 288 |
+
class SemanticCNN(nn.Module):
|
| 289 |
+
|
| 290 |
+
# function: init
|
| 291 |
+
#
|
| 292 |
+
# arguments: input_size - int representing size of input
|
| 293 |
+
# hidden_size - number of nodes in the hidden layer
|
| 294 |
+
# num_classes - number of classes to classify
|
| 295 |
+
#
|
| 296 |
+
# return: none
|
| 297 |
+
#
|
| 298 |
+
# This method is the main function.
|
| 299 |
+
#
|
| 300 |
+
def __init__(self, block, layers, num_classes=2, zero_init_residual=True,
|
| 301 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
| 302 |
+
norm_layer=None):
|
| 303 |
+
|
| 304 |
+
# inherit the superclass properties/methods
|
| 305 |
+
#
|
| 306 |
+
super(SemanticCNN, self).__init__()
|
| 307 |
+
# define the model
|
| 308 |
+
#
|
| 309 |
+
################## ped_pos net model: ###################
|
| 310 |
+
if norm_layer is None:
|
| 311 |
+
norm_layer = nn.BatchNorm2d
|
| 312 |
+
self._norm_layer = norm_layer
|
| 313 |
+
|
| 314 |
+
self.inplanes = 64
|
| 315 |
+
self.dilation = 1
|
| 316 |
+
if replace_stride_with_dilation is None:
|
| 317 |
+
# each element in the tuple indicates if we should replace
|
| 318 |
+
# the 2x2 stride with a dilated convolution instead
|
| 319 |
+
replace_stride_with_dilation = [False, False, False]
|
| 320 |
+
if len(replace_stride_with_dilation) != 3:
|
| 321 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
| 322 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
| 323 |
+
self.groups = groups
|
| 324 |
+
self.base_width = width_per_group
|
| 325 |
+
self.conv1 = nn.Conv2d(2, self.inplanes, kernel_size=3, stride=1, padding=1,
|
| 326 |
+
bias=False)
|
| 327 |
+
self.bn1 = norm_layer(self.inplanes)
|
| 328 |
+
self.relu = nn.ReLU(inplace=True)
|
| 329 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
|
| 330 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 331 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
| 332 |
+
dilate=replace_stride_with_dilation[0])
|
| 333 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
| 334 |
+
dilate=replace_stride_with_dilation[1])
|
| 335 |
+
|
| 336 |
+
self.conv2_2 = nn.Sequential(
|
| 337 |
+
nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(1, 1), stride=(1,1), padding=(0, 0)),
|
| 338 |
+
nn.BatchNorm2d(128),
|
| 339 |
+
nn.ReLU(inplace=True),
|
| 340 |
+
|
| 341 |
+
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1,1), padding=(1, 1)),
|
| 342 |
+
nn.BatchNorm2d(128),
|
| 343 |
+
nn.ReLU(inplace=True),
|
| 344 |
+
|
| 345 |
+
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(1, 1), stride=(1,1), padding=(0, 0)),
|
| 346 |
+
nn.BatchNorm2d(256)
|
| 347 |
+
)
|
| 348 |
+
self.downsample2 = nn.Sequential(
|
| 349 |
+
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(1, 1), stride=(2,2), padding=(0, 0)),
|
| 350 |
+
nn.BatchNorm2d(256)
|
| 351 |
+
)
|
| 352 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 353 |
+
|
| 354 |
+
self.conv3_2 = nn.Sequential(
|
| 355 |
+
nn.Conv2d(in_channels=512, out_channels=256, kernel_size=(1, 1), stride=(1,1), padding=(0, 0)),
|
| 356 |
+
nn.BatchNorm2d(256),
|
| 357 |
+
nn.ReLU(inplace=True),
|
| 358 |
+
|
| 359 |
+
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1,1), padding=(1, 1)),
|
| 360 |
+
nn.BatchNorm2d(256),
|
| 361 |
+
nn.ReLU(inplace=True),
|
| 362 |
+
|
| 363 |
+
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(1, 1), stride=(1,1), padding=(0, 0)),
|
| 364 |
+
nn.BatchNorm2d(512)
|
| 365 |
+
)
|
| 366 |
+
self.downsample3 = nn.Sequential(
|
| 367 |
+
nn.Conv2d(in_channels=64, out_channels=512, kernel_size=(1, 1), stride=(4,4), padding=(0, 0)),
|
| 368 |
+
nn.BatchNorm2d(512)
|
| 369 |
+
)
|
| 370 |
+
self.relu3 = nn.ReLU(inplace=True)
|
| 371 |
+
|
| 372 |
+
# self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
| 373 |
+
# dilate=replace_stride_with_dilation[2])
|
| 374 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 375 |
+
self.fc = nn.Linear(256 * block.expansion + 2, num_classes)
|
| 376 |
+
|
| 377 |
+
for m in self.modules():
|
| 378 |
+
if isinstance(m, nn.Conv2d):
|
| 379 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 380 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 381 |
+
nn.init.constant_(m.weight, 1)
|
| 382 |
+
nn.init.constant_(m.bias, 0)
|
| 383 |
+
elif isinstance(m, nn.BatchNorm1d): # add by xzt
|
| 384 |
+
nn.init.constant_(m.weight, 1)
|
| 385 |
+
nn.init.constant_(m.bias, 0)
|
| 386 |
+
elif isinstance(m, nn.Linear):
|
| 387 |
+
nn.init.xavier_normal_(m.weight)
|
| 388 |
+
|
| 389 |
+
# Zero-initialize the last BN in each residual branch,
|
| 390 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
| 391 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
| 392 |
+
if zero_init_residual:
|
| 393 |
+
for m in self.modules():
|
| 394 |
+
if isinstance(m, Bottleneck):
|
| 395 |
+
nn.init.constant_(m.bn3.weight, 0)
|
| 396 |
+
|
| 397 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
| 398 |
+
norm_layer = self._norm_layer
|
| 399 |
+
downsample = None
|
| 400 |
+
previous_dilation = self.dilation
|
| 401 |
+
if dilate:
|
| 402 |
+
self.dilation *= stride
|
| 403 |
+
stride = 1
|
| 404 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 405 |
+
downsample = nn.Sequential(
|
| 406 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 407 |
+
norm_layer(planes * block.expansion),
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
layers = []
|
| 411 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
| 412 |
+
self.base_width, previous_dilation, norm_layer))
|
| 413 |
+
self.inplanes = planes * block.expansion
|
| 414 |
+
for _ in range(1, blocks):
|
| 415 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
| 416 |
+
base_width=self.base_width, dilation=self.dilation,
|
| 417 |
+
norm_layer=norm_layer))
|
| 418 |
+
|
| 419 |
+
return nn.Sequential(*layers)
|
| 420 |
+
|
| 421 |
+
def _forward_impl(self, scan, semantics, goal):
|
| 422 |
+
###### Start of fusion net ######
|
| 423 |
+
scan_in = scan.reshape(-1,1,80,80)
|
| 424 |
+
semantics_in = semantics.reshape(-1,1,80,80)
|
| 425 |
+
fusion_in = torch.cat((scan_in, semantics_in), dim=1)
|
| 426 |
+
|
| 427 |
+
# See note [TorchScript super()]
|
| 428 |
+
x = self.conv1(fusion_in)
|
| 429 |
+
x = self.bn1(x)
|
| 430 |
+
x = self.relu(x)
|
| 431 |
+
x = self.maxpool(x)
|
| 432 |
+
|
| 433 |
+
identity3 = self.downsample3(x)
|
| 434 |
+
|
| 435 |
+
x = self.layer1(x)
|
| 436 |
+
|
| 437 |
+
identity2 = self.downsample2(x)
|
| 438 |
+
|
| 439 |
+
x = self.layer2(x)
|
| 440 |
+
|
| 441 |
+
x = self.conv2_2(x)
|
| 442 |
+
x += identity2
|
| 443 |
+
x = self.relu2(x)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
x = self.layer3(x)
|
| 447 |
+
# x = self.layer4(x)
|
| 448 |
+
|
| 449 |
+
x = self.conv3_2(x)
|
| 450 |
+
x += identity3
|
| 451 |
+
x = self.relu3(x)
|
| 452 |
+
|
| 453 |
+
x = self.avgpool(x)
|
| 454 |
+
fusion_out = torch.flatten(x, 1)
|
| 455 |
+
###### End of fusion net ######
|
| 456 |
+
|
| 457 |
+
###### Start of goal net #######
|
| 458 |
+
goal_in = goal.reshape(-1,2)
|
| 459 |
+
goal_out = torch.flatten(goal_in, 1)
|
| 460 |
+
###### End of goal net #######
|
| 461 |
+
# Combine
|
| 462 |
+
fc_in = torch.cat((fusion_out, goal_out), dim=1)
|
| 463 |
+
x = self.fc(fc_in)
|
| 464 |
+
|
| 465 |
+
return x
|
| 466 |
+
|
| 467 |
+
def forward(self, scan, semantics, goal):
|
| 468 |
+
return self._forward_impl(scan, semantics, goal)
|
| 469 |
+
#
|
| 470 |
+
# end of method
|
| 471 |
+
#
|
| 472 |
+
# end of class
|
| 473 |
+
|
| 474 |
+
#
|
| 475 |
+
# end of file
|
training/scripts/train.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
#
|
| 3 |
+
# file: $ISIP_EXP/SOGMP/scripts/train.py
|
| 4 |
+
#
|
| 5 |
+
# revision history: xzt
|
| 6 |
+
# 20220824 (TE): first version
|
| 7 |
+
#
|
| 8 |
+
# usage:
|
| 9 |
+
# python train.py mdir train_data val_data
|
| 10 |
+
#
|
| 11 |
+
# arguments:
|
| 12 |
+
# mdir: the directory where the output model is stored
|
| 13 |
+
# train_data: the directory of training data
|
| 14 |
+
# val_data: the directory of valiation data
|
| 15 |
+
#
|
| 16 |
+
# This script trains a Semantic CNN model
|
| 17 |
+
#------------------------------------------------------------------------------
|
| 18 |
+
|
| 19 |
+
# import pytorch modules
|
| 20 |
+
#
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
from torch.optim import Adam
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
|
| 26 |
+
# visualize:
|
| 27 |
+
from tensorboardX import SummaryWriter
|
| 28 |
+
import numpy as np
|
| 29 |
+
|
| 30 |
+
# import the model and all of its variables/functions
|
| 31 |
+
#
|
| 32 |
+
from model import *
|
| 33 |
+
|
| 34 |
+
# import modules
|
| 35 |
+
#
|
| 36 |
+
import sys
|
| 37 |
+
import os
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
#-----------------------------------------------------------------------------
|
| 41 |
+
#
|
| 42 |
+
# global variables are listed here
|
| 43 |
+
#
|
| 44 |
+
#-----------------------------------------------------------------------------
|
| 45 |
+
|
| 46 |
+
# general global values
|
| 47 |
+
#
|
| 48 |
+
model_dir = './model/semantic_cnn_model.pth' # the path of model storage
|
| 49 |
+
NUM_ARGS = 3
|
| 50 |
+
NUM_EPOCHS = 4000
|
| 51 |
+
BATCH_SIZE = 64
|
| 52 |
+
LEARNING_RATE = "lr"
|
| 53 |
+
BETAS = "betas"
|
| 54 |
+
EPS = "eps"
|
| 55 |
+
WEIGHT_DECAY = "weight_decay"
|
| 56 |
+
|
| 57 |
+
# for reproducibility, we seed the rng
|
| 58 |
+
#
|
| 59 |
+
set_seed(SEED1)
|
| 60 |
+
|
| 61 |
+
# adjust_learning_rate
|
| 62 |
+
#
|
| 63 |
+
def adjust_learning_rate(optimizer, epoch):
|
| 64 |
+
lr = 1e-3
|
| 65 |
+
if epoch > 40:
|
| 66 |
+
lr = 2e-4
|
| 67 |
+
if epoch > 2000:
|
| 68 |
+
lr = 2e-5
|
| 69 |
+
if epoch > 21000:
|
| 70 |
+
lr = 1e-5
|
| 71 |
+
if epoch > 32984:
|
| 72 |
+
lr = 1e-6
|
| 73 |
+
if epoch > 48000:
|
| 74 |
+
# lr = 5e-8
|
| 75 |
+
lr = lr * (0.1 ** (epoch // 110000))
|
| 76 |
+
# if epoch > 8300:
|
| 77 |
+
# lr = 1e-9
|
| 78 |
+
for param_group in optimizer.param_groups:
|
| 79 |
+
param_group['lr'] = lr
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# train function:
|
| 84 |
+
def train(model, dataloader, dataset, device, optimizer, criterion, epoch, epochs):
|
| 85 |
+
################################## Train #####################################
|
| 86 |
+
# Set model to training mode
|
| 87 |
+
model.train()
|
| 88 |
+
# for each batch in increments of batch size
|
| 89 |
+
#
|
| 90 |
+
running_loss = 0
|
| 91 |
+
counter = 0
|
| 92 |
+
# get the number of batches (ceiling of train_data/batch_size):
|
| 93 |
+
num_batches = int(len(dataset)/dataloader.batch_size)
|
| 94 |
+
for i, batch in tqdm(enumerate(dataloader), total=num_batches):
|
| 95 |
+
#for i, batch in enumerate(dataloader, 0):
|
| 96 |
+
counter += 1
|
| 97 |
+
# collect the samples as a batch:
|
| 98 |
+
scan_maps = batch['scan_map']
|
| 99 |
+
scan_maps = scan_maps.to(device)
|
| 100 |
+
semantic_maps = batch['semantic_map']
|
| 101 |
+
semantic_maps = semantic_maps.to(device)
|
| 102 |
+
sub_goals = batch['sub_goal']
|
| 103 |
+
sub_goals = sub_goals.to(device)
|
| 104 |
+
velocities = batch['velocity']
|
| 105 |
+
velocities = velocities.to(device)
|
| 106 |
+
|
| 107 |
+
# set all gradients to 0:
|
| 108 |
+
optimizer.zero_grad()
|
| 109 |
+
# feed the network the batch
|
| 110 |
+
#
|
| 111 |
+
|
| 112 |
+
output = model(scan_maps, semantic_maps, sub_goals)
|
| 113 |
+
#writer.add_graph(model,[batch_ped_pos_t, batch_scan_t, batch_goal_t])
|
| 114 |
+
# get the loss
|
| 115 |
+
#
|
| 116 |
+
# loss = criterion(output, velocities)
|
| 117 |
+
|
| 118 |
+
# ---------------------------
|
| 119 |
+
# Mask zero-velocity samples
|
| 120 |
+
# ---------------------------
|
| 121 |
+
mask = (velocities != 0).any(dim=1) # (B,)
|
| 122 |
+
|
| 123 |
+
if mask.sum() == 0:
|
| 124 |
+
loss = output.sum() * 0 # safe zero loss
|
| 125 |
+
else:
|
| 126 |
+
loss = criterion(output[mask], velocities[mask])
|
| 127 |
+
|
| 128 |
+
# perform back propagation:
|
| 129 |
+
loss.backward(torch.ones_like(loss))
|
| 130 |
+
optimizer.step()
|
| 131 |
+
# get the loss:
|
| 132 |
+
# multiple GPUs:
|
| 133 |
+
if torch.cuda.device_count() > 1:
|
| 134 |
+
loss = loss.mean()
|
| 135 |
+
|
| 136 |
+
running_loss += loss.item()
|
| 137 |
+
|
| 138 |
+
# display informational message
|
| 139 |
+
#
|
| 140 |
+
if(i % 1280 == 0):
|
| 141 |
+
print('Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}'
|
| 142 |
+
.format(epoch, epochs, i + 1, num_batches, loss.item()))
|
| 143 |
+
|
| 144 |
+
train_loss = running_loss / len(dataset) #counter
|
| 145 |
+
|
| 146 |
+
return train_loss
|
| 147 |
+
|
| 148 |
+
# validate function:
|
| 149 |
+
def validate(model, dataloader, dataset, device, criterion):
|
| 150 |
+
################################## Train #####################################
|
| 151 |
+
# set model to evaluation mode:
|
| 152 |
+
model.eval()
|
| 153 |
+
# for each batch in increments of batch size
|
| 154 |
+
#
|
| 155 |
+
running_loss = 0
|
| 156 |
+
counter = 0
|
| 157 |
+
# get the number of batches (ceiling of train_data/batch_size):
|
| 158 |
+
num_batches = int(len(dataset)/dataloader.batch_size)
|
| 159 |
+
for i, batch in tqdm(enumerate(dataloader), total=num_batches):
|
| 160 |
+
#for i, batch in enumerate(dataloader, 0):
|
| 161 |
+
counter += 1
|
| 162 |
+
# collect the samples as a batch:
|
| 163 |
+
scan_maps = batch['scan_map']
|
| 164 |
+
scan_maps = scan_maps.to(device)
|
| 165 |
+
|
| 166 |
+
semantic_maps = batch['semantic_map']
|
| 167 |
+
semantic_maps = semantic_maps.to(device)
|
| 168 |
+
|
| 169 |
+
sub_goals = batch['sub_goal']
|
| 170 |
+
sub_goals = sub_goals.to(device)
|
| 171 |
+
velocities = batch['velocity']
|
| 172 |
+
velocities = velocities.to(device)
|
| 173 |
+
|
| 174 |
+
# feed the network the batch
|
| 175 |
+
#
|
| 176 |
+
output = model(scan_maps, semantic_maps, sub_goals)
|
| 177 |
+
#writer.add_graph(model,[batch_ped_pos_t, batch_scan_t, batch_goal_t])
|
| 178 |
+
# get the loss
|
| 179 |
+
#
|
| 180 |
+
# loss = criterion(output, velocities)
|
| 181 |
+
# ---------------------------
|
| 182 |
+
# Mask zero-velocity samples
|
| 183 |
+
# ---------------------------
|
| 184 |
+
mask = (velocities != 0).any(dim=1) # (B,)
|
| 185 |
+
|
| 186 |
+
if mask.sum() == 0:
|
| 187 |
+
loss = output.sum() * 0 # safe zero loss
|
| 188 |
+
else:
|
| 189 |
+
loss = criterion(output[mask], velocities[mask])
|
| 190 |
+
|
| 191 |
+
# get the loss:
|
| 192 |
+
# multiple GPUs:
|
| 193 |
+
if torch.cuda.device_count() > 1:
|
| 194 |
+
loss = loss.mean()
|
| 195 |
+
|
| 196 |
+
running_loss += loss.item()
|
| 197 |
+
|
| 198 |
+
val_loss = running_loss / len(dataset) #counter
|
| 199 |
+
|
| 200 |
+
return val_loss
|
| 201 |
+
|
| 202 |
+
#------------------------------------------------------------------------------
|
| 203 |
+
#
|
| 204 |
+
# the main program starts here
|
| 205 |
+
#
|
| 206 |
+
#------------------------------------------------------------------------------
|
| 207 |
+
|
| 208 |
+
# function: main
|
| 209 |
+
#
|
| 210 |
+
# arguments: none
|
| 211 |
+
#
|
| 212 |
+
# return: none
|
| 213 |
+
#
|
| 214 |
+
# This method is the main function.
|
| 215 |
+
#
|
| 216 |
+
def main(argv):
|
| 217 |
+
|
| 218 |
+
# ensure we have the correct amount of arguments
|
| 219 |
+
#
|
| 220 |
+
#global cur_batch_win
|
| 221 |
+
if(len(argv) != NUM_ARGS):
|
| 222 |
+
print("usage: python nedc_train_mdl.py [MDL_PATH] [TRAIN_PATH] [DEV_PATH]")
|
| 223 |
+
exit(-1)
|
| 224 |
+
|
| 225 |
+
# define local variables
|
| 226 |
+
#
|
| 227 |
+
mdl_path = argv[0]
|
| 228 |
+
pTrain = argv[1]
|
| 229 |
+
pDev = argv[2]
|
| 230 |
+
|
| 231 |
+
# get the output directory name
|
| 232 |
+
#
|
| 233 |
+
odir = os.path.dirname(mdl_path)
|
| 234 |
+
|
| 235 |
+
# if the odir doesn't exits, we make it
|
| 236 |
+
#
|
| 237 |
+
if not os.path.exists(odir):
|
| 238 |
+
os.makedirs(odir)
|
| 239 |
+
|
| 240 |
+
# set the device to use GPU if available
|
| 241 |
+
#
|
| 242 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 243 |
+
|
| 244 |
+
### train:
|
| 245 |
+
print('...Start reading data...')
|
| 246 |
+
# get array of the data
|
| 247 |
+
# data: [[0, 1, ... 26], [27, 28, ...] ...]
|
| 248 |
+
# labels: [0, 0, 1, ...]
|
| 249 |
+
#
|
| 250 |
+
#[ped_pos_t, scan_t, goal_t, vel_t] = get_data(pTrain)
|
| 251 |
+
train_dataset = NavDataset(pTrain, 'train')
|
| 252 |
+
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, \
|
| 253 |
+
shuffle=True, drop_last=True, pin_memory=True)
|
| 254 |
+
#train_data = train_data - np.mean(train_data, axis=0)
|
| 255 |
+
|
| 256 |
+
### dev:
|
| 257 |
+
|
| 258 |
+
# get array of the data
|
| 259 |
+
# data: [[0, 1, ... 26], [27, 28, ...] ...]
|
| 260 |
+
# labels: [0, 0, 1, ...]
|
| 261 |
+
#
|
| 262 |
+
#[ped_pos_d, scan_d, goal_d, vel_d] = get_data(pDev)
|
| 263 |
+
dev_dataset = NavDataset(pDev, 'dev')
|
| 264 |
+
dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=BATCH_SIZE, \
|
| 265 |
+
shuffle=True, drop_last=True, pin_memory=True)
|
| 266 |
+
#dev_data = dev_data - np.mean(dev_data, axis=0)
|
| 267 |
+
print('...Finish reading data...')
|
| 268 |
+
|
| 269 |
+
# instantiate a model
|
| 270 |
+
#
|
| 271 |
+
model = SemanticCNN(Bottleneck, [2, 1, 1])
|
| 272 |
+
|
| 273 |
+
# moves the model to device (cpu in our case so no change)
|
| 274 |
+
#
|
| 275 |
+
model.to(device)
|
| 276 |
+
|
| 277 |
+
# set the adam optimizer parameters
|
| 278 |
+
#
|
| 279 |
+
opt_params = { LEARNING_RATE: 0.001,
|
| 280 |
+
BETAS: (.9,0.999),
|
| 281 |
+
EPS: 1e-08,
|
| 282 |
+
WEIGHT_DECAY: .001 }
|
| 283 |
+
|
| 284 |
+
# set the loss and optimizer
|
| 285 |
+
#
|
| 286 |
+
criterion = nn.MSELoss(reduction='sum')
|
| 287 |
+
criterion.to(device)
|
| 288 |
+
|
| 289 |
+
# create an optimizer, and pass the model params to it
|
| 290 |
+
#
|
| 291 |
+
optimizer = Adam(model.parameters(), **opt_params)
|
| 292 |
+
|
| 293 |
+
# get the number of epochs to train on
|
| 294 |
+
#
|
| 295 |
+
epochs = NUM_EPOCHS
|
| 296 |
+
|
| 297 |
+
# if there are trained models, continue training:
|
| 298 |
+
if os.path.exists(mdl_path):
|
| 299 |
+
checkpoint = torch.load(mdl_path)
|
| 300 |
+
model.load_state_dict(checkpoint['model'])
|
| 301 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 302 |
+
start_epoch = checkpoint['epoch']
|
| 303 |
+
print('Load epoch {} success'.format(start_epoch))
|
| 304 |
+
else:
|
| 305 |
+
start_epoch = 0
|
| 306 |
+
print('No trained models, restart training')
|
| 307 |
+
|
| 308 |
+
# multiple GPUs:
|
| 309 |
+
if torch.cuda.device_count() > 1:
|
| 310 |
+
print("Let's use 2 of total", torch.cuda.device_count(), "GPUs!")
|
| 311 |
+
# dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
|
| 312 |
+
model = nn.DataParallel(model) #, device_ids=[0, 1])
|
| 313 |
+
|
| 314 |
+
# moves the model to device (cpu in our case so no change)
|
| 315 |
+
#
|
| 316 |
+
model.to(device)
|
| 317 |
+
|
| 318 |
+
# tensorboard writer:
|
| 319 |
+
writer = SummaryWriter('runs')
|
| 320 |
+
|
| 321 |
+
# for each epoch
|
| 322 |
+
#
|
| 323 |
+
#loss_train = []
|
| 324 |
+
#loss_vector = []
|
| 325 |
+
epoch_num = 0
|
| 326 |
+
for epoch in range(start_epoch+1, epochs):
|
| 327 |
+
|
| 328 |
+
# adjust learning rate:
|
| 329 |
+
adjust_learning_rate(optimizer, epoch)
|
| 330 |
+
################################## Train #####################################
|
| 331 |
+
# for each batch in increments of batch size
|
| 332 |
+
#
|
| 333 |
+
train_epoch_loss = train(
|
| 334 |
+
model, train_dataloader, train_dataset, device, optimizer, criterion, epoch, epochs
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
################################## Test #####################################
|
| 338 |
+
valid_epoch_loss = validate(
|
| 339 |
+
model, dev_dataloader, dev_dataset, device, criterion
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
# log the epoch loss
|
| 343 |
+
writer.add_scalar('training loss',
|
| 344 |
+
train_epoch_loss,
|
| 345 |
+
epoch)
|
| 346 |
+
writer.add_scalar('validation loss',
|
| 347 |
+
valid_epoch_loss,
|
| 348 |
+
epoch)
|
| 349 |
+
|
| 350 |
+
print('Train set: Average loss: {:.4f}'.format(train_epoch_loss))
|
| 351 |
+
print('Validation set: Average loss: {:.4f}'.format(valid_epoch_loss))
|
| 352 |
+
|
| 353 |
+
# save the model
|
| 354 |
+
#
|
| 355 |
+
if(epoch % 50 == 0):
|
| 356 |
+
if torch.cuda.device_count() > 1: # multiple GPUS:
|
| 357 |
+
state = {'model':model.module.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
|
| 358 |
+
else:
|
| 359 |
+
state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
|
| 360 |
+
path='./model/model' + str(epoch) +'.pth'
|
| 361 |
+
torch.save(state, path)
|
| 362 |
+
|
| 363 |
+
epoch_num = epoch
|
| 364 |
+
|
| 365 |
+
# save the final model
|
| 366 |
+
if torch.cuda.device_count() > 1: # multiple GPUS:
|
| 367 |
+
state = {'model':model.module.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch_num}
|
| 368 |
+
else:
|
| 369 |
+
state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch_num}
|
| 370 |
+
torch.save(state, mdl_path)
|
| 371 |
+
|
| 372 |
+
# exit gracefully
|
| 373 |
+
#
|
| 374 |
+
|
| 375 |
+
return True
|
| 376 |
+
#
|
| 377 |
+
# end of function
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
# begin gracefully
|
| 381 |
+
#
|
| 382 |
+
if __name__ == '__main__':
|
| 383 |
+
main(sys.argv[1:])
|
| 384 |
+
#
|
| 385 |
+
# end of file
|