Lucii1 commited on
Commit
8aeb9ae
·
1 Parent(s): d1c13b5

Add source code and model

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +70 -0
  2. AIGVDet/Dockerfile +21 -0
  3. AIGVDet/README.md +94 -0
  4. AIGVDet/__init__.py +3 -0
  5. AIGVDet/alt_cuda_corr/correlation.cpp +54 -0
  6. AIGVDet/alt_cuda_corr/correlation_kernel.cu +324 -0
  7. AIGVDet/alt_cuda_corr/setup.py +15 -0
  8. AIGVDet/app.py +155 -0
  9. AIGVDet/checkpoints/optical.pth +3 -0
  10. AIGVDet/checkpoints/original.pth +3 -0
  11. AIGVDet/core/__init__.py +0 -0
  12. AIGVDet/core/corr.py +91 -0
  13. AIGVDet/core/datasets.py +235 -0
  14. AIGVDet/core/extractor.py +267 -0
  15. AIGVDet/core/raft.py +144 -0
  16. AIGVDet/core/update.py +139 -0
  17. AIGVDet/core/utils/__init__.py +0 -0
  18. AIGVDet/core/utils/augmentor.py +246 -0
  19. AIGVDet/core/utils/flow_viz.py +132 -0
  20. AIGVDet/core/utils/frame_utils.py +137 -0
  21. AIGVDet/core/utils/utils.py +82 -0
  22. AIGVDet/core/utils1/config.py +156 -0
  23. AIGVDet/core/utils1/datasets.py +178 -0
  24. AIGVDet/core/utils1/earlystop.py +46 -0
  25. AIGVDet/core/utils1/eval.py +66 -0
  26. AIGVDet/core/utils1/trainer.py +163 -0
  27. AIGVDet/core/utils1/utils.py +109 -0
  28. AIGVDet/core/utils1/utils1/config.py +157 -0
  29. AIGVDet/core/utils1/utils1/datasets.py +178 -0
  30. AIGVDet/core/utils1/utils1/earlystop.py +46 -0
  31. AIGVDet/core/utils1/utils1/eval.py +66 -0
  32. AIGVDet/core/utils1/utils1/trainer.py +169 -0
  33. AIGVDet/core/utils1/utils1/utils.py +109 -0
  34. AIGVDet/core/utils1/utils1/warmup.py +70 -0
  35. AIGVDet/core/utils1/warmup.py +70 -0
  36. AIGVDet/docker-compose.yml +17 -0
  37. AIGVDet/main.py +78 -0
  38. AIGVDet/networks/resnet.py +211 -0
  39. AIGVDet/raft_model/raft-things.pth +3 -0
  40. AIGVDet/requirements.txt +22 -0
  41. AIGVDet/run.py +214 -0
  42. AIGVDet/run.sh +1 -0
  43. AIGVDet/test.py +261 -0
  44. AIGVDet/test.sh +1 -0
  45. AIGVDet/train.py +87 -0
  46. AIGVDet/train.sh +4 -0
  47. api_server.py +281 -0
  48. checkpoints/image/best-mirage-img.pt +3 -0
  49. checkpoints/image/cbm-encoder.pt +3 -0
  50. checkpoints/image/cbm-predictor.pt +3 -0
.gitignore ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ AIGVDet/frame/
2
+ AIGVDet/optical_result/
3
+ AIGVDet/temp/
4
+ frame/
5
+ optical_result/
6
+ temp/
7
+ temp_uploads/
8
+
9
+ miragenews/encodings/crops/
10
+ miragenews/encodings/image/test1_nyt_mj/
11
+ miragenews/encodings/image/test2_bbc_dalle/
12
+ miragenews/encodings/image/test3_cnn_dalle/
13
+ miragenews/encodings/image/test4_bbc_sdxl/
14
+ miragenews/encodings/image/test5_cnn_sdxl/
15
+ miragenews/encodings/image/train/
16
+ miragenews/encodings/image/validation/
17
+ miragenews/encodings/predictions/image/cbm-encoder/test1_nyt_mj/
18
+ miragenews/encodings/predictions/image/cbm-encoder/test2_bbc_dalle/
19
+ miragenews/encodings/predictions/image/cbm-encoder/test3_cnn_dalle/
20
+ miragenews/encodings/predictions/image/cbm-encoder/test4_bbc_sdxl/
21
+ miragenews/encodings/predictions/image/cbm-encoder/test5_cnn_sdxl/
22
+ miragenews/encodings/predictions/image/cbm-encoder/train/
23
+ miragenews/encodings/predictions/image/cbm-encoder/validation/
24
+ miragenews/encodings/predictions/image/linear/validation/
25
+ miragenews/encodings/predictions/image/linear/test1_nyt_mj/
26
+ miragenews/encodings/predictions/image/linear/test2_bbc_dalle/
27
+ miragenews/encodings/predictions/image/linear/test3_cnn_dalle/
28
+ miragenews/encodings/predictions/image/linear/test4_bbc_sdxl/
29
+ miragenews/encodings/predictions/image/linear/test5_cnn_sdxl/
30
+ miragenews/encodings/predictions/image/linear/train/
31
+ miragenews/encodings/predictions/image/merged/train/
32
+ miragenews/encodings/predictions/image/merged/test1_nyt_mj/
33
+ miragenews/encodings/predictions/image/merged/test2_bbc_dalle/
34
+ miragenews/encodings/predictions/image/merged/test3_cnn_dalle/
35
+ miragenews/encodings/predictions/image/merged/test4_bbc_sdxl/
36
+ miragenews/encodings/predictions/image/merged/test5_cnn_sdxl/
37
+ miragenews/encodings/predictions/image/merged/validation/
38
+
39
+ miragenews/encodings/text/train/
40
+ miragenews/encodings/text/test1_nyt_mj/
41
+ miragenews/encodings/text/test2_bbc_dalle/
42
+ miragenews/encodings/text/test3_cnn_dalle/
43
+ miragenews/encodings/text/test4_bbc_sdxl/
44
+ miragenews/encodings/text/test5_cnn_sdxl/
45
+ miragenews/encodings/text/validation/
46
+
47
+ miragenews/encodings/predictions/text/merged/train/
48
+ miragenews/encodings/predictions/text/merged/test1_nyt_mj/
49
+ miragenews/encodings/predictions/text/merged/test2_bbc_dalle/
50
+ miragenews/encodings/predictions/text/merged/test3_cnn_dalle/
51
+ miragenews/encodings/predictions/text/merged/test4_bbc_sdxl/
52
+ miragenews/encodings/predictions/text/merged/test5_cnn_sdxl/
53
+ miragenews/encodings/predictions/text/merged/validation/
54
+
55
+ miragenews/encodings/predictions/text/tbm-encoder/train/
56
+ miragenews/encodings/predictions/text/tbm-encoder/test1_nyt_mj/
57
+ miragenews/encodings/predictions/text/tbm-encoder/test2_bbc_dalle/
58
+ miragenews/encodings/predictions/text/tbm-encoder/test3_cnn_dalle/
59
+ miragenews/encodings/predictions/text/tbm-encoder/test4_bbc_sdxl/
60
+ miragenews/encodings/predictions/text/tbm-encoder/test5_cnn_sdxl/
61
+ miragenews/encodings/predictions/text/tbm-encoder/validation/
62
+
63
+ miragenews/encodings/predictions/text/linear/train/
64
+ miragenews/encodings/predictions/text/linear/test1_nyt_mj/
65
+ miragenews/encodings/predictions/text/linear/test2_bbc_dalle/
66
+ miragenews/encodings/predictions/text/linear/test3_cnn_dalle/
67
+ miragenews/encodings/predictions/text/linear/test4_bbc_sdxl/
68
+ miragenews/encodings/predictions/text/linear/test5_cnn_sdxl/
69
+ miragenews/encodings/predictions/text/linear/validation/
70
+ AIGVDet/data/
AIGVDet/Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-runtime
2
+
3
+ # Install necessary OS packages for OpenCV
4
+ RUN apt-get update && apt-get install -y \
5
+ libgl1 \
6
+ libglib2.0-0 \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ # Set working directory
10
+ WORKDIR /app
11
+
12
+ # Install Python dependencies
13
+ COPY requirements.txt .
14
+ RUN pip install --no-cache-dir -r requirements.txt
15
+
16
+ # Copy all source code
17
+ COPY . .
18
+
19
+ # Default run command
20
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8003"]
21
+
AIGVDet/README.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## AIGVDet
2
+ An official implementation code for paper "AI-Generated Video Detection via Spatial-Temporal Anomaly Learning", PRCV 2024. This repo will provide <B>codes, trained weights, and our training datasets</B>.
3
+
4
+ ## Network Architecture
5
+ <center> <img src="fig/NetworkArchitecture.png" alt="architecture"/> </center>
6
+
7
+ ## Dataset
8
+ - Download the preprocessed training frames from
9
+ [Baiduyun Link](https://pan.baidu.com/s/17xmDyFjtcmNsoxmUeImMTQ?pwd=ra95) (extract code: ra95).
10
+ - Download the test videos from [Google Drive](https://drive.google.com/drive/folders/1D84SRWEJ8BK8KBpTMuGi3BUM80mW_dKb?usp=sharing).
11
+
12
+ **You are allowed to use the datasets for <B>research purpose only</B>.**
13
+
14
+ ## Training
15
+ - Prepare for the training datasets.
16
+ ```
17
+ └─data
18
+ ├── train
19
+ │   └── trainset_1
20
+ │   ├── 0_real
21
+ │   │ ├── video_00000
22
+ │   │ │ ├── 00000.png
23
+ │   │ │ └── ...
24
+ │   │ └── ...
25
+ │   └── 1_fake
26
+ │   ├── video_00000
27
+ │   │ ├── 00000.png
28
+ │   │ └── ...
29
+ │   └── ...
30
+ ├── val
31
+ │   └── val_set_1
32
+ │   ├── 0_real
33
+ │   │ ├── video_00000
34
+ │   │ │ ├── 00000.png
35
+ │   │ │ └── ...
36
+ │   │ └── ...
37
+ │   └── 1_fake
38
+ │   ├── video_00000
39
+ │   │ ├── 00000.png
40
+ │   │ └── ...
41
+ │   └── ...
42
+ └── test
43
+    └── testset_1
44
+    ├── 0_real
45
+    │ ├── video_00000
46
+    │ │ ├── 00000.png
47
+    │ │ └── ...
48
+    │ └── ...
49
+    └── 1_fake
50
+    ├── video_00000
51
+    │ ├── 00000.png
52
+    │ └── ...
53
+    └── ...
54
+
55
+ ```
56
+ - Modify configuration file in `core/utils1/config.py`.
57
+ - Train the Spatial Domain Detector with the RGB frames.
58
+ ```
59
+ python train.py --gpus 0 --exp_name TRAIN_RGB_BRANCH datasets RGB_TRAINSET datasets_test RGB_TESTSET
60
+ ```
61
+ - Train the Optical Flow Detector with the optical flow frames.
62
+ ```
63
+ python train.py --gpus 0 --exp_name TRAIN_OF_BRANCH datasets OpticalFlow_TRAINSET datasets_test OpticalFlow_TESTSET
64
+ ```
65
+ ## Testing
66
+ Download the weights from [Google Drive Link](https://drive.google.com/drive/folders/18JO_YxOEqwJYfbVvy308XjoV-N6fE4yP?usp=share_link) and move it into the `checkpoints/`.
67
+
68
+ - Run on a dataset.
69
+ Prepare the RGB frames and the optical flow maps.
70
+ ```
71
+ python test.py -fop "data/test/hotshot" -mop "checkpoints/optical_aug.pth" -for "data/test/original/hotshot" -mor "checkpoints/original_aug.pth" -e "data/results/T2V/hotshot.csv" -ef "data/results/frame/T2V/hotshot.csv" -t 0.5
72
+ ```
73
+ - Run on a video.
74
+ Download the RAFT model weights from [Google Drive Link](https://drive.google.com/file/d/1MqDajR89k-xLV0HIrmJ0k-n8ZpG6_suM/view) and move it into the `raft_model/`.
75
+ ```
76
+ python demo.py --use_cpu --path "video/000000.mp4" --folder_original_path "frame/000000" --folder_optical_flow_path "optical_result/000000" -mop "checkpoints/optical.pth" -mor "checkpoints/original.pth"
77
+ ```
78
+
79
+ ## License
80
+ The code and dataset is released only for academic research. Commercial usage is strictly prohibited.
81
+
82
+ ## Citation
83
+ ```
84
+ @article{AIGVDet24,
85
+ author = {Jianfa Bai and Man Lin and Gang Cao and Zijie Lou},
86
+ title = {{AI-generated video detection via spatial-temporal anomaly learning}},
87
+ conference = {The 7th Chinese Conference on Pattern Recognition and Computer Vision (PRCV)},
88
+ year = {2024},}
89
+ ```
90
+
91
+ ## Contact
92
+ If you have any questions, please contact us(lyan924@cuc.edu.cn).
93
+
94
+
AIGVDet/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .main import run_video_to_json
2
+
3
+ __all__ = ["run_video_to_json"]
AIGVDet/alt_cuda_corr/correlation.cpp ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <vector>
3
+
4
+ // CUDA forward declarations
5
+ std::vector<torch::Tensor> corr_cuda_forward(
6
+ torch::Tensor fmap1,
7
+ torch::Tensor fmap2,
8
+ torch::Tensor coords,
9
+ int radius);
10
+
11
+ std::vector<torch::Tensor> corr_cuda_backward(
12
+ torch::Tensor fmap1,
13
+ torch::Tensor fmap2,
14
+ torch::Tensor coords,
15
+ torch::Tensor corr_grad,
16
+ int radius);
17
+
18
+ // C++ interface
19
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
20
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
21
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
22
+
23
+ std::vector<torch::Tensor> corr_forward(
24
+ torch::Tensor fmap1,
25
+ torch::Tensor fmap2,
26
+ torch::Tensor coords,
27
+ int radius) {
28
+ CHECK_INPUT(fmap1);
29
+ CHECK_INPUT(fmap2);
30
+ CHECK_INPUT(coords);
31
+
32
+ return corr_cuda_forward(fmap1, fmap2, coords, radius);
33
+ }
34
+
35
+
36
+ std::vector<torch::Tensor> corr_backward(
37
+ torch::Tensor fmap1,
38
+ torch::Tensor fmap2,
39
+ torch::Tensor coords,
40
+ torch::Tensor corr_grad,
41
+ int radius) {
42
+ CHECK_INPUT(fmap1);
43
+ CHECK_INPUT(fmap2);
44
+ CHECK_INPUT(coords);
45
+ CHECK_INPUT(corr_grad);
46
+
47
+ return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
48
+ }
49
+
50
+
51
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
52
+ m.def("forward", &corr_forward, "CORR forward");
53
+ m.def("backward", &corr_backward, "CORR backward");
54
+ }
AIGVDet/alt_cuda_corr/correlation_kernel.cu ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <cuda.h>
3
+ #include <cuda_runtime.h>
4
+ #include <vector>
5
+
6
+
7
+ #define BLOCK_H 4
8
+ #define BLOCK_W 8
9
+ #define BLOCK_HW BLOCK_H * BLOCK_W
10
+ #define CHANNEL_STRIDE 32
11
+
12
+
13
+ __forceinline__ __device__
14
+ bool within_bounds(int h, int w, int H, int W) {
15
+ return h >= 0 && h < H && w >= 0 && w < W;
16
+ }
17
+
18
+ template <typename scalar_t>
19
+ __global__ void corr_forward_kernel(
20
+ const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,
21
+ const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,
22
+ const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords,
23
+ torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr,
24
+ int r)
25
+ {
26
+ const int b = blockIdx.x;
27
+ const int h0 = blockIdx.y * blockDim.x;
28
+ const int w0 = blockIdx.z * blockDim.y;
29
+ const int tid = threadIdx.x * blockDim.y + threadIdx.y;
30
+
31
+ const int H1 = fmap1.size(1);
32
+ const int W1 = fmap1.size(2);
33
+ const int H2 = fmap2.size(1);
34
+ const int W2 = fmap2.size(2);
35
+ const int N = coords.size(1);
36
+ const int C = fmap1.size(3);
37
+
38
+ __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
39
+ __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
40
+ __shared__ scalar_t x2s[BLOCK_HW];
41
+ __shared__ scalar_t y2s[BLOCK_HW];
42
+
43
+ for (int c=0; c<C; c+=CHANNEL_STRIDE) {
44
+ for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
45
+ int k1 = k + tid / CHANNEL_STRIDE;
46
+ int h1 = h0 + k1 / BLOCK_W;
47
+ int w1 = w0 + k1 % BLOCK_W;
48
+ int c1 = tid % CHANNEL_STRIDE;
49
+
50
+ auto fptr = fmap1[b][h1][w1];
51
+ if (within_bounds(h1, w1, H1, W1))
52
+ f1[c1][k1] = fptr[c+c1];
53
+ else
54
+ f1[c1][k1] = 0.0;
55
+ }
56
+
57
+ __syncthreads();
58
+
59
+ for (int n=0; n<N; n++) {
60
+ int h1 = h0 + threadIdx.x;
61
+ int w1 = w0 + threadIdx.y;
62
+ if (within_bounds(h1, w1, H1, W1)) {
63
+ x2s[tid] = coords[b][n][h1][w1][0];
64
+ y2s[tid] = coords[b][n][h1][w1][1];
65
+ }
66
+
67
+ scalar_t dx = x2s[tid] - floor(x2s[tid]);
68
+ scalar_t dy = y2s[tid] - floor(y2s[tid]);
69
+
70
+ int rd = 2*r + 1;
71
+ for (int iy=0; iy<rd+1; iy++) {
72
+ for (int ix=0; ix<rd+1; ix++) {
73
+ for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
74
+ int k1 = k + tid / CHANNEL_STRIDE;
75
+ int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
76
+ int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
77
+ int c2 = tid % CHANNEL_STRIDE;
78
+
79
+ auto fptr = fmap2[b][h2][w2];
80
+ if (within_bounds(h2, w2, H2, W2))
81
+ f2[c2][k1] = fptr[c+c2];
82
+ else
83
+ f2[c2][k1] = 0.0;
84
+ }
85
+
86
+ __syncthreads();
87
+
88
+ scalar_t s = 0.0;
89
+ for (int k=0; k<CHANNEL_STRIDE; k++)
90
+ s += f1[k][tid] * f2[k][tid];
91
+
92
+ int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
93
+ int ix_ne = H1*W1*((iy-1) + rd*ix);
94
+ int ix_sw = H1*W1*(iy + rd*(ix-1));
95
+ int ix_se = H1*W1*(iy + rd*ix);
96
+
97
+ scalar_t nw = s * (dy) * (dx);
98
+ scalar_t ne = s * (dy) * (1-dx);
99
+ scalar_t sw = s * (1-dy) * (dx);
100
+ scalar_t se = s * (1-dy) * (1-dx);
101
+
102
+ scalar_t* corr_ptr = &corr[b][n][0][h1][w1];
103
+
104
+ if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
105
+ *(corr_ptr + ix_nw) += nw;
106
+
107
+ if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
108
+ *(corr_ptr + ix_ne) += ne;
109
+
110
+ if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
111
+ *(corr_ptr + ix_sw) += sw;
112
+
113
+ if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
114
+ *(corr_ptr + ix_se) += se;
115
+ }
116
+ }
117
+ }
118
+ }
119
+ }
120
+
121
+
122
+ template <typename scalar_t>
123
+ __global__ void corr_backward_kernel(
124
+ const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,
125
+ const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,
126
+ const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords,
127
+ const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr_grad,
128
+ torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1_grad,
129
+ torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2_grad,
130
+ torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords_grad,
131
+ int r)
132
+ {
133
+
134
+ const int b = blockIdx.x;
135
+ const int h0 = blockIdx.y * blockDim.x;
136
+ const int w0 = blockIdx.z * blockDim.y;
137
+ const int tid = threadIdx.x * blockDim.y + threadIdx.y;
138
+
139
+ const int H1 = fmap1.size(1);
140
+ const int W1 = fmap1.size(2);
141
+ const int H2 = fmap2.size(1);
142
+ const int W2 = fmap2.size(2);
143
+ const int N = coords.size(1);
144
+ const int C = fmap1.size(3);
145
+
146
+ __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
147
+ __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
148
+
149
+ __shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1];
150
+ __shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1];
151
+
152
+ __shared__ scalar_t x2s[BLOCK_HW];
153
+ __shared__ scalar_t y2s[BLOCK_HW];
154
+
155
+ for (int c=0; c<C; c+=CHANNEL_STRIDE) {
156
+
157
+ for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
158
+ int k1 = k + tid / CHANNEL_STRIDE;
159
+ int h1 = h0 + k1 / BLOCK_W;
160
+ int w1 = w0 + k1 % BLOCK_W;
161
+ int c1 = tid % CHANNEL_STRIDE;
162
+
163
+ auto fptr = fmap1[b][h1][w1];
164
+ if (within_bounds(h1, w1, H1, W1))
165
+ f1[c1][k1] = fptr[c+c1];
166
+ else
167
+ f1[c1][k1] = 0.0;
168
+
169
+ f1_grad[c1][k1] = 0.0;
170
+ }
171
+
172
+ __syncthreads();
173
+
174
+ int h1 = h0 + threadIdx.x;
175
+ int w1 = w0 + threadIdx.y;
176
+
177
+ for (int n=0; n<N; n++) {
178
+ x2s[tid] = coords[b][n][h1][w1][0];
179
+ y2s[tid] = coords[b][n][h1][w1][1];
180
+
181
+ scalar_t dx = x2s[tid] - floor(x2s[tid]);
182
+ scalar_t dy = y2s[tid] - floor(y2s[tid]);
183
+
184
+ int rd = 2*r + 1;
185
+ for (int iy=0; iy<rd+1; iy++) {
186
+ for (int ix=0; ix<rd+1; ix++) {
187
+ for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
188
+ int k1 = k + tid / CHANNEL_STRIDE;
189
+ int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
190
+ int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
191
+ int c2 = tid % CHANNEL_STRIDE;
192
+
193
+ auto fptr = fmap2[b][h2][w2];
194
+ if (within_bounds(h2, w2, H2, W2))
195
+ f2[c2][k1] = fptr[c+c2];
196
+ else
197
+ f2[c2][k1] = 0.0;
198
+
199
+ f2_grad[c2][k1] = 0.0;
200
+ }
201
+
202
+ __syncthreads();
203
+
204
+ const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1];
205
+ scalar_t g = 0.0;
206
+
207
+ int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
208
+ int ix_ne = H1*W1*((iy-1) + rd*ix);
209
+ int ix_sw = H1*W1*(iy + rd*(ix-1));
210
+ int ix_se = H1*W1*(iy + rd*ix);
211
+
212
+ if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
213
+ g += *(grad_ptr + ix_nw) * dy * dx;
214
+
215
+ if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
216
+ g += *(grad_ptr + ix_ne) * dy * (1-dx);
217
+
218
+ if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
219
+ g += *(grad_ptr + ix_sw) * (1-dy) * dx;
220
+
221
+ if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
222
+ g += *(grad_ptr + ix_se) * (1-dy) * (1-dx);
223
+
224
+ for (int k=0; k<CHANNEL_STRIDE; k++) {
225
+ f1_grad[k][tid] += g * f2[k][tid];
226
+ f2_grad[k][tid] += g * f1[k][tid];
227
+ }
228
+
229
+ for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
230
+ int k1 = k + tid / CHANNEL_STRIDE;
231
+ int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
232
+ int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
233
+ int c2 = tid % CHANNEL_STRIDE;
234
+
235
+ scalar_t* fptr = &fmap2_grad[b][h2][w2][0];
236
+ if (within_bounds(h2, w2, H2, W2))
237
+ atomicAdd(fptr+c+c2, f2_grad[c2][k1]);
238
+ }
239
+ }
240
+ }
241
+ }
242
+ __syncthreads();
243
+
244
+
245
+ for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
246
+ int k1 = k + tid / CHANNEL_STRIDE;
247
+ int h1 = h0 + k1 / BLOCK_W;
248
+ int w1 = w0 + k1 % BLOCK_W;
249
+ int c1 = tid % CHANNEL_STRIDE;
250
+
251
+ scalar_t* fptr = &fmap1_grad[b][h1][w1][0];
252
+ if (within_bounds(h1, w1, H1, W1))
253
+ fptr[c+c1] += f1_grad[c1][k1];
254
+ }
255
+ }
256
+ }
257
+
258
+
259
+
260
+ std::vector<torch::Tensor> corr_cuda_forward(
261
+ torch::Tensor fmap1,
262
+ torch::Tensor fmap2,
263
+ torch::Tensor coords,
264
+ int radius)
265
+ {
266
+ const auto B = coords.size(0);
267
+ const auto N = coords.size(1);
268
+ const auto H = coords.size(2);
269
+ const auto W = coords.size(3);
270
+
271
+ const auto rd = 2 * radius + 1;
272
+ auto opts = fmap1.options();
273
+ auto corr = torch::zeros({B, N, rd*rd, H, W}, opts);
274
+
275
+ const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W);
276
+ const dim3 threads(BLOCK_H, BLOCK_W);
277
+
278
+ corr_forward_kernel<float><<<blocks, threads>>>(
279
+ fmap1.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
280
+ fmap2.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
281
+ coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
282
+ corr.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
283
+ radius);
284
+
285
+ return {corr};
286
+ }
287
+
288
+ std::vector<torch::Tensor> corr_cuda_backward(
289
+ torch::Tensor fmap1,
290
+ torch::Tensor fmap2,
291
+ torch::Tensor coords,
292
+ torch::Tensor corr_grad,
293
+ int radius)
294
+ {
295
+ const auto B = coords.size(0);
296
+ const auto N = coords.size(1);
297
+
298
+ const auto H1 = fmap1.size(1);
299
+ const auto W1 = fmap1.size(2);
300
+ const auto H2 = fmap2.size(1);
301
+ const auto W2 = fmap2.size(2);
302
+ const auto C = fmap1.size(3);
303
+
304
+ auto opts = fmap1.options();
305
+ auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts);
306
+ auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts);
307
+ auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts);
308
+
309
+ const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W);
310
+ const dim3 threads(BLOCK_H, BLOCK_W);
311
+
312
+
313
+ corr_backward_kernel<float><<<blocks, threads>>>(
314
+ fmap1.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
315
+ fmap2.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
316
+ coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
317
+ corr_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
318
+ fmap1_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
319
+ fmap2_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
320
+ coords_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
321
+ radius);
322
+
323
+ return {fmap1_grad, fmap2_grad, coords_grad};
324
+ }
AIGVDet/alt_cuda_corr/setup.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3
+
4
+
5
+ setup(
6
+ name='correlation',
7
+ ext_modules=[
8
+ CUDAExtension('alt_cuda_corr',
9
+ sources=['correlation.cpp', 'correlation_kernel.cu'],
10
+ extra_compile_args={'cxx': [], 'nvcc': ['-O3']}),
11
+ ],
12
+ cmdclass={
13
+ 'build_ext': BuildExtension
14
+ })
15
+
AIGVDet/app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import shutil
4
+ import uuid
5
+ import time
6
+ from typing import Optional
7
+
8
+ from fastapi import FastAPI, UploadFile, File, HTTPException, status
9
+ from pydantic import BaseModel
10
+ from run import RUN
11
+
12
+
13
+ class PredictionResponse(BaseModel):
14
+ authenticity_assessment: str
15
+ verification_tools_methods: str
16
+ synthetic_type: str
17
+ other_artifacts: str
18
+
19
+ app = FastAPI(
20
+ title="Video Authenticity API",
21
+ description="Detect authentic vs synthetic video using deepfake detector.",
22
+ )
23
+
24
+ def predict(input_path: str) -> dict:
25
+ result_data = {
26
+ "authenticity_assessment": "Error",
27
+ "verification_tools_methods": "Deepfake Detection Model (Optical Flow + Frame Analysis)",
28
+ "synthetic_type": "N/A",
29
+ "other_artifacts": "Analysis failed."
30
+ }
31
+
32
+ if not os.path.isfile(input_path):
33
+ raise FileNotFoundError(f"File not found: {input_path}")
34
+
35
+ video_name = os.path.basename(input_path)
36
+ video_id = os.path.splitext(video_name)[0]
37
+
38
+ folder_original = f"frame/{video_id}"
39
+ folder_optical = f"optical_result/{video_id}"
40
+
41
+ args = [
42
+ "--path", input_path,
43
+ "--folder_original_path", folder_original,
44
+ "--folder_optical_flow_path", folder_optical,
45
+ "--model_optical_flow_path", "checkpoints/optical.pth",
46
+ "--model_original_path", "checkpoints/original.pth",
47
+ ]
48
+
49
+ try:
50
+ start_time = time.perf_counter()
51
+
52
+ output = RUN(args)
53
+
54
+ elapsed = time.perf_counter() - start_time
55
+ print(f"⏱️ [run_AIGVDetection] Service call took {elapsed:.2f} seconds")
56
+
57
+ real_score = float(output.get("real_score", 0.0))
58
+ fake_score = float(output.get("fake_score", 0.0))
59
+
60
+ likely_authentic = real_score > fake_score
61
+
62
+ if likely_authentic:
63
+ assessment = f"REAL (Authentic) | Confidence: {real_score:.4f}"
64
+
65
+ analysis_text = (
66
+ "Our algorithms observed consistent and natural motion patterns across frames. "
67
+ "Inter-frame motion analysis indicates that objects maintain physical trajectories consistent with real-world recording, "
68
+ "without the jitter or warping artifacts typically associated with generative AI."
69
+ )
70
+ syn_type = "N/A"
71
+ else:
72
+ assessment = f"🤖 NOT REAL (Fake/Synthetic) | Confidence: {fake_score:.4f}"
73
+
74
+ analysis_text = (
75
+ "Our algorithms have detected asynchronous and inconsistent movement between frames. "
76
+ "Upon conducting inter-frame motion analysis, we observed that objects and details within the video "
77
+ "fail to maintain natural motion trajectories. These anomalies—such as sudden velocity shifts, "
78
+ "subtle per-frame distortions, or motion vectors that defy physical laws—are characteristic indicators "
79
+ "typically found in AI-generated videos."
80
+ )
81
+ syn_type = "Video Deepfake / AI Generated"
82
+
83
+ tools = "Deepfake Detector (Optical Flow + CNN Frame Analysis)"
84
+
85
+ artifacts = (
86
+ f"{analysis_text}"
87
+ )
88
+
89
+ result_data = {
90
+ "authenticity_assessment": assessment,
91
+ "verification_tools_methods": tools,
92
+ "synthetic_type": syn_type,
93
+ "other_artifacts": artifacts
94
+ }
95
+
96
+ except Exception as e:
97
+ import traceback
98
+ traceback.print_exc()
99
+ result_data["other_artifacts"] = f"Error during processing: {str(e)}"
100
+
101
+ finally:
102
+ for folder in [folder_original, folder_optical]:
103
+ try:
104
+ if os.path.exists(folder):
105
+ shutil.rmtree(folder)
106
+ except Exception as cleanup_error:
107
+ print(f"Error deleting folder {folder}: {cleanup_error}")
108
+
109
+ return result_data
110
+
111
+ # --- API ENDPOINT ---
112
+ @app.post("/predict", response_model=PredictionResponse)
113
+ async def predict_endpoint(file: UploadFile = File(...)):
114
+ try:
115
+ for parent in ["frame", "optical_result", "uploads"]:
116
+ if os.path.exists(parent):
117
+ print(f"🧹 Cleaning folder: {parent}")
118
+ for item in os.listdir(parent):
119
+ path = os.path.join(parent, item)
120
+ try:
121
+ if os.path.isfile(path) or os.path.islink(path):
122
+ os.remove(path)
123
+ elif os.path.isdir(path):
124
+ shutil.rmtree(path)
125
+ except Exception as e:
126
+ print(f"⚠️ Cannot delete {path}: {e}")
127
+ else:
128
+ os.makedirs(parent)
129
+ print(f"📁 Created new folder: {parent}")
130
+
131
+ temp_filename = f"uploads_{uuid.uuid4().hex}_{file.filename}"
132
+ os.makedirs("uploads", exist_ok=True)
133
+ temp_filepath = os.path.join("uploads", temp_filename)
134
+
135
+ with open(temp_filepath, "wb") as buffer:
136
+ shutil.copyfileobj(file.file, buffer)
137
+
138
+ result = predict(temp_filepath)
139
+
140
+ if os.path.exists(temp_filepath):
141
+ os.remove(temp_filepath)
142
+
143
+ return result
144
+
145
+ except Exception as e:
146
+ import traceback
147
+ traceback.print_exc()
148
+ raise HTTPException(
149
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
150
+ detail=f"Prediction failed: {e}",
151
+ )
152
+
153
+ if __name__ == "__main__":
154
+ import uvicorn
155
+ uvicorn.run(app, host="0.0.0.0", port=80022)
AIGVDet/checkpoints/optical.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23a167ba7adccb6421fd0eddd81fbfc69519d2a97a343dbf0d5da894b9893b19
3
+ size 282581704
AIGVDet/checkpoints/original.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2df6f477a590f4b24b14eac9654b868a9f178312795df20749274a502d59bdd
3
+ size 282581704
AIGVDet/core/__init__.py ADDED
File without changes
AIGVDet/core/corr.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from .utils.utils import bilinear_sampler, coords_grid
4
+
5
+ try:
6
+ from .. import alt_cuda_corr
7
+ except:
8
+ # alt_cuda_corr is not compiled
9
+ pass
10
+
11
+
12
+ class CorrBlock:
13
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
14
+ self.num_levels = num_levels
15
+ self.radius = radius
16
+ self.corr_pyramid = []
17
+
18
+ # all pairs correlation
19
+ corr = CorrBlock.corr(fmap1, fmap2)
20
+
21
+ batch, h1, w1, dim, h2, w2 = corr.shape
22
+ corr = corr.reshape(batch*h1*w1, dim, h2, w2)
23
+
24
+ self.corr_pyramid.append(corr)
25
+ for i in range(self.num_levels-1):
26
+ corr = F.avg_pool2d(corr, 2, stride=2)
27
+ self.corr_pyramid.append(corr)
28
+
29
+ def __call__(self, coords):
30
+ r = self.radius
31
+ coords = coords.permute(0, 2, 3, 1)
32
+ batch, h1, w1, _ = coords.shape
33
+
34
+ out_pyramid = []
35
+ for i in range(self.num_levels):
36
+ corr = self.corr_pyramid[i]
37
+ dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
38
+ dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
39
+ delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
40
+
41
+ centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
42
+ delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
43
+ coords_lvl = centroid_lvl + delta_lvl
44
+
45
+ corr = bilinear_sampler(corr, coords_lvl)
46
+ corr = corr.view(batch, h1, w1, -1)
47
+ out_pyramid.append(corr)
48
+
49
+ out = torch.cat(out_pyramid, dim=-1)
50
+ return out.permute(0, 3, 1, 2).contiguous().float()
51
+
52
+ @staticmethod
53
+ def corr(fmap1, fmap2):
54
+ batch, dim, ht, wd = fmap1.shape
55
+ fmap1 = fmap1.view(batch, dim, ht*wd)
56
+ fmap2 = fmap2.view(batch, dim, ht*wd)
57
+
58
+ corr = torch.matmul(fmap1.transpose(1,2), fmap2)
59
+ corr = corr.view(batch, ht, wd, 1, ht, wd)
60
+ return corr / torch.sqrt(torch.tensor(dim).float())
61
+
62
+
63
+ class AlternateCorrBlock:
64
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
65
+ self.num_levels = num_levels
66
+ self.radius = radius
67
+
68
+ self.pyramid = [(fmap1, fmap2)]
69
+ for i in range(self.num_levels):
70
+ fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
71
+ fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
72
+ self.pyramid.append((fmap1, fmap2))
73
+
74
+ def __call__(self, coords):
75
+ coords = coords.permute(0, 2, 3, 1)
76
+ B, H, W, _ = coords.shape
77
+ dim = self.pyramid[0][0].shape[1]
78
+
79
+ corr_list = []
80
+ for i in range(self.num_levels):
81
+ r = self.radius
82
+ fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
83
+ fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
84
+
85
+ coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
86
+ corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
87
+ corr_list.append(corr.squeeze(1))
88
+
89
+ corr = torch.stack(corr_list, dim=1)
90
+ corr = corr.reshape(B, -1, H, W)
91
+ return corr / torch.sqrt(torch.tensor(dim).float())
AIGVDet/core/datasets.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data loading based on https://github.com/NVIDIA/flownet2-pytorch
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.utils.data as data
6
+ import torch.nn.functional as F
7
+
8
+ import os
9
+ import math
10
+ import random
11
+ from glob import glob
12
+ import os.path as osp
13
+
14
+ from utils import frame_utils
15
+ from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
16
+
17
+
18
+ class FlowDataset(data.Dataset):
19
+ def __init__(self, aug_params=None, sparse=False):
20
+ self.augmentor = None
21
+ self.sparse = sparse
22
+ if aug_params is not None:
23
+ if sparse:
24
+ self.augmentor = SparseFlowAugmentor(**aug_params)
25
+ else:
26
+ self.augmentor = FlowAugmentor(**aug_params)
27
+
28
+ self.is_test = False
29
+ self.init_seed = False
30
+ self.flow_list = []
31
+ self.image_list = []
32
+ self.extra_info = []
33
+
34
+ def __getitem__(self, index):
35
+
36
+ if self.is_test:
37
+ img1 = frame_utils.read_gen(self.image_list[index][0])
38
+ img2 = frame_utils.read_gen(self.image_list[index][1])
39
+ img1 = np.array(img1).astype(np.uint8)[..., :3]
40
+ img2 = np.array(img2).astype(np.uint8)[..., :3]
41
+ img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
42
+ img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
43
+ return img1, img2, self.extra_info[index]
44
+
45
+ if not self.init_seed:
46
+ worker_info = torch.utils.data.get_worker_info()
47
+ if worker_info is not None:
48
+ torch.manual_seed(worker_info.id)
49
+ np.random.seed(worker_info.id)
50
+ random.seed(worker_info.id)
51
+ self.init_seed = True
52
+
53
+ index = index % len(self.image_list)
54
+ valid = None
55
+ if self.sparse:
56
+ flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
57
+ else:
58
+ flow = frame_utils.read_gen(self.flow_list[index])
59
+
60
+ img1 = frame_utils.read_gen(self.image_list[index][0])
61
+ img2 = frame_utils.read_gen(self.image_list[index][1])
62
+
63
+ flow = np.array(flow).astype(np.float32)
64
+ img1 = np.array(img1).astype(np.uint8)
65
+ img2 = np.array(img2).astype(np.uint8)
66
+
67
+ # grayscale images
68
+ if len(img1.shape) == 2:
69
+ img1 = np.tile(img1[...,None], (1, 1, 3))
70
+ img2 = np.tile(img2[...,None], (1, 1, 3))
71
+ else:
72
+ img1 = img1[..., :3]
73
+ img2 = img2[..., :3]
74
+
75
+ if self.augmentor is not None:
76
+ if self.sparse:
77
+ img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
78
+ else:
79
+ img1, img2, flow = self.augmentor(img1, img2, flow)
80
+
81
+ img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
82
+ img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
83
+ flow = torch.from_numpy(flow).permute(2, 0, 1).float()
84
+
85
+ if valid is not None:
86
+ valid = torch.from_numpy(valid)
87
+ else:
88
+ valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
89
+
90
+ return img1, img2, flow, valid.float()
91
+
92
+
93
+ def __rmul__(self, v):
94
+ self.flow_list = v * self.flow_list
95
+ self.image_list = v * self.image_list
96
+ return self
97
+
98
+ def __len__(self):
99
+ return len(self.image_list)
100
+
101
+
102
+ class MpiSintel(FlowDataset):
103
+ def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
104
+ super(MpiSintel, self).__init__(aug_params)
105
+ flow_root = osp.join(root, split, 'flow')
106
+ image_root = osp.join(root, split, dstype)
107
+
108
+ if split == 'test':
109
+ self.is_test = True
110
+
111
+ for scene in os.listdir(image_root):
112
+ image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
113
+ for i in range(len(image_list)-1):
114
+ self.image_list += [ [image_list[i], image_list[i+1]] ]
115
+ self.extra_info += [ (scene, i) ] # scene and frame_id
116
+
117
+ if split != 'test':
118
+ self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
119
+
120
+
121
+ class FlyingChairs(FlowDataset):
122
+ def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
123
+ super(FlyingChairs, self).__init__(aug_params)
124
+
125
+ images = sorted(glob(osp.join(root, '*.ppm')))
126
+ flows = sorted(glob(osp.join(root, '*.flo')))
127
+ assert (len(images)//2 == len(flows))
128
+
129
+ split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
130
+ for i in range(len(flows)):
131
+ xid = split_list[i]
132
+ if (split=='training' and xid==1) or (split=='validation' and xid==2):
133
+ self.flow_list += [ flows[i] ]
134
+ self.image_list += [ [images[2*i], images[2*i+1]] ]
135
+
136
+
137
+ class FlyingThings3D(FlowDataset):
138
+ def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
139
+ super(FlyingThings3D, self).__init__(aug_params)
140
+
141
+ for cam in ['left']:
142
+ for direction in ['into_future', 'into_past']:
143
+ image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
144
+ image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
145
+
146
+ flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
147
+ flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
148
+
149
+ for idir, fdir in zip(image_dirs, flow_dirs):
150
+ images = sorted(glob(osp.join(idir, '*.png')) )
151
+ flows = sorted(glob(osp.join(fdir, '*.pfm')) )
152
+ for i in range(len(flows)-1):
153
+ if direction == 'into_future':
154
+ self.image_list += [ [images[i], images[i+1]] ]
155
+ self.flow_list += [ flows[i] ]
156
+ elif direction == 'into_past':
157
+ self.image_list += [ [images[i+1], images[i]] ]
158
+ self.flow_list += [ flows[i+1] ]
159
+
160
+
161
+ class KITTI(FlowDataset):
162
+ def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
163
+ super(KITTI, self).__init__(aug_params, sparse=True)
164
+ if split == 'testing':
165
+ self.is_test = True
166
+
167
+ root = osp.join(root, split)
168
+ images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
169
+ images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
170
+
171
+ for img1, img2 in zip(images1, images2):
172
+ frame_id = img1.split('/')[-1]
173
+ self.extra_info += [ [frame_id] ]
174
+ self.image_list += [ [img1, img2] ]
175
+
176
+ if split == 'training':
177
+ self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
178
+
179
+
180
+ class HD1K(FlowDataset):
181
+ def __init__(self, aug_params=None, root='datasets/HD1k'):
182
+ super(HD1K, self).__init__(aug_params, sparse=True)
183
+
184
+ seq_ix = 0
185
+ while 1:
186
+ flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
187
+ images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
188
+
189
+ if len(flows) == 0:
190
+ break
191
+
192
+ for i in range(len(flows)-1):
193
+ self.flow_list += [flows[i]]
194
+ self.image_list += [ [images[i], images[i+1]] ]
195
+
196
+ seq_ix += 1
197
+
198
+
199
+ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
200
+ """ Create the data loader for the corresponding trainign set """
201
+
202
+ if args.stage == 'chairs':
203
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
204
+ train_dataset = FlyingChairs(aug_params, split='training')
205
+
206
+ elif args.stage == 'things':
207
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
208
+ clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
209
+ final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
210
+ train_dataset = clean_dataset + final_dataset
211
+
212
+ elif args.stage == 'sintel':
213
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
214
+ things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
215
+ sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
216
+ sintel_final = MpiSintel(aug_params, split='training', dstype='final')
217
+
218
+ if TRAIN_DS == 'C+T+K+S+H':
219
+ kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
220
+ hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
221
+ train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
222
+
223
+ elif TRAIN_DS == 'C+T+K/S':
224
+ train_dataset = 100*sintel_clean + 100*sintel_final + things
225
+
226
+ elif args.stage == 'kitti':
227
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
228
+ train_dataset = KITTI(aug_params, split='training')
229
+
230
+ train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
231
+ pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
232
+
233
+ print('Training with %d image pairs' % len(train_dataset))
234
+ return train_loader
235
+
AIGVDet/core/extractor.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class ResidualBlock(nn.Module):
7
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8
+ super(ResidualBlock, self).__init__()
9
+
10
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12
+ self.relu = nn.ReLU(inplace=True)
13
+
14
+ num_groups = planes // 8
15
+
16
+ if norm_fn == 'group':
17
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19
+ if not stride == 1:
20
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21
+
22
+ elif norm_fn == 'batch':
23
+ self.norm1 = nn.BatchNorm2d(planes)
24
+ self.norm2 = nn.BatchNorm2d(planes)
25
+ if not stride == 1:
26
+ self.norm3 = nn.BatchNorm2d(planes)
27
+
28
+ elif norm_fn == 'instance':
29
+ self.norm1 = nn.InstanceNorm2d(planes)
30
+ self.norm2 = nn.InstanceNorm2d(planes)
31
+ if not stride == 1:
32
+ self.norm3 = nn.InstanceNorm2d(planes)
33
+
34
+ elif norm_fn == 'none':
35
+ self.norm1 = nn.Sequential()
36
+ self.norm2 = nn.Sequential()
37
+ if not stride == 1:
38
+ self.norm3 = nn.Sequential()
39
+
40
+ if stride == 1:
41
+ self.downsample = None
42
+
43
+ else:
44
+ self.downsample = nn.Sequential(
45
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46
+
47
+
48
+ def forward(self, x):
49
+ y = x
50
+ y = self.relu(self.norm1(self.conv1(y)))
51
+ y = self.relu(self.norm2(self.conv2(y)))
52
+
53
+ if self.downsample is not None:
54
+ x = self.downsample(x)
55
+
56
+ return self.relu(x+y)
57
+
58
+
59
+
60
+ class BottleneckBlock(nn.Module):
61
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
62
+ super(BottleneckBlock, self).__init__()
63
+
64
+ self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
65
+ self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
66
+ self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
67
+ self.relu = nn.ReLU(inplace=True)
68
+
69
+ num_groups = planes // 8
70
+
71
+ if norm_fn == 'group':
72
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
73
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
74
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
75
+ if not stride == 1:
76
+ self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
77
+
78
+ elif norm_fn == 'batch':
79
+ self.norm1 = nn.BatchNorm2d(planes//4)
80
+ self.norm2 = nn.BatchNorm2d(planes//4)
81
+ self.norm3 = nn.BatchNorm2d(planes)
82
+ if not stride == 1:
83
+ self.norm4 = nn.BatchNorm2d(planes)
84
+
85
+ elif norm_fn == 'instance':
86
+ self.norm1 = nn.InstanceNorm2d(planes//4)
87
+ self.norm2 = nn.InstanceNorm2d(planes//4)
88
+ self.norm3 = nn.InstanceNorm2d(planes)
89
+ if not stride == 1:
90
+ self.norm4 = nn.InstanceNorm2d(planes)
91
+
92
+ elif norm_fn == 'none':
93
+ self.norm1 = nn.Sequential()
94
+ self.norm2 = nn.Sequential()
95
+ self.norm3 = nn.Sequential()
96
+ if not stride == 1:
97
+ self.norm4 = nn.Sequential()
98
+
99
+ if stride == 1:
100
+ self.downsample = None
101
+
102
+ else:
103
+ self.downsample = nn.Sequential(
104
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
105
+
106
+
107
+ def forward(self, x):
108
+ y = x
109
+ y = self.relu(self.norm1(self.conv1(y)))
110
+ y = self.relu(self.norm2(self.conv2(y)))
111
+ y = self.relu(self.norm3(self.conv3(y)))
112
+
113
+ if self.downsample is not None:
114
+ x = self.downsample(x)
115
+
116
+ return self.relu(x+y)
117
+
118
+ class BasicEncoder(nn.Module):
119
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
120
+ super(BasicEncoder, self).__init__()
121
+ self.norm_fn = norm_fn
122
+
123
+ if self.norm_fn == 'group':
124
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
125
+
126
+ elif self.norm_fn == 'batch':
127
+ self.norm1 = nn.BatchNorm2d(64)
128
+
129
+ elif self.norm_fn == 'instance':
130
+ self.norm1 = nn.InstanceNorm2d(64)
131
+
132
+ elif self.norm_fn == 'none':
133
+ self.norm1 = nn.Sequential()
134
+
135
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
136
+ self.relu1 = nn.ReLU(inplace=True)
137
+
138
+ self.in_planes = 64
139
+ self.layer1 = self._make_layer(64, stride=1)
140
+ self.layer2 = self._make_layer(96, stride=2)
141
+ self.layer3 = self._make_layer(128, stride=2)
142
+
143
+ # output convolution
144
+ self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
145
+
146
+ self.dropout = None
147
+ if dropout > 0:
148
+ self.dropout = nn.Dropout2d(p=dropout)
149
+
150
+ for m in self.modules():
151
+ if isinstance(m, nn.Conv2d):
152
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
153
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
154
+ if m.weight is not None:
155
+ nn.init.constant_(m.weight, 1)
156
+ if m.bias is not None:
157
+ nn.init.constant_(m.bias, 0)
158
+
159
+ def _make_layer(self, dim, stride=1):
160
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
161
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
162
+ layers = (layer1, layer2)
163
+
164
+ self.in_planes = dim
165
+ return nn.Sequential(*layers)
166
+
167
+
168
+ def forward(self, x):
169
+
170
+ # if input is list, combine batch dimension
171
+ is_list = isinstance(x, tuple) or isinstance(x, list)
172
+ if is_list:
173
+ batch_dim = x[0].shape[0]
174
+ x = torch.cat(x, dim=0)
175
+
176
+ x = self.conv1(x)
177
+ x = self.norm1(x)
178
+ x = self.relu1(x)
179
+
180
+ x = self.layer1(x)
181
+ x = self.layer2(x)
182
+ x = self.layer3(x)
183
+
184
+ x = self.conv2(x)
185
+
186
+ if self.training and self.dropout is not None:
187
+ x = self.dropout(x)
188
+
189
+ if is_list:
190
+ x = torch.split(x, [batch_dim, batch_dim], dim=0)
191
+
192
+ return x
193
+
194
+
195
+ class SmallEncoder(nn.Module):
196
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
197
+ super(SmallEncoder, self).__init__()
198
+ self.norm_fn = norm_fn
199
+
200
+ if self.norm_fn == 'group':
201
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
202
+
203
+ elif self.norm_fn == 'batch':
204
+ self.norm1 = nn.BatchNorm2d(32)
205
+
206
+ elif self.norm_fn == 'instance':
207
+ self.norm1 = nn.InstanceNorm2d(32)
208
+
209
+ elif self.norm_fn == 'none':
210
+ self.norm1 = nn.Sequential()
211
+
212
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
213
+ self.relu1 = nn.ReLU(inplace=True)
214
+
215
+ self.in_planes = 32
216
+ self.layer1 = self._make_layer(32, stride=1)
217
+ self.layer2 = self._make_layer(64, stride=2)
218
+ self.layer3 = self._make_layer(96, stride=2)
219
+
220
+ self.dropout = None
221
+ if dropout > 0:
222
+ self.dropout = nn.Dropout2d(p=dropout)
223
+
224
+ self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
225
+
226
+ for m in self.modules():
227
+ if isinstance(m, nn.Conv2d):
228
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
229
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
230
+ if m.weight is not None:
231
+ nn.init.constant_(m.weight, 1)
232
+ if m.bias is not None:
233
+ nn.init.constant_(m.bias, 0)
234
+
235
+ def _make_layer(self, dim, stride=1):
236
+ layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
237
+ layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
238
+ layers = (layer1, layer2)
239
+
240
+ self.in_planes = dim
241
+ return nn.Sequential(*layers)
242
+
243
+
244
+ def forward(self, x):
245
+
246
+ # if input is list, combine batch dimension
247
+ is_list = isinstance(x, tuple) or isinstance(x, list)
248
+ if is_list:
249
+ batch_dim = x[0].shape[0]
250
+ x = torch.cat(x, dim=0)
251
+
252
+ x = self.conv1(x)
253
+ x = self.norm1(x)
254
+ x = self.relu1(x)
255
+
256
+ x = self.layer1(x)
257
+ x = self.layer2(x)
258
+ x = self.layer3(x)
259
+ x = self.conv2(x)
260
+
261
+ if self.training and self.dropout is not None:
262
+ x = self.dropout(x)
263
+
264
+ if is_list:
265
+ x = torch.split(x, [batch_dim, batch_dim], dim=0)
266
+
267
+ return x
AIGVDet/core/raft.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .update import BasicUpdateBlock, SmallUpdateBlock
7
+ from .extractor import BasicEncoder, SmallEncoder
8
+ from .corr import CorrBlock, AlternateCorrBlock
9
+ from .utils.utils import bilinear_sampler, coords_grid, upflow8
10
+
11
+ try:
12
+ autocast = torch.cuda.amp.autocast
13
+ except:
14
+ # dummy autocast for PyTorch < 1.6
15
+ class autocast:
16
+ def __init__(self, enabled):
17
+ pass
18
+ def __enter__(self):
19
+ pass
20
+ def __exit__(self, *args):
21
+ pass
22
+
23
+
24
+ class RAFT(nn.Module):
25
+ def __init__(self, args):
26
+ super(RAFT, self).__init__()
27
+ self.args = args
28
+
29
+ if args.small:
30
+ self.hidden_dim = hdim = 96
31
+ self.context_dim = cdim = 64
32
+ args.corr_levels = 4
33
+ args.corr_radius = 3
34
+
35
+ else:
36
+ self.hidden_dim = hdim = 128
37
+ self.context_dim = cdim = 128
38
+ args.corr_levels = 4
39
+ args.corr_radius = 4
40
+
41
+ if 'dropout' not in self.args:
42
+ self.args.dropout = 0
43
+
44
+ if 'alternate_corr' not in self.args:
45
+ self.args.alternate_corr = False
46
+
47
+ # feature network, context network, and update block
48
+ if args.small:
49
+ self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
50
+ self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
51
+ self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
52
+
53
+ else:
54
+ self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
55
+ self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
56
+ self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
57
+
58
+ def freeze_bn(self):
59
+ for m in self.modules():
60
+ if isinstance(m, nn.BatchNorm2d):
61
+ m.eval()
62
+
63
+ def initialize_flow(self, img):
64
+ """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
65
+ N, C, H, W = img.shape
66
+ coords0 = coords_grid(N, H//8, W//8, device=img.device)
67
+ coords1 = coords_grid(N, H//8, W//8, device=img.device)
68
+
69
+ # optical flow computed as difference: flow = coords1 - coords0
70
+ return coords0, coords1
71
+
72
+ def upsample_flow(self, flow, mask):
73
+ """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
74
+ N, _, H, W = flow.shape
75
+ mask = mask.view(N, 1, 9, 8, 8, H, W)
76
+ mask = torch.softmax(mask, dim=2)
77
+
78
+ up_flow = F.unfold(8 * flow, [3,3], padding=1)
79
+ up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
80
+
81
+ up_flow = torch.sum(mask * up_flow, dim=2)
82
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
83
+ return up_flow.reshape(N, 2, 8*H, 8*W)
84
+
85
+
86
+ def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
87
+ """ Estimate optical flow between pair of frames """
88
+
89
+ image1 = 2 * (image1 / 255.0) - 1.0
90
+ image2 = 2 * (image2 / 255.0) - 1.0
91
+
92
+ image1 = image1.contiguous()
93
+ image2 = image2.contiguous()
94
+
95
+ hdim = self.hidden_dim
96
+ cdim = self.context_dim
97
+
98
+ # run the feature network
99
+ with autocast(enabled=self.args.mixed_precision):
100
+ fmap1, fmap2 = self.fnet([image1, image2])
101
+
102
+ fmap1 = fmap1.float()
103
+ fmap2 = fmap2.float()
104
+ if self.args.alternate_corr:
105
+ corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
106
+ else:
107
+ corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
108
+
109
+ # run the context network
110
+ with autocast(enabled=self.args.mixed_precision):
111
+ cnet = self.cnet(image1)
112
+ net, inp = torch.split(cnet, [hdim, cdim], dim=1)
113
+ net = torch.tanh(net)
114
+ inp = torch.relu(inp)
115
+
116
+ coords0, coords1 = self.initialize_flow(image1)
117
+
118
+ if flow_init is not None:
119
+ coords1 = coords1 + flow_init
120
+
121
+ flow_predictions = []
122
+ for itr in range(iters):
123
+ coords1 = coords1.detach()
124
+ corr = corr_fn(coords1) # index correlation volume
125
+
126
+ flow = coords1 - coords0
127
+ with autocast(enabled=self.args.mixed_precision):
128
+ net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
129
+
130
+ # F(t+1) = F(t) + \Delta(t)
131
+ coords1 = coords1 + delta_flow
132
+
133
+ # upsample predictions
134
+ if up_mask is None:
135
+ flow_up = upflow8(coords1 - coords0)
136
+ else:
137
+ flow_up = self.upsample_flow(coords1 - coords0, up_mask)
138
+
139
+ flow_predictions.append(flow_up)
140
+
141
+ if test_mode:
142
+ return coords1 - coords0, flow_up
143
+
144
+ return flow_predictions
AIGVDet/core/update.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class FlowHead(nn.Module):
7
+ def __init__(self, input_dim=128, hidden_dim=256):
8
+ super(FlowHead, self).__init__()
9
+ self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10
+ self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11
+ self.relu = nn.ReLU(inplace=True)
12
+
13
+ def forward(self, x):
14
+ return self.conv2(self.relu(self.conv1(x)))
15
+
16
+ class ConvGRU(nn.Module):
17
+ def __init__(self, hidden_dim=128, input_dim=192+128):
18
+ super(ConvGRU, self).__init__()
19
+ self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
20
+ self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
21
+ self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
22
+
23
+ def forward(self, h, x):
24
+ hx = torch.cat([h, x], dim=1)
25
+
26
+ z = torch.sigmoid(self.convz(hx))
27
+ r = torch.sigmoid(self.convr(hx))
28
+ q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
29
+
30
+ h = (1-z) * h + z * q
31
+ return h
32
+
33
+ class SepConvGRU(nn.Module):
34
+ def __init__(self, hidden_dim=128, input_dim=192+128):
35
+ super(SepConvGRU, self).__init__()
36
+ self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
37
+ self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
38
+ self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
39
+
40
+ self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
41
+ self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
42
+ self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
43
+
44
+
45
+ def forward(self, h, x):
46
+ # horizontal
47
+ hx = torch.cat([h, x], dim=1)
48
+ z = torch.sigmoid(self.convz1(hx))
49
+ r = torch.sigmoid(self.convr1(hx))
50
+ q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
51
+ h = (1-z) * h + z * q
52
+
53
+ # vertical
54
+ hx = torch.cat([h, x], dim=1)
55
+ z = torch.sigmoid(self.convz2(hx))
56
+ r = torch.sigmoid(self.convr2(hx))
57
+ q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
58
+ h = (1-z) * h + z * q
59
+
60
+ return h
61
+
62
+ class SmallMotionEncoder(nn.Module):
63
+ def __init__(self, args):
64
+ super(SmallMotionEncoder, self).__init__()
65
+ cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
66
+ self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
67
+ self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
68
+ self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
69
+ self.conv = nn.Conv2d(128, 80, 3, padding=1)
70
+
71
+ def forward(self, flow, corr):
72
+ cor = F.relu(self.convc1(corr))
73
+ flo = F.relu(self.convf1(flow))
74
+ flo = F.relu(self.convf2(flo))
75
+ cor_flo = torch.cat([cor, flo], dim=1)
76
+ out = F.relu(self.conv(cor_flo))
77
+ return torch.cat([out, flow], dim=1)
78
+
79
+ class BasicMotionEncoder(nn.Module):
80
+ def __init__(self, args):
81
+ super(BasicMotionEncoder, self).__init__()
82
+ cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
83
+ self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
84
+ self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
85
+ self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
86
+ self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
87
+ self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
88
+
89
+ def forward(self, flow, corr):
90
+ cor = F.relu(self.convc1(corr))
91
+ cor = F.relu(self.convc2(cor))
92
+ flo = F.relu(self.convf1(flow))
93
+ flo = F.relu(self.convf2(flo))
94
+
95
+ cor_flo = torch.cat([cor, flo], dim=1)
96
+ out = F.relu(self.conv(cor_flo))
97
+ return torch.cat([out, flow], dim=1)
98
+
99
+ class SmallUpdateBlock(nn.Module):
100
+ def __init__(self, args, hidden_dim=96):
101
+ super(SmallUpdateBlock, self).__init__()
102
+ self.encoder = SmallMotionEncoder(args)
103
+ self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
104
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
105
+
106
+ def forward(self, net, inp, corr, flow):
107
+ motion_features = self.encoder(flow, corr)
108
+ inp = torch.cat([inp, motion_features], dim=1)
109
+ net = self.gru(net, inp)
110
+ delta_flow = self.flow_head(net)
111
+
112
+ return net, None, delta_flow
113
+
114
+ class BasicUpdateBlock(nn.Module):
115
+ def __init__(self, args, hidden_dim=128, input_dim=128):
116
+ super(BasicUpdateBlock, self).__init__()
117
+ self.args = args
118
+ self.encoder = BasicMotionEncoder(args)
119
+ self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
120
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
121
+
122
+ self.mask = nn.Sequential(
123
+ nn.Conv2d(128, 256, 3, padding=1),
124
+ nn.ReLU(inplace=True),
125
+ nn.Conv2d(256, 64*9, 1, padding=0))
126
+
127
+ def forward(self, net, inp, corr, flow, upsample=True):
128
+ motion_features = self.encoder(flow, corr)
129
+ inp = torch.cat([inp, motion_features], dim=1)
130
+
131
+ net = self.gru(net, inp)
132
+ delta_flow = self.flow_head(net)
133
+
134
+ # scale mask to balence gradients
135
+ mask = .25 * self.mask(net)
136
+ return net, mask, delta_flow
137
+
138
+
139
+
AIGVDet/core/utils/__init__.py ADDED
File without changes
AIGVDet/core/utils/augmentor.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import math
4
+ from PIL import Image
5
+
6
+ import cv2
7
+ cv2.setNumThreads(0)
8
+ cv2.ocl.setUseOpenCL(False)
9
+
10
+ import torch
11
+ from torchvision.transforms import ColorJitter
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class FlowAugmentor:
16
+ def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
17
+
18
+ # spatial augmentation params
19
+ self.crop_size = crop_size
20
+ self.min_scale = min_scale
21
+ self.max_scale = max_scale
22
+ self.spatial_aug_prob = 0.8
23
+ self.stretch_prob = 0.8
24
+ self.max_stretch = 0.2
25
+
26
+ # flip augmentation params
27
+ self.do_flip = do_flip
28
+ self.h_flip_prob = 0.5
29
+ self.v_flip_prob = 0.1
30
+
31
+ # photometric augmentation params
32
+ self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
33
+ self.asymmetric_color_aug_prob = 0.2
34
+ self.eraser_aug_prob = 0.5
35
+
36
+ def color_transform(self, img1, img2):
37
+ """ Photometric augmentation """
38
+
39
+ # asymmetric
40
+ if np.random.rand() < self.asymmetric_color_aug_prob:
41
+ img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
42
+ img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
43
+
44
+ # symmetric
45
+ else:
46
+ image_stack = np.concatenate([img1, img2], axis=0)
47
+ image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
48
+ img1, img2 = np.split(image_stack, 2, axis=0)
49
+
50
+ return img1, img2
51
+
52
+ def eraser_transform(self, img1, img2, bounds=[50, 100]):
53
+ """ Occlusion augmentation """
54
+
55
+ ht, wd = img1.shape[:2]
56
+ if np.random.rand() < self.eraser_aug_prob:
57
+ mean_color = np.mean(img2.reshape(-1, 3), axis=0)
58
+ for _ in range(np.random.randint(1, 3)):
59
+ x0 = np.random.randint(0, wd)
60
+ y0 = np.random.randint(0, ht)
61
+ dx = np.random.randint(bounds[0], bounds[1])
62
+ dy = np.random.randint(bounds[0], bounds[1])
63
+ img2[y0:y0+dy, x0:x0+dx, :] = mean_color
64
+
65
+ return img1, img2
66
+
67
+ def spatial_transform(self, img1, img2, flow):
68
+ # randomly sample scale
69
+ ht, wd = img1.shape[:2]
70
+ min_scale = np.maximum(
71
+ (self.crop_size[0] + 8) / float(ht),
72
+ (self.crop_size[1] + 8) / float(wd))
73
+
74
+ scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
75
+ scale_x = scale
76
+ scale_y = scale
77
+ if np.random.rand() < self.stretch_prob:
78
+ scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
79
+ scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
80
+
81
+ scale_x = np.clip(scale_x, min_scale, None)
82
+ scale_y = np.clip(scale_y, min_scale, None)
83
+
84
+ if np.random.rand() < self.spatial_aug_prob:
85
+ # rescale the images
86
+ img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
87
+ img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
88
+ flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
89
+ flow = flow * [scale_x, scale_y]
90
+
91
+ if self.do_flip:
92
+ if np.random.rand() < self.h_flip_prob: # h-flip
93
+ img1 = img1[:, ::-1]
94
+ img2 = img2[:, ::-1]
95
+ flow = flow[:, ::-1] * [-1.0, 1.0]
96
+
97
+ if np.random.rand() < self.v_flip_prob: # v-flip
98
+ img1 = img1[::-1, :]
99
+ img2 = img2[::-1, :]
100
+ flow = flow[::-1, :] * [1.0, -1.0]
101
+
102
+ y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
103
+ x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
104
+
105
+ img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
106
+ img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
107
+ flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
108
+
109
+ return img1, img2, flow
110
+
111
+ def __call__(self, img1, img2, flow):
112
+ img1, img2 = self.color_transform(img1, img2)
113
+ img1, img2 = self.eraser_transform(img1, img2)
114
+ img1, img2, flow = self.spatial_transform(img1, img2, flow)
115
+
116
+ img1 = np.ascontiguousarray(img1)
117
+ img2 = np.ascontiguousarray(img2)
118
+ flow = np.ascontiguousarray(flow)
119
+
120
+ return img1, img2, flow
121
+
122
+ class SparseFlowAugmentor:
123
+ def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
124
+ # spatial augmentation params
125
+ self.crop_size = crop_size
126
+ self.min_scale = min_scale
127
+ self.max_scale = max_scale
128
+ self.spatial_aug_prob = 0.8
129
+ self.stretch_prob = 0.8
130
+ self.max_stretch = 0.2
131
+
132
+ # flip augmentation params
133
+ self.do_flip = do_flip
134
+ self.h_flip_prob = 0.5
135
+ self.v_flip_prob = 0.1
136
+
137
+ # photometric augmentation params
138
+ self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
139
+ self.asymmetric_color_aug_prob = 0.2
140
+ self.eraser_aug_prob = 0.5
141
+
142
+ def color_transform(self, img1, img2):
143
+ image_stack = np.concatenate([img1, img2], axis=0)
144
+ image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
145
+ img1, img2 = np.split(image_stack, 2, axis=0)
146
+ return img1, img2
147
+
148
+ def eraser_transform(self, img1, img2):
149
+ ht, wd = img1.shape[:2]
150
+ if np.random.rand() < self.eraser_aug_prob:
151
+ mean_color = np.mean(img2.reshape(-1, 3), axis=0)
152
+ for _ in range(np.random.randint(1, 3)):
153
+ x0 = np.random.randint(0, wd)
154
+ y0 = np.random.randint(0, ht)
155
+ dx = np.random.randint(50, 100)
156
+ dy = np.random.randint(50, 100)
157
+ img2[y0:y0+dy, x0:x0+dx, :] = mean_color
158
+
159
+ return img1, img2
160
+
161
+ def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
162
+ ht, wd = flow.shape[:2]
163
+ coords = np.meshgrid(np.arange(wd), np.arange(ht))
164
+ coords = np.stack(coords, axis=-1)
165
+
166
+ coords = coords.reshape(-1, 2).astype(np.float32)
167
+ flow = flow.reshape(-1, 2).astype(np.float32)
168
+ valid = valid.reshape(-1).astype(np.float32)
169
+
170
+ coords0 = coords[valid>=1]
171
+ flow0 = flow[valid>=1]
172
+
173
+ ht1 = int(round(ht * fy))
174
+ wd1 = int(round(wd * fx))
175
+
176
+ coords1 = coords0 * [fx, fy]
177
+ flow1 = flow0 * [fx, fy]
178
+
179
+ xx = np.round(coords1[:,0]).astype(np.int32)
180
+ yy = np.round(coords1[:,1]).astype(np.int32)
181
+
182
+ v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
183
+ xx = xx[v]
184
+ yy = yy[v]
185
+ flow1 = flow1[v]
186
+
187
+ flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
188
+ valid_img = np.zeros([ht1, wd1], dtype=np.int32)
189
+
190
+ flow_img[yy, xx] = flow1
191
+ valid_img[yy, xx] = 1
192
+
193
+ return flow_img, valid_img
194
+
195
+ def spatial_transform(self, img1, img2, flow, valid):
196
+ # randomly sample scale
197
+
198
+ ht, wd = img1.shape[:2]
199
+ min_scale = np.maximum(
200
+ (self.crop_size[0] + 1) / float(ht),
201
+ (self.crop_size[1] + 1) / float(wd))
202
+
203
+ scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
204
+ scale_x = np.clip(scale, min_scale, None)
205
+ scale_y = np.clip(scale, min_scale, None)
206
+
207
+ if np.random.rand() < self.spatial_aug_prob:
208
+ # rescale the images
209
+ img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
210
+ img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
211
+ flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
212
+
213
+ if self.do_flip:
214
+ if np.random.rand() < 0.5: # h-flip
215
+ img1 = img1[:, ::-1]
216
+ img2 = img2[:, ::-1]
217
+ flow = flow[:, ::-1] * [-1.0, 1.0]
218
+ valid = valid[:, ::-1]
219
+
220
+ margin_y = 20
221
+ margin_x = 50
222
+
223
+ y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
224
+ x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
225
+
226
+ y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
227
+ x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
228
+
229
+ img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
230
+ img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
231
+ flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
232
+ valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
233
+ return img1, img2, flow, valid
234
+
235
+
236
+ def __call__(self, img1, img2, flow, valid):
237
+ img1, img2 = self.color_transform(img1, img2)
238
+ img1, img2 = self.eraser_transform(img1, img2)
239
+ img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
240
+
241
+ img1 = np.ascontiguousarray(img1)
242
+ img2 = np.ascontiguousarray(img2)
243
+ flow = np.ascontiguousarray(flow)
244
+ valid = np.ascontiguousarray(valid)
245
+
246
+ return img1, img2, flow, valid
AIGVDet/core/utils/flow_viz.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2
+
3
+
4
+ # MIT License
5
+ #
6
+ # Copyright (c) 2018 Tom Runia
7
+ #
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to conditions.
14
+ #
15
+ # Author: Tom Runia
16
+ # Date Created: 2018-08-03
17
+
18
+ import numpy as np
19
+
20
+ def make_colorwheel():
21
+ """
22
+ Generates a color wheel for optical flow visualization as presented in:
23
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
24
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
25
+
26
+ Code follows the original C++ source code of Daniel Scharstein.
27
+ Code follows the the Matlab source code of Deqing Sun.
28
+
29
+ Returns:
30
+ np.ndarray: Color wheel
31
+ """
32
+
33
+ RY = 15
34
+ YG = 6
35
+ GC = 4
36
+ CB = 11
37
+ BM = 13
38
+ MR = 6
39
+
40
+ ncols = RY + YG + GC + CB + BM + MR
41
+ colorwheel = np.zeros((ncols, 3))
42
+ col = 0
43
+
44
+ # RY
45
+ colorwheel[0:RY, 0] = 255
46
+ colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
47
+ col = col+RY
48
+ # YG
49
+ colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
50
+ colorwheel[col:col+YG, 1] = 255
51
+ col = col+YG
52
+ # GC
53
+ colorwheel[col:col+GC, 1] = 255
54
+ colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
55
+ col = col+GC
56
+ # CB
57
+ colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
58
+ colorwheel[col:col+CB, 2] = 255
59
+ col = col+CB
60
+ # BM
61
+ colorwheel[col:col+BM, 2] = 255
62
+ colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
63
+ col = col+BM
64
+ # MR
65
+ colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
66
+ colorwheel[col:col+MR, 0] = 255
67
+ return colorwheel
68
+
69
+
70
+ def flow_uv_to_colors(u, v, convert_to_bgr=False):
71
+ """
72
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
73
+
74
+ According to the C++ source code of Daniel Scharstein
75
+ According to the Matlab source code of Deqing Sun
76
+
77
+ Args:
78
+ u (np.ndarray): Input horizontal flow of shape [H,W]
79
+ v (np.ndarray): Input vertical flow of shape [H,W]
80
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
81
+
82
+ Returns:
83
+ np.ndarray: Flow visualization image of shape [H,W,3]
84
+ """
85
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
86
+ colorwheel = make_colorwheel() # shape [55x3]
87
+ ncols = colorwheel.shape[0]
88
+ rad = np.sqrt(np.square(u) + np.square(v))
89
+ a = np.arctan2(-v, -u)/np.pi
90
+ fk = (a+1) / 2*(ncols-1)
91
+ k0 = np.floor(fk).astype(np.int32)
92
+ k1 = k0 + 1
93
+ k1[k1 == ncols] = 0
94
+ f = fk - k0
95
+ for i in range(colorwheel.shape[1]):
96
+ tmp = colorwheel[:,i]
97
+ col0 = tmp[k0] / 255.0
98
+ col1 = tmp[k1] / 255.0
99
+ col = (1-f)*col0 + f*col1
100
+ idx = (rad <= 1)
101
+ col[idx] = 1 - rad[idx] * (1-col[idx])
102
+ col[~idx] = col[~idx] * 0.75 # out of range
103
+ # Note the 2-i => BGR instead of RGB
104
+ ch_idx = 2-i if convert_to_bgr else i
105
+ flow_image[:,:,ch_idx] = np.floor(255 * col)
106
+ return flow_image
107
+
108
+
109
+ def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
110
+ """
111
+ Expects a two dimensional flow image of shape.
112
+
113
+ Args:
114
+ flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
115
+ clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
116
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
117
+
118
+ Returns:
119
+ np.ndarray: Flow visualization image of shape [H,W,3]
120
+ """
121
+ assert flow_uv.ndim == 3, 'input flow must have three dimensions'
122
+ assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
123
+ if clip_flow is not None:
124
+ flow_uv = np.clip(flow_uv, 0, clip_flow)
125
+ u = flow_uv[:,:,0]
126
+ v = flow_uv[:,:,1]
127
+ rad = np.sqrt(np.square(u) + np.square(v))
128
+ rad_max = np.max(rad)
129
+ epsilon = 1e-5
130
+ u = u / (rad_max + epsilon)
131
+ v = v / (rad_max + epsilon)
132
+ return flow_uv_to_colors(u, v, convert_to_bgr)
AIGVDet/core/utils/frame_utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ from os.path import *
4
+ import re
5
+
6
+ import cv2
7
+ cv2.setNumThreads(0)
8
+ cv2.ocl.setUseOpenCL(False)
9
+
10
+ TAG_CHAR = np.array([202021.25], np.float32)
11
+
12
+ def readFlow(fn):
13
+ """ Read .flo file in Middlebury format"""
14
+ # Code adapted from:
15
+ # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
16
+
17
+ # WARNING: this will work on little-endian architectures (eg Intel x86) only!
18
+ # print 'fn = %s'%(fn)
19
+ with open(fn, 'rb') as f:
20
+ magic = np.fromfile(f, np.float32, count=1)
21
+ if 202021.25 != magic:
22
+ print('Magic number incorrect. Invalid .flo file')
23
+ return None
24
+ else:
25
+ w = np.fromfile(f, np.int32, count=1)
26
+ h = np.fromfile(f, np.int32, count=1)
27
+ # print 'Reading %d x %d flo file\n' % (w, h)
28
+ data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
29
+ # Reshape data into 3D array (columns, rows, bands)
30
+ # The reshape here is for visualization, the original code is (w,h,2)
31
+ return np.resize(data, (int(h), int(w), 2))
32
+
33
+ def readPFM(file):
34
+ file = open(file, 'rb')
35
+
36
+ color = None
37
+ width = None
38
+ height = None
39
+ scale = None
40
+ endian = None
41
+
42
+ header = file.readline().rstrip()
43
+ if header == b'PF':
44
+ color = True
45
+ elif header == b'Pf':
46
+ color = False
47
+ else:
48
+ raise Exception('Not a PFM file.')
49
+
50
+ dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
51
+ if dim_match:
52
+ width, height = map(int, dim_match.groups())
53
+ else:
54
+ raise Exception('Malformed PFM header.')
55
+
56
+ scale = float(file.readline().rstrip())
57
+ if scale < 0: # little-endian
58
+ endian = '<'
59
+ scale = -scale
60
+ else:
61
+ endian = '>' # big-endian
62
+
63
+ data = np.fromfile(file, endian + 'f')
64
+ shape = (height, width, 3) if color else (height, width)
65
+
66
+ data = np.reshape(data, shape)
67
+ data = np.flipud(data)
68
+ return data
69
+
70
+ def writeFlow(filename,uv,v=None):
71
+ """ Write optical flow to file.
72
+
73
+ If v is None, uv is assumed to contain both u and v channels,
74
+ stacked in depth.
75
+ Original code by Deqing Sun, adapted from Daniel Scharstein.
76
+ """
77
+ nBands = 2
78
+
79
+ if v is None:
80
+ assert(uv.ndim == 3)
81
+ assert(uv.shape[2] == 2)
82
+ u = uv[:,:,0]
83
+ v = uv[:,:,1]
84
+ else:
85
+ u = uv
86
+
87
+ assert(u.shape == v.shape)
88
+ height,width = u.shape
89
+ f = open(filename,'wb')
90
+ # write the header
91
+ f.write(TAG_CHAR)
92
+ np.array(width).astype(np.int32).tofile(f)
93
+ np.array(height).astype(np.int32).tofile(f)
94
+ # arrange into matrix form
95
+ tmp = np.zeros((height, width*nBands))
96
+ tmp[:,np.arange(width)*2] = u
97
+ tmp[:,np.arange(width)*2 + 1] = v
98
+ tmp.astype(np.float32).tofile(f)
99
+ f.close()
100
+
101
+
102
+ def readFlowKITTI(filename):
103
+ flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
104
+ flow = flow[:,:,::-1].astype(np.float32)
105
+ flow, valid = flow[:, :, :2], flow[:, :, 2]
106
+ flow = (flow - 2**15) / 64.0
107
+ return flow, valid
108
+
109
+ def readDispKITTI(filename):
110
+ disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
111
+ valid = disp > 0.0
112
+ flow = np.stack([-disp, np.zeros_like(disp)], -1)
113
+ return flow, valid
114
+
115
+
116
+ def writeFlowKITTI(filename, uv):
117
+ uv = 64.0 * uv + 2**15
118
+ valid = np.ones([uv.shape[0], uv.shape[1], 1])
119
+ uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
120
+ cv2.imwrite(filename, uv[..., ::-1])
121
+
122
+
123
+ def read_gen(file_name, pil=False):
124
+ ext = splitext(file_name)[-1]
125
+ if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
126
+ return Image.open(file_name)
127
+ elif ext == '.bin' or ext == '.raw':
128
+ return np.load(file_name)
129
+ elif ext == '.flo':
130
+ return readFlow(file_name).astype(np.float32)
131
+ elif ext == '.pfm':
132
+ flow = readPFM(file_name).astype(np.float32)
133
+ if len(flow.shape) == 2:
134
+ return flow
135
+ else:
136
+ return flow[:, :, :-1]
137
+ return []
AIGVDet/core/utils/utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from scipy import interpolate
5
+
6
+
7
+ class InputPadder:
8
+ """ Pads images such that dimensions are divisible by 8 """
9
+ def __init__(self, dims, mode='sintel'):
10
+ self.ht, self.wd = dims[-2:]
11
+ pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
12
+ pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
13
+ if mode == 'sintel':
14
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
15
+ else:
16
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
17
+
18
+ def pad(self, *inputs):
19
+ return [F.pad(x, self._pad, mode='replicate') for x in inputs]
20
+
21
+ def unpad(self,x):
22
+ ht, wd = x.shape[-2:]
23
+ c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
24
+ return x[..., c[0]:c[1], c[2]:c[3]]
25
+
26
+ def forward_interpolate(flow):
27
+ flow = flow.detach().cpu().numpy()
28
+ dx, dy = flow[0], flow[1]
29
+
30
+ ht, wd = dx.shape
31
+ x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
32
+
33
+ x1 = x0 + dx
34
+ y1 = y0 + dy
35
+
36
+ x1 = x1.reshape(-1)
37
+ y1 = y1.reshape(-1)
38
+ dx = dx.reshape(-1)
39
+ dy = dy.reshape(-1)
40
+
41
+ valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
42
+ x1 = x1[valid]
43
+ y1 = y1[valid]
44
+ dx = dx[valid]
45
+ dy = dy[valid]
46
+
47
+ flow_x = interpolate.griddata(
48
+ (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
49
+
50
+ flow_y = interpolate.griddata(
51
+ (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
52
+
53
+ flow = np.stack([flow_x, flow_y], axis=0)
54
+ return torch.from_numpy(flow).float()
55
+
56
+
57
+ def bilinear_sampler(img, coords, mode='bilinear', mask=False):
58
+ """ Wrapper for grid_sample, uses pixel coordinates """
59
+ H, W = img.shape[-2:]
60
+ xgrid, ygrid = coords.split([1,1], dim=-1)
61
+ xgrid = 2*xgrid/(W-1) - 1
62
+ ygrid = 2*ygrid/(H-1) - 1
63
+
64
+ grid = torch.cat([xgrid, ygrid], dim=-1)
65
+ img = F.grid_sample(img, grid, align_corners=True)
66
+
67
+ if mask:
68
+ mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
69
+ return img, mask.float()
70
+
71
+ return img
72
+
73
+
74
+ def coords_grid(batch, ht, wd, device):
75
+ coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
76
+ coords = torch.stack(coords[::-1], dim=0).float()
77
+ return coords[None].repeat(batch, 1, 1, 1)
78
+
79
+
80
+ def upflow8(flow, mode='bilinear'):
81
+ new_size = (8 * flow.shape[2], 8 * flow.shape[3])
82
+ return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
AIGVDet/core/utils1/config.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ from abc import ABC
5
+ from typing import Type
6
+
7
+
8
+ class DefaultConfigs(ABC):
9
+ ####### base setting ######
10
+ gpus = [0]
11
+ seed = 3407
12
+ arch = "resnet50"
13
+ datasets = ["zhaolian_train"]
14
+ datasets_test = ["adm_res_abs_ddim20s"]
15
+ mode = "binary"
16
+ class_bal = False
17
+ batch_size = 64
18
+ loadSize = 256
19
+ cropSize = 224
20
+ epoch = "latest"
21
+ num_workers = 20
22
+ serial_batches = False
23
+ isTrain = True
24
+
25
+ # data augmentation
26
+ rz_interp = ["bilinear"]
27
+ # blur_prob = 0.0
28
+ blur_prob = 0.1
29
+ blur_sig = [0.5]
30
+ # jpg_prob = 0.0
31
+ jpg_prob = 0.1
32
+ jpg_method = ["cv2"]
33
+ jpg_qual = [75]
34
+ gray_prob = 0.0
35
+ aug_resize = True
36
+ aug_crop = True
37
+ aug_flip = True
38
+ aug_norm = True
39
+
40
+ ####### train setting ######
41
+ warmup = False
42
+ # warmup = True
43
+ warmup_epoch = 3
44
+ earlystop = True
45
+ earlystop_epoch = 5
46
+ optim = "adam"
47
+ new_optim = False
48
+ loss_freq = 400
49
+ save_latest_freq = 2000
50
+ save_epoch_freq = 20
51
+ continue_train = False
52
+ epoch_count = 1
53
+ last_epoch = -1
54
+ nepoch = 400
55
+ beta1 = 0.9
56
+ lr = 0.0001
57
+ init_type = "normal"
58
+ init_gain = 0.02
59
+ pretrained = True
60
+
61
+ # paths information
62
+ root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
63
+ dataset_root = os.path.join(root_dir, "data")
64
+ exp_root = os.path.join(root_dir, "data", "exp")
65
+ _exp_name = ""
66
+ exp_dir = ""
67
+ ckpt_dir = ""
68
+ logs_path = ""
69
+ ckpt_path = ""
70
+
71
+ @property
72
+ def exp_name(self):
73
+ return self._exp_name
74
+
75
+ @exp_name.setter
76
+ def exp_name(self, value: str):
77
+ self._exp_name = value
78
+ self.exp_dir: str = os.path.join(self.exp_root, self.exp_name)
79
+ self.ckpt_dir: str = os.path.join(self.exp_dir, "ckpt")
80
+ self.logs_path: str = os.path.join(self.exp_dir, "logs.txt")
81
+
82
+ os.makedirs(self.exp_dir, exist_ok=True)
83
+ os.makedirs(self.ckpt_dir, exist_ok=True)
84
+
85
+ def to_dict(self):
86
+ dic = {}
87
+ for fieldkey in dir(self):
88
+ fieldvalue = getattr(self, fieldkey)
89
+ if not fieldkey.startswith("__") and not callable(fieldvalue) and not fieldkey.startswith("_"):
90
+ dic[fieldkey] = fieldvalue
91
+ return dic
92
+
93
+
94
+ def args_list2dict(arg_list: list):
95
+ assert len(arg_list) % 2 == 0, f"Override list has odd length: {arg_list}; it must be a list of pairs"
96
+ return dict(zip(arg_list[::2], arg_list[1::2]))
97
+
98
+
99
+ def str2bool(v: str) -> bool:
100
+ if isinstance(v, bool):
101
+ return v
102
+ elif v.lower() in ("true", "yes", "on", "y", "t", "1"):
103
+ return True
104
+ elif v.lower() in ("false", "no", "off", "n", "f", "0"):
105
+ return False
106
+ else:
107
+ return bool(v)
108
+
109
+
110
+ def str2list(v: str, element_type=None) -> list:
111
+ if not isinstance(v, (list, tuple, set)):
112
+ v = v.lstrip("[").rstrip("]")
113
+ v = v.split(",")
114
+ v = list(map(str.strip, v))
115
+ if element_type is not None:
116
+ v = list(map(element_type, v))
117
+ return v
118
+
119
+
120
+ CONFIGCLASS = Type[DefaultConfigs]
121
+
122
+ parser = argparse.ArgumentParser()
123
+ parser.add_argument("--gpus", default=[0], type=int, nargs="+")
124
+ parser.add_argument("--exp_name", default="", type=str)
125
+ parser.add_argument("--ckpt", default="model_epoch_latest.pth", type=str)
126
+ parser.add_argument("opts", default=[], nargs=argparse.REMAINDER)
127
+ args = parser.parse_args()
128
+
129
+ if os.path.exists(os.path.join(DefaultConfigs.exp_root, args.exp_name, "config.py")):
130
+ sys.path.insert(0, os.path.join(DefaultConfigs.exp_root, args.exp_name))
131
+ from config import cfg
132
+
133
+ cfg: CONFIGCLASS
134
+ else:
135
+ cfg = DefaultConfigs()
136
+
137
+ if args.opts:
138
+ opts = args_list2dict(args.opts)
139
+ for k, v in opts.items():
140
+ if not hasattr(cfg, k):
141
+ raise ValueError(f"Unrecognized option: {k}")
142
+ original_type = type(getattr(cfg, k))
143
+ if original_type == bool:
144
+ setattr(cfg, k, str2bool(v))
145
+ elif original_type in (list, tuple, set):
146
+ setattr(cfg, k, str2list(v, type(getattr(cfg, k)[0])))
147
+ else:
148
+ setattr(cfg, k, original_type(v))
149
+
150
+ cfg.gpus: list = args.gpus
151
+ os.environ["CUDA_VISIBLE_DEVICES"] = ", ".join([str(gpu) for gpu in cfg.gpus])
152
+ cfg.exp_name = args.exp_name
153
+ cfg.ckpt_path: str = os.path.join(cfg.ckpt_dir, args.ckpt)
154
+
155
+ if isinstance(cfg.datasets, str):
156
+ cfg.datasets = cfg.datasets.split(",")
AIGVDet/core/utils1/datasets.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from io import BytesIO
3
+ from random import choice, random
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import torch.utils.data
9
+ import torchvision.datasets as datasets
10
+ import torchvision.transforms as transforms
11
+ import torchvision.transforms.functional as TF
12
+ from PIL import Image, ImageFile
13
+ from scipy.ndimage import gaussian_filter
14
+ from torch.utils.data.sampler import WeightedRandomSampler
15
+
16
+ from utils1.config import CONFIGCLASS
17
+
18
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
19
+
20
+
21
+ def dataset_folder(root: str, cfg: CONFIGCLASS):
22
+ if cfg.mode == "binary":
23
+ return binary_dataset(root, cfg)
24
+ if cfg.mode == "filename":
25
+ return FileNameDataset(root, cfg)
26
+ raise ValueError("cfg.mode needs to be binary or filename.")
27
+
28
+
29
+ def binary_dataset(root: str, cfg: CONFIGCLASS):
30
+ identity_transform = transforms.Lambda(lambda img: img)
31
+
32
+ rz_func = identity_transform
33
+
34
+ if cfg.isTrain:
35
+ crop_func = transforms.RandomCrop((448,448))
36
+ else:
37
+ crop_func = transforms.CenterCrop((448,448)) if cfg.aug_crop else identity_transform
38
+
39
+ if cfg.isTrain and cfg.aug_flip:
40
+ flip_func = transforms.RandomHorizontalFlip()
41
+ else:
42
+ flip_func = identity_transform
43
+
44
+
45
+ return datasets.ImageFolder(
46
+ root,
47
+ transforms.Compose(
48
+ [
49
+ rz_func,
50
+ #change
51
+ transforms.Lambda(lambda img: blur_jpg_augment(img, cfg)),
52
+ crop_func,
53
+ flip_func,
54
+ transforms.ToTensor(),
55
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
56
+ if cfg.aug_norm
57
+ else identity_transform,
58
+ ]
59
+ )
60
+ )
61
+
62
+
63
+ class FileNameDataset(datasets.ImageFolder):
64
+ def name(self):
65
+ return 'FileNameDataset'
66
+
67
+ def __init__(self, opt, root):
68
+ self.opt = opt
69
+ super().__init__(root)
70
+
71
+ def __getitem__(self, index):
72
+ # Loading sample
73
+ path, target = self.samples[index]
74
+ return path
75
+
76
+
77
+ def blur_jpg_augment(img: Image.Image, cfg: CONFIGCLASS):
78
+ img: np.ndarray = np.array(img)
79
+ if cfg.isTrain:
80
+ if random() < cfg.blur_prob:
81
+ sig = sample_continuous(cfg.blur_sig)
82
+ gaussian_blur(img, sig)
83
+
84
+ if random() < cfg.jpg_prob:
85
+ method = sample_discrete(cfg.jpg_method)
86
+ qual = sample_discrete(cfg.jpg_qual)
87
+ img = jpeg_from_key(img, qual, method)
88
+
89
+ return Image.fromarray(img)
90
+
91
+
92
+ def sample_continuous(s: list):
93
+ if len(s) == 1:
94
+ return s[0]
95
+ if len(s) == 2:
96
+ rg = s[1] - s[0]
97
+ return random() * rg + s[0]
98
+ raise ValueError("Length of iterable s should be 1 or 2.")
99
+
100
+
101
+ def sample_discrete(s: list):
102
+ return s[0] if len(s) == 1 else choice(s)
103
+
104
+
105
+ def gaussian_blur(img: np.ndarray, sigma: float):
106
+ gaussian_filter(img[:, :, 0], output=img[:, :, 0], sigma=sigma)
107
+ gaussian_filter(img[:, :, 1], output=img[:, :, 1], sigma=sigma)
108
+ gaussian_filter(img[:, :, 2], output=img[:, :, 2], sigma=sigma)
109
+
110
+
111
+ def cv2_jpg(img: np.ndarray, compress_val: int) -> np.ndarray:
112
+ img_cv2 = img[:, :, ::-1]
113
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), compress_val]
114
+ result, encimg = cv2.imencode(".jpg", img_cv2, encode_param)
115
+ decimg = cv2.imdecode(encimg, 1)
116
+ return decimg[:, :, ::-1]
117
+
118
+
119
+ def pil_jpg(img: np.ndarray, compress_val: int):
120
+ out = BytesIO()
121
+ img = Image.fromarray(img)
122
+ img.save(out, format="jpeg", quality=compress_val)
123
+ img = Image.open(out)
124
+ # load from memory before ByteIO closes
125
+ img = np.array(img)
126
+ out.close()
127
+ return img
128
+
129
+
130
+ jpeg_dict = {"cv2": cv2_jpg, "pil": pil_jpg}
131
+
132
+
133
+ def jpeg_from_key(img: np.ndarray, compress_val: int, key: str) -> np.ndarray:
134
+ method = jpeg_dict[key]
135
+ return method(img, compress_val)
136
+
137
+
138
+ rz_dict = {'bilinear': Image.BILINEAR,
139
+ 'bicubic': Image.BICUBIC,
140
+ 'lanczos': Image.LANCZOS,
141
+ 'nearest': Image.NEAREST}
142
+ def custom_resize(img: Image.Image, cfg: CONFIGCLASS) -> Image.Image:
143
+ interp = sample_discrete(cfg.rz_interp)
144
+ return TF.resize(img, cfg.loadSize, interpolation=rz_dict[interp])
145
+
146
+
147
+ def get_dataset(cfg: CONFIGCLASS):
148
+ dset_lst = []
149
+ for dataset in cfg.datasets:
150
+ root = os.path.join(cfg.dataset_root, dataset)
151
+ dset = dataset_folder(root, cfg)
152
+ dset_lst.append(dset)
153
+ return torch.utils.data.ConcatDataset(dset_lst)
154
+
155
+
156
+ def get_bal_sampler(dataset: torch.utils.data.ConcatDataset):
157
+ targets = []
158
+ for d in dataset.datasets:
159
+ targets.extend(d.targets)
160
+
161
+ ratio = np.bincount(targets)
162
+ w = 1.0 / torch.tensor(ratio, dtype=torch.float)
163
+ sample_weights = w[targets]
164
+ return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights))
165
+
166
+
167
+ def create_dataloader(cfg: CONFIGCLASS):
168
+ shuffle = not cfg.serial_batches if (cfg.isTrain and not cfg.class_bal) else False
169
+ dataset = get_dataset(cfg)
170
+ sampler = get_bal_sampler(dataset) if cfg.class_bal else None
171
+
172
+ return torch.utils.data.DataLoader(
173
+ dataset,
174
+ batch_size=cfg.batch_size,
175
+ shuffle=shuffle,
176
+ sampler=sampler,
177
+ num_workers=int(cfg.num_workers),
178
+ )
AIGVDet/core/utils1/earlystop.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from utils1.trainer import Trainer
4
+
5
+
6
+ class EarlyStopping:
7
+ """Early stops the training if validation loss doesn't improve after a given patience."""
8
+
9
+ def __init__(self, patience=1, verbose=False, delta=0):
10
+ """
11
+ Args:
12
+ patience (int): How long to wait after last time validation loss improved.
13
+ Default: 7
14
+ verbose (bool): If True, prints a message for each validation loss improvement.
15
+ Default: False
16
+ delta (float): Minimum change in the monitored quantity to qualify as an improvement.
17
+ Default: 0
18
+ """
19
+ self.patience = patience
20
+ self.verbose = verbose
21
+ self.counter = 0
22
+ self.best_score = None
23
+ self.early_stop = False
24
+ self.score_max = -np.Inf
25
+ self.delta = delta
26
+
27
+ def __call__(self, score: float, trainer: Trainer):
28
+ if self.best_score is None:
29
+ self.best_score = score
30
+ self.save_checkpoint(score, trainer)
31
+ elif score < self.best_score - self.delta:
32
+ self.counter += 1
33
+ print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
34
+ if self.counter >= self.patience:
35
+ self.early_stop = True
36
+ else:
37
+ self.best_score = score
38
+ self.save_checkpoint(score, trainer)
39
+ self.counter = 0
40
+
41
+ def save_checkpoint(self, score: float, trainer: Trainer):
42
+ """Saves model when validation loss decrease."""
43
+ if self.verbose:
44
+ print(f"Validation accuracy increased ({self.score_max:.6f} --> {score:.6f}). Saving model ...")
45
+ trainer.save_networks("best")
46
+ self.score_max = score
AIGVDet/core/utils1/eval.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from utils1.config import CONFIGCLASS
10
+ from utils1.utils import to_cuda
11
+
12
+
13
+ def get_val_cfg(cfg: CONFIGCLASS, split="val", copy=True):
14
+ if copy:
15
+ from copy import deepcopy
16
+
17
+ val_cfg = deepcopy(cfg)
18
+ else:
19
+ val_cfg = cfg
20
+ val_cfg.dataset_root = os.path.join(val_cfg.dataset_root, split)
21
+ val_cfg.datasets = cfg.datasets_test
22
+ val_cfg.isTrain = False
23
+ # val_cfg.aug_resize = False
24
+ # val_cfg.aug_crop = False
25
+ val_cfg.aug_flip = False
26
+ val_cfg.serial_batches = True
27
+ val_cfg.jpg_method = ["pil"]
28
+ # Currently assumes jpg_prob, blur_prob 0 or 1
29
+ if len(val_cfg.blur_sig) == 2:
30
+ b_sig = val_cfg.blur_sig
31
+ val_cfg.blur_sig = [(b_sig[0] + b_sig[1]) / 2]
32
+ if len(val_cfg.jpg_qual) != 1:
33
+ j_qual = val_cfg.jpg_qual
34
+ val_cfg.jpg_qual = [int((j_qual[0] + j_qual[-1]) / 2)]
35
+ return val_cfg
36
+
37
+ def validate(model: nn.Module, cfg: CONFIGCLASS):
38
+ from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score
39
+
40
+ from utils1.datasets import create_dataloader
41
+
42
+ data_loader = create_dataloader(cfg)
43
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
+
45
+ with torch.no_grad():
46
+ y_true, y_pred = [], []
47
+ for data in data_loader:
48
+ img, label, meta = data if len(data) == 3 else (*data, None)
49
+ in_tens = to_cuda(img, device)
50
+ meta = to_cuda(meta, device)
51
+ predict = model(in_tens, meta).sigmoid()
52
+ y_pred.extend(predict.flatten().tolist())
53
+ y_true.extend(label.flatten().tolist())
54
+
55
+ y_true, y_pred = np.array(y_true), np.array(y_pred)
56
+ r_acc = accuracy_score(y_true[y_true == 0], y_pred[y_true == 0] > 0.5)
57
+ f_acc = accuracy_score(y_true[y_true == 1], y_pred[y_true == 1] > 0.5)
58
+ acc = accuracy_score(y_true, y_pred > 0.5)
59
+ ap = average_precision_score(y_true, y_pred)
60
+ results = {
61
+ "ACC": acc,
62
+ "AP": ap,
63
+ "R_ACC": r_acc,
64
+ "F_ACC": f_acc,
65
+ }
66
+ return results
AIGVDet/core/utils1/trainer.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import init
6
+
7
+ from utils1.config import CONFIGCLASS
8
+ from utils1.utils import get_network
9
+ from utils1.warmup import GradualWarmupScheduler
10
+
11
+
12
+ class BaseModel(nn.Module):
13
+ def __init__(self, cfg: CONFIGCLASS):
14
+ super().__init__()
15
+ self.cfg = cfg
16
+ self.total_steps = 0
17
+ self.isTrain = cfg.isTrain
18
+ self.save_dir = cfg.ckpt_dir
19
+ self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
20
+ self.model:nn.Module
21
+ self.model=nn.Module.to(self.device)
22
+ # self.model.to(self.device)
23
+ #self.model.load_state_dict(torch.load('./checkpoints/optical.pth'))
24
+ self.optimizer: torch.optim.Optimizer
25
+
26
+ def save_networks(self, epoch: int):
27
+ save_filename = f"model_epoch_{epoch}.pth"
28
+ save_path = os.path.join(self.save_dir, save_filename)
29
+
30
+ # serialize model and optimizer to dict
31
+ state_dict = {
32
+ "model": self.model.state_dict(),
33
+ "optimizer": self.optimizer.state_dict(),
34
+ "total_steps": self.total_steps,
35
+ }
36
+
37
+ torch.save(state_dict, save_path)
38
+
39
+ # load models from the disk
40
+ def load_networks(self, epoch: int):
41
+ load_filename = f"model_epoch_{epoch}.pth"
42
+ load_path = os.path.join(self.save_dir, load_filename)
43
+
44
+ if epoch==0:
45
+ # load_filename = f"lsun_adm.pth"
46
+ load_path="checkpoints/optical.pth"
47
+ print("loading optical path")
48
+ else :
49
+ print(f"loading the model from {load_path}")
50
+
51
+ # print(f"loading the model from {load_path}")
52
+
53
+ # if you are using PyTorch newer than 0.4 (e.g., built from
54
+ # GitHub source), you can remove str() on self.device
55
+ state_dict = torch.load(load_path, map_location=self.device)
56
+ if hasattr(state_dict, "_metadata"):
57
+ del state_dict._metadata
58
+
59
+ self.model.load_state_dict(state_dict["model"])
60
+ self.total_steps = state_dict["total_steps"]
61
+
62
+ if self.isTrain and not self.cfg.new_optim:
63
+ self.optimizer.load_state_dict(state_dict["optimizer"])
64
+ # move optimizer state to GPU
65
+ for state in self.optimizer.state.values():
66
+ for k, v in state.items():
67
+ if torch.is_tensor(v):
68
+ state[k] = v.to(self.device)
69
+
70
+ for g in self.optimizer.param_groups:
71
+ g["lr"] = self.cfg.lr
72
+
73
+ def eval(self):
74
+ self.model.eval()
75
+
76
+ def test(self):
77
+ with torch.no_grad():
78
+ self.forward()
79
+
80
+
81
+ def init_weights(net: nn.Module, init_type="normal", gain=0.02):
82
+ def init_func(m: nn.Module):
83
+ classname = m.__class__.__name__
84
+ if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1):
85
+ if init_type == "normal":
86
+ init.normal_(m.weight.data, 0.0, gain)
87
+ elif init_type == "xavier":
88
+ init.xavier_normal_(m.weight.data, gain=gain)
89
+ elif init_type == "kaiming":
90
+ init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
91
+ elif init_type == "orthogonal":
92
+ init.orthogonal_(m.weight.data, gain=gain)
93
+ else:
94
+ raise NotImplementedError(f"initialization method [{init_type}] is not implemented")
95
+ if hasattr(m, "bias") and m.bias is not None:
96
+ init.constant_(m.bias.data, 0.0)
97
+ elif classname.find("BatchNorm2d") != -1:
98
+ init.normal_(m.weight.data, 1.0, gain)
99
+ init.constant_(m.bias.data, 0.0)
100
+
101
+ print(f"initialize network with {init_type}")
102
+ net.apply(init_func)
103
+
104
+
105
+ class Trainer(BaseModel):
106
+ def name(self):
107
+ return "Trainer"
108
+
109
+ def __init__(self, cfg: CONFIGCLASS):
110
+ super().__init__(cfg)
111
+ self.arch = cfg.arch
112
+ self.model = get_network(self.arch, cfg.isTrain, cfg.continue_train, cfg.init_gain, cfg.pretrained)
113
+
114
+ self.loss_fn = nn.BCEWithLogitsLoss()
115
+ # initialize optimizers
116
+ if cfg.optim == "adam":
117
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg.lr, betas=(cfg.beta1, 0.999))
118
+ elif cfg.optim == "sgd":
119
+ self.optimizer = torch.optim.SGD(self.model.parameters(), lr=cfg.lr, momentum=0.9, weight_decay=5e-4)
120
+ else:
121
+ raise ValueError("optim should be [adam, sgd]")
122
+ if cfg.warmup:
123
+ scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
124
+ self.optimizer, cfg.nepoch - cfg.warmup_epoch, eta_min=1e-6
125
+ )
126
+ self.scheduler = GradualWarmupScheduler(
127
+ self.optimizer, multiplier=1, total_epoch=cfg.warmup_epoch, after_scheduler=scheduler_cosine
128
+ )
129
+ self.scheduler.step()
130
+ if cfg.continue_train:
131
+ self.load_networks(cfg.epoch)
132
+ self.model.to(self.device)
133
+
134
+
135
+
136
+ def adjust_learning_rate(self, min_lr=1e-6):
137
+ for param_group in self.optimizer.param_groups:
138
+ param_group["lr"] /= 10.0
139
+ if param_group["lr"] < min_lr:
140
+ return False
141
+ return True
142
+
143
+ def set_input(self, input):
144
+ img, label, meta = input if len(input) == 3 else (input[0], input[1], {})
145
+ self.input = img.to(self.device)
146
+ self.label = label.to(self.device).float()
147
+ for k in meta.keys():
148
+ if isinstance(meta[k], torch.Tensor):
149
+ meta[k] = meta[k].to(self.device)
150
+ self.meta = meta
151
+
152
+ def forward(self):
153
+ self.output = self.model(self.input, self.meta)
154
+
155
+ def get_loss(self):
156
+ return self.loss_fn(self.output.squeeze(1), self.label)
157
+
158
+ def optimize_parameters(self):
159
+ self.forward()
160
+ self.loss = self.loss_fn(self.output.squeeze(1), self.label)
161
+ self.optimizer.zero_grad()
162
+ self.loss.backward()
163
+ self.optimizer.step()
AIGVDet/core/utils1/utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import time
5
+ import warnings
6
+ from importlib import import_module
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ from PIL import Image
12
+
13
+ warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.functional")
14
+
15
+
16
+ def str2bool(v: str, strict=True) -> bool:
17
+ if isinstance(v, bool):
18
+ return v
19
+ elif isinstance(v, str):
20
+ if v.lower() in ("true", "yes", "on" "t", "y", "1"):
21
+ return True
22
+ elif v.lower() in ("false", "no", "off", "f", "n", "0"):
23
+ return False
24
+ if strict:
25
+ raise argparse.ArgumentTypeError("Unsupported value encountered.")
26
+ else:
27
+ return True
28
+
29
+
30
+ def to_cuda(data, device="cuda", exclude_keys: "list[str]" = None):
31
+ if isinstance(data, torch.Tensor):
32
+ data = data.to(device)
33
+ elif isinstance(data, (tuple, list, set)):
34
+ data = [to_cuda(b, device) for b in data]
35
+ elif isinstance(data, dict):
36
+ if exclude_keys is None:
37
+ exclude_keys = []
38
+ for k in data.keys():
39
+ if k not in exclude_keys:
40
+ data[k] = to_cuda(data[k], device)
41
+ else:
42
+ # raise TypeError(f"Unsupported type: {type(data)}")
43
+ data = data
44
+ return data
45
+
46
+
47
+ class HiddenPrints:
48
+ def __enter__(self):
49
+ self._original_stdout = sys.stdout
50
+ sys.stdout = open(os.devnull, "w")
51
+
52
+ def __exit__(self, exc_type, exc_val, exc_tb):
53
+ sys.stdout.close()
54
+ sys.stdout = self._original_stdout
55
+
56
+
57
+ class Logger(object):
58
+ def __init__(self):
59
+ self.terminal = sys.stdout
60
+ self.file = None
61
+
62
+ def open(self, file, mode=None):
63
+ if mode is None:
64
+ mode = "w"
65
+ self.file = open(file, mode)
66
+
67
+ def write(self, message, is_terminal=1, is_file=1):
68
+ if "\r" in message:
69
+ is_file = 0
70
+ if is_terminal == 1:
71
+ self.terminal.write(message)
72
+ self.terminal.flush()
73
+ if is_file == 1:
74
+ self.file.write(message)
75
+ self.file.flush()
76
+
77
+ def flush(self):
78
+ # this flush method is needed for python 3 compatibility.
79
+ # this handles the flush command by doing nothing.
80
+ # you might want to specify some extra behavior here.
81
+ pass
82
+
83
+
84
+ def get_network(arch: str, isTrain=False, continue_train=False, init_gain=0.02, pretrained=True):
85
+ if "resnet" in arch:
86
+ from ...networks.resnet import ResNet
87
+
88
+ resnet = getattr(import_module("...networks.resnet", package=__package__), arch)
89
+ if isTrain:
90
+ if continue_train:
91
+ model: ResNet = resnet(num_classes=1)
92
+ else:
93
+ model: ResNet = resnet(pretrained=pretrained)
94
+ model.fc = nn.Linear(2048, 1)
95
+ nn.init.normal_(model.fc.weight.data, 0.0, init_gain)
96
+ else:
97
+ model: ResNet = resnet(num_classes=1)
98
+ return model
99
+ else:
100
+ raise ValueError(f"Unsupported arch: {arch}")
101
+
102
+
103
+ def pad_img_to_square(img: np.ndarray):
104
+ H, W = img.shape[:2]
105
+ if H != W:
106
+ new_size = max(H, W)
107
+ img = np.pad(img, ((0, new_size - H), (0, new_size - W), (0, 0)), mode="constant")
108
+ assert img.shape[0] == img.shape[1] == new_size
109
+ return img
AIGVDet/core/utils1/utils1/config.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ from abc import ABC
5
+ from typing import Type
6
+
7
+
8
+ class DefaultConfigs(ABC):
9
+ ####### base setting ######
10
+ gpus = [0]
11
+ seed = 3407
12
+ arch = "resnet50"
13
+ datasets = ["zhaolian_train"]
14
+ datasets_test = ["adm_res_abs_ddim20s"]
15
+ mode = "binary"
16
+ class_bal = False
17
+ batch_size = 64
18
+ loadSize = 256
19
+ cropSize = 224
20
+ epoch = "latest"
21
+ num_workers = 20
22
+ serial_batches = False
23
+ isTrain = True
24
+
25
+ # data augmentation
26
+ rz_interp = ["bilinear"]
27
+ # blur_prob = 0.0
28
+ blur_prob = 0.1
29
+ blur_sig = [0.5]
30
+ # jpg_prob = 0.0
31
+ jpg_prob = 0.1
32
+ jpg_method = ["cv2"]
33
+ jpg_qual = [75]
34
+ gray_prob = 0.0
35
+ aug_resize = True
36
+ aug_crop = True
37
+ aug_flip = True
38
+ aug_norm = True
39
+
40
+ ####### train setting ######
41
+ warmup = False
42
+ # warmup = True
43
+ warmup_epoch = 3
44
+ earlystop = True
45
+ earlystop_epoch = 5
46
+ optim = "adam"
47
+ new_optim = False
48
+ loss_freq = 400
49
+ save_latest_freq = 2000
50
+ save_epoch_freq = 20
51
+ continue_train = False
52
+ epoch_count = 1
53
+ last_epoch = -1
54
+ nepoch = 400
55
+ beta1 = 0.9
56
+ lr = 0.0001
57
+ init_type = "normal"
58
+ init_gain = 0.02
59
+ pretrained = True
60
+
61
+ # paths information
62
+ root_dir1 = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
63
+ root_dir = os.path.dirname(root_dir1)
64
+ dataset_root = os.path.join(root_dir, "data")
65
+ exp_root = os.path.join(root_dir, "data", "exp")
66
+ _exp_name = ""
67
+ exp_dir = ""
68
+ ckpt_dir = ""
69
+ logs_path = ""
70
+ ckpt_path = ""
71
+
72
+ @property
73
+ def exp_name(self):
74
+ return self._exp_name
75
+
76
+ @exp_name.setter
77
+ def exp_name(self, value: str):
78
+ self._exp_name = value
79
+ self.exp_dir: str = os.path.join(self.exp_root, self.exp_name)
80
+ self.ckpt_dir: str = os.path.join(self.exp_dir, "ckpt")
81
+ self.logs_path: str = os.path.join(self.exp_dir, "logs.txt")
82
+
83
+ os.makedirs(self.exp_dir, exist_ok=True)
84
+ os.makedirs(self.ckpt_dir, exist_ok=True)
85
+
86
+ def to_dict(self):
87
+ dic = {}
88
+ for fieldkey in dir(self):
89
+ fieldvalue = getattr(self, fieldkey)
90
+ if not fieldkey.startswith("__") and not callable(fieldvalue) and not fieldkey.startswith("_"):
91
+ dic[fieldkey] = fieldvalue
92
+ return dic
93
+
94
+
95
+ def args_list2dict(arg_list: list):
96
+ assert len(arg_list) % 2 == 0, f"Override list has odd length: {arg_list}; it must be a list of pairs"
97
+ return dict(zip(arg_list[::2], arg_list[1::2]))
98
+
99
+
100
+ def str2bool(v: str) -> bool:
101
+ if isinstance(v, bool):
102
+ return v
103
+ elif v.lower() in ("true", "yes", "on", "y", "t", "1"):
104
+ return True
105
+ elif v.lower() in ("false", "no", "off", "n", "f", "0"):
106
+ return False
107
+ else:
108
+ return bool(v)
109
+
110
+
111
+ def str2list(v: str, element_type=None) -> list:
112
+ if not isinstance(v, (list, tuple, set)):
113
+ v = v.lstrip("[").rstrip("]")
114
+ v = v.split(",")
115
+ v = list(map(str.strip, v))
116
+ if element_type is not None:
117
+ v = list(map(element_type, v))
118
+ return v
119
+
120
+
121
+ CONFIGCLASS = Type[DefaultConfigs]
122
+
123
+ parser = argparse.ArgumentParser()
124
+ parser.add_argument("--gpus", default=[0], type=int, nargs="+")
125
+ parser.add_argument("--exp_name", default="", type=str)
126
+ parser.add_argument("--ckpt", default="model_epoch_latest.pth", type=str)
127
+ parser.add_argument("opts", default=[], nargs=argparse.REMAINDER)
128
+ args = parser.parse_args()
129
+
130
+ if os.path.exists(os.path.join(DefaultConfigs.exp_root, args.exp_name, "config.py")):
131
+ sys.path.insert(0, os.path.join(DefaultConfigs.exp_root, args.exp_name))
132
+ from config import cfg
133
+
134
+ cfg: CONFIGCLASS
135
+ else:
136
+ cfg = DefaultConfigs()
137
+
138
+ if args.opts:
139
+ opts = args_list2dict(args.opts)
140
+ for k, v in opts.items():
141
+ if not hasattr(cfg, k):
142
+ raise ValueError(f"Unrecognized option: {k}")
143
+ original_type = type(getattr(cfg, k))
144
+ if original_type == bool:
145
+ setattr(cfg, k, str2bool(v))
146
+ elif original_type in (list, tuple, set):
147
+ setattr(cfg, k, str2list(v, type(getattr(cfg, k)[0])))
148
+ else:
149
+ setattr(cfg, k, original_type(v))
150
+
151
+ cfg.gpus: list = args.gpus
152
+ os.environ["CUDA_VISIBLE_DEVICES"] = ", ".join([str(gpu) for gpu in cfg.gpus])
153
+ cfg.exp_name = args.exp_name
154
+ cfg.ckpt_path: str = os.path.join(cfg.ckpt_dir, args.ckpt)
155
+
156
+ if isinstance(cfg.datasets, str):
157
+ cfg.datasets = cfg.datasets.split(",")
AIGVDet/core/utils1/utils1/datasets.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from io import BytesIO
3
+ from random import choice, random
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import torch.utils.data
9
+ import torchvision.datasets as datasets
10
+ import torchvision.transforms as transforms
11
+ import torchvision.transforms.functional as TF
12
+ from PIL import Image, ImageFile
13
+ from scipy.ndimage import gaussian_filter
14
+ from torch.utils.data.sampler import WeightedRandomSampler
15
+
16
+ from .config import CONFIGCLASS
17
+
18
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
19
+
20
+
21
+ def dataset_folder(root: str, cfg: CONFIGCLASS):
22
+ if cfg.mode == "binary":
23
+ return binary_dataset(root, cfg)
24
+ if cfg.mode == "filename":
25
+ return FileNameDataset(root, cfg)
26
+ raise ValueError("cfg.mode needs to be binary or filename.")
27
+
28
+
29
+ def binary_dataset(root: str, cfg: CONFIGCLASS):
30
+ identity_transform = transforms.Lambda(lambda img: img)
31
+
32
+ rz_func = identity_transform
33
+
34
+ if cfg.isTrain:
35
+ crop_func = transforms.RandomCrop((448,448))
36
+ else:
37
+ crop_func = transforms.CenterCrop((448,448)) if cfg.aug_crop else identity_transform
38
+
39
+ if cfg.isTrain and cfg.aug_flip:
40
+ flip_func = transforms.RandomHorizontalFlip()
41
+ else:
42
+ flip_func = identity_transform
43
+
44
+
45
+ return datasets.ImageFolder(
46
+ root,
47
+ transforms.Compose(
48
+ [
49
+ rz_func,
50
+ #change
51
+ transforms.Lambda(lambda img: blur_jpg_augment(img, cfg)),
52
+ crop_func,
53
+ flip_func,
54
+ transforms.ToTensor(),
55
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
56
+ if cfg.aug_norm
57
+ else identity_transform,
58
+ ]
59
+ )
60
+ )
61
+
62
+
63
+ class FileNameDataset(datasets.ImageFolder):
64
+ def name(self):
65
+ return 'FileNameDataset'
66
+
67
+ def __init__(self, opt, root):
68
+ self.opt = opt
69
+ super().__init__(root)
70
+
71
+ def __getitem__(self, index):
72
+ # Loading sample
73
+ path, target = self.samples[index]
74
+ return path
75
+
76
+
77
+ def blur_jpg_augment(img: Image.Image, cfg: CONFIGCLASS):
78
+ img: np.ndarray = np.array(img)
79
+ if cfg.isTrain:
80
+ if random() < cfg.blur_prob:
81
+ sig = sample_continuous(cfg.blur_sig)
82
+ gaussian_blur(img, sig)
83
+
84
+ if random() < cfg.jpg_prob:
85
+ method = sample_discrete(cfg.jpg_method)
86
+ qual = sample_discrete(cfg.jpg_qual)
87
+ img = jpeg_from_key(img, qual, method)
88
+
89
+ return Image.fromarray(img)
90
+
91
+
92
+ def sample_continuous(s: list):
93
+ if len(s) == 1:
94
+ return s[0]
95
+ if len(s) == 2:
96
+ rg = s[1] - s[0]
97
+ return random() * rg + s[0]
98
+ raise ValueError("Length of iterable s should be 1 or 2.")
99
+
100
+
101
+ def sample_discrete(s: list):
102
+ return s[0] if len(s) == 1 else choice(s)
103
+
104
+
105
+ def gaussian_blur(img: np.ndarray, sigma: float):
106
+ gaussian_filter(img[:, :, 0], output=img[:, :, 0], sigma=sigma)
107
+ gaussian_filter(img[:, :, 1], output=img[:, :, 1], sigma=sigma)
108
+ gaussian_filter(img[:, :, 2], output=img[:, :, 2], sigma=sigma)
109
+
110
+
111
+ def cv2_jpg(img: np.ndarray, compress_val: int) -> np.ndarray:
112
+ img_cv2 = img[:, :, ::-1]
113
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), compress_val]
114
+ result, encimg = cv2.imencode(".jpg", img_cv2, encode_param)
115
+ decimg = cv2.imdecode(encimg, 1)
116
+ return decimg[:, :, ::-1]
117
+
118
+
119
+ def pil_jpg(img: np.ndarray, compress_val: int):
120
+ out = BytesIO()
121
+ img = Image.fromarray(img)
122
+ img.save(out, format="jpeg", quality=compress_val)
123
+ img = Image.open(out)
124
+ # load from memory before ByteIO closes
125
+ img = np.array(img)
126
+ out.close()
127
+ return img
128
+
129
+
130
+ jpeg_dict = {"cv2": cv2_jpg, "pil": pil_jpg}
131
+
132
+
133
+ def jpeg_from_key(img: np.ndarray, compress_val: int, key: str) -> np.ndarray:
134
+ method = jpeg_dict[key]
135
+ return method(img, compress_val)
136
+
137
+
138
+ rz_dict = {'bilinear': Image.BILINEAR,
139
+ 'bicubic': Image.BICUBIC,
140
+ 'lanczos': Image.LANCZOS,
141
+ 'nearest': Image.NEAREST}
142
+ def custom_resize(img: Image.Image, cfg: CONFIGCLASS) -> Image.Image:
143
+ interp = sample_discrete(cfg.rz_interp)
144
+ return TF.resize(img, cfg.loadSize, interpolation=rz_dict[interp])
145
+
146
+
147
+ def get_dataset(cfg: CONFIGCLASS):
148
+ dset_lst = []
149
+ for dataset in cfg.datasets:
150
+ root = os.path.join(cfg.dataset_root, dataset)
151
+ dset = dataset_folder(root, cfg)
152
+ dset_lst.append(dset)
153
+ return torch.utils.data.ConcatDataset(dset_lst)
154
+
155
+
156
+ def get_bal_sampler(dataset: torch.utils.data.ConcatDataset):
157
+ targets = []
158
+ for d in dataset.datasets:
159
+ targets.extend(d.targets)
160
+
161
+ ratio = np.bincount(targets)
162
+ w = 1.0 / torch.tensor(ratio, dtype=torch.float)
163
+ sample_weights = w[targets]
164
+ return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights))
165
+
166
+
167
+ def create_dataloader(cfg: CONFIGCLASS):
168
+ shuffle = not cfg.serial_batches if (cfg.isTrain and not cfg.class_bal) else False
169
+ dataset = get_dataset(cfg)
170
+ sampler = get_bal_sampler(dataset) if cfg.class_bal else None
171
+
172
+ return torch.utils.data.DataLoader(
173
+ dataset,
174
+ batch_size=cfg.batch_size,
175
+ shuffle=shuffle,
176
+ sampler=sampler,
177
+ num_workers=int(cfg.num_workers),
178
+ )
AIGVDet/core/utils1/utils1/earlystop.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from .trainer import Trainer
4
+
5
+
6
+ class EarlyStopping:
7
+ """Early stops the training if validation loss doesn't improve after a given patience."""
8
+
9
+ def __init__(self, patience=1, verbose=False, delta=0):
10
+ """
11
+ Args:
12
+ patience (int): How long to wait after last time validation loss improved.
13
+ Default: 7
14
+ verbose (bool): If True, prints a message for each validation loss improvement.
15
+ Default: False
16
+ delta (float): Minimum change in the monitored quantity to qualify as an improvement.
17
+ Default: 0
18
+ """
19
+ self.patience = patience
20
+ self.verbose = verbose
21
+ self.counter = 0
22
+ self.best_score = None
23
+ self.early_stop = False
24
+ self.score_max = -np.Inf
25
+ self.delta = delta
26
+
27
+ def __call__(self, score: float, trainer: Trainer):
28
+ if self.best_score is None:
29
+ self.best_score = score
30
+ self.save_checkpoint(score, trainer)
31
+ elif score < self.best_score - self.delta:
32
+ self.counter += 1
33
+ print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
34
+ if self.counter >= self.patience:
35
+ self.early_stop = True
36
+ else:
37
+ self.best_score = score
38
+ self.save_checkpoint(score, trainer)
39
+ self.counter = 0
40
+
41
+ def save_checkpoint(self, score: float, trainer: Trainer):
42
+ """Saves model when validation loss decrease."""
43
+ if self.verbose:
44
+ print(f"Validation accuracy increased ({self.score_max:.6f} --> {score:.6f}). Saving model ...")
45
+ trainer.save_networks("best")
46
+ self.score_max = score
AIGVDet/core/utils1/utils1/eval.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from .config import CONFIGCLASS
10
+ from .utils import to_cuda
11
+
12
+
13
+ def get_val_cfg(cfg: CONFIGCLASS, split="val", copy=True):
14
+ if copy:
15
+ from copy import deepcopy
16
+
17
+ val_cfg = deepcopy(cfg)
18
+ else:
19
+ val_cfg = cfg
20
+ val_cfg.dataset_root = os.path.join(val_cfg.dataset_root, split)
21
+ val_cfg.datasets = cfg.datasets_test
22
+ val_cfg.isTrain = False
23
+ # val_cfg.aug_resize = False
24
+ # val_cfg.aug_crop = False
25
+ val_cfg.aug_flip = False
26
+ val_cfg.serial_batches = True
27
+ val_cfg.jpg_method = ["pil"]
28
+ # Currently assumes jpg_prob, blur_prob 0 or 1
29
+ if len(val_cfg.blur_sig) == 2:
30
+ b_sig = val_cfg.blur_sig
31
+ val_cfg.blur_sig = [(b_sig[0] + b_sig[1]) / 2]
32
+ if len(val_cfg.jpg_qual) != 1:
33
+ j_qual = val_cfg.jpg_qual
34
+ val_cfg.jpg_qual = [int((j_qual[0] + j_qual[-1]) / 2)]
35
+ return val_cfg
36
+
37
+ def validate(model: nn.Module, cfg: CONFIGCLASS):
38
+ from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score
39
+
40
+ from .datasets import create_dataloader
41
+
42
+ data_loader = create_dataloader(cfg)
43
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
+
45
+ with torch.no_grad():
46
+ y_true, y_pred = [], []
47
+ for data in data_loader:
48
+ img, label, meta = data if len(data) == 3 else (*data, None)
49
+ in_tens = to_cuda(img, device)
50
+ meta = to_cuda(meta, device)
51
+ predict = model(in_tens, meta).sigmoid()
52
+ y_pred.extend(predict.flatten().tolist())
53
+ y_true.extend(label.flatten().tolist())
54
+
55
+ y_true, y_pred = np.array(y_true), np.array(y_pred)
56
+ r_acc = accuracy_score(y_true[y_true == 0], y_pred[y_true == 0] > 0.5)
57
+ f_acc = accuracy_score(y_true[y_true == 1], y_pred[y_true == 1] > 0.5)
58
+ acc = accuracy_score(y_true, y_pred > 0.5)
59
+ ap = average_precision_score(y_true, y_pred)
60
+ results = {
61
+ "ACC": acc,
62
+ "AP": ap,
63
+ "R_ACC": r_acc,
64
+ "F_ACC": f_acc,
65
+ }
66
+ return results
AIGVDet/core/utils1/utils1/trainer.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import init
6
+
7
+ from .config import CONFIGCLASS
8
+ from .utils import get_network
9
+ from .warmup import GradualWarmupScheduler
10
+
11
+
12
+ class BaseModel(nn.Module):
13
+ def __init__(self, cfg: CONFIGCLASS):
14
+ super().__init__()
15
+ self.cfg = cfg
16
+ self.total_steps = 0
17
+ self.isTrain = cfg.isTrain
18
+ self.save_dir = cfg.ckpt_dir
19
+ self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
20
+ self.model:nn.Module
21
+ self.model=nn.Module.to(self.device)
22
+ # self.model.to(self.device)
23
+ self.model.load_state_dict(torch.load('./checkpoints/optical.pth'))
24
+ self.optimizer: torch.optim.Optimizer
25
+
26
+ def save_networks(self, epoch: int):
27
+ save_filename = f"model_epoch_{epoch}.pth"
28
+ save_path = os.path.join(self.save_dir, save_filename)
29
+
30
+ # serialize model and optimizer to dict
31
+ state_dict = {
32
+ "model": self.model.state_dict(),
33
+ "optimizer": self.optimizer.state_dict(),
34
+ "total_steps": self.total_steps,
35
+ }
36
+
37
+ torch.save(state_dict, save_path)
38
+
39
+ # load models from the disk
40
+ def load_networks(self, epoch: int):
41
+ load_filename = f"model_epoch_{epoch}.pth"
42
+ load_path = os.path.join(self.save_dir, load_filename)
43
+
44
+ if epoch==0:
45
+ # load_filename = f"lsun_adm.pth"
46
+ load_path="checkpoints/optical.pth"
47
+ print("loading optical path")
48
+ else :
49
+ print(f"loading the model from {load_path}")
50
+
51
+ # print(f"loading the model from {load_path}")
52
+
53
+ # if you are using PyTorch newer than 0.4 (e.g., built from
54
+ # GitHub source), you can remove str() on self.device
55
+ state_dict = torch.load(load_path, map_location=self.device)
56
+ if hasattr(state_dict, "_metadata"):
57
+ del state_dict._metadata
58
+
59
+ self.model.load_state_dict(state_dict["model"])
60
+ self.total_steps = state_dict["total_steps"]
61
+
62
+ if self.isTrain and not self.cfg.new_optim:
63
+ self.optimizer.load_state_dict(state_dict["optimizer"])
64
+ # move optimizer state to GPU
65
+ for state in self.optimizer.state.values():
66
+ for k, v in state.items():
67
+ if torch.is_tensor(v):
68
+ state[k] = v.to(self.device)
69
+
70
+ for g in self.optimizer.param_groups:
71
+ g["lr"] = self.cfg.lr
72
+
73
+ def eval(self):
74
+ self.model.eval()
75
+
76
+ def test(self):
77
+ with torch.no_grad():
78
+ self.forward()
79
+
80
+
81
+ def init_weights(net: nn.Module, init_type="normal", gain=0.02):
82
+ def init_func(m: nn.Module):
83
+ classname = m.__class__.__name__
84
+ if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1):
85
+ if init_type == "normal":
86
+ init.normal_(m.weight.data, 0.0, gain)
87
+ elif init_type == "xavier":
88
+ init.xavier_normal_(m.weight.data, gain=gain)
89
+ elif init_type == "kaiming":
90
+ init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
91
+ elif init_type == "orthogonal":
92
+ init.orthogonal_(m.weight.data, gain=gain)
93
+ else:
94
+ raise NotImplementedError(f"initialization method [{init_type}] is not implemented")
95
+ if hasattr(m, "bias") and m.bias is not None:
96
+ init.constant_(m.bias.data, 0.0)
97
+ elif classname.find("BatchNorm2d") != -1:
98
+ init.normal_(m.weight.data, 1.0, gain)
99
+ init.constant_(m.bias.data, 0.0)
100
+
101
+ print(f"initialize network with {init_type}")
102
+ net.apply(init_func)
103
+
104
+
105
+ class Trainer(BaseModel):
106
+ def name(self):
107
+ return "Trainer"
108
+
109
+ def __init__(self, cfg: CONFIGCLASS):
110
+ super().__init__(cfg)
111
+ self.arch = cfg.arch
112
+ self.model = get_network(self.arch, cfg.isTrain, cfg.continue_train, cfg.init_gain, cfg.pretrained)
113
+
114
+ self.loss_fn = nn.BCEWithLogitsLoss()
115
+ # initialize optimizers
116
+ if cfg.optim == "adam":
117
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg.lr, betas=(cfg.beta1, 0.999))
118
+ elif cfg.optim == "sgd":
119
+ self.optimizer = torch.optim.SGD(self.model.parameters(), lr=cfg.lr, momentum=0.9, weight_decay=5e-4)
120
+ else:
121
+ raise ValueError("optim should be [adam, sgd]")
122
+ if cfg.warmup:
123
+ scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
124
+ self.optimizer, cfg.nepoch - cfg.warmup_epoch, eta_min=1e-6
125
+ )
126
+ self.scheduler = GradualWarmupScheduler(
127
+ self.optimizer, multiplier=1, total_epoch=cfg.warmup_epoch, after_scheduler=scheduler_cosine
128
+ )
129
+ self.scheduler.step()
130
+ if cfg.continue_train:
131
+ self.load_networks(cfg.epoch)
132
+ self.model.to(self.device)
133
+
134
+ # self.model.load_state_dict(torch.load('checkpoints/optical.pth'))
135
+ load_path='checkpoints/optical.pth'
136
+ state_dict = torch.load(load_path, map_location=self.device)
137
+
138
+
139
+ self.model.load_state_dict(state_dict["model"])
140
+
141
+
142
+ def adjust_learning_rate(self, min_lr=1e-6):
143
+ for param_group in self.optimizer.param_groups:
144
+ param_group["lr"] /= 10.0
145
+ if param_group["lr"] < min_lr:
146
+ return False
147
+ return True
148
+
149
+ def set_input(self, input):
150
+ img, label, meta = input if len(input) == 3 else (input[0], input[1], {})
151
+ self.input = img.to(self.device)
152
+ self.label = label.to(self.device).float()
153
+ for k in meta.keys():
154
+ if isinstance(meta[k], torch.Tensor):
155
+ meta[k] = meta[k].to(self.device)
156
+ self.meta = meta
157
+
158
+ def forward(self):
159
+ self.output = self.model(self.input, self.meta)
160
+
161
+ def get_loss(self):
162
+ return self.loss_fn(self.output.squeeze(1), self.label)
163
+
164
+ def optimize_parameters(self):
165
+ self.forward()
166
+ self.loss = self.loss_fn(self.output.squeeze(1), self.label)
167
+ self.optimizer.zero_grad()
168
+ self.loss.backward()
169
+ self.optimizer.step()
AIGVDet/core/utils1/utils1/utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import time
5
+ import warnings
6
+ from importlib import import_module
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ from PIL import Image
12
+
13
+ warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.functional")
14
+
15
+
16
+ def str2bool(v: str, strict=True) -> bool:
17
+ if isinstance(v, bool):
18
+ return v
19
+ elif isinstance(v, str):
20
+ if v.lower() in ("true", "yes", "on" "t", "y", "1"):
21
+ return True
22
+ elif v.lower() in ("false", "no", "off", "f", "n", "0"):
23
+ return False
24
+ if strict:
25
+ raise argparse.ArgumentTypeError("Unsupported value encountered.")
26
+ else:
27
+ return True
28
+
29
+
30
+ def to_cuda(data, device="cuda", exclude_keys: "list[str]" = None):
31
+ if isinstance(data, torch.Tensor):
32
+ data = data.to(device)
33
+ elif isinstance(data, (tuple, list, set)):
34
+ data = [to_cuda(b, device) for b in data]
35
+ elif isinstance(data, dict):
36
+ if exclude_keys is None:
37
+ exclude_keys = []
38
+ for k in data.keys():
39
+ if k not in exclude_keys:
40
+ data[k] = to_cuda(data[k], device)
41
+ else:
42
+ # raise TypeError(f"Unsupported type: {type(data)}")
43
+ data = data
44
+ return data
45
+
46
+
47
+ class HiddenPrints:
48
+ def __enter__(self):
49
+ self._original_stdout = sys.stdout
50
+ sys.stdout = open(os.devnull, "w")
51
+
52
+ def __exit__(self, exc_type, exc_val, exc_tb):
53
+ sys.stdout.close()
54
+ sys.stdout = self._original_stdout
55
+
56
+
57
+ class Logger(object):
58
+ def __init__(self):
59
+ self.terminal = sys.stdout
60
+ self.file = None
61
+
62
+ def open(self, file, mode=None):
63
+ if mode is None:
64
+ mode = "w"
65
+ self.file = open(file, mode)
66
+
67
+ def write(self, message, is_terminal=1, is_file=1):
68
+ if "\r" in message:
69
+ is_file = 0
70
+ if is_terminal == 1:
71
+ self.terminal.write(message)
72
+ self.terminal.flush()
73
+ if is_file == 1:
74
+ self.file.write(message)
75
+ self.file.flush()
76
+
77
+ def flush(self):
78
+ # this flush method is needed for python 3 compatibility.
79
+ # this handles the flush command by doing nothing.
80
+ # you might want to specify some extra behavior here.
81
+ pass
82
+
83
+
84
+ def get_network(arch: str, isTrain=False, continue_train=False, init_gain=0.02, pretrained=True):
85
+ if "resnet" in arch:
86
+ from networks.resnet import ResNet
87
+
88
+ resnet = getattr(import_module("networks.resnet"), arch)
89
+ if isTrain:
90
+ if continue_train:
91
+ model: ResNet = resnet(num_classes=1)
92
+ else:
93
+ model: ResNet = resnet(pretrained=pretrained)
94
+ model.fc = nn.Linear(2048, 1)
95
+ nn.init.normal_(model.fc.weight.data, 0.0, init_gain)
96
+ else:
97
+ model: ResNet = resnet(num_classes=1)
98
+ return model
99
+ else:
100
+ raise ValueError(f"Unsupported arch: {arch}")
101
+
102
+
103
+ def pad_img_to_square(img: np.ndarray):
104
+ H, W = img.shape[:2]
105
+ if H != W:
106
+ new_size = max(H, W)
107
+ img = np.pad(img, ((0, new_size - H), (0, new_size - W), (0, 0)), mode="constant")
108
+ assert img.shape[0] == img.shape[1] == new_size
109
+ return img
AIGVDet/core/utils1/utils1/warmup.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler
2
+
3
+
4
+ class GradualWarmupScheduler(_LRScheduler):
5
+ """Gradually warm-up(increasing) learning rate in optimizer.
6
+ Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
7
+
8
+ Args:
9
+ optimizer (Optimizer): Wrapped optimizer.
10
+ multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
11
+ total_epoch: target learning rate is reached at total_epoch, gradually
12
+ after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
13
+ """
14
+
15
+ def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
16
+ self.multiplier = multiplier
17
+ if self.multiplier < 1.0:
18
+ raise ValueError("multiplier should be greater thant or equal to 1.")
19
+ self.total_epoch = total_epoch
20
+ self.after_scheduler = after_scheduler
21
+ self.finished = False
22
+ super().__init__(optimizer)
23
+
24
+ def get_lr(self):
25
+ if self.last_epoch > self.total_epoch:
26
+ if self.after_scheduler:
27
+ if not self.finished:
28
+ self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
29
+ self.finished = True
30
+ return self.after_scheduler.get_last_lr()
31
+ return [base_lr * self.multiplier for base_lr in self.base_lrs]
32
+
33
+ if self.multiplier == 1.0:
34
+ return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
35
+ else:
36
+ return [
37
+ base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
38
+ for base_lr in self.base_lrs
39
+ ]
40
+
41
+ def step_ReduceLROnPlateau(self, metrics, epoch=None):
42
+ if epoch is None:
43
+ epoch = self.last_epoch + 1
44
+ self.last_epoch = (
45
+ epoch if epoch != 0 else 1
46
+ ) # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
47
+ if self.last_epoch <= self.total_epoch:
48
+ warmup_lr = [
49
+ base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
50
+ for base_lr in self.base_lrs
51
+ ]
52
+ for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
53
+ param_group["lr"] = lr
54
+ else:
55
+ if epoch is None:
56
+ self.after_scheduler.step(metrics, None)
57
+ else:
58
+ self.after_scheduler.step(metrics, epoch - self.total_epoch)
59
+
60
+ def step(self, epoch=None, metrics=None):
61
+ if type(self.after_scheduler) != ReduceLROnPlateau:
62
+ if self.finished and self.after_scheduler:
63
+ if epoch is None:
64
+ self.after_scheduler.step(None)
65
+ else:
66
+ self.after_scheduler.step(epoch - self.total_epoch)
67
+ else:
68
+ return super().step(epoch)
69
+ else:
70
+ self.step_ReduceLROnPlateau(metrics, epoch)
AIGVDet/core/utils1/warmup.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler
2
+
3
+
4
+ class GradualWarmupScheduler(_LRScheduler):
5
+ """Gradually warm-up(increasing) learning rate in optimizer.
6
+ Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
7
+
8
+ Args:
9
+ optimizer (Optimizer): Wrapped optimizer.
10
+ multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
11
+ total_epoch: target learning rate is reached at total_epoch, gradually
12
+ after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
13
+ """
14
+
15
+ def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
16
+ self.multiplier = multiplier
17
+ if self.multiplier < 1.0:
18
+ raise ValueError("multiplier should be greater thant or equal to 1.")
19
+ self.total_epoch = total_epoch
20
+ self.after_scheduler = after_scheduler
21
+ self.finished = False
22
+ super().__init__(optimizer)
23
+
24
+ def get_lr(self):
25
+ if self.last_epoch > self.total_epoch:
26
+ if self.after_scheduler:
27
+ if not self.finished:
28
+ self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
29
+ self.finished = True
30
+ return self.after_scheduler.get_last_lr()
31
+ return [base_lr * self.multiplier for base_lr in self.base_lrs]
32
+
33
+ if self.multiplier == 1.0:
34
+ return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
35
+ else:
36
+ return [
37
+ base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
38
+ for base_lr in self.base_lrs
39
+ ]
40
+
41
+ def step_ReduceLROnPlateau(self, metrics, epoch=None):
42
+ if epoch is None:
43
+ epoch = self.last_epoch + 1
44
+ self.last_epoch = (
45
+ epoch if epoch != 0 else 1
46
+ ) # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
47
+ if self.last_epoch <= self.total_epoch:
48
+ warmup_lr = [
49
+ base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
50
+ for base_lr in self.base_lrs
51
+ ]
52
+ for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
53
+ param_group["lr"] = lr
54
+ else:
55
+ if epoch is None:
56
+ self.after_scheduler.step(metrics, None)
57
+ else:
58
+ self.after_scheduler.step(metrics, epoch - self.total_epoch)
59
+
60
+ def step(self, epoch=None, metrics=None):
61
+ if type(self.after_scheduler) != ReduceLROnPlateau:
62
+ if self.finished and self.after_scheduler:
63
+ if epoch is None:
64
+ self.after_scheduler.step(None)
65
+ else:
66
+ self.after_scheduler.step(epoch - self.total_epoch)
67
+ else:
68
+ return super().step(epoch)
69
+ else:
70
+ self.step_ReduceLROnPlateau(metrics, epoch)
AIGVDet/docker-compose.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "3.9"
2
+
3
+ services:
4
+ web:
5
+ build: .
6
+ ports:
7
+ - "8003:8003"
8
+ restart: unless-stopped
9
+ deploy:
10
+ resources:
11
+ reservations:
12
+ devices:
13
+ - driver: nvidia
14
+ count: all
15
+ capabilities: [gpu]
16
+ ipc: host
17
+ shm_size: "24gb"
AIGVDet/main.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import shutil
4
+ from typing import Dict, Optional
5
+
6
+ from .run import RUN
7
+
8
+
9
+ def run_video_to_json(
10
+ video_path: str,
11
+ output_json_path: Optional[str] = None,
12
+ model_optical_path: str = "checkpoints/optical.pth",
13
+ model_original_path: str = "checkpoints/original.pth",
14
+ frame_root: str = "frame",
15
+ optical_root: str = "optical_result"
16
+ ) -> Dict:
17
+ """
18
+ Xử lý 1 video và ghi kết quả ra file JSON.
19
+
20
+ Returns:
21
+ dict kết quả (đồng thời ghi ra JSON)
22
+ """
23
+
24
+ script_dir = os.path.dirname(os.path.abspath(__file__))
25
+ if not os.path.isabs(model_optical_path):
26
+ model_optical_path = os.path.join(script_dir, model_optical_path)
27
+ if not os.path.isabs(model_original_path):
28
+ model_original_path = os.path.join(script_dir, model_original_path)
29
+
30
+ results = {}
31
+
32
+ if not os.path.isfile(video_path):
33
+ raise FileNotFoundError(f"File not found: {video_path}")
34
+
35
+ video_name = os.path.basename(video_path)
36
+ video_id = os.path.splitext(video_name)[0]
37
+
38
+ folder_original = os.path.join(frame_root, video_id)
39
+ folder_optical = os.path.join(optical_root, video_id)
40
+
41
+ args = [
42
+ '--path', video_path,
43
+ '--folder_original_path', folder_original,
44
+ '--folder_optical_flow_path', folder_optical,
45
+ '--model_optical_flow_path', model_optical_path,
46
+ '--model_original_path', model_original_path
47
+ ]
48
+
49
+ try:
50
+ output = RUN(args)
51
+ results[video_id] = {
52
+ "video_name": video_name,
53
+ "authentic_confidence_score": round(output["real_score"], 4),
54
+ "synthetic_confidence_score": round(output["fake_score"], 4)
55
+ }
56
+
57
+ except Exception as e:
58
+ results[video_id] = {
59
+ "video_name": video_name,
60
+ "error": str(e)
61
+ }
62
+
63
+ finally:
64
+ # Clean up intermediate folders
65
+ for folder in [folder_original, folder_optical]:
66
+ try:
67
+ if os.path.exists(folder):
68
+ shutil.rmtree(folder)
69
+ except Exception:
70
+ pass
71
+
72
+ # Save result to JSON or return dict
73
+ if output_json_path:
74
+ os.makedirs(os.path.dirname(output_json_path), exist_ok=True)
75
+ with open(output_json_path, "w", encoding="utf-8") as f:
76
+ json.dump(results, f, ensure_ascii=False, indent=2)
77
+
78
+ return results
AIGVDet/networks/resnet.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.utils.model_zoo as model_zoo
3
+
4
+ __all__ = ["ResNet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152"]
5
+
6
+
7
+ model_urls = {
8
+ "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
9
+ "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
10
+ "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
11
+ "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
12
+ "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
13
+ }
14
+
15
+
16
+ def conv3x3(in_planes, out_planes, stride=1):
17
+ """3x3 convolution with padding"""
18
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
19
+
20
+
21
+ def conv1x1(in_planes, out_planes, stride=1):
22
+ """1x1 convolution"""
23
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
24
+
25
+
26
+ class BasicBlock(nn.Module):
27
+ expansion = 1
28
+
29
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
30
+ super().__init__()
31
+ self.conv1 = conv3x3(inplanes, planes, stride)
32
+ self.bn1 = nn.BatchNorm2d(planes)
33
+ self.relu = nn.ReLU(inplace=True)
34
+ self.conv2 = conv3x3(planes, planes)
35
+ self.bn2 = nn.BatchNorm2d(planes)
36
+ self.downsample = downsample
37
+ self.stride = stride
38
+
39
+ def forward(self, x):
40
+ identity = x
41
+
42
+ out = self.conv1(x)
43
+ out = self.bn1(out)
44
+ out = self.relu(out)
45
+
46
+ out = self.conv2(out)
47
+ out = self.bn2(out)
48
+
49
+ if self.downsample is not None:
50
+ identity = self.downsample(x)
51
+
52
+ out += identity
53
+ out = self.relu(out)
54
+
55
+ return out
56
+
57
+
58
+ class Bottleneck(nn.Module):
59
+ expansion = 4
60
+
61
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
62
+ super().__init__()
63
+ self.conv1 = conv1x1(inplanes, planes)
64
+ self.bn1 = nn.BatchNorm2d(planes)
65
+ self.conv2 = conv3x3(planes, planes, stride)
66
+ self.bn2 = nn.BatchNorm2d(planes)
67
+ self.conv3 = conv1x1(planes, planes * self.expansion)
68
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
69
+ self.relu = nn.ReLU(inplace=True)
70
+ self.downsample = downsample
71
+ self.stride = stride
72
+
73
+ def forward(self, x):
74
+ identity = x
75
+
76
+ out = self.conv1(x)
77
+ out = self.bn1(out)
78
+ out = self.relu(out)
79
+
80
+ out = self.conv2(out)
81
+ out = self.bn2(out)
82
+ out = self.relu(out)
83
+
84
+ out = self.conv3(out)
85
+ out = self.bn3(out)
86
+
87
+ if self.downsample is not None:
88
+ identity = self.downsample(x)
89
+
90
+ out += identity
91
+ out = self.relu(out)
92
+
93
+ return out
94
+
95
+
96
+ class ResNet(nn.Module):
97
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):
98
+ super().__init__()
99
+ self.inplanes = 64
100
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
101
+ self.bn1 = nn.BatchNorm2d(64)
102
+ self.relu = nn.ReLU(inplace=True)
103
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
104
+ self.layer1 = self._make_layer(block, 64, layers[0])
105
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
106
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
107
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
108
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
109
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
110
+
111
+ for m in self.modules():
112
+ if isinstance(m, nn.Conv2d):
113
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
114
+ elif isinstance(m, nn.BatchNorm2d):
115
+ nn.init.constant_(m.weight, 1)
116
+ nn.init.constant_(m.bias, 0)
117
+
118
+ # Zero-initialize the last BN in each residual branch,
119
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
120
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
121
+ if zero_init_residual:
122
+ for m in self.modules():
123
+ if isinstance(m, Bottleneck):
124
+ nn.init.constant_(m.bn3.weight, 0)
125
+ elif isinstance(m, BasicBlock):
126
+ nn.init.constant_(m.bn2.weight, 0)
127
+
128
+ def _make_layer(self, block, planes, blocks, stride=1):
129
+ downsample = None
130
+ if stride != 1 or self.inplanes != planes * block.expansion:
131
+ downsample = nn.Sequential(
132
+ conv1x1(self.inplanes, planes * block.expansion, stride),
133
+ nn.BatchNorm2d(planes * block.expansion),
134
+ )
135
+
136
+ layers = [block(self.inplanes, planes, stride, downsample)]
137
+ self.inplanes = planes * block.expansion
138
+ layers.extend(block(self.inplanes, planes) for _ in range(1, blocks))
139
+ return nn.Sequential(*layers)
140
+
141
+ def forward(self, x, *args):
142
+ x = self.conv1(x)
143
+ x = self.bn1(x)
144
+ x = self.relu(x)
145
+ x = self.maxpool(x)
146
+
147
+ x = self.layer1(x)
148
+ x = self.layer2(x)
149
+ x = self.layer3(x)
150
+ x = self.layer4(x)
151
+
152
+ x = self.avgpool(x)
153
+ x = x.view(x.size(0), -1)
154
+ x = self.fc(x)
155
+
156
+ return x
157
+
158
+
159
+ def resnet18(pretrained=False, **kwargs):
160
+ """Constructs a ResNet-18 model.
161
+ Args:
162
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
163
+ """
164
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
165
+ if pretrained:
166
+ model.load_state_dict(model_zoo.load_url(model_urls["resnet18"]))
167
+ return model
168
+
169
+
170
+ def resnet34(pretrained=False, **kwargs):
171
+ """Constructs a ResNet-34 model.
172
+ Args:
173
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
174
+ """
175
+ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
176
+ if pretrained:
177
+ model.load_state_dict(model_zoo.load_url(model_urls["resnet34"]))
178
+ return model
179
+
180
+
181
+ def resnet50(pretrained=False, **kwargs):
182
+ """Constructs a ResNet-50 model.
183
+ Args:
184
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
185
+ """
186
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
187
+ if pretrained:
188
+ model.load_state_dict(model_zoo.load_url(model_urls["resnet50"]))
189
+ return model
190
+
191
+
192
+ def resnet101(pretrained=False, **kwargs):
193
+ """Constructs a ResNet-101 model.
194
+ Args:
195
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
196
+ """
197
+ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
198
+ if pretrained:
199
+ model.load_state_dict(model_zoo.load_url(model_urls["resnet101"]))
200
+ return model
201
+
202
+
203
+ def resnet152(pretrained=False, **kwargs):
204
+ """Constructs a ResNet-152 model.
205
+ Args:
206
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
207
+ """
208
+ model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
209
+ if pretrained:
210
+ model.load_state_dict(model_zoo.load_url(model_urls["resnet152"]))
211
+ return model
AIGVDet/raft_model/raft-things.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fcfa4125d6418f4de95d84aec20a3c5f4e205101715a79f193243c186ac9a7e1
3
+ size 21108000
AIGVDet/requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # conda create -n aigvdet python=3.9
2
+ # pip install torch==2.0.0+cu117 torchvision==0.15.1+cu117 -f https://download.pytorch.org/whl/torch_stable.html
3
+ einops
4
+ imageio
5
+ ipympl
6
+ matplotlib
7
+ numpy<2
8
+ opencv-python
9
+ pandas
10
+ scikit-learn
11
+ tensorboard
12
+ tensorboardX
13
+ tqdm
14
+ blobfile>=1.0.5
15
+ natsort
16
+ fastapi==0.116.1
17
+ pydantic==2.11.7
18
+ uvicorn[standard]
19
+ torch==2.0.0+cu117
20
+ torchvision==0.15.1+cu117
21
+ -f https://download.pytorch.org/whl/torch_stable.html
22
+ python-multipart
AIGVDet/run.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import argparse
3
+ import os
4
+ import cv2
5
+ import glob
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+ import pandas as pd
10
+ import torch
11
+ import torch.nn
12
+ import torchvision.transforms as transforms
13
+ import torchvision.transforms.functional as TF
14
+ from tqdm import tqdm
15
+ import time
16
+
17
+ from .core.raft import RAFT
18
+ from .core.utils import flow_viz
19
+ from .core.utils.utils import InputPadder
20
+ from natsort import natsorted
21
+ from .core.utils1.utils import get_network, str2bool, to_cuda
22
+ from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score,roc_auc_score
23
+
24
+ DEVICE = 'cuda'
25
+ # DEVICE = 'cpu' # Changed to 'cpu'
26
+ device = torch.device(DEVICE)
27
+
28
+
29
+ def load_image(imfile):
30
+ img = np.array(Image.open(imfile)).astype(np.uint8)
31
+ img = torch.from_numpy(img).permute(2, 0, 1).float()
32
+ return img[None].to(DEVICE)
33
+
34
+
35
+ def viz(img, flo, folder_optical_flow_path, imfile1):
36
+ img = img[0].permute(1, 2, 0).cpu().numpy()
37
+ flo = flo[0].permute(1, 2, 0).cpu().numpy()
38
+
39
+ # map flow to RGB image
40
+ flo = flow_viz.flow_to_image(flo)
41
+ img_flo = np.concatenate([img, flo], axis=0)
42
+
43
+ # extract filename safely (cross-platform)
44
+ filename = os.path.basename(imfile1).strip()
45
+ output_path = os.path.join(folder_optical_flow_path, filename)
46
+
47
+ print(output_path)
48
+ cv2.imwrite(output_path, flo)
49
+
50
+
51
+ def video_to_frames(video_path, output_folder):
52
+ if not os.path.exists(output_folder):
53
+ os.makedirs(output_folder)
54
+
55
+ cap = cv2.VideoCapture(video_path)
56
+ frame_count = 0
57
+
58
+ while cap.isOpened():
59
+ ret, frame = cap.read()
60
+ if not ret:
61
+ break
62
+
63
+ frame_filename = os.path.join(output_folder, f"frame_{frame_count:05d}.png")
64
+ cv2.imwrite(frame_filename, frame)
65
+ frame_count += 1
66
+
67
+ cap.release()
68
+
69
+ images = glob.glob(os.path.join(output_folder, '*.png')) + \
70
+ glob.glob(os.path.join(output_folder, '*.jpg'))
71
+ images = sorted(images)
72
+
73
+ return images
74
+
75
+
76
+ # generate optical flow images
77
+ def OF_gen(args):
78
+ model = torch.nn.DataParallel(RAFT(args))
79
+ model.load_state_dict(torch.load(args.model, map_location=torch.device(DEVICE)))
80
+
81
+ model = model.module
82
+ model.to(DEVICE)
83
+ model.eval()
84
+
85
+ if not os.path.exists(args.folder_optical_flow_path):
86
+ os.makedirs(args.folder_optical_flow_path)
87
+ print(f'{args.folder_optical_flow_path}')
88
+
89
+ with torch.no_grad():
90
+
91
+ images = video_to_frames(args.path, args.folder_original_path)
92
+ images = natsorted(images)
93
+
94
+ for imfile1, imfile2 in zip(images[:-1], images[1:]):
95
+ image1 = load_image(imfile1)
96
+ image2 = load_image(imfile2)
97
+
98
+ padder = InputPadder(image1.shape)
99
+ image1, image2 = padder.pad(image1, image2)
100
+
101
+ flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
102
+
103
+ viz(image1, flow_up,args.folder_optical_flow_path,imfile1)
104
+
105
+
106
+ def RUN(args=None):
107
+ script_dir = os.path.dirname(os.path.abspath(__file__))
108
+ default_model_path = os.path.join(script_dir, "raft_model/raft-things.pth")
109
+
110
+ parser = argparse.ArgumentParser()
111
+ parser.add_argument('--model', help="restore checkpoint",default=default_model_path)
112
+ parser.add_argument('--path', help="dataset for evaluation",default="video/000000.mp4")
113
+ parser.add_argument('--folder_original_path', help="dataset for evaluation_frames",default="frame/000000")
114
+ parser.add_argument('--small', action='store_true', help='use small model')
115
+ parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
116
+ parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
117
+ parser.add_argument('--folder_optical_flow_path',help="the results to save",default="optical_result/000000")
118
+ parser.add_argument(
119
+ "-mop",
120
+ "--model_optical_flow_path",
121
+ type=str,
122
+ default="checkpoints/optical.pth",
123
+ )
124
+ parser.add_argument(
125
+ "-mor",
126
+ "--model_original_path",
127
+ type=str,
128
+ default="checkpoints/original.pth",
129
+ )
130
+ parser.add_argument(
131
+ "-t",
132
+ "--threshold",
133
+ type=float,
134
+ default=0.5,
135
+ )
136
+ parser.add_argument("--use_cpu", action="store_true", help="uses gpu by default, turn on to use cpu")
137
+ parser.add_argument("--arch", type=str, default="resnet50")
138
+ parser.add_argument("--aug_norm", type=str2bool, default=True)
139
+ args = parser.parse_args(args)
140
+ start_time = time.perf_counter()
141
+ OF_gen(args)
142
+ elapsed = time.perf_counter() - start_time
143
+ print(f"⏱️ [OF_gen] Service call took {elapsed:.2f} seconds")
144
+ # Load models
145
+ model_op = get_network(args.arch)
146
+ state_dict = torch.load(args.model_optical_flow_path, map_location='cpu')
147
+ model_op.load_state_dict(state_dict["model"] if "model" in state_dict else state_dict)
148
+ model_op.eval().to(device)
149
+
150
+ model_or = get_network(args.arch)
151
+ state_dict = torch.load(args.model_original_path, map_location='cpu')
152
+ model_or.load_state_dict(state_dict["model"] if "model" in state_dict else state_dict)
153
+ model_or.eval().to(device)
154
+
155
+ # Transform
156
+ trans = transforms.Compose([
157
+ transforms.CenterCrop((448, 448)),
158
+ transforms.ToTensor(),
159
+ ])
160
+
161
+ # Process original frames
162
+ original_file_list = sorted(
163
+ glob.glob(os.path.join(args.folder_original_path, "*.jpg")) +
164
+ glob.glob(os.path.join(args.folder_original_path, "*.png")) +
165
+ glob.glob(os.path.join(args.folder_original_path, "*.JPEG"))
166
+ )
167
+ original_prob_sum = 0
168
+ for img_path in tqdm(original_file_list, desc="Original", dynamic_ncols=True, disable=len(original_file_list) <= 1):
169
+ img = Image.open(img_path).convert("RGB")
170
+ img = trans(img)
171
+ if args.aug_norm:
172
+ img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
173
+ in_tens = img.unsqueeze(0).to(device)
174
+ with torch.no_grad():
175
+ prob = model_or(in_tens).sigmoid().item()
176
+ original_prob_sum += prob
177
+
178
+ original_prob = original_prob_sum / len(original_file_list)
179
+ print(f"Original prob: {original_prob:.4f}")
180
+
181
+ # Process optical flow frames
182
+ optical_file_list = sorted(
183
+ glob.glob(os.path.join(args.folder_optical_flow_path, "*.jpg")) +
184
+ glob.glob(os.path.join(args.folder_optical_flow_path, "*.png")) +
185
+ glob.glob(os.path.join(args.folder_optical_flow_path, "*.JPEG"))
186
+ )
187
+ optical_prob_sum = 0
188
+ for img_path in tqdm(optical_file_list, desc="Optical Flow", dynamic_ncols=True, disable=len(optical_file_list) <= 1):
189
+ img = Image.open(img_path).convert("RGB")
190
+ img = trans(img)
191
+ if args.aug_norm:
192
+ img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
193
+ in_tens = img.unsqueeze(0).to(device)
194
+ with torch.no_grad():
195
+ prob = model_op(in_tens).sigmoid().item()
196
+ optical_prob_sum += prob
197
+
198
+ optical_prob = optical_prob_sum / len(optical_file_list)
199
+ print(f"Optical prob: {optical_prob:.4f}")
200
+
201
+ final_prob = (original_prob + optical_prob) / 2
202
+ print(f"predict: {final_prob}")
203
+
204
+ real_score = 1 - final_prob
205
+ fake_score = final_prob
206
+ print(f"Confidence scores - Real: {real_score:.4f}, Fake: {fake_score:.4f}")
207
+
208
+ return {
209
+ "original_prob": original_prob,
210
+ "optical_prob": optical_prob,
211
+ "final_predict": final_prob,
212
+ "real_score": real_score,
213
+ "fake_score": fake_score
214
+ }
AIGVDet/run.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ python demo.py --path "demo_video/real/094.mp4" --folder_original_path "frame/000000" --folder_optical_flow_path "optical_result/000000" -mop "checkpoints/optical.pth" -mor "checkpoints/original.pth"
AIGVDet/test.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import os
4
+ import pandas as pd
5
+
6
+ import torch
7
+ import torch.nn
8
+ import torchvision.transforms as transforms
9
+ import torchvision.transforms.functional as TF
10
+ from PIL import Image
11
+ from tqdm import tqdm
12
+
13
+ from core.utils1.utils import get_network, str2bool, to_cuda
14
+ from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score,roc_auc_score
15
+
16
+ if __name__=="__main__":
17
+
18
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
19
+ parser.add_argument(
20
+ "-fop", "--folder_optical_flow_path", default="data/test/T2V/videocraft", type=str, help="path to optical flow imagefile folder"
21
+ )
22
+ parser.add_argument(
23
+ "-for", "--folder_original_path", default="data/test/original/T2V/videocraft", type=str, help="path to RGB image file folder"
24
+ )
25
+ parser.add_argument(
26
+ "-mop",
27
+ "--model_optical_flow_path",
28
+ type=str,
29
+ default="checkpoints/optical.pth",
30
+ )
31
+ parser.add_argument(
32
+ "-mor",
33
+ "--model_original_path",
34
+ type=str,
35
+ default="checkpoints/original.pth",
36
+ )
37
+
38
+ parser.add_argument(
39
+ "-t",
40
+ "--threshold",
41
+ type=float,
42
+ default=0.5,
43
+ )
44
+
45
+ parser.add_argument(
46
+ "-e",
47
+ "--excel_path",
48
+ type=str,
49
+ help="path to excel of frames",
50
+ default="data/results/moonvalley_wang.csv",
51
+ )
52
+
53
+ parser.add_argument(
54
+ "-ef",
55
+ "--excel_frame_path",
56
+ type=str,
57
+ help="path to excel of frame detection result",
58
+ default="data/results/frame/moonvalley_wang.csv",
59
+ )
60
+
61
+
62
+
63
+
64
+ parser.add_argument("--use_cpu", action="store_true", help="uses gpu by default, turn on to use cpu")
65
+ parser.add_argument("--arch", type=str, default="resnet50")
66
+ parser.add_argument("--aug_norm", type=str2bool, default=True)
67
+
68
+ args = parser.parse_args()
69
+ subfolder_count = 0
70
+
71
+ model_op = get_network(args.arch)
72
+ state_dict = torch.load(args.model_optical_flow_path, map_location="cpu")
73
+ if "model" in state_dict:
74
+ state_dict = state_dict["model"]
75
+ model_op.load_state_dict(state_dict)
76
+ model_op.eval()
77
+ if not args.use_cpu:
78
+ model_op.cuda()
79
+
80
+
81
+ model_or = get_network(args.arch)
82
+ state_dict = torch.load(args.model_original_path, map_location="cpu")
83
+ if "model" in state_dict:
84
+ state_dict = state_dict["model"]
85
+ model_or.load_state_dict(state_dict)
86
+ model_or.eval()
87
+ if not args.use_cpu:
88
+ model_or.cuda()
89
+
90
+
91
+ trans = transforms.Compose(
92
+ (
93
+ transforms.CenterCrop((448,448)),
94
+ transforms.ToTensor(),
95
+ )
96
+ )
97
+
98
+ print("*" * 50)
99
+
100
+ flag=0
101
+ p=0
102
+ n=0
103
+ tp=0
104
+ tn=0
105
+ y_true=[]
106
+ y_pred=[]
107
+
108
+ # create an empty DataFrame
109
+ df = pd.DataFrame(columns=['name', 'pro','flag','optical_pro','original_pro'])
110
+ df1 = pd.DataFrame(columns=['original_path', 'original_pro','optical_path','optical_pro','flag'])
111
+ index1=0
112
+
113
+ # Traverse through subfolders in a large folder.
114
+ for subfolder_name in ["0_real", "1_fake"]:
115
+ optical_subfolder_path = os.path.join(args.folder_optical_flow_path, subfolder_name)
116
+ original_subfolder_path = os.path.join(args.folder_original_path, subfolder_name)
117
+
118
+ if subfolder_name=="0_real":
119
+ flag=0
120
+ else:
121
+ flag=1
122
+
123
+ if os.path.isdir(optical_subfolder_path):
124
+ pass
125
+ else:
126
+ print("Subfolder does not exist.", optical_subfolder_path)
127
+
128
+ # Check if the subfolder path exists.
129
+ if os.path.isdir(original_subfolder_path):
130
+ print("test subfolder:", subfolder_name)
131
+
132
+ # Traverse through sub-subfolders within a subfolder.
133
+ for subsubfolder_name in os.listdir(original_subfolder_path):
134
+ original_subsubfolder_path = os.path.join(original_subfolder_path, subsubfolder_name)
135
+ optical_subsubfolder_path = os.path.join(optical_subfolder_path, subsubfolder_name)
136
+ if os.path.isdir(optical_subsubfolder_path):
137
+ pass
138
+ else:
139
+ print("Sub-subfolder does not exist.",optical_subsubfolder_path)
140
+
141
+ if os.path.isdir(original_subsubfolder_path):
142
+ print("test subsubfolder:", subsubfolder_name)
143
+
144
+ #Detect original
145
+ original_file_list = sorted(glob.glob(os.path.join(original_subsubfolder_path, "*.jpg")) + glob.glob(os.path.join(original_subsubfolder_path, "*.png"))+glob.glob(os.path.join(original_subsubfolder_path, "*.JPEG")))
146
+
147
+ original_prob_sum=0
148
+ for img_path in tqdm(original_file_list, dynamic_ncols=True, disable=len(original_file_list) <= 1):
149
+
150
+ img = Image.open(img_path).convert("RGB")
151
+ img = trans(img)
152
+ if args.aug_norm:
153
+ img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
154
+ in_tens = img.unsqueeze(0)
155
+ if not args.use_cpu:
156
+ in_tens = in_tens.cuda()
157
+
158
+ with torch.no_grad():
159
+ prob = model_or(in_tens).sigmoid().item()
160
+ original_prob_sum+=prob
161
+
162
+ df1 = df1.append({'original_path': img_path, 'original_pro': prob , 'flag':flag}, ignore_index=True)
163
+
164
+
165
+ original_predict=original_prob_sum/len(original_file_list)
166
+ print("original prob",original_predict)
167
+
168
+ #Detect optical flow
169
+ optical_file_list = sorted(glob.glob(os.path.join(optical_subsubfolder_path, "*.jpg")) + glob.glob(os.path.join(optical_subsubfolder_path, "*.png"))+glob.glob(os.path.join(optical_subsubfolder_path, "*.JPEG")))
170
+ optical_prob_sum=0
171
+ for img_path in tqdm(optical_file_list, dynamic_ncols=True, disable=len(original_file_list) <= 1):
172
+
173
+ img = Image.open(img_path).convert("RGB")
174
+ img = trans(img)
175
+ if args.aug_norm:
176
+ img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
177
+ in_tens = img.unsqueeze(0)
178
+ if not args.use_cpu:
179
+ in_tens = in_tens.cuda()
180
+
181
+ with torch.no_grad():
182
+ prob = model_op(in_tens).sigmoid().item()
183
+ optical_prob_sum+=prob
184
+
185
+ df1.loc[index1, 'optical_path'] = img_path
186
+ df1.loc[index1, 'optical_pro'] = prob
187
+ index1=index1+1
188
+ index1=index1+1
189
+
190
+ optical_predict=optical_prob_sum/len(optical_file_list)
191
+ print("optical prob",optical_predict)
192
+
193
+ predict=original_predict*0.5+optical_predict*0.5
194
+ print(f"flag:{flag} predict:{predict}")
195
+ # y_true.append((float)(flag))
196
+ y_true.append((flag))
197
+ y_pred.append(predict)
198
+ if flag==0:
199
+ n+=1
200
+ if predict<args.threshold:
201
+ tn+=1
202
+ else:
203
+ p+=1
204
+ if predict>=args.threshold:
205
+ tp+=1
206
+ df = df.append({'name': subsubfolder_name, 'pro': predict , 'flag':flag ,'optical_pro':optical_predict,'original_pro':original_predict}, ignore_index=True)
207
+ else:
208
+ print("Subfolder does not exist:", original_subfolder_path)
209
+ # r_acc = accuracy_score(y_true[y_true == 0], y_pred[y_true == 0] > args.threshold)
210
+ # f_acc = accuracy_score(y_true[y_true == 1], y_pred[y_true == 1] > args.threshold)
211
+ # acc = accuracy_score(y_true, y_pred > args.threshold)
212
+
213
+ ap = average_precision_score(y_true, y_pred)
214
+ auc=roc_auc_score(y_true,y_pred)
215
+ # print(f"r_acc:{r_acc}")
216
+ print(f"tnr:{tn/n}")
217
+ # print(f"f_acc:{f_acc}")
218
+ print(f"tpr:{tp/p}")
219
+ print(f"acc:{(tp+tn)/(p+n)}")
220
+ # print(f"acc:{acc}")
221
+ print(f"ap:{ap}")
222
+ print(f"auc:{auc}")
223
+ print(f"p:{p}")
224
+ print(f"n:{n}")
225
+ print(f"tp:{tp}")
226
+ print(f"tn:{tn}")
227
+
228
+ # Write the DataFrame to a csv file.
229
+ csv_filename = args.excel_path
230
+ csv_folder = os.path.dirname(csv_filename)
231
+ if not os.path.exists(csv_folder):
232
+ os.makedirs(csv_folder)
233
+
234
+
235
+ if not os.path.exists(csv_filename):
236
+ df.to_csv(csv_filename, index=False)
237
+ else:
238
+ df.to_csv(csv_filename, mode='a', header=False, index=False)
239
+ print(f"Results have been saved to {csv_filename}")
240
+
241
+ # Write the prediction probabilities of the frame to a CSV file.
242
+ csv_filename1 = args.excel_frame_path
243
+ csv_folder1 = os.path.dirname(csv_filename1)
244
+ if not os.path.exists(csv_folder1):
245
+ os.makedirs(csv_folder1)
246
+
247
+ if not os.path.exists(csv_filename1):
248
+ df1.to_csv(csv_filename1, index=False)
249
+ else:
250
+ df1.to_csv(csv_filename1, mode='a', header=False, index=False)
251
+
252
+ # if not os.path.exists(excel_filename):
253
+ # with pd.ExcelWriter(excel_filename, engine='xlsxwriter') as writer:
254
+ # df.to_excel(writer, sheet_name='Sheet1', index=False)
255
+ # else:
256
+ # with pd.ExcelWriter(excel_filename, mode='a', engine='openpyxl') as writer:
257
+ # df.to_excel(writer, sheet_name='Sheet1', index=False, startrow=0, header=False)
258
+ print(f"Results have been saved to {csv_filename1}")
259
+
260
+
261
+
AIGVDet/test.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ python test.py -fop "data/test/T2V/hotshot" -mop "checkpoints/optical_aug.pth" -for "data/test/original/T2V/hotshot" -mor "checkpoints/original_aug.pth" -e "data/results/T2V/hotshot.csv" -ef "data/results/frame/T2V/hotshot.csv" -t 0.5
AIGVDet/train.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from core.utils1.config import cfg # isort: split
2
+
3
+ import os
4
+ import time
5
+
6
+ from tensorboardX import SummaryWriter
7
+ from tqdm import tqdm
8
+
9
+ from core.utils1.datasets import create_dataloader
10
+ from core.utils1.earlystop import EarlyStopping
11
+ from core.utils1.eval import get_val_cfg, validate
12
+ from core.utils1.trainer import Trainer
13
+ from core.utils1.utils import Logger
14
+
15
+ import ssl
16
+ ssl._create_default_https_context = ssl._create_unverified_context
17
+
18
+
19
+ if __name__ == "__main__":
20
+ val_cfg = get_val_cfg(cfg, split="val", copy=True)
21
+ cfg.dataset_root = os.path.join(cfg.dataset_root, "train")
22
+ data_loader = create_dataloader(cfg)
23
+ dataset_size = len(data_loader)
24
+
25
+ log = Logger()
26
+ log.open(cfg.logs_path, mode="a")
27
+ log.write("Num of training images = %d\n" % (dataset_size * cfg.batch_size))
28
+ log.write("Config:\n" + str(cfg.to_dict()) + "\n")
29
+
30
+ train_writer = SummaryWriter(os.path.join(cfg.exp_dir, "train"))
31
+ val_writer = SummaryWriter(os.path.join(cfg.exp_dir, "val"))
32
+
33
+ trainer = Trainer(cfg)
34
+ early_stopping = EarlyStopping(patience=cfg.earlystop_epoch, delta=-0.001, verbose=True)
35
+ for epoch in range(cfg.nepoch):
36
+ epoch_start_time = time.time()
37
+ iter_data_time = time.time()
38
+ epoch_iter = 0
39
+
40
+ for data in tqdm(data_loader, dynamic_ncols=True):
41
+ trainer.total_steps += 1
42
+ epoch_iter += cfg.batch_size
43
+
44
+ trainer.set_input(data)
45
+ trainer.optimize_parameters()
46
+
47
+ # if trainer.total_steps % cfg.loss_freq == 0:
48
+ # log.write(f"Train loss: {trainer.loss} at step: {trainer.total_steps}\n")
49
+ train_writer.add_scalar("loss", trainer.loss, trainer.total_steps)
50
+
51
+ if trainer.total_steps % cfg.save_latest_freq == 0:
52
+ log.write(
53
+ "saving the latest model %s (epoch %d, model.total_steps %d)\n"
54
+ % (cfg.exp_name, epoch, trainer.total_steps)
55
+ )
56
+ trainer.save_networks("latest")
57
+
58
+ if epoch % cfg.save_epoch_freq == 0:
59
+ log.write("saving the model at the end of epoch %d, iters %d\n" % (epoch, trainer.total_steps))
60
+ trainer.save_networks("latest")
61
+ trainer.save_networks(epoch)
62
+
63
+ # Validation
64
+ trainer.eval()
65
+ val_results = validate(trainer.model, val_cfg)
66
+ val_writer.add_scalar("AP", val_results["AP"], trainer.total_steps)
67
+ val_writer.add_scalar("ACC", val_results["ACC"], trainer.total_steps)
68
+ # add
69
+ val_writer.add_scalar("AUC", val_results["AUC"], trainer.total_steps)
70
+ val_writer.add_scalar("TPR", val_results["TPR"], trainer.total_steps)
71
+ val_writer.add_scalar("TNR", val_results["TNR"], trainer.total_steps)
72
+
73
+ log.write(f"(Val @ epoch {epoch}) AP: {val_results['AP']}; ACC: {val_results['ACC']}\n")
74
+
75
+ if cfg.earlystop:
76
+ early_stopping(val_results["ACC"], trainer)
77
+ if early_stopping.early_stop:
78
+ if trainer.adjust_learning_rate():
79
+ log.write("Learning rate dropped by 10, continue training...\n")
80
+ early_stopping = EarlyStopping(patience=cfg.earlystop_epoch, delta=-0.002, verbose=True)
81
+ else:
82
+ log.write("Early stopping.\n")
83
+ break
84
+ if cfg.warmup:
85
+ # print(trainer.scheduler.get_lr()[0])
86
+ trainer.scheduler.step()
87
+ trainer.train()
AIGVDet/train.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ EXP_NAME="moonvalley_vos2_crop"
2
+ DATASETS="moonvalley_vos2_crop"
3
+ DATASETS_TEST="moonvalley_vos2_crop"
4
+ python train.py --gpus 0 --exp_name $EXP_NAME datasets $DATASETS datasets_test $DATASETS_TEST
api_server.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import shutil
4
+ import json
5
+ import uuid
6
+ import time
7
+ import threading
8
+ from typing import List, Dict, Any, Optional
9
+ from concurrent.futures import ThreadPoolExecutor
10
+
11
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks
12
+ from fastapi.responses import JSONResponse
13
+ from pydantic import BaseModel
14
+
15
+ # Giả lập import các thư viện của bạn
16
+ from miragenews import run_multimodal_to_json
17
+ from AIGVDet import run_video_to_json
18
+
19
+ UPLOAD_DIR = "temp_uploads"
20
+ MAX_WORKERS = 4
21
+
22
+ app = FastAPI(
23
+ title="Multimedia Analysis API (Polling Mode)",
24
+ description="API phân tích đa phương tiện sử dụng cơ chế Polling để tránh Timeout.",
25
+ version="2.0.0",
26
+ )
27
+
28
+ jobs: Dict[str, Dict] = {}
29
+ jobs_lock = threading.Lock()
30
+
31
+ class AnalysisResult(BaseModel):
32
+ image_analysis_results: Optional[List[Any]] = None
33
+ video_analysis_result: Optional[Dict[str, Any]] = None
34
+
35
+ class JobStatus(BaseModel):
36
+ job_id: str
37
+ status: str
38
+ message: Optional[str] = None
39
+ result: Optional[AnalysisResult] = None
40
+ created_at: float
41
+ updated_at: float
42
+
43
+ def _create_job() -> str:
44
+ job_id = uuid.uuid4().hex
45
+ now = time.time()
46
+ with jobs_lock:
47
+ jobs[job_id] = {
48
+ "job_id": job_id,
49
+ "status": "queued",
50
+ "message": "Đang chờ xử lý...",
51
+ "result": None,
52
+ "created_at": now,
53
+ "updated_at": now
54
+ }
55
+ return job_id
56
+
57
+ def _update_job(job_id: str, **kwargs):
58
+ with jobs_lock:
59
+ if job_id in jobs:
60
+ jobs[job_id].update(kwargs)
61
+ jobs[job_id]["updated_at"] = time.time()
62
+
63
+ def _get_job(job_id: str) -> Optional[Dict]:
64
+ with jobs_lock:
65
+ return jobs.get(job_id)
66
+
67
+ async def run_analysis_logic(
68
+ image_paths: Optional[List[str]] = None,
69
+ video_path: Optional[str] = None,
70
+ text: str = "",
71
+ ) -> Dict[str, Any]:
72
+
73
+ if not image_paths and not video_path:
74
+ raise ValueError("Cần cung cấp ít nhất một trong hai: image_paths hoặc video_path.")
75
+
76
+ tasks = []
77
+
78
+ if image_paths:
79
+ image_task = asyncio.create_task(
80
+ run_multimodal_to_json(image_paths=image_paths, text=text, output_json_path=None)
81
+ )
82
+ tasks.append(image_task)
83
+
84
+ if video_path:
85
+ video_task = asyncio.to_thread(
86
+ run_video_to_json, video_path=video_path, output_json_path=None
87
+ )
88
+ tasks.append(video_task)
89
+
90
+ task_results = await asyncio.gather(*tasks)
91
+
92
+ final_result = {"image_analysis_results": [], "video_analysis_result": {}}
93
+ image_analysis_results = []
94
+ video_result_index = -1
95
+
96
+ current_idx = 0
97
+ if image_paths:
98
+ image_analysis_results = task_results[current_idx]
99
+ current_idx += 1
100
+
101
+ if video_path:
102
+ video_result_index = current_idx
103
+
104
+ final_result["image_analysis_results"] = image_analysis_results
105
+
106
+ if video_result_index != -1:
107
+ raw_video_result = task_results[video_result_index]
108
+ if raw_video_result:
109
+ video_id_key = list(raw_video_result.keys())[0]
110
+ video_data = raw_video_result[video_id_key]
111
+
112
+ avg_authentic = video_data.get("authentic_confidence_score", 0)
113
+ avg_synthetic = video_data.get("synthetic_confidence_score", 0)
114
+
115
+ if avg_authentic > avg_synthetic and avg_authentic > 0.5:
116
+ authenticity_assessment = "REAL (Authentic)"
117
+ verification_tools = "Deepfake Detector"
118
+ synthetic_type = "N/A"
119
+ other_artifacts = "Our algorithms conducted a thorough analysis of the video's motion patterns, lighting consistency, and object interactions. We observed fluid, natural movements and consistent physics that align with real-world recordings. No discernible artifacts, such as pixel distortion, unnatural blurring, or shadow inconsistencies, were detected that would indicate digital manipulation or AI-driven synthesis."
120
+ elif avg_authentic > avg_synthetic and avg_authentic <= 0.5:
121
+ authenticity_assessment = "Potentially Synthetic"
122
+ verification_tools = "Deepfake Detector"
123
+ synthetic_type = "Potentially AI-generated"
124
+ other_artifacts = "Our analysis has identified subtle anomalies within the video frames, particularly in areas of complex texture and inconsistent lighting across different objects. While these discrepancies are not significant enough to definitively classify the video as synthetic, they do suggest a possibility of digital alteration or partial AI generation. Further examination may be required for a conclusive determination."
125
+ else:
126
+ authenticity_assessment = "NOT REAL (Fake, Manipulated, or AI)"
127
+ verification_tools = "Deepfake Detector"
128
+ synthetic_type = "AI-generated"
129
+ other_artifacts = "Our deep analysis detected multiple, significant artifacts commonly associated with synthetic or manipulated media. These include, but are not limited to, unnatural facial expressions and eye movements, inconsistent or floating shadows, logical impossibilities in object interaction, and high-frequency digital noise characteristic of generative models. These factors strongly indicate that the video is not authentic."
130
+
131
+ final_result["video_analysis_result"] = {
132
+ "filename": video_data.get("video_name", ""),
133
+ "result": {
134
+ "authenticity_assessment": authenticity_assessment,
135
+ "verification_tools_methods": verification_tools,
136
+ "synthetic_type": synthetic_type,
137
+ "other_artifacts": other_artifacts,
138
+ },
139
+ }
140
+
141
+ if not final_result.get("image_analysis_results"):
142
+ final_result.pop("image_analysis_results", None)
143
+ if not final_result.get("video_analysis_result"):
144
+ final_result.pop("video_analysis_result", None)
145
+
146
+ return final_result
147
+
148
+ async def process_job_background(
149
+ job_id: str,
150
+ temp_dir: str,
151
+ image_paths: List[str],
152
+ video_path: str,
153
+ text: str
154
+ ):
155
+ """Hàm chạy ngầm thực hiện phân tích"""
156
+ _update_job(job_id, status="running", message="Đang phân tích...")
157
+
158
+ try:
159
+ result_data = await run_analysis_logic(
160
+ image_paths=image_paths if image_paths else None,
161
+ video_path=video_path if video_path else None,
162
+ text=text
163
+ )
164
+
165
+ _update_job(job_id, status="succeeded", result=result_data, message="Hoàn tất")
166
+
167
+ except Exception as e:
168
+ print(f"Error processing job {job_id}: {e}")
169
+ _update_job(job_id, status="failed", message=str(e))
170
+
171
+ finally:
172
+ try:
173
+ if os.path.exists(temp_dir):
174
+ shutil.rmtree(temp_dir)
175
+ print(f"Deleted temp dir: {temp_dir}")
176
+ except Exception as cleanup_error:
177
+ print(f"Cleanup error for {job_id}: {cleanup_error}")
178
+
179
+
180
+ @app.on_event("startup")
181
+ def startup_event():
182
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
183
+
184
+ @app.get("/analyze/{job_id}", response_model=JobStatus)
185
+ async def get_job_status(job_id: str):
186
+ """Client gọi API này định kỳ để kiểm tra kết quả"""
187
+ job = _get_job(job_id)
188
+ if not job:
189
+ raise HTTPException(status_code=404, detail="Job not found")
190
+ return job
191
+
192
+ @app.post("/analyze/image/", response_model=JobStatus)
193
+ async def analyze_image_endpoint(
194
+ background_tasks: BackgroundTasks,
195
+ images: List[UploadFile] = File(...),
196
+ text: Optional[str] = Form(""),
197
+ ):
198
+ job_id = _create_job()
199
+
200
+ job_dir = os.path.join(UPLOAD_DIR, job_id)
201
+ os.makedirs(job_dir, exist_ok=True)
202
+
203
+ saved_image_paths = []
204
+ try:
205
+ for img in images:
206
+ file_path = os.path.join(job_dir, img.filename)
207
+ with open(file_path, "wb") as buffer:
208
+ shutil.copyfileobj(img.file, buffer)
209
+ saved_image_paths.append(file_path)
210
+ except Exception as e:
211
+ _update_job(job_id, status="failed", message=f"Lỗi upload: {e}")
212
+ return _get_job(job_id)
213
+
214
+ background_tasks.add_task(
215
+ process_job_background,
216
+ job_id, job_dir, saved_image_paths, None, text
217
+ )
218
+
219
+ return _get_job(job_id)
220
+
221
+ @app.post("/analyze/video/", response_model=JobStatus)
222
+ async def analyze_video_endpoint(
223
+ background_tasks: BackgroundTasks,
224
+ video: UploadFile = File(...),
225
+ ):
226
+ job_id = _create_job()
227
+ job_dir = os.path.join(UPLOAD_DIR, job_id)
228
+ os.makedirs(job_dir, exist_ok=True)
229
+
230
+ saved_video_path = os.path.join(job_dir, video.filename)
231
+ try:
232
+ with open(saved_video_path, "wb") as buffer:
233
+ shutil.copyfileobj(video.file, buffer)
234
+ except Exception as e:
235
+ _update_job(job_id, status="failed", message=f"Lỗi upload: {e}")
236
+ return _get_job(job_id)
237
+
238
+ background_tasks.add_task(
239
+ process_job_background,
240
+ job_id, job_dir, [], saved_video_path, ""
241
+ )
242
+
243
+ return _get_job(job_id)
244
+
245
+ @app.post("/analyze/multimodal/", response_model=JobStatus)
246
+ async def analyze_multimodal_endpoint(
247
+ background_tasks: BackgroundTasks,
248
+ images: List[UploadFile] = File(...),
249
+ video: UploadFile = File(...),
250
+ text: Optional[str] = Form(""),
251
+ ):
252
+ job_id = _create_job()
253
+ job_dir = os.path.join(UPLOAD_DIR, job_id)
254
+ os.makedirs(job_dir, exist_ok=True)
255
+
256
+ saved_image_paths = []
257
+ saved_video_path = None
258
+
259
+ try:
260
+ # Save Images
261
+ for img in images:
262
+ file_path = os.path.join(job_dir, img.filename)
263
+ with open(file_path, "wb") as buffer:
264
+ shutil.copyfileobj(img.file, buffer)
265
+ saved_image_paths.append(file_path)
266
+
267
+ # Save Video
268
+ saved_video_path = os.path.join(job_dir, video.filename)
269
+ with open(saved_video_path, "wb") as buffer:
270
+ shutil.copyfileobj(video.file, buffer)
271
+
272
+ except Exception as e:
273
+ _update_job(job_id, status="failed", message=f"Lỗi upload: {e}")
274
+ return _get_job(job_id)
275
+
276
+ background_tasks.add_task(
277
+ process_job_background,
278
+ job_id, job_dir, saved_image_paths, saved_video_path, text
279
+ )
280
+
281
+ return _get_job(job_id)
checkpoints/image/best-mirage-img.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:755e26bfe97830e03e4c475596ade7168e1df18c57d89162d685b1148ae9f5a8
3
+ size 2375
checkpoints/image/cbm-encoder.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a90c02450ae4a105414c6701bf571d1a498571083a725d916119d792cbdcf2f
3
+ size 1917947
checkpoints/image/cbm-predictor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2eba7ff0cd8db1c69961478551b2a40e47687f281d7e853f95a6650908e8c3df
3
+ size 2387