fix model
Browse files- app.py +16 -5
- environment.yaml +0 -161
app.py
CHANGED
|
@@ -22,6 +22,7 @@ import gradio as gr
|
|
| 22 |
import torchvision.transforms as standard_transforms
|
| 23 |
from torch.utils.data import DataLoader
|
| 24 |
from torch.utils.data import Dataset
|
|
|
|
| 25 |
|
| 26 |
warnings.filterwarnings('ignore')
|
| 27 |
|
|
@@ -96,14 +97,23 @@ with gr.Blocks() as demo:
|
|
| 96 |
We implemented a image crowd counting model with VGG16 following the paper of Song et. al (2021).
|
| 97 |
|
| 98 |
## Abstract
|
| 99 |
-
In this paper, we address the large scale variation problem in crowd counting by taking full advantage of the multi-scale feature representations in a multi-level network. We
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
## References
|
| 104 |
-
Song, Q., Wang, C., Wang, Y., Tai, Y., Wang, C., Li, J., … Ma, J. (2021). To Choose or to Fuse? Scale Selection for Crowd Counting.
|
|
|
|
| 105 |
""")
|
| 106 |
-
image_button = gr.Button("Count the Crowd!")
|
| 107 |
with gr.Row():
|
| 108 |
with gr.Column():
|
| 109 |
image_input = gr.Image(type="pil")
|
|
@@ -112,6 +122,7 @@ The code will be available at: https://github.com/TencentYoutuResearch/CrowdCoun
|
|
| 112 |
image_output = gr.Plot()
|
| 113 |
with gr.Column():
|
| 114 |
text_output = gr.Label()
|
|
|
|
| 115 |
|
| 116 |
|
| 117 |
image_button.click(predict, inputs=image_input, outputs=[text_output, image_output])
|
|
|
|
| 22 |
import torchvision.transforms as standard_transforms
|
| 23 |
from torch.utils.data import DataLoader
|
| 24 |
from torch.utils.data import Dataset
|
| 25 |
+
from model import SASNet
|
| 26 |
|
| 27 |
warnings.filterwarnings('ignore')
|
| 28 |
|
|
|
|
| 97 |
We implemented a image crowd counting model with VGG16 following the paper of Song et. al (2021).
|
| 98 |
|
| 99 |
## Abstract
|
| 100 |
+
In this paper, we address the large scale variation problem in crowd counting by taking full advantage of the multi-scale feature representations in a multi-level network. We
|
| 101 |
+
implement such an idea by keeping the counting error of a patch as small as possible with a proper feature level selection strategy, since a specific feature level tends to perform
|
| 102 |
+
better for a certain range of scales. However, without scale annotations, it is sub-optimal and error-prone to manually assign the predictions for heads of different scales to
|
| 103 |
+
specific feature levels. Therefore, we propose a Scale-Adaptive Selection Network (SASNet), which automatically learns the internal correspondence between the scales and the feature
|
| 104 |
+
levels. Instead of directly using the predictions from the most appropriate feature level as the final estimation, our SASNet also considers the predictions from other feature
|
| 105 |
+
levels via weighted average, which helps to mitigate the gap between discrete feature levels and continuous scale variation. Since the heads in a local patch share roughly a same
|
| 106 |
+
scale, we conduct the adaptive selection strategy in a patch-wise style. However, pixels within a patch contribute different counting errors due to the various difficulty degrees of
|
| 107 |
+
learning. Thus, we further propose a Pyramid Region Awareness Loss (PRA Loss) to recursively select the most hard sub-regions within a patch until reaching the pixel level. With
|
| 108 |
+
awareness of whether the parent patch is over-estimated or under-estimated, the fine-grained optimization with the PRA Loss for these region-aware hard pixels helps to alleviate the
|
| 109 |
+
inconsistency problem between training target and evaluation metric. The state-of-the-art results on four datasets demonstrate the superiority of our approach.
|
| 110 |
+
|
| 111 |
+
The code will be available at: https://github.com/TencentYoutuResearch/CrowdCounting-SASNet.
|
| 112 |
|
| 113 |
## References
|
| 114 |
+
Song, Q., Wang, C., Wang, Y., Tai, Y., Wang, C., Li, J., … Ma, J. (2021). To Choose or to Fuse? Scale Selection for Crowd Counting.
|
| 115 |
+
The Thirty-Fifth AAAI Conference on Artificial Intelligence (AAAI-21).
|
| 116 |
""")
|
|
|
|
| 117 |
with gr.Row():
|
| 118 |
with gr.Column():
|
| 119 |
image_input = gr.Image(type="pil")
|
|
|
|
| 122 |
image_output = gr.Plot()
|
| 123 |
with gr.Column():
|
| 124 |
text_output = gr.Label()
|
| 125 |
+
image_button = gr.Button("Count the Crowd!")
|
| 126 |
|
| 127 |
|
| 128 |
image_button.click(predict, inputs=image_input, outputs=[text_output, image_output])
|
environment.yaml
DELETED
|
@@ -1,161 +0,0 @@
|
|
| 1 |
-
name: SASNet
|
| 2 |
-
channels:
|
| 3 |
-
- pytorch
|
| 4 |
-
- nvidia
|
| 5 |
-
- anaconda
|
| 6 |
-
- defaults
|
| 7 |
-
dependencies:
|
| 8 |
-
- _libgcc_mutex=0.1=main
|
| 9 |
-
- _openmp_mutex=5.1=1_gnu
|
| 10 |
-
- _pytorch_select=0.1=cpu_0
|
| 11 |
-
- backcall=0.2.0=pyhd3eb1b0_0
|
| 12 |
-
- blas=1.0=mkl
|
| 13 |
-
- ca-certificates=2022.07.19=h06a4308_0
|
| 14 |
-
- certifi=2022.6.15=py37h06a4308_0
|
| 15 |
-
- cffi=1.15.0=py37h7f8727e_0
|
| 16 |
-
- cuda=12.0.0=0
|
| 17 |
-
- cuda-cccl=12.0.90=0
|
| 18 |
-
- cuda-command-line-tools=12.0.0=0
|
| 19 |
-
- cuda-compiler=12.0.0=0
|
| 20 |
-
- cuda-cudart=12.0.107=0
|
| 21 |
-
- cuda-cudart-dev=12.0.107=0
|
| 22 |
-
- cuda-cudart-static=12.0.107=0
|
| 23 |
-
- cuda-cuobjdump=12.0.76=0
|
| 24 |
-
- cuda-cupti=12.0.90=0
|
| 25 |
-
- cuda-cupti-static=12.0.90=0
|
| 26 |
-
- cuda-cuxxfilt=12.0.76=0
|
| 27 |
-
- cuda-demo-suite=12.0.76=0
|
| 28 |
-
- cuda-documentation=12.0.76=0
|
| 29 |
-
- cuda-driver-dev=12.0.107=0
|
| 30 |
-
- cuda-gdb=12.0.90=0
|
| 31 |
-
- cuda-libraries=12.0.0=0
|
| 32 |
-
- cuda-libraries-dev=12.0.0=0
|
| 33 |
-
- cuda-libraries-static=12.0.0=0
|
| 34 |
-
- cuda-nsight=12.0.78=0
|
| 35 |
-
- cuda-nsight-compute=12.0.0=0
|
| 36 |
-
- cuda-nvcc=12.0.76=0
|
| 37 |
-
- cuda-nvdisasm=12.0.76=0
|
| 38 |
-
- cuda-nvml-dev=12.0.76=0
|
| 39 |
-
- cuda-nvprof=12.0.90=0
|
| 40 |
-
- cuda-nvprune=12.0.76=0
|
| 41 |
-
- cuda-nvrtc=12.0.76=0
|
| 42 |
-
- cuda-nvrtc-dev=12.0.76=0
|
| 43 |
-
- cuda-nvrtc-static=12.0.76=0
|
| 44 |
-
- cuda-nvtx=12.0.76=0
|
| 45 |
-
- cuda-nvvp=12.0.90=0
|
| 46 |
-
- cuda-opencl=12.0.76=0
|
| 47 |
-
- cuda-opencl-dev=12.0.76=0
|
| 48 |
-
- cuda-profiler-api=12.0.76=0
|
| 49 |
-
- cuda-runtime=12.0.0=0
|
| 50 |
-
- cuda-sanitizer-api=12.0.90=0
|
| 51 |
-
- cuda-toolkit=12.0.0=0
|
| 52 |
-
- cuda-tools=12.0.0=0
|
| 53 |
-
- cuda-visual-tools=12.0.0=0
|
| 54 |
-
- cudatoolkit=10.2.89=hfd86e86_1
|
| 55 |
-
- debugpy=1.5.1=py37h295c915_0
|
| 56 |
-
- decorator=5.1.1=pyhd3eb1b0_0
|
| 57 |
-
- entrypoints=0.4=py37h06a4308_0
|
| 58 |
-
- freetype=2.12.1=h4a9f257_0
|
| 59 |
-
- gds-tools=1.5.0.59=0
|
| 60 |
-
- giflib=5.2.1=h7b6447c_0
|
| 61 |
-
- intel-openmp=2022.1.0=h9e868ea_3769
|
| 62 |
-
- ipykernel=6.9.1=py37h06a4308_0
|
| 63 |
-
- ipython=7.31.1=py37h06a4308_1
|
| 64 |
-
- jedi=0.18.1=py37h06a4308_1
|
| 65 |
-
- jpeg=9e=h7f8727e_0
|
| 66 |
-
- jupyter_client=7.2.2=py37h06a4308_0
|
| 67 |
-
- jupyter_core=4.10.0=py37h06a4308_0
|
| 68 |
-
- lcms2=2.12=h3be6417_0
|
| 69 |
-
- lerc=3.0=h295c915_0
|
| 70 |
-
- libcublas=12.0.1.189=0
|
| 71 |
-
- libcublas-dev=12.0.1.189=0
|
| 72 |
-
- libcublas-static=12.0.1.189=0
|
| 73 |
-
- libcufft=11.0.0.21=0
|
| 74 |
-
- libcufft-dev=11.0.0.21=0
|
| 75 |
-
- libcufft-static=11.0.0.21=0
|
| 76 |
-
- libcufile=1.5.0.59=0
|
| 77 |
-
- libcufile-dev=1.5.0.59=0
|
| 78 |
-
- libcufile-static=1.5.0.59=0
|
| 79 |
-
- libcurand=10.3.1.50=0
|
| 80 |
-
- libcurand-dev=10.3.1.50=0
|
| 81 |
-
- libcurand-static=10.3.1.50=0
|
| 82 |
-
- libcusolver=11.4.2.57=0
|
| 83 |
-
- libcusolver-dev=11.4.2.57=0
|
| 84 |
-
- libcusolver-static=11.4.2.57=0
|
| 85 |
-
- libcusparse=12.0.0.76=0
|
| 86 |
-
- libcusparse-dev=12.0.0.76=0
|
| 87 |
-
- libcusparse-static=12.0.0.76=0
|
| 88 |
-
- libdeflate=1.8=h7f8727e_5
|
| 89 |
-
- libedit=3.1.20221030=h5eee18b_0
|
| 90 |
-
- libffi=3.2.1=hf484d3e_1007
|
| 91 |
-
- libgcc-ng=11.2.0=h1234567_1
|
| 92 |
-
- libgfortran-ng=7.5.0=ha8ba4b0_17
|
| 93 |
-
- libgfortran4=7.5.0=ha8ba4b0_17
|
| 94 |
-
- libgomp=11.2.0=h1234567_1
|
| 95 |
-
- libnpp=12.0.0.30=0
|
| 96 |
-
- libnpp-dev=12.0.0.30=0
|
| 97 |
-
- libnpp-static=12.0.0.30=0
|
| 98 |
-
- libnvjitlink=12.0.76=0
|
| 99 |
-
- libnvjitlink-dev=12.0.76=0
|
| 100 |
-
- libnvjpeg=12.0.0.28=0
|
| 101 |
-
- libnvjpeg-dev=12.0.0.28=0
|
| 102 |
-
- libnvjpeg-static=12.0.0.28=0
|
| 103 |
-
- libnvvm-samples=12.0.94=0
|
| 104 |
-
- libpng=1.6.37=hbc83047_0
|
| 105 |
-
- libsodium=1.0.18=h7b6447c_0
|
| 106 |
-
- libstdcxx-ng=11.2.0=h1234567_1
|
| 107 |
-
- libtiff=4.5.0=hecacb30_0
|
| 108 |
-
- libwebp=1.2.4=h11a3e52_0
|
| 109 |
-
- libwebp-base=1.2.4=h5eee18b_0
|
| 110 |
-
- lz4-c=1.9.4=h6a678d5_0
|
| 111 |
-
- matplotlib-inline=0.1.2=pyhd3eb1b0_2
|
| 112 |
-
- mkl=2019.4=243
|
| 113 |
-
- mkl-service=2.3.0=py37he8ac12f_0
|
| 114 |
-
- mkl_fft=1.3.0=py37h54f3939_0
|
| 115 |
-
- mkl_random=1.1.0=py37hd6b4f25_0
|
| 116 |
-
- ncurses=6.3=h5eee18b_3
|
| 117 |
-
- nest-asyncio=1.5.5=py37h06a4308_0
|
| 118 |
-
- ninja=1.10.2=h06a4308_5
|
| 119 |
-
- ninja-base=1.10.2=hd09550d_5
|
| 120 |
-
- nsight-compute=2022.4.0.15=0
|
| 121 |
-
- numpy-base=1.17.0=py37hde5b4d6_0
|
| 122 |
-
- openssl=1.0.2u=h7b6447c_0
|
| 123 |
-
- parso=0.8.3=pyhd3eb1b0_0
|
| 124 |
-
- pexpect=4.8.0=pyhd3eb1b0_3
|
| 125 |
-
- pickleshare=0.7.5=pyhd3eb1b0_1003
|
| 126 |
-
- pip=22.3.1=py37h06a4308_0
|
| 127 |
-
- prompt-toolkit=3.0.20=pyhd3eb1b0_0
|
| 128 |
-
- ptyprocess=0.7.0=pyhd3eb1b0_2
|
| 129 |
-
- pycparser=2.21=pyhd3eb1b0_0
|
| 130 |
-
- pygments=2.11.2=pyhd3eb1b0_0
|
| 131 |
-
- python=3.7.0=h6e4f718_3
|
| 132 |
-
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
| 133 |
-
- pytorch=1.5.0=py3.7_cuda10.2.89_cudnn7.6.5_0
|
| 134 |
-
- pyzmq=23.2.0=py37h6a678d5_0
|
| 135 |
-
- readline=7.0=h7b6447c_5
|
| 136 |
-
- setuptools=65.6.3=py37h06a4308_0
|
| 137 |
-
- six=1.16.0=pyhd3eb1b0_1
|
| 138 |
-
- sqlite=3.33.0=h62c20be_0
|
| 139 |
-
- tk=8.6.12=h1ccaba5_0
|
| 140 |
-
- torchvision=0.6.0=py37_cu102
|
| 141 |
-
- tornado=6.1=py37h27cfd23_0
|
| 142 |
-
- traitlets=5.1.1=pyhd3eb1b0_0
|
| 143 |
-
- wcwidth=0.2.5=pyhd3eb1b0_0
|
| 144 |
-
- wheel=0.37.1=pyhd3eb1b0_0
|
| 145 |
-
- xz=5.2.8=h5eee18b_0
|
| 146 |
-
- zeromq=4.3.4=h2531618_0
|
| 147 |
-
- zlib=1.2.13=h5eee18b_0
|
| 148 |
-
- zstd=1.5.2=ha4553b6_0
|
| 149 |
-
- pip:
|
| 150 |
-
- cached-property==1.5.2
|
| 151 |
-
- cycler==0.11.0
|
| 152 |
-
- h5py==3.1.0
|
| 153 |
-
- kiwisolver==1.4.4
|
| 154 |
-
- matplotlib==3.3.3
|
| 155 |
-
- numpy==1.19.0
|
| 156 |
-
- opencv-python==4.4.0.46
|
| 157 |
-
- pillow==8.0.1
|
| 158 |
-
- pyparsing==3.0.9
|
| 159 |
-
- scipy==1.5.4
|
| 160 |
-
- typing-extensions==4.4.0
|
| 161 |
-
prefix: /home/leuschnm/miniconda3/envs/SASNet
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|