Spaces:
Sleeping
Sleeping
Add source code and model
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +70 -0
- AIGVDet/Dockerfile +21 -0
- AIGVDet/README.md +94 -0
- AIGVDet/__init__.py +3 -0
- AIGVDet/alt_cuda_corr/correlation.cpp +54 -0
- AIGVDet/alt_cuda_corr/correlation_kernel.cu +324 -0
- AIGVDet/alt_cuda_corr/setup.py +15 -0
- AIGVDet/app.py +155 -0
- AIGVDet/checkpoints/optical.pth +3 -0
- AIGVDet/checkpoints/original.pth +3 -0
- AIGVDet/core/__init__.py +0 -0
- AIGVDet/core/corr.py +91 -0
- AIGVDet/core/datasets.py +235 -0
- AIGVDet/core/extractor.py +267 -0
- AIGVDet/core/raft.py +144 -0
- AIGVDet/core/update.py +139 -0
- AIGVDet/core/utils/__init__.py +0 -0
- AIGVDet/core/utils/augmentor.py +246 -0
- AIGVDet/core/utils/flow_viz.py +132 -0
- AIGVDet/core/utils/frame_utils.py +137 -0
- AIGVDet/core/utils/utils.py +82 -0
- AIGVDet/core/utils1/config.py +156 -0
- AIGVDet/core/utils1/datasets.py +178 -0
- AIGVDet/core/utils1/earlystop.py +46 -0
- AIGVDet/core/utils1/eval.py +66 -0
- AIGVDet/core/utils1/trainer.py +163 -0
- AIGVDet/core/utils1/utils.py +109 -0
- AIGVDet/core/utils1/utils1/config.py +157 -0
- AIGVDet/core/utils1/utils1/datasets.py +178 -0
- AIGVDet/core/utils1/utils1/earlystop.py +46 -0
- AIGVDet/core/utils1/utils1/eval.py +66 -0
- AIGVDet/core/utils1/utils1/trainer.py +169 -0
- AIGVDet/core/utils1/utils1/utils.py +109 -0
- AIGVDet/core/utils1/utils1/warmup.py +70 -0
- AIGVDet/core/utils1/warmup.py +70 -0
- AIGVDet/docker-compose.yml +17 -0
- AIGVDet/main.py +78 -0
- AIGVDet/networks/resnet.py +211 -0
- AIGVDet/raft_model/raft-things.pth +3 -0
- AIGVDet/requirements.txt +22 -0
- AIGVDet/run.py +214 -0
- AIGVDet/run.sh +1 -0
- AIGVDet/test.py +261 -0
- AIGVDet/test.sh +1 -0
- AIGVDet/train.py +87 -0
- AIGVDet/train.sh +4 -0
- api_server.py +281 -0
- checkpoints/image/best-mirage-img.pt +3 -0
- checkpoints/image/cbm-encoder.pt +3 -0
- checkpoints/image/cbm-predictor.pt +3 -0
.gitignore
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
AIGVDet/frame/
|
| 2 |
+
AIGVDet/optical_result/
|
| 3 |
+
AIGVDet/temp/
|
| 4 |
+
frame/
|
| 5 |
+
optical_result/
|
| 6 |
+
temp/
|
| 7 |
+
temp_uploads/
|
| 8 |
+
|
| 9 |
+
miragenews/encodings/crops/
|
| 10 |
+
miragenews/encodings/image/test1_nyt_mj/
|
| 11 |
+
miragenews/encodings/image/test2_bbc_dalle/
|
| 12 |
+
miragenews/encodings/image/test3_cnn_dalle/
|
| 13 |
+
miragenews/encodings/image/test4_bbc_sdxl/
|
| 14 |
+
miragenews/encodings/image/test5_cnn_sdxl/
|
| 15 |
+
miragenews/encodings/image/train/
|
| 16 |
+
miragenews/encodings/image/validation/
|
| 17 |
+
miragenews/encodings/predictions/image/cbm-encoder/test1_nyt_mj/
|
| 18 |
+
miragenews/encodings/predictions/image/cbm-encoder/test2_bbc_dalle/
|
| 19 |
+
miragenews/encodings/predictions/image/cbm-encoder/test3_cnn_dalle/
|
| 20 |
+
miragenews/encodings/predictions/image/cbm-encoder/test4_bbc_sdxl/
|
| 21 |
+
miragenews/encodings/predictions/image/cbm-encoder/test5_cnn_sdxl/
|
| 22 |
+
miragenews/encodings/predictions/image/cbm-encoder/train/
|
| 23 |
+
miragenews/encodings/predictions/image/cbm-encoder/validation/
|
| 24 |
+
miragenews/encodings/predictions/image/linear/validation/
|
| 25 |
+
miragenews/encodings/predictions/image/linear/test1_nyt_mj/
|
| 26 |
+
miragenews/encodings/predictions/image/linear/test2_bbc_dalle/
|
| 27 |
+
miragenews/encodings/predictions/image/linear/test3_cnn_dalle/
|
| 28 |
+
miragenews/encodings/predictions/image/linear/test4_bbc_sdxl/
|
| 29 |
+
miragenews/encodings/predictions/image/linear/test5_cnn_sdxl/
|
| 30 |
+
miragenews/encodings/predictions/image/linear/train/
|
| 31 |
+
miragenews/encodings/predictions/image/merged/train/
|
| 32 |
+
miragenews/encodings/predictions/image/merged/test1_nyt_mj/
|
| 33 |
+
miragenews/encodings/predictions/image/merged/test2_bbc_dalle/
|
| 34 |
+
miragenews/encodings/predictions/image/merged/test3_cnn_dalle/
|
| 35 |
+
miragenews/encodings/predictions/image/merged/test4_bbc_sdxl/
|
| 36 |
+
miragenews/encodings/predictions/image/merged/test5_cnn_sdxl/
|
| 37 |
+
miragenews/encodings/predictions/image/merged/validation/
|
| 38 |
+
|
| 39 |
+
miragenews/encodings/text/train/
|
| 40 |
+
miragenews/encodings/text/test1_nyt_mj/
|
| 41 |
+
miragenews/encodings/text/test2_bbc_dalle/
|
| 42 |
+
miragenews/encodings/text/test3_cnn_dalle/
|
| 43 |
+
miragenews/encodings/text/test4_bbc_sdxl/
|
| 44 |
+
miragenews/encodings/text/test5_cnn_sdxl/
|
| 45 |
+
miragenews/encodings/text/validation/
|
| 46 |
+
|
| 47 |
+
miragenews/encodings/predictions/text/merged/train/
|
| 48 |
+
miragenews/encodings/predictions/text/merged/test1_nyt_mj/
|
| 49 |
+
miragenews/encodings/predictions/text/merged/test2_bbc_dalle/
|
| 50 |
+
miragenews/encodings/predictions/text/merged/test3_cnn_dalle/
|
| 51 |
+
miragenews/encodings/predictions/text/merged/test4_bbc_sdxl/
|
| 52 |
+
miragenews/encodings/predictions/text/merged/test5_cnn_sdxl/
|
| 53 |
+
miragenews/encodings/predictions/text/merged/validation/
|
| 54 |
+
|
| 55 |
+
miragenews/encodings/predictions/text/tbm-encoder/train/
|
| 56 |
+
miragenews/encodings/predictions/text/tbm-encoder/test1_nyt_mj/
|
| 57 |
+
miragenews/encodings/predictions/text/tbm-encoder/test2_bbc_dalle/
|
| 58 |
+
miragenews/encodings/predictions/text/tbm-encoder/test3_cnn_dalle/
|
| 59 |
+
miragenews/encodings/predictions/text/tbm-encoder/test4_bbc_sdxl/
|
| 60 |
+
miragenews/encodings/predictions/text/tbm-encoder/test5_cnn_sdxl/
|
| 61 |
+
miragenews/encodings/predictions/text/tbm-encoder/validation/
|
| 62 |
+
|
| 63 |
+
miragenews/encodings/predictions/text/linear/train/
|
| 64 |
+
miragenews/encodings/predictions/text/linear/test1_nyt_mj/
|
| 65 |
+
miragenews/encodings/predictions/text/linear/test2_bbc_dalle/
|
| 66 |
+
miragenews/encodings/predictions/text/linear/test3_cnn_dalle/
|
| 67 |
+
miragenews/encodings/predictions/text/linear/test4_bbc_sdxl/
|
| 68 |
+
miragenews/encodings/predictions/text/linear/test5_cnn_sdxl/
|
| 69 |
+
miragenews/encodings/predictions/text/linear/validation/
|
| 70 |
+
AIGVDet/data/
|
AIGVDet/Dockerfile
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-runtime
|
| 2 |
+
|
| 3 |
+
# Install necessary OS packages for OpenCV
|
| 4 |
+
RUN apt-get update && apt-get install -y \
|
| 5 |
+
libgl1 \
|
| 6 |
+
libglib2.0-0 \
|
| 7 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 8 |
+
|
| 9 |
+
# Set working directory
|
| 10 |
+
WORKDIR /app
|
| 11 |
+
|
| 12 |
+
# Install Python dependencies
|
| 13 |
+
COPY requirements.txt .
|
| 14 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 15 |
+
|
| 16 |
+
# Copy all source code
|
| 17 |
+
COPY . .
|
| 18 |
+
|
| 19 |
+
# Default run command
|
| 20 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8003"]
|
| 21 |
+
|
AIGVDet/README.md
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## AIGVDet
|
| 2 |
+
An official implementation code for paper "AI-Generated Video Detection via Spatial-Temporal Anomaly Learning", PRCV 2024. This repo will provide <B>codes, trained weights, and our training datasets</B>.
|
| 3 |
+
|
| 4 |
+
## Network Architecture
|
| 5 |
+
<center> <img src="fig/NetworkArchitecture.png" alt="architecture"/> </center>
|
| 6 |
+
|
| 7 |
+
## Dataset
|
| 8 |
+
- Download the preprocessed training frames from
|
| 9 |
+
[Baiduyun Link](https://pan.baidu.com/s/17xmDyFjtcmNsoxmUeImMTQ?pwd=ra95) (extract code: ra95).
|
| 10 |
+
- Download the test videos from [Google Drive](https://drive.google.com/drive/folders/1D84SRWEJ8BK8KBpTMuGi3BUM80mW_dKb?usp=sharing).
|
| 11 |
+
|
| 12 |
+
**You are allowed to use the datasets for <B>research purpose only</B>.**
|
| 13 |
+
|
| 14 |
+
## Training
|
| 15 |
+
- Prepare for the training datasets.
|
| 16 |
+
```
|
| 17 |
+
└─data
|
| 18 |
+
├── train
|
| 19 |
+
│ └── trainset_1
|
| 20 |
+
│ ├── 0_real
|
| 21 |
+
│ │ ├── video_00000
|
| 22 |
+
│ │ │ ├── 00000.png
|
| 23 |
+
│ │ │ └── ...
|
| 24 |
+
│ │ └── ...
|
| 25 |
+
│ └── 1_fake
|
| 26 |
+
│ ├── video_00000
|
| 27 |
+
│ │ ├── 00000.png
|
| 28 |
+
│ │ └── ...
|
| 29 |
+
│ └── ...
|
| 30 |
+
├── val
|
| 31 |
+
│ └── val_set_1
|
| 32 |
+
│ ├── 0_real
|
| 33 |
+
│ │ ├── video_00000
|
| 34 |
+
│ │ │ ├── 00000.png
|
| 35 |
+
│ │ │ └── ...
|
| 36 |
+
│ │ └── ...
|
| 37 |
+
│ └── 1_fake
|
| 38 |
+
│ ├── video_00000
|
| 39 |
+
│ │ ├── 00000.png
|
| 40 |
+
│ │ └── ...
|
| 41 |
+
│ └── ...
|
| 42 |
+
└── test
|
| 43 |
+
└── testset_1
|
| 44 |
+
├── 0_real
|
| 45 |
+
│ ├── video_00000
|
| 46 |
+
│ │ ├── 00000.png
|
| 47 |
+
│ │ └── ...
|
| 48 |
+
│ └── ...
|
| 49 |
+
└── 1_fake
|
| 50 |
+
├── video_00000
|
| 51 |
+
│ ├── 00000.png
|
| 52 |
+
│ └── ...
|
| 53 |
+
└── ...
|
| 54 |
+
|
| 55 |
+
```
|
| 56 |
+
- Modify configuration file in `core/utils1/config.py`.
|
| 57 |
+
- Train the Spatial Domain Detector with the RGB frames.
|
| 58 |
+
```
|
| 59 |
+
python train.py --gpus 0 --exp_name TRAIN_RGB_BRANCH datasets RGB_TRAINSET datasets_test RGB_TESTSET
|
| 60 |
+
```
|
| 61 |
+
- Train the Optical Flow Detector with the optical flow frames.
|
| 62 |
+
```
|
| 63 |
+
python train.py --gpus 0 --exp_name TRAIN_OF_BRANCH datasets OpticalFlow_TRAINSET datasets_test OpticalFlow_TESTSET
|
| 64 |
+
```
|
| 65 |
+
## Testing
|
| 66 |
+
Download the weights from [Google Drive Link](https://drive.google.com/drive/folders/18JO_YxOEqwJYfbVvy308XjoV-N6fE4yP?usp=share_link) and move it into the `checkpoints/`.
|
| 67 |
+
|
| 68 |
+
- Run on a dataset.
|
| 69 |
+
Prepare the RGB frames and the optical flow maps.
|
| 70 |
+
```
|
| 71 |
+
python test.py -fop "data/test/hotshot" -mop "checkpoints/optical_aug.pth" -for "data/test/original/hotshot" -mor "checkpoints/original_aug.pth" -e "data/results/T2V/hotshot.csv" -ef "data/results/frame/T2V/hotshot.csv" -t 0.5
|
| 72 |
+
```
|
| 73 |
+
- Run on a video.
|
| 74 |
+
Download the RAFT model weights from [Google Drive Link](https://drive.google.com/file/d/1MqDajR89k-xLV0HIrmJ0k-n8ZpG6_suM/view) and move it into the `raft_model/`.
|
| 75 |
+
```
|
| 76 |
+
python demo.py --use_cpu --path "video/000000.mp4" --folder_original_path "frame/000000" --folder_optical_flow_path "optical_result/000000" -mop "checkpoints/optical.pth" -mor "checkpoints/original.pth"
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
## License
|
| 80 |
+
The code and dataset is released only for academic research. Commercial usage is strictly prohibited.
|
| 81 |
+
|
| 82 |
+
## Citation
|
| 83 |
+
```
|
| 84 |
+
@article{AIGVDet24,
|
| 85 |
+
author = {Jianfa Bai and Man Lin and Gang Cao and Zijie Lou},
|
| 86 |
+
title = {{AI-generated video detection via spatial-temporal anomaly learning}},
|
| 87 |
+
conference = {The 7th Chinese Conference on Pattern Recognition and Computer Vision (PRCV)},
|
| 88 |
+
year = {2024},}
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
## Contact
|
| 92 |
+
If you have any questions, please contact us(lyan924@cuc.edu.cn).
|
| 93 |
+
|
| 94 |
+
|
AIGVDet/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .main import run_video_to_json
|
| 2 |
+
|
| 3 |
+
__all__ = ["run_video_to_json"]
|
AIGVDet/alt_cuda_corr/correlation.cpp
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include <vector>
|
| 3 |
+
|
| 4 |
+
// CUDA forward declarations
|
| 5 |
+
std::vector<torch::Tensor> corr_cuda_forward(
|
| 6 |
+
torch::Tensor fmap1,
|
| 7 |
+
torch::Tensor fmap2,
|
| 8 |
+
torch::Tensor coords,
|
| 9 |
+
int radius);
|
| 10 |
+
|
| 11 |
+
std::vector<torch::Tensor> corr_cuda_backward(
|
| 12 |
+
torch::Tensor fmap1,
|
| 13 |
+
torch::Tensor fmap2,
|
| 14 |
+
torch::Tensor coords,
|
| 15 |
+
torch::Tensor corr_grad,
|
| 16 |
+
int radius);
|
| 17 |
+
|
| 18 |
+
// C++ interface
|
| 19 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
| 20 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 21 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
| 22 |
+
|
| 23 |
+
std::vector<torch::Tensor> corr_forward(
|
| 24 |
+
torch::Tensor fmap1,
|
| 25 |
+
torch::Tensor fmap2,
|
| 26 |
+
torch::Tensor coords,
|
| 27 |
+
int radius) {
|
| 28 |
+
CHECK_INPUT(fmap1);
|
| 29 |
+
CHECK_INPUT(fmap2);
|
| 30 |
+
CHECK_INPUT(coords);
|
| 31 |
+
|
| 32 |
+
return corr_cuda_forward(fmap1, fmap2, coords, radius);
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
std::vector<torch::Tensor> corr_backward(
|
| 37 |
+
torch::Tensor fmap1,
|
| 38 |
+
torch::Tensor fmap2,
|
| 39 |
+
torch::Tensor coords,
|
| 40 |
+
torch::Tensor corr_grad,
|
| 41 |
+
int radius) {
|
| 42 |
+
CHECK_INPUT(fmap1);
|
| 43 |
+
CHECK_INPUT(fmap2);
|
| 44 |
+
CHECK_INPUT(coords);
|
| 45 |
+
CHECK_INPUT(corr_grad);
|
| 46 |
+
|
| 47 |
+
return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 52 |
+
m.def("forward", &corr_forward, "CORR forward");
|
| 53 |
+
m.def("backward", &corr_backward, "CORR backward");
|
| 54 |
+
}
|
AIGVDet/alt_cuda_corr/correlation_kernel.cu
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include <cuda.h>
|
| 3 |
+
#include <cuda_runtime.h>
|
| 4 |
+
#include <vector>
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
#define BLOCK_H 4
|
| 8 |
+
#define BLOCK_W 8
|
| 9 |
+
#define BLOCK_HW BLOCK_H * BLOCK_W
|
| 10 |
+
#define CHANNEL_STRIDE 32
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
__forceinline__ __device__
|
| 14 |
+
bool within_bounds(int h, int w, int H, int W) {
|
| 15 |
+
return h >= 0 && h < H && w >= 0 && w < W;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
template <typename scalar_t>
|
| 19 |
+
__global__ void corr_forward_kernel(
|
| 20 |
+
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,
|
| 21 |
+
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,
|
| 22 |
+
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords,
|
| 23 |
+
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr,
|
| 24 |
+
int r)
|
| 25 |
+
{
|
| 26 |
+
const int b = blockIdx.x;
|
| 27 |
+
const int h0 = blockIdx.y * blockDim.x;
|
| 28 |
+
const int w0 = blockIdx.z * blockDim.y;
|
| 29 |
+
const int tid = threadIdx.x * blockDim.y + threadIdx.y;
|
| 30 |
+
|
| 31 |
+
const int H1 = fmap1.size(1);
|
| 32 |
+
const int W1 = fmap1.size(2);
|
| 33 |
+
const int H2 = fmap2.size(1);
|
| 34 |
+
const int W2 = fmap2.size(2);
|
| 35 |
+
const int N = coords.size(1);
|
| 36 |
+
const int C = fmap1.size(3);
|
| 37 |
+
|
| 38 |
+
__shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
|
| 39 |
+
__shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
|
| 40 |
+
__shared__ scalar_t x2s[BLOCK_HW];
|
| 41 |
+
__shared__ scalar_t y2s[BLOCK_HW];
|
| 42 |
+
|
| 43 |
+
for (int c=0; c<C; c+=CHANNEL_STRIDE) {
|
| 44 |
+
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
| 45 |
+
int k1 = k + tid / CHANNEL_STRIDE;
|
| 46 |
+
int h1 = h0 + k1 / BLOCK_W;
|
| 47 |
+
int w1 = w0 + k1 % BLOCK_W;
|
| 48 |
+
int c1 = tid % CHANNEL_STRIDE;
|
| 49 |
+
|
| 50 |
+
auto fptr = fmap1[b][h1][w1];
|
| 51 |
+
if (within_bounds(h1, w1, H1, W1))
|
| 52 |
+
f1[c1][k1] = fptr[c+c1];
|
| 53 |
+
else
|
| 54 |
+
f1[c1][k1] = 0.0;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
__syncthreads();
|
| 58 |
+
|
| 59 |
+
for (int n=0; n<N; n++) {
|
| 60 |
+
int h1 = h0 + threadIdx.x;
|
| 61 |
+
int w1 = w0 + threadIdx.y;
|
| 62 |
+
if (within_bounds(h1, w1, H1, W1)) {
|
| 63 |
+
x2s[tid] = coords[b][n][h1][w1][0];
|
| 64 |
+
y2s[tid] = coords[b][n][h1][w1][1];
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
scalar_t dx = x2s[tid] - floor(x2s[tid]);
|
| 68 |
+
scalar_t dy = y2s[tid] - floor(y2s[tid]);
|
| 69 |
+
|
| 70 |
+
int rd = 2*r + 1;
|
| 71 |
+
for (int iy=0; iy<rd+1; iy++) {
|
| 72 |
+
for (int ix=0; ix<rd+1; ix++) {
|
| 73 |
+
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
| 74 |
+
int k1 = k + tid / CHANNEL_STRIDE;
|
| 75 |
+
int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
|
| 76 |
+
int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
|
| 77 |
+
int c2 = tid % CHANNEL_STRIDE;
|
| 78 |
+
|
| 79 |
+
auto fptr = fmap2[b][h2][w2];
|
| 80 |
+
if (within_bounds(h2, w2, H2, W2))
|
| 81 |
+
f2[c2][k1] = fptr[c+c2];
|
| 82 |
+
else
|
| 83 |
+
f2[c2][k1] = 0.0;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
__syncthreads();
|
| 87 |
+
|
| 88 |
+
scalar_t s = 0.0;
|
| 89 |
+
for (int k=0; k<CHANNEL_STRIDE; k++)
|
| 90 |
+
s += f1[k][tid] * f2[k][tid];
|
| 91 |
+
|
| 92 |
+
int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
|
| 93 |
+
int ix_ne = H1*W1*((iy-1) + rd*ix);
|
| 94 |
+
int ix_sw = H1*W1*(iy + rd*(ix-1));
|
| 95 |
+
int ix_se = H1*W1*(iy + rd*ix);
|
| 96 |
+
|
| 97 |
+
scalar_t nw = s * (dy) * (dx);
|
| 98 |
+
scalar_t ne = s * (dy) * (1-dx);
|
| 99 |
+
scalar_t sw = s * (1-dy) * (dx);
|
| 100 |
+
scalar_t se = s * (1-dy) * (1-dx);
|
| 101 |
+
|
| 102 |
+
scalar_t* corr_ptr = &corr[b][n][0][h1][w1];
|
| 103 |
+
|
| 104 |
+
if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
|
| 105 |
+
*(corr_ptr + ix_nw) += nw;
|
| 106 |
+
|
| 107 |
+
if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
|
| 108 |
+
*(corr_ptr + ix_ne) += ne;
|
| 109 |
+
|
| 110 |
+
if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
|
| 111 |
+
*(corr_ptr + ix_sw) += sw;
|
| 112 |
+
|
| 113 |
+
if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
|
| 114 |
+
*(corr_ptr + ix_se) += se;
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
template <typename scalar_t>
|
| 123 |
+
__global__ void corr_backward_kernel(
|
| 124 |
+
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,
|
| 125 |
+
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,
|
| 126 |
+
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords,
|
| 127 |
+
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr_grad,
|
| 128 |
+
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1_grad,
|
| 129 |
+
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2_grad,
|
| 130 |
+
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords_grad,
|
| 131 |
+
int r)
|
| 132 |
+
{
|
| 133 |
+
|
| 134 |
+
const int b = blockIdx.x;
|
| 135 |
+
const int h0 = blockIdx.y * blockDim.x;
|
| 136 |
+
const int w0 = blockIdx.z * blockDim.y;
|
| 137 |
+
const int tid = threadIdx.x * blockDim.y + threadIdx.y;
|
| 138 |
+
|
| 139 |
+
const int H1 = fmap1.size(1);
|
| 140 |
+
const int W1 = fmap1.size(2);
|
| 141 |
+
const int H2 = fmap2.size(1);
|
| 142 |
+
const int W2 = fmap2.size(2);
|
| 143 |
+
const int N = coords.size(1);
|
| 144 |
+
const int C = fmap1.size(3);
|
| 145 |
+
|
| 146 |
+
__shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
|
| 147 |
+
__shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
|
| 148 |
+
|
| 149 |
+
__shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1];
|
| 150 |
+
__shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1];
|
| 151 |
+
|
| 152 |
+
__shared__ scalar_t x2s[BLOCK_HW];
|
| 153 |
+
__shared__ scalar_t y2s[BLOCK_HW];
|
| 154 |
+
|
| 155 |
+
for (int c=0; c<C; c+=CHANNEL_STRIDE) {
|
| 156 |
+
|
| 157 |
+
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
| 158 |
+
int k1 = k + tid / CHANNEL_STRIDE;
|
| 159 |
+
int h1 = h0 + k1 / BLOCK_W;
|
| 160 |
+
int w1 = w0 + k1 % BLOCK_W;
|
| 161 |
+
int c1 = tid % CHANNEL_STRIDE;
|
| 162 |
+
|
| 163 |
+
auto fptr = fmap1[b][h1][w1];
|
| 164 |
+
if (within_bounds(h1, w1, H1, W1))
|
| 165 |
+
f1[c1][k1] = fptr[c+c1];
|
| 166 |
+
else
|
| 167 |
+
f1[c1][k1] = 0.0;
|
| 168 |
+
|
| 169 |
+
f1_grad[c1][k1] = 0.0;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
__syncthreads();
|
| 173 |
+
|
| 174 |
+
int h1 = h0 + threadIdx.x;
|
| 175 |
+
int w1 = w0 + threadIdx.y;
|
| 176 |
+
|
| 177 |
+
for (int n=0; n<N; n++) {
|
| 178 |
+
x2s[tid] = coords[b][n][h1][w1][0];
|
| 179 |
+
y2s[tid] = coords[b][n][h1][w1][1];
|
| 180 |
+
|
| 181 |
+
scalar_t dx = x2s[tid] - floor(x2s[tid]);
|
| 182 |
+
scalar_t dy = y2s[tid] - floor(y2s[tid]);
|
| 183 |
+
|
| 184 |
+
int rd = 2*r + 1;
|
| 185 |
+
for (int iy=0; iy<rd+1; iy++) {
|
| 186 |
+
for (int ix=0; ix<rd+1; ix++) {
|
| 187 |
+
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
| 188 |
+
int k1 = k + tid / CHANNEL_STRIDE;
|
| 189 |
+
int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
|
| 190 |
+
int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
|
| 191 |
+
int c2 = tid % CHANNEL_STRIDE;
|
| 192 |
+
|
| 193 |
+
auto fptr = fmap2[b][h2][w2];
|
| 194 |
+
if (within_bounds(h2, w2, H2, W2))
|
| 195 |
+
f2[c2][k1] = fptr[c+c2];
|
| 196 |
+
else
|
| 197 |
+
f2[c2][k1] = 0.0;
|
| 198 |
+
|
| 199 |
+
f2_grad[c2][k1] = 0.0;
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
__syncthreads();
|
| 203 |
+
|
| 204 |
+
const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1];
|
| 205 |
+
scalar_t g = 0.0;
|
| 206 |
+
|
| 207 |
+
int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
|
| 208 |
+
int ix_ne = H1*W1*((iy-1) + rd*ix);
|
| 209 |
+
int ix_sw = H1*W1*(iy + rd*(ix-1));
|
| 210 |
+
int ix_se = H1*W1*(iy + rd*ix);
|
| 211 |
+
|
| 212 |
+
if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
|
| 213 |
+
g += *(grad_ptr + ix_nw) * dy * dx;
|
| 214 |
+
|
| 215 |
+
if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
|
| 216 |
+
g += *(grad_ptr + ix_ne) * dy * (1-dx);
|
| 217 |
+
|
| 218 |
+
if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
|
| 219 |
+
g += *(grad_ptr + ix_sw) * (1-dy) * dx;
|
| 220 |
+
|
| 221 |
+
if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
|
| 222 |
+
g += *(grad_ptr + ix_se) * (1-dy) * (1-dx);
|
| 223 |
+
|
| 224 |
+
for (int k=0; k<CHANNEL_STRIDE; k++) {
|
| 225 |
+
f1_grad[k][tid] += g * f2[k][tid];
|
| 226 |
+
f2_grad[k][tid] += g * f1[k][tid];
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
| 230 |
+
int k1 = k + tid / CHANNEL_STRIDE;
|
| 231 |
+
int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
|
| 232 |
+
int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
|
| 233 |
+
int c2 = tid % CHANNEL_STRIDE;
|
| 234 |
+
|
| 235 |
+
scalar_t* fptr = &fmap2_grad[b][h2][w2][0];
|
| 236 |
+
if (within_bounds(h2, w2, H2, W2))
|
| 237 |
+
atomicAdd(fptr+c+c2, f2_grad[c2][k1]);
|
| 238 |
+
}
|
| 239 |
+
}
|
| 240 |
+
}
|
| 241 |
+
}
|
| 242 |
+
__syncthreads();
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
| 246 |
+
int k1 = k + tid / CHANNEL_STRIDE;
|
| 247 |
+
int h1 = h0 + k1 / BLOCK_W;
|
| 248 |
+
int w1 = w0 + k1 % BLOCK_W;
|
| 249 |
+
int c1 = tid % CHANNEL_STRIDE;
|
| 250 |
+
|
| 251 |
+
scalar_t* fptr = &fmap1_grad[b][h1][w1][0];
|
| 252 |
+
if (within_bounds(h1, w1, H1, W1))
|
| 253 |
+
fptr[c+c1] += f1_grad[c1][k1];
|
| 254 |
+
}
|
| 255 |
+
}
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
std::vector<torch::Tensor> corr_cuda_forward(
|
| 261 |
+
torch::Tensor fmap1,
|
| 262 |
+
torch::Tensor fmap2,
|
| 263 |
+
torch::Tensor coords,
|
| 264 |
+
int radius)
|
| 265 |
+
{
|
| 266 |
+
const auto B = coords.size(0);
|
| 267 |
+
const auto N = coords.size(1);
|
| 268 |
+
const auto H = coords.size(2);
|
| 269 |
+
const auto W = coords.size(3);
|
| 270 |
+
|
| 271 |
+
const auto rd = 2 * radius + 1;
|
| 272 |
+
auto opts = fmap1.options();
|
| 273 |
+
auto corr = torch::zeros({B, N, rd*rd, H, W}, opts);
|
| 274 |
+
|
| 275 |
+
const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W);
|
| 276 |
+
const dim3 threads(BLOCK_H, BLOCK_W);
|
| 277 |
+
|
| 278 |
+
corr_forward_kernel<float><<<blocks, threads>>>(
|
| 279 |
+
fmap1.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
| 280 |
+
fmap2.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
| 281 |
+
coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
|
| 282 |
+
corr.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
|
| 283 |
+
radius);
|
| 284 |
+
|
| 285 |
+
return {corr};
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
std::vector<torch::Tensor> corr_cuda_backward(
|
| 289 |
+
torch::Tensor fmap1,
|
| 290 |
+
torch::Tensor fmap2,
|
| 291 |
+
torch::Tensor coords,
|
| 292 |
+
torch::Tensor corr_grad,
|
| 293 |
+
int radius)
|
| 294 |
+
{
|
| 295 |
+
const auto B = coords.size(0);
|
| 296 |
+
const auto N = coords.size(1);
|
| 297 |
+
|
| 298 |
+
const auto H1 = fmap1.size(1);
|
| 299 |
+
const auto W1 = fmap1.size(2);
|
| 300 |
+
const auto H2 = fmap2.size(1);
|
| 301 |
+
const auto W2 = fmap2.size(2);
|
| 302 |
+
const auto C = fmap1.size(3);
|
| 303 |
+
|
| 304 |
+
auto opts = fmap1.options();
|
| 305 |
+
auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts);
|
| 306 |
+
auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts);
|
| 307 |
+
auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts);
|
| 308 |
+
|
| 309 |
+
const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W);
|
| 310 |
+
const dim3 threads(BLOCK_H, BLOCK_W);
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
corr_backward_kernel<float><<<blocks, threads>>>(
|
| 314 |
+
fmap1.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
| 315 |
+
fmap2.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
| 316 |
+
coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
|
| 317 |
+
corr_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
|
| 318 |
+
fmap1_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
| 319 |
+
fmap2_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
| 320 |
+
coords_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
|
| 321 |
+
radius);
|
| 322 |
+
|
| 323 |
+
return {fmap1_grad, fmap2_grad, coords_grad};
|
| 324 |
+
}
|
AIGVDet/alt_cuda_corr/setup.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup
|
| 2 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
setup(
|
| 6 |
+
name='correlation',
|
| 7 |
+
ext_modules=[
|
| 8 |
+
CUDAExtension('alt_cuda_corr',
|
| 9 |
+
sources=['correlation.cpp', 'correlation_kernel.cu'],
|
| 10 |
+
extra_compile_args={'cxx': [], 'nvcc': ['-O3']}),
|
| 11 |
+
],
|
| 12 |
+
cmdclass={
|
| 13 |
+
'build_ext': BuildExtension
|
| 14 |
+
})
|
| 15 |
+
|
AIGVDet/app.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import uuid
|
| 5 |
+
import time
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException, status
|
| 9 |
+
from pydantic import BaseModel
|
| 10 |
+
from run import RUN
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class PredictionResponse(BaseModel):
|
| 14 |
+
authenticity_assessment: str
|
| 15 |
+
verification_tools_methods: str
|
| 16 |
+
synthetic_type: str
|
| 17 |
+
other_artifacts: str
|
| 18 |
+
|
| 19 |
+
app = FastAPI(
|
| 20 |
+
title="Video Authenticity API",
|
| 21 |
+
description="Detect authentic vs synthetic video using deepfake detector.",
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def predict(input_path: str) -> dict:
|
| 25 |
+
result_data = {
|
| 26 |
+
"authenticity_assessment": "Error",
|
| 27 |
+
"verification_tools_methods": "Deepfake Detection Model (Optical Flow + Frame Analysis)",
|
| 28 |
+
"synthetic_type": "N/A",
|
| 29 |
+
"other_artifacts": "Analysis failed."
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
if not os.path.isfile(input_path):
|
| 33 |
+
raise FileNotFoundError(f"File not found: {input_path}")
|
| 34 |
+
|
| 35 |
+
video_name = os.path.basename(input_path)
|
| 36 |
+
video_id = os.path.splitext(video_name)[0]
|
| 37 |
+
|
| 38 |
+
folder_original = f"frame/{video_id}"
|
| 39 |
+
folder_optical = f"optical_result/{video_id}"
|
| 40 |
+
|
| 41 |
+
args = [
|
| 42 |
+
"--path", input_path,
|
| 43 |
+
"--folder_original_path", folder_original,
|
| 44 |
+
"--folder_optical_flow_path", folder_optical,
|
| 45 |
+
"--model_optical_flow_path", "checkpoints/optical.pth",
|
| 46 |
+
"--model_original_path", "checkpoints/original.pth",
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
start_time = time.perf_counter()
|
| 51 |
+
|
| 52 |
+
output = RUN(args)
|
| 53 |
+
|
| 54 |
+
elapsed = time.perf_counter() - start_time
|
| 55 |
+
print(f"⏱️ [run_AIGVDetection] Service call took {elapsed:.2f} seconds")
|
| 56 |
+
|
| 57 |
+
real_score = float(output.get("real_score", 0.0))
|
| 58 |
+
fake_score = float(output.get("fake_score", 0.0))
|
| 59 |
+
|
| 60 |
+
likely_authentic = real_score > fake_score
|
| 61 |
+
|
| 62 |
+
if likely_authentic:
|
| 63 |
+
assessment = f"REAL (Authentic) | Confidence: {real_score:.4f}"
|
| 64 |
+
|
| 65 |
+
analysis_text = (
|
| 66 |
+
"Our algorithms observed consistent and natural motion patterns across frames. "
|
| 67 |
+
"Inter-frame motion analysis indicates that objects maintain physical trajectories consistent with real-world recording, "
|
| 68 |
+
"without the jitter or warping artifacts typically associated with generative AI."
|
| 69 |
+
)
|
| 70 |
+
syn_type = "N/A"
|
| 71 |
+
else:
|
| 72 |
+
assessment = f"🤖 NOT REAL (Fake/Synthetic) | Confidence: {fake_score:.4f}"
|
| 73 |
+
|
| 74 |
+
analysis_text = (
|
| 75 |
+
"Our algorithms have detected asynchronous and inconsistent movement between frames. "
|
| 76 |
+
"Upon conducting inter-frame motion analysis, we observed that objects and details within the video "
|
| 77 |
+
"fail to maintain natural motion trajectories. These anomalies—such as sudden velocity shifts, "
|
| 78 |
+
"subtle per-frame distortions, or motion vectors that defy physical laws—are characteristic indicators "
|
| 79 |
+
"typically found in AI-generated videos."
|
| 80 |
+
)
|
| 81 |
+
syn_type = "Video Deepfake / AI Generated"
|
| 82 |
+
|
| 83 |
+
tools = "Deepfake Detector (Optical Flow + CNN Frame Analysis)"
|
| 84 |
+
|
| 85 |
+
artifacts = (
|
| 86 |
+
f"{analysis_text}"
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
result_data = {
|
| 90 |
+
"authenticity_assessment": assessment,
|
| 91 |
+
"verification_tools_methods": tools,
|
| 92 |
+
"synthetic_type": syn_type,
|
| 93 |
+
"other_artifacts": artifacts
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
except Exception as e:
|
| 97 |
+
import traceback
|
| 98 |
+
traceback.print_exc()
|
| 99 |
+
result_data["other_artifacts"] = f"Error during processing: {str(e)}"
|
| 100 |
+
|
| 101 |
+
finally:
|
| 102 |
+
for folder in [folder_original, folder_optical]:
|
| 103 |
+
try:
|
| 104 |
+
if os.path.exists(folder):
|
| 105 |
+
shutil.rmtree(folder)
|
| 106 |
+
except Exception as cleanup_error:
|
| 107 |
+
print(f"Error deleting folder {folder}: {cleanup_error}")
|
| 108 |
+
|
| 109 |
+
return result_data
|
| 110 |
+
|
| 111 |
+
# --- API ENDPOINT ---
|
| 112 |
+
@app.post("/predict", response_model=PredictionResponse)
|
| 113 |
+
async def predict_endpoint(file: UploadFile = File(...)):
|
| 114 |
+
try:
|
| 115 |
+
for parent in ["frame", "optical_result", "uploads"]:
|
| 116 |
+
if os.path.exists(parent):
|
| 117 |
+
print(f"🧹 Cleaning folder: {parent}")
|
| 118 |
+
for item in os.listdir(parent):
|
| 119 |
+
path = os.path.join(parent, item)
|
| 120 |
+
try:
|
| 121 |
+
if os.path.isfile(path) or os.path.islink(path):
|
| 122 |
+
os.remove(path)
|
| 123 |
+
elif os.path.isdir(path):
|
| 124 |
+
shutil.rmtree(path)
|
| 125 |
+
except Exception as e:
|
| 126 |
+
print(f"⚠️ Cannot delete {path}: {e}")
|
| 127 |
+
else:
|
| 128 |
+
os.makedirs(parent)
|
| 129 |
+
print(f"📁 Created new folder: {parent}")
|
| 130 |
+
|
| 131 |
+
temp_filename = f"uploads_{uuid.uuid4().hex}_{file.filename}"
|
| 132 |
+
os.makedirs("uploads", exist_ok=True)
|
| 133 |
+
temp_filepath = os.path.join("uploads", temp_filename)
|
| 134 |
+
|
| 135 |
+
with open(temp_filepath, "wb") as buffer:
|
| 136 |
+
shutil.copyfileobj(file.file, buffer)
|
| 137 |
+
|
| 138 |
+
result = predict(temp_filepath)
|
| 139 |
+
|
| 140 |
+
if os.path.exists(temp_filepath):
|
| 141 |
+
os.remove(temp_filepath)
|
| 142 |
+
|
| 143 |
+
return result
|
| 144 |
+
|
| 145 |
+
except Exception as e:
|
| 146 |
+
import traceback
|
| 147 |
+
traceback.print_exc()
|
| 148 |
+
raise HTTPException(
|
| 149 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 150 |
+
detail=f"Prediction failed: {e}",
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
if __name__ == "__main__":
|
| 154 |
+
import uvicorn
|
| 155 |
+
uvicorn.run(app, host="0.0.0.0", port=80022)
|
AIGVDet/checkpoints/optical.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:23a167ba7adccb6421fd0eddd81fbfc69519d2a97a343dbf0d5da894b9893b19
|
| 3 |
+
size 282581704
|
AIGVDet/checkpoints/original.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c2df6f477a590f4b24b14eac9654b868a9f178312795df20749274a502d59bdd
|
| 3 |
+
size 282581704
|
AIGVDet/core/__init__.py
ADDED
|
File without changes
|
AIGVDet/core/corr.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from .utils.utils import bilinear_sampler, coords_grid
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
from .. import alt_cuda_corr
|
| 7 |
+
except:
|
| 8 |
+
# alt_cuda_corr is not compiled
|
| 9 |
+
pass
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CorrBlock:
|
| 13 |
+
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
| 14 |
+
self.num_levels = num_levels
|
| 15 |
+
self.radius = radius
|
| 16 |
+
self.corr_pyramid = []
|
| 17 |
+
|
| 18 |
+
# all pairs correlation
|
| 19 |
+
corr = CorrBlock.corr(fmap1, fmap2)
|
| 20 |
+
|
| 21 |
+
batch, h1, w1, dim, h2, w2 = corr.shape
|
| 22 |
+
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
|
| 23 |
+
|
| 24 |
+
self.corr_pyramid.append(corr)
|
| 25 |
+
for i in range(self.num_levels-1):
|
| 26 |
+
corr = F.avg_pool2d(corr, 2, stride=2)
|
| 27 |
+
self.corr_pyramid.append(corr)
|
| 28 |
+
|
| 29 |
+
def __call__(self, coords):
|
| 30 |
+
r = self.radius
|
| 31 |
+
coords = coords.permute(0, 2, 3, 1)
|
| 32 |
+
batch, h1, w1, _ = coords.shape
|
| 33 |
+
|
| 34 |
+
out_pyramid = []
|
| 35 |
+
for i in range(self.num_levels):
|
| 36 |
+
corr = self.corr_pyramid[i]
|
| 37 |
+
dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
|
| 38 |
+
dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
|
| 39 |
+
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
|
| 40 |
+
|
| 41 |
+
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
|
| 42 |
+
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
|
| 43 |
+
coords_lvl = centroid_lvl + delta_lvl
|
| 44 |
+
|
| 45 |
+
corr = bilinear_sampler(corr, coords_lvl)
|
| 46 |
+
corr = corr.view(batch, h1, w1, -1)
|
| 47 |
+
out_pyramid.append(corr)
|
| 48 |
+
|
| 49 |
+
out = torch.cat(out_pyramid, dim=-1)
|
| 50 |
+
return out.permute(0, 3, 1, 2).contiguous().float()
|
| 51 |
+
|
| 52 |
+
@staticmethod
|
| 53 |
+
def corr(fmap1, fmap2):
|
| 54 |
+
batch, dim, ht, wd = fmap1.shape
|
| 55 |
+
fmap1 = fmap1.view(batch, dim, ht*wd)
|
| 56 |
+
fmap2 = fmap2.view(batch, dim, ht*wd)
|
| 57 |
+
|
| 58 |
+
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
|
| 59 |
+
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
| 60 |
+
return corr / torch.sqrt(torch.tensor(dim).float())
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class AlternateCorrBlock:
|
| 64 |
+
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
| 65 |
+
self.num_levels = num_levels
|
| 66 |
+
self.radius = radius
|
| 67 |
+
|
| 68 |
+
self.pyramid = [(fmap1, fmap2)]
|
| 69 |
+
for i in range(self.num_levels):
|
| 70 |
+
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
|
| 71 |
+
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
|
| 72 |
+
self.pyramid.append((fmap1, fmap2))
|
| 73 |
+
|
| 74 |
+
def __call__(self, coords):
|
| 75 |
+
coords = coords.permute(0, 2, 3, 1)
|
| 76 |
+
B, H, W, _ = coords.shape
|
| 77 |
+
dim = self.pyramid[0][0].shape[1]
|
| 78 |
+
|
| 79 |
+
corr_list = []
|
| 80 |
+
for i in range(self.num_levels):
|
| 81 |
+
r = self.radius
|
| 82 |
+
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
|
| 83 |
+
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
|
| 84 |
+
|
| 85 |
+
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
|
| 86 |
+
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
|
| 87 |
+
corr_list.append(corr.squeeze(1))
|
| 88 |
+
|
| 89 |
+
corr = torch.stack(corr_list, dim=1)
|
| 90 |
+
corr = corr.reshape(B, -1, H, W)
|
| 91 |
+
return corr / torch.sqrt(torch.tensor(dim).float())
|
AIGVDet/core/datasets.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Data loading based on https://github.com/NVIDIA/flownet2-pytorch
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.utils.data as data
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import math
|
| 10 |
+
import random
|
| 11 |
+
from glob import glob
|
| 12 |
+
import os.path as osp
|
| 13 |
+
|
| 14 |
+
from utils import frame_utils
|
| 15 |
+
from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class FlowDataset(data.Dataset):
|
| 19 |
+
def __init__(self, aug_params=None, sparse=False):
|
| 20 |
+
self.augmentor = None
|
| 21 |
+
self.sparse = sparse
|
| 22 |
+
if aug_params is not None:
|
| 23 |
+
if sparse:
|
| 24 |
+
self.augmentor = SparseFlowAugmentor(**aug_params)
|
| 25 |
+
else:
|
| 26 |
+
self.augmentor = FlowAugmentor(**aug_params)
|
| 27 |
+
|
| 28 |
+
self.is_test = False
|
| 29 |
+
self.init_seed = False
|
| 30 |
+
self.flow_list = []
|
| 31 |
+
self.image_list = []
|
| 32 |
+
self.extra_info = []
|
| 33 |
+
|
| 34 |
+
def __getitem__(self, index):
|
| 35 |
+
|
| 36 |
+
if self.is_test:
|
| 37 |
+
img1 = frame_utils.read_gen(self.image_list[index][0])
|
| 38 |
+
img2 = frame_utils.read_gen(self.image_list[index][1])
|
| 39 |
+
img1 = np.array(img1).astype(np.uint8)[..., :3]
|
| 40 |
+
img2 = np.array(img2).astype(np.uint8)[..., :3]
|
| 41 |
+
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
|
| 42 |
+
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
|
| 43 |
+
return img1, img2, self.extra_info[index]
|
| 44 |
+
|
| 45 |
+
if not self.init_seed:
|
| 46 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 47 |
+
if worker_info is not None:
|
| 48 |
+
torch.manual_seed(worker_info.id)
|
| 49 |
+
np.random.seed(worker_info.id)
|
| 50 |
+
random.seed(worker_info.id)
|
| 51 |
+
self.init_seed = True
|
| 52 |
+
|
| 53 |
+
index = index % len(self.image_list)
|
| 54 |
+
valid = None
|
| 55 |
+
if self.sparse:
|
| 56 |
+
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
|
| 57 |
+
else:
|
| 58 |
+
flow = frame_utils.read_gen(self.flow_list[index])
|
| 59 |
+
|
| 60 |
+
img1 = frame_utils.read_gen(self.image_list[index][0])
|
| 61 |
+
img2 = frame_utils.read_gen(self.image_list[index][1])
|
| 62 |
+
|
| 63 |
+
flow = np.array(flow).astype(np.float32)
|
| 64 |
+
img1 = np.array(img1).astype(np.uint8)
|
| 65 |
+
img2 = np.array(img2).astype(np.uint8)
|
| 66 |
+
|
| 67 |
+
# grayscale images
|
| 68 |
+
if len(img1.shape) == 2:
|
| 69 |
+
img1 = np.tile(img1[...,None], (1, 1, 3))
|
| 70 |
+
img2 = np.tile(img2[...,None], (1, 1, 3))
|
| 71 |
+
else:
|
| 72 |
+
img1 = img1[..., :3]
|
| 73 |
+
img2 = img2[..., :3]
|
| 74 |
+
|
| 75 |
+
if self.augmentor is not None:
|
| 76 |
+
if self.sparse:
|
| 77 |
+
img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
|
| 78 |
+
else:
|
| 79 |
+
img1, img2, flow = self.augmentor(img1, img2, flow)
|
| 80 |
+
|
| 81 |
+
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
|
| 82 |
+
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
|
| 83 |
+
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
|
| 84 |
+
|
| 85 |
+
if valid is not None:
|
| 86 |
+
valid = torch.from_numpy(valid)
|
| 87 |
+
else:
|
| 88 |
+
valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
|
| 89 |
+
|
| 90 |
+
return img1, img2, flow, valid.float()
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def __rmul__(self, v):
|
| 94 |
+
self.flow_list = v * self.flow_list
|
| 95 |
+
self.image_list = v * self.image_list
|
| 96 |
+
return self
|
| 97 |
+
|
| 98 |
+
def __len__(self):
|
| 99 |
+
return len(self.image_list)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class MpiSintel(FlowDataset):
|
| 103 |
+
def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
|
| 104 |
+
super(MpiSintel, self).__init__(aug_params)
|
| 105 |
+
flow_root = osp.join(root, split, 'flow')
|
| 106 |
+
image_root = osp.join(root, split, dstype)
|
| 107 |
+
|
| 108 |
+
if split == 'test':
|
| 109 |
+
self.is_test = True
|
| 110 |
+
|
| 111 |
+
for scene in os.listdir(image_root):
|
| 112 |
+
image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
|
| 113 |
+
for i in range(len(image_list)-1):
|
| 114 |
+
self.image_list += [ [image_list[i], image_list[i+1]] ]
|
| 115 |
+
self.extra_info += [ (scene, i) ] # scene and frame_id
|
| 116 |
+
|
| 117 |
+
if split != 'test':
|
| 118 |
+
self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class FlyingChairs(FlowDataset):
|
| 122 |
+
def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
|
| 123 |
+
super(FlyingChairs, self).__init__(aug_params)
|
| 124 |
+
|
| 125 |
+
images = sorted(glob(osp.join(root, '*.ppm')))
|
| 126 |
+
flows = sorted(glob(osp.join(root, '*.flo')))
|
| 127 |
+
assert (len(images)//2 == len(flows))
|
| 128 |
+
|
| 129 |
+
split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
|
| 130 |
+
for i in range(len(flows)):
|
| 131 |
+
xid = split_list[i]
|
| 132 |
+
if (split=='training' and xid==1) or (split=='validation' and xid==2):
|
| 133 |
+
self.flow_list += [ flows[i] ]
|
| 134 |
+
self.image_list += [ [images[2*i], images[2*i+1]] ]
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class FlyingThings3D(FlowDataset):
|
| 138 |
+
def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
|
| 139 |
+
super(FlyingThings3D, self).__init__(aug_params)
|
| 140 |
+
|
| 141 |
+
for cam in ['left']:
|
| 142 |
+
for direction in ['into_future', 'into_past']:
|
| 143 |
+
image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
|
| 144 |
+
image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
|
| 145 |
+
|
| 146 |
+
flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
|
| 147 |
+
flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
|
| 148 |
+
|
| 149 |
+
for idir, fdir in zip(image_dirs, flow_dirs):
|
| 150 |
+
images = sorted(glob(osp.join(idir, '*.png')) )
|
| 151 |
+
flows = sorted(glob(osp.join(fdir, '*.pfm')) )
|
| 152 |
+
for i in range(len(flows)-1):
|
| 153 |
+
if direction == 'into_future':
|
| 154 |
+
self.image_list += [ [images[i], images[i+1]] ]
|
| 155 |
+
self.flow_list += [ flows[i] ]
|
| 156 |
+
elif direction == 'into_past':
|
| 157 |
+
self.image_list += [ [images[i+1], images[i]] ]
|
| 158 |
+
self.flow_list += [ flows[i+1] ]
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class KITTI(FlowDataset):
|
| 162 |
+
def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
|
| 163 |
+
super(KITTI, self).__init__(aug_params, sparse=True)
|
| 164 |
+
if split == 'testing':
|
| 165 |
+
self.is_test = True
|
| 166 |
+
|
| 167 |
+
root = osp.join(root, split)
|
| 168 |
+
images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
|
| 169 |
+
images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
|
| 170 |
+
|
| 171 |
+
for img1, img2 in zip(images1, images2):
|
| 172 |
+
frame_id = img1.split('/')[-1]
|
| 173 |
+
self.extra_info += [ [frame_id] ]
|
| 174 |
+
self.image_list += [ [img1, img2] ]
|
| 175 |
+
|
| 176 |
+
if split == 'training':
|
| 177 |
+
self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class HD1K(FlowDataset):
|
| 181 |
+
def __init__(self, aug_params=None, root='datasets/HD1k'):
|
| 182 |
+
super(HD1K, self).__init__(aug_params, sparse=True)
|
| 183 |
+
|
| 184 |
+
seq_ix = 0
|
| 185 |
+
while 1:
|
| 186 |
+
flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
|
| 187 |
+
images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
|
| 188 |
+
|
| 189 |
+
if len(flows) == 0:
|
| 190 |
+
break
|
| 191 |
+
|
| 192 |
+
for i in range(len(flows)-1):
|
| 193 |
+
self.flow_list += [flows[i]]
|
| 194 |
+
self.image_list += [ [images[i], images[i+1]] ]
|
| 195 |
+
|
| 196 |
+
seq_ix += 1
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
|
| 200 |
+
""" Create the data loader for the corresponding trainign set """
|
| 201 |
+
|
| 202 |
+
if args.stage == 'chairs':
|
| 203 |
+
aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
|
| 204 |
+
train_dataset = FlyingChairs(aug_params, split='training')
|
| 205 |
+
|
| 206 |
+
elif args.stage == 'things':
|
| 207 |
+
aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
|
| 208 |
+
clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
|
| 209 |
+
final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
|
| 210 |
+
train_dataset = clean_dataset + final_dataset
|
| 211 |
+
|
| 212 |
+
elif args.stage == 'sintel':
|
| 213 |
+
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
|
| 214 |
+
things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
|
| 215 |
+
sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
|
| 216 |
+
sintel_final = MpiSintel(aug_params, split='training', dstype='final')
|
| 217 |
+
|
| 218 |
+
if TRAIN_DS == 'C+T+K+S+H':
|
| 219 |
+
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
|
| 220 |
+
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
|
| 221 |
+
train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
|
| 222 |
+
|
| 223 |
+
elif TRAIN_DS == 'C+T+K/S':
|
| 224 |
+
train_dataset = 100*sintel_clean + 100*sintel_final + things
|
| 225 |
+
|
| 226 |
+
elif args.stage == 'kitti':
|
| 227 |
+
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
|
| 228 |
+
train_dataset = KITTI(aug_params, split='training')
|
| 229 |
+
|
| 230 |
+
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
|
| 231 |
+
pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
|
| 232 |
+
|
| 233 |
+
print('Training with %d image pairs' % len(train_dataset))
|
| 234 |
+
return train_loader
|
| 235 |
+
|
AIGVDet/core/extractor.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ResidualBlock(nn.Module):
|
| 7 |
+
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
| 8 |
+
super(ResidualBlock, self).__init__()
|
| 9 |
+
|
| 10 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
|
| 11 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
| 12 |
+
self.relu = nn.ReLU(inplace=True)
|
| 13 |
+
|
| 14 |
+
num_groups = planes // 8
|
| 15 |
+
|
| 16 |
+
if norm_fn == 'group':
|
| 17 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 18 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 19 |
+
if not stride == 1:
|
| 20 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 21 |
+
|
| 22 |
+
elif norm_fn == 'batch':
|
| 23 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
| 24 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
| 25 |
+
if not stride == 1:
|
| 26 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
| 27 |
+
|
| 28 |
+
elif norm_fn == 'instance':
|
| 29 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
| 30 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
| 31 |
+
if not stride == 1:
|
| 32 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
| 33 |
+
|
| 34 |
+
elif norm_fn == 'none':
|
| 35 |
+
self.norm1 = nn.Sequential()
|
| 36 |
+
self.norm2 = nn.Sequential()
|
| 37 |
+
if not stride == 1:
|
| 38 |
+
self.norm3 = nn.Sequential()
|
| 39 |
+
|
| 40 |
+
if stride == 1:
|
| 41 |
+
self.downsample = None
|
| 42 |
+
|
| 43 |
+
else:
|
| 44 |
+
self.downsample = nn.Sequential(
|
| 45 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
y = x
|
| 50 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
| 51 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
| 52 |
+
|
| 53 |
+
if self.downsample is not None:
|
| 54 |
+
x = self.downsample(x)
|
| 55 |
+
|
| 56 |
+
return self.relu(x+y)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class BottleneckBlock(nn.Module):
|
| 61 |
+
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
| 62 |
+
super(BottleneckBlock, self).__init__()
|
| 63 |
+
|
| 64 |
+
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
|
| 65 |
+
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
|
| 66 |
+
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
|
| 67 |
+
self.relu = nn.ReLU(inplace=True)
|
| 68 |
+
|
| 69 |
+
num_groups = planes // 8
|
| 70 |
+
|
| 71 |
+
if norm_fn == 'group':
|
| 72 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
| 73 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
| 74 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 75 |
+
if not stride == 1:
|
| 76 |
+
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 77 |
+
|
| 78 |
+
elif norm_fn == 'batch':
|
| 79 |
+
self.norm1 = nn.BatchNorm2d(planes//4)
|
| 80 |
+
self.norm2 = nn.BatchNorm2d(planes//4)
|
| 81 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
| 82 |
+
if not stride == 1:
|
| 83 |
+
self.norm4 = nn.BatchNorm2d(planes)
|
| 84 |
+
|
| 85 |
+
elif norm_fn == 'instance':
|
| 86 |
+
self.norm1 = nn.InstanceNorm2d(planes//4)
|
| 87 |
+
self.norm2 = nn.InstanceNorm2d(planes//4)
|
| 88 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
| 89 |
+
if not stride == 1:
|
| 90 |
+
self.norm4 = nn.InstanceNorm2d(planes)
|
| 91 |
+
|
| 92 |
+
elif norm_fn == 'none':
|
| 93 |
+
self.norm1 = nn.Sequential()
|
| 94 |
+
self.norm2 = nn.Sequential()
|
| 95 |
+
self.norm3 = nn.Sequential()
|
| 96 |
+
if not stride == 1:
|
| 97 |
+
self.norm4 = nn.Sequential()
|
| 98 |
+
|
| 99 |
+
if stride == 1:
|
| 100 |
+
self.downsample = None
|
| 101 |
+
|
| 102 |
+
else:
|
| 103 |
+
self.downsample = nn.Sequential(
|
| 104 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def forward(self, x):
|
| 108 |
+
y = x
|
| 109 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
| 110 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
| 111 |
+
y = self.relu(self.norm3(self.conv3(y)))
|
| 112 |
+
|
| 113 |
+
if self.downsample is not None:
|
| 114 |
+
x = self.downsample(x)
|
| 115 |
+
|
| 116 |
+
return self.relu(x+y)
|
| 117 |
+
|
| 118 |
+
class BasicEncoder(nn.Module):
|
| 119 |
+
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
| 120 |
+
super(BasicEncoder, self).__init__()
|
| 121 |
+
self.norm_fn = norm_fn
|
| 122 |
+
|
| 123 |
+
if self.norm_fn == 'group':
|
| 124 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
| 125 |
+
|
| 126 |
+
elif self.norm_fn == 'batch':
|
| 127 |
+
self.norm1 = nn.BatchNorm2d(64)
|
| 128 |
+
|
| 129 |
+
elif self.norm_fn == 'instance':
|
| 130 |
+
self.norm1 = nn.InstanceNorm2d(64)
|
| 131 |
+
|
| 132 |
+
elif self.norm_fn == 'none':
|
| 133 |
+
self.norm1 = nn.Sequential()
|
| 134 |
+
|
| 135 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
| 136 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 137 |
+
|
| 138 |
+
self.in_planes = 64
|
| 139 |
+
self.layer1 = self._make_layer(64, stride=1)
|
| 140 |
+
self.layer2 = self._make_layer(96, stride=2)
|
| 141 |
+
self.layer3 = self._make_layer(128, stride=2)
|
| 142 |
+
|
| 143 |
+
# output convolution
|
| 144 |
+
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
| 145 |
+
|
| 146 |
+
self.dropout = None
|
| 147 |
+
if dropout > 0:
|
| 148 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
| 149 |
+
|
| 150 |
+
for m in self.modules():
|
| 151 |
+
if isinstance(m, nn.Conv2d):
|
| 152 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 153 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
| 154 |
+
if m.weight is not None:
|
| 155 |
+
nn.init.constant_(m.weight, 1)
|
| 156 |
+
if m.bias is not None:
|
| 157 |
+
nn.init.constant_(m.bias, 0)
|
| 158 |
+
|
| 159 |
+
def _make_layer(self, dim, stride=1):
|
| 160 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
| 161 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
| 162 |
+
layers = (layer1, layer2)
|
| 163 |
+
|
| 164 |
+
self.in_planes = dim
|
| 165 |
+
return nn.Sequential(*layers)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def forward(self, x):
|
| 169 |
+
|
| 170 |
+
# if input is list, combine batch dimension
|
| 171 |
+
is_list = isinstance(x, tuple) or isinstance(x, list)
|
| 172 |
+
if is_list:
|
| 173 |
+
batch_dim = x[0].shape[0]
|
| 174 |
+
x = torch.cat(x, dim=0)
|
| 175 |
+
|
| 176 |
+
x = self.conv1(x)
|
| 177 |
+
x = self.norm1(x)
|
| 178 |
+
x = self.relu1(x)
|
| 179 |
+
|
| 180 |
+
x = self.layer1(x)
|
| 181 |
+
x = self.layer2(x)
|
| 182 |
+
x = self.layer3(x)
|
| 183 |
+
|
| 184 |
+
x = self.conv2(x)
|
| 185 |
+
|
| 186 |
+
if self.training and self.dropout is not None:
|
| 187 |
+
x = self.dropout(x)
|
| 188 |
+
|
| 189 |
+
if is_list:
|
| 190 |
+
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
| 191 |
+
|
| 192 |
+
return x
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class SmallEncoder(nn.Module):
|
| 196 |
+
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
| 197 |
+
super(SmallEncoder, self).__init__()
|
| 198 |
+
self.norm_fn = norm_fn
|
| 199 |
+
|
| 200 |
+
if self.norm_fn == 'group':
|
| 201 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
|
| 202 |
+
|
| 203 |
+
elif self.norm_fn == 'batch':
|
| 204 |
+
self.norm1 = nn.BatchNorm2d(32)
|
| 205 |
+
|
| 206 |
+
elif self.norm_fn == 'instance':
|
| 207 |
+
self.norm1 = nn.InstanceNorm2d(32)
|
| 208 |
+
|
| 209 |
+
elif self.norm_fn == 'none':
|
| 210 |
+
self.norm1 = nn.Sequential()
|
| 211 |
+
|
| 212 |
+
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
|
| 213 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 214 |
+
|
| 215 |
+
self.in_planes = 32
|
| 216 |
+
self.layer1 = self._make_layer(32, stride=1)
|
| 217 |
+
self.layer2 = self._make_layer(64, stride=2)
|
| 218 |
+
self.layer3 = self._make_layer(96, stride=2)
|
| 219 |
+
|
| 220 |
+
self.dropout = None
|
| 221 |
+
if dropout > 0:
|
| 222 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
| 223 |
+
|
| 224 |
+
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
|
| 225 |
+
|
| 226 |
+
for m in self.modules():
|
| 227 |
+
if isinstance(m, nn.Conv2d):
|
| 228 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 229 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
| 230 |
+
if m.weight is not None:
|
| 231 |
+
nn.init.constant_(m.weight, 1)
|
| 232 |
+
if m.bias is not None:
|
| 233 |
+
nn.init.constant_(m.bias, 0)
|
| 234 |
+
|
| 235 |
+
def _make_layer(self, dim, stride=1):
|
| 236 |
+
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
| 237 |
+
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
|
| 238 |
+
layers = (layer1, layer2)
|
| 239 |
+
|
| 240 |
+
self.in_planes = dim
|
| 241 |
+
return nn.Sequential(*layers)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def forward(self, x):
|
| 245 |
+
|
| 246 |
+
# if input is list, combine batch dimension
|
| 247 |
+
is_list = isinstance(x, tuple) or isinstance(x, list)
|
| 248 |
+
if is_list:
|
| 249 |
+
batch_dim = x[0].shape[0]
|
| 250 |
+
x = torch.cat(x, dim=0)
|
| 251 |
+
|
| 252 |
+
x = self.conv1(x)
|
| 253 |
+
x = self.norm1(x)
|
| 254 |
+
x = self.relu1(x)
|
| 255 |
+
|
| 256 |
+
x = self.layer1(x)
|
| 257 |
+
x = self.layer2(x)
|
| 258 |
+
x = self.layer3(x)
|
| 259 |
+
x = self.conv2(x)
|
| 260 |
+
|
| 261 |
+
if self.training and self.dropout is not None:
|
| 262 |
+
x = self.dropout(x)
|
| 263 |
+
|
| 264 |
+
if is_list:
|
| 265 |
+
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
| 266 |
+
|
| 267 |
+
return x
|
AIGVDet/core/raft.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from .update import BasicUpdateBlock, SmallUpdateBlock
|
| 7 |
+
from .extractor import BasicEncoder, SmallEncoder
|
| 8 |
+
from .corr import CorrBlock, AlternateCorrBlock
|
| 9 |
+
from .utils.utils import bilinear_sampler, coords_grid, upflow8
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
autocast = torch.cuda.amp.autocast
|
| 13 |
+
except:
|
| 14 |
+
# dummy autocast for PyTorch < 1.6
|
| 15 |
+
class autocast:
|
| 16 |
+
def __init__(self, enabled):
|
| 17 |
+
pass
|
| 18 |
+
def __enter__(self):
|
| 19 |
+
pass
|
| 20 |
+
def __exit__(self, *args):
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class RAFT(nn.Module):
|
| 25 |
+
def __init__(self, args):
|
| 26 |
+
super(RAFT, self).__init__()
|
| 27 |
+
self.args = args
|
| 28 |
+
|
| 29 |
+
if args.small:
|
| 30 |
+
self.hidden_dim = hdim = 96
|
| 31 |
+
self.context_dim = cdim = 64
|
| 32 |
+
args.corr_levels = 4
|
| 33 |
+
args.corr_radius = 3
|
| 34 |
+
|
| 35 |
+
else:
|
| 36 |
+
self.hidden_dim = hdim = 128
|
| 37 |
+
self.context_dim = cdim = 128
|
| 38 |
+
args.corr_levels = 4
|
| 39 |
+
args.corr_radius = 4
|
| 40 |
+
|
| 41 |
+
if 'dropout' not in self.args:
|
| 42 |
+
self.args.dropout = 0
|
| 43 |
+
|
| 44 |
+
if 'alternate_corr' not in self.args:
|
| 45 |
+
self.args.alternate_corr = False
|
| 46 |
+
|
| 47 |
+
# feature network, context network, and update block
|
| 48 |
+
if args.small:
|
| 49 |
+
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
|
| 50 |
+
self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
|
| 51 |
+
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
|
| 52 |
+
|
| 53 |
+
else:
|
| 54 |
+
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
|
| 55 |
+
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
|
| 56 |
+
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
|
| 57 |
+
|
| 58 |
+
def freeze_bn(self):
|
| 59 |
+
for m in self.modules():
|
| 60 |
+
if isinstance(m, nn.BatchNorm2d):
|
| 61 |
+
m.eval()
|
| 62 |
+
|
| 63 |
+
def initialize_flow(self, img):
|
| 64 |
+
""" Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
|
| 65 |
+
N, C, H, W = img.shape
|
| 66 |
+
coords0 = coords_grid(N, H//8, W//8, device=img.device)
|
| 67 |
+
coords1 = coords_grid(N, H//8, W//8, device=img.device)
|
| 68 |
+
|
| 69 |
+
# optical flow computed as difference: flow = coords1 - coords0
|
| 70 |
+
return coords0, coords1
|
| 71 |
+
|
| 72 |
+
def upsample_flow(self, flow, mask):
|
| 73 |
+
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
|
| 74 |
+
N, _, H, W = flow.shape
|
| 75 |
+
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
| 76 |
+
mask = torch.softmax(mask, dim=2)
|
| 77 |
+
|
| 78 |
+
up_flow = F.unfold(8 * flow, [3,3], padding=1)
|
| 79 |
+
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
| 80 |
+
|
| 81 |
+
up_flow = torch.sum(mask * up_flow, dim=2)
|
| 82 |
+
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
| 83 |
+
return up_flow.reshape(N, 2, 8*H, 8*W)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
|
| 87 |
+
""" Estimate optical flow between pair of frames """
|
| 88 |
+
|
| 89 |
+
image1 = 2 * (image1 / 255.0) - 1.0
|
| 90 |
+
image2 = 2 * (image2 / 255.0) - 1.0
|
| 91 |
+
|
| 92 |
+
image1 = image1.contiguous()
|
| 93 |
+
image2 = image2.contiguous()
|
| 94 |
+
|
| 95 |
+
hdim = self.hidden_dim
|
| 96 |
+
cdim = self.context_dim
|
| 97 |
+
|
| 98 |
+
# run the feature network
|
| 99 |
+
with autocast(enabled=self.args.mixed_precision):
|
| 100 |
+
fmap1, fmap2 = self.fnet([image1, image2])
|
| 101 |
+
|
| 102 |
+
fmap1 = fmap1.float()
|
| 103 |
+
fmap2 = fmap2.float()
|
| 104 |
+
if self.args.alternate_corr:
|
| 105 |
+
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
| 106 |
+
else:
|
| 107 |
+
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
| 108 |
+
|
| 109 |
+
# run the context network
|
| 110 |
+
with autocast(enabled=self.args.mixed_precision):
|
| 111 |
+
cnet = self.cnet(image1)
|
| 112 |
+
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
|
| 113 |
+
net = torch.tanh(net)
|
| 114 |
+
inp = torch.relu(inp)
|
| 115 |
+
|
| 116 |
+
coords0, coords1 = self.initialize_flow(image1)
|
| 117 |
+
|
| 118 |
+
if flow_init is not None:
|
| 119 |
+
coords1 = coords1 + flow_init
|
| 120 |
+
|
| 121 |
+
flow_predictions = []
|
| 122 |
+
for itr in range(iters):
|
| 123 |
+
coords1 = coords1.detach()
|
| 124 |
+
corr = corr_fn(coords1) # index correlation volume
|
| 125 |
+
|
| 126 |
+
flow = coords1 - coords0
|
| 127 |
+
with autocast(enabled=self.args.mixed_precision):
|
| 128 |
+
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
|
| 129 |
+
|
| 130 |
+
# F(t+1) = F(t) + \Delta(t)
|
| 131 |
+
coords1 = coords1 + delta_flow
|
| 132 |
+
|
| 133 |
+
# upsample predictions
|
| 134 |
+
if up_mask is None:
|
| 135 |
+
flow_up = upflow8(coords1 - coords0)
|
| 136 |
+
else:
|
| 137 |
+
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
|
| 138 |
+
|
| 139 |
+
flow_predictions.append(flow_up)
|
| 140 |
+
|
| 141 |
+
if test_mode:
|
| 142 |
+
return coords1 - coords0, flow_up
|
| 143 |
+
|
| 144 |
+
return flow_predictions
|
AIGVDet/core/update.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class FlowHead(nn.Module):
|
| 7 |
+
def __init__(self, input_dim=128, hidden_dim=256):
|
| 8 |
+
super(FlowHead, self).__init__()
|
| 9 |
+
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
| 10 |
+
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
|
| 11 |
+
self.relu = nn.ReLU(inplace=True)
|
| 12 |
+
|
| 13 |
+
def forward(self, x):
|
| 14 |
+
return self.conv2(self.relu(self.conv1(x)))
|
| 15 |
+
|
| 16 |
+
class ConvGRU(nn.Module):
|
| 17 |
+
def __init__(self, hidden_dim=128, input_dim=192+128):
|
| 18 |
+
super(ConvGRU, self).__init__()
|
| 19 |
+
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
| 20 |
+
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
| 21 |
+
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
| 22 |
+
|
| 23 |
+
def forward(self, h, x):
|
| 24 |
+
hx = torch.cat([h, x], dim=1)
|
| 25 |
+
|
| 26 |
+
z = torch.sigmoid(self.convz(hx))
|
| 27 |
+
r = torch.sigmoid(self.convr(hx))
|
| 28 |
+
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
|
| 29 |
+
|
| 30 |
+
h = (1-z) * h + z * q
|
| 31 |
+
return h
|
| 32 |
+
|
| 33 |
+
class SepConvGRU(nn.Module):
|
| 34 |
+
def __init__(self, hidden_dim=128, input_dim=192+128):
|
| 35 |
+
super(SepConvGRU, self).__init__()
|
| 36 |
+
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
| 37 |
+
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
| 38 |
+
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
| 39 |
+
|
| 40 |
+
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
| 41 |
+
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
| 42 |
+
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def forward(self, h, x):
|
| 46 |
+
# horizontal
|
| 47 |
+
hx = torch.cat([h, x], dim=1)
|
| 48 |
+
z = torch.sigmoid(self.convz1(hx))
|
| 49 |
+
r = torch.sigmoid(self.convr1(hx))
|
| 50 |
+
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
|
| 51 |
+
h = (1-z) * h + z * q
|
| 52 |
+
|
| 53 |
+
# vertical
|
| 54 |
+
hx = torch.cat([h, x], dim=1)
|
| 55 |
+
z = torch.sigmoid(self.convz2(hx))
|
| 56 |
+
r = torch.sigmoid(self.convr2(hx))
|
| 57 |
+
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
|
| 58 |
+
h = (1-z) * h + z * q
|
| 59 |
+
|
| 60 |
+
return h
|
| 61 |
+
|
| 62 |
+
class SmallMotionEncoder(nn.Module):
|
| 63 |
+
def __init__(self, args):
|
| 64 |
+
super(SmallMotionEncoder, self).__init__()
|
| 65 |
+
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
|
| 66 |
+
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
|
| 67 |
+
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
|
| 68 |
+
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
|
| 69 |
+
self.conv = nn.Conv2d(128, 80, 3, padding=1)
|
| 70 |
+
|
| 71 |
+
def forward(self, flow, corr):
|
| 72 |
+
cor = F.relu(self.convc1(corr))
|
| 73 |
+
flo = F.relu(self.convf1(flow))
|
| 74 |
+
flo = F.relu(self.convf2(flo))
|
| 75 |
+
cor_flo = torch.cat([cor, flo], dim=1)
|
| 76 |
+
out = F.relu(self.conv(cor_flo))
|
| 77 |
+
return torch.cat([out, flow], dim=1)
|
| 78 |
+
|
| 79 |
+
class BasicMotionEncoder(nn.Module):
|
| 80 |
+
def __init__(self, args):
|
| 81 |
+
super(BasicMotionEncoder, self).__init__()
|
| 82 |
+
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
|
| 83 |
+
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
|
| 84 |
+
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
|
| 85 |
+
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
|
| 86 |
+
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
|
| 87 |
+
self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
|
| 88 |
+
|
| 89 |
+
def forward(self, flow, corr):
|
| 90 |
+
cor = F.relu(self.convc1(corr))
|
| 91 |
+
cor = F.relu(self.convc2(cor))
|
| 92 |
+
flo = F.relu(self.convf1(flow))
|
| 93 |
+
flo = F.relu(self.convf2(flo))
|
| 94 |
+
|
| 95 |
+
cor_flo = torch.cat([cor, flo], dim=1)
|
| 96 |
+
out = F.relu(self.conv(cor_flo))
|
| 97 |
+
return torch.cat([out, flow], dim=1)
|
| 98 |
+
|
| 99 |
+
class SmallUpdateBlock(nn.Module):
|
| 100 |
+
def __init__(self, args, hidden_dim=96):
|
| 101 |
+
super(SmallUpdateBlock, self).__init__()
|
| 102 |
+
self.encoder = SmallMotionEncoder(args)
|
| 103 |
+
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
|
| 104 |
+
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
|
| 105 |
+
|
| 106 |
+
def forward(self, net, inp, corr, flow):
|
| 107 |
+
motion_features = self.encoder(flow, corr)
|
| 108 |
+
inp = torch.cat([inp, motion_features], dim=1)
|
| 109 |
+
net = self.gru(net, inp)
|
| 110 |
+
delta_flow = self.flow_head(net)
|
| 111 |
+
|
| 112 |
+
return net, None, delta_flow
|
| 113 |
+
|
| 114 |
+
class BasicUpdateBlock(nn.Module):
|
| 115 |
+
def __init__(self, args, hidden_dim=128, input_dim=128):
|
| 116 |
+
super(BasicUpdateBlock, self).__init__()
|
| 117 |
+
self.args = args
|
| 118 |
+
self.encoder = BasicMotionEncoder(args)
|
| 119 |
+
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
|
| 120 |
+
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
| 121 |
+
|
| 122 |
+
self.mask = nn.Sequential(
|
| 123 |
+
nn.Conv2d(128, 256, 3, padding=1),
|
| 124 |
+
nn.ReLU(inplace=True),
|
| 125 |
+
nn.Conv2d(256, 64*9, 1, padding=0))
|
| 126 |
+
|
| 127 |
+
def forward(self, net, inp, corr, flow, upsample=True):
|
| 128 |
+
motion_features = self.encoder(flow, corr)
|
| 129 |
+
inp = torch.cat([inp, motion_features], dim=1)
|
| 130 |
+
|
| 131 |
+
net = self.gru(net, inp)
|
| 132 |
+
delta_flow = self.flow_head(net)
|
| 133 |
+
|
| 134 |
+
# scale mask to balence gradients
|
| 135 |
+
mask = .25 * self.mask(net)
|
| 136 |
+
return net, mask, delta_flow
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
|
AIGVDet/core/utils/__init__.py
ADDED
|
File without changes
|
AIGVDet/core/utils/augmentor.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import random
|
| 3 |
+
import math
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
cv2.setNumThreads(0)
|
| 8 |
+
cv2.ocl.setUseOpenCL(False)
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torchvision.transforms import ColorJitter
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class FlowAugmentor:
|
| 16 |
+
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
|
| 17 |
+
|
| 18 |
+
# spatial augmentation params
|
| 19 |
+
self.crop_size = crop_size
|
| 20 |
+
self.min_scale = min_scale
|
| 21 |
+
self.max_scale = max_scale
|
| 22 |
+
self.spatial_aug_prob = 0.8
|
| 23 |
+
self.stretch_prob = 0.8
|
| 24 |
+
self.max_stretch = 0.2
|
| 25 |
+
|
| 26 |
+
# flip augmentation params
|
| 27 |
+
self.do_flip = do_flip
|
| 28 |
+
self.h_flip_prob = 0.5
|
| 29 |
+
self.v_flip_prob = 0.1
|
| 30 |
+
|
| 31 |
+
# photometric augmentation params
|
| 32 |
+
self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
|
| 33 |
+
self.asymmetric_color_aug_prob = 0.2
|
| 34 |
+
self.eraser_aug_prob = 0.5
|
| 35 |
+
|
| 36 |
+
def color_transform(self, img1, img2):
|
| 37 |
+
""" Photometric augmentation """
|
| 38 |
+
|
| 39 |
+
# asymmetric
|
| 40 |
+
if np.random.rand() < self.asymmetric_color_aug_prob:
|
| 41 |
+
img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
|
| 42 |
+
img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
|
| 43 |
+
|
| 44 |
+
# symmetric
|
| 45 |
+
else:
|
| 46 |
+
image_stack = np.concatenate([img1, img2], axis=0)
|
| 47 |
+
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
|
| 48 |
+
img1, img2 = np.split(image_stack, 2, axis=0)
|
| 49 |
+
|
| 50 |
+
return img1, img2
|
| 51 |
+
|
| 52 |
+
def eraser_transform(self, img1, img2, bounds=[50, 100]):
|
| 53 |
+
""" Occlusion augmentation """
|
| 54 |
+
|
| 55 |
+
ht, wd = img1.shape[:2]
|
| 56 |
+
if np.random.rand() < self.eraser_aug_prob:
|
| 57 |
+
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
| 58 |
+
for _ in range(np.random.randint(1, 3)):
|
| 59 |
+
x0 = np.random.randint(0, wd)
|
| 60 |
+
y0 = np.random.randint(0, ht)
|
| 61 |
+
dx = np.random.randint(bounds[0], bounds[1])
|
| 62 |
+
dy = np.random.randint(bounds[0], bounds[1])
|
| 63 |
+
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
|
| 64 |
+
|
| 65 |
+
return img1, img2
|
| 66 |
+
|
| 67 |
+
def spatial_transform(self, img1, img2, flow):
|
| 68 |
+
# randomly sample scale
|
| 69 |
+
ht, wd = img1.shape[:2]
|
| 70 |
+
min_scale = np.maximum(
|
| 71 |
+
(self.crop_size[0] + 8) / float(ht),
|
| 72 |
+
(self.crop_size[1] + 8) / float(wd))
|
| 73 |
+
|
| 74 |
+
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
|
| 75 |
+
scale_x = scale
|
| 76 |
+
scale_y = scale
|
| 77 |
+
if np.random.rand() < self.stretch_prob:
|
| 78 |
+
scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
|
| 79 |
+
scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
|
| 80 |
+
|
| 81 |
+
scale_x = np.clip(scale_x, min_scale, None)
|
| 82 |
+
scale_y = np.clip(scale_y, min_scale, None)
|
| 83 |
+
|
| 84 |
+
if np.random.rand() < self.spatial_aug_prob:
|
| 85 |
+
# rescale the images
|
| 86 |
+
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
| 87 |
+
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
| 88 |
+
flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
| 89 |
+
flow = flow * [scale_x, scale_y]
|
| 90 |
+
|
| 91 |
+
if self.do_flip:
|
| 92 |
+
if np.random.rand() < self.h_flip_prob: # h-flip
|
| 93 |
+
img1 = img1[:, ::-1]
|
| 94 |
+
img2 = img2[:, ::-1]
|
| 95 |
+
flow = flow[:, ::-1] * [-1.0, 1.0]
|
| 96 |
+
|
| 97 |
+
if np.random.rand() < self.v_flip_prob: # v-flip
|
| 98 |
+
img1 = img1[::-1, :]
|
| 99 |
+
img2 = img2[::-1, :]
|
| 100 |
+
flow = flow[::-1, :] * [1.0, -1.0]
|
| 101 |
+
|
| 102 |
+
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
|
| 103 |
+
x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
|
| 104 |
+
|
| 105 |
+
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
| 106 |
+
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
| 107 |
+
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
| 108 |
+
|
| 109 |
+
return img1, img2, flow
|
| 110 |
+
|
| 111 |
+
def __call__(self, img1, img2, flow):
|
| 112 |
+
img1, img2 = self.color_transform(img1, img2)
|
| 113 |
+
img1, img2 = self.eraser_transform(img1, img2)
|
| 114 |
+
img1, img2, flow = self.spatial_transform(img1, img2, flow)
|
| 115 |
+
|
| 116 |
+
img1 = np.ascontiguousarray(img1)
|
| 117 |
+
img2 = np.ascontiguousarray(img2)
|
| 118 |
+
flow = np.ascontiguousarray(flow)
|
| 119 |
+
|
| 120 |
+
return img1, img2, flow
|
| 121 |
+
|
| 122 |
+
class SparseFlowAugmentor:
|
| 123 |
+
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
|
| 124 |
+
# spatial augmentation params
|
| 125 |
+
self.crop_size = crop_size
|
| 126 |
+
self.min_scale = min_scale
|
| 127 |
+
self.max_scale = max_scale
|
| 128 |
+
self.spatial_aug_prob = 0.8
|
| 129 |
+
self.stretch_prob = 0.8
|
| 130 |
+
self.max_stretch = 0.2
|
| 131 |
+
|
| 132 |
+
# flip augmentation params
|
| 133 |
+
self.do_flip = do_flip
|
| 134 |
+
self.h_flip_prob = 0.5
|
| 135 |
+
self.v_flip_prob = 0.1
|
| 136 |
+
|
| 137 |
+
# photometric augmentation params
|
| 138 |
+
self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
|
| 139 |
+
self.asymmetric_color_aug_prob = 0.2
|
| 140 |
+
self.eraser_aug_prob = 0.5
|
| 141 |
+
|
| 142 |
+
def color_transform(self, img1, img2):
|
| 143 |
+
image_stack = np.concatenate([img1, img2], axis=0)
|
| 144 |
+
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
|
| 145 |
+
img1, img2 = np.split(image_stack, 2, axis=0)
|
| 146 |
+
return img1, img2
|
| 147 |
+
|
| 148 |
+
def eraser_transform(self, img1, img2):
|
| 149 |
+
ht, wd = img1.shape[:2]
|
| 150 |
+
if np.random.rand() < self.eraser_aug_prob:
|
| 151 |
+
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
| 152 |
+
for _ in range(np.random.randint(1, 3)):
|
| 153 |
+
x0 = np.random.randint(0, wd)
|
| 154 |
+
y0 = np.random.randint(0, ht)
|
| 155 |
+
dx = np.random.randint(50, 100)
|
| 156 |
+
dy = np.random.randint(50, 100)
|
| 157 |
+
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
|
| 158 |
+
|
| 159 |
+
return img1, img2
|
| 160 |
+
|
| 161 |
+
def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
|
| 162 |
+
ht, wd = flow.shape[:2]
|
| 163 |
+
coords = np.meshgrid(np.arange(wd), np.arange(ht))
|
| 164 |
+
coords = np.stack(coords, axis=-1)
|
| 165 |
+
|
| 166 |
+
coords = coords.reshape(-1, 2).astype(np.float32)
|
| 167 |
+
flow = flow.reshape(-1, 2).astype(np.float32)
|
| 168 |
+
valid = valid.reshape(-1).astype(np.float32)
|
| 169 |
+
|
| 170 |
+
coords0 = coords[valid>=1]
|
| 171 |
+
flow0 = flow[valid>=1]
|
| 172 |
+
|
| 173 |
+
ht1 = int(round(ht * fy))
|
| 174 |
+
wd1 = int(round(wd * fx))
|
| 175 |
+
|
| 176 |
+
coords1 = coords0 * [fx, fy]
|
| 177 |
+
flow1 = flow0 * [fx, fy]
|
| 178 |
+
|
| 179 |
+
xx = np.round(coords1[:,0]).astype(np.int32)
|
| 180 |
+
yy = np.round(coords1[:,1]).astype(np.int32)
|
| 181 |
+
|
| 182 |
+
v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
|
| 183 |
+
xx = xx[v]
|
| 184 |
+
yy = yy[v]
|
| 185 |
+
flow1 = flow1[v]
|
| 186 |
+
|
| 187 |
+
flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
|
| 188 |
+
valid_img = np.zeros([ht1, wd1], dtype=np.int32)
|
| 189 |
+
|
| 190 |
+
flow_img[yy, xx] = flow1
|
| 191 |
+
valid_img[yy, xx] = 1
|
| 192 |
+
|
| 193 |
+
return flow_img, valid_img
|
| 194 |
+
|
| 195 |
+
def spatial_transform(self, img1, img2, flow, valid):
|
| 196 |
+
# randomly sample scale
|
| 197 |
+
|
| 198 |
+
ht, wd = img1.shape[:2]
|
| 199 |
+
min_scale = np.maximum(
|
| 200 |
+
(self.crop_size[0] + 1) / float(ht),
|
| 201 |
+
(self.crop_size[1] + 1) / float(wd))
|
| 202 |
+
|
| 203 |
+
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
|
| 204 |
+
scale_x = np.clip(scale, min_scale, None)
|
| 205 |
+
scale_y = np.clip(scale, min_scale, None)
|
| 206 |
+
|
| 207 |
+
if np.random.rand() < self.spatial_aug_prob:
|
| 208 |
+
# rescale the images
|
| 209 |
+
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
| 210 |
+
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
| 211 |
+
flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
|
| 212 |
+
|
| 213 |
+
if self.do_flip:
|
| 214 |
+
if np.random.rand() < 0.5: # h-flip
|
| 215 |
+
img1 = img1[:, ::-1]
|
| 216 |
+
img2 = img2[:, ::-1]
|
| 217 |
+
flow = flow[:, ::-1] * [-1.0, 1.0]
|
| 218 |
+
valid = valid[:, ::-1]
|
| 219 |
+
|
| 220 |
+
margin_y = 20
|
| 221 |
+
margin_x = 50
|
| 222 |
+
|
| 223 |
+
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
|
| 224 |
+
x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
|
| 225 |
+
|
| 226 |
+
y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
|
| 227 |
+
x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
|
| 228 |
+
|
| 229 |
+
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
| 230 |
+
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
| 231 |
+
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
| 232 |
+
valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
| 233 |
+
return img1, img2, flow, valid
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def __call__(self, img1, img2, flow, valid):
|
| 237 |
+
img1, img2 = self.color_transform(img1, img2)
|
| 238 |
+
img1, img2 = self.eraser_transform(img1, img2)
|
| 239 |
+
img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
|
| 240 |
+
|
| 241 |
+
img1 = np.ascontiguousarray(img1)
|
| 242 |
+
img2 = np.ascontiguousarray(img2)
|
| 243 |
+
flow = np.ascontiguousarray(flow)
|
| 244 |
+
valid = np.ascontiguousarray(valid)
|
| 245 |
+
|
| 246 |
+
return img1, img2, flow, valid
|
AIGVDet/core/utils/flow_viz.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
# MIT License
|
| 5 |
+
#
|
| 6 |
+
# Copyright (c) 2018 Tom Runia
|
| 7 |
+
#
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to conditions.
|
| 14 |
+
#
|
| 15 |
+
# Author: Tom Runia
|
| 16 |
+
# Date Created: 2018-08-03
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
def make_colorwheel():
|
| 21 |
+
"""
|
| 22 |
+
Generates a color wheel for optical flow visualization as presented in:
|
| 23 |
+
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
| 24 |
+
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
|
| 25 |
+
|
| 26 |
+
Code follows the original C++ source code of Daniel Scharstein.
|
| 27 |
+
Code follows the the Matlab source code of Deqing Sun.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
np.ndarray: Color wheel
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
RY = 15
|
| 34 |
+
YG = 6
|
| 35 |
+
GC = 4
|
| 36 |
+
CB = 11
|
| 37 |
+
BM = 13
|
| 38 |
+
MR = 6
|
| 39 |
+
|
| 40 |
+
ncols = RY + YG + GC + CB + BM + MR
|
| 41 |
+
colorwheel = np.zeros((ncols, 3))
|
| 42 |
+
col = 0
|
| 43 |
+
|
| 44 |
+
# RY
|
| 45 |
+
colorwheel[0:RY, 0] = 255
|
| 46 |
+
colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
|
| 47 |
+
col = col+RY
|
| 48 |
+
# YG
|
| 49 |
+
colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
|
| 50 |
+
colorwheel[col:col+YG, 1] = 255
|
| 51 |
+
col = col+YG
|
| 52 |
+
# GC
|
| 53 |
+
colorwheel[col:col+GC, 1] = 255
|
| 54 |
+
colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
|
| 55 |
+
col = col+GC
|
| 56 |
+
# CB
|
| 57 |
+
colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
|
| 58 |
+
colorwheel[col:col+CB, 2] = 255
|
| 59 |
+
col = col+CB
|
| 60 |
+
# BM
|
| 61 |
+
colorwheel[col:col+BM, 2] = 255
|
| 62 |
+
colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
|
| 63 |
+
col = col+BM
|
| 64 |
+
# MR
|
| 65 |
+
colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
|
| 66 |
+
colorwheel[col:col+MR, 0] = 255
|
| 67 |
+
return colorwheel
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def flow_uv_to_colors(u, v, convert_to_bgr=False):
|
| 71 |
+
"""
|
| 72 |
+
Applies the flow color wheel to (possibly clipped) flow components u and v.
|
| 73 |
+
|
| 74 |
+
According to the C++ source code of Daniel Scharstein
|
| 75 |
+
According to the Matlab source code of Deqing Sun
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
u (np.ndarray): Input horizontal flow of shape [H,W]
|
| 79 |
+
v (np.ndarray): Input vertical flow of shape [H,W]
|
| 80 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
| 84 |
+
"""
|
| 85 |
+
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
|
| 86 |
+
colorwheel = make_colorwheel() # shape [55x3]
|
| 87 |
+
ncols = colorwheel.shape[0]
|
| 88 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
| 89 |
+
a = np.arctan2(-v, -u)/np.pi
|
| 90 |
+
fk = (a+1) / 2*(ncols-1)
|
| 91 |
+
k0 = np.floor(fk).astype(np.int32)
|
| 92 |
+
k1 = k0 + 1
|
| 93 |
+
k1[k1 == ncols] = 0
|
| 94 |
+
f = fk - k0
|
| 95 |
+
for i in range(colorwheel.shape[1]):
|
| 96 |
+
tmp = colorwheel[:,i]
|
| 97 |
+
col0 = tmp[k0] / 255.0
|
| 98 |
+
col1 = tmp[k1] / 255.0
|
| 99 |
+
col = (1-f)*col0 + f*col1
|
| 100 |
+
idx = (rad <= 1)
|
| 101 |
+
col[idx] = 1 - rad[idx] * (1-col[idx])
|
| 102 |
+
col[~idx] = col[~idx] * 0.75 # out of range
|
| 103 |
+
# Note the 2-i => BGR instead of RGB
|
| 104 |
+
ch_idx = 2-i if convert_to_bgr else i
|
| 105 |
+
flow_image[:,:,ch_idx] = np.floor(255 * col)
|
| 106 |
+
return flow_image
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
|
| 110 |
+
"""
|
| 111 |
+
Expects a two dimensional flow image of shape.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
|
| 115 |
+
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
|
| 116 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
| 120 |
+
"""
|
| 121 |
+
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
|
| 122 |
+
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
|
| 123 |
+
if clip_flow is not None:
|
| 124 |
+
flow_uv = np.clip(flow_uv, 0, clip_flow)
|
| 125 |
+
u = flow_uv[:,:,0]
|
| 126 |
+
v = flow_uv[:,:,1]
|
| 127 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
| 128 |
+
rad_max = np.max(rad)
|
| 129 |
+
epsilon = 1e-5
|
| 130 |
+
u = u / (rad_max + epsilon)
|
| 131 |
+
v = v / (rad_max + epsilon)
|
| 132 |
+
return flow_uv_to_colors(u, v, convert_to_bgr)
|
AIGVDet/core/utils/frame_utils.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from os.path import *
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
cv2.setNumThreads(0)
|
| 8 |
+
cv2.ocl.setUseOpenCL(False)
|
| 9 |
+
|
| 10 |
+
TAG_CHAR = np.array([202021.25], np.float32)
|
| 11 |
+
|
| 12 |
+
def readFlow(fn):
|
| 13 |
+
""" Read .flo file in Middlebury format"""
|
| 14 |
+
# Code adapted from:
|
| 15 |
+
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
|
| 16 |
+
|
| 17 |
+
# WARNING: this will work on little-endian architectures (eg Intel x86) only!
|
| 18 |
+
# print 'fn = %s'%(fn)
|
| 19 |
+
with open(fn, 'rb') as f:
|
| 20 |
+
magic = np.fromfile(f, np.float32, count=1)
|
| 21 |
+
if 202021.25 != magic:
|
| 22 |
+
print('Magic number incorrect. Invalid .flo file')
|
| 23 |
+
return None
|
| 24 |
+
else:
|
| 25 |
+
w = np.fromfile(f, np.int32, count=1)
|
| 26 |
+
h = np.fromfile(f, np.int32, count=1)
|
| 27 |
+
# print 'Reading %d x %d flo file\n' % (w, h)
|
| 28 |
+
data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
|
| 29 |
+
# Reshape data into 3D array (columns, rows, bands)
|
| 30 |
+
# The reshape here is for visualization, the original code is (w,h,2)
|
| 31 |
+
return np.resize(data, (int(h), int(w), 2))
|
| 32 |
+
|
| 33 |
+
def readPFM(file):
|
| 34 |
+
file = open(file, 'rb')
|
| 35 |
+
|
| 36 |
+
color = None
|
| 37 |
+
width = None
|
| 38 |
+
height = None
|
| 39 |
+
scale = None
|
| 40 |
+
endian = None
|
| 41 |
+
|
| 42 |
+
header = file.readline().rstrip()
|
| 43 |
+
if header == b'PF':
|
| 44 |
+
color = True
|
| 45 |
+
elif header == b'Pf':
|
| 46 |
+
color = False
|
| 47 |
+
else:
|
| 48 |
+
raise Exception('Not a PFM file.')
|
| 49 |
+
|
| 50 |
+
dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
|
| 51 |
+
if dim_match:
|
| 52 |
+
width, height = map(int, dim_match.groups())
|
| 53 |
+
else:
|
| 54 |
+
raise Exception('Malformed PFM header.')
|
| 55 |
+
|
| 56 |
+
scale = float(file.readline().rstrip())
|
| 57 |
+
if scale < 0: # little-endian
|
| 58 |
+
endian = '<'
|
| 59 |
+
scale = -scale
|
| 60 |
+
else:
|
| 61 |
+
endian = '>' # big-endian
|
| 62 |
+
|
| 63 |
+
data = np.fromfile(file, endian + 'f')
|
| 64 |
+
shape = (height, width, 3) if color else (height, width)
|
| 65 |
+
|
| 66 |
+
data = np.reshape(data, shape)
|
| 67 |
+
data = np.flipud(data)
|
| 68 |
+
return data
|
| 69 |
+
|
| 70 |
+
def writeFlow(filename,uv,v=None):
|
| 71 |
+
""" Write optical flow to file.
|
| 72 |
+
|
| 73 |
+
If v is None, uv is assumed to contain both u and v channels,
|
| 74 |
+
stacked in depth.
|
| 75 |
+
Original code by Deqing Sun, adapted from Daniel Scharstein.
|
| 76 |
+
"""
|
| 77 |
+
nBands = 2
|
| 78 |
+
|
| 79 |
+
if v is None:
|
| 80 |
+
assert(uv.ndim == 3)
|
| 81 |
+
assert(uv.shape[2] == 2)
|
| 82 |
+
u = uv[:,:,0]
|
| 83 |
+
v = uv[:,:,1]
|
| 84 |
+
else:
|
| 85 |
+
u = uv
|
| 86 |
+
|
| 87 |
+
assert(u.shape == v.shape)
|
| 88 |
+
height,width = u.shape
|
| 89 |
+
f = open(filename,'wb')
|
| 90 |
+
# write the header
|
| 91 |
+
f.write(TAG_CHAR)
|
| 92 |
+
np.array(width).astype(np.int32).tofile(f)
|
| 93 |
+
np.array(height).astype(np.int32).tofile(f)
|
| 94 |
+
# arrange into matrix form
|
| 95 |
+
tmp = np.zeros((height, width*nBands))
|
| 96 |
+
tmp[:,np.arange(width)*2] = u
|
| 97 |
+
tmp[:,np.arange(width)*2 + 1] = v
|
| 98 |
+
tmp.astype(np.float32).tofile(f)
|
| 99 |
+
f.close()
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def readFlowKITTI(filename):
|
| 103 |
+
flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
|
| 104 |
+
flow = flow[:,:,::-1].astype(np.float32)
|
| 105 |
+
flow, valid = flow[:, :, :2], flow[:, :, 2]
|
| 106 |
+
flow = (flow - 2**15) / 64.0
|
| 107 |
+
return flow, valid
|
| 108 |
+
|
| 109 |
+
def readDispKITTI(filename):
|
| 110 |
+
disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
|
| 111 |
+
valid = disp > 0.0
|
| 112 |
+
flow = np.stack([-disp, np.zeros_like(disp)], -1)
|
| 113 |
+
return flow, valid
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def writeFlowKITTI(filename, uv):
|
| 117 |
+
uv = 64.0 * uv + 2**15
|
| 118 |
+
valid = np.ones([uv.shape[0], uv.shape[1], 1])
|
| 119 |
+
uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
|
| 120 |
+
cv2.imwrite(filename, uv[..., ::-1])
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def read_gen(file_name, pil=False):
|
| 124 |
+
ext = splitext(file_name)[-1]
|
| 125 |
+
if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
|
| 126 |
+
return Image.open(file_name)
|
| 127 |
+
elif ext == '.bin' or ext == '.raw':
|
| 128 |
+
return np.load(file_name)
|
| 129 |
+
elif ext == '.flo':
|
| 130 |
+
return readFlow(file_name).astype(np.float32)
|
| 131 |
+
elif ext == '.pfm':
|
| 132 |
+
flow = readPFM(file_name).astype(np.float32)
|
| 133 |
+
if len(flow.shape) == 2:
|
| 134 |
+
return flow
|
| 135 |
+
else:
|
| 136 |
+
return flow[:, :, :-1]
|
| 137 |
+
return []
|
AIGVDet/core/utils/utils.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import numpy as np
|
| 4 |
+
from scipy import interpolate
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class InputPadder:
|
| 8 |
+
""" Pads images such that dimensions are divisible by 8 """
|
| 9 |
+
def __init__(self, dims, mode='sintel'):
|
| 10 |
+
self.ht, self.wd = dims[-2:]
|
| 11 |
+
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
|
| 12 |
+
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
|
| 13 |
+
if mode == 'sintel':
|
| 14 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
|
| 15 |
+
else:
|
| 16 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
| 17 |
+
|
| 18 |
+
def pad(self, *inputs):
|
| 19 |
+
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
| 20 |
+
|
| 21 |
+
def unpad(self,x):
|
| 22 |
+
ht, wd = x.shape[-2:]
|
| 23 |
+
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
| 24 |
+
return x[..., c[0]:c[1], c[2]:c[3]]
|
| 25 |
+
|
| 26 |
+
def forward_interpolate(flow):
|
| 27 |
+
flow = flow.detach().cpu().numpy()
|
| 28 |
+
dx, dy = flow[0], flow[1]
|
| 29 |
+
|
| 30 |
+
ht, wd = dx.shape
|
| 31 |
+
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
|
| 32 |
+
|
| 33 |
+
x1 = x0 + dx
|
| 34 |
+
y1 = y0 + dy
|
| 35 |
+
|
| 36 |
+
x1 = x1.reshape(-1)
|
| 37 |
+
y1 = y1.reshape(-1)
|
| 38 |
+
dx = dx.reshape(-1)
|
| 39 |
+
dy = dy.reshape(-1)
|
| 40 |
+
|
| 41 |
+
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
|
| 42 |
+
x1 = x1[valid]
|
| 43 |
+
y1 = y1[valid]
|
| 44 |
+
dx = dx[valid]
|
| 45 |
+
dy = dy[valid]
|
| 46 |
+
|
| 47 |
+
flow_x = interpolate.griddata(
|
| 48 |
+
(x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
|
| 49 |
+
|
| 50 |
+
flow_y = interpolate.griddata(
|
| 51 |
+
(x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
|
| 52 |
+
|
| 53 |
+
flow = np.stack([flow_x, flow_y], axis=0)
|
| 54 |
+
return torch.from_numpy(flow).float()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
| 58 |
+
""" Wrapper for grid_sample, uses pixel coordinates """
|
| 59 |
+
H, W = img.shape[-2:]
|
| 60 |
+
xgrid, ygrid = coords.split([1,1], dim=-1)
|
| 61 |
+
xgrid = 2*xgrid/(W-1) - 1
|
| 62 |
+
ygrid = 2*ygrid/(H-1) - 1
|
| 63 |
+
|
| 64 |
+
grid = torch.cat([xgrid, ygrid], dim=-1)
|
| 65 |
+
img = F.grid_sample(img, grid, align_corners=True)
|
| 66 |
+
|
| 67 |
+
if mask:
|
| 68 |
+
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
| 69 |
+
return img, mask.float()
|
| 70 |
+
|
| 71 |
+
return img
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def coords_grid(batch, ht, wd, device):
|
| 75 |
+
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
|
| 76 |
+
coords = torch.stack(coords[::-1], dim=0).float()
|
| 77 |
+
return coords[None].repeat(batch, 1, 1, 1)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def upflow8(flow, mode='bilinear'):
|
| 81 |
+
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
|
| 82 |
+
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
|
AIGVDet/core/utils1/config.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
from abc import ABC
|
| 5 |
+
from typing import Type
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DefaultConfigs(ABC):
|
| 9 |
+
####### base setting ######
|
| 10 |
+
gpus = [0]
|
| 11 |
+
seed = 3407
|
| 12 |
+
arch = "resnet50"
|
| 13 |
+
datasets = ["zhaolian_train"]
|
| 14 |
+
datasets_test = ["adm_res_abs_ddim20s"]
|
| 15 |
+
mode = "binary"
|
| 16 |
+
class_bal = False
|
| 17 |
+
batch_size = 64
|
| 18 |
+
loadSize = 256
|
| 19 |
+
cropSize = 224
|
| 20 |
+
epoch = "latest"
|
| 21 |
+
num_workers = 20
|
| 22 |
+
serial_batches = False
|
| 23 |
+
isTrain = True
|
| 24 |
+
|
| 25 |
+
# data augmentation
|
| 26 |
+
rz_interp = ["bilinear"]
|
| 27 |
+
# blur_prob = 0.0
|
| 28 |
+
blur_prob = 0.1
|
| 29 |
+
blur_sig = [0.5]
|
| 30 |
+
# jpg_prob = 0.0
|
| 31 |
+
jpg_prob = 0.1
|
| 32 |
+
jpg_method = ["cv2"]
|
| 33 |
+
jpg_qual = [75]
|
| 34 |
+
gray_prob = 0.0
|
| 35 |
+
aug_resize = True
|
| 36 |
+
aug_crop = True
|
| 37 |
+
aug_flip = True
|
| 38 |
+
aug_norm = True
|
| 39 |
+
|
| 40 |
+
####### train setting ######
|
| 41 |
+
warmup = False
|
| 42 |
+
# warmup = True
|
| 43 |
+
warmup_epoch = 3
|
| 44 |
+
earlystop = True
|
| 45 |
+
earlystop_epoch = 5
|
| 46 |
+
optim = "adam"
|
| 47 |
+
new_optim = False
|
| 48 |
+
loss_freq = 400
|
| 49 |
+
save_latest_freq = 2000
|
| 50 |
+
save_epoch_freq = 20
|
| 51 |
+
continue_train = False
|
| 52 |
+
epoch_count = 1
|
| 53 |
+
last_epoch = -1
|
| 54 |
+
nepoch = 400
|
| 55 |
+
beta1 = 0.9
|
| 56 |
+
lr = 0.0001
|
| 57 |
+
init_type = "normal"
|
| 58 |
+
init_gain = 0.02
|
| 59 |
+
pretrained = True
|
| 60 |
+
|
| 61 |
+
# paths information
|
| 62 |
+
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 63 |
+
dataset_root = os.path.join(root_dir, "data")
|
| 64 |
+
exp_root = os.path.join(root_dir, "data", "exp")
|
| 65 |
+
_exp_name = ""
|
| 66 |
+
exp_dir = ""
|
| 67 |
+
ckpt_dir = ""
|
| 68 |
+
logs_path = ""
|
| 69 |
+
ckpt_path = ""
|
| 70 |
+
|
| 71 |
+
@property
|
| 72 |
+
def exp_name(self):
|
| 73 |
+
return self._exp_name
|
| 74 |
+
|
| 75 |
+
@exp_name.setter
|
| 76 |
+
def exp_name(self, value: str):
|
| 77 |
+
self._exp_name = value
|
| 78 |
+
self.exp_dir: str = os.path.join(self.exp_root, self.exp_name)
|
| 79 |
+
self.ckpt_dir: str = os.path.join(self.exp_dir, "ckpt")
|
| 80 |
+
self.logs_path: str = os.path.join(self.exp_dir, "logs.txt")
|
| 81 |
+
|
| 82 |
+
os.makedirs(self.exp_dir, exist_ok=True)
|
| 83 |
+
os.makedirs(self.ckpt_dir, exist_ok=True)
|
| 84 |
+
|
| 85 |
+
def to_dict(self):
|
| 86 |
+
dic = {}
|
| 87 |
+
for fieldkey in dir(self):
|
| 88 |
+
fieldvalue = getattr(self, fieldkey)
|
| 89 |
+
if not fieldkey.startswith("__") and not callable(fieldvalue) and not fieldkey.startswith("_"):
|
| 90 |
+
dic[fieldkey] = fieldvalue
|
| 91 |
+
return dic
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def args_list2dict(arg_list: list):
|
| 95 |
+
assert len(arg_list) % 2 == 0, f"Override list has odd length: {arg_list}; it must be a list of pairs"
|
| 96 |
+
return dict(zip(arg_list[::2], arg_list[1::2]))
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def str2bool(v: str) -> bool:
|
| 100 |
+
if isinstance(v, bool):
|
| 101 |
+
return v
|
| 102 |
+
elif v.lower() in ("true", "yes", "on", "y", "t", "1"):
|
| 103 |
+
return True
|
| 104 |
+
elif v.lower() in ("false", "no", "off", "n", "f", "0"):
|
| 105 |
+
return False
|
| 106 |
+
else:
|
| 107 |
+
return bool(v)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def str2list(v: str, element_type=None) -> list:
|
| 111 |
+
if not isinstance(v, (list, tuple, set)):
|
| 112 |
+
v = v.lstrip("[").rstrip("]")
|
| 113 |
+
v = v.split(",")
|
| 114 |
+
v = list(map(str.strip, v))
|
| 115 |
+
if element_type is not None:
|
| 116 |
+
v = list(map(element_type, v))
|
| 117 |
+
return v
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
CONFIGCLASS = Type[DefaultConfigs]
|
| 121 |
+
|
| 122 |
+
parser = argparse.ArgumentParser()
|
| 123 |
+
parser.add_argument("--gpus", default=[0], type=int, nargs="+")
|
| 124 |
+
parser.add_argument("--exp_name", default="", type=str)
|
| 125 |
+
parser.add_argument("--ckpt", default="model_epoch_latest.pth", type=str)
|
| 126 |
+
parser.add_argument("opts", default=[], nargs=argparse.REMAINDER)
|
| 127 |
+
args = parser.parse_args()
|
| 128 |
+
|
| 129 |
+
if os.path.exists(os.path.join(DefaultConfigs.exp_root, args.exp_name, "config.py")):
|
| 130 |
+
sys.path.insert(0, os.path.join(DefaultConfigs.exp_root, args.exp_name))
|
| 131 |
+
from config import cfg
|
| 132 |
+
|
| 133 |
+
cfg: CONFIGCLASS
|
| 134 |
+
else:
|
| 135 |
+
cfg = DefaultConfigs()
|
| 136 |
+
|
| 137 |
+
if args.opts:
|
| 138 |
+
opts = args_list2dict(args.opts)
|
| 139 |
+
for k, v in opts.items():
|
| 140 |
+
if not hasattr(cfg, k):
|
| 141 |
+
raise ValueError(f"Unrecognized option: {k}")
|
| 142 |
+
original_type = type(getattr(cfg, k))
|
| 143 |
+
if original_type == bool:
|
| 144 |
+
setattr(cfg, k, str2bool(v))
|
| 145 |
+
elif original_type in (list, tuple, set):
|
| 146 |
+
setattr(cfg, k, str2list(v, type(getattr(cfg, k)[0])))
|
| 147 |
+
else:
|
| 148 |
+
setattr(cfg, k, original_type(v))
|
| 149 |
+
|
| 150 |
+
cfg.gpus: list = args.gpus
|
| 151 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ", ".join([str(gpu) for gpu in cfg.gpus])
|
| 152 |
+
cfg.exp_name = args.exp_name
|
| 153 |
+
cfg.ckpt_path: str = os.path.join(cfg.ckpt_dir, args.ckpt)
|
| 154 |
+
|
| 155 |
+
if isinstance(cfg.datasets, str):
|
| 156 |
+
cfg.datasets = cfg.datasets.split(",")
|
AIGVDet/core/utils1/datasets.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from io import BytesIO
|
| 3 |
+
from random import choice, random
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.utils.data
|
| 9 |
+
import torchvision.datasets as datasets
|
| 10 |
+
import torchvision.transforms as transforms
|
| 11 |
+
import torchvision.transforms.functional as TF
|
| 12 |
+
from PIL import Image, ImageFile
|
| 13 |
+
from scipy.ndimage import gaussian_filter
|
| 14 |
+
from torch.utils.data.sampler import WeightedRandomSampler
|
| 15 |
+
|
| 16 |
+
from utils1.config import CONFIGCLASS
|
| 17 |
+
|
| 18 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def dataset_folder(root: str, cfg: CONFIGCLASS):
|
| 22 |
+
if cfg.mode == "binary":
|
| 23 |
+
return binary_dataset(root, cfg)
|
| 24 |
+
if cfg.mode == "filename":
|
| 25 |
+
return FileNameDataset(root, cfg)
|
| 26 |
+
raise ValueError("cfg.mode needs to be binary or filename.")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def binary_dataset(root: str, cfg: CONFIGCLASS):
|
| 30 |
+
identity_transform = transforms.Lambda(lambda img: img)
|
| 31 |
+
|
| 32 |
+
rz_func = identity_transform
|
| 33 |
+
|
| 34 |
+
if cfg.isTrain:
|
| 35 |
+
crop_func = transforms.RandomCrop((448,448))
|
| 36 |
+
else:
|
| 37 |
+
crop_func = transforms.CenterCrop((448,448)) if cfg.aug_crop else identity_transform
|
| 38 |
+
|
| 39 |
+
if cfg.isTrain and cfg.aug_flip:
|
| 40 |
+
flip_func = transforms.RandomHorizontalFlip()
|
| 41 |
+
else:
|
| 42 |
+
flip_func = identity_transform
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
return datasets.ImageFolder(
|
| 46 |
+
root,
|
| 47 |
+
transforms.Compose(
|
| 48 |
+
[
|
| 49 |
+
rz_func,
|
| 50 |
+
#change
|
| 51 |
+
transforms.Lambda(lambda img: blur_jpg_augment(img, cfg)),
|
| 52 |
+
crop_func,
|
| 53 |
+
flip_func,
|
| 54 |
+
transforms.ToTensor(),
|
| 55 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 56 |
+
if cfg.aug_norm
|
| 57 |
+
else identity_transform,
|
| 58 |
+
]
|
| 59 |
+
)
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class FileNameDataset(datasets.ImageFolder):
|
| 64 |
+
def name(self):
|
| 65 |
+
return 'FileNameDataset'
|
| 66 |
+
|
| 67 |
+
def __init__(self, opt, root):
|
| 68 |
+
self.opt = opt
|
| 69 |
+
super().__init__(root)
|
| 70 |
+
|
| 71 |
+
def __getitem__(self, index):
|
| 72 |
+
# Loading sample
|
| 73 |
+
path, target = self.samples[index]
|
| 74 |
+
return path
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def blur_jpg_augment(img: Image.Image, cfg: CONFIGCLASS):
|
| 78 |
+
img: np.ndarray = np.array(img)
|
| 79 |
+
if cfg.isTrain:
|
| 80 |
+
if random() < cfg.blur_prob:
|
| 81 |
+
sig = sample_continuous(cfg.blur_sig)
|
| 82 |
+
gaussian_blur(img, sig)
|
| 83 |
+
|
| 84 |
+
if random() < cfg.jpg_prob:
|
| 85 |
+
method = sample_discrete(cfg.jpg_method)
|
| 86 |
+
qual = sample_discrete(cfg.jpg_qual)
|
| 87 |
+
img = jpeg_from_key(img, qual, method)
|
| 88 |
+
|
| 89 |
+
return Image.fromarray(img)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def sample_continuous(s: list):
|
| 93 |
+
if len(s) == 1:
|
| 94 |
+
return s[0]
|
| 95 |
+
if len(s) == 2:
|
| 96 |
+
rg = s[1] - s[0]
|
| 97 |
+
return random() * rg + s[0]
|
| 98 |
+
raise ValueError("Length of iterable s should be 1 or 2.")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def sample_discrete(s: list):
|
| 102 |
+
return s[0] if len(s) == 1 else choice(s)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def gaussian_blur(img: np.ndarray, sigma: float):
|
| 106 |
+
gaussian_filter(img[:, :, 0], output=img[:, :, 0], sigma=sigma)
|
| 107 |
+
gaussian_filter(img[:, :, 1], output=img[:, :, 1], sigma=sigma)
|
| 108 |
+
gaussian_filter(img[:, :, 2], output=img[:, :, 2], sigma=sigma)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def cv2_jpg(img: np.ndarray, compress_val: int) -> np.ndarray:
|
| 112 |
+
img_cv2 = img[:, :, ::-1]
|
| 113 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), compress_val]
|
| 114 |
+
result, encimg = cv2.imencode(".jpg", img_cv2, encode_param)
|
| 115 |
+
decimg = cv2.imdecode(encimg, 1)
|
| 116 |
+
return decimg[:, :, ::-1]
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def pil_jpg(img: np.ndarray, compress_val: int):
|
| 120 |
+
out = BytesIO()
|
| 121 |
+
img = Image.fromarray(img)
|
| 122 |
+
img.save(out, format="jpeg", quality=compress_val)
|
| 123 |
+
img = Image.open(out)
|
| 124 |
+
# load from memory before ByteIO closes
|
| 125 |
+
img = np.array(img)
|
| 126 |
+
out.close()
|
| 127 |
+
return img
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
jpeg_dict = {"cv2": cv2_jpg, "pil": pil_jpg}
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def jpeg_from_key(img: np.ndarray, compress_val: int, key: str) -> np.ndarray:
|
| 134 |
+
method = jpeg_dict[key]
|
| 135 |
+
return method(img, compress_val)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
rz_dict = {'bilinear': Image.BILINEAR,
|
| 139 |
+
'bicubic': Image.BICUBIC,
|
| 140 |
+
'lanczos': Image.LANCZOS,
|
| 141 |
+
'nearest': Image.NEAREST}
|
| 142 |
+
def custom_resize(img: Image.Image, cfg: CONFIGCLASS) -> Image.Image:
|
| 143 |
+
interp = sample_discrete(cfg.rz_interp)
|
| 144 |
+
return TF.resize(img, cfg.loadSize, interpolation=rz_dict[interp])
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def get_dataset(cfg: CONFIGCLASS):
|
| 148 |
+
dset_lst = []
|
| 149 |
+
for dataset in cfg.datasets:
|
| 150 |
+
root = os.path.join(cfg.dataset_root, dataset)
|
| 151 |
+
dset = dataset_folder(root, cfg)
|
| 152 |
+
dset_lst.append(dset)
|
| 153 |
+
return torch.utils.data.ConcatDataset(dset_lst)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def get_bal_sampler(dataset: torch.utils.data.ConcatDataset):
|
| 157 |
+
targets = []
|
| 158 |
+
for d in dataset.datasets:
|
| 159 |
+
targets.extend(d.targets)
|
| 160 |
+
|
| 161 |
+
ratio = np.bincount(targets)
|
| 162 |
+
w = 1.0 / torch.tensor(ratio, dtype=torch.float)
|
| 163 |
+
sample_weights = w[targets]
|
| 164 |
+
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights))
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def create_dataloader(cfg: CONFIGCLASS):
|
| 168 |
+
shuffle = not cfg.serial_batches if (cfg.isTrain and not cfg.class_bal) else False
|
| 169 |
+
dataset = get_dataset(cfg)
|
| 170 |
+
sampler = get_bal_sampler(dataset) if cfg.class_bal else None
|
| 171 |
+
|
| 172 |
+
return torch.utils.data.DataLoader(
|
| 173 |
+
dataset,
|
| 174 |
+
batch_size=cfg.batch_size,
|
| 175 |
+
shuffle=shuffle,
|
| 176 |
+
sampler=sampler,
|
| 177 |
+
num_workers=int(cfg.num_workers),
|
| 178 |
+
)
|
AIGVDet/core/utils1/earlystop.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from utils1.trainer import Trainer
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class EarlyStopping:
|
| 7 |
+
"""Early stops the training if validation loss doesn't improve after a given patience."""
|
| 8 |
+
|
| 9 |
+
def __init__(self, patience=1, verbose=False, delta=0):
|
| 10 |
+
"""
|
| 11 |
+
Args:
|
| 12 |
+
patience (int): How long to wait after last time validation loss improved.
|
| 13 |
+
Default: 7
|
| 14 |
+
verbose (bool): If True, prints a message for each validation loss improvement.
|
| 15 |
+
Default: False
|
| 16 |
+
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
|
| 17 |
+
Default: 0
|
| 18 |
+
"""
|
| 19 |
+
self.patience = patience
|
| 20 |
+
self.verbose = verbose
|
| 21 |
+
self.counter = 0
|
| 22 |
+
self.best_score = None
|
| 23 |
+
self.early_stop = False
|
| 24 |
+
self.score_max = -np.Inf
|
| 25 |
+
self.delta = delta
|
| 26 |
+
|
| 27 |
+
def __call__(self, score: float, trainer: Trainer):
|
| 28 |
+
if self.best_score is None:
|
| 29 |
+
self.best_score = score
|
| 30 |
+
self.save_checkpoint(score, trainer)
|
| 31 |
+
elif score < self.best_score - self.delta:
|
| 32 |
+
self.counter += 1
|
| 33 |
+
print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
|
| 34 |
+
if self.counter >= self.patience:
|
| 35 |
+
self.early_stop = True
|
| 36 |
+
else:
|
| 37 |
+
self.best_score = score
|
| 38 |
+
self.save_checkpoint(score, trainer)
|
| 39 |
+
self.counter = 0
|
| 40 |
+
|
| 41 |
+
def save_checkpoint(self, score: float, trainer: Trainer):
|
| 42 |
+
"""Saves model when validation loss decrease."""
|
| 43 |
+
if self.verbose:
|
| 44 |
+
print(f"Validation accuracy increased ({self.score_max:.6f} --> {score:.6f}). Saving model ...")
|
| 45 |
+
trainer.save_networks("best")
|
| 46 |
+
self.score_max = score
|
AIGVDet/core/utils1/eval.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
from utils1.config import CONFIGCLASS
|
| 10 |
+
from utils1.utils import to_cuda
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_val_cfg(cfg: CONFIGCLASS, split="val", copy=True):
|
| 14 |
+
if copy:
|
| 15 |
+
from copy import deepcopy
|
| 16 |
+
|
| 17 |
+
val_cfg = deepcopy(cfg)
|
| 18 |
+
else:
|
| 19 |
+
val_cfg = cfg
|
| 20 |
+
val_cfg.dataset_root = os.path.join(val_cfg.dataset_root, split)
|
| 21 |
+
val_cfg.datasets = cfg.datasets_test
|
| 22 |
+
val_cfg.isTrain = False
|
| 23 |
+
# val_cfg.aug_resize = False
|
| 24 |
+
# val_cfg.aug_crop = False
|
| 25 |
+
val_cfg.aug_flip = False
|
| 26 |
+
val_cfg.serial_batches = True
|
| 27 |
+
val_cfg.jpg_method = ["pil"]
|
| 28 |
+
# Currently assumes jpg_prob, blur_prob 0 or 1
|
| 29 |
+
if len(val_cfg.blur_sig) == 2:
|
| 30 |
+
b_sig = val_cfg.blur_sig
|
| 31 |
+
val_cfg.blur_sig = [(b_sig[0] + b_sig[1]) / 2]
|
| 32 |
+
if len(val_cfg.jpg_qual) != 1:
|
| 33 |
+
j_qual = val_cfg.jpg_qual
|
| 34 |
+
val_cfg.jpg_qual = [int((j_qual[0] + j_qual[-1]) / 2)]
|
| 35 |
+
return val_cfg
|
| 36 |
+
|
| 37 |
+
def validate(model: nn.Module, cfg: CONFIGCLASS):
|
| 38 |
+
from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score
|
| 39 |
+
|
| 40 |
+
from utils1.datasets import create_dataloader
|
| 41 |
+
|
| 42 |
+
data_loader = create_dataloader(cfg)
|
| 43 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 44 |
+
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
y_true, y_pred = [], []
|
| 47 |
+
for data in data_loader:
|
| 48 |
+
img, label, meta = data if len(data) == 3 else (*data, None)
|
| 49 |
+
in_tens = to_cuda(img, device)
|
| 50 |
+
meta = to_cuda(meta, device)
|
| 51 |
+
predict = model(in_tens, meta).sigmoid()
|
| 52 |
+
y_pred.extend(predict.flatten().tolist())
|
| 53 |
+
y_true.extend(label.flatten().tolist())
|
| 54 |
+
|
| 55 |
+
y_true, y_pred = np.array(y_true), np.array(y_pred)
|
| 56 |
+
r_acc = accuracy_score(y_true[y_true == 0], y_pred[y_true == 0] > 0.5)
|
| 57 |
+
f_acc = accuracy_score(y_true[y_true == 1], y_pred[y_true == 1] > 0.5)
|
| 58 |
+
acc = accuracy_score(y_true, y_pred > 0.5)
|
| 59 |
+
ap = average_precision_score(y_true, y_pred)
|
| 60 |
+
results = {
|
| 61 |
+
"ACC": acc,
|
| 62 |
+
"AP": ap,
|
| 63 |
+
"R_ACC": r_acc,
|
| 64 |
+
"F_ACC": f_acc,
|
| 65 |
+
}
|
| 66 |
+
return results
|
AIGVDet/core/utils1/trainer.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.nn import init
|
| 6 |
+
|
| 7 |
+
from utils1.config import CONFIGCLASS
|
| 8 |
+
from utils1.utils import get_network
|
| 9 |
+
from utils1.warmup import GradualWarmupScheduler
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BaseModel(nn.Module):
|
| 13 |
+
def __init__(self, cfg: CONFIGCLASS):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.cfg = cfg
|
| 16 |
+
self.total_steps = 0
|
| 17 |
+
self.isTrain = cfg.isTrain
|
| 18 |
+
self.save_dir = cfg.ckpt_dir
|
| 19 |
+
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 20 |
+
self.model:nn.Module
|
| 21 |
+
self.model=nn.Module.to(self.device)
|
| 22 |
+
# self.model.to(self.device)
|
| 23 |
+
#self.model.load_state_dict(torch.load('./checkpoints/optical.pth'))
|
| 24 |
+
self.optimizer: torch.optim.Optimizer
|
| 25 |
+
|
| 26 |
+
def save_networks(self, epoch: int):
|
| 27 |
+
save_filename = f"model_epoch_{epoch}.pth"
|
| 28 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
| 29 |
+
|
| 30 |
+
# serialize model and optimizer to dict
|
| 31 |
+
state_dict = {
|
| 32 |
+
"model": self.model.state_dict(),
|
| 33 |
+
"optimizer": self.optimizer.state_dict(),
|
| 34 |
+
"total_steps": self.total_steps,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
torch.save(state_dict, save_path)
|
| 38 |
+
|
| 39 |
+
# load models from the disk
|
| 40 |
+
def load_networks(self, epoch: int):
|
| 41 |
+
load_filename = f"model_epoch_{epoch}.pth"
|
| 42 |
+
load_path = os.path.join(self.save_dir, load_filename)
|
| 43 |
+
|
| 44 |
+
if epoch==0:
|
| 45 |
+
# load_filename = f"lsun_adm.pth"
|
| 46 |
+
load_path="checkpoints/optical.pth"
|
| 47 |
+
print("loading optical path")
|
| 48 |
+
else :
|
| 49 |
+
print(f"loading the model from {load_path}")
|
| 50 |
+
|
| 51 |
+
# print(f"loading the model from {load_path}")
|
| 52 |
+
|
| 53 |
+
# if you are using PyTorch newer than 0.4 (e.g., built from
|
| 54 |
+
# GitHub source), you can remove str() on self.device
|
| 55 |
+
state_dict = torch.load(load_path, map_location=self.device)
|
| 56 |
+
if hasattr(state_dict, "_metadata"):
|
| 57 |
+
del state_dict._metadata
|
| 58 |
+
|
| 59 |
+
self.model.load_state_dict(state_dict["model"])
|
| 60 |
+
self.total_steps = state_dict["total_steps"]
|
| 61 |
+
|
| 62 |
+
if self.isTrain and not self.cfg.new_optim:
|
| 63 |
+
self.optimizer.load_state_dict(state_dict["optimizer"])
|
| 64 |
+
# move optimizer state to GPU
|
| 65 |
+
for state in self.optimizer.state.values():
|
| 66 |
+
for k, v in state.items():
|
| 67 |
+
if torch.is_tensor(v):
|
| 68 |
+
state[k] = v.to(self.device)
|
| 69 |
+
|
| 70 |
+
for g in self.optimizer.param_groups:
|
| 71 |
+
g["lr"] = self.cfg.lr
|
| 72 |
+
|
| 73 |
+
def eval(self):
|
| 74 |
+
self.model.eval()
|
| 75 |
+
|
| 76 |
+
def test(self):
|
| 77 |
+
with torch.no_grad():
|
| 78 |
+
self.forward()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def init_weights(net: nn.Module, init_type="normal", gain=0.02):
|
| 82 |
+
def init_func(m: nn.Module):
|
| 83 |
+
classname = m.__class__.__name__
|
| 84 |
+
if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1):
|
| 85 |
+
if init_type == "normal":
|
| 86 |
+
init.normal_(m.weight.data, 0.0, gain)
|
| 87 |
+
elif init_type == "xavier":
|
| 88 |
+
init.xavier_normal_(m.weight.data, gain=gain)
|
| 89 |
+
elif init_type == "kaiming":
|
| 90 |
+
init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
|
| 91 |
+
elif init_type == "orthogonal":
|
| 92 |
+
init.orthogonal_(m.weight.data, gain=gain)
|
| 93 |
+
else:
|
| 94 |
+
raise NotImplementedError(f"initialization method [{init_type}] is not implemented")
|
| 95 |
+
if hasattr(m, "bias") and m.bias is not None:
|
| 96 |
+
init.constant_(m.bias.data, 0.0)
|
| 97 |
+
elif classname.find("BatchNorm2d") != -1:
|
| 98 |
+
init.normal_(m.weight.data, 1.0, gain)
|
| 99 |
+
init.constant_(m.bias.data, 0.0)
|
| 100 |
+
|
| 101 |
+
print(f"initialize network with {init_type}")
|
| 102 |
+
net.apply(init_func)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class Trainer(BaseModel):
|
| 106 |
+
def name(self):
|
| 107 |
+
return "Trainer"
|
| 108 |
+
|
| 109 |
+
def __init__(self, cfg: CONFIGCLASS):
|
| 110 |
+
super().__init__(cfg)
|
| 111 |
+
self.arch = cfg.arch
|
| 112 |
+
self.model = get_network(self.arch, cfg.isTrain, cfg.continue_train, cfg.init_gain, cfg.pretrained)
|
| 113 |
+
|
| 114 |
+
self.loss_fn = nn.BCEWithLogitsLoss()
|
| 115 |
+
# initialize optimizers
|
| 116 |
+
if cfg.optim == "adam":
|
| 117 |
+
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg.lr, betas=(cfg.beta1, 0.999))
|
| 118 |
+
elif cfg.optim == "sgd":
|
| 119 |
+
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=cfg.lr, momentum=0.9, weight_decay=5e-4)
|
| 120 |
+
else:
|
| 121 |
+
raise ValueError("optim should be [adam, sgd]")
|
| 122 |
+
if cfg.warmup:
|
| 123 |
+
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 124 |
+
self.optimizer, cfg.nepoch - cfg.warmup_epoch, eta_min=1e-6
|
| 125 |
+
)
|
| 126 |
+
self.scheduler = GradualWarmupScheduler(
|
| 127 |
+
self.optimizer, multiplier=1, total_epoch=cfg.warmup_epoch, after_scheduler=scheduler_cosine
|
| 128 |
+
)
|
| 129 |
+
self.scheduler.step()
|
| 130 |
+
if cfg.continue_train:
|
| 131 |
+
self.load_networks(cfg.epoch)
|
| 132 |
+
self.model.to(self.device)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def adjust_learning_rate(self, min_lr=1e-6):
|
| 137 |
+
for param_group in self.optimizer.param_groups:
|
| 138 |
+
param_group["lr"] /= 10.0
|
| 139 |
+
if param_group["lr"] < min_lr:
|
| 140 |
+
return False
|
| 141 |
+
return True
|
| 142 |
+
|
| 143 |
+
def set_input(self, input):
|
| 144 |
+
img, label, meta = input if len(input) == 3 else (input[0], input[1], {})
|
| 145 |
+
self.input = img.to(self.device)
|
| 146 |
+
self.label = label.to(self.device).float()
|
| 147 |
+
for k in meta.keys():
|
| 148 |
+
if isinstance(meta[k], torch.Tensor):
|
| 149 |
+
meta[k] = meta[k].to(self.device)
|
| 150 |
+
self.meta = meta
|
| 151 |
+
|
| 152 |
+
def forward(self):
|
| 153 |
+
self.output = self.model(self.input, self.meta)
|
| 154 |
+
|
| 155 |
+
def get_loss(self):
|
| 156 |
+
return self.loss_fn(self.output.squeeze(1), self.label)
|
| 157 |
+
|
| 158 |
+
def optimize_parameters(self):
|
| 159 |
+
self.forward()
|
| 160 |
+
self.loss = self.loss_fn(self.output.squeeze(1), self.label)
|
| 161 |
+
self.optimizer.zero_grad()
|
| 162 |
+
self.loss.backward()
|
| 163 |
+
self.optimizer.step()
|
AIGVDet/core/utils1/utils.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import time
|
| 5 |
+
import warnings
|
| 6 |
+
from importlib import import_module
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.functional")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def str2bool(v: str, strict=True) -> bool:
|
| 17 |
+
if isinstance(v, bool):
|
| 18 |
+
return v
|
| 19 |
+
elif isinstance(v, str):
|
| 20 |
+
if v.lower() in ("true", "yes", "on" "t", "y", "1"):
|
| 21 |
+
return True
|
| 22 |
+
elif v.lower() in ("false", "no", "off", "f", "n", "0"):
|
| 23 |
+
return False
|
| 24 |
+
if strict:
|
| 25 |
+
raise argparse.ArgumentTypeError("Unsupported value encountered.")
|
| 26 |
+
else:
|
| 27 |
+
return True
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def to_cuda(data, device="cuda", exclude_keys: "list[str]" = None):
|
| 31 |
+
if isinstance(data, torch.Tensor):
|
| 32 |
+
data = data.to(device)
|
| 33 |
+
elif isinstance(data, (tuple, list, set)):
|
| 34 |
+
data = [to_cuda(b, device) for b in data]
|
| 35 |
+
elif isinstance(data, dict):
|
| 36 |
+
if exclude_keys is None:
|
| 37 |
+
exclude_keys = []
|
| 38 |
+
for k in data.keys():
|
| 39 |
+
if k not in exclude_keys:
|
| 40 |
+
data[k] = to_cuda(data[k], device)
|
| 41 |
+
else:
|
| 42 |
+
# raise TypeError(f"Unsupported type: {type(data)}")
|
| 43 |
+
data = data
|
| 44 |
+
return data
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class HiddenPrints:
|
| 48 |
+
def __enter__(self):
|
| 49 |
+
self._original_stdout = sys.stdout
|
| 50 |
+
sys.stdout = open(os.devnull, "w")
|
| 51 |
+
|
| 52 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 53 |
+
sys.stdout.close()
|
| 54 |
+
sys.stdout = self._original_stdout
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class Logger(object):
|
| 58 |
+
def __init__(self):
|
| 59 |
+
self.terminal = sys.stdout
|
| 60 |
+
self.file = None
|
| 61 |
+
|
| 62 |
+
def open(self, file, mode=None):
|
| 63 |
+
if mode is None:
|
| 64 |
+
mode = "w"
|
| 65 |
+
self.file = open(file, mode)
|
| 66 |
+
|
| 67 |
+
def write(self, message, is_terminal=1, is_file=1):
|
| 68 |
+
if "\r" in message:
|
| 69 |
+
is_file = 0
|
| 70 |
+
if is_terminal == 1:
|
| 71 |
+
self.terminal.write(message)
|
| 72 |
+
self.terminal.flush()
|
| 73 |
+
if is_file == 1:
|
| 74 |
+
self.file.write(message)
|
| 75 |
+
self.file.flush()
|
| 76 |
+
|
| 77 |
+
def flush(self):
|
| 78 |
+
# this flush method is needed for python 3 compatibility.
|
| 79 |
+
# this handles the flush command by doing nothing.
|
| 80 |
+
# you might want to specify some extra behavior here.
|
| 81 |
+
pass
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_network(arch: str, isTrain=False, continue_train=False, init_gain=0.02, pretrained=True):
|
| 85 |
+
if "resnet" in arch:
|
| 86 |
+
from ...networks.resnet import ResNet
|
| 87 |
+
|
| 88 |
+
resnet = getattr(import_module("...networks.resnet", package=__package__), arch)
|
| 89 |
+
if isTrain:
|
| 90 |
+
if continue_train:
|
| 91 |
+
model: ResNet = resnet(num_classes=1)
|
| 92 |
+
else:
|
| 93 |
+
model: ResNet = resnet(pretrained=pretrained)
|
| 94 |
+
model.fc = nn.Linear(2048, 1)
|
| 95 |
+
nn.init.normal_(model.fc.weight.data, 0.0, init_gain)
|
| 96 |
+
else:
|
| 97 |
+
model: ResNet = resnet(num_classes=1)
|
| 98 |
+
return model
|
| 99 |
+
else:
|
| 100 |
+
raise ValueError(f"Unsupported arch: {arch}")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def pad_img_to_square(img: np.ndarray):
|
| 104 |
+
H, W = img.shape[:2]
|
| 105 |
+
if H != W:
|
| 106 |
+
new_size = max(H, W)
|
| 107 |
+
img = np.pad(img, ((0, new_size - H), (0, new_size - W), (0, 0)), mode="constant")
|
| 108 |
+
assert img.shape[0] == img.shape[1] == new_size
|
| 109 |
+
return img
|
AIGVDet/core/utils1/utils1/config.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
from abc import ABC
|
| 5 |
+
from typing import Type
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DefaultConfigs(ABC):
|
| 9 |
+
####### base setting ######
|
| 10 |
+
gpus = [0]
|
| 11 |
+
seed = 3407
|
| 12 |
+
arch = "resnet50"
|
| 13 |
+
datasets = ["zhaolian_train"]
|
| 14 |
+
datasets_test = ["adm_res_abs_ddim20s"]
|
| 15 |
+
mode = "binary"
|
| 16 |
+
class_bal = False
|
| 17 |
+
batch_size = 64
|
| 18 |
+
loadSize = 256
|
| 19 |
+
cropSize = 224
|
| 20 |
+
epoch = "latest"
|
| 21 |
+
num_workers = 20
|
| 22 |
+
serial_batches = False
|
| 23 |
+
isTrain = True
|
| 24 |
+
|
| 25 |
+
# data augmentation
|
| 26 |
+
rz_interp = ["bilinear"]
|
| 27 |
+
# blur_prob = 0.0
|
| 28 |
+
blur_prob = 0.1
|
| 29 |
+
blur_sig = [0.5]
|
| 30 |
+
# jpg_prob = 0.0
|
| 31 |
+
jpg_prob = 0.1
|
| 32 |
+
jpg_method = ["cv2"]
|
| 33 |
+
jpg_qual = [75]
|
| 34 |
+
gray_prob = 0.0
|
| 35 |
+
aug_resize = True
|
| 36 |
+
aug_crop = True
|
| 37 |
+
aug_flip = True
|
| 38 |
+
aug_norm = True
|
| 39 |
+
|
| 40 |
+
####### train setting ######
|
| 41 |
+
warmup = False
|
| 42 |
+
# warmup = True
|
| 43 |
+
warmup_epoch = 3
|
| 44 |
+
earlystop = True
|
| 45 |
+
earlystop_epoch = 5
|
| 46 |
+
optim = "adam"
|
| 47 |
+
new_optim = False
|
| 48 |
+
loss_freq = 400
|
| 49 |
+
save_latest_freq = 2000
|
| 50 |
+
save_epoch_freq = 20
|
| 51 |
+
continue_train = False
|
| 52 |
+
epoch_count = 1
|
| 53 |
+
last_epoch = -1
|
| 54 |
+
nepoch = 400
|
| 55 |
+
beta1 = 0.9
|
| 56 |
+
lr = 0.0001
|
| 57 |
+
init_type = "normal"
|
| 58 |
+
init_gain = 0.02
|
| 59 |
+
pretrained = True
|
| 60 |
+
|
| 61 |
+
# paths information
|
| 62 |
+
root_dir1 = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 63 |
+
root_dir = os.path.dirname(root_dir1)
|
| 64 |
+
dataset_root = os.path.join(root_dir, "data")
|
| 65 |
+
exp_root = os.path.join(root_dir, "data", "exp")
|
| 66 |
+
_exp_name = ""
|
| 67 |
+
exp_dir = ""
|
| 68 |
+
ckpt_dir = ""
|
| 69 |
+
logs_path = ""
|
| 70 |
+
ckpt_path = ""
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def exp_name(self):
|
| 74 |
+
return self._exp_name
|
| 75 |
+
|
| 76 |
+
@exp_name.setter
|
| 77 |
+
def exp_name(self, value: str):
|
| 78 |
+
self._exp_name = value
|
| 79 |
+
self.exp_dir: str = os.path.join(self.exp_root, self.exp_name)
|
| 80 |
+
self.ckpt_dir: str = os.path.join(self.exp_dir, "ckpt")
|
| 81 |
+
self.logs_path: str = os.path.join(self.exp_dir, "logs.txt")
|
| 82 |
+
|
| 83 |
+
os.makedirs(self.exp_dir, exist_ok=True)
|
| 84 |
+
os.makedirs(self.ckpt_dir, exist_ok=True)
|
| 85 |
+
|
| 86 |
+
def to_dict(self):
|
| 87 |
+
dic = {}
|
| 88 |
+
for fieldkey in dir(self):
|
| 89 |
+
fieldvalue = getattr(self, fieldkey)
|
| 90 |
+
if not fieldkey.startswith("__") and not callable(fieldvalue) and not fieldkey.startswith("_"):
|
| 91 |
+
dic[fieldkey] = fieldvalue
|
| 92 |
+
return dic
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def args_list2dict(arg_list: list):
|
| 96 |
+
assert len(arg_list) % 2 == 0, f"Override list has odd length: {arg_list}; it must be a list of pairs"
|
| 97 |
+
return dict(zip(arg_list[::2], arg_list[1::2]))
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def str2bool(v: str) -> bool:
|
| 101 |
+
if isinstance(v, bool):
|
| 102 |
+
return v
|
| 103 |
+
elif v.lower() in ("true", "yes", "on", "y", "t", "1"):
|
| 104 |
+
return True
|
| 105 |
+
elif v.lower() in ("false", "no", "off", "n", "f", "0"):
|
| 106 |
+
return False
|
| 107 |
+
else:
|
| 108 |
+
return bool(v)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def str2list(v: str, element_type=None) -> list:
|
| 112 |
+
if not isinstance(v, (list, tuple, set)):
|
| 113 |
+
v = v.lstrip("[").rstrip("]")
|
| 114 |
+
v = v.split(",")
|
| 115 |
+
v = list(map(str.strip, v))
|
| 116 |
+
if element_type is not None:
|
| 117 |
+
v = list(map(element_type, v))
|
| 118 |
+
return v
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
CONFIGCLASS = Type[DefaultConfigs]
|
| 122 |
+
|
| 123 |
+
parser = argparse.ArgumentParser()
|
| 124 |
+
parser.add_argument("--gpus", default=[0], type=int, nargs="+")
|
| 125 |
+
parser.add_argument("--exp_name", default="", type=str)
|
| 126 |
+
parser.add_argument("--ckpt", default="model_epoch_latest.pth", type=str)
|
| 127 |
+
parser.add_argument("opts", default=[], nargs=argparse.REMAINDER)
|
| 128 |
+
args = parser.parse_args()
|
| 129 |
+
|
| 130 |
+
if os.path.exists(os.path.join(DefaultConfigs.exp_root, args.exp_name, "config.py")):
|
| 131 |
+
sys.path.insert(0, os.path.join(DefaultConfigs.exp_root, args.exp_name))
|
| 132 |
+
from config import cfg
|
| 133 |
+
|
| 134 |
+
cfg: CONFIGCLASS
|
| 135 |
+
else:
|
| 136 |
+
cfg = DefaultConfigs()
|
| 137 |
+
|
| 138 |
+
if args.opts:
|
| 139 |
+
opts = args_list2dict(args.opts)
|
| 140 |
+
for k, v in opts.items():
|
| 141 |
+
if not hasattr(cfg, k):
|
| 142 |
+
raise ValueError(f"Unrecognized option: {k}")
|
| 143 |
+
original_type = type(getattr(cfg, k))
|
| 144 |
+
if original_type == bool:
|
| 145 |
+
setattr(cfg, k, str2bool(v))
|
| 146 |
+
elif original_type in (list, tuple, set):
|
| 147 |
+
setattr(cfg, k, str2list(v, type(getattr(cfg, k)[0])))
|
| 148 |
+
else:
|
| 149 |
+
setattr(cfg, k, original_type(v))
|
| 150 |
+
|
| 151 |
+
cfg.gpus: list = args.gpus
|
| 152 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ", ".join([str(gpu) for gpu in cfg.gpus])
|
| 153 |
+
cfg.exp_name = args.exp_name
|
| 154 |
+
cfg.ckpt_path: str = os.path.join(cfg.ckpt_dir, args.ckpt)
|
| 155 |
+
|
| 156 |
+
if isinstance(cfg.datasets, str):
|
| 157 |
+
cfg.datasets = cfg.datasets.split(",")
|
AIGVDet/core/utils1/utils1/datasets.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from io import BytesIO
|
| 3 |
+
from random import choice, random
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.utils.data
|
| 9 |
+
import torchvision.datasets as datasets
|
| 10 |
+
import torchvision.transforms as transforms
|
| 11 |
+
import torchvision.transforms.functional as TF
|
| 12 |
+
from PIL import Image, ImageFile
|
| 13 |
+
from scipy.ndimage import gaussian_filter
|
| 14 |
+
from torch.utils.data.sampler import WeightedRandomSampler
|
| 15 |
+
|
| 16 |
+
from .config import CONFIGCLASS
|
| 17 |
+
|
| 18 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def dataset_folder(root: str, cfg: CONFIGCLASS):
|
| 22 |
+
if cfg.mode == "binary":
|
| 23 |
+
return binary_dataset(root, cfg)
|
| 24 |
+
if cfg.mode == "filename":
|
| 25 |
+
return FileNameDataset(root, cfg)
|
| 26 |
+
raise ValueError("cfg.mode needs to be binary or filename.")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def binary_dataset(root: str, cfg: CONFIGCLASS):
|
| 30 |
+
identity_transform = transforms.Lambda(lambda img: img)
|
| 31 |
+
|
| 32 |
+
rz_func = identity_transform
|
| 33 |
+
|
| 34 |
+
if cfg.isTrain:
|
| 35 |
+
crop_func = transforms.RandomCrop((448,448))
|
| 36 |
+
else:
|
| 37 |
+
crop_func = transforms.CenterCrop((448,448)) if cfg.aug_crop else identity_transform
|
| 38 |
+
|
| 39 |
+
if cfg.isTrain and cfg.aug_flip:
|
| 40 |
+
flip_func = transforms.RandomHorizontalFlip()
|
| 41 |
+
else:
|
| 42 |
+
flip_func = identity_transform
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
return datasets.ImageFolder(
|
| 46 |
+
root,
|
| 47 |
+
transforms.Compose(
|
| 48 |
+
[
|
| 49 |
+
rz_func,
|
| 50 |
+
#change
|
| 51 |
+
transforms.Lambda(lambda img: blur_jpg_augment(img, cfg)),
|
| 52 |
+
crop_func,
|
| 53 |
+
flip_func,
|
| 54 |
+
transforms.ToTensor(),
|
| 55 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 56 |
+
if cfg.aug_norm
|
| 57 |
+
else identity_transform,
|
| 58 |
+
]
|
| 59 |
+
)
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class FileNameDataset(datasets.ImageFolder):
|
| 64 |
+
def name(self):
|
| 65 |
+
return 'FileNameDataset'
|
| 66 |
+
|
| 67 |
+
def __init__(self, opt, root):
|
| 68 |
+
self.opt = opt
|
| 69 |
+
super().__init__(root)
|
| 70 |
+
|
| 71 |
+
def __getitem__(self, index):
|
| 72 |
+
# Loading sample
|
| 73 |
+
path, target = self.samples[index]
|
| 74 |
+
return path
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def blur_jpg_augment(img: Image.Image, cfg: CONFIGCLASS):
|
| 78 |
+
img: np.ndarray = np.array(img)
|
| 79 |
+
if cfg.isTrain:
|
| 80 |
+
if random() < cfg.blur_prob:
|
| 81 |
+
sig = sample_continuous(cfg.blur_sig)
|
| 82 |
+
gaussian_blur(img, sig)
|
| 83 |
+
|
| 84 |
+
if random() < cfg.jpg_prob:
|
| 85 |
+
method = sample_discrete(cfg.jpg_method)
|
| 86 |
+
qual = sample_discrete(cfg.jpg_qual)
|
| 87 |
+
img = jpeg_from_key(img, qual, method)
|
| 88 |
+
|
| 89 |
+
return Image.fromarray(img)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def sample_continuous(s: list):
|
| 93 |
+
if len(s) == 1:
|
| 94 |
+
return s[0]
|
| 95 |
+
if len(s) == 2:
|
| 96 |
+
rg = s[1] - s[0]
|
| 97 |
+
return random() * rg + s[0]
|
| 98 |
+
raise ValueError("Length of iterable s should be 1 or 2.")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def sample_discrete(s: list):
|
| 102 |
+
return s[0] if len(s) == 1 else choice(s)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def gaussian_blur(img: np.ndarray, sigma: float):
|
| 106 |
+
gaussian_filter(img[:, :, 0], output=img[:, :, 0], sigma=sigma)
|
| 107 |
+
gaussian_filter(img[:, :, 1], output=img[:, :, 1], sigma=sigma)
|
| 108 |
+
gaussian_filter(img[:, :, 2], output=img[:, :, 2], sigma=sigma)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def cv2_jpg(img: np.ndarray, compress_val: int) -> np.ndarray:
|
| 112 |
+
img_cv2 = img[:, :, ::-1]
|
| 113 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), compress_val]
|
| 114 |
+
result, encimg = cv2.imencode(".jpg", img_cv2, encode_param)
|
| 115 |
+
decimg = cv2.imdecode(encimg, 1)
|
| 116 |
+
return decimg[:, :, ::-1]
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def pil_jpg(img: np.ndarray, compress_val: int):
|
| 120 |
+
out = BytesIO()
|
| 121 |
+
img = Image.fromarray(img)
|
| 122 |
+
img.save(out, format="jpeg", quality=compress_val)
|
| 123 |
+
img = Image.open(out)
|
| 124 |
+
# load from memory before ByteIO closes
|
| 125 |
+
img = np.array(img)
|
| 126 |
+
out.close()
|
| 127 |
+
return img
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
jpeg_dict = {"cv2": cv2_jpg, "pil": pil_jpg}
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def jpeg_from_key(img: np.ndarray, compress_val: int, key: str) -> np.ndarray:
|
| 134 |
+
method = jpeg_dict[key]
|
| 135 |
+
return method(img, compress_val)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
rz_dict = {'bilinear': Image.BILINEAR,
|
| 139 |
+
'bicubic': Image.BICUBIC,
|
| 140 |
+
'lanczos': Image.LANCZOS,
|
| 141 |
+
'nearest': Image.NEAREST}
|
| 142 |
+
def custom_resize(img: Image.Image, cfg: CONFIGCLASS) -> Image.Image:
|
| 143 |
+
interp = sample_discrete(cfg.rz_interp)
|
| 144 |
+
return TF.resize(img, cfg.loadSize, interpolation=rz_dict[interp])
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def get_dataset(cfg: CONFIGCLASS):
|
| 148 |
+
dset_lst = []
|
| 149 |
+
for dataset in cfg.datasets:
|
| 150 |
+
root = os.path.join(cfg.dataset_root, dataset)
|
| 151 |
+
dset = dataset_folder(root, cfg)
|
| 152 |
+
dset_lst.append(dset)
|
| 153 |
+
return torch.utils.data.ConcatDataset(dset_lst)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def get_bal_sampler(dataset: torch.utils.data.ConcatDataset):
|
| 157 |
+
targets = []
|
| 158 |
+
for d in dataset.datasets:
|
| 159 |
+
targets.extend(d.targets)
|
| 160 |
+
|
| 161 |
+
ratio = np.bincount(targets)
|
| 162 |
+
w = 1.0 / torch.tensor(ratio, dtype=torch.float)
|
| 163 |
+
sample_weights = w[targets]
|
| 164 |
+
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights))
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def create_dataloader(cfg: CONFIGCLASS):
|
| 168 |
+
shuffle = not cfg.serial_batches if (cfg.isTrain and not cfg.class_bal) else False
|
| 169 |
+
dataset = get_dataset(cfg)
|
| 170 |
+
sampler = get_bal_sampler(dataset) if cfg.class_bal else None
|
| 171 |
+
|
| 172 |
+
return torch.utils.data.DataLoader(
|
| 173 |
+
dataset,
|
| 174 |
+
batch_size=cfg.batch_size,
|
| 175 |
+
shuffle=shuffle,
|
| 176 |
+
sampler=sampler,
|
| 177 |
+
num_workers=int(cfg.num_workers),
|
| 178 |
+
)
|
AIGVDet/core/utils1/utils1/earlystop.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from .trainer import Trainer
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class EarlyStopping:
|
| 7 |
+
"""Early stops the training if validation loss doesn't improve after a given patience."""
|
| 8 |
+
|
| 9 |
+
def __init__(self, patience=1, verbose=False, delta=0):
|
| 10 |
+
"""
|
| 11 |
+
Args:
|
| 12 |
+
patience (int): How long to wait after last time validation loss improved.
|
| 13 |
+
Default: 7
|
| 14 |
+
verbose (bool): If True, prints a message for each validation loss improvement.
|
| 15 |
+
Default: False
|
| 16 |
+
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
|
| 17 |
+
Default: 0
|
| 18 |
+
"""
|
| 19 |
+
self.patience = patience
|
| 20 |
+
self.verbose = verbose
|
| 21 |
+
self.counter = 0
|
| 22 |
+
self.best_score = None
|
| 23 |
+
self.early_stop = False
|
| 24 |
+
self.score_max = -np.Inf
|
| 25 |
+
self.delta = delta
|
| 26 |
+
|
| 27 |
+
def __call__(self, score: float, trainer: Trainer):
|
| 28 |
+
if self.best_score is None:
|
| 29 |
+
self.best_score = score
|
| 30 |
+
self.save_checkpoint(score, trainer)
|
| 31 |
+
elif score < self.best_score - self.delta:
|
| 32 |
+
self.counter += 1
|
| 33 |
+
print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
|
| 34 |
+
if self.counter >= self.patience:
|
| 35 |
+
self.early_stop = True
|
| 36 |
+
else:
|
| 37 |
+
self.best_score = score
|
| 38 |
+
self.save_checkpoint(score, trainer)
|
| 39 |
+
self.counter = 0
|
| 40 |
+
|
| 41 |
+
def save_checkpoint(self, score: float, trainer: Trainer):
|
| 42 |
+
"""Saves model when validation loss decrease."""
|
| 43 |
+
if self.verbose:
|
| 44 |
+
print(f"Validation accuracy increased ({self.score_max:.6f} --> {score:.6f}). Saving model ...")
|
| 45 |
+
trainer.save_networks("best")
|
| 46 |
+
self.score_max = score
|
AIGVDet/core/utils1/utils1/eval.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
from .config import CONFIGCLASS
|
| 10 |
+
from .utils import to_cuda
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_val_cfg(cfg: CONFIGCLASS, split="val", copy=True):
|
| 14 |
+
if copy:
|
| 15 |
+
from copy import deepcopy
|
| 16 |
+
|
| 17 |
+
val_cfg = deepcopy(cfg)
|
| 18 |
+
else:
|
| 19 |
+
val_cfg = cfg
|
| 20 |
+
val_cfg.dataset_root = os.path.join(val_cfg.dataset_root, split)
|
| 21 |
+
val_cfg.datasets = cfg.datasets_test
|
| 22 |
+
val_cfg.isTrain = False
|
| 23 |
+
# val_cfg.aug_resize = False
|
| 24 |
+
# val_cfg.aug_crop = False
|
| 25 |
+
val_cfg.aug_flip = False
|
| 26 |
+
val_cfg.serial_batches = True
|
| 27 |
+
val_cfg.jpg_method = ["pil"]
|
| 28 |
+
# Currently assumes jpg_prob, blur_prob 0 or 1
|
| 29 |
+
if len(val_cfg.blur_sig) == 2:
|
| 30 |
+
b_sig = val_cfg.blur_sig
|
| 31 |
+
val_cfg.blur_sig = [(b_sig[0] + b_sig[1]) / 2]
|
| 32 |
+
if len(val_cfg.jpg_qual) != 1:
|
| 33 |
+
j_qual = val_cfg.jpg_qual
|
| 34 |
+
val_cfg.jpg_qual = [int((j_qual[0] + j_qual[-1]) / 2)]
|
| 35 |
+
return val_cfg
|
| 36 |
+
|
| 37 |
+
def validate(model: nn.Module, cfg: CONFIGCLASS):
|
| 38 |
+
from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score
|
| 39 |
+
|
| 40 |
+
from .datasets import create_dataloader
|
| 41 |
+
|
| 42 |
+
data_loader = create_dataloader(cfg)
|
| 43 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 44 |
+
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
y_true, y_pred = [], []
|
| 47 |
+
for data in data_loader:
|
| 48 |
+
img, label, meta = data if len(data) == 3 else (*data, None)
|
| 49 |
+
in_tens = to_cuda(img, device)
|
| 50 |
+
meta = to_cuda(meta, device)
|
| 51 |
+
predict = model(in_tens, meta).sigmoid()
|
| 52 |
+
y_pred.extend(predict.flatten().tolist())
|
| 53 |
+
y_true.extend(label.flatten().tolist())
|
| 54 |
+
|
| 55 |
+
y_true, y_pred = np.array(y_true), np.array(y_pred)
|
| 56 |
+
r_acc = accuracy_score(y_true[y_true == 0], y_pred[y_true == 0] > 0.5)
|
| 57 |
+
f_acc = accuracy_score(y_true[y_true == 1], y_pred[y_true == 1] > 0.5)
|
| 58 |
+
acc = accuracy_score(y_true, y_pred > 0.5)
|
| 59 |
+
ap = average_precision_score(y_true, y_pred)
|
| 60 |
+
results = {
|
| 61 |
+
"ACC": acc,
|
| 62 |
+
"AP": ap,
|
| 63 |
+
"R_ACC": r_acc,
|
| 64 |
+
"F_ACC": f_acc,
|
| 65 |
+
}
|
| 66 |
+
return results
|
AIGVDet/core/utils1/utils1/trainer.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.nn import init
|
| 6 |
+
|
| 7 |
+
from .config import CONFIGCLASS
|
| 8 |
+
from .utils import get_network
|
| 9 |
+
from .warmup import GradualWarmupScheduler
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BaseModel(nn.Module):
|
| 13 |
+
def __init__(self, cfg: CONFIGCLASS):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.cfg = cfg
|
| 16 |
+
self.total_steps = 0
|
| 17 |
+
self.isTrain = cfg.isTrain
|
| 18 |
+
self.save_dir = cfg.ckpt_dir
|
| 19 |
+
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 20 |
+
self.model:nn.Module
|
| 21 |
+
self.model=nn.Module.to(self.device)
|
| 22 |
+
# self.model.to(self.device)
|
| 23 |
+
self.model.load_state_dict(torch.load('./checkpoints/optical.pth'))
|
| 24 |
+
self.optimizer: torch.optim.Optimizer
|
| 25 |
+
|
| 26 |
+
def save_networks(self, epoch: int):
|
| 27 |
+
save_filename = f"model_epoch_{epoch}.pth"
|
| 28 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
| 29 |
+
|
| 30 |
+
# serialize model and optimizer to dict
|
| 31 |
+
state_dict = {
|
| 32 |
+
"model": self.model.state_dict(),
|
| 33 |
+
"optimizer": self.optimizer.state_dict(),
|
| 34 |
+
"total_steps": self.total_steps,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
torch.save(state_dict, save_path)
|
| 38 |
+
|
| 39 |
+
# load models from the disk
|
| 40 |
+
def load_networks(self, epoch: int):
|
| 41 |
+
load_filename = f"model_epoch_{epoch}.pth"
|
| 42 |
+
load_path = os.path.join(self.save_dir, load_filename)
|
| 43 |
+
|
| 44 |
+
if epoch==0:
|
| 45 |
+
# load_filename = f"lsun_adm.pth"
|
| 46 |
+
load_path="checkpoints/optical.pth"
|
| 47 |
+
print("loading optical path")
|
| 48 |
+
else :
|
| 49 |
+
print(f"loading the model from {load_path}")
|
| 50 |
+
|
| 51 |
+
# print(f"loading the model from {load_path}")
|
| 52 |
+
|
| 53 |
+
# if you are using PyTorch newer than 0.4 (e.g., built from
|
| 54 |
+
# GitHub source), you can remove str() on self.device
|
| 55 |
+
state_dict = torch.load(load_path, map_location=self.device)
|
| 56 |
+
if hasattr(state_dict, "_metadata"):
|
| 57 |
+
del state_dict._metadata
|
| 58 |
+
|
| 59 |
+
self.model.load_state_dict(state_dict["model"])
|
| 60 |
+
self.total_steps = state_dict["total_steps"]
|
| 61 |
+
|
| 62 |
+
if self.isTrain and not self.cfg.new_optim:
|
| 63 |
+
self.optimizer.load_state_dict(state_dict["optimizer"])
|
| 64 |
+
# move optimizer state to GPU
|
| 65 |
+
for state in self.optimizer.state.values():
|
| 66 |
+
for k, v in state.items():
|
| 67 |
+
if torch.is_tensor(v):
|
| 68 |
+
state[k] = v.to(self.device)
|
| 69 |
+
|
| 70 |
+
for g in self.optimizer.param_groups:
|
| 71 |
+
g["lr"] = self.cfg.lr
|
| 72 |
+
|
| 73 |
+
def eval(self):
|
| 74 |
+
self.model.eval()
|
| 75 |
+
|
| 76 |
+
def test(self):
|
| 77 |
+
with torch.no_grad():
|
| 78 |
+
self.forward()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def init_weights(net: nn.Module, init_type="normal", gain=0.02):
|
| 82 |
+
def init_func(m: nn.Module):
|
| 83 |
+
classname = m.__class__.__name__
|
| 84 |
+
if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1):
|
| 85 |
+
if init_type == "normal":
|
| 86 |
+
init.normal_(m.weight.data, 0.0, gain)
|
| 87 |
+
elif init_type == "xavier":
|
| 88 |
+
init.xavier_normal_(m.weight.data, gain=gain)
|
| 89 |
+
elif init_type == "kaiming":
|
| 90 |
+
init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
|
| 91 |
+
elif init_type == "orthogonal":
|
| 92 |
+
init.orthogonal_(m.weight.data, gain=gain)
|
| 93 |
+
else:
|
| 94 |
+
raise NotImplementedError(f"initialization method [{init_type}] is not implemented")
|
| 95 |
+
if hasattr(m, "bias") and m.bias is not None:
|
| 96 |
+
init.constant_(m.bias.data, 0.0)
|
| 97 |
+
elif classname.find("BatchNorm2d") != -1:
|
| 98 |
+
init.normal_(m.weight.data, 1.0, gain)
|
| 99 |
+
init.constant_(m.bias.data, 0.0)
|
| 100 |
+
|
| 101 |
+
print(f"initialize network with {init_type}")
|
| 102 |
+
net.apply(init_func)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class Trainer(BaseModel):
|
| 106 |
+
def name(self):
|
| 107 |
+
return "Trainer"
|
| 108 |
+
|
| 109 |
+
def __init__(self, cfg: CONFIGCLASS):
|
| 110 |
+
super().__init__(cfg)
|
| 111 |
+
self.arch = cfg.arch
|
| 112 |
+
self.model = get_network(self.arch, cfg.isTrain, cfg.continue_train, cfg.init_gain, cfg.pretrained)
|
| 113 |
+
|
| 114 |
+
self.loss_fn = nn.BCEWithLogitsLoss()
|
| 115 |
+
# initialize optimizers
|
| 116 |
+
if cfg.optim == "adam":
|
| 117 |
+
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg.lr, betas=(cfg.beta1, 0.999))
|
| 118 |
+
elif cfg.optim == "sgd":
|
| 119 |
+
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=cfg.lr, momentum=0.9, weight_decay=5e-4)
|
| 120 |
+
else:
|
| 121 |
+
raise ValueError("optim should be [adam, sgd]")
|
| 122 |
+
if cfg.warmup:
|
| 123 |
+
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 124 |
+
self.optimizer, cfg.nepoch - cfg.warmup_epoch, eta_min=1e-6
|
| 125 |
+
)
|
| 126 |
+
self.scheduler = GradualWarmupScheduler(
|
| 127 |
+
self.optimizer, multiplier=1, total_epoch=cfg.warmup_epoch, after_scheduler=scheduler_cosine
|
| 128 |
+
)
|
| 129 |
+
self.scheduler.step()
|
| 130 |
+
if cfg.continue_train:
|
| 131 |
+
self.load_networks(cfg.epoch)
|
| 132 |
+
self.model.to(self.device)
|
| 133 |
+
|
| 134 |
+
# self.model.load_state_dict(torch.load('checkpoints/optical.pth'))
|
| 135 |
+
load_path='checkpoints/optical.pth'
|
| 136 |
+
state_dict = torch.load(load_path, map_location=self.device)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
self.model.load_state_dict(state_dict["model"])
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def adjust_learning_rate(self, min_lr=1e-6):
|
| 143 |
+
for param_group in self.optimizer.param_groups:
|
| 144 |
+
param_group["lr"] /= 10.0
|
| 145 |
+
if param_group["lr"] < min_lr:
|
| 146 |
+
return False
|
| 147 |
+
return True
|
| 148 |
+
|
| 149 |
+
def set_input(self, input):
|
| 150 |
+
img, label, meta = input if len(input) == 3 else (input[0], input[1], {})
|
| 151 |
+
self.input = img.to(self.device)
|
| 152 |
+
self.label = label.to(self.device).float()
|
| 153 |
+
for k in meta.keys():
|
| 154 |
+
if isinstance(meta[k], torch.Tensor):
|
| 155 |
+
meta[k] = meta[k].to(self.device)
|
| 156 |
+
self.meta = meta
|
| 157 |
+
|
| 158 |
+
def forward(self):
|
| 159 |
+
self.output = self.model(self.input, self.meta)
|
| 160 |
+
|
| 161 |
+
def get_loss(self):
|
| 162 |
+
return self.loss_fn(self.output.squeeze(1), self.label)
|
| 163 |
+
|
| 164 |
+
def optimize_parameters(self):
|
| 165 |
+
self.forward()
|
| 166 |
+
self.loss = self.loss_fn(self.output.squeeze(1), self.label)
|
| 167 |
+
self.optimizer.zero_grad()
|
| 168 |
+
self.loss.backward()
|
| 169 |
+
self.optimizer.step()
|
AIGVDet/core/utils1/utils1/utils.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import time
|
| 5 |
+
import warnings
|
| 6 |
+
from importlib import import_module
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.functional")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def str2bool(v: str, strict=True) -> bool:
|
| 17 |
+
if isinstance(v, bool):
|
| 18 |
+
return v
|
| 19 |
+
elif isinstance(v, str):
|
| 20 |
+
if v.lower() in ("true", "yes", "on" "t", "y", "1"):
|
| 21 |
+
return True
|
| 22 |
+
elif v.lower() in ("false", "no", "off", "f", "n", "0"):
|
| 23 |
+
return False
|
| 24 |
+
if strict:
|
| 25 |
+
raise argparse.ArgumentTypeError("Unsupported value encountered.")
|
| 26 |
+
else:
|
| 27 |
+
return True
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def to_cuda(data, device="cuda", exclude_keys: "list[str]" = None):
|
| 31 |
+
if isinstance(data, torch.Tensor):
|
| 32 |
+
data = data.to(device)
|
| 33 |
+
elif isinstance(data, (tuple, list, set)):
|
| 34 |
+
data = [to_cuda(b, device) for b in data]
|
| 35 |
+
elif isinstance(data, dict):
|
| 36 |
+
if exclude_keys is None:
|
| 37 |
+
exclude_keys = []
|
| 38 |
+
for k in data.keys():
|
| 39 |
+
if k not in exclude_keys:
|
| 40 |
+
data[k] = to_cuda(data[k], device)
|
| 41 |
+
else:
|
| 42 |
+
# raise TypeError(f"Unsupported type: {type(data)}")
|
| 43 |
+
data = data
|
| 44 |
+
return data
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class HiddenPrints:
|
| 48 |
+
def __enter__(self):
|
| 49 |
+
self._original_stdout = sys.stdout
|
| 50 |
+
sys.stdout = open(os.devnull, "w")
|
| 51 |
+
|
| 52 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 53 |
+
sys.stdout.close()
|
| 54 |
+
sys.stdout = self._original_stdout
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class Logger(object):
|
| 58 |
+
def __init__(self):
|
| 59 |
+
self.terminal = sys.stdout
|
| 60 |
+
self.file = None
|
| 61 |
+
|
| 62 |
+
def open(self, file, mode=None):
|
| 63 |
+
if mode is None:
|
| 64 |
+
mode = "w"
|
| 65 |
+
self.file = open(file, mode)
|
| 66 |
+
|
| 67 |
+
def write(self, message, is_terminal=1, is_file=1):
|
| 68 |
+
if "\r" in message:
|
| 69 |
+
is_file = 0
|
| 70 |
+
if is_terminal == 1:
|
| 71 |
+
self.terminal.write(message)
|
| 72 |
+
self.terminal.flush()
|
| 73 |
+
if is_file == 1:
|
| 74 |
+
self.file.write(message)
|
| 75 |
+
self.file.flush()
|
| 76 |
+
|
| 77 |
+
def flush(self):
|
| 78 |
+
# this flush method is needed for python 3 compatibility.
|
| 79 |
+
# this handles the flush command by doing nothing.
|
| 80 |
+
# you might want to specify some extra behavior here.
|
| 81 |
+
pass
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_network(arch: str, isTrain=False, continue_train=False, init_gain=0.02, pretrained=True):
|
| 85 |
+
if "resnet" in arch:
|
| 86 |
+
from networks.resnet import ResNet
|
| 87 |
+
|
| 88 |
+
resnet = getattr(import_module("networks.resnet"), arch)
|
| 89 |
+
if isTrain:
|
| 90 |
+
if continue_train:
|
| 91 |
+
model: ResNet = resnet(num_classes=1)
|
| 92 |
+
else:
|
| 93 |
+
model: ResNet = resnet(pretrained=pretrained)
|
| 94 |
+
model.fc = nn.Linear(2048, 1)
|
| 95 |
+
nn.init.normal_(model.fc.weight.data, 0.0, init_gain)
|
| 96 |
+
else:
|
| 97 |
+
model: ResNet = resnet(num_classes=1)
|
| 98 |
+
return model
|
| 99 |
+
else:
|
| 100 |
+
raise ValueError(f"Unsupported arch: {arch}")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def pad_img_to_square(img: np.ndarray):
|
| 104 |
+
H, W = img.shape[:2]
|
| 105 |
+
if H != W:
|
| 106 |
+
new_size = max(H, W)
|
| 107 |
+
img = np.pad(img, ((0, new_size - H), (0, new_size - W), (0, 0)), mode="constant")
|
| 108 |
+
assert img.shape[0] == img.shape[1] == new_size
|
| 109 |
+
return img
|
AIGVDet/core/utils1/utils1/warmup.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class GradualWarmupScheduler(_LRScheduler):
|
| 5 |
+
"""Gradually warm-up(increasing) learning rate in optimizer.
|
| 6 |
+
Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
optimizer (Optimizer): Wrapped optimizer.
|
| 10 |
+
multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
|
| 11 |
+
total_epoch: target learning rate is reached at total_epoch, gradually
|
| 12 |
+
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
|
| 16 |
+
self.multiplier = multiplier
|
| 17 |
+
if self.multiplier < 1.0:
|
| 18 |
+
raise ValueError("multiplier should be greater thant or equal to 1.")
|
| 19 |
+
self.total_epoch = total_epoch
|
| 20 |
+
self.after_scheduler = after_scheduler
|
| 21 |
+
self.finished = False
|
| 22 |
+
super().__init__(optimizer)
|
| 23 |
+
|
| 24 |
+
def get_lr(self):
|
| 25 |
+
if self.last_epoch > self.total_epoch:
|
| 26 |
+
if self.after_scheduler:
|
| 27 |
+
if not self.finished:
|
| 28 |
+
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
|
| 29 |
+
self.finished = True
|
| 30 |
+
return self.after_scheduler.get_last_lr()
|
| 31 |
+
return [base_lr * self.multiplier for base_lr in self.base_lrs]
|
| 32 |
+
|
| 33 |
+
if self.multiplier == 1.0:
|
| 34 |
+
return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
|
| 35 |
+
else:
|
| 36 |
+
return [
|
| 37 |
+
base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
|
| 38 |
+
for base_lr in self.base_lrs
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
def step_ReduceLROnPlateau(self, metrics, epoch=None):
|
| 42 |
+
if epoch is None:
|
| 43 |
+
epoch = self.last_epoch + 1
|
| 44 |
+
self.last_epoch = (
|
| 45 |
+
epoch if epoch != 0 else 1
|
| 46 |
+
) # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
|
| 47 |
+
if self.last_epoch <= self.total_epoch:
|
| 48 |
+
warmup_lr = [
|
| 49 |
+
base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
|
| 50 |
+
for base_lr in self.base_lrs
|
| 51 |
+
]
|
| 52 |
+
for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
|
| 53 |
+
param_group["lr"] = lr
|
| 54 |
+
else:
|
| 55 |
+
if epoch is None:
|
| 56 |
+
self.after_scheduler.step(metrics, None)
|
| 57 |
+
else:
|
| 58 |
+
self.after_scheduler.step(metrics, epoch - self.total_epoch)
|
| 59 |
+
|
| 60 |
+
def step(self, epoch=None, metrics=None):
|
| 61 |
+
if type(self.after_scheduler) != ReduceLROnPlateau:
|
| 62 |
+
if self.finished and self.after_scheduler:
|
| 63 |
+
if epoch is None:
|
| 64 |
+
self.after_scheduler.step(None)
|
| 65 |
+
else:
|
| 66 |
+
self.after_scheduler.step(epoch - self.total_epoch)
|
| 67 |
+
else:
|
| 68 |
+
return super().step(epoch)
|
| 69 |
+
else:
|
| 70 |
+
self.step_ReduceLROnPlateau(metrics, epoch)
|
AIGVDet/core/utils1/warmup.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class GradualWarmupScheduler(_LRScheduler):
|
| 5 |
+
"""Gradually warm-up(increasing) learning rate in optimizer.
|
| 6 |
+
Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
optimizer (Optimizer): Wrapped optimizer.
|
| 10 |
+
multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
|
| 11 |
+
total_epoch: target learning rate is reached at total_epoch, gradually
|
| 12 |
+
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
|
| 16 |
+
self.multiplier = multiplier
|
| 17 |
+
if self.multiplier < 1.0:
|
| 18 |
+
raise ValueError("multiplier should be greater thant or equal to 1.")
|
| 19 |
+
self.total_epoch = total_epoch
|
| 20 |
+
self.after_scheduler = after_scheduler
|
| 21 |
+
self.finished = False
|
| 22 |
+
super().__init__(optimizer)
|
| 23 |
+
|
| 24 |
+
def get_lr(self):
|
| 25 |
+
if self.last_epoch > self.total_epoch:
|
| 26 |
+
if self.after_scheduler:
|
| 27 |
+
if not self.finished:
|
| 28 |
+
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
|
| 29 |
+
self.finished = True
|
| 30 |
+
return self.after_scheduler.get_last_lr()
|
| 31 |
+
return [base_lr * self.multiplier for base_lr in self.base_lrs]
|
| 32 |
+
|
| 33 |
+
if self.multiplier == 1.0:
|
| 34 |
+
return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
|
| 35 |
+
else:
|
| 36 |
+
return [
|
| 37 |
+
base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
|
| 38 |
+
for base_lr in self.base_lrs
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
def step_ReduceLROnPlateau(self, metrics, epoch=None):
|
| 42 |
+
if epoch is None:
|
| 43 |
+
epoch = self.last_epoch + 1
|
| 44 |
+
self.last_epoch = (
|
| 45 |
+
epoch if epoch != 0 else 1
|
| 46 |
+
) # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
|
| 47 |
+
if self.last_epoch <= self.total_epoch:
|
| 48 |
+
warmup_lr = [
|
| 49 |
+
base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
|
| 50 |
+
for base_lr in self.base_lrs
|
| 51 |
+
]
|
| 52 |
+
for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
|
| 53 |
+
param_group["lr"] = lr
|
| 54 |
+
else:
|
| 55 |
+
if epoch is None:
|
| 56 |
+
self.after_scheduler.step(metrics, None)
|
| 57 |
+
else:
|
| 58 |
+
self.after_scheduler.step(metrics, epoch - self.total_epoch)
|
| 59 |
+
|
| 60 |
+
def step(self, epoch=None, metrics=None):
|
| 61 |
+
if type(self.after_scheduler) != ReduceLROnPlateau:
|
| 62 |
+
if self.finished and self.after_scheduler:
|
| 63 |
+
if epoch is None:
|
| 64 |
+
self.after_scheduler.step(None)
|
| 65 |
+
else:
|
| 66 |
+
self.after_scheduler.step(epoch - self.total_epoch)
|
| 67 |
+
else:
|
| 68 |
+
return super().step(epoch)
|
| 69 |
+
else:
|
| 70 |
+
self.step_ReduceLROnPlateau(metrics, epoch)
|
AIGVDet/docker-compose.yml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: "3.9"
|
| 2 |
+
|
| 3 |
+
services:
|
| 4 |
+
web:
|
| 5 |
+
build: .
|
| 6 |
+
ports:
|
| 7 |
+
- "8003:8003"
|
| 8 |
+
restart: unless-stopped
|
| 9 |
+
deploy:
|
| 10 |
+
resources:
|
| 11 |
+
reservations:
|
| 12 |
+
devices:
|
| 13 |
+
- driver: nvidia
|
| 14 |
+
count: all
|
| 15 |
+
capabilities: [gpu]
|
| 16 |
+
ipc: host
|
| 17 |
+
shm_size: "24gb"
|
AIGVDet/main.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import shutil
|
| 4 |
+
from typing import Dict, Optional
|
| 5 |
+
|
| 6 |
+
from .run import RUN
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def run_video_to_json(
|
| 10 |
+
video_path: str,
|
| 11 |
+
output_json_path: Optional[str] = None,
|
| 12 |
+
model_optical_path: str = "checkpoints/optical.pth",
|
| 13 |
+
model_original_path: str = "checkpoints/original.pth",
|
| 14 |
+
frame_root: str = "frame",
|
| 15 |
+
optical_root: str = "optical_result"
|
| 16 |
+
) -> Dict:
|
| 17 |
+
"""
|
| 18 |
+
Xử lý 1 video và ghi kết quả ra file JSON.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
dict kết quả (đồng thời ghi ra JSON)
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 25 |
+
if not os.path.isabs(model_optical_path):
|
| 26 |
+
model_optical_path = os.path.join(script_dir, model_optical_path)
|
| 27 |
+
if not os.path.isabs(model_original_path):
|
| 28 |
+
model_original_path = os.path.join(script_dir, model_original_path)
|
| 29 |
+
|
| 30 |
+
results = {}
|
| 31 |
+
|
| 32 |
+
if not os.path.isfile(video_path):
|
| 33 |
+
raise FileNotFoundError(f"File not found: {video_path}")
|
| 34 |
+
|
| 35 |
+
video_name = os.path.basename(video_path)
|
| 36 |
+
video_id = os.path.splitext(video_name)[0]
|
| 37 |
+
|
| 38 |
+
folder_original = os.path.join(frame_root, video_id)
|
| 39 |
+
folder_optical = os.path.join(optical_root, video_id)
|
| 40 |
+
|
| 41 |
+
args = [
|
| 42 |
+
'--path', video_path,
|
| 43 |
+
'--folder_original_path', folder_original,
|
| 44 |
+
'--folder_optical_flow_path', folder_optical,
|
| 45 |
+
'--model_optical_flow_path', model_optical_path,
|
| 46 |
+
'--model_original_path', model_original_path
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
output = RUN(args)
|
| 51 |
+
results[video_id] = {
|
| 52 |
+
"video_name": video_name,
|
| 53 |
+
"authentic_confidence_score": round(output["real_score"], 4),
|
| 54 |
+
"synthetic_confidence_score": round(output["fake_score"], 4)
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
except Exception as e:
|
| 58 |
+
results[video_id] = {
|
| 59 |
+
"video_name": video_name,
|
| 60 |
+
"error": str(e)
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
finally:
|
| 64 |
+
# Clean up intermediate folders
|
| 65 |
+
for folder in [folder_original, folder_optical]:
|
| 66 |
+
try:
|
| 67 |
+
if os.path.exists(folder):
|
| 68 |
+
shutil.rmtree(folder)
|
| 69 |
+
except Exception:
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
# Save result to JSON or return dict
|
| 73 |
+
if output_json_path:
|
| 74 |
+
os.makedirs(os.path.dirname(output_json_path), exist_ok=True)
|
| 75 |
+
with open(output_json_path, "w", encoding="utf-8") as f:
|
| 76 |
+
json.dump(results, f, ensure_ascii=False, indent=2)
|
| 77 |
+
|
| 78 |
+
return results
|
AIGVDet/networks/resnet.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.utils.model_zoo as model_zoo
|
| 3 |
+
|
| 4 |
+
__all__ = ["ResNet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152"]
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
model_urls = {
|
| 8 |
+
"resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
|
| 9 |
+
"resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
|
| 10 |
+
"resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
|
| 11 |
+
"resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
|
| 12 |
+
"resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 17 |
+
"""3x3 convolution with padding"""
|
| 18 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
| 22 |
+
"""1x1 convolution"""
|
| 23 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class BasicBlock(nn.Module):
|
| 27 |
+
expansion = 1
|
| 28 |
+
|
| 29 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 32 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 33 |
+
self.relu = nn.ReLU(inplace=True)
|
| 34 |
+
self.conv2 = conv3x3(planes, planes)
|
| 35 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 36 |
+
self.downsample = downsample
|
| 37 |
+
self.stride = stride
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
identity = x
|
| 41 |
+
|
| 42 |
+
out = self.conv1(x)
|
| 43 |
+
out = self.bn1(out)
|
| 44 |
+
out = self.relu(out)
|
| 45 |
+
|
| 46 |
+
out = self.conv2(out)
|
| 47 |
+
out = self.bn2(out)
|
| 48 |
+
|
| 49 |
+
if self.downsample is not None:
|
| 50 |
+
identity = self.downsample(x)
|
| 51 |
+
|
| 52 |
+
out += identity
|
| 53 |
+
out = self.relu(out)
|
| 54 |
+
|
| 55 |
+
return out
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class Bottleneck(nn.Module):
|
| 59 |
+
expansion = 4
|
| 60 |
+
|
| 61 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.conv1 = conv1x1(inplanes, planes)
|
| 64 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 65 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
| 66 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 67 |
+
self.conv3 = conv1x1(planes, planes * self.expansion)
|
| 68 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 69 |
+
self.relu = nn.ReLU(inplace=True)
|
| 70 |
+
self.downsample = downsample
|
| 71 |
+
self.stride = stride
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
identity = x
|
| 75 |
+
|
| 76 |
+
out = self.conv1(x)
|
| 77 |
+
out = self.bn1(out)
|
| 78 |
+
out = self.relu(out)
|
| 79 |
+
|
| 80 |
+
out = self.conv2(out)
|
| 81 |
+
out = self.bn2(out)
|
| 82 |
+
out = self.relu(out)
|
| 83 |
+
|
| 84 |
+
out = self.conv3(out)
|
| 85 |
+
out = self.bn3(out)
|
| 86 |
+
|
| 87 |
+
if self.downsample is not None:
|
| 88 |
+
identity = self.downsample(x)
|
| 89 |
+
|
| 90 |
+
out += identity
|
| 91 |
+
out = self.relu(out)
|
| 92 |
+
|
| 93 |
+
return out
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class ResNet(nn.Module):
|
| 97 |
+
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.inplanes = 64
|
| 100 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
| 101 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 102 |
+
self.relu = nn.ReLU(inplace=True)
|
| 103 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 104 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 105 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 106 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 107 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
| 108 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 109 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
| 110 |
+
|
| 111 |
+
for m in self.modules():
|
| 112 |
+
if isinstance(m, nn.Conv2d):
|
| 113 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 114 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 115 |
+
nn.init.constant_(m.weight, 1)
|
| 116 |
+
nn.init.constant_(m.bias, 0)
|
| 117 |
+
|
| 118 |
+
# Zero-initialize the last BN in each residual branch,
|
| 119 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
| 120 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
| 121 |
+
if zero_init_residual:
|
| 122 |
+
for m in self.modules():
|
| 123 |
+
if isinstance(m, Bottleneck):
|
| 124 |
+
nn.init.constant_(m.bn3.weight, 0)
|
| 125 |
+
elif isinstance(m, BasicBlock):
|
| 126 |
+
nn.init.constant_(m.bn2.weight, 0)
|
| 127 |
+
|
| 128 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 129 |
+
downsample = None
|
| 130 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 131 |
+
downsample = nn.Sequential(
|
| 132 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 133 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
layers = [block(self.inplanes, planes, stride, downsample)]
|
| 137 |
+
self.inplanes = planes * block.expansion
|
| 138 |
+
layers.extend(block(self.inplanes, planes) for _ in range(1, blocks))
|
| 139 |
+
return nn.Sequential(*layers)
|
| 140 |
+
|
| 141 |
+
def forward(self, x, *args):
|
| 142 |
+
x = self.conv1(x)
|
| 143 |
+
x = self.bn1(x)
|
| 144 |
+
x = self.relu(x)
|
| 145 |
+
x = self.maxpool(x)
|
| 146 |
+
|
| 147 |
+
x = self.layer1(x)
|
| 148 |
+
x = self.layer2(x)
|
| 149 |
+
x = self.layer3(x)
|
| 150 |
+
x = self.layer4(x)
|
| 151 |
+
|
| 152 |
+
x = self.avgpool(x)
|
| 153 |
+
x = x.view(x.size(0), -1)
|
| 154 |
+
x = self.fc(x)
|
| 155 |
+
|
| 156 |
+
return x
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def resnet18(pretrained=False, **kwargs):
|
| 160 |
+
"""Constructs a ResNet-18 model.
|
| 161 |
+
Args:
|
| 162 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 163 |
+
"""
|
| 164 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
| 165 |
+
if pretrained:
|
| 166 |
+
model.load_state_dict(model_zoo.load_url(model_urls["resnet18"]))
|
| 167 |
+
return model
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def resnet34(pretrained=False, **kwargs):
|
| 171 |
+
"""Constructs a ResNet-34 model.
|
| 172 |
+
Args:
|
| 173 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 174 |
+
"""
|
| 175 |
+
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
| 176 |
+
if pretrained:
|
| 177 |
+
model.load_state_dict(model_zoo.load_url(model_urls["resnet34"]))
|
| 178 |
+
return model
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def resnet50(pretrained=False, **kwargs):
|
| 182 |
+
"""Constructs a ResNet-50 model.
|
| 183 |
+
Args:
|
| 184 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 185 |
+
"""
|
| 186 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
| 187 |
+
if pretrained:
|
| 188 |
+
model.load_state_dict(model_zoo.load_url(model_urls["resnet50"]))
|
| 189 |
+
return model
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def resnet101(pretrained=False, **kwargs):
|
| 193 |
+
"""Constructs a ResNet-101 model.
|
| 194 |
+
Args:
|
| 195 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 196 |
+
"""
|
| 197 |
+
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
| 198 |
+
if pretrained:
|
| 199 |
+
model.load_state_dict(model_zoo.load_url(model_urls["resnet101"]))
|
| 200 |
+
return model
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def resnet152(pretrained=False, **kwargs):
|
| 204 |
+
"""Constructs a ResNet-152 model.
|
| 205 |
+
Args:
|
| 206 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 207 |
+
"""
|
| 208 |
+
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
| 209 |
+
if pretrained:
|
| 210 |
+
model.load_state_dict(model_zoo.load_url(model_urls["resnet152"]))
|
| 211 |
+
return model
|
AIGVDet/raft_model/raft-things.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fcfa4125d6418f4de95d84aec20a3c5f4e205101715a79f193243c186ac9a7e1
|
| 3 |
+
size 21108000
|
AIGVDet/requirements.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# conda create -n aigvdet python=3.9
|
| 2 |
+
# pip install torch==2.0.0+cu117 torchvision==0.15.1+cu117 -f https://download.pytorch.org/whl/torch_stable.html
|
| 3 |
+
einops
|
| 4 |
+
imageio
|
| 5 |
+
ipympl
|
| 6 |
+
matplotlib
|
| 7 |
+
numpy<2
|
| 8 |
+
opencv-python
|
| 9 |
+
pandas
|
| 10 |
+
scikit-learn
|
| 11 |
+
tensorboard
|
| 12 |
+
tensorboardX
|
| 13 |
+
tqdm
|
| 14 |
+
blobfile>=1.0.5
|
| 15 |
+
natsort
|
| 16 |
+
fastapi==0.116.1
|
| 17 |
+
pydantic==2.11.7
|
| 18 |
+
uvicorn[standard]
|
| 19 |
+
torch==2.0.0+cu117
|
| 20 |
+
torchvision==0.15.1+cu117
|
| 21 |
+
-f https://download.pytorch.org/whl/torch_stable.html
|
| 22 |
+
python-multipart
|
AIGVDet/run.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import cv2
|
| 5 |
+
import glob
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn
|
| 12 |
+
import torchvision.transforms as transforms
|
| 13 |
+
import torchvision.transforms.functional as TF
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
import time
|
| 16 |
+
|
| 17 |
+
from .core.raft import RAFT
|
| 18 |
+
from .core.utils import flow_viz
|
| 19 |
+
from .core.utils.utils import InputPadder
|
| 20 |
+
from natsort import natsorted
|
| 21 |
+
from .core.utils1.utils import get_network, str2bool, to_cuda
|
| 22 |
+
from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score,roc_auc_score
|
| 23 |
+
|
| 24 |
+
DEVICE = 'cuda'
|
| 25 |
+
# DEVICE = 'cpu' # Changed to 'cpu'
|
| 26 |
+
device = torch.device(DEVICE)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_image(imfile):
|
| 30 |
+
img = np.array(Image.open(imfile)).astype(np.uint8)
|
| 31 |
+
img = torch.from_numpy(img).permute(2, 0, 1).float()
|
| 32 |
+
return img[None].to(DEVICE)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def viz(img, flo, folder_optical_flow_path, imfile1):
|
| 36 |
+
img = img[0].permute(1, 2, 0).cpu().numpy()
|
| 37 |
+
flo = flo[0].permute(1, 2, 0).cpu().numpy()
|
| 38 |
+
|
| 39 |
+
# map flow to RGB image
|
| 40 |
+
flo = flow_viz.flow_to_image(flo)
|
| 41 |
+
img_flo = np.concatenate([img, flo], axis=0)
|
| 42 |
+
|
| 43 |
+
# extract filename safely (cross-platform)
|
| 44 |
+
filename = os.path.basename(imfile1).strip()
|
| 45 |
+
output_path = os.path.join(folder_optical_flow_path, filename)
|
| 46 |
+
|
| 47 |
+
print(output_path)
|
| 48 |
+
cv2.imwrite(output_path, flo)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def video_to_frames(video_path, output_folder):
|
| 52 |
+
if not os.path.exists(output_folder):
|
| 53 |
+
os.makedirs(output_folder)
|
| 54 |
+
|
| 55 |
+
cap = cv2.VideoCapture(video_path)
|
| 56 |
+
frame_count = 0
|
| 57 |
+
|
| 58 |
+
while cap.isOpened():
|
| 59 |
+
ret, frame = cap.read()
|
| 60 |
+
if not ret:
|
| 61 |
+
break
|
| 62 |
+
|
| 63 |
+
frame_filename = os.path.join(output_folder, f"frame_{frame_count:05d}.png")
|
| 64 |
+
cv2.imwrite(frame_filename, frame)
|
| 65 |
+
frame_count += 1
|
| 66 |
+
|
| 67 |
+
cap.release()
|
| 68 |
+
|
| 69 |
+
images = glob.glob(os.path.join(output_folder, '*.png')) + \
|
| 70 |
+
glob.glob(os.path.join(output_folder, '*.jpg'))
|
| 71 |
+
images = sorted(images)
|
| 72 |
+
|
| 73 |
+
return images
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# generate optical flow images
|
| 77 |
+
def OF_gen(args):
|
| 78 |
+
model = torch.nn.DataParallel(RAFT(args))
|
| 79 |
+
model.load_state_dict(torch.load(args.model, map_location=torch.device(DEVICE)))
|
| 80 |
+
|
| 81 |
+
model = model.module
|
| 82 |
+
model.to(DEVICE)
|
| 83 |
+
model.eval()
|
| 84 |
+
|
| 85 |
+
if not os.path.exists(args.folder_optical_flow_path):
|
| 86 |
+
os.makedirs(args.folder_optical_flow_path)
|
| 87 |
+
print(f'{args.folder_optical_flow_path}')
|
| 88 |
+
|
| 89 |
+
with torch.no_grad():
|
| 90 |
+
|
| 91 |
+
images = video_to_frames(args.path, args.folder_original_path)
|
| 92 |
+
images = natsorted(images)
|
| 93 |
+
|
| 94 |
+
for imfile1, imfile2 in zip(images[:-1], images[1:]):
|
| 95 |
+
image1 = load_image(imfile1)
|
| 96 |
+
image2 = load_image(imfile2)
|
| 97 |
+
|
| 98 |
+
padder = InputPadder(image1.shape)
|
| 99 |
+
image1, image2 = padder.pad(image1, image2)
|
| 100 |
+
|
| 101 |
+
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
|
| 102 |
+
|
| 103 |
+
viz(image1, flow_up,args.folder_optical_flow_path,imfile1)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def RUN(args=None):
|
| 107 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 108 |
+
default_model_path = os.path.join(script_dir, "raft_model/raft-things.pth")
|
| 109 |
+
|
| 110 |
+
parser = argparse.ArgumentParser()
|
| 111 |
+
parser.add_argument('--model', help="restore checkpoint",default=default_model_path)
|
| 112 |
+
parser.add_argument('--path', help="dataset for evaluation",default="video/000000.mp4")
|
| 113 |
+
parser.add_argument('--folder_original_path', help="dataset for evaluation_frames",default="frame/000000")
|
| 114 |
+
parser.add_argument('--small', action='store_true', help='use small model')
|
| 115 |
+
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
| 116 |
+
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
|
| 117 |
+
parser.add_argument('--folder_optical_flow_path',help="the results to save",default="optical_result/000000")
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
"-mop",
|
| 120 |
+
"--model_optical_flow_path",
|
| 121 |
+
type=str,
|
| 122 |
+
default="checkpoints/optical.pth",
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"-mor",
|
| 126 |
+
"--model_original_path",
|
| 127 |
+
type=str,
|
| 128 |
+
default="checkpoints/original.pth",
|
| 129 |
+
)
|
| 130 |
+
parser.add_argument(
|
| 131 |
+
"-t",
|
| 132 |
+
"--threshold",
|
| 133 |
+
type=float,
|
| 134 |
+
default=0.5,
|
| 135 |
+
)
|
| 136 |
+
parser.add_argument("--use_cpu", action="store_true", help="uses gpu by default, turn on to use cpu")
|
| 137 |
+
parser.add_argument("--arch", type=str, default="resnet50")
|
| 138 |
+
parser.add_argument("--aug_norm", type=str2bool, default=True)
|
| 139 |
+
args = parser.parse_args(args)
|
| 140 |
+
start_time = time.perf_counter()
|
| 141 |
+
OF_gen(args)
|
| 142 |
+
elapsed = time.perf_counter() - start_time
|
| 143 |
+
print(f"⏱️ [OF_gen] Service call took {elapsed:.2f} seconds")
|
| 144 |
+
# Load models
|
| 145 |
+
model_op = get_network(args.arch)
|
| 146 |
+
state_dict = torch.load(args.model_optical_flow_path, map_location='cpu')
|
| 147 |
+
model_op.load_state_dict(state_dict["model"] if "model" in state_dict else state_dict)
|
| 148 |
+
model_op.eval().to(device)
|
| 149 |
+
|
| 150 |
+
model_or = get_network(args.arch)
|
| 151 |
+
state_dict = torch.load(args.model_original_path, map_location='cpu')
|
| 152 |
+
model_or.load_state_dict(state_dict["model"] if "model" in state_dict else state_dict)
|
| 153 |
+
model_or.eval().to(device)
|
| 154 |
+
|
| 155 |
+
# Transform
|
| 156 |
+
trans = transforms.Compose([
|
| 157 |
+
transforms.CenterCrop((448, 448)),
|
| 158 |
+
transforms.ToTensor(),
|
| 159 |
+
])
|
| 160 |
+
|
| 161 |
+
# Process original frames
|
| 162 |
+
original_file_list = sorted(
|
| 163 |
+
glob.glob(os.path.join(args.folder_original_path, "*.jpg")) +
|
| 164 |
+
glob.glob(os.path.join(args.folder_original_path, "*.png")) +
|
| 165 |
+
glob.glob(os.path.join(args.folder_original_path, "*.JPEG"))
|
| 166 |
+
)
|
| 167 |
+
original_prob_sum = 0
|
| 168 |
+
for img_path in tqdm(original_file_list, desc="Original", dynamic_ncols=True, disable=len(original_file_list) <= 1):
|
| 169 |
+
img = Image.open(img_path).convert("RGB")
|
| 170 |
+
img = trans(img)
|
| 171 |
+
if args.aug_norm:
|
| 172 |
+
img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 173 |
+
in_tens = img.unsqueeze(0).to(device)
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
prob = model_or(in_tens).sigmoid().item()
|
| 176 |
+
original_prob_sum += prob
|
| 177 |
+
|
| 178 |
+
original_prob = original_prob_sum / len(original_file_list)
|
| 179 |
+
print(f"Original prob: {original_prob:.4f}")
|
| 180 |
+
|
| 181 |
+
# Process optical flow frames
|
| 182 |
+
optical_file_list = sorted(
|
| 183 |
+
glob.glob(os.path.join(args.folder_optical_flow_path, "*.jpg")) +
|
| 184 |
+
glob.glob(os.path.join(args.folder_optical_flow_path, "*.png")) +
|
| 185 |
+
glob.glob(os.path.join(args.folder_optical_flow_path, "*.JPEG"))
|
| 186 |
+
)
|
| 187 |
+
optical_prob_sum = 0
|
| 188 |
+
for img_path in tqdm(optical_file_list, desc="Optical Flow", dynamic_ncols=True, disable=len(optical_file_list) <= 1):
|
| 189 |
+
img = Image.open(img_path).convert("RGB")
|
| 190 |
+
img = trans(img)
|
| 191 |
+
if args.aug_norm:
|
| 192 |
+
img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 193 |
+
in_tens = img.unsqueeze(0).to(device)
|
| 194 |
+
with torch.no_grad():
|
| 195 |
+
prob = model_op(in_tens).sigmoid().item()
|
| 196 |
+
optical_prob_sum += prob
|
| 197 |
+
|
| 198 |
+
optical_prob = optical_prob_sum / len(optical_file_list)
|
| 199 |
+
print(f"Optical prob: {optical_prob:.4f}")
|
| 200 |
+
|
| 201 |
+
final_prob = (original_prob + optical_prob) / 2
|
| 202 |
+
print(f"predict: {final_prob}")
|
| 203 |
+
|
| 204 |
+
real_score = 1 - final_prob
|
| 205 |
+
fake_score = final_prob
|
| 206 |
+
print(f"Confidence scores - Real: {real_score:.4f}, Fake: {fake_score:.4f}")
|
| 207 |
+
|
| 208 |
+
return {
|
| 209 |
+
"original_prob": original_prob,
|
| 210 |
+
"optical_prob": optical_prob,
|
| 211 |
+
"final_predict": final_prob,
|
| 212 |
+
"real_score": real_score,
|
| 213 |
+
"fake_score": fake_score
|
| 214 |
+
}
|
AIGVDet/run.sh
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
python demo.py --path "demo_video/real/094.mp4" --folder_original_path "frame/000000" --folder_optical_flow_path "optical_result/000000" -mop "checkpoints/optical.pth" -mor "checkpoints/original.pth"
|
AIGVDet/test.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import glob
|
| 3 |
+
import os
|
| 4 |
+
import pandas as pd
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn
|
| 8 |
+
import torchvision.transforms as transforms
|
| 9 |
+
import torchvision.transforms.functional as TF
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from core.utils1.utils import get_network, str2bool, to_cuda
|
| 14 |
+
from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score,roc_auc_score
|
| 15 |
+
|
| 16 |
+
if __name__=="__main__":
|
| 17 |
+
|
| 18 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 19 |
+
parser.add_argument(
|
| 20 |
+
"-fop", "--folder_optical_flow_path", default="data/test/T2V/videocraft", type=str, help="path to optical flow imagefile folder"
|
| 21 |
+
)
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"-for", "--folder_original_path", default="data/test/original/T2V/videocraft", type=str, help="path to RGB image file folder"
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"-mop",
|
| 27 |
+
"--model_optical_flow_path",
|
| 28 |
+
type=str,
|
| 29 |
+
default="checkpoints/optical.pth",
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"-mor",
|
| 33 |
+
"--model_original_path",
|
| 34 |
+
type=str,
|
| 35 |
+
default="checkpoints/original.pth",
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"-t",
|
| 40 |
+
"--threshold",
|
| 41 |
+
type=float,
|
| 42 |
+
default=0.5,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"-e",
|
| 47 |
+
"--excel_path",
|
| 48 |
+
type=str,
|
| 49 |
+
help="path to excel of frames",
|
| 50 |
+
default="data/results/moonvalley_wang.csv",
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"-ef",
|
| 55 |
+
"--excel_frame_path",
|
| 56 |
+
type=str,
|
| 57 |
+
help="path to excel of frame detection result",
|
| 58 |
+
default="data/results/frame/moonvalley_wang.csv",
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
parser.add_argument("--use_cpu", action="store_true", help="uses gpu by default, turn on to use cpu")
|
| 65 |
+
parser.add_argument("--arch", type=str, default="resnet50")
|
| 66 |
+
parser.add_argument("--aug_norm", type=str2bool, default=True)
|
| 67 |
+
|
| 68 |
+
args = parser.parse_args()
|
| 69 |
+
subfolder_count = 0
|
| 70 |
+
|
| 71 |
+
model_op = get_network(args.arch)
|
| 72 |
+
state_dict = torch.load(args.model_optical_flow_path, map_location="cpu")
|
| 73 |
+
if "model" in state_dict:
|
| 74 |
+
state_dict = state_dict["model"]
|
| 75 |
+
model_op.load_state_dict(state_dict)
|
| 76 |
+
model_op.eval()
|
| 77 |
+
if not args.use_cpu:
|
| 78 |
+
model_op.cuda()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
model_or = get_network(args.arch)
|
| 82 |
+
state_dict = torch.load(args.model_original_path, map_location="cpu")
|
| 83 |
+
if "model" in state_dict:
|
| 84 |
+
state_dict = state_dict["model"]
|
| 85 |
+
model_or.load_state_dict(state_dict)
|
| 86 |
+
model_or.eval()
|
| 87 |
+
if not args.use_cpu:
|
| 88 |
+
model_or.cuda()
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
trans = transforms.Compose(
|
| 92 |
+
(
|
| 93 |
+
transforms.CenterCrop((448,448)),
|
| 94 |
+
transforms.ToTensor(),
|
| 95 |
+
)
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
print("*" * 50)
|
| 99 |
+
|
| 100 |
+
flag=0
|
| 101 |
+
p=0
|
| 102 |
+
n=0
|
| 103 |
+
tp=0
|
| 104 |
+
tn=0
|
| 105 |
+
y_true=[]
|
| 106 |
+
y_pred=[]
|
| 107 |
+
|
| 108 |
+
# create an empty DataFrame
|
| 109 |
+
df = pd.DataFrame(columns=['name', 'pro','flag','optical_pro','original_pro'])
|
| 110 |
+
df1 = pd.DataFrame(columns=['original_path', 'original_pro','optical_path','optical_pro','flag'])
|
| 111 |
+
index1=0
|
| 112 |
+
|
| 113 |
+
# Traverse through subfolders in a large folder.
|
| 114 |
+
for subfolder_name in ["0_real", "1_fake"]:
|
| 115 |
+
optical_subfolder_path = os.path.join(args.folder_optical_flow_path, subfolder_name)
|
| 116 |
+
original_subfolder_path = os.path.join(args.folder_original_path, subfolder_name)
|
| 117 |
+
|
| 118 |
+
if subfolder_name=="0_real":
|
| 119 |
+
flag=0
|
| 120 |
+
else:
|
| 121 |
+
flag=1
|
| 122 |
+
|
| 123 |
+
if os.path.isdir(optical_subfolder_path):
|
| 124 |
+
pass
|
| 125 |
+
else:
|
| 126 |
+
print("Subfolder does not exist.", optical_subfolder_path)
|
| 127 |
+
|
| 128 |
+
# Check if the subfolder path exists.
|
| 129 |
+
if os.path.isdir(original_subfolder_path):
|
| 130 |
+
print("test subfolder:", subfolder_name)
|
| 131 |
+
|
| 132 |
+
# Traverse through sub-subfolders within a subfolder.
|
| 133 |
+
for subsubfolder_name in os.listdir(original_subfolder_path):
|
| 134 |
+
original_subsubfolder_path = os.path.join(original_subfolder_path, subsubfolder_name)
|
| 135 |
+
optical_subsubfolder_path = os.path.join(optical_subfolder_path, subsubfolder_name)
|
| 136 |
+
if os.path.isdir(optical_subsubfolder_path):
|
| 137 |
+
pass
|
| 138 |
+
else:
|
| 139 |
+
print("Sub-subfolder does not exist.",optical_subsubfolder_path)
|
| 140 |
+
|
| 141 |
+
if os.path.isdir(original_subsubfolder_path):
|
| 142 |
+
print("test subsubfolder:", subsubfolder_name)
|
| 143 |
+
|
| 144 |
+
#Detect original
|
| 145 |
+
original_file_list = sorted(glob.glob(os.path.join(original_subsubfolder_path, "*.jpg")) + glob.glob(os.path.join(original_subsubfolder_path, "*.png"))+glob.glob(os.path.join(original_subsubfolder_path, "*.JPEG")))
|
| 146 |
+
|
| 147 |
+
original_prob_sum=0
|
| 148 |
+
for img_path in tqdm(original_file_list, dynamic_ncols=True, disable=len(original_file_list) <= 1):
|
| 149 |
+
|
| 150 |
+
img = Image.open(img_path).convert("RGB")
|
| 151 |
+
img = trans(img)
|
| 152 |
+
if args.aug_norm:
|
| 153 |
+
img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 154 |
+
in_tens = img.unsqueeze(0)
|
| 155 |
+
if not args.use_cpu:
|
| 156 |
+
in_tens = in_tens.cuda()
|
| 157 |
+
|
| 158 |
+
with torch.no_grad():
|
| 159 |
+
prob = model_or(in_tens).sigmoid().item()
|
| 160 |
+
original_prob_sum+=prob
|
| 161 |
+
|
| 162 |
+
df1 = df1.append({'original_path': img_path, 'original_pro': prob , 'flag':flag}, ignore_index=True)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
original_predict=original_prob_sum/len(original_file_list)
|
| 166 |
+
print("original prob",original_predict)
|
| 167 |
+
|
| 168 |
+
#Detect optical flow
|
| 169 |
+
optical_file_list = sorted(glob.glob(os.path.join(optical_subsubfolder_path, "*.jpg")) + glob.glob(os.path.join(optical_subsubfolder_path, "*.png"))+glob.glob(os.path.join(optical_subsubfolder_path, "*.JPEG")))
|
| 170 |
+
optical_prob_sum=0
|
| 171 |
+
for img_path in tqdm(optical_file_list, dynamic_ncols=True, disable=len(original_file_list) <= 1):
|
| 172 |
+
|
| 173 |
+
img = Image.open(img_path).convert("RGB")
|
| 174 |
+
img = trans(img)
|
| 175 |
+
if args.aug_norm:
|
| 176 |
+
img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 177 |
+
in_tens = img.unsqueeze(0)
|
| 178 |
+
if not args.use_cpu:
|
| 179 |
+
in_tens = in_tens.cuda()
|
| 180 |
+
|
| 181 |
+
with torch.no_grad():
|
| 182 |
+
prob = model_op(in_tens).sigmoid().item()
|
| 183 |
+
optical_prob_sum+=prob
|
| 184 |
+
|
| 185 |
+
df1.loc[index1, 'optical_path'] = img_path
|
| 186 |
+
df1.loc[index1, 'optical_pro'] = prob
|
| 187 |
+
index1=index1+1
|
| 188 |
+
index1=index1+1
|
| 189 |
+
|
| 190 |
+
optical_predict=optical_prob_sum/len(optical_file_list)
|
| 191 |
+
print("optical prob",optical_predict)
|
| 192 |
+
|
| 193 |
+
predict=original_predict*0.5+optical_predict*0.5
|
| 194 |
+
print(f"flag:{flag} predict:{predict}")
|
| 195 |
+
# y_true.append((float)(flag))
|
| 196 |
+
y_true.append((flag))
|
| 197 |
+
y_pred.append(predict)
|
| 198 |
+
if flag==0:
|
| 199 |
+
n+=1
|
| 200 |
+
if predict<args.threshold:
|
| 201 |
+
tn+=1
|
| 202 |
+
else:
|
| 203 |
+
p+=1
|
| 204 |
+
if predict>=args.threshold:
|
| 205 |
+
tp+=1
|
| 206 |
+
df = df.append({'name': subsubfolder_name, 'pro': predict , 'flag':flag ,'optical_pro':optical_predict,'original_pro':original_predict}, ignore_index=True)
|
| 207 |
+
else:
|
| 208 |
+
print("Subfolder does not exist:", original_subfolder_path)
|
| 209 |
+
# r_acc = accuracy_score(y_true[y_true == 0], y_pred[y_true == 0] > args.threshold)
|
| 210 |
+
# f_acc = accuracy_score(y_true[y_true == 1], y_pred[y_true == 1] > args.threshold)
|
| 211 |
+
# acc = accuracy_score(y_true, y_pred > args.threshold)
|
| 212 |
+
|
| 213 |
+
ap = average_precision_score(y_true, y_pred)
|
| 214 |
+
auc=roc_auc_score(y_true,y_pred)
|
| 215 |
+
# print(f"r_acc:{r_acc}")
|
| 216 |
+
print(f"tnr:{tn/n}")
|
| 217 |
+
# print(f"f_acc:{f_acc}")
|
| 218 |
+
print(f"tpr:{tp/p}")
|
| 219 |
+
print(f"acc:{(tp+tn)/(p+n)}")
|
| 220 |
+
# print(f"acc:{acc}")
|
| 221 |
+
print(f"ap:{ap}")
|
| 222 |
+
print(f"auc:{auc}")
|
| 223 |
+
print(f"p:{p}")
|
| 224 |
+
print(f"n:{n}")
|
| 225 |
+
print(f"tp:{tp}")
|
| 226 |
+
print(f"tn:{tn}")
|
| 227 |
+
|
| 228 |
+
# Write the DataFrame to a csv file.
|
| 229 |
+
csv_filename = args.excel_path
|
| 230 |
+
csv_folder = os.path.dirname(csv_filename)
|
| 231 |
+
if not os.path.exists(csv_folder):
|
| 232 |
+
os.makedirs(csv_folder)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
if not os.path.exists(csv_filename):
|
| 236 |
+
df.to_csv(csv_filename, index=False)
|
| 237 |
+
else:
|
| 238 |
+
df.to_csv(csv_filename, mode='a', header=False, index=False)
|
| 239 |
+
print(f"Results have been saved to {csv_filename}")
|
| 240 |
+
|
| 241 |
+
# Write the prediction probabilities of the frame to a CSV file.
|
| 242 |
+
csv_filename1 = args.excel_frame_path
|
| 243 |
+
csv_folder1 = os.path.dirname(csv_filename1)
|
| 244 |
+
if not os.path.exists(csv_folder1):
|
| 245 |
+
os.makedirs(csv_folder1)
|
| 246 |
+
|
| 247 |
+
if not os.path.exists(csv_filename1):
|
| 248 |
+
df1.to_csv(csv_filename1, index=False)
|
| 249 |
+
else:
|
| 250 |
+
df1.to_csv(csv_filename1, mode='a', header=False, index=False)
|
| 251 |
+
|
| 252 |
+
# if not os.path.exists(excel_filename):
|
| 253 |
+
# with pd.ExcelWriter(excel_filename, engine='xlsxwriter') as writer:
|
| 254 |
+
# df.to_excel(writer, sheet_name='Sheet1', index=False)
|
| 255 |
+
# else:
|
| 256 |
+
# with pd.ExcelWriter(excel_filename, mode='a', engine='openpyxl') as writer:
|
| 257 |
+
# df.to_excel(writer, sheet_name='Sheet1', index=False, startrow=0, header=False)
|
| 258 |
+
print(f"Results have been saved to {csv_filename1}")
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
|
AIGVDet/test.sh
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
python test.py -fop "data/test/T2V/hotshot" -mop "checkpoints/optical_aug.pth" -for "data/test/original/T2V/hotshot" -mor "checkpoints/original_aug.pth" -e "data/results/T2V/hotshot.csv" -ef "data/results/frame/T2V/hotshot.csv" -t 0.5
|
AIGVDet/train.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from core.utils1.config import cfg # isort: split
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
from tensorboardX import SummaryWriter
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
from core.utils1.datasets import create_dataloader
|
| 10 |
+
from core.utils1.earlystop import EarlyStopping
|
| 11 |
+
from core.utils1.eval import get_val_cfg, validate
|
| 12 |
+
from core.utils1.trainer import Trainer
|
| 13 |
+
from core.utils1.utils import Logger
|
| 14 |
+
|
| 15 |
+
import ssl
|
| 16 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
if __name__ == "__main__":
|
| 20 |
+
val_cfg = get_val_cfg(cfg, split="val", copy=True)
|
| 21 |
+
cfg.dataset_root = os.path.join(cfg.dataset_root, "train")
|
| 22 |
+
data_loader = create_dataloader(cfg)
|
| 23 |
+
dataset_size = len(data_loader)
|
| 24 |
+
|
| 25 |
+
log = Logger()
|
| 26 |
+
log.open(cfg.logs_path, mode="a")
|
| 27 |
+
log.write("Num of training images = %d\n" % (dataset_size * cfg.batch_size))
|
| 28 |
+
log.write("Config:\n" + str(cfg.to_dict()) + "\n")
|
| 29 |
+
|
| 30 |
+
train_writer = SummaryWriter(os.path.join(cfg.exp_dir, "train"))
|
| 31 |
+
val_writer = SummaryWriter(os.path.join(cfg.exp_dir, "val"))
|
| 32 |
+
|
| 33 |
+
trainer = Trainer(cfg)
|
| 34 |
+
early_stopping = EarlyStopping(patience=cfg.earlystop_epoch, delta=-0.001, verbose=True)
|
| 35 |
+
for epoch in range(cfg.nepoch):
|
| 36 |
+
epoch_start_time = time.time()
|
| 37 |
+
iter_data_time = time.time()
|
| 38 |
+
epoch_iter = 0
|
| 39 |
+
|
| 40 |
+
for data in tqdm(data_loader, dynamic_ncols=True):
|
| 41 |
+
trainer.total_steps += 1
|
| 42 |
+
epoch_iter += cfg.batch_size
|
| 43 |
+
|
| 44 |
+
trainer.set_input(data)
|
| 45 |
+
trainer.optimize_parameters()
|
| 46 |
+
|
| 47 |
+
# if trainer.total_steps % cfg.loss_freq == 0:
|
| 48 |
+
# log.write(f"Train loss: {trainer.loss} at step: {trainer.total_steps}\n")
|
| 49 |
+
train_writer.add_scalar("loss", trainer.loss, trainer.total_steps)
|
| 50 |
+
|
| 51 |
+
if trainer.total_steps % cfg.save_latest_freq == 0:
|
| 52 |
+
log.write(
|
| 53 |
+
"saving the latest model %s (epoch %d, model.total_steps %d)\n"
|
| 54 |
+
% (cfg.exp_name, epoch, trainer.total_steps)
|
| 55 |
+
)
|
| 56 |
+
trainer.save_networks("latest")
|
| 57 |
+
|
| 58 |
+
if epoch % cfg.save_epoch_freq == 0:
|
| 59 |
+
log.write("saving the model at the end of epoch %d, iters %d\n" % (epoch, trainer.total_steps))
|
| 60 |
+
trainer.save_networks("latest")
|
| 61 |
+
trainer.save_networks(epoch)
|
| 62 |
+
|
| 63 |
+
# Validation
|
| 64 |
+
trainer.eval()
|
| 65 |
+
val_results = validate(trainer.model, val_cfg)
|
| 66 |
+
val_writer.add_scalar("AP", val_results["AP"], trainer.total_steps)
|
| 67 |
+
val_writer.add_scalar("ACC", val_results["ACC"], trainer.total_steps)
|
| 68 |
+
# add
|
| 69 |
+
val_writer.add_scalar("AUC", val_results["AUC"], trainer.total_steps)
|
| 70 |
+
val_writer.add_scalar("TPR", val_results["TPR"], trainer.total_steps)
|
| 71 |
+
val_writer.add_scalar("TNR", val_results["TNR"], trainer.total_steps)
|
| 72 |
+
|
| 73 |
+
log.write(f"(Val @ epoch {epoch}) AP: {val_results['AP']}; ACC: {val_results['ACC']}\n")
|
| 74 |
+
|
| 75 |
+
if cfg.earlystop:
|
| 76 |
+
early_stopping(val_results["ACC"], trainer)
|
| 77 |
+
if early_stopping.early_stop:
|
| 78 |
+
if trainer.adjust_learning_rate():
|
| 79 |
+
log.write("Learning rate dropped by 10, continue training...\n")
|
| 80 |
+
early_stopping = EarlyStopping(patience=cfg.earlystop_epoch, delta=-0.002, verbose=True)
|
| 81 |
+
else:
|
| 82 |
+
log.write("Early stopping.\n")
|
| 83 |
+
break
|
| 84 |
+
if cfg.warmup:
|
| 85 |
+
# print(trainer.scheduler.get_lr()[0])
|
| 86 |
+
trainer.scheduler.step()
|
| 87 |
+
trainer.train()
|
AIGVDet/train.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
EXP_NAME="moonvalley_vos2_crop"
|
| 2 |
+
DATASETS="moonvalley_vos2_crop"
|
| 3 |
+
DATASETS_TEST="moonvalley_vos2_crop"
|
| 4 |
+
python train.py --gpus 0 --exp_name $EXP_NAME datasets $DATASETS datasets_test $DATASETS_TEST
|
api_server.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import json
|
| 5 |
+
import uuid
|
| 6 |
+
import time
|
| 7 |
+
import threading
|
| 8 |
+
from typing import List, Dict, Any, Optional
|
| 9 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 10 |
+
|
| 11 |
+
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks
|
| 12 |
+
from fastapi.responses import JSONResponse
|
| 13 |
+
from pydantic import BaseModel
|
| 14 |
+
|
| 15 |
+
# Giả lập import các thư viện của bạn
|
| 16 |
+
from miragenews import run_multimodal_to_json
|
| 17 |
+
from AIGVDet import run_video_to_json
|
| 18 |
+
|
| 19 |
+
UPLOAD_DIR = "temp_uploads"
|
| 20 |
+
MAX_WORKERS = 4
|
| 21 |
+
|
| 22 |
+
app = FastAPI(
|
| 23 |
+
title="Multimedia Analysis API (Polling Mode)",
|
| 24 |
+
description="API phân tích đa phương tiện sử dụng cơ chế Polling để tránh Timeout.",
|
| 25 |
+
version="2.0.0",
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
jobs: Dict[str, Dict] = {}
|
| 29 |
+
jobs_lock = threading.Lock()
|
| 30 |
+
|
| 31 |
+
class AnalysisResult(BaseModel):
|
| 32 |
+
image_analysis_results: Optional[List[Any]] = None
|
| 33 |
+
video_analysis_result: Optional[Dict[str, Any]] = None
|
| 34 |
+
|
| 35 |
+
class JobStatus(BaseModel):
|
| 36 |
+
job_id: str
|
| 37 |
+
status: str
|
| 38 |
+
message: Optional[str] = None
|
| 39 |
+
result: Optional[AnalysisResult] = None
|
| 40 |
+
created_at: float
|
| 41 |
+
updated_at: float
|
| 42 |
+
|
| 43 |
+
def _create_job() -> str:
|
| 44 |
+
job_id = uuid.uuid4().hex
|
| 45 |
+
now = time.time()
|
| 46 |
+
with jobs_lock:
|
| 47 |
+
jobs[job_id] = {
|
| 48 |
+
"job_id": job_id,
|
| 49 |
+
"status": "queued",
|
| 50 |
+
"message": "Đang chờ xử lý...",
|
| 51 |
+
"result": None,
|
| 52 |
+
"created_at": now,
|
| 53 |
+
"updated_at": now
|
| 54 |
+
}
|
| 55 |
+
return job_id
|
| 56 |
+
|
| 57 |
+
def _update_job(job_id: str, **kwargs):
|
| 58 |
+
with jobs_lock:
|
| 59 |
+
if job_id in jobs:
|
| 60 |
+
jobs[job_id].update(kwargs)
|
| 61 |
+
jobs[job_id]["updated_at"] = time.time()
|
| 62 |
+
|
| 63 |
+
def _get_job(job_id: str) -> Optional[Dict]:
|
| 64 |
+
with jobs_lock:
|
| 65 |
+
return jobs.get(job_id)
|
| 66 |
+
|
| 67 |
+
async def run_analysis_logic(
|
| 68 |
+
image_paths: Optional[List[str]] = None,
|
| 69 |
+
video_path: Optional[str] = None,
|
| 70 |
+
text: str = "",
|
| 71 |
+
) -> Dict[str, Any]:
|
| 72 |
+
|
| 73 |
+
if not image_paths and not video_path:
|
| 74 |
+
raise ValueError("Cần cung cấp ít nhất một trong hai: image_paths hoặc video_path.")
|
| 75 |
+
|
| 76 |
+
tasks = []
|
| 77 |
+
|
| 78 |
+
if image_paths:
|
| 79 |
+
image_task = asyncio.create_task(
|
| 80 |
+
run_multimodal_to_json(image_paths=image_paths, text=text, output_json_path=None)
|
| 81 |
+
)
|
| 82 |
+
tasks.append(image_task)
|
| 83 |
+
|
| 84 |
+
if video_path:
|
| 85 |
+
video_task = asyncio.to_thread(
|
| 86 |
+
run_video_to_json, video_path=video_path, output_json_path=None
|
| 87 |
+
)
|
| 88 |
+
tasks.append(video_task)
|
| 89 |
+
|
| 90 |
+
task_results = await asyncio.gather(*tasks)
|
| 91 |
+
|
| 92 |
+
final_result = {"image_analysis_results": [], "video_analysis_result": {}}
|
| 93 |
+
image_analysis_results = []
|
| 94 |
+
video_result_index = -1
|
| 95 |
+
|
| 96 |
+
current_idx = 0
|
| 97 |
+
if image_paths:
|
| 98 |
+
image_analysis_results = task_results[current_idx]
|
| 99 |
+
current_idx += 1
|
| 100 |
+
|
| 101 |
+
if video_path:
|
| 102 |
+
video_result_index = current_idx
|
| 103 |
+
|
| 104 |
+
final_result["image_analysis_results"] = image_analysis_results
|
| 105 |
+
|
| 106 |
+
if video_result_index != -1:
|
| 107 |
+
raw_video_result = task_results[video_result_index]
|
| 108 |
+
if raw_video_result:
|
| 109 |
+
video_id_key = list(raw_video_result.keys())[0]
|
| 110 |
+
video_data = raw_video_result[video_id_key]
|
| 111 |
+
|
| 112 |
+
avg_authentic = video_data.get("authentic_confidence_score", 0)
|
| 113 |
+
avg_synthetic = video_data.get("synthetic_confidence_score", 0)
|
| 114 |
+
|
| 115 |
+
if avg_authentic > avg_synthetic and avg_authentic > 0.5:
|
| 116 |
+
authenticity_assessment = "REAL (Authentic)"
|
| 117 |
+
verification_tools = "Deepfake Detector"
|
| 118 |
+
synthetic_type = "N/A"
|
| 119 |
+
other_artifacts = "Our algorithms conducted a thorough analysis of the video's motion patterns, lighting consistency, and object interactions. We observed fluid, natural movements and consistent physics that align with real-world recordings. No discernible artifacts, such as pixel distortion, unnatural blurring, or shadow inconsistencies, were detected that would indicate digital manipulation or AI-driven synthesis."
|
| 120 |
+
elif avg_authentic > avg_synthetic and avg_authentic <= 0.5:
|
| 121 |
+
authenticity_assessment = "Potentially Synthetic"
|
| 122 |
+
verification_tools = "Deepfake Detector"
|
| 123 |
+
synthetic_type = "Potentially AI-generated"
|
| 124 |
+
other_artifacts = "Our analysis has identified subtle anomalies within the video frames, particularly in areas of complex texture and inconsistent lighting across different objects. While these discrepancies are not significant enough to definitively classify the video as synthetic, they do suggest a possibility of digital alteration or partial AI generation. Further examination may be required for a conclusive determination."
|
| 125 |
+
else:
|
| 126 |
+
authenticity_assessment = "NOT REAL (Fake, Manipulated, or AI)"
|
| 127 |
+
verification_tools = "Deepfake Detector"
|
| 128 |
+
synthetic_type = "AI-generated"
|
| 129 |
+
other_artifacts = "Our deep analysis detected multiple, significant artifacts commonly associated with synthetic or manipulated media. These include, but are not limited to, unnatural facial expressions and eye movements, inconsistent or floating shadows, logical impossibilities in object interaction, and high-frequency digital noise characteristic of generative models. These factors strongly indicate that the video is not authentic."
|
| 130 |
+
|
| 131 |
+
final_result["video_analysis_result"] = {
|
| 132 |
+
"filename": video_data.get("video_name", ""),
|
| 133 |
+
"result": {
|
| 134 |
+
"authenticity_assessment": authenticity_assessment,
|
| 135 |
+
"verification_tools_methods": verification_tools,
|
| 136 |
+
"synthetic_type": synthetic_type,
|
| 137 |
+
"other_artifacts": other_artifacts,
|
| 138 |
+
},
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
if not final_result.get("image_analysis_results"):
|
| 142 |
+
final_result.pop("image_analysis_results", None)
|
| 143 |
+
if not final_result.get("video_analysis_result"):
|
| 144 |
+
final_result.pop("video_analysis_result", None)
|
| 145 |
+
|
| 146 |
+
return final_result
|
| 147 |
+
|
| 148 |
+
async def process_job_background(
|
| 149 |
+
job_id: str,
|
| 150 |
+
temp_dir: str,
|
| 151 |
+
image_paths: List[str],
|
| 152 |
+
video_path: str,
|
| 153 |
+
text: str
|
| 154 |
+
):
|
| 155 |
+
"""Hàm chạy ngầm thực hiện phân tích"""
|
| 156 |
+
_update_job(job_id, status="running", message="Đang phân tích...")
|
| 157 |
+
|
| 158 |
+
try:
|
| 159 |
+
result_data = await run_analysis_logic(
|
| 160 |
+
image_paths=image_paths if image_paths else None,
|
| 161 |
+
video_path=video_path if video_path else None,
|
| 162 |
+
text=text
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
_update_job(job_id, status="succeeded", result=result_data, message="Hoàn tất")
|
| 166 |
+
|
| 167 |
+
except Exception as e:
|
| 168 |
+
print(f"Error processing job {job_id}: {e}")
|
| 169 |
+
_update_job(job_id, status="failed", message=str(e))
|
| 170 |
+
|
| 171 |
+
finally:
|
| 172 |
+
try:
|
| 173 |
+
if os.path.exists(temp_dir):
|
| 174 |
+
shutil.rmtree(temp_dir)
|
| 175 |
+
print(f"Deleted temp dir: {temp_dir}")
|
| 176 |
+
except Exception as cleanup_error:
|
| 177 |
+
print(f"Cleanup error for {job_id}: {cleanup_error}")
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
@app.on_event("startup")
|
| 181 |
+
def startup_event():
|
| 182 |
+
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
| 183 |
+
|
| 184 |
+
@app.get("/analyze/{job_id}", response_model=JobStatus)
|
| 185 |
+
async def get_job_status(job_id: str):
|
| 186 |
+
"""Client gọi API này định kỳ để kiểm tra kết quả"""
|
| 187 |
+
job = _get_job(job_id)
|
| 188 |
+
if not job:
|
| 189 |
+
raise HTTPException(status_code=404, detail="Job not found")
|
| 190 |
+
return job
|
| 191 |
+
|
| 192 |
+
@app.post("/analyze/image/", response_model=JobStatus)
|
| 193 |
+
async def analyze_image_endpoint(
|
| 194 |
+
background_tasks: BackgroundTasks,
|
| 195 |
+
images: List[UploadFile] = File(...),
|
| 196 |
+
text: Optional[str] = Form(""),
|
| 197 |
+
):
|
| 198 |
+
job_id = _create_job()
|
| 199 |
+
|
| 200 |
+
job_dir = os.path.join(UPLOAD_DIR, job_id)
|
| 201 |
+
os.makedirs(job_dir, exist_ok=True)
|
| 202 |
+
|
| 203 |
+
saved_image_paths = []
|
| 204 |
+
try:
|
| 205 |
+
for img in images:
|
| 206 |
+
file_path = os.path.join(job_dir, img.filename)
|
| 207 |
+
with open(file_path, "wb") as buffer:
|
| 208 |
+
shutil.copyfileobj(img.file, buffer)
|
| 209 |
+
saved_image_paths.append(file_path)
|
| 210 |
+
except Exception as e:
|
| 211 |
+
_update_job(job_id, status="failed", message=f"Lỗi upload: {e}")
|
| 212 |
+
return _get_job(job_id)
|
| 213 |
+
|
| 214 |
+
background_tasks.add_task(
|
| 215 |
+
process_job_background,
|
| 216 |
+
job_id, job_dir, saved_image_paths, None, text
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
return _get_job(job_id)
|
| 220 |
+
|
| 221 |
+
@app.post("/analyze/video/", response_model=JobStatus)
|
| 222 |
+
async def analyze_video_endpoint(
|
| 223 |
+
background_tasks: BackgroundTasks,
|
| 224 |
+
video: UploadFile = File(...),
|
| 225 |
+
):
|
| 226 |
+
job_id = _create_job()
|
| 227 |
+
job_dir = os.path.join(UPLOAD_DIR, job_id)
|
| 228 |
+
os.makedirs(job_dir, exist_ok=True)
|
| 229 |
+
|
| 230 |
+
saved_video_path = os.path.join(job_dir, video.filename)
|
| 231 |
+
try:
|
| 232 |
+
with open(saved_video_path, "wb") as buffer:
|
| 233 |
+
shutil.copyfileobj(video.file, buffer)
|
| 234 |
+
except Exception as e:
|
| 235 |
+
_update_job(job_id, status="failed", message=f"Lỗi upload: {e}")
|
| 236 |
+
return _get_job(job_id)
|
| 237 |
+
|
| 238 |
+
background_tasks.add_task(
|
| 239 |
+
process_job_background,
|
| 240 |
+
job_id, job_dir, [], saved_video_path, ""
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
return _get_job(job_id)
|
| 244 |
+
|
| 245 |
+
@app.post("/analyze/multimodal/", response_model=JobStatus)
|
| 246 |
+
async def analyze_multimodal_endpoint(
|
| 247 |
+
background_tasks: BackgroundTasks,
|
| 248 |
+
images: List[UploadFile] = File(...),
|
| 249 |
+
video: UploadFile = File(...),
|
| 250 |
+
text: Optional[str] = Form(""),
|
| 251 |
+
):
|
| 252 |
+
job_id = _create_job()
|
| 253 |
+
job_dir = os.path.join(UPLOAD_DIR, job_id)
|
| 254 |
+
os.makedirs(job_dir, exist_ok=True)
|
| 255 |
+
|
| 256 |
+
saved_image_paths = []
|
| 257 |
+
saved_video_path = None
|
| 258 |
+
|
| 259 |
+
try:
|
| 260 |
+
# Save Images
|
| 261 |
+
for img in images:
|
| 262 |
+
file_path = os.path.join(job_dir, img.filename)
|
| 263 |
+
with open(file_path, "wb") as buffer:
|
| 264 |
+
shutil.copyfileobj(img.file, buffer)
|
| 265 |
+
saved_image_paths.append(file_path)
|
| 266 |
+
|
| 267 |
+
# Save Video
|
| 268 |
+
saved_video_path = os.path.join(job_dir, video.filename)
|
| 269 |
+
with open(saved_video_path, "wb") as buffer:
|
| 270 |
+
shutil.copyfileobj(video.file, buffer)
|
| 271 |
+
|
| 272 |
+
except Exception as e:
|
| 273 |
+
_update_job(job_id, status="failed", message=f"Lỗi upload: {e}")
|
| 274 |
+
return _get_job(job_id)
|
| 275 |
+
|
| 276 |
+
background_tasks.add_task(
|
| 277 |
+
process_job_background,
|
| 278 |
+
job_id, job_dir, saved_image_paths, saved_video_path, text
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
return _get_job(job_id)
|
checkpoints/image/best-mirage-img.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:755e26bfe97830e03e4c475596ade7168e1df18c57d89162d685b1148ae9f5a8
|
| 3 |
+
size 2375
|
checkpoints/image/cbm-encoder.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2a90c02450ae4a105414c6701bf571d1a498571083a725d916119d792cbdcf2f
|
| 3 |
+
size 1917947
|
checkpoints/image/cbm-predictor.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2eba7ff0cd8db1c69961478551b2a40e47687f281d7e853f95a6650908e8c3df
|
| 3 |
+
size 2387
|