diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..7e2052b326e8dabeac0020570241e705c879ee4d
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,70 @@
+AIGVDet/frame/
+AIGVDet/optical_result/
+AIGVDet/temp/
+frame/
+optical_result/
+temp/
+temp_uploads/
+
+miragenews/encodings/crops/
+miragenews/encodings/image/test1_nyt_mj/
+miragenews/encodings/image/test2_bbc_dalle/
+miragenews/encodings/image/test3_cnn_dalle/
+miragenews/encodings/image/test4_bbc_sdxl/
+miragenews/encodings/image/test5_cnn_sdxl/
+miragenews/encodings/image/train/
+miragenews/encodings/image/validation/
+miragenews/encodings/predictions/image/cbm-encoder/test1_nyt_mj/
+miragenews/encodings/predictions/image/cbm-encoder/test2_bbc_dalle/
+miragenews/encodings/predictions/image/cbm-encoder/test3_cnn_dalle/
+miragenews/encodings/predictions/image/cbm-encoder/test4_bbc_sdxl/
+miragenews/encodings/predictions/image/cbm-encoder/test5_cnn_sdxl/
+miragenews/encodings/predictions/image/cbm-encoder/train/
+miragenews/encodings/predictions/image/cbm-encoder/validation/
+miragenews/encodings/predictions/image/linear/validation/
+miragenews/encodings/predictions/image/linear/test1_nyt_mj/
+miragenews/encodings/predictions/image/linear/test2_bbc_dalle/
+miragenews/encodings/predictions/image/linear/test3_cnn_dalle/
+miragenews/encodings/predictions/image/linear/test4_bbc_sdxl/
+miragenews/encodings/predictions/image/linear/test5_cnn_sdxl/
+miragenews/encodings/predictions/image/linear/train/
+miragenews/encodings/predictions/image/merged/train/
+miragenews/encodings/predictions/image/merged/test1_nyt_mj/
+miragenews/encodings/predictions/image/merged/test2_bbc_dalle/
+miragenews/encodings/predictions/image/merged/test3_cnn_dalle/
+miragenews/encodings/predictions/image/merged/test4_bbc_sdxl/
+miragenews/encodings/predictions/image/merged/test5_cnn_sdxl/
+miragenews/encodings/predictions/image/merged/validation/
+
+miragenews/encodings/text/train/
+miragenews/encodings/text/test1_nyt_mj/
+miragenews/encodings/text/test2_bbc_dalle/
+miragenews/encodings/text/test3_cnn_dalle/
+miragenews/encodings/text/test4_bbc_sdxl/
+miragenews/encodings/text/test5_cnn_sdxl/
+miragenews/encodings/text/validation/
+
+miragenews/encodings/predictions/text/merged/train/
+miragenews/encodings/predictions/text/merged/test1_nyt_mj/
+miragenews/encodings/predictions/text/merged/test2_bbc_dalle/
+miragenews/encodings/predictions/text/merged/test3_cnn_dalle/
+miragenews/encodings/predictions/text/merged/test4_bbc_sdxl/
+miragenews/encodings/predictions/text/merged/test5_cnn_sdxl/
+miragenews/encodings/predictions/text/merged/validation/
+
+miragenews/encodings/predictions/text/tbm-encoder/train/
+miragenews/encodings/predictions/text/tbm-encoder/test1_nyt_mj/
+miragenews/encodings/predictions/text/tbm-encoder/test2_bbc_dalle/
+miragenews/encodings/predictions/text/tbm-encoder/test3_cnn_dalle/
+miragenews/encodings/predictions/text/tbm-encoder/test4_bbc_sdxl/
+miragenews/encodings/predictions/text/tbm-encoder/test5_cnn_sdxl/
+miragenews/encodings/predictions/text/tbm-encoder/validation/
+
+miragenews/encodings/predictions/text/linear/train/
+miragenews/encodings/predictions/text/linear/test1_nyt_mj/
+miragenews/encodings/predictions/text/linear/test2_bbc_dalle/
+miragenews/encodings/predictions/text/linear/test3_cnn_dalle/
+miragenews/encodings/predictions/text/linear/test4_bbc_sdxl/
+miragenews/encodings/predictions/text/linear/test5_cnn_sdxl/
+miragenews/encodings/predictions/text/linear/validation/
+AIGVDet/data/
diff --git a/AIGVDet/Dockerfile b/AIGVDet/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..2840503765b7d1d1d4890efe1ea338d0e74fd0a8
--- /dev/null
+++ b/AIGVDet/Dockerfile
@@ -0,0 +1,21 @@
+FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-runtime
+
+# Install necessary OS packages for OpenCV
+RUN apt-get update && apt-get install -y \
+ libgl1 \
+ libglib2.0-0 \
+ && rm -rf /var/lib/apt/lists/*
+
+# Set working directory
+WORKDIR /app
+
+# Install Python dependencies
+COPY requirements.txt .
+RUN pip install --no-cache-dir -r requirements.txt
+
+# Copy all source code
+COPY . .
+
+# Default run command
+CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8003"]
+
diff --git a/AIGVDet/README.md b/AIGVDet/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b3eb259786d2bbd08bd97fbee06234b1f5e01878
--- /dev/null
+++ b/AIGVDet/README.md
@@ -0,0 +1,94 @@
+## AIGVDet
+An official implementation code for paper "AI-Generated Video Detection via Spatial-Temporal Anomaly Learning", PRCV 2024. This repo will provide codes, trained weights, and our training datasets.
+
+## Network Architecture
+
+
+## Dataset
+- Download the preprocessed training frames from
+[Baiduyun Link](https://pan.baidu.com/s/17xmDyFjtcmNsoxmUeImMTQ?pwd=ra95) (extract code: ra95).
+- Download the test videos from [Google Drive](https://drive.google.com/drive/folders/1D84SRWEJ8BK8KBpTMuGi3BUM80mW_dKb?usp=sharing).
+
+**You are allowed to use the datasets for research purpose only.**
+
+## Training
+- Prepare for the training datasets.
+```
+└─data
+ ├── train
+ │ └── trainset_1
+ │ ├── 0_real
+ │ │ ├── video_00000
+ │ │ │ ├── 00000.png
+ │ │ │ └── ...
+ │ │ └── ...
+ │ └── 1_fake
+ │ ├── video_00000
+ │ │ ├── 00000.png
+ │ │ └── ...
+ │ └── ...
+ ├── val
+ │ └── val_set_1
+ │ ├── 0_real
+ │ │ ├── video_00000
+ │ │ │ ├── 00000.png
+ │ │ │ └── ...
+ │ │ └── ...
+ │ └── 1_fake
+ │ ├── video_00000
+ │ │ ├── 00000.png
+ │ │ └── ...
+ │ └── ...
+ └── test
+ └── testset_1
+ ├── 0_real
+ │ ├── video_00000
+ │ │ ├── 00000.png
+ │ │ └── ...
+ │ └── ...
+ └── 1_fake
+ ├── video_00000
+ │ ├── 00000.png
+ │ └── ...
+ └── ...
+
+```
+- Modify configuration file in `core/utils1/config.py`.
+- Train the Spatial Domain Detector with the RGB frames.
+```
+python train.py --gpus 0 --exp_name TRAIN_RGB_BRANCH datasets RGB_TRAINSET datasets_test RGB_TESTSET
+```
+- Train the Optical Flow Detector with the optical flow frames.
+```
+python train.py --gpus 0 --exp_name TRAIN_OF_BRANCH datasets OpticalFlow_TRAINSET datasets_test OpticalFlow_TESTSET
+```
+## Testing
+Download the weights from [Google Drive Link](https://drive.google.com/drive/folders/18JO_YxOEqwJYfbVvy308XjoV-N6fE4yP?usp=share_link) and move it into the `checkpoints/`.
+
+- Run on a dataset.
+Prepare the RGB frames and the optical flow maps.
+```
+python test.py -fop "data/test/hotshot" -mop "checkpoints/optical_aug.pth" -for "data/test/original/hotshot" -mor "checkpoints/original_aug.pth" -e "data/results/T2V/hotshot.csv" -ef "data/results/frame/T2V/hotshot.csv" -t 0.5
+```
+- Run on a video.
+Download the RAFT model weights from [Google Drive Link](https://drive.google.com/file/d/1MqDajR89k-xLV0HIrmJ0k-n8ZpG6_suM/view) and move it into the `raft_model/`.
+```
+python demo.py --use_cpu --path "video/000000.mp4" --folder_original_path "frame/000000" --folder_optical_flow_path "optical_result/000000" -mop "checkpoints/optical.pth" -mor "checkpoints/original.pth"
+```
+
+## License
+The code and dataset is released only for academic research. Commercial usage is strictly prohibited.
+
+## Citation
+ ```
+@article{AIGVDet24,
+author = {Jianfa Bai and Man Lin and Gang Cao and Zijie Lou},
+title = {{AI-generated video detection via spatial-temporal anomaly learning}},
+conference = {The 7th Chinese Conference on Pattern Recognition and Computer Vision (PRCV)},
+year = {2024},}
+```
+
+## Contact
+If you have any questions, please contact us(lyan924@cuc.edu.cn).
+
+
diff --git a/AIGVDet/__init__.py b/AIGVDet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c869ebef5ec4af55cae2edf41fce45946dbe4af8
--- /dev/null
+++ b/AIGVDet/__init__.py
@@ -0,0 +1,3 @@
+from .main import run_video_to_json
+
+__all__ = ["run_video_to_json"]
\ No newline at end of file
diff --git a/AIGVDet/alt_cuda_corr/correlation.cpp b/AIGVDet/alt_cuda_corr/correlation.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..b01584d19edb99e7feec5f2e4c51169a1ed208db
--- /dev/null
+++ b/AIGVDet/alt_cuda_corr/correlation.cpp
@@ -0,0 +1,54 @@
+#include
+#include
+
+// CUDA forward declarations
+std::vector corr_cuda_forward(
+ torch::Tensor fmap1,
+ torch::Tensor fmap2,
+ torch::Tensor coords,
+ int radius);
+
+std::vector corr_cuda_backward(
+ torch::Tensor fmap1,
+ torch::Tensor fmap2,
+ torch::Tensor coords,
+ torch::Tensor corr_grad,
+ int radius);
+
+// C++ interface
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+std::vector corr_forward(
+ torch::Tensor fmap1,
+ torch::Tensor fmap2,
+ torch::Tensor coords,
+ int radius) {
+ CHECK_INPUT(fmap1);
+ CHECK_INPUT(fmap2);
+ CHECK_INPUT(coords);
+
+ return corr_cuda_forward(fmap1, fmap2, coords, radius);
+}
+
+
+std::vector corr_backward(
+ torch::Tensor fmap1,
+ torch::Tensor fmap2,
+ torch::Tensor coords,
+ torch::Tensor corr_grad,
+ int radius) {
+ CHECK_INPUT(fmap1);
+ CHECK_INPUT(fmap2);
+ CHECK_INPUT(coords);
+ CHECK_INPUT(corr_grad);
+
+ return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
+}
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("forward", &corr_forward, "CORR forward");
+ m.def("backward", &corr_backward, "CORR backward");
+}
\ No newline at end of file
diff --git a/AIGVDet/alt_cuda_corr/correlation_kernel.cu b/AIGVDet/alt_cuda_corr/correlation_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..145e5804a16ece51b8ff5f1cb61ae8dab4fc3bb7
--- /dev/null
+++ b/AIGVDet/alt_cuda_corr/correlation_kernel.cu
@@ -0,0 +1,324 @@
+#include
+#include
+#include
+#include
+
+
+#define BLOCK_H 4
+#define BLOCK_W 8
+#define BLOCK_HW BLOCK_H * BLOCK_W
+#define CHANNEL_STRIDE 32
+
+
+__forceinline__ __device__
+bool within_bounds(int h, int w, int H, int W) {
+ return h >= 0 && h < H && w >= 0 && w < W;
+}
+
+template
+__global__ void corr_forward_kernel(
+ const torch::PackedTensorAccessor32 fmap1,
+ const torch::PackedTensorAccessor32 fmap2,
+ const torch::PackedTensorAccessor32 coords,
+ torch::PackedTensorAccessor32 corr,
+ int r)
+{
+ const int b = blockIdx.x;
+ const int h0 = blockIdx.y * blockDim.x;
+ const int w0 = blockIdx.z * blockDim.y;
+ const int tid = threadIdx.x * blockDim.y + threadIdx.y;
+
+ const int H1 = fmap1.size(1);
+ const int W1 = fmap1.size(2);
+ const int H2 = fmap2.size(1);
+ const int W2 = fmap2.size(2);
+ const int N = coords.size(1);
+ const int C = fmap1.size(3);
+
+ __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
+ __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
+ __shared__ scalar_t x2s[BLOCK_HW];
+ __shared__ scalar_t y2s[BLOCK_HW];
+
+ for (int c=0; c(floor(y2s[k1]))-r+iy;
+ int w2 = static_cast(floor(x2s[k1]))-r+ix;
+ int c2 = tid % CHANNEL_STRIDE;
+
+ auto fptr = fmap2[b][h2][w2];
+ if (within_bounds(h2, w2, H2, W2))
+ f2[c2][k1] = fptr[c+c2];
+ else
+ f2[c2][k1] = 0.0;
+ }
+
+ __syncthreads();
+
+ scalar_t s = 0.0;
+ for (int k=0; k 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
+ *(corr_ptr + ix_nw) += nw;
+
+ if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
+ *(corr_ptr + ix_ne) += ne;
+
+ if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
+ *(corr_ptr + ix_sw) += sw;
+
+ if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
+ *(corr_ptr + ix_se) += se;
+ }
+ }
+ }
+ }
+}
+
+
+template
+__global__ void corr_backward_kernel(
+ const torch::PackedTensorAccessor32 fmap1,
+ const torch::PackedTensorAccessor32 fmap2,
+ const torch::PackedTensorAccessor32 coords,
+ const torch::PackedTensorAccessor32 corr_grad,
+ torch::PackedTensorAccessor32 fmap1_grad,
+ torch::PackedTensorAccessor32 fmap2_grad,
+ torch::PackedTensorAccessor32 coords_grad,
+ int r)
+{
+
+ const int b = blockIdx.x;
+ const int h0 = blockIdx.y * blockDim.x;
+ const int w0 = blockIdx.z * blockDim.y;
+ const int tid = threadIdx.x * blockDim.y + threadIdx.y;
+
+ const int H1 = fmap1.size(1);
+ const int W1 = fmap1.size(2);
+ const int H2 = fmap2.size(1);
+ const int W2 = fmap2.size(2);
+ const int N = coords.size(1);
+ const int C = fmap1.size(3);
+
+ __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
+ __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
+
+ __shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1];
+ __shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1];
+
+ __shared__ scalar_t x2s[BLOCK_HW];
+ __shared__ scalar_t y2s[BLOCK_HW];
+
+ for (int c=0; c(floor(y2s[k1]))-r+iy;
+ int w2 = static_cast(floor(x2s[k1]))-r+ix;
+ int c2 = tid % CHANNEL_STRIDE;
+
+ auto fptr = fmap2[b][h2][w2];
+ if (within_bounds(h2, w2, H2, W2))
+ f2[c2][k1] = fptr[c+c2];
+ else
+ f2[c2][k1] = 0.0;
+
+ f2_grad[c2][k1] = 0.0;
+ }
+
+ __syncthreads();
+
+ const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1];
+ scalar_t g = 0.0;
+
+ int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
+ int ix_ne = H1*W1*((iy-1) + rd*ix);
+ int ix_sw = H1*W1*(iy + rd*(ix-1));
+ int ix_se = H1*W1*(iy + rd*ix);
+
+ if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
+ g += *(grad_ptr + ix_nw) * dy * dx;
+
+ if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
+ g += *(grad_ptr + ix_ne) * dy * (1-dx);
+
+ if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
+ g += *(grad_ptr + ix_sw) * (1-dy) * dx;
+
+ if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
+ g += *(grad_ptr + ix_se) * (1-dy) * (1-dx);
+
+ for (int k=0; k(floor(y2s[k1]))-r+iy;
+ int w2 = static_cast(floor(x2s[k1]))-r+ix;
+ int c2 = tid % CHANNEL_STRIDE;
+
+ scalar_t* fptr = &fmap2_grad[b][h2][w2][0];
+ if (within_bounds(h2, w2, H2, W2))
+ atomicAdd(fptr+c+c2, f2_grad[c2][k1]);
+ }
+ }
+ }
+ }
+ __syncthreads();
+
+
+ for (int k=0; k corr_cuda_forward(
+ torch::Tensor fmap1,
+ torch::Tensor fmap2,
+ torch::Tensor coords,
+ int radius)
+{
+ const auto B = coords.size(0);
+ const auto N = coords.size(1);
+ const auto H = coords.size(2);
+ const auto W = coords.size(3);
+
+ const auto rd = 2 * radius + 1;
+ auto opts = fmap1.options();
+ auto corr = torch::zeros({B, N, rd*rd, H, W}, opts);
+
+ const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W);
+ const dim3 threads(BLOCK_H, BLOCK_W);
+
+ corr_forward_kernel<<>>(
+ fmap1.packed_accessor32(),
+ fmap2.packed_accessor32(),
+ coords.packed_accessor32(),
+ corr.packed_accessor32(),
+ radius);
+
+ return {corr};
+}
+
+std::vector corr_cuda_backward(
+ torch::Tensor fmap1,
+ torch::Tensor fmap2,
+ torch::Tensor coords,
+ torch::Tensor corr_grad,
+ int radius)
+{
+ const auto B = coords.size(0);
+ const auto N = coords.size(1);
+
+ const auto H1 = fmap1.size(1);
+ const auto W1 = fmap1.size(2);
+ const auto H2 = fmap2.size(1);
+ const auto W2 = fmap2.size(2);
+ const auto C = fmap1.size(3);
+
+ auto opts = fmap1.options();
+ auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts);
+ auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts);
+ auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts);
+
+ const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W);
+ const dim3 threads(BLOCK_H, BLOCK_W);
+
+
+ corr_backward_kernel<<>>(
+ fmap1.packed_accessor32(),
+ fmap2.packed_accessor32(),
+ coords.packed_accessor32(),
+ corr_grad.packed_accessor32(),
+ fmap1_grad.packed_accessor32(),
+ fmap2_grad.packed_accessor32(),
+ coords_grad.packed_accessor32(),
+ radius);
+
+ return {fmap1_grad, fmap2_grad, coords_grad};
+}
\ No newline at end of file
diff --git a/AIGVDet/alt_cuda_corr/setup.py b/AIGVDet/alt_cuda_corr/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0207ff285ffac4c8146c79d154f12416dbef48c
--- /dev/null
+++ b/AIGVDet/alt_cuda_corr/setup.py
@@ -0,0 +1,15 @@
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+
+setup(
+ name='correlation',
+ ext_modules=[
+ CUDAExtension('alt_cuda_corr',
+ sources=['correlation.cpp', 'correlation_kernel.cu'],
+ extra_compile_args={'cxx': [], 'nvcc': ['-O3']}),
+ ],
+ cmdclass={
+ 'build_ext': BuildExtension
+ })
+
diff --git a/AIGVDet/app.py b/AIGVDet/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad7a2d47360871d0e9e1e9eb09f7aa03de2fd9a0
--- /dev/null
+++ b/AIGVDet/app.py
@@ -0,0 +1,155 @@
+# app.py
+import os
+import shutil
+import uuid
+import time
+from typing import Optional
+
+from fastapi import FastAPI, UploadFile, File, HTTPException, status
+from pydantic import BaseModel
+from run import RUN
+
+
+class PredictionResponse(BaseModel):
+ authenticity_assessment: str
+ verification_tools_methods: str
+ synthetic_type: str
+ other_artifacts: str
+
+app = FastAPI(
+ title="Video Authenticity API",
+ description="Detect authentic vs synthetic video using deepfake detector.",
+)
+
+def predict(input_path: str) -> dict:
+ result_data = {
+ "authenticity_assessment": "Error",
+ "verification_tools_methods": "Deepfake Detection Model (Optical Flow + Frame Analysis)",
+ "synthetic_type": "N/A",
+ "other_artifacts": "Analysis failed."
+ }
+
+ if not os.path.isfile(input_path):
+ raise FileNotFoundError(f"File not found: {input_path}")
+
+ video_name = os.path.basename(input_path)
+ video_id = os.path.splitext(video_name)[0]
+
+ folder_original = f"frame/{video_id}"
+ folder_optical = f"optical_result/{video_id}"
+
+ args = [
+ "--path", input_path,
+ "--folder_original_path", folder_original,
+ "--folder_optical_flow_path", folder_optical,
+ "--model_optical_flow_path", "checkpoints/optical.pth",
+ "--model_original_path", "checkpoints/original.pth",
+ ]
+
+ try:
+ start_time = time.perf_counter()
+
+ output = RUN(args)
+
+ elapsed = time.perf_counter() - start_time
+ print(f"⏱️ [run_AIGVDetection] Service call took {elapsed:.2f} seconds")
+
+ real_score = float(output.get("real_score", 0.0))
+ fake_score = float(output.get("fake_score", 0.0))
+
+ likely_authentic = real_score > fake_score
+
+ if likely_authentic:
+ assessment = f"REAL (Authentic) | Confidence: {real_score:.4f}"
+
+ analysis_text = (
+ "Our algorithms observed consistent and natural motion patterns across frames. "
+ "Inter-frame motion analysis indicates that objects maintain physical trajectories consistent with real-world recording, "
+ "without the jitter or warping artifacts typically associated with generative AI."
+ )
+ syn_type = "N/A"
+ else:
+ assessment = f"🤖 NOT REAL (Fake/Synthetic) | Confidence: {fake_score:.4f}"
+
+ analysis_text = (
+ "Our algorithms have detected asynchronous and inconsistent movement between frames. "
+ "Upon conducting inter-frame motion analysis, we observed that objects and details within the video "
+ "fail to maintain natural motion trajectories. These anomalies—such as sudden velocity shifts, "
+ "subtle per-frame distortions, or motion vectors that defy physical laws—are characteristic indicators "
+ "typically found in AI-generated videos."
+ )
+ syn_type = "Video Deepfake / AI Generated"
+
+ tools = "Deepfake Detector (Optical Flow + CNN Frame Analysis)"
+
+ artifacts = (
+ f"{analysis_text}"
+ )
+
+ result_data = {
+ "authenticity_assessment": assessment,
+ "verification_tools_methods": tools,
+ "synthetic_type": syn_type,
+ "other_artifacts": artifacts
+ }
+
+ except Exception as e:
+ import traceback
+ traceback.print_exc()
+ result_data["other_artifacts"] = f"Error during processing: {str(e)}"
+
+ finally:
+ for folder in [folder_original, folder_optical]:
+ try:
+ if os.path.exists(folder):
+ shutil.rmtree(folder)
+ except Exception as cleanup_error:
+ print(f"Error deleting folder {folder}: {cleanup_error}")
+
+ return result_data
+
+# --- API ENDPOINT ---
+@app.post("/predict", response_model=PredictionResponse)
+async def predict_endpoint(file: UploadFile = File(...)):
+ try:
+ for parent in ["frame", "optical_result", "uploads"]:
+ if os.path.exists(parent):
+ print(f"🧹 Cleaning folder: {parent}")
+ for item in os.listdir(parent):
+ path = os.path.join(parent, item)
+ try:
+ if os.path.isfile(path) or os.path.islink(path):
+ os.remove(path)
+ elif os.path.isdir(path):
+ shutil.rmtree(path)
+ except Exception as e:
+ print(f"⚠️ Cannot delete {path}: {e}")
+ else:
+ os.makedirs(parent)
+ print(f"📁 Created new folder: {parent}")
+
+ temp_filename = f"uploads_{uuid.uuid4().hex}_{file.filename}"
+ os.makedirs("uploads", exist_ok=True)
+ temp_filepath = os.path.join("uploads", temp_filename)
+
+ with open(temp_filepath, "wb") as buffer:
+ shutil.copyfileobj(file.file, buffer)
+
+ result = predict(temp_filepath)
+
+ if os.path.exists(temp_filepath):
+ os.remove(temp_filepath)
+
+ return result
+
+ except Exception as e:
+ import traceback
+ traceback.print_exc()
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Prediction failed: {e}",
+ )
+
+if __name__ == "__main__":
+ import uvicorn
+ uvicorn.run(app, host="0.0.0.0", port=80022)
diff --git a/AIGVDet/checkpoints/optical.pth b/AIGVDet/checkpoints/optical.pth
new file mode 100644
index 0000000000000000000000000000000000000000..86370d59e8975f4d0c83dbccc4944646c7c7b468
--- /dev/null
+++ b/AIGVDet/checkpoints/optical.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:23a167ba7adccb6421fd0eddd81fbfc69519d2a97a343dbf0d5da894b9893b19
+size 282581704
diff --git a/AIGVDet/checkpoints/original.pth b/AIGVDet/checkpoints/original.pth
new file mode 100644
index 0000000000000000000000000000000000000000..f13a9ca5dfb5b7c919082c1b548f3939f127eaaf
--- /dev/null
+++ b/AIGVDet/checkpoints/original.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c2df6f477a590f4b24b14eac9654b868a9f178312795df20749274a502d59bdd
+size 282581704
diff --git a/AIGVDet/core/__init__.py b/AIGVDet/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/AIGVDet/core/corr.py b/AIGVDet/core/corr.py
new file mode 100644
index 0000000000000000000000000000000000000000..e31800dcd9883ee355a15f2121e747c14fde2499
--- /dev/null
+++ b/AIGVDet/core/corr.py
@@ -0,0 +1,91 @@
+import torch
+import torch.nn.functional as F
+from .utils.utils import bilinear_sampler, coords_grid
+
+try:
+ from .. import alt_cuda_corr
+except:
+ # alt_cuda_corr is not compiled
+ pass
+
+
+class CorrBlock:
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
+ self.num_levels = num_levels
+ self.radius = radius
+ self.corr_pyramid = []
+
+ # all pairs correlation
+ corr = CorrBlock.corr(fmap1, fmap2)
+
+ batch, h1, w1, dim, h2, w2 = corr.shape
+ corr = corr.reshape(batch*h1*w1, dim, h2, w2)
+
+ self.corr_pyramid.append(corr)
+ for i in range(self.num_levels-1):
+ corr = F.avg_pool2d(corr, 2, stride=2)
+ self.corr_pyramid.append(corr)
+
+ def __call__(self, coords):
+ r = self.radius
+ coords = coords.permute(0, 2, 3, 1)
+ batch, h1, w1, _ = coords.shape
+
+ out_pyramid = []
+ for i in range(self.num_levels):
+ corr = self.corr_pyramid[i]
+ dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
+ dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
+ delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
+
+ centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
+ delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
+ coords_lvl = centroid_lvl + delta_lvl
+
+ corr = bilinear_sampler(corr, coords_lvl)
+ corr = corr.view(batch, h1, w1, -1)
+ out_pyramid.append(corr)
+
+ out = torch.cat(out_pyramid, dim=-1)
+ return out.permute(0, 3, 1, 2).contiguous().float()
+
+ @staticmethod
+ def corr(fmap1, fmap2):
+ batch, dim, ht, wd = fmap1.shape
+ fmap1 = fmap1.view(batch, dim, ht*wd)
+ fmap2 = fmap2.view(batch, dim, ht*wd)
+
+ corr = torch.matmul(fmap1.transpose(1,2), fmap2)
+ corr = corr.view(batch, ht, wd, 1, ht, wd)
+ return corr / torch.sqrt(torch.tensor(dim).float())
+
+
+class AlternateCorrBlock:
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
+ self.num_levels = num_levels
+ self.radius = radius
+
+ self.pyramid = [(fmap1, fmap2)]
+ for i in range(self.num_levels):
+ fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
+ fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
+ self.pyramid.append((fmap1, fmap2))
+
+ def __call__(self, coords):
+ coords = coords.permute(0, 2, 3, 1)
+ B, H, W, _ = coords.shape
+ dim = self.pyramid[0][0].shape[1]
+
+ corr_list = []
+ for i in range(self.num_levels):
+ r = self.radius
+ fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
+ fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
+
+ coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
+ corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
+ corr_list.append(corr.squeeze(1))
+
+ corr = torch.stack(corr_list, dim=1)
+ corr = corr.reshape(B, -1, H, W)
+ return corr / torch.sqrt(torch.tensor(dim).float())
diff --git a/AIGVDet/core/datasets.py b/AIGVDet/core/datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..3411fdacfb900024005e8997d07c600e963a95ca
--- /dev/null
+++ b/AIGVDet/core/datasets.py
@@ -0,0 +1,235 @@
+# Data loading based on https://github.com/NVIDIA/flownet2-pytorch
+
+import numpy as np
+import torch
+import torch.utils.data as data
+import torch.nn.functional as F
+
+import os
+import math
+import random
+from glob import glob
+import os.path as osp
+
+from utils import frame_utils
+from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
+
+
+class FlowDataset(data.Dataset):
+ def __init__(self, aug_params=None, sparse=False):
+ self.augmentor = None
+ self.sparse = sparse
+ if aug_params is not None:
+ if sparse:
+ self.augmentor = SparseFlowAugmentor(**aug_params)
+ else:
+ self.augmentor = FlowAugmentor(**aug_params)
+
+ self.is_test = False
+ self.init_seed = False
+ self.flow_list = []
+ self.image_list = []
+ self.extra_info = []
+
+ def __getitem__(self, index):
+
+ if self.is_test:
+ img1 = frame_utils.read_gen(self.image_list[index][0])
+ img2 = frame_utils.read_gen(self.image_list[index][1])
+ img1 = np.array(img1).astype(np.uint8)[..., :3]
+ img2 = np.array(img2).astype(np.uint8)[..., :3]
+ img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
+ img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
+ return img1, img2, self.extra_info[index]
+
+ if not self.init_seed:
+ worker_info = torch.utils.data.get_worker_info()
+ if worker_info is not None:
+ torch.manual_seed(worker_info.id)
+ np.random.seed(worker_info.id)
+ random.seed(worker_info.id)
+ self.init_seed = True
+
+ index = index % len(self.image_list)
+ valid = None
+ if self.sparse:
+ flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
+ else:
+ flow = frame_utils.read_gen(self.flow_list[index])
+
+ img1 = frame_utils.read_gen(self.image_list[index][0])
+ img2 = frame_utils.read_gen(self.image_list[index][1])
+
+ flow = np.array(flow).astype(np.float32)
+ img1 = np.array(img1).astype(np.uint8)
+ img2 = np.array(img2).astype(np.uint8)
+
+ # grayscale images
+ if len(img1.shape) == 2:
+ img1 = np.tile(img1[...,None], (1, 1, 3))
+ img2 = np.tile(img2[...,None], (1, 1, 3))
+ else:
+ img1 = img1[..., :3]
+ img2 = img2[..., :3]
+
+ if self.augmentor is not None:
+ if self.sparse:
+ img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
+ else:
+ img1, img2, flow = self.augmentor(img1, img2, flow)
+
+ img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
+ img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
+ flow = torch.from_numpy(flow).permute(2, 0, 1).float()
+
+ if valid is not None:
+ valid = torch.from_numpy(valid)
+ else:
+ valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
+
+ return img1, img2, flow, valid.float()
+
+
+ def __rmul__(self, v):
+ self.flow_list = v * self.flow_list
+ self.image_list = v * self.image_list
+ return self
+
+ def __len__(self):
+ return len(self.image_list)
+
+
+class MpiSintel(FlowDataset):
+ def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
+ super(MpiSintel, self).__init__(aug_params)
+ flow_root = osp.join(root, split, 'flow')
+ image_root = osp.join(root, split, dstype)
+
+ if split == 'test':
+ self.is_test = True
+
+ for scene in os.listdir(image_root):
+ image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
+ for i in range(len(image_list)-1):
+ self.image_list += [ [image_list[i], image_list[i+1]] ]
+ self.extra_info += [ (scene, i) ] # scene and frame_id
+
+ if split != 'test':
+ self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
+
+
+class FlyingChairs(FlowDataset):
+ def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
+ super(FlyingChairs, self).__init__(aug_params)
+
+ images = sorted(glob(osp.join(root, '*.ppm')))
+ flows = sorted(glob(osp.join(root, '*.flo')))
+ assert (len(images)//2 == len(flows))
+
+ split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
+ for i in range(len(flows)):
+ xid = split_list[i]
+ if (split=='training' and xid==1) or (split=='validation' and xid==2):
+ self.flow_list += [ flows[i] ]
+ self.image_list += [ [images[2*i], images[2*i+1]] ]
+
+
+class FlyingThings3D(FlowDataset):
+ def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
+ super(FlyingThings3D, self).__init__(aug_params)
+
+ for cam in ['left']:
+ for direction in ['into_future', 'into_past']:
+ image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
+ image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
+
+ flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
+ flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
+
+ for idir, fdir in zip(image_dirs, flow_dirs):
+ images = sorted(glob(osp.join(idir, '*.png')) )
+ flows = sorted(glob(osp.join(fdir, '*.pfm')) )
+ for i in range(len(flows)-1):
+ if direction == 'into_future':
+ self.image_list += [ [images[i], images[i+1]] ]
+ self.flow_list += [ flows[i] ]
+ elif direction == 'into_past':
+ self.image_list += [ [images[i+1], images[i]] ]
+ self.flow_list += [ flows[i+1] ]
+
+
+class KITTI(FlowDataset):
+ def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
+ super(KITTI, self).__init__(aug_params, sparse=True)
+ if split == 'testing':
+ self.is_test = True
+
+ root = osp.join(root, split)
+ images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
+ images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
+
+ for img1, img2 in zip(images1, images2):
+ frame_id = img1.split('/')[-1]
+ self.extra_info += [ [frame_id] ]
+ self.image_list += [ [img1, img2] ]
+
+ if split == 'training':
+ self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
+
+
+class HD1K(FlowDataset):
+ def __init__(self, aug_params=None, root='datasets/HD1k'):
+ super(HD1K, self).__init__(aug_params, sparse=True)
+
+ seq_ix = 0
+ while 1:
+ flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
+ images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
+
+ if len(flows) == 0:
+ break
+
+ for i in range(len(flows)-1):
+ self.flow_list += [flows[i]]
+ self.image_list += [ [images[i], images[i+1]] ]
+
+ seq_ix += 1
+
+
+def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
+ """ Create the data loader for the corresponding trainign set """
+
+ if args.stage == 'chairs':
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
+ train_dataset = FlyingChairs(aug_params, split='training')
+
+ elif args.stage == 'things':
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
+ clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
+ final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
+ train_dataset = clean_dataset + final_dataset
+
+ elif args.stage == 'sintel':
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
+ things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
+ sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
+ sintel_final = MpiSintel(aug_params, split='training', dstype='final')
+
+ if TRAIN_DS == 'C+T+K+S+H':
+ kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
+ hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
+ train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
+
+ elif TRAIN_DS == 'C+T+K/S':
+ train_dataset = 100*sintel_clean + 100*sintel_final + things
+
+ elif args.stage == 'kitti':
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
+ train_dataset = KITTI(aug_params, split='training')
+
+ train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
+ pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
+
+ print('Training with %d image pairs' % len(train_dataset))
+ return train_loader
+
diff --git a/AIGVDet/core/extractor.py b/AIGVDet/core/extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a9c759d1243d4694e8656c2f6f8a37e53edd009
--- /dev/null
+++ b/AIGVDet/core/extractor.py
@@ -0,0 +1,267 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
+ super(ResidualBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == 'group':
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not stride == 1:
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == 'batch':
+ self.norm1 = nn.BatchNorm2d(planes)
+ self.norm2 = nn.BatchNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == 'instance':
+ self.norm1 = nn.InstanceNorm2d(planes)
+ self.norm2 = nn.InstanceNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == 'none':
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ if not stride == 1:
+ self.norm3 = nn.Sequential()
+
+ if stride == 1:
+ self.downsample = None
+
+ else:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
+
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.norm1(self.conv1(y)))
+ y = self.relu(self.norm2(self.conv2(y)))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x+y)
+
+
+
+class BottleneckBlock(nn.Module):
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
+ super(BottleneckBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
+ self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
+ self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == 'group':
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not stride == 1:
+ self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == 'batch':
+ self.norm1 = nn.BatchNorm2d(planes//4)
+ self.norm2 = nn.BatchNorm2d(planes//4)
+ self.norm3 = nn.BatchNorm2d(planes)
+ if not stride == 1:
+ self.norm4 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == 'instance':
+ self.norm1 = nn.InstanceNorm2d(planes//4)
+ self.norm2 = nn.InstanceNorm2d(planes//4)
+ self.norm3 = nn.InstanceNorm2d(planes)
+ if not stride == 1:
+ self.norm4 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == 'none':
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ self.norm3 = nn.Sequential()
+ if not stride == 1:
+ self.norm4 = nn.Sequential()
+
+ if stride == 1:
+ self.downsample = None
+
+ else:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
+
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.norm1(self.conv1(y)))
+ y = self.relu(self.norm2(self.conv2(y)))
+ y = self.relu(self.norm3(self.conv3(y)))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x+y)
+
+class BasicEncoder(nn.Module):
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
+ super(BasicEncoder, self).__init__()
+ self.norm_fn = norm_fn
+
+ if self.norm_fn == 'group':
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
+
+ elif self.norm_fn == 'batch':
+ self.norm1 = nn.BatchNorm2d(64)
+
+ elif self.norm_fn == 'instance':
+ self.norm1 = nn.InstanceNorm2d(64)
+
+ elif self.norm_fn == 'none':
+ self.norm1 = nn.Sequential()
+
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
+ self.relu1 = nn.ReLU(inplace=True)
+
+ self.in_planes = 64
+ self.layer1 = self._make_layer(64, stride=1)
+ self.layer2 = self._make_layer(96, stride=2)
+ self.layer3 = self._make_layer(128, stride=2)
+
+ # output convolution
+ self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
+
+ self.dropout = None
+ if dropout > 0:
+ self.dropout = nn.Dropout2d(p=dropout)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, dim, stride=1):
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+
+ def forward(self, x):
+
+ # if input is list, combine batch dimension
+ is_list = isinstance(x, tuple) or isinstance(x, list)
+ if is_list:
+ batch_dim = x[0].shape[0]
+ x = torch.cat(x, dim=0)
+
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu1(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+
+ x = self.conv2(x)
+
+ if self.training and self.dropout is not None:
+ x = self.dropout(x)
+
+ if is_list:
+ x = torch.split(x, [batch_dim, batch_dim], dim=0)
+
+ return x
+
+
+class SmallEncoder(nn.Module):
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
+ super(SmallEncoder, self).__init__()
+ self.norm_fn = norm_fn
+
+ if self.norm_fn == 'group':
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
+
+ elif self.norm_fn == 'batch':
+ self.norm1 = nn.BatchNorm2d(32)
+
+ elif self.norm_fn == 'instance':
+ self.norm1 = nn.InstanceNorm2d(32)
+
+ elif self.norm_fn == 'none':
+ self.norm1 = nn.Sequential()
+
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
+ self.relu1 = nn.ReLU(inplace=True)
+
+ self.in_planes = 32
+ self.layer1 = self._make_layer(32, stride=1)
+ self.layer2 = self._make_layer(64, stride=2)
+ self.layer3 = self._make_layer(96, stride=2)
+
+ self.dropout = None
+ if dropout > 0:
+ self.dropout = nn.Dropout2d(p=dropout)
+
+ self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, dim, stride=1):
+ layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
+ layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+
+ def forward(self, x):
+
+ # if input is list, combine batch dimension
+ is_list = isinstance(x, tuple) or isinstance(x, list)
+ if is_list:
+ batch_dim = x[0].shape[0]
+ x = torch.cat(x, dim=0)
+
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu1(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.conv2(x)
+
+ if self.training and self.dropout is not None:
+ x = self.dropout(x)
+
+ if is_list:
+ x = torch.split(x, [batch_dim, batch_dim], dim=0)
+
+ return x
diff --git a/AIGVDet/core/raft.py b/AIGVDet/core/raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..14b850058e4df3de51b9efa7031e22f3be6d9516
--- /dev/null
+++ b/AIGVDet/core/raft.py
@@ -0,0 +1,144 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .update import BasicUpdateBlock, SmallUpdateBlock
+from .extractor import BasicEncoder, SmallEncoder
+from .corr import CorrBlock, AlternateCorrBlock
+from .utils.utils import bilinear_sampler, coords_grid, upflow8
+
+try:
+ autocast = torch.cuda.amp.autocast
+except:
+ # dummy autocast for PyTorch < 1.6
+ class autocast:
+ def __init__(self, enabled):
+ pass
+ def __enter__(self):
+ pass
+ def __exit__(self, *args):
+ pass
+
+
+class RAFT(nn.Module):
+ def __init__(self, args):
+ super(RAFT, self).__init__()
+ self.args = args
+
+ if args.small:
+ self.hidden_dim = hdim = 96
+ self.context_dim = cdim = 64
+ args.corr_levels = 4
+ args.corr_radius = 3
+
+ else:
+ self.hidden_dim = hdim = 128
+ self.context_dim = cdim = 128
+ args.corr_levels = 4
+ args.corr_radius = 4
+
+ if 'dropout' not in self.args:
+ self.args.dropout = 0
+
+ if 'alternate_corr' not in self.args:
+ self.args.alternate_corr = False
+
+ # feature network, context network, and update block
+ if args.small:
+ self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
+ self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
+ self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
+
+ else:
+ self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
+ self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
+ self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
+
+ def freeze_bn(self):
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+
+ def initialize_flow(self, img):
+ """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
+ N, C, H, W = img.shape
+ coords0 = coords_grid(N, H//8, W//8, device=img.device)
+ coords1 = coords_grid(N, H//8, W//8, device=img.device)
+
+ # optical flow computed as difference: flow = coords1 - coords0
+ return coords0, coords1
+
+ def upsample_flow(self, flow, mask):
+ """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
+ N, _, H, W = flow.shape
+ mask = mask.view(N, 1, 9, 8, 8, H, W)
+ mask = torch.softmax(mask, dim=2)
+
+ up_flow = F.unfold(8 * flow, [3,3], padding=1)
+ up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
+
+ up_flow = torch.sum(mask * up_flow, dim=2)
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
+ return up_flow.reshape(N, 2, 8*H, 8*W)
+
+
+ def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
+ """ Estimate optical flow between pair of frames """
+
+ image1 = 2 * (image1 / 255.0) - 1.0
+ image2 = 2 * (image2 / 255.0) - 1.0
+
+ image1 = image1.contiguous()
+ image2 = image2.contiguous()
+
+ hdim = self.hidden_dim
+ cdim = self.context_dim
+
+ # run the feature network
+ with autocast(enabled=self.args.mixed_precision):
+ fmap1, fmap2 = self.fnet([image1, image2])
+
+ fmap1 = fmap1.float()
+ fmap2 = fmap2.float()
+ if self.args.alternate_corr:
+ corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
+ else:
+ corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
+
+ # run the context network
+ with autocast(enabled=self.args.mixed_precision):
+ cnet = self.cnet(image1)
+ net, inp = torch.split(cnet, [hdim, cdim], dim=1)
+ net = torch.tanh(net)
+ inp = torch.relu(inp)
+
+ coords0, coords1 = self.initialize_flow(image1)
+
+ if flow_init is not None:
+ coords1 = coords1 + flow_init
+
+ flow_predictions = []
+ for itr in range(iters):
+ coords1 = coords1.detach()
+ corr = corr_fn(coords1) # index correlation volume
+
+ flow = coords1 - coords0
+ with autocast(enabled=self.args.mixed_precision):
+ net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
+
+ # F(t+1) = F(t) + \Delta(t)
+ coords1 = coords1 + delta_flow
+
+ # upsample predictions
+ if up_mask is None:
+ flow_up = upflow8(coords1 - coords0)
+ else:
+ flow_up = self.upsample_flow(coords1 - coords0, up_mask)
+
+ flow_predictions.append(flow_up)
+
+ if test_mode:
+ return coords1 - coords0, flow_up
+
+ return flow_predictions
diff --git a/AIGVDet/core/update.py b/AIGVDet/core/update.py
new file mode 100644
index 0000000000000000000000000000000000000000..f940497f9b5eb1c12091574fe9a0223a1b196d50
--- /dev/null
+++ b/AIGVDet/core/update.py
@@ -0,0 +1,139 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class FlowHead(nn.Module):
+ def __init__(self, input_dim=128, hidden_dim=256):
+ super(FlowHead, self).__init__()
+ self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
+ self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ return self.conv2(self.relu(self.conv1(x)))
+
+class ConvGRU(nn.Module):
+ def __init__(self, hidden_dim=128, input_dim=192+128):
+ super(ConvGRU, self).__init__()
+ self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
+ self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
+ self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
+
+ def forward(self, h, x):
+ hx = torch.cat([h, x], dim=1)
+
+ z = torch.sigmoid(self.convz(hx))
+ r = torch.sigmoid(self.convr(hx))
+ q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
+
+ h = (1-z) * h + z * q
+ return h
+
+class SepConvGRU(nn.Module):
+ def __init__(self, hidden_dim=128, input_dim=192+128):
+ super(SepConvGRU, self).__init__()
+ self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
+ self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
+ self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
+
+ self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
+ self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
+ self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
+
+
+ def forward(self, h, x):
+ # horizontal
+ hx = torch.cat([h, x], dim=1)
+ z = torch.sigmoid(self.convz1(hx))
+ r = torch.sigmoid(self.convr1(hx))
+ q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
+ h = (1-z) * h + z * q
+
+ # vertical
+ hx = torch.cat([h, x], dim=1)
+ z = torch.sigmoid(self.convz2(hx))
+ r = torch.sigmoid(self.convr2(hx))
+ q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
+ h = (1-z) * h + z * q
+
+ return h
+
+class SmallMotionEncoder(nn.Module):
+ def __init__(self, args):
+ super(SmallMotionEncoder, self).__init__()
+ cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
+ self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
+ self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
+ self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
+ self.conv = nn.Conv2d(128, 80, 3, padding=1)
+
+ def forward(self, flow, corr):
+ cor = F.relu(self.convc1(corr))
+ flo = F.relu(self.convf1(flow))
+ flo = F.relu(self.convf2(flo))
+ cor_flo = torch.cat([cor, flo], dim=1)
+ out = F.relu(self.conv(cor_flo))
+ return torch.cat([out, flow], dim=1)
+
+class BasicMotionEncoder(nn.Module):
+ def __init__(self, args):
+ super(BasicMotionEncoder, self).__init__()
+ cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
+ self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
+ self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
+ self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
+ self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
+ self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
+
+ def forward(self, flow, corr):
+ cor = F.relu(self.convc1(corr))
+ cor = F.relu(self.convc2(cor))
+ flo = F.relu(self.convf1(flow))
+ flo = F.relu(self.convf2(flo))
+
+ cor_flo = torch.cat([cor, flo], dim=1)
+ out = F.relu(self.conv(cor_flo))
+ return torch.cat([out, flow], dim=1)
+
+class SmallUpdateBlock(nn.Module):
+ def __init__(self, args, hidden_dim=96):
+ super(SmallUpdateBlock, self).__init__()
+ self.encoder = SmallMotionEncoder(args)
+ self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
+
+ def forward(self, net, inp, corr, flow):
+ motion_features = self.encoder(flow, corr)
+ inp = torch.cat([inp, motion_features], dim=1)
+ net = self.gru(net, inp)
+ delta_flow = self.flow_head(net)
+
+ return net, None, delta_flow
+
+class BasicUpdateBlock(nn.Module):
+ def __init__(self, args, hidden_dim=128, input_dim=128):
+ super(BasicUpdateBlock, self).__init__()
+ self.args = args
+ self.encoder = BasicMotionEncoder(args)
+ self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
+
+ self.mask = nn.Sequential(
+ nn.Conv2d(128, 256, 3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, 64*9, 1, padding=0))
+
+ def forward(self, net, inp, corr, flow, upsample=True):
+ motion_features = self.encoder(flow, corr)
+ inp = torch.cat([inp, motion_features], dim=1)
+
+ net = self.gru(net, inp)
+ delta_flow = self.flow_head(net)
+
+ # scale mask to balence gradients
+ mask = .25 * self.mask(net)
+ return net, mask, delta_flow
+
+
+
diff --git a/AIGVDet/core/utils/__init__.py b/AIGVDet/core/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/AIGVDet/core/utils/augmentor.py b/AIGVDet/core/utils/augmentor.py
new file mode 100644
index 0000000000000000000000000000000000000000..e81c4f2b5c16c31c0ae236d744f299d430228a04
--- /dev/null
+++ b/AIGVDet/core/utils/augmentor.py
@@ -0,0 +1,246 @@
+import numpy as np
+import random
+import math
+from PIL import Image
+
+import cv2
+cv2.setNumThreads(0)
+cv2.ocl.setUseOpenCL(False)
+
+import torch
+from torchvision.transforms import ColorJitter
+import torch.nn.functional as F
+
+
+class FlowAugmentor:
+ def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
+
+ # spatial augmentation params
+ self.crop_size = crop_size
+ self.min_scale = min_scale
+ self.max_scale = max_scale
+ self.spatial_aug_prob = 0.8
+ self.stretch_prob = 0.8
+ self.max_stretch = 0.2
+
+ # flip augmentation params
+ self.do_flip = do_flip
+ self.h_flip_prob = 0.5
+ self.v_flip_prob = 0.1
+
+ # photometric augmentation params
+ self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
+ self.asymmetric_color_aug_prob = 0.2
+ self.eraser_aug_prob = 0.5
+
+ def color_transform(self, img1, img2):
+ """ Photometric augmentation """
+
+ # asymmetric
+ if np.random.rand() < self.asymmetric_color_aug_prob:
+ img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
+ img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
+
+ # symmetric
+ else:
+ image_stack = np.concatenate([img1, img2], axis=0)
+ image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
+ img1, img2 = np.split(image_stack, 2, axis=0)
+
+ return img1, img2
+
+ def eraser_transform(self, img1, img2, bounds=[50, 100]):
+ """ Occlusion augmentation """
+
+ ht, wd = img1.shape[:2]
+ if np.random.rand() < self.eraser_aug_prob:
+ mean_color = np.mean(img2.reshape(-1, 3), axis=0)
+ for _ in range(np.random.randint(1, 3)):
+ x0 = np.random.randint(0, wd)
+ y0 = np.random.randint(0, ht)
+ dx = np.random.randint(bounds[0], bounds[1])
+ dy = np.random.randint(bounds[0], bounds[1])
+ img2[y0:y0+dy, x0:x0+dx, :] = mean_color
+
+ return img1, img2
+
+ def spatial_transform(self, img1, img2, flow):
+ # randomly sample scale
+ ht, wd = img1.shape[:2]
+ min_scale = np.maximum(
+ (self.crop_size[0] + 8) / float(ht),
+ (self.crop_size[1] + 8) / float(wd))
+
+ scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
+ scale_x = scale
+ scale_y = scale
+ if np.random.rand() < self.stretch_prob:
+ scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
+ scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
+
+ scale_x = np.clip(scale_x, min_scale, None)
+ scale_y = np.clip(scale_y, min_scale, None)
+
+ if np.random.rand() < self.spatial_aug_prob:
+ # rescale the images
+ img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ flow = flow * [scale_x, scale_y]
+
+ if self.do_flip:
+ if np.random.rand() < self.h_flip_prob: # h-flip
+ img1 = img1[:, ::-1]
+ img2 = img2[:, ::-1]
+ flow = flow[:, ::-1] * [-1.0, 1.0]
+
+ if np.random.rand() < self.v_flip_prob: # v-flip
+ img1 = img1[::-1, :]
+ img2 = img2[::-1, :]
+ flow = flow[::-1, :] * [1.0, -1.0]
+
+ y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
+ x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
+
+ img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+ img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+ flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+
+ return img1, img2, flow
+
+ def __call__(self, img1, img2, flow):
+ img1, img2 = self.color_transform(img1, img2)
+ img1, img2 = self.eraser_transform(img1, img2)
+ img1, img2, flow = self.spatial_transform(img1, img2, flow)
+
+ img1 = np.ascontiguousarray(img1)
+ img2 = np.ascontiguousarray(img2)
+ flow = np.ascontiguousarray(flow)
+
+ return img1, img2, flow
+
+class SparseFlowAugmentor:
+ def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
+ # spatial augmentation params
+ self.crop_size = crop_size
+ self.min_scale = min_scale
+ self.max_scale = max_scale
+ self.spatial_aug_prob = 0.8
+ self.stretch_prob = 0.8
+ self.max_stretch = 0.2
+
+ # flip augmentation params
+ self.do_flip = do_flip
+ self.h_flip_prob = 0.5
+ self.v_flip_prob = 0.1
+
+ # photometric augmentation params
+ self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
+ self.asymmetric_color_aug_prob = 0.2
+ self.eraser_aug_prob = 0.5
+
+ def color_transform(self, img1, img2):
+ image_stack = np.concatenate([img1, img2], axis=0)
+ image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
+ img1, img2 = np.split(image_stack, 2, axis=0)
+ return img1, img2
+
+ def eraser_transform(self, img1, img2):
+ ht, wd = img1.shape[:2]
+ if np.random.rand() < self.eraser_aug_prob:
+ mean_color = np.mean(img2.reshape(-1, 3), axis=0)
+ for _ in range(np.random.randint(1, 3)):
+ x0 = np.random.randint(0, wd)
+ y0 = np.random.randint(0, ht)
+ dx = np.random.randint(50, 100)
+ dy = np.random.randint(50, 100)
+ img2[y0:y0+dy, x0:x0+dx, :] = mean_color
+
+ return img1, img2
+
+ def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
+ ht, wd = flow.shape[:2]
+ coords = np.meshgrid(np.arange(wd), np.arange(ht))
+ coords = np.stack(coords, axis=-1)
+
+ coords = coords.reshape(-1, 2).astype(np.float32)
+ flow = flow.reshape(-1, 2).astype(np.float32)
+ valid = valid.reshape(-1).astype(np.float32)
+
+ coords0 = coords[valid>=1]
+ flow0 = flow[valid>=1]
+
+ ht1 = int(round(ht * fy))
+ wd1 = int(round(wd * fx))
+
+ coords1 = coords0 * [fx, fy]
+ flow1 = flow0 * [fx, fy]
+
+ xx = np.round(coords1[:,0]).astype(np.int32)
+ yy = np.round(coords1[:,1]).astype(np.int32)
+
+ v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
+ xx = xx[v]
+ yy = yy[v]
+ flow1 = flow1[v]
+
+ flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
+ valid_img = np.zeros([ht1, wd1], dtype=np.int32)
+
+ flow_img[yy, xx] = flow1
+ valid_img[yy, xx] = 1
+
+ return flow_img, valid_img
+
+ def spatial_transform(self, img1, img2, flow, valid):
+ # randomly sample scale
+
+ ht, wd = img1.shape[:2]
+ min_scale = np.maximum(
+ (self.crop_size[0] + 1) / float(ht),
+ (self.crop_size[1] + 1) / float(wd))
+
+ scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
+ scale_x = np.clip(scale, min_scale, None)
+ scale_y = np.clip(scale, min_scale, None)
+
+ if np.random.rand() < self.spatial_aug_prob:
+ # rescale the images
+ img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
+
+ if self.do_flip:
+ if np.random.rand() < 0.5: # h-flip
+ img1 = img1[:, ::-1]
+ img2 = img2[:, ::-1]
+ flow = flow[:, ::-1] * [-1.0, 1.0]
+ valid = valid[:, ::-1]
+
+ margin_y = 20
+ margin_x = 50
+
+ y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
+ x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
+
+ y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
+ x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
+
+ img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+ img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+ flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+ valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+ return img1, img2, flow, valid
+
+
+ def __call__(self, img1, img2, flow, valid):
+ img1, img2 = self.color_transform(img1, img2)
+ img1, img2 = self.eraser_transform(img1, img2)
+ img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
+
+ img1 = np.ascontiguousarray(img1)
+ img2 = np.ascontiguousarray(img2)
+ flow = np.ascontiguousarray(flow)
+ valid = np.ascontiguousarray(valid)
+
+ return img1, img2, flow, valid
diff --git a/AIGVDet/core/utils/flow_viz.py b/AIGVDet/core/utils/flow_viz.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcee65e89b91b07ee0496aeb4c7e7436abf99641
--- /dev/null
+++ b/AIGVDet/core/utils/flow_viz.py
@@ -0,0 +1,132 @@
+# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
+
+
+# MIT License
+#
+# Copyright (c) 2018 Tom Runia
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to conditions.
+#
+# Author: Tom Runia
+# Date Created: 2018-08-03
+
+import numpy as np
+
+def make_colorwheel():
+ """
+ Generates a color wheel for optical flow visualization as presented in:
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
+
+ Code follows the original C++ source code of Daniel Scharstein.
+ Code follows the the Matlab source code of Deqing Sun.
+
+ Returns:
+ np.ndarray: Color wheel
+ """
+
+ RY = 15
+ YG = 6
+ GC = 4
+ CB = 11
+ BM = 13
+ MR = 6
+
+ ncols = RY + YG + GC + CB + BM + MR
+ colorwheel = np.zeros((ncols, 3))
+ col = 0
+
+ # RY
+ colorwheel[0:RY, 0] = 255
+ colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
+ col = col+RY
+ # YG
+ colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
+ colorwheel[col:col+YG, 1] = 255
+ col = col+YG
+ # GC
+ colorwheel[col:col+GC, 1] = 255
+ colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
+ col = col+GC
+ # CB
+ colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
+ colorwheel[col:col+CB, 2] = 255
+ col = col+CB
+ # BM
+ colorwheel[col:col+BM, 2] = 255
+ colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
+ col = col+BM
+ # MR
+ colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
+ colorwheel[col:col+MR, 0] = 255
+ return colorwheel
+
+
+def flow_uv_to_colors(u, v, convert_to_bgr=False):
+ """
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
+
+ According to the C++ source code of Daniel Scharstein
+ According to the Matlab source code of Deqing Sun
+
+ Args:
+ u (np.ndarray): Input horizontal flow of shape [H,W]
+ v (np.ndarray): Input vertical flow of shape [H,W]
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
+
+ Returns:
+ np.ndarray: Flow visualization image of shape [H,W,3]
+ """
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
+ colorwheel = make_colorwheel() # shape [55x3]
+ ncols = colorwheel.shape[0]
+ rad = np.sqrt(np.square(u) + np.square(v))
+ a = np.arctan2(-v, -u)/np.pi
+ fk = (a+1) / 2*(ncols-1)
+ k0 = np.floor(fk).astype(np.int32)
+ k1 = k0 + 1
+ k1[k1 == ncols] = 0
+ f = fk - k0
+ for i in range(colorwheel.shape[1]):
+ tmp = colorwheel[:,i]
+ col0 = tmp[k0] / 255.0
+ col1 = tmp[k1] / 255.0
+ col = (1-f)*col0 + f*col1
+ idx = (rad <= 1)
+ col[idx] = 1 - rad[idx] * (1-col[idx])
+ col[~idx] = col[~idx] * 0.75 # out of range
+ # Note the 2-i => BGR instead of RGB
+ ch_idx = 2-i if convert_to_bgr else i
+ flow_image[:,:,ch_idx] = np.floor(255 * col)
+ return flow_image
+
+
+def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
+ """
+ Expects a two dimensional flow image of shape.
+
+ Args:
+ flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
+ clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
+
+ Returns:
+ np.ndarray: Flow visualization image of shape [H,W,3]
+ """
+ assert flow_uv.ndim == 3, 'input flow must have three dimensions'
+ assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
+ if clip_flow is not None:
+ flow_uv = np.clip(flow_uv, 0, clip_flow)
+ u = flow_uv[:,:,0]
+ v = flow_uv[:,:,1]
+ rad = np.sqrt(np.square(u) + np.square(v))
+ rad_max = np.max(rad)
+ epsilon = 1e-5
+ u = u / (rad_max + epsilon)
+ v = v / (rad_max + epsilon)
+ return flow_uv_to_colors(u, v, convert_to_bgr)
\ No newline at end of file
diff --git a/AIGVDet/core/utils/frame_utils.py b/AIGVDet/core/utils/frame_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c491135efaffc25bd61ec3ecde99d236f5deb12
--- /dev/null
+++ b/AIGVDet/core/utils/frame_utils.py
@@ -0,0 +1,137 @@
+import numpy as np
+from PIL import Image
+from os.path import *
+import re
+
+import cv2
+cv2.setNumThreads(0)
+cv2.ocl.setUseOpenCL(False)
+
+TAG_CHAR = np.array([202021.25], np.float32)
+
+def readFlow(fn):
+ """ Read .flo file in Middlebury format"""
+ # Code adapted from:
+ # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
+
+ # WARNING: this will work on little-endian architectures (eg Intel x86) only!
+ # print 'fn = %s'%(fn)
+ with open(fn, 'rb') as f:
+ magic = np.fromfile(f, np.float32, count=1)
+ if 202021.25 != magic:
+ print('Magic number incorrect. Invalid .flo file')
+ return None
+ else:
+ w = np.fromfile(f, np.int32, count=1)
+ h = np.fromfile(f, np.int32, count=1)
+ # print 'Reading %d x %d flo file\n' % (w, h)
+ data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
+ # Reshape data into 3D array (columns, rows, bands)
+ # The reshape here is for visualization, the original code is (w,h,2)
+ return np.resize(data, (int(h), int(w), 2))
+
+def readPFM(file):
+ file = open(file, 'rb')
+
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+
+ header = file.readline().rstrip()
+ if header == b'PF':
+ color = True
+ elif header == b'Pf':
+ color = False
+ else:
+ raise Exception('Not a PFM file.')
+
+ dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
+ if dim_match:
+ width, height = map(int, dim_match.groups())
+ else:
+ raise Exception('Malformed PFM header.')
+
+ scale = float(file.readline().rstrip())
+ if scale < 0: # little-endian
+ endian = '<'
+ scale = -scale
+ else:
+ endian = '>' # big-endian
+
+ data = np.fromfile(file, endian + 'f')
+ shape = (height, width, 3) if color else (height, width)
+
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+ return data
+
+def writeFlow(filename,uv,v=None):
+ """ Write optical flow to file.
+
+ If v is None, uv is assumed to contain both u and v channels,
+ stacked in depth.
+ Original code by Deqing Sun, adapted from Daniel Scharstein.
+ """
+ nBands = 2
+
+ if v is None:
+ assert(uv.ndim == 3)
+ assert(uv.shape[2] == 2)
+ u = uv[:,:,0]
+ v = uv[:,:,1]
+ else:
+ u = uv
+
+ assert(u.shape == v.shape)
+ height,width = u.shape
+ f = open(filename,'wb')
+ # write the header
+ f.write(TAG_CHAR)
+ np.array(width).astype(np.int32).tofile(f)
+ np.array(height).astype(np.int32).tofile(f)
+ # arrange into matrix form
+ tmp = np.zeros((height, width*nBands))
+ tmp[:,np.arange(width)*2] = u
+ tmp[:,np.arange(width)*2 + 1] = v
+ tmp.astype(np.float32).tofile(f)
+ f.close()
+
+
+def readFlowKITTI(filename):
+ flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
+ flow = flow[:,:,::-1].astype(np.float32)
+ flow, valid = flow[:, :, :2], flow[:, :, 2]
+ flow = (flow - 2**15) / 64.0
+ return flow, valid
+
+def readDispKITTI(filename):
+ disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
+ valid = disp > 0.0
+ flow = np.stack([-disp, np.zeros_like(disp)], -1)
+ return flow, valid
+
+
+def writeFlowKITTI(filename, uv):
+ uv = 64.0 * uv + 2**15
+ valid = np.ones([uv.shape[0], uv.shape[1], 1])
+ uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
+ cv2.imwrite(filename, uv[..., ::-1])
+
+
+def read_gen(file_name, pil=False):
+ ext = splitext(file_name)[-1]
+ if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
+ return Image.open(file_name)
+ elif ext == '.bin' or ext == '.raw':
+ return np.load(file_name)
+ elif ext == '.flo':
+ return readFlow(file_name).astype(np.float32)
+ elif ext == '.pfm':
+ flow = readPFM(file_name).astype(np.float32)
+ if len(flow.shape) == 2:
+ return flow
+ else:
+ return flow[:, :, :-1]
+ return []
\ No newline at end of file
diff --git a/AIGVDet/core/utils/utils.py b/AIGVDet/core/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..741ccfe4d0d778c3199c586d368edc2882d4fff8
--- /dev/null
+++ b/AIGVDet/core/utils/utils.py
@@ -0,0 +1,82 @@
+import torch
+import torch.nn.functional as F
+import numpy as np
+from scipy import interpolate
+
+
+class InputPadder:
+ """ Pads images such that dimensions are divisible by 8 """
+ def __init__(self, dims, mode='sintel'):
+ self.ht, self.wd = dims[-2:]
+ pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
+ pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
+ if mode == 'sintel':
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
+ else:
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
+
+ def pad(self, *inputs):
+ return [F.pad(x, self._pad, mode='replicate') for x in inputs]
+
+ def unpad(self,x):
+ ht, wd = x.shape[-2:]
+ c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
+ return x[..., c[0]:c[1], c[2]:c[3]]
+
+def forward_interpolate(flow):
+ flow = flow.detach().cpu().numpy()
+ dx, dy = flow[0], flow[1]
+
+ ht, wd = dx.shape
+ x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
+
+ x1 = x0 + dx
+ y1 = y0 + dy
+
+ x1 = x1.reshape(-1)
+ y1 = y1.reshape(-1)
+ dx = dx.reshape(-1)
+ dy = dy.reshape(-1)
+
+ valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
+ x1 = x1[valid]
+ y1 = y1[valid]
+ dx = dx[valid]
+ dy = dy[valid]
+
+ flow_x = interpolate.griddata(
+ (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
+
+ flow_y = interpolate.griddata(
+ (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
+
+ flow = np.stack([flow_x, flow_y], axis=0)
+ return torch.from_numpy(flow).float()
+
+
+def bilinear_sampler(img, coords, mode='bilinear', mask=False):
+ """ Wrapper for grid_sample, uses pixel coordinates """
+ H, W = img.shape[-2:]
+ xgrid, ygrid = coords.split([1,1], dim=-1)
+ xgrid = 2*xgrid/(W-1) - 1
+ ygrid = 2*ygrid/(H-1) - 1
+
+ grid = torch.cat([xgrid, ygrid], dim=-1)
+ img = F.grid_sample(img, grid, align_corners=True)
+
+ if mask:
+ mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
+ return img, mask.float()
+
+ return img
+
+
+def coords_grid(batch, ht, wd, device):
+ coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
+ coords = torch.stack(coords[::-1], dim=0).float()
+ return coords[None].repeat(batch, 1, 1, 1)
+
+
+def upflow8(flow, mode='bilinear'):
+ new_size = (8 * flow.shape[2], 8 * flow.shape[3])
+ return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
diff --git a/AIGVDet/core/utils1/config.py b/AIGVDet/core/utils1/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..577968a6d4455019617e0104b7ca358e6c6a61b4
--- /dev/null
+++ b/AIGVDet/core/utils1/config.py
@@ -0,0 +1,156 @@
+import argparse
+import os
+import sys
+from abc import ABC
+from typing import Type
+
+
+class DefaultConfigs(ABC):
+ ####### base setting ######
+ gpus = [0]
+ seed = 3407
+ arch = "resnet50"
+ datasets = ["zhaolian_train"]
+ datasets_test = ["adm_res_abs_ddim20s"]
+ mode = "binary"
+ class_bal = False
+ batch_size = 64
+ loadSize = 256
+ cropSize = 224
+ epoch = "latest"
+ num_workers = 20
+ serial_batches = False
+ isTrain = True
+
+ # data augmentation
+ rz_interp = ["bilinear"]
+ # blur_prob = 0.0
+ blur_prob = 0.1
+ blur_sig = [0.5]
+ # jpg_prob = 0.0
+ jpg_prob = 0.1
+ jpg_method = ["cv2"]
+ jpg_qual = [75]
+ gray_prob = 0.0
+ aug_resize = True
+ aug_crop = True
+ aug_flip = True
+ aug_norm = True
+
+ ####### train setting ######
+ warmup = False
+ # warmup = True
+ warmup_epoch = 3
+ earlystop = True
+ earlystop_epoch = 5
+ optim = "adam"
+ new_optim = False
+ loss_freq = 400
+ save_latest_freq = 2000
+ save_epoch_freq = 20
+ continue_train = False
+ epoch_count = 1
+ last_epoch = -1
+ nepoch = 400
+ beta1 = 0.9
+ lr = 0.0001
+ init_type = "normal"
+ init_gain = 0.02
+ pretrained = True
+
+ # paths information
+ root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+ dataset_root = os.path.join(root_dir, "data")
+ exp_root = os.path.join(root_dir, "data", "exp")
+ _exp_name = ""
+ exp_dir = ""
+ ckpt_dir = ""
+ logs_path = ""
+ ckpt_path = ""
+
+ @property
+ def exp_name(self):
+ return self._exp_name
+
+ @exp_name.setter
+ def exp_name(self, value: str):
+ self._exp_name = value
+ self.exp_dir: str = os.path.join(self.exp_root, self.exp_name)
+ self.ckpt_dir: str = os.path.join(self.exp_dir, "ckpt")
+ self.logs_path: str = os.path.join(self.exp_dir, "logs.txt")
+
+ os.makedirs(self.exp_dir, exist_ok=True)
+ os.makedirs(self.ckpt_dir, exist_ok=True)
+
+ def to_dict(self):
+ dic = {}
+ for fieldkey in dir(self):
+ fieldvalue = getattr(self, fieldkey)
+ if not fieldkey.startswith("__") and not callable(fieldvalue) and not fieldkey.startswith("_"):
+ dic[fieldkey] = fieldvalue
+ return dic
+
+
+def args_list2dict(arg_list: list):
+ assert len(arg_list) % 2 == 0, f"Override list has odd length: {arg_list}; it must be a list of pairs"
+ return dict(zip(arg_list[::2], arg_list[1::2]))
+
+
+def str2bool(v: str) -> bool:
+ if isinstance(v, bool):
+ return v
+ elif v.lower() in ("true", "yes", "on", "y", "t", "1"):
+ return True
+ elif v.lower() in ("false", "no", "off", "n", "f", "0"):
+ return False
+ else:
+ return bool(v)
+
+
+def str2list(v: str, element_type=None) -> list:
+ if not isinstance(v, (list, tuple, set)):
+ v = v.lstrip("[").rstrip("]")
+ v = v.split(",")
+ v = list(map(str.strip, v))
+ if element_type is not None:
+ v = list(map(element_type, v))
+ return v
+
+
+CONFIGCLASS = Type[DefaultConfigs]
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--gpus", default=[0], type=int, nargs="+")
+parser.add_argument("--exp_name", default="", type=str)
+parser.add_argument("--ckpt", default="model_epoch_latest.pth", type=str)
+parser.add_argument("opts", default=[], nargs=argparse.REMAINDER)
+args = parser.parse_args()
+
+if os.path.exists(os.path.join(DefaultConfigs.exp_root, args.exp_name, "config.py")):
+ sys.path.insert(0, os.path.join(DefaultConfigs.exp_root, args.exp_name))
+ from config import cfg
+
+ cfg: CONFIGCLASS
+else:
+ cfg = DefaultConfigs()
+
+if args.opts:
+ opts = args_list2dict(args.opts)
+ for k, v in opts.items():
+ if not hasattr(cfg, k):
+ raise ValueError(f"Unrecognized option: {k}")
+ original_type = type(getattr(cfg, k))
+ if original_type == bool:
+ setattr(cfg, k, str2bool(v))
+ elif original_type in (list, tuple, set):
+ setattr(cfg, k, str2list(v, type(getattr(cfg, k)[0])))
+ else:
+ setattr(cfg, k, original_type(v))
+
+cfg.gpus: list = args.gpus
+os.environ["CUDA_VISIBLE_DEVICES"] = ", ".join([str(gpu) for gpu in cfg.gpus])
+cfg.exp_name = args.exp_name
+cfg.ckpt_path: str = os.path.join(cfg.ckpt_dir, args.ckpt)
+
+if isinstance(cfg.datasets, str):
+ cfg.datasets = cfg.datasets.split(",")
diff --git a/AIGVDet/core/utils1/datasets.py b/AIGVDet/core/utils1/datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a150dead9ce5b517b4e7de1adc47fdbaac33bc0
--- /dev/null
+++ b/AIGVDet/core/utils1/datasets.py
@@ -0,0 +1,178 @@
+import os
+from io import BytesIO
+from random import choice, random
+
+import cv2
+import numpy as np
+import torch
+import torch.utils.data
+import torchvision.datasets as datasets
+import torchvision.transforms as transforms
+import torchvision.transforms.functional as TF
+from PIL import Image, ImageFile
+from scipy.ndimage import gaussian_filter
+from torch.utils.data.sampler import WeightedRandomSampler
+
+from utils1.config import CONFIGCLASS
+
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+
+def dataset_folder(root: str, cfg: CONFIGCLASS):
+ if cfg.mode == "binary":
+ return binary_dataset(root, cfg)
+ if cfg.mode == "filename":
+ return FileNameDataset(root, cfg)
+ raise ValueError("cfg.mode needs to be binary or filename.")
+
+
+def binary_dataset(root: str, cfg: CONFIGCLASS):
+ identity_transform = transforms.Lambda(lambda img: img)
+
+ rz_func = identity_transform
+
+ if cfg.isTrain:
+ crop_func = transforms.RandomCrop((448,448))
+ else:
+ crop_func = transforms.CenterCrop((448,448)) if cfg.aug_crop else identity_transform
+
+ if cfg.isTrain and cfg.aug_flip:
+ flip_func = transforms.RandomHorizontalFlip()
+ else:
+ flip_func = identity_transform
+
+
+ return datasets.ImageFolder(
+ root,
+ transforms.Compose(
+ [
+ rz_func,
+ #change
+ transforms.Lambda(lambda img: blur_jpg_augment(img, cfg)),
+ crop_func,
+ flip_func,
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ if cfg.aug_norm
+ else identity_transform,
+ ]
+ )
+ )
+
+
+class FileNameDataset(datasets.ImageFolder):
+ def name(self):
+ return 'FileNameDataset'
+
+ def __init__(self, opt, root):
+ self.opt = opt
+ super().__init__(root)
+
+ def __getitem__(self, index):
+ # Loading sample
+ path, target = self.samples[index]
+ return path
+
+
+def blur_jpg_augment(img: Image.Image, cfg: CONFIGCLASS):
+ img: np.ndarray = np.array(img)
+ if cfg.isTrain:
+ if random() < cfg.blur_prob:
+ sig = sample_continuous(cfg.blur_sig)
+ gaussian_blur(img, sig)
+
+ if random() < cfg.jpg_prob:
+ method = sample_discrete(cfg.jpg_method)
+ qual = sample_discrete(cfg.jpg_qual)
+ img = jpeg_from_key(img, qual, method)
+
+ return Image.fromarray(img)
+
+
+def sample_continuous(s: list):
+ if len(s) == 1:
+ return s[0]
+ if len(s) == 2:
+ rg = s[1] - s[0]
+ return random() * rg + s[0]
+ raise ValueError("Length of iterable s should be 1 or 2.")
+
+
+def sample_discrete(s: list):
+ return s[0] if len(s) == 1 else choice(s)
+
+
+def gaussian_blur(img: np.ndarray, sigma: float):
+ gaussian_filter(img[:, :, 0], output=img[:, :, 0], sigma=sigma)
+ gaussian_filter(img[:, :, 1], output=img[:, :, 1], sigma=sigma)
+ gaussian_filter(img[:, :, 2], output=img[:, :, 2], sigma=sigma)
+
+
+def cv2_jpg(img: np.ndarray, compress_val: int) -> np.ndarray:
+ img_cv2 = img[:, :, ::-1]
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), compress_val]
+ result, encimg = cv2.imencode(".jpg", img_cv2, encode_param)
+ decimg = cv2.imdecode(encimg, 1)
+ return decimg[:, :, ::-1]
+
+
+def pil_jpg(img: np.ndarray, compress_val: int):
+ out = BytesIO()
+ img = Image.fromarray(img)
+ img.save(out, format="jpeg", quality=compress_val)
+ img = Image.open(out)
+ # load from memory before ByteIO closes
+ img = np.array(img)
+ out.close()
+ return img
+
+
+jpeg_dict = {"cv2": cv2_jpg, "pil": pil_jpg}
+
+
+def jpeg_from_key(img: np.ndarray, compress_val: int, key: str) -> np.ndarray:
+ method = jpeg_dict[key]
+ return method(img, compress_val)
+
+
+rz_dict = {'bilinear': Image.BILINEAR,
+ 'bicubic': Image.BICUBIC,
+ 'lanczos': Image.LANCZOS,
+ 'nearest': Image.NEAREST}
+def custom_resize(img: Image.Image, cfg: CONFIGCLASS) -> Image.Image:
+ interp = sample_discrete(cfg.rz_interp)
+ return TF.resize(img, cfg.loadSize, interpolation=rz_dict[interp])
+
+
+def get_dataset(cfg: CONFIGCLASS):
+ dset_lst = []
+ for dataset in cfg.datasets:
+ root = os.path.join(cfg.dataset_root, dataset)
+ dset = dataset_folder(root, cfg)
+ dset_lst.append(dset)
+ return torch.utils.data.ConcatDataset(dset_lst)
+
+
+def get_bal_sampler(dataset: torch.utils.data.ConcatDataset):
+ targets = []
+ for d in dataset.datasets:
+ targets.extend(d.targets)
+
+ ratio = np.bincount(targets)
+ w = 1.0 / torch.tensor(ratio, dtype=torch.float)
+ sample_weights = w[targets]
+ return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights))
+
+
+def create_dataloader(cfg: CONFIGCLASS):
+ shuffle = not cfg.serial_batches if (cfg.isTrain and not cfg.class_bal) else False
+ dataset = get_dataset(cfg)
+ sampler = get_bal_sampler(dataset) if cfg.class_bal else None
+
+ return torch.utils.data.DataLoader(
+ dataset,
+ batch_size=cfg.batch_size,
+ shuffle=shuffle,
+ sampler=sampler,
+ num_workers=int(cfg.num_workers),
+ )
diff --git a/AIGVDet/core/utils1/earlystop.py b/AIGVDet/core/utils1/earlystop.py
new file mode 100644
index 0000000000000000000000000000000000000000..741d07e466d4f3750d492de7d0a234193cab6bcb
--- /dev/null
+++ b/AIGVDet/core/utils1/earlystop.py
@@ -0,0 +1,46 @@
+import numpy as np
+
+from utils1.trainer import Trainer
+
+
+class EarlyStopping:
+ """Early stops the training if validation loss doesn't improve after a given patience."""
+
+ def __init__(self, patience=1, verbose=False, delta=0):
+ """
+ Args:
+ patience (int): How long to wait after last time validation loss improved.
+ Default: 7
+ verbose (bool): If True, prints a message for each validation loss improvement.
+ Default: False
+ delta (float): Minimum change in the monitored quantity to qualify as an improvement.
+ Default: 0
+ """
+ self.patience = patience
+ self.verbose = verbose
+ self.counter = 0
+ self.best_score = None
+ self.early_stop = False
+ self.score_max = -np.Inf
+ self.delta = delta
+
+ def __call__(self, score: float, trainer: Trainer):
+ if self.best_score is None:
+ self.best_score = score
+ self.save_checkpoint(score, trainer)
+ elif score < self.best_score - self.delta:
+ self.counter += 1
+ print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
+ if self.counter >= self.patience:
+ self.early_stop = True
+ else:
+ self.best_score = score
+ self.save_checkpoint(score, trainer)
+ self.counter = 0
+
+ def save_checkpoint(self, score: float, trainer: Trainer):
+ """Saves model when validation loss decrease."""
+ if self.verbose:
+ print(f"Validation accuracy increased ({self.score_max:.6f} --> {score:.6f}). Saving model ...")
+ trainer.save_networks("best")
+ self.score_max = score
diff --git a/AIGVDet/core/utils1/eval.py b/AIGVDet/core/utils1/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..eaf62c2f356c039c79055b67bea66762f475e1f9
--- /dev/null
+++ b/AIGVDet/core/utils1/eval.py
@@ -0,0 +1,66 @@
+import math
+import os
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import torch.nn as nn
+
+from utils1.config import CONFIGCLASS
+from utils1.utils import to_cuda
+
+
+def get_val_cfg(cfg: CONFIGCLASS, split="val", copy=True):
+ if copy:
+ from copy import deepcopy
+
+ val_cfg = deepcopy(cfg)
+ else:
+ val_cfg = cfg
+ val_cfg.dataset_root = os.path.join(val_cfg.dataset_root, split)
+ val_cfg.datasets = cfg.datasets_test
+ val_cfg.isTrain = False
+ # val_cfg.aug_resize = False
+ # val_cfg.aug_crop = False
+ val_cfg.aug_flip = False
+ val_cfg.serial_batches = True
+ val_cfg.jpg_method = ["pil"]
+ # Currently assumes jpg_prob, blur_prob 0 or 1
+ if len(val_cfg.blur_sig) == 2:
+ b_sig = val_cfg.blur_sig
+ val_cfg.blur_sig = [(b_sig[0] + b_sig[1]) / 2]
+ if len(val_cfg.jpg_qual) != 1:
+ j_qual = val_cfg.jpg_qual
+ val_cfg.jpg_qual = [int((j_qual[0] + j_qual[-1]) / 2)]
+ return val_cfg
+
+def validate(model: nn.Module, cfg: CONFIGCLASS):
+ from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score
+
+ from utils1.datasets import create_dataloader
+
+ data_loader = create_dataloader(cfg)
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ with torch.no_grad():
+ y_true, y_pred = [], []
+ for data in data_loader:
+ img, label, meta = data if len(data) == 3 else (*data, None)
+ in_tens = to_cuda(img, device)
+ meta = to_cuda(meta, device)
+ predict = model(in_tens, meta).sigmoid()
+ y_pred.extend(predict.flatten().tolist())
+ y_true.extend(label.flatten().tolist())
+
+ y_true, y_pred = np.array(y_true), np.array(y_pred)
+ r_acc = accuracy_score(y_true[y_true == 0], y_pred[y_true == 0] > 0.5)
+ f_acc = accuracy_score(y_true[y_true == 1], y_pred[y_true == 1] > 0.5)
+ acc = accuracy_score(y_true, y_pred > 0.5)
+ ap = average_precision_score(y_true, y_pred)
+ results = {
+ "ACC": acc,
+ "AP": ap,
+ "R_ACC": r_acc,
+ "F_ACC": f_acc,
+ }
+ return results
diff --git a/AIGVDet/core/utils1/trainer.py b/AIGVDet/core/utils1/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..859960375388faf1d97e5c128eab67ab74d2974e
--- /dev/null
+++ b/AIGVDet/core/utils1/trainer.py
@@ -0,0 +1,163 @@
+import os
+
+import torch
+import torch.nn as nn
+from torch.nn import init
+
+from utils1.config import CONFIGCLASS
+from utils1.utils import get_network
+from utils1.warmup import GradualWarmupScheduler
+
+
+class BaseModel(nn.Module):
+ def __init__(self, cfg: CONFIGCLASS):
+ super().__init__()
+ self.cfg = cfg
+ self.total_steps = 0
+ self.isTrain = cfg.isTrain
+ self.save_dir = cfg.ckpt_dir
+ self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ self.model:nn.Module
+ self.model=nn.Module.to(self.device)
+ # self.model.to(self.device)
+ #self.model.load_state_dict(torch.load('./checkpoints/optical.pth'))
+ self.optimizer: torch.optim.Optimizer
+
+ def save_networks(self, epoch: int):
+ save_filename = f"model_epoch_{epoch}.pth"
+ save_path = os.path.join(self.save_dir, save_filename)
+
+ # serialize model and optimizer to dict
+ state_dict = {
+ "model": self.model.state_dict(),
+ "optimizer": self.optimizer.state_dict(),
+ "total_steps": self.total_steps,
+ }
+
+ torch.save(state_dict, save_path)
+
+ # load models from the disk
+ def load_networks(self, epoch: int):
+ load_filename = f"model_epoch_{epoch}.pth"
+ load_path = os.path.join(self.save_dir, load_filename)
+
+ if epoch==0:
+ # load_filename = f"lsun_adm.pth"
+ load_path="checkpoints/optical.pth"
+ print("loading optical path")
+ else :
+ print(f"loading the model from {load_path}")
+
+ # print(f"loading the model from {load_path}")
+
+ # if you are using PyTorch newer than 0.4 (e.g., built from
+ # GitHub source), you can remove str() on self.device
+ state_dict = torch.load(load_path, map_location=self.device)
+ if hasattr(state_dict, "_metadata"):
+ del state_dict._metadata
+
+ self.model.load_state_dict(state_dict["model"])
+ self.total_steps = state_dict["total_steps"]
+
+ if self.isTrain and not self.cfg.new_optim:
+ self.optimizer.load_state_dict(state_dict["optimizer"])
+ # move optimizer state to GPU
+ for state in self.optimizer.state.values():
+ for k, v in state.items():
+ if torch.is_tensor(v):
+ state[k] = v.to(self.device)
+
+ for g in self.optimizer.param_groups:
+ g["lr"] = self.cfg.lr
+
+ def eval(self):
+ self.model.eval()
+
+ def test(self):
+ with torch.no_grad():
+ self.forward()
+
+
+def init_weights(net: nn.Module, init_type="normal", gain=0.02):
+ def init_func(m: nn.Module):
+ classname = m.__class__.__name__
+ if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1):
+ if init_type == "normal":
+ init.normal_(m.weight.data, 0.0, gain)
+ elif init_type == "xavier":
+ init.xavier_normal_(m.weight.data, gain=gain)
+ elif init_type == "kaiming":
+ init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
+ elif init_type == "orthogonal":
+ init.orthogonal_(m.weight.data, gain=gain)
+ else:
+ raise NotImplementedError(f"initialization method [{init_type}] is not implemented")
+ if hasattr(m, "bias") and m.bias is not None:
+ init.constant_(m.bias.data, 0.0)
+ elif classname.find("BatchNorm2d") != -1:
+ init.normal_(m.weight.data, 1.0, gain)
+ init.constant_(m.bias.data, 0.0)
+
+ print(f"initialize network with {init_type}")
+ net.apply(init_func)
+
+
+class Trainer(BaseModel):
+ def name(self):
+ return "Trainer"
+
+ def __init__(self, cfg: CONFIGCLASS):
+ super().__init__(cfg)
+ self.arch = cfg.arch
+ self.model = get_network(self.arch, cfg.isTrain, cfg.continue_train, cfg.init_gain, cfg.pretrained)
+
+ self.loss_fn = nn.BCEWithLogitsLoss()
+ # initialize optimizers
+ if cfg.optim == "adam":
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg.lr, betas=(cfg.beta1, 0.999))
+ elif cfg.optim == "sgd":
+ self.optimizer = torch.optim.SGD(self.model.parameters(), lr=cfg.lr, momentum=0.9, weight_decay=5e-4)
+ else:
+ raise ValueError("optim should be [adam, sgd]")
+ if cfg.warmup:
+ scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
+ self.optimizer, cfg.nepoch - cfg.warmup_epoch, eta_min=1e-6
+ )
+ self.scheduler = GradualWarmupScheduler(
+ self.optimizer, multiplier=1, total_epoch=cfg.warmup_epoch, after_scheduler=scheduler_cosine
+ )
+ self.scheduler.step()
+ if cfg.continue_train:
+ self.load_networks(cfg.epoch)
+ self.model.to(self.device)
+
+
+
+ def adjust_learning_rate(self, min_lr=1e-6):
+ for param_group in self.optimizer.param_groups:
+ param_group["lr"] /= 10.0
+ if param_group["lr"] < min_lr:
+ return False
+ return True
+
+ def set_input(self, input):
+ img, label, meta = input if len(input) == 3 else (input[0], input[1], {})
+ self.input = img.to(self.device)
+ self.label = label.to(self.device).float()
+ for k in meta.keys():
+ if isinstance(meta[k], torch.Tensor):
+ meta[k] = meta[k].to(self.device)
+ self.meta = meta
+
+ def forward(self):
+ self.output = self.model(self.input, self.meta)
+
+ def get_loss(self):
+ return self.loss_fn(self.output.squeeze(1), self.label)
+
+ def optimize_parameters(self):
+ self.forward()
+ self.loss = self.loss_fn(self.output.squeeze(1), self.label)
+ self.optimizer.zero_grad()
+ self.loss.backward()
+ self.optimizer.step()
diff --git a/AIGVDet/core/utils1/utils.py b/AIGVDet/core/utils1/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef671dbc6a55b89ce4b2192052aec32bcd177fef
--- /dev/null
+++ b/AIGVDet/core/utils1/utils.py
@@ -0,0 +1,109 @@
+import argparse
+import os
+import sys
+import time
+import warnings
+from importlib import import_module
+
+import numpy as np
+import torch
+import torch.nn as nn
+from PIL import Image
+
+warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.functional")
+
+
+def str2bool(v: str, strict=True) -> bool:
+ if isinstance(v, bool):
+ return v
+ elif isinstance(v, str):
+ if v.lower() in ("true", "yes", "on" "t", "y", "1"):
+ return True
+ elif v.lower() in ("false", "no", "off", "f", "n", "0"):
+ return False
+ if strict:
+ raise argparse.ArgumentTypeError("Unsupported value encountered.")
+ else:
+ return True
+
+
+def to_cuda(data, device="cuda", exclude_keys: "list[str]" = None):
+ if isinstance(data, torch.Tensor):
+ data = data.to(device)
+ elif isinstance(data, (tuple, list, set)):
+ data = [to_cuda(b, device) for b in data]
+ elif isinstance(data, dict):
+ if exclude_keys is None:
+ exclude_keys = []
+ for k in data.keys():
+ if k not in exclude_keys:
+ data[k] = to_cuda(data[k], device)
+ else:
+ # raise TypeError(f"Unsupported type: {type(data)}")
+ data = data
+ return data
+
+
+class HiddenPrints:
+ def __enter__(self):
+ self._original_stdout = sys.stdout
+ sys.stdout = open(os.devnull, "w")
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ sys.stdout.close()
+ sys.stdout = self._original_stdout
+
+
+class Logger(object):
+ def __init__(self):
+ self.terminal = sys.stdout
+ self.file = None
+
+ def open(self, file, mode=None):
+ if mode is None:
+ mode = "w"
+ self.file = open(file, mode)
+
+ def write(self, message, is_terminal=1, is_file=1):
+ if "\r" in message:
+ is_file = 0
+ if is_terminal == 1:
+ self.terminal.write(message)
+ self.terminal.flush()
+ if is_file == 1:
+ self.file.write(message)
+ self.file.flush()
+
+ def flush(self):
+ # this flush method is needed for python 3 compatibility.
+ # this handles the flush command by doing nothing.
+ # you might want to specify some extra behavior here.
+ pass
+
+
+def get_network(arch: str, isTrain=False, continue_train=False, init_gain=0.02, pretrained=True):
+ if "resnet" in arch:
+ from ...networks.resnet import ResNet
+
+ resnet = getattr(import_module("...networks.resnet", package=__package__), arch)
+ if isTrain:
+ if continue_train:
+ model: ResNet = resnet(num_classes=1)
+ else:
+ model: ResNet = resnet(pretrained=pretrained)
+ model.fc = nn.Linear(2048, 1)
+ nn.init.normal_(model.fc.weight.data, 0.0, init_gain)
+ else:
+ model: ResNet = resnet(num_classes=1)
+ return model
+ else:
+ raise ValueError(f"Unsupported arch: {arch}")
+
+
+def pad_img_to_square(img: np.ndarray):
+ H, W = img.shape[:2]
+ if H != W:
+ new_size = max(H, W)
+ img = np.pad(img, ((0, new_size - H), (0, new_size - W), (0, 0)), mode="constant")
+ assert img.shape[0] == img.shape[1] == new_size
+ return img
diff --git a/AIGVDet/core/utils1/utils1/config.py b/AIGVDet/core/utils1/utils1/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..20ccb7a76f81e7ef66fac6f6ae8917751c85f4f6
--- /dev/null
+++ b/AIGVDet/core/utils1/utils1/config.py
@@ -0,0 +1,157 @@
+import argparse
+import os
+import sys
+from abc import ABC
+from typing import Type
+
+
+class DefaultConfigs(ABC):
+ ####### base setting ######
+ gpus = [0]
+ seed = 3407
+ arch = "resnet50"
+ datasets = ["zhaolian_train"]
+ datasets_test = ["adm_res_abs_ddim20s"]
+ mode = "binary"
+ class_bal = False
+ batch_size = 64
+ loadSize = 256
+ cropSize = 224
+ epoch = "latest"
+ num_workers = 20
+ serial_batches = False
+ isTrain = True
+
+ # data augmentation
+ rz_interp = ["bilinear"]
+ # blur_prob = 0.0
+ blur_prob = 0.1
+ blur_sig = [0.5]
+ # jpg_prob = 0.0
+ jpg_prob = 0.1
+ jpg_method = ["cv2"]
+ jpg_qual = [75]
+ gray_prob = 0.0
+ aug_resize = True
+ aug_crop = True
+ aug_flip = True
+ aug_norm = True
+
+ ####### train setting ######
+ warmup = False
+ # warmup = True
+ warmup_epoch = 3
+ earlystop = True
+ earlystop_epoch = 5
+ optim = "adam"
+ new_optim = False
+ loss_freq = 400
+ save_latest_freq = 2000
+ save_epoch_freq = 20
+ continue_train = False
+ epoch_count = 1
+ last_epoch = -1
+ nepoch = 400
+ beta1 = 0.9
+ lr = 0.0001
+ init_type = "normal"
+ init_gain = 0.02
+ pretrained = True
+
+ # paths information
+ root_dir1 = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+ root_dir = os.path.dirname(root_dir1)
+ dataset_root = os.path.join(root_dir, "data")
+ exp_root = os.path.join(root_dir, "data", "exp")
+ _exp_name = ""
+ exp_dir = ""
+ ckpt_dir = ""
+ logs_path = ""
+ ckpt_path = ""
+
+ @property
+ def exp_name(self):
+ return self._exp_name
+
+ @exp_name.setter
+ def exp_name(self, value: str):
+ self._exp_name = value
+ self.exp_dir: str = os.path.join(self.exp_root, self.exp_name)
+ self.ckpt_dir: str = os.path.join(self.exp_dir, "ckpt")
+ self.logs_path: str = os.path.join(self.exp_dir, "logs.txt")
+
+ os.makedirs(self.exp_dir, exist_ok=True)
+ os.makedirs(self.ckpt_dir, exist_ok=True)
+
+ def to_dict(self):
+ dic = {}
+ for fieldkey in dir(self):
+ fieldvalue = getattr(self, fieldkey)
+ if not fieldkey.startswith("__") and not callable(fieldvalue) and not fieldkey.startswith("_"):
+ dic[fieldkey] = fieldvalue
+ return dic
+
+
+def args_list2dict(arg_list: list):
+ assert len(arg_list) % 2 == 0, f"Override list has odd length: {arg_list}; it must be a list of pairs"
+ return dict(zip(arg_list[::2], arg_list[1::2]))
+
+
+def str2bool(v: str) -> bool:
+ if isinstance(v, bool):
+ return v
+ elif v.lower() in ("true", "yes", "on", "y", "t", "1"):
+ return True
+ elif v.lower() in ("false", "no", "off", "n", "f", "0"):
+ return False
+ else:
+ return bool(v)
+
+
+def str2list(v: str, element_type=None) -> list:
+ if not isinstance(v, (list, tuple, set)):
+ v = v.lstrip("[").rstrip("]")
+ v = v.split(",")
+ v = list(map(str.strip, v))
+ if element_type is not None:
+ v = list(map(element_type, v))
+ return v
+
+
+CONFIGCLASS = Type[DefaultConfigs]
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--gpus", default=[0], type=int, nargs="+")
+parser.add_argument("--exp_name", default="", type=str)
+parser.add_argument("--ckpt", default="model_epoch_latest.pth", type=str)
+parser.add_argument("opts", default=[], nargs=argparse.REMAINDER)
+args = parser.parse_args()
+
+if os.path.exists(os.path.join(DefaultConfigs.exp_root, args.exp_name, "config.py")):
+ sys.path.insert(0, os.path.join(DefaultConfigs.exp_root, args.exp_name))
+ from config import cfg
+
+ cfg: CONFIGCLASS
+else:
+ cfg = DefaultConfigs()
+
+if args.opts:
+ opts = args_list2dict(args.opts)
+ for k, v in opts.items():
+ if not hasattr(cfg, k):
+ raise ValueError(f"Unrecognized option: {k}")
+ original_type = type(getattr(cfg, k))
+ if original_type == bool:
+ setattr(cfg, k, str2bool(v))
+ elif original_type in (list, tuple, set):
+ setattr(cfg, k, str2list(v, type(getattr(cfg, k)[0])))
+ else:
+ setattr(cfg, k, original_type(v))
+
+cfg.gpus: list = args.gpus
+os.environ["CUDA_VISIBLE_DEVICES"] = ", ".join([str(gpu) for gpu in cfg.gpus])
+cfg.exp_name = args.exp_name
+cfg.ckpt_path: str = os.path.join(cfg.ckpt_dir, args.ckpt)
+
+if isinstance(cfg.datasets, str):
+ cfg.datasets = cfg.datasets.split(",")
diff --git a/AIGVDet/core/utils1/utils1/datasets.py b/AIGVDet/core/utils1/utils1/datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..a35863aebae5bbfd6f4a844fc6ac360e84601ed8
--- /dev/null
+++ b/AIGVDet/core/utils1/utils1/datasets.py
@@ -0,0 +1,178 @@
+import os
+from io import BytesIO
+from random import choice, random
+
+import cv2
+import numpy as np
+import torch
+import torch.utils.data
+import torchvision.datasets as datasets
+import torchvision.transforms as transforms
+import torchvision.transforms.functional as TF
+from PIL import Image, ImageFile
+from scipy.ndimage import gaussian_filter
+from torch.utils.data.sampler import WeightedRandomSampler
+
+from .config import CONFIGCLASS
+
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+
+def dataset_folder(root: str, cfg: CONFIGCLASS):
+ if cfg.mode == "binary":
+ return binary_dataset(root, cfg)
+ if cfg.mode == "filename":
+ return FileNameDataset(root, cfg)
+ raise ValueError("cfg.mode needs to be binary or filename.")
+
+
+def binary_dataset(root: str, cfg: CONFIGCLASS):
+ identity_transform = transforms.Lambda(lambda img: img)
+
+ rz_func = identity_transform
+
+ if cfg.isTrain:
+ crop_func = transforms.RandomCrop((448,448))
+ else:
+ crop_func = transforms.CenterCrop((448,448)) if cfg.aug_crop else identity_transform
+
+ if cfg.isTrain and cfg.aug_flip:
+ flip_func = transforms.RandomHorizontalFlip()
+ else:
+ flip_func = identity_transform
+
+
+ return datasets.ImageFolder(
+ root,
+ transforms.Compose(
+ [
+ rz_func,
+ #change
+ transforms.Lambda(lambda img: blur_jpg_augment(img, cfg)),
+ crop_func,
+ flip_func,
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ if cfg.aug_norm
+ else identity_transform,
+ ]
+ )
+ )
+
+
+class FileNameDataset(datasets.ImageFolder):
+ def name(self):
+ return 'FileNameDataset'
+
+ def __init__(self, opt, root):
+ self.opt = opt
+ super().__init__(root)
+
+ def __getitem__(self, index):
+ # Loading sample
+ path, target = self.samples[index]
+ return path
+
+
+def blur_jpg_augment(img: Image.Image, cfg: CONFIGCLASS):
+ img: np.ndarray = np.array(img)
+ if cfg.isTrain:
+ if random() < cfg.blur_prob:
+ sig = sample_continuous(cfg.blur_sig)
+ gaussian_blur(img, sig)
+
+ if random() < cfg.jpg_prob:
+ method = sample_discrete(cfg.jpg_method)
+ qual = sample_discrete(cfg.jpg_qual)
+ img = jpeg_from_key(img, qual, method)
+
+ return Image.fromarray(img)
+
+
+def sample_continuous(s: list):
+ if len(s) == 1:
+ return s[0]
+ if len(s) == 2:
+ rg = s[1] - s[0]
+ return random() * rg + s[0]
+ raise ValueError("Length of iterable s should be 1 or 2.")
+
+
+def sample_discrete(s: list):
+ return s[0] if len(s) == 1 else choice(s)
+
+
+def gaussian_blur(img: np.ndarray, sigma: float):
+ gaussian_filter(img[:, :, 0], output=img[:, :, 0], sigma=sigma)
+ gaussian_filter(img[:, :, 1], output=img[:, :, 1], sigma=sigma)
+ gaussian_filter(img[:, :, 2], output=img[:, :, 2], sigma=sigma)
+
+
+def cv2_jpg(img: np.ndarray, compress_val: int) -> np.ndarray:
+ img_cv2 = img[:, :, ::-1]
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), compress_val]
+ result, encimg = cv2.imencode(".jpg", img_cv2, encode_param)
+ decimg = cv2.imdecode(encimg, 1)
+ return decimg[:, :, ::-1]
+
+
+def pil_jpg(img: np.ndarray, compress_val: int):
+ out = BytesIO()
+ img = Image.fromarray(img)
+ img.save(out, format="jpeg", quality=compress_val)
+ img = Image.open(out)
+ # load from memory before ByteIO closes
+ img = np.array(img)
+ out.close()
+ return img
+
+
+jpeg_dict = {"cv2": cv2_jpg, "pil": pil_jpg}
+
+
+def jpeg_from_key(img: np.ndarray, compress_val: int, key: str) -> np.ndarray:
+ method = jpeg_dict[key]
+ return method(img, compress_val)
+
+
+rz_dict = {'bilinear': Image.BILINEAR,
+ 'bicubic': Image.BICUBIC,
+ 'lanczos': Image.LANCZOS,
+ 'nearest': Image.NEAREST}
+def custom_resize(img: Image.Image, cfg: CONFIGCLASS) -> Image.Image:
+ interp = sample_discrete(cfg.rz_interp)
+ return TF.resize(img, cfg.loadSize, interpolation=rz_dict[interp])
+
+
+def get_dataset(cfg: CONFIGCLASS):
+ dset_lst = []
+ for dataset in cfg.datasets:
+ root = os.path.join(cfg.dataset_root, dataset)
+ dset = dataset_folder(root, cfg)
+ dset_lst.append(dset)
+ return torch.utils.data.ConcatDataset(dset_lst)
+
+
+def get_bal_sampler(dataset: torch.utils.data.ConcatDataset):
+ targets = []
+ for d in dataset.datasets:
+ targets.extend(d.targets)
+
+ ratio = np.bincount(targets)
+ w = 1.0 / torch.tensor(ratio, dtype=torch.float)
+ sample_weights = w[targets]
+ return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights))
+
+
+def create_dataloader(cfg: CONFIGCLASS):
+ shuffle = not cfg.serial_batches if (cfg.isTrain and not cfg.class_bal) else False
+ dataset = get_dataset(cfg)
+ sampler = get_bal_sampler(dataset) if cfg.class_bal else None
+
+ return torch.utils.data.DataLoader(
+ dataset,
+ batch_size=cfg.batch_size,
+ shuffle=shuffle,
+ sampler=sampler,
+ num_workers=int(cfg.num_workers),
+ )
diff --git a/AIGVDet/core/utils1/utils1/earlystop.py b/AIGVDet/core/utils1/utils1/earlystop.py
new file mode 100644
index 0000000000000000000000000000000000000000..25aef71789c0c15e25509c42f28a1927582fd47d
--- /dev/null
+++ b/AIGVDet/core/utils1/utils1/earlystop.py
@@ -0,0 +1,46 @@
+import numpy as np
+
+from .trainer import Trainer
+
+
+class EarlyStopping:
+ """Early stops the training if validation loss doesn't improve after a given patience."""
+
+ def __init__(self, patience=1, verbose=False, delta=0):
+ """
+ Args:
+ patience (int): How long to wait after last time validation loss improved.
+ Default: 7
+ verbose (bool): If True, prints a message for each validation loss improvement.
+ Default: False
+ delta (float): Minimum change in the monitored quantity to qualify as an improvement.
+ Default: 0
+ """
+ self.patience = patience
+ self.verbose = verbose
+ self.counter = 0
+ self.best_score = None
+ self.early_stop = False
+ self.score_max = -np.Inf
+ self.delta = delta
+
+ def __call__(self, score: float, trainer: Trainer):
+ if self.best_score is None:
+ self.best_score = score
+ self.save_checkpoint(score, trainer)
+ elif score < self.best_score - self.delta:
+ self.counter += 1
+ print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
+ if self.counter >= self.patience:
+ self.early_stop = True
+ else:
+ self.best_score = score
+ self.save_checkpoint(score, trainer)
+ self.counter = 0
+
+ def save_checkpoint(self, score: float, trainer: Trainer):
+ """Saves model when validation loss decrease."""
+ if self.verbose:
+ print(f"Validation accuracy increased ({self.score_max:.6f} --> {score:.6f}). Saving model ...")
+ trainer.save_networks("best")
+ self.score_max = score
diff --git a/AIGVDet/core/utils1/utils1/eval.py b/AIGVDet/core/utils1/utils1/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..5784de313b831b3c9ef49c91e60a62dd8765e5d3
--- /dev/null
+++ b/AIGVDet/core/utils1/utils1/eval.py
@@ -0,0 +1,66 @@
+import math
+import os
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import torch.nn as nn
+
+from .config import CONFIGCLASS
+from .utils import to_cuda
+
+
+def get_val_cfg(cfg: CONFIGCLASS, split="val", copy=True):
+ if copy:
+ from copy import deepcopy
+
+ val_cfg = deepcopy(cfg)
+ else:
+ val_cfg = cfg
+ val_cfg.dataset_root = os.path.join(val_cfg.dataset_root, split)
+ val_cfg.datasets = cfg.datasets_test
+ val_cfg.isTrain = False
+ # val_cfg.aug_resize = False
+ # val_cfg.aug_crop = False
+ val_cfg.aug_flip = False
+ val_cfg.serial_batches = True
+ val_cfg.jpg_method = ["pil"]
+ # Currently assumes jpg_prob, blur_prob 0 or 1
+ if len(val_cfg.blur_sig) == 2:
+ b_sig = val_cfg.blur_sig
+ val_cfg.blur_sig = [(b_sig[0] + b_sig[1]) / 2]
+ if len(val_cfg.jpg_qual) != 1:
+ j_qual = val_cfg.jpg_qual
+ val_cfg.jpg_qual = [int((j_qual[0] + j_qual[-1]) / 2)]
+ return val_cfg
+
+def validate(model: nn.Module, cfg: CONFIGCLASS):
+ from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score
+
+ from .datasets import create_dataloader
+
+ data_loader = create_dataloader(cfg)
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ with torch.no_grad():
+ y_true, y_pred = [], []
+ for data in data_loader:
+ img, label, meta = data if len(data) == 3 else (*data, None)
+ in_tens = to_cuda(img, device)
+ meta = to_cuda(meta, device)
+ predict = model(in_tens, meta).sigmoid()
+ y_pred.extend(predict.flatten().tolist())
+ y_true.extend(label.flatten().tolist())
+
+ y_true, y_pred = np.array(y_true), np.array(y_pred)
+ r_acc = accuracy_score(y_true[y_true == 0], y_pred[y_true == 0] > 0.5)
+ f_acc = accuracy_score(y_true[y_true == 1], y_pred[y_true == 1] > 0.5)
+ acc = accuracy_score(y_true, y_pred > 0.5)
+ ap = average_precision_score(y_true, y_pred)
+ results = {
+ "ACC": acc,
+ "AP": ap,
+ "R_ACC": r_acc,
+ "F_ACC": f_acc,
+ }
+ return results
diff --git a/AIGVDet/core/utils1/utils1/trainer.py b/AIGVDet/core/utils1/utils1/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a4ce8bb164c3237dd11801087a6521abcc020f2
--- /dev/null
+++ b/AIGVDet/core/utils1/utils1/trainer.py
@@ -0,0 +1,169 @@
+import os
+
+import torch
+import torch.nn as nn
+from torch.nn import init
+
+from .config import CONFIGCLASS
+from .utils import get_network
+from .warmup import GradualWarmupScheduler
+
+
+class BaseModel(nn.Module):
+ def __init__(self, cfg: CONFIGCLASS):
+ super().__init__()
+ self.cfg = cfg
+ self.total_steps = 0
+ self.isTrain = cfg.isTrain
+ self.save_dir = cfg.ckpt_dir
+ self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ self.model:nn.Module
+ self.model=nn.Module.to(self.device)
+ # self.model.to(self.device)
+ self.model.load_state_dict(torch.load('./checkpoints/optical.pth'))
+ self.optimizer: torch.optim.Optimizer
+
+ def save_networks(self, epoch: int):
+ save_filename = f"model_epoch_{epoch}.pth"
+ save_path = os.path.join(self.save_dir, save_filename)
+
+ # serialize model and optimizer to dict
+ state_dict = {
+ "model": self.model.state_dict(),
+ "optimizer": self.optimizer.state_dict(),
+ "total_steps": self.total_steps,
+ }
+
+ torch.save(state_dict, save_path)
+
+ # load models from the disk
+ def load_networks(self, epoch: int):
+ load_filename = f"model_epoch_{epoch}.pth"
+ load_path = os.path.join(self.save_dir, load_filename)
+
+ if epoch==0:
+ # load_filename = f"lsun_adm.pth"
+ load_path="checkpoints/optical.pth"
+ print("loading optical path")
+ else :
+ print(f"loading the model from {load_path}")
+
+ # print(f"loading the model from {load_path}")
+
+ # if you are using PyTorch newer than 0.4 (e.g., built from
+ # GitHub source), you can remove str() on self.device
+ state_dict = torch.load(load_path, map_location=self.device)
+ if hasattr(state_dict, "_metadata"):
+ del state_dict._metadata
+
+ self.model.load_state_dict(state_dict["model"])
+ self.total_steps = state_dict["total_steps"]
+
+ if self.isTrain and not self.cfg.new_optim:
+ self.optimizer.load_state_dict(state_dict["optimizer"])
+ # move optimizer state to GPU
+ for state in self.optimizer.state.values():
+ for k, v in state.items():
+ if torch.is_tensor(v):
+ state[k] = v.to(self.device)
+
+ for g in self.optimizer.param_groups:
+ g["lr"] = self.cfg.lr
+
+ def eval(self):
+ self.model.eval()
+
+ def test(self):
+ with torch.no_grad():
+ self.forward()
+
+
+def init_weights(net: nn.Module, init_type="normal", gain=0.02):
+ def init_func(m: nn.Module):
+ classname = m.__class__.__name__
+ if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1):
+ if init_type == "normal":
+ init.normal_(m.weight.data, 0.0, gain)
+ elif init_type == "xavier":
+ init.xavier_normal_(m.weight.data, gain=gain)
+ elif init_type == "kaiming":
+ init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
+ elif init_type == "orthogonal":
+ init.orthogonal_(m.weight.data, gain=gain)
+ else:
+ raise NotImplementedError(f"initialization method [{init_type}] is not implemented")
+ if hasattr(m, "bias") and m.bias is not None:
+ init.constant_(m.bias.data, 0.0)
+ elif classname.find("BatchNorm2d") != -1:
+ init.normal_(m.weight.data, 1.0, gain)
+ init.constant_(m.bias.data, 0.0)
+
+ print(f"initialize network with {init_type}")
+ net.apply(init_func)
+
+
+class Trainer(BaseModel):
+ def name(self):
+ return "Trainer"
+
+ def __init__(self, cfg: CONFIGCLASS):
+ super().__init__(cfg)
+ self.arch = cfg.arch
+ self.model = get_network(self.arch, cfg.isTrain, cfg.continue_train, cfg.init_gain, cfg.pretrained)
+
+ self.loss_fn = nn.BCEWithLogitsLoss()
+ # initialize optimizers
+ if cfg.optim == "adam":
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg.lr, betas=(cfg.beta1, 0.999))
+ elif cfg.optim == "sgd":
+ self.optimizer = torch.optim.SGD(self.model.parameters(), lr=cfg.lr, momentum=0.9, weight_decay=5e-4)
+ else:
+ raise ValueError("optim should be [adam, sgd]")
+ if cfg.warmup:
+ scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
+ self.optimizer, cfg.nepoch - cfg.warmup_epoch, eta_min=1e-6
+ )
+ self.scheduler = GradualWarmupScheduler(
+ self.optimizer, multiplier=1, total_epoch=cfg.warmup_epoch, after_scheduler=scheduler_cosine
+ )
+ self.scheduler.step()
+ if cfg.continue_train:
+ self.load_networks(cfg.epoch)
+ self.model.to(self.device)
+
+ # self.model.load_state_dict(torch.load('checkpoints/optical.pth'))
+ load_path='checkpoints/optical.pth'
+ state_dict = torch.load(load_path, map_location=self.device)
+
+
+ self.model.load_state_dict(state_dict["model"])
+
+
+ def adjust_learning_rate(self, min_lr=1e-6):
+ for param_group in self.optimizer.param_groups:
+ param_group["lr"] /= 10.0
+ if param_group["lr"] < min_lr:
+ return False
+ return True
+
+ def set_input(self, input):
+ img, label, meta = input if len(input) == 3 else (input[0], input[1], {})
+ self.input = img.to(self.device)
+ self.label = label.to(self.device).float()
+ for k in meta.keys():
+ if isinstance(meta[k], torch.Tensor):
+ meta[k] = meta[k].to(self.device)
+ self.meta = meta
+
+ def forward(self):
+ self.output = self.model(self.input, self.meta)
+
+ def get_loss(self):
+ return self.loss_fn(self.output.squeeze(1), self.label)
+
+ def optimize_parameters(self):
+ self.forward()
+ self.loss = self.loss_fn(self.output.squeeze(1), self.label)
+ self.optimizer.zero_grad()
+ self.loss.backward()
+ self.optimizer.step()
diff --git a/AIGVDet/core/utils1/utils1/utils.py b/AIGVDet/core/utils1/utils1/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d52ebbdba20a1b61f3c26cd7bca65a21107b965b
--- /dev/null
+++ b/AIGVDet/core/utils1/utils1/utils.py
@@ -0,0 +1,109 @@
+import argparse
+import os
+import sys
+import time
+import warnings
+from importlib import import_module
+
+import numpy as np
+import torch
+import torch.nn as nn
+from PIL import Image
+
+warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.functional")
+
+
+def str2bool(v: str, strict=True) -> bool:
+ if isinstance(v, bool):
+ return v
+ elif isinstance(v, str):
+ if v.lower() in ("true", "yes", "on" "t", "y", "1"):
+ return True
+ elif v.lower() in ("false", "no", "off", "f", "n", "0"):
+ return False
+ if strict:
+ raise argparse.ArgumentTypeError("Unsupported value encountered.")
+ else:
+ return True
+
+
+def to_cuda(data, device="cuda", exclude_keys: "list[str]" = None):
+ if isinstance(data, torch.Tensor):
+ data = data.to(device)
+ elif isinstance(data, (tuple, list, set)):
+ data = [to_cuda(b, device) for b in data]
+ elif isinstance(data, dict):
+ if exclude_keys is None:
+ exclude_keys = []
+ for k in data.keys():
+ if k not in exclude_keys:
+ data[k] = to_cuda(data[k], device)
+ else:
+ # raise TypeError(f"Unsupported type: {type(data)}")
+ data = data
+ return data
+
+
+class HiddenPrints:
+ def __enter__(self):
+ self._original_stdout = sys.stdout
+ sys.stdout = open(os.devnull, "w")
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ sys.stdout.close()
+ sys.stdout = self._original_stdout
+
+
+class Logger(object):
+ def __init__(self):
+ self.terminal = sys.stdout
+ self.file = None
+
+ def open(self, file, mode=None):
+ if mode is None:
+ mode = "w"
+ self.file = open(file, mode)
+
+ def write(self, message, is_terminal=1, is_file=1):
+ if "\r" in message:
+ is_file = 0
+ if is_terminal == 1:
+ self.terminal.write(message)
+ self.terminal.flush()
+ if is_file == 1:
+ self.file.write(message)
+ self.file.flush()
+
+ def flush(self):
+ # this flush method is needed for python 3 compatibility.
+ # this handles the flush command by doing nothing.
+ # you might want to specify some extra behavior here.
+ pass
+
+
+def get_network(arch: str, isTrain=False, continue_train=False, init_gain=0.02, pretrained=True):
+ if "resnet" in arch:
+ from networks.resnet import ResNet
+
+ resnet = getattr(import_module("networks.resnet"), arch)
+ if isTrain:
+ if continue_train:
+ model: ResNet = resnet(num_classes=1)
+ else:
+ model: ResNet = resnet(pretrained=pretrained)
+ model.fc = nn.Linear(2048, 1)
+ nn.init.normal_(model.fc.weight.data, 0.0, init_gain)
+ else:
+ model: ResNet = resnet(num_classes=1)
+ return model
+ else:
+ raise ValueError(f"Unsupported arch: {arch}")
+
+
+def pad_img_to_square(img: np.ndarray):
+ H, W = img.shape[:2]
+ if H != W:
+ new_size = max(H, W)
+ img = np.pad(img, ((0, new_size - H), (0, new_size - W), (0, 0)), mode="constant")
+ assert img.shape[0] == img.shape[1] == new_size
+ return img
diff --git a/AIGVDet/core/utils1/utils1/warmup.py b/AIGVDet/core/utils1/utils1/warmup.py
new file mode 100644
index 0000000000000000000000000000000000000000..c193a6cfa664f6330f79f01d59bbf9f8d43272a9
--- /dev/null
+++ b/AIGVDet/core/utils1/utils1/warmup.py
@@ -0,0 +1,70 @@
+from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler
+
+
+class GradualWarmupScheduler(_LRScheduler):
+ """Gradually warm-up(increasing) learning rate in optimizer.
+ Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
+ total_epoch: target learning rate is reached at total_epoch, gradually
+ after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
+ """
+
+ def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
+ self.multiplier = multiplier
+ if self.multiplier < 1.0:
+ raise ValueError("multiplier should be greater thant or equal to 1.")
+ self.total_epoch = total_epoch
+ self.after_scheduler = after_scheduler
+ self.finished = False
+ super().__init__(optimizer)
+
+ def get_lr(self):
+ if self.last_epoch > self.total_epoch:
+ if self.after_scheduler:
+ if not self.finished:
+ self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
+ self.finished = True
+ return self.after_scheduler.get_last_lr()
+ return [base_lr * self.multiplier for base_lr in self.base_lrs]
+
+ if self.multiplier == 1.0:
+ return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
+ else:
+ return [
+ base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
+ for base_lr in self.base_lrs
+ ]
+
+ def step_ReduceLROnPlateau(self, metrics, epoch=None):
+ if epoch is None:
+ epoch = self.last_epoch + 1
+ self.last_epoch = (
+ epoch if epoch != 0 else 1
+ ) # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
+ if self.last_epoch <= self.total_epoch:
+ warmup_lr = [
+ base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
+ for base_lr in self.base_lrs
+ ]
+ for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
+ param_group["lr"] = lr
+ else:
+ if epoch is None:
+ self.after_scheduler.step(metrics, None)
+ else:
+ self.after_scheduler.step(metrics, epoch - self.total_epoch)
+
+ def step(self, epoch=None, metrics=None):
+ if type(self.after_scheduler) != ReduceLROnPlateau:
+ if self.finished and self.after_scheduler:
+ if epoch is None:
+ self.after_scheduler.step(None)
+ else:
+ self.after_scheduler.step(epoch - self.total_epoch)
+ else:
+ return super().step(epoch)
+ else:
+ self.step_ReduceLROnPlateau(metrics, epoch)
diff --git a/AIGVDet/core/utils1/warmup.py b/AIGVDet/core/utils1/warmup.py
new file mode 100644
index 0000000000000000000000000000000000000000..c193a6cfa664f6330f79f01d59bbf9f8d43272a9
--- /dev/null
+++ b/AIGVDet/core/utils1/warmup.py
@@ -0,0 +1,70 @@
+from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler
+
+
+class GradualWarmupScheduler(_LRScheduler):
+ """Gradually warm-up(increasing) learning rate in optimizer.
+ Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
+ total_epoch: target learning rate is reached at total_epoch, gradually
+ after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
+ """
+
+ def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
+ self.multiplier = multiplier
+ if self.multiplier < 1.0:
+ raise ValueError("multiplier should be greater thant or equal to 1.")
+ self.total_epoch = total_epoch
+ self.after_scheduler = after_scheduler
+ self.finished = False
+ super().__init__(optimizer)
+
+ def get_lr(self):
+ if self.last_epoch > self.total_epoch:
+ if self.after_scheduler:
+ if not self.finished:
+ self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
+ self.finished = True
+ return self.after_scheduler.get_last_lr()
+ return [base_lr * self.multiplier for base_lr in self.base_lrs]
+
+ if self.multiplier == 1.0:
+ return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
+ else:
+ return [
+ base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
+ for base_lr in self.base_lrs
+ ]
+
+ def step_ReduceLROnPlateau(self, metrics, epoch=None):
+ if epoch is None:
+ epoch = self.last_epoch + 1
+ self.last_epoch = (
+ epoch if epoch != 0 else 1
+ ) # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
+ if self.last_epoch <= self.total_epoch:
+ warmup_lr = [
+ base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
+ for base_lr in self.base_lrs
+ ]
+ for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
+ param_group["lr"] = lr
+ else:
+ if epoch is None:
+ self.after_scheduler.step(metrics, None)
+ else:
+ self.after_scheduler.step(metrics, epoch - self.total_epoch)
+
+ def step(self, epoch=None, metrics=None):
+ if type(self.after_scheduler) != ReduceLROnPlateau:
+ if self.finished and self.after_scheduler:
+ if epoch is None:
+ self.after_scheduler.step(None)
+ else:
+ self.after_scheduler.step(epoch - self.total_epoch)
+ else:
+ return super().step(epoch)
+ else:
+ self.step_ReduceLROnPlateau(metrics, epoch)
diff --git a/AIGVDet/docker-compose.yml b/AIGVDet/docker-compose.yml
new file mode 100644
index 0000000000000000000000000000000000000000..f9f12b97f5214dd87145c42f1f06171963bc25b7
--- /dev/null
+++ b/AIGVDet/docker-compose.yml
@@ -0,0 +1,17 @@
+version: "3.9"
+
+services:
+ web:
+ build: .
+ ports:
+ - "8003:8003"
+ restart: unless-stopped
+ deploy:
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ count: all
+ capabilities: [gpu]
+ ipc: host
+ shm_size: "24gb"
diff --git a/AIGVDet/main.py b/AIGVDet/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..6be77c985928d871c357e089c0423e892854d72f
--- /dev/null
+++ b/AIGVDet/main.py
@@ -0,0 +1,78 @@
+import os
+import json
+import shutil
+from typing import Dict, Optional
+
+from .run import RUN
+
+
+def run_video_to_json(
+ video_path: str,
+ output_json_path: Optional[str] = None,
+ model_optical_path: str = "checkpoints/optical.pth",
+ model_original_path: str = "checkpoints/original.pth",
+ frame_root: str = "frame",
+ optical_root: str = "optical_result"
+) -> Dict:
+ """
+ Xử lý 1 video và ghi kết quả ra file JSON.
+
+ Returns:
+ dict kết quả (đồng thời ghi ra JSON)
+ """
+
+ script_dir = os.path.dirname(os.path.abspath(__file__))
+ if not os.path.isabs(model_optical_path):
+ model_optical_path = os.path.join(script_dir, model_optical_path)
+ if not os.path.isabs(model_original_path):
+ model_original_path = os.path.join(script_dir, model_original_path)
+
+ results = {}
+
+ if not os.path.isfile(video_path):
+ raise FileNotFoundError(f"File not found: {video_path}")
+
+ video_name = os.path.basename(video_path)
+ video_id = os.path.splitext(video_name)[0]
+
+ folder_original = os.path.join(frame_root, video_id)
+ folder_optical = os.path.join(optical_root, video_id)
+
+ args = [
+ '--path', video_path,
+ '--folder_original_path', folder_original,
+ '--folder_optical_flow_path', folder_optical,
+ '--model_optical_flow_path', model_optical_path,
+ '--model_original_path', model_original_path
+ ]
+
+ try:
+ output = RUN(args)
+ results[video_id] = {
+ "video_name": video_name,
+ "authentic_confidence_score": round(output["real_score"], 4),
+ "synthetic_confidence_score": round(output["fake_score"], 4)
+ }
+
+ except Exception as e:
+ results[video_id] = {
+ "video_name": video_name,
+ "error": str(e)
+ }
+
+ finally:
+ # Clean up intermediate folders
+ for folder in [folder_original, folder_optical]:
+ try:
+ if os.path.exists(folder):
+ shutil.rmtree(folder)
+ except Exception:
+ pass
+
+ # Save result to JSON or return dict
+ if output_json_path:
+ os.makedirs(os.path.dirname(output_json_path), exist_ok=True)
+ with open(output_json_path, "w", encoding="utf-8") as f:
+ json.dump(results, f, ensure_ascii=False, indent=2)
+
+ return results
diff --git a/AIGVDet/networks/resnet.py b/AIGVDet/networks/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c10a34d937d238e56114dd090051c5abec67ede1
--- /dev/null
+++ b/AIGVDet/networks/resnet.py
@@ -0,0 +1,211 @@
+import torch.nn as nn
+import torch.utils.model_zoo as model_zoo
+
+__all__ = ["ResNet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152"]
+
+
+model_urls = {
+ "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
+ "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
+ "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
+ "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
+ "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super().__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super().__init__()
+ self.conv1 = conv1x1(inplanes, planes)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = conv3x3(planes, planes, stride)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = conv1x1(planes, planes * self.expansion)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):
+ super().__init__()
+ self.inplanes = 64
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0)
+ elif isinstance(m, BasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = [block(self.inplanes, planes, stride, downsample)]
+ self.inplanes = planes * block.expansion
+ layers.extend(block(self.inplanes, planes) for _ in range(1, blocks))
+ return nn.Sequential(*layers)
+
+ def forward(self, x, *args):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ x = x.view(x.size(0), -1)
+ x = self.fc(x)
+
+ return x
+
+
+def resnet18(pretrained=False, **kwargs):
+ """Constructs a ResNet-18 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(model_urls["resnet18"]))
+ return model
+
+
+def resnet34(pretrained=False, **kwargs):
+ """Constructs a ResNet-34 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(model_urls["resnet34"]))
+ return model
+
+
+def resnet50(pretrained=False, **kwargs):
+ """Constructs a ResNet-50 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(model_urls["resnet50"]))
+ return model
+
+
+def resnet101(pretrained=False, **kwargs):
+ """Constructs a ResNet-101 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(model_urls["resnet101"]))
+ return model
+
+
+def resnet152(pretrained=False, **kwargs):
+ """Constructs a ResNet-152 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(model_urls["resnet152"]))
+ return model
diff --git a/AIGVDet/raft_model/raft-things.pth b/AIGVDet/raft_model/raft-things.pth
new file mode 100644
index 0000000000000000000000000000000000000000..1e206ac8a2f660bc7620b0806a9278ddb3fc594d
--- /dev/null
+++ b/AIGVDet/raft_model/raft-things.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fcfa4125d6418f4de95d84aec20a3c5f4e205101715a79f193243c186ac9a7e1
+size 21108000
diff --git a/AIGVDet/requirements.txt b/AIGVDet/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7a14c9e0037d6316fccac6a343123b0f46c2ffc0
--- /dev/null
+++ b/AIGVDet/requirements.txt
@@ -0,0 +1,22 @@
+# conda create -n aigvdet python=3.9
+# pip install torch==2.0.0+cu117 torchvision==0.15.1+cu117 -f https://download.pytorch.org/whl/torch_stable.html
+einops
+imageio
+ipympl
+matplotlib
+numpy<2
+opencv-python
+pandas
+scikit-learn
+tensorboard
+tensorboardX
+tqdm
+blobfile>=1.0.5
+natsort
+fastapi==0.116.1
+pydantic==2.11.7
+uvicorn[standard]
+torch==2.0.0+cu117
+torchvision==0.15.1+cu117
+-f https://download.pytorch.org/whl/torch_stable.html
+python-multipart
diff --git a/AIGVDet/run.py b/AIGVDet/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..5420ff7ca6d1b65f6f56e66366599486c0cd443f
--- /dev/null
+++ b/AIGVDet/run.py
@@ -0,0 +1,214 @@
+import sys
+import argparse
+import os
+import cv2
+import glob
+import numpy as np
+import torch
+from PIL import Image
+import pandas as pd
+import torch
+import torch.nn
+import torchvision.transforms as transforms
+import torchvision.transforms.functional as TF
+from tqdm import tqdm
+import time
+
+from .core.raft import RAFT
+from .core.utils import flow_viz
+from .core.utils.utils import InputPadder
+from natsort import natsorted
+from .core.utils1.utils import get_network, str2bool, to_cuda
+from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score,roc_auc_score
+
+DEVICE = 'cuda'
+# DEVICE = 'cpu' # Changed to 'cpu'
+device = torch.device(DEVICE)
+
+
+def load_image(imfile):
+ img = np.array(Image.open(imfile)).astype(np.uint8)
+ img = torch.from_numpy(img).permute(2, 0, 1).float()
+ return img[None].to(DEVICE)
+
+
+def viz(img, flo, folder_optical_flow_path, imfile1):
+ img = img[0].permute(1, 2, 0).cpu().numpy()
+ flo = flo[0].permute(1, 2, 0).cpu().numpy()
+
+ # map flow to RGB image
+ flo = flow_viz.flow_to_image(flo)
+ img_flo = np.concatenate([img, flo], axis=0)
+
+ # extract filename safely (cross-platform)
+ filename = os.path.basename(imfile1).strip()
+ output_path = os.path.join(folder_optical_flow_path, filename)
+
+ print(output_path)
+ cv2.imwrite(output_path, flo)
+
+
+def video_to_frames(video_path, output_folder):
+ if not os.path.exists(output_folder):
+ os.makedirs(output_folder)
+
+ cap = cv2.VideoCapture(video_path)
+ frame_count = 0
+
+ while cap.isOpened():
+ ret, frame = cap.read()
+ if not ret:
+ break
+
+ frame_filename = os.path.join(output_folder, f"frame_{frame_count:05d}.png")
+ cv2.imwrite(frame_filename, frame)
+ frame_count += 1
+
+ cap.release()
+
+ images = glob.glob(os.path.join(output_folder, '*.png')) + \
+ glob.glob(os.path.join(output_folder, '*.jpg'))
+ images = sorted(images)
+
+ return images
+
+
+# generate optical flow images
+def OF_gen(args):
+ model = torch.nn.DataParallel(RAFT(args))
+ model.load_state_dict(torch.load(args.model, map_location=torch.device(DEVICE)))
+
+ model = model.module
+ model.to(DEVICE)
+ model.eval()
+
+ if not os.path.exists(args.folder_optical_flow_path):
+ os.makedirs(args.folder_optical_flow_path)
+ print(f'{args.folder_optical_flow_path}')
+
+ with torch.no_grad():
+
+ images = video_to_frames(args.path, args.folder_original_path)
+ images = natsorted(images)
+
+ for imfile1, imfile2 in zip(images[:-1], images[1:]):
+ image1 = load_image(imfile1)
+ image2 = load_image(imfile2)
+
+ padder = InputPadder(image1.shape)
+ image1, image2 = padder.pad(image1, image2)
+
+ flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
+
+ viz(image1, flow_up,args.folder_optical_flow_path,imfile1)
+
+
+def RUN(args=None):
+ script_dir = os.path.dirname(os.path.abspath(__file__))
+ default_model_path = os.path.join(script_dir, "raft_model/raft-things.pth")
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--model', help="restore checkpoint",default=default_model_path)
+ parser.add_argument('--path', help="dataset for evaluation",default="video/000000.mp4")
+ parser.add_argument('--folder_original_path', help="dataset for evaluation_frames",default="frame/000000")
+ parser.add_argument('--small', action='store_true', help='use small model')
+ parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
+ parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
+ parser.add_argument('--folder_optical_flow_path',help="the results to save",default="optical_result/000000")
+ parser.add_argument(
+ "-mop",
+ "--model_optical_flow_path",
+ type=str,
+ default="checkpoints/optical.pth",
+ )
+ parser.add_argument(
+ "-mor",
+ "--model_original_path",
+ type=str,
+ default="checkpoints/original.pth",
+ )
+ parser.add_argument(
+ "-t",
+ "--threshold",
+ type=float,
+ default=0.5,
+ )
+ parser.add_argument("--use_cpu", action="store_true", help="uses gpu by default, turn on to use cpu")
+ parser.add_argument("--arch", type=str, default="resnet50")
+ parser.add_argument("--aug_norm", type=str2bool, default=True)
+ args = parser.parse_args(args)
+ start_time = time.perf_counter()
+ OF_gen(args)
+ elapsed = time.perf_counter() - start_time
+ print(f"⏱️ [OF_gen] Service call took {elapsed:.2f} seconds")
+ # Load models
+ model_op = get_network(args.arch)
+ state_dict = torch.load(args.model_optical_flow_path, map_location='cpu')
+ model_op.load_state_dict(state_dict["model"] if "model" in state_dict else state_dict)
+ model_op.eval().to(device)
+
+ model_or = get_network(args.arch)
+ state_dict = torch.load(args.model_original_path, map_location='cpu')
+ model_or.load_state_dict(state_dict["model"] if "model" in state_dict else state_dict)
+ model_or.eval().to(device)
+
+ # Transform
+ trans = transforms.Compose([
+ transforms.CenterCrop((448, 448)),
+ transforms.ToTensor(),
+ ])
+
+ # Process original frames
+ original_file_list = sorted(
+ glob.glob(os.path.join(args.folder_original_path, "*.jpg")) +
+ glob.glob(os.path.join(args.folder_original_path, "*.png")) +
+ glob.glob(os.path.join(args.folder_original_path, "*.JPEG"))
+ )
+ original_prob_sum = 0
+ for img_path in tqdm(original_file_list, desc="Original", dynamic_ncols=True, disable=len(original_file_list) <= 1):
+ img = Image.open(img_path).convert("RGB")
+ img = trans(img)
+ if args.aug_norm:
+ img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ in_tens = img.unsqueeze(0).to(device)
+ with torch.no_grad():
+ prob = model_or(in_tens).sigmoid().item()
+ original_prob_sum += prob
+
+ original_prob = original_prob_sum / len(original_file_list)
+ print(f"Original prob: {original_prob:.4f}")
+
+ # Process optical flow frames
+ optical_file_list = sorted(
+ glob.glob(os.path.join(args.folder_optical_flow_path, "*.jpg")) +
+ glob.glob(os.path.join(args.folder_optical_flow_path, "*.png")) +
+ glob.glob(os.path.join(args.folder_optical_flow_path, "*.JPEG"))
+ )
+ optical_prob_sum = 0
+ for img_path in tqdm(optical_file_list, desc="Optical Flow", dynamic_ncols=True, disable=len(optical_file_list) <= 1):
+ img = Image.open(img_path).convert("RGB")
+ img = trans(img)
+ if args.aug_norm:
+ img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ in_tens = img.unsqueeze(0).to(device)
+ with torch.no_grad():
+ prob = model_op(in_tens).sigmoid().item()
+ optical_prob_sum += prob
+
+ optical_prob = optical_prob_sum / len(optical_file_list)
+ print(f"Optical prob: {optical_prob:.4f}")
+
+ final_prob = (original_prob + optical_prob) / 2
+ print(f"predict: {final_prob}")
+
+ real_score = 1 - final_prob
+ fake_score = final_prob
+ print(f"Confidence scores - Real: {real_score:.4f}, Fake: {fake_score:.4f}")
+
+ return {
+ "original_prob": original_prob,
+ "optical_prob": optical_prob,
+ "final_predict": final_prob,
+ "real_score": real_score,
+ "fake_score": fake_score
+ }
diff --git a/AIGVDet/run.sh b/AIGVDet/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..01026f77544eca84d10d7840ffc5fd544b769f25
--- /dev/null
+++ b/AIGVDet/run.sh
@@ -0,0 +1 @@
+python demo.py --path "demo_video/real/094.mp4" --folder_original_path "frame/000000" --folder_optical_flow_path "optical_result/000000" -mop "checkpoints/optical.pth" -mor "checkpoints/original.pth"
\ No newline at end of file
diff --git a/AIGVDet/test.py b/AIGVDet/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..e25e90b0364ce6bdf9f6e25426df1956c57dd03d
--- /dev/null
+++ b/AIGVDet/test.py
@@ -0,0 +1,261 @@
+import argparse
+import glob
+import os
+import pandas as pd
+
+import torch
+import torch.nn
+import torchvision.transforms as transforms
+import torchvision.transforms.functional as TF
+from PIL import Image
+from tqdm import tqdm
+
+from core.utils1.utils import get_network, str2bool, to_cuda
+from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score,roc_auc_score
+
+if __name__=="__main__":
+
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument(
+ "-fop", "--folder_optical_flow_path", default="data/test/T2V/videocraft", type=str, help="path to optical flow imagefile folder"
+ )
+ parser.add_argument(
+ "-for", "--folder_original_path", default="data/test/original/T2V/videocraft", type=str, help="path to RGB image file folder"
+ )
+ parser.add_argument(
+ "-mop",
+ "--model_optical_flow_path",
+ type=str,
+ default="checkpoints/optical.pth",
+ )
+ parser.add_argument(
+ "-mor",
+ "--model_original_path",
+ type=str,
+ default="checkpoints/original.pth",
+ )
+
+ parser.add_argument(
+ "-t",
+ "--threshold",
+ type=float,
+ default=0.5,
+ )
+
+ parser.add_argument(
+ "-e",
+ "--excel_path",
+ type=str,
+ help="path to excel of frames",
+ default="data/results/moonvalley_wang.csv",
+ )
+
+ parser.add_argument(
+ "-ef",
+ "--excel_frame_path",
+ type=str,
+ help="path to excel of frame detection result",
+ default="data/results/frame/moonvalley_wang.csv",
+ )
+
+
+
+
+ parser.add_argument("--use_cpu", action="store_true", help="uses gpu by default, turn on to use cpu")
+ parser.add_argument("--arch", type=str, default="resnet50")
+ parser.add_argument("--aug_norm", type=str2bool, default=True)
+
+ args = parser.parse_args()
+ subfolder_count = 0
+
+ model_op = get_network(args.arch)
+ state_dict = torch.load(args.model_optical_flow_path, map_location="cpu")
+ if "model" in state_dict:
+ state_dict = state_dict["model"]
+ model_op.load_state_dict(state_dict)
+ model_op.eval()
+ if not args.use_cpu:
+ model_op.cuda()
+
+
+ model_or = get_network(args.arch)
+ state_dict = torch.load(args.model_original_path, map_location="cpu")
+ if "model" in state_dict:
+ state_dict = state_dict["model"]
+ model_or.load_state_dict(state_dict)
+ model_or.eval()
+ if not args.use_cpu:
+ model_or.cuda()
+
+
+ trans = transforms.Compose(
+ (
+ transforms.CenterCrop((448,448)),
+ transforms.ToTensor(),
+ )
+ )
+
+ print("*" * 50)
+
+ flag=0
+ p=0
+ n=0
+ tp=0
+ tn=0
+ y_true=[]
+ y_pred=[]
+
+ # create an empty DataFrame
+ df = pd.DataFrame(columns=['name', 'pro','flag','optical_pro','original_pro'])
+ df1 = pd.DataFrame(columns=['original_path', 'original_pro','optical_path','optical_pro','flag'])
+ index1=0
+
+ # Traverse through subfolders in a large folder.
+ for subfolder_name in ["0_real", "1_fake"]:
+ optical_subfolder_path = os.path.join(args.folder_optical_flow_path, subfolder_name)
+ original_subfolder_path = os.path.join(args.folder_original_path, subfolder_name)
+
+ if subfolder_name=="0_real":
+ flag=0
+ else:
+ flag=1
+
+ if os.path.isdir(optical_subfolder_path):
+ pass
+ else:
+ print("Subfolder does not exist.", optical_subfolder_path)
+
+ # Check if the subfolder path exists.
+ if os.path.isdir(original_subfolder_path):
+ print("test subfolder:", subfolder_name)
+
+ # Traverse through sub-subfolders within a subfolder.
+ for subsubfolder_name in os.listdir(original_subfolder_path):
+ original_subsubfolder_path = os.path.join(original_subfolder_path, subsubfolder_name)
+ optical_subsubfolder_path = os.path.join(optical_subfolder_path, subsubfolder_name)
+ if os.path.isdir(optical_subsubfolder_path):
+ pass
+ else:
+ print("Sub-subfolder does not exist.",optical_subsubfolder_path)
+
+ if os.path.isdir(original_subsubfolder_path):
+ print("test subsubfolder:", subsubfolder_name)
+
+ #Detect original
+ original_file_list = sorted(glob.glob(os.path.join(original_subsubfolder_path, "*.jpg")) + glob.glob(os.path.join(original_subsubfolder_path, "*.png"))+glob.glob(os.path.join(original_subsubfolder_path, "*.JPEG")))
+
+ original_prob_sum=0
+ for img_path in tqdm(original_file_list, dynamic_ncols=True, disable=len(original_file_list) <= 1):
+
+ img = Image.open(img_path).convert("RGB")
+ img = trans(img)
+ if args.aug_norm:
+ img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ in_tens = img.unsqueeze(0)
+ if not args.use_cpu:
+ in_tens = in_tens.cuda()
+
+ with torch.no_grad():
+ prob = model_or(in_tens).sigmoid().item()
+ original_prob_sum+=prob
+
+ df1 = df1.append({'original_path': img_path, 'original_pro': prob , 'flag':flag}, ignore_index=True)
+
+
+ original_predict=original_prob_sum/len(original_file_list)
+ print("original prob",original_predict)
+
+ #Detect optical flow
+ optical_file_list = sorted(glob.glob(os.path.join(optical_subsubfolder_path, "*.jpg")) + glob.glob(os.path.join(optical_subsubfolder_path, "*.png"))+glob.glob(os.path.join(optical_subsubfolder_path, "*.JPEG")))
+ optical_prob_sum=0
+ for img_path in tqdm(optical_file_list, dynamic_ncols=True, disable=len(original_file_list) <= 1):
+
+ img = Image.open(img_path).convert("RGB")
+ img = trans(img)
+ if args.aug_norm:
+ img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ in_tens = img.unsqueeze(0)
+ if not args.use_cpu:
+ in_tens = in_tens.cuda()
+
+ with torch.no_grad():
+ prob = model_op(in_tens).sigmoid().item()
+ optical_prob_sum+=prob
+
+ df1.loc[index1, 'optical_path'] = img_path
+ df1.loc[index1, 'optical_pro'] = prob
+ index1=index1+1
+ index1=index1+1
+
+ optical_predict=optical_prob_sum/len(optical_file_list)
+ print("optical prob",optical_predict)
+
+ predict=original_predict*0.5+optical_predict*0.5
+ print(f"flag:{flag} predict:{predict}")
+ # y_true.append((float)(flag))
+ y_true.append((flag))
+ y_pred.append(predict)
+ if flag==0:
+ n+=1
+ if predict=args.threshold:
+ tp+=1
+ df = df.append({'name': subsubfolder_name, 'pro': predict , 'flag':flag ,'optical_pro':optical_predict,'original_pro':original_predict}, ignore_index=True)
+ else:
+ print("Subfolder does not exist:", original_subfolder_path)
+ # r_acc = accuracy_score(y_true[y_true == 0], y_pred[y_true == 0] > args.threshold)
+ # f_acc = accuracy_score(y_true[y_true == 1], y_pred[y_true == 1] > args.threshold)
+ # acc = accuracy_score(y_true, y_pred > args.threshold)
+
+ ap = average_precision_score(y_true, y_pred)
+ auc=roc_auc_score(y_true,y_pred)
+ # print(f"r_acc:{r_acc}")
+ print(f"tnr:{tn/n}")
+ # print(f"f_acc:{f_acc}")
+ print(f"tpr:{tp/p}")
+ print(f"acc:{(tp+tn)/(p+n)}")
+ # print(f"acc:{acc}")
+ print(f"ap:{ap}")
+ print(f"auc:{auc}")
+ print(f"p:{p}")
+ print(f"n:{n}")
+ print(f"tp:{tp}")
+ print(f"tn:{tn}")
+
+ # Write the DataFrame to a csv file.
+ csv_filename = args.excel_path
+ csv_folder = os.path.dirname(csv_filename)
+ if not os.path.exists(csv_folder):
+ os.makedirs(csv_folder)
+
+
+ if not os.path.exists(csv_filename):
+ df.to_csv(csv_filename, index=False)
+ else:
+ df.to_csv(csv_filename, mode='a', header=False, index=False)
+ print(f"Results have been saved to {csv_filename}")
+
+ # Write the prediction probabilities of the frame to a CSV file.
+ csv_filename1 = args.excel_frame_path
+ csv_folder1 = os.path.dirname(csv_filename1)
+ if not os.path.exists(csv_folder1):
+ os.makedirs(csv_folder1)
+
+ if not os.path.exists(csv_filename1):
+ df1.to_csv(csv_filename1, index=False)
+ else:
+ df1.to_csv(csv_filename1, mode='a', header=False, index=False)
+
+ # if not os.path.exists(excel_filename):
+ # with pd.ExcelWriter(excel_filename, engine='xlsxwriter') as writer:
+ # df.to_excel(writer, sheet_name='Sheet1', index=False)
+ # else:
+ # with pd.ExcelWriter(excel_filename, mode='a', engine='openpyxl') as writer:
+ # df.to_excel(writer, sheet_name='Sheet1', index=False, startrow=0, header=False)
+ print(f"Results have been saved to {csv_filename1}")
+
+
+
diff --git a/AIGVDet/test.sh b/AIGVDet/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c870a97df5c197b1d53c99a762cdf2e560d1bed1
--- /dev/null
+++ b/AIGVDet/test.sh
@@ -0,0 +1 @@
+python test.py -fop "data/test/T2V/hotshot" -mop "checkpoints/optical_aug.pth" -for "data/test/original/T2V/hotshot" -mor "checkpoints/original_aug.pth" -e "data/results/T2V/hotshot.csv" -ef "data/results/frame/T2V/hotshot.csv" -t 0.5
\ No newline at end of file
diff --git a/AIGVDet/train.py b/AIGVDet/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..4792f2e7d38966e90e3bdfaa46cd45adad085cbe
--- /dev/null
+++ b/AIGVDet/train.py
@@ -0,0 +1,87 @@
+from core.utils1.config import cfg # isort: split
+
+import os
+import time
+
+from tensorboardX import SummaryWriter
+from tqdm import tqdm
+
+from core.utils1.datasets import create_dataloader
+from core.utils1.earlystop import EarlyStopping
+from core.utils1.eval import get_val_cfg, validate
+from core.utils1.trainer import Trainer
+from core.utils1.utils import Logger
+
+import ssl
+ssl._create_default_https_context = ssl._create_unverified_context
+
+
+if __name__ == "__main__":
+ val_cfg = get_val_cfg(cfg, split="val", copy=True)
+ cfg.dataset_root = os.path.join(cfg.dataset_root, "train")
+ data_loader = create_dataloader(cfg)
+ dataset_size = len(data_loader)
+
+ log = Logger()
+ log.open(cfg.logs_path, mode="a")
+ log.write("Num of training images = %d\n" % (dataset_size * cfg.batch_size))
+ log.write("Config:\n" + str(cfg.to_dict()) + "\n")
+
+ train_writer = SummaryWriter(os.path.join(cfg.exp_dir, "train"))
+ val_writer = SummaryWriter(os.path.join(cfg.exp_dir, "val"))
+
+ trainer = Trainer(cfg)
+ early_stopping = EarlyStopping(patience=cfg.earlystop_epoch, delta=-0.001, verbose=True)
+ for epoch in range(cfg.nepoch):
+ epoch_start_time = time.time()
+ iter_data_time = time.time()
+ epoch_iter = 0
+
+ for data in tqdm(data_loader, dynamic_ncols=True):
+ trainer.total_steps += 1
+ epoch_iter += cfg.batch_size
+
+ trainer.set_input(data)
+ trainer.optimize_parameters()
+
+ # if trainer.total_steps % cfg.loss_freq == 0:
+ # log.write(f"Train loss: {trainer.loss} at step: {trainer.total_steps}\n")
+ train_writer.add_scalar("loss", trainer.loss, trainer.total_steps)
+
+ if trainer.total_steps % cfg.save_latest_freq == 0:
+ log.write(
+ "saving the latest model %s (epoch %d, model.total_steps %d)\n"
+ % (cfg.exp_name, epoch, trainer.total_steps)
+ )
+ trainer.save_networks("latest")
+
+ if epoch % cfg.save_epoch_freq == 0:
+ log.write("saving the model at the end of epoch %d, iters %d\n" % (epoch, trainer.total_steps))
+ trainer.save_networks("latest")
+ trainer.save_networks(epoch)
+
+ # Validation
+ trainer.eval()
+ val_results = validate(trainer.model, val_cfg)
+ val_writer.add_scalar("AP", val_results["AP"], trainer.total_steps)
+ val_writer.add_scalar("ACC", val_results["ACC"], trainer.total_steps)
+ # add
+ val_writer.add_scalar("AUC", val_results["AUC"], trainer.total_steps)
+ val_writer.add_scalar("TPR", val_results["TPR"], trainer.total_steps)
+ val_writer.add_scalar("TNR", val_results["TNR"], trainer.total_steps)
+
+ log.write(f"(Val @ epoch {epoch}) AP: {val_results['AP']}; ACC: {val_results['ACC']}\n")
+
+ if cfg.earlystop:
+ early_stopping(val_results["ACC"], trainer)
+ if early_stopping.early_stop:
+ if trainer.adjust_learning_rate():
+ log.write("Learning rate dropped by 10, continue training...\n")
+ early_stopping = EarlyStopping(patience=cfg.earlystop_epoch, delta=-0.002, verbose=True)
+ else:
+ log.write("Early stopping.\n")
+ break
+ if cfg.warmup:
+ # print(trainer.scheduler.get_lr()[0])
+ trainer.scheduler.step()
+ trainer.train()
diff --git a/AIGVDet/train.sh b/AIGVDet/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d8dd0b32979f957c9a79c706b45289adc572aea8
--- /dev/null
+++ b/AIGVDet/train.sh
@@ -0,0 +1,4 @@
+EXP_NAME="moonvalley_vos2_crop"
+DATASETS="moonvalley_vos2_crop"
+DATASETS_TEST="moonvalley_vos2_crop"
+python train.py --gpus 0 --exp_name $EXP_NAME datasets $DATASETS datasets_test $DATASETS_TEST
\ No newline at end of file
diff --git a/api_server.py b/api_server.py
new file mode 100644
index 0000000000000000000000000000000000000000..245d33369bd7147a1e906aafa06615efbd59e2cd
--- /dev/null
+++ b/api_server.py
@@ -0,0 +1,281 @@
+import asyncio
+import os
+import shutil
+import json
+import uuid
+import time
+import threading
+from typing import List, Dict, Any, Optional
+from concurrent.futures import ThreadPoolExecutor
+
+from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks
+from fastapi.responses import JSONResponse
+from pydantic import BaseModel
+
+# Giả lập import các thư viện của bạn
+from miragenews import run_multimodal_to_json
+from AIGVDet import run_video_to_json
+
+UPLOAD_DIR = "temp_uploads"
+MAX_WORKERS = 4
+
+app = FastAPI(
+ title="Multimedia Analysis API (Polling Mode)",
+ description="API phân tích đa phương tiện sử dụng cơ chế Polling để tránh Timeout.",
+ version="2.0.0",
+)
+
+jobs: Dict[str, Dict] = {}
+jobs_lock = threading.Lock()
+
+class AnalysisResult(BaseModel):
+ image_analysis_results: Optional[List[Any]] = None
+ video_analysis_result: Optional[Dict[str, Any]] = None
+
+class JobStatus(BaseModel):
+ job_id: str
+ status: str
+ message: Optional[str] = None
+ result: Optional[AnalysisResult] = None
+ created_at: float
+ updated_at: float
+
+def _create_job() -> str:
+ job_id = uuid.uuid4().hex
+ now = time.time()
+ with jobs_lock:
+ jobs[job_id] = {
+ "job_id": job_id,
+ "status": "queued",
+ "message": "Đang chờ xử lý...",
+ "result": None,
+ "created_at": now,
+ "updated_at": now
+ }
+ return job_id
+
+def _update_job(job_id: str, **kwargs):
+ with jobs_lock:
+ if job_id in jobs:
+ jobs[job_id].update(kwargs)
+ jobs[job_id]["updated_at"] = time.time()
+
+def _get_job(job_id: str) -> Optional[Dict]:
+ with jobs_lock:
+ return jobs.get(job_id)
+
+async def run_analysis_logic(
+ image_paths: Optional[List[str]] = None,
+ video_path: Optional[str] = None,
+ text: str = "",
+) -> Dict[str, Any]:
+
+ if not image_paths and not video_path:
+ raise ValueError("Cần cung cấp ít nhất một trong hai: image_paths hoặc video_path.")
+
+ tasks = []
+
+ if image_paths:
+ image_task = asyncio.create_task(
+ run_multimodal_to_json(image_paths=image_paths, text=text, output_json_path=None)
+ )
+ tasks.append(image_task)
+
+ if video_path:
+ video_task = asyncio.to_thread(
+ run_video_to_json, video_path=video_path, output_json_path=None
+ )
+ tasks.append(video_task)
+
+ task_results = await asyncio.gather(*tasks)
+
+ final_result = {"image_analysis_results": [], "video_analysis_result": {}}
+ image_analysis_results = []
+ video_result_index = -1
+
+ current_idx = 0
+ if image_paths:
+ image_analysis_results = task_results[current_idx]
+ current_idx += 1
+
+ if video_path:
+ video_result_index = current_idx
+
+ final_result["image_analysis_results"] = image_analysis_results
+
+ if video_result_index != -1:
+ raw_video_result = task_results[video_result_index]
+ if raw_video_result:
+ video_id_key = list(raw_video_result.keys())[0]
+ video_data = raw_video_result[video_id_key]
+
+ avg_authentic = video_data.get("authentic_confidence_score", 0)
+ avg_synthetic = video_data.get("synthetic_confidence_score", 0)
+
+ if avg_authentic > avg_synthetic and avg_authentic > 0.5:
+ authenticity_assessment = "REAL (Authentic)"
+ verification_tools = "Deepfake Detector"
+ synthetic_type = "N/A"
+ other_artifacts = "Our algorithms conducted a thorough analysis of the video's motion patterns, lighting consistency, and object interactions. We observed fluid, natural movements and consistent physics that align with real-world recordings. No discernible artifacts, such as pixel distortion, unnatural blurring, or shadow inconsistencies, were detected that would indicate digital manipulation or AI-driven synthesis."
+ elif avg_authentic > avg_synthetic and avg_authentic <= 0.5:
+ authenticity_assessment = "Potentially Synthetic"
+ verification_tools = "Deepfake Detector"
+ synthetic_type = "Potentially AI-generated"
+ other_artifacts = "Our analysis has identified subtle anomalies within the video frames, particularly in areas of complex texture and inconsistent lighting across different objects. While these discrepancies are not significant enough to definitively classify the video as synthetic, they do suggest a possibility of digital alteration or partial AI generation. Further examination may be required for a conclusive determination."
+ else:
+ authenticity_assessment = "NOT REAL (Fake, Manipulated, or AI)"
+ verification_tools = "Deepfake Detector"
+ synthetic_type = "AI-generated"
+ other_artifacts = "Our deep analysis detected multiple, significant artifacts commonly associated with synthetic or manipulated media. These include, but are not limited to, unnatural facial expressions and eye movements, inconsistent or floating shadows, logical impossibilities in object interaction, and high-frequency digital noise characteristic of generative models. These factors strongly indicate that the video is not authentic."
+
+ final_result["video_analysis_result"] = {
+ "filename": video_data.get("video_name", ""),
+ "result": {
+ "authenticity_assessment": authenticity_assessment,
+ "verification_tools_methods": verification_tools,
+ "synthetic_type": synthetic_type,
+ "other_artifacts": other_artifacts,
+ },
+ }
+
+ if not final_result.get("image_analysis_results"):
+ final_result.pop("image_analysis_results", None)
+ if not final_result.get("video_analysis_result"):
+ final_result.pop("video_analysis_result", None)
+
+ return final_result
+
+async def process_job_background(
+ job_id: str,
+ temp_dir: str,
+ image_paths: List[str],
+ video_path: str,
+ text: str
+):
+ """Hàm chạy ngầm thực hiện phân tích"""
+ _update_job(job_id, status="running", message="Đang phân tích...")
+
+ try:
+ result_data = await run_analysis_logic(
+ image_paths=image_paths if image_paths else None,
+ video_path=video_path if video_path else None,
+ text=text
+ )
+
+ _update_job(job_id, status="succeeded", result=result_data, message="Hoàn tất")
+
+ except Exception as e:
+ print(f"Error processing job {job_id}: {e}")
+ _update_job(job_id, status="failed", message=str(e))
+
+ finally:
+ try:
+ if os.path.exists(temp_dir):
+ shutil.rmtree(temp_dir)
+ print(f"Deleted temp dir: {temp_dir}")
+ except Exception as cleanup_error:
+ print(f"Cleanup error for {job_id}: {cleanup_error}")
+
+
+@app.on_event("startup")
+def startup_event():
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
+
+@app.get("/analyze/{job_id}", response_model=JobStatus)
+async def get_job_status(job_id: str):
+ """Client gọi API này định kỳ để kiểm tra kết quả"""
+ job = _get_job(job_id)
+ if not job:
+ raise HTTPException(status_code=404, detail="Job not found")
+ return job
+
+@app.post("/analyze/image/", response_model=JobStatus)
+async def analyze_image_endpoint(
+ background_tasks: BackgroundTasks,
+ images: List[UploadFile] = File(...),
+ text: Optional[str] = Form(""),
+):
+ job_id = _create_job()
+
+ job_dir = os.path.join(UPLOAD_DIR, job_id)
+ os.makedirs(job_dir, exist_ok=True)
+
+ saved_image_paths = []
+ try:
+ for img in images:
+ file_path = os.path.join(job_dir, img.filename)
+ with open(file_path, "wb") as buffer:
+ shutil.copyfileobj(img.file, buffer)
+ saved_image_paths.append(file_path)
+ except Exception as e:
+ _update_job(job_id, status="failed", message=f"Lỗi upload: {e}")
+ return _get_job(job_id)
+
+ background_tasks.add_task(
+ process_job_background,
+ job_id, job_dir, saved_image_paths, None, text
+ )
+
+ return _get_job(job_id)
+
+@app.post("/analyze/video/", response_model=JobStatus)
+async def analyze_video_endpoint(
+ background_tasks: BackgroundTasks,
+ video: UploadFile = File(...),
+):
+ job_id = _create_job()
+ job_dir = os.path.join(UPLOAD_DIR, job_id)
+ os.makedirs(job_dir, exist_ok=True)
+
+ saved_video_path = os.path.join(job_dir, video.filename)
+ try:
+ with open(saved_video_path, "wb") as buffer:
+ shutil.copyfileobj(video.file, buffer)
+ except Exception as e:
+ _update_job(job_id, status="failed", message=f"Lỗi upload: {e}")
+ return _get_job(job_id)
+
+ background_tasks.add_task(
+ process_job_background,
+ job_id, job_dir, [], saved_video_path, ""
+ )
+
+ return _get_job(job_id)
+
+@app.post("/analyze/multimodal/", response_model=JobStatus)
+async def analyze_multimodal_endpoint(
+ background_tasks: BackgroundTasks,
+ images: List[UploadFile] = File(...),
+ video: UploadFile = File(...),
+ text: Optional[str] = Form(""),
+):
+ job_id = _create_job()
+ job_dir = os.path.join(UPLOAD_DIR, job_id)
+ os.makedirs(job_dir, exist_ok=True)
+
+ saved_image_paths = []
+ saved_video_path = None
+
+ try:
+ # Save Images
+ for img in images:
+ file_path = os.path.join(job_dir, img.filename)
+ with open(file_path, "wb") as buffer:
+ shutil.copyfileobj(img.file, buffer)
+ saved_image_paths.append(file_path)
+
+ # Save Video
+ saved_video_path = os.path.join(job_dir, video.filename)
+ with open(saved_video_path, "wb") as buffer:
+ shutil.copyfileobj(video.file, buffer)
+
+ except Exception as e:
+ _update_job(job_id, status="failed", message=f"Lỗi upload: {e}")
+ return _get_job(job_id)
+
+ background_tasks.add_task(
+ process_job_background,
+ job_id, job_dir, saved_image_paths, saved_video_path, text
+ )
+
+ return _get_job(job_id)
\ No newline at end of file
diff --git a/checkpoints/image/best-mirage-img.pt b/checkpoints/image/best-mirage-img.pt
new file mode 100644
index 0000000000000000000000000000000000000000..5c89ed0b73bb6831ec1b6912ab83b8b91887989e
--- /dev/null
+++ b/checkpoints/image/best-mirage-img.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:755e26bfe97830e03e4c475596ade7168e1df18c57d89162d685b1148ae9f5a8
+size 2375
diff --git a/checkpoints/image/cbm-encoder.pt b/checkpoints/image/cbm-encoder.pt
new file mode 100644
index 0000000000000000000000000000000000000000..4e70876f43511c3ff3767a7653e595990ff84d0b
--- /dev/null
+++ b/checkpoints/image/cbm-encoder.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2a90c02450ae4a105414c6701bf571d1a498571083a725d916119d792cbdcf2f
+size 1917947
diff --git a/checkpoints/image/cbm-predictor.pt b/checkpoints/image/cbm-predictor.pt
new file mode 100644
index 0000000000000000000000000000000000000000..82d2e0cd974701eb4d60e8c2c72a3f03ecb461ae
--- /dev/null
+++ b/checkpoints/image/cbm-predictor.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2eba7ff0cd8db1c69961478551b2a40e47687f281d7e853f95a6650908e8c3df
+size 2387
diff --git a/checkpoints/image/img-linear.pt b/checkpoints/image/img-linear.pt
new file mode 100644
index 0000000000000000000000000000000000000000..ffc5e1ac651e2c85711dde93dec50ef53690b290
--- /dev/null
+++ b/checkpoints/image/img-linear.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:457af200c5b62936c0dbcfc9a0e930e0c05ea18bea16cc9d56dd869e445c7833
+size 6839
diff --git a/checkpoints/text/mirage-txt.pt b/checkpoints/text/mirage-txt.pt
new file mode 100644
index 0000000000000000000000000000000000000000..a987a21403c6f14b8ded97e127af8f2b5c0143d7
--- /dev/null
+++ b/checkpoints/text/mirage-txt.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4b5adbd2992bec2fabb6770debc54c1aab26102fb1ec2f6480da69c5e11797b3
+size 1267
diff --git a/checkpoints/text/tbm-predictor.pt b/checkpoints/text/tbm-predictor.pt
new file mode 100644
index 0000000000000000000000000000000000000000..f0e4b41bd9a58a90512c3fa5657949e107a8b726
--- /dev/null
+++ b/checkpoints/text/tbm-predictor.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:699d961ef46e6c9f997de6e3a7d3003e9ee32c7ebcc0a9cd77d2cde8e36071b9
+size 1279
diff --git a/checkpoints/text/txt-linear.pt b/checkpoints/text/txt-linear.pt
new file mode 100644
index 0000000000000000000000000000000000000000..6aca03bf3336306ef63e6cca34d93777ebd089c5
--- /dev/null
+++ b/checkpoints/text/txt-linear.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d2f11dab8f689ea8052f86489d9b5e821c043d8018b7bac0431e40af79685bcb
+size 4275
diff --git a/configs/image/cbm-encoder.yaml b/configs/image/cbm-encoder.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0c566ffa97961004b179a6eb97e1c5dd265a6fd7
--- /dev/null
+++ b/configs/image/cbm-encoder.yaml
@@ -0,0 +1,8 @@
+model:
+ name: "cbm-encoder"
+
+training:
+ epochs: 100
+ batch_size: 64
+ learning_rate: 0.001
+ save_path: "checkpoints/image/cbm-encoder.pt"
\ No newline at end of file
diff --git a/configs/image/cbm-predictor.yaml b/configs/image/cbm-predictor.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0ec79b79ae5e17e101630cc4a1eaeb80e6b30ab5
--- /dev/null
+++ b/configs/image/cbm-predictor.yaml
@@ -0,0 +1,59 @@
+model:
+ name: "cbm-predictor"
+
+training:
+ epochs: 100
+ batch_size: 64
+ learning_rate: 0.001
+ save_path: "checkpoints/image/cbm-predictor.pt"
+
+train_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/predictions/image/cbm-encoder/train/real.pt"
+ fake_pt: "encodings/predictions/image/cbm-encoder/train/fake.pt"
+
+val_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/predictions/image/cbm-encoder/validation/real.pt"
+ fake_pt: "encodings/predictions/image/cbm-encoder/validation/fake.pt"
+
+test1_dataset:
+ name: "img-or-text"
+ test_name: "test1_nyt_mj"
+ params:
+ real_pt: "encodings/predictions/image/cbm-encoder/test1_nyt_mj/real.pt"
+ fake_pt: "encodings/predictions/image/cbm-encoder/test1_nyt_mj/fake.pt"
+
+test2_dataset:
+ name: "img-or-text"
+ test_name: "test2_bbc_dalle"
+ params:
+ real_pt: "encodings/predictions/image/cbm-encoder/test2_bbc_dalle/real.pt"
+ fake_pt: "encodings/predictions/image/cbm-encoder/test2_bbc_dalle/fake.pt"
+
+test3_dataset:
+ name: "img-or-text"
+ test_name: "test3_cnn_dalle"
+ params:
+ real_pt: "encodings/predictions/image/cbm-encoder/test3_cnn_dalle/real.pt"
+ fake_pt: "encodings/predictions/image/cbm-encoder/test3_cnn_dalle/fake.pt"
+
+test4_dataset:
+ name: "img-or-text"
+ test_name: "test4_bbc_sdxl"
+ params:
+ real_pt: "encodings/predictions/image/cbm-encoder/test4_bbc_sdxl/real.pt"
+ fake_pt: "encodings/predictions/image/cbm-encoder/test4_bbc_sdxl/fake.pt"
+
+test5_dataset:
+ name: "img-or-text"
+ test_name: "test5_cnn_sdxl"
+ params:
+ real_pt: "encodings/predictions/image/cbm-encoder/test5_cnn_sdxl/real.pt"
+ fake_pt: "encodings/predictions/image/cbm-encoder/test5_cnn_sdxl/fake.pt"
+
+testing:
+ batch_size: 64
+ save_path: "results/image/cbm-predictor.jsonl"
\ No newline at end of file
diff --git a/configs/image/linear.yaml b/configs/image/linear.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..140016aae0e48db33c8e03f2746635d18648c96b
--- /dev/null
+++ b/configs/image/linear.yaml
@@ -0,0 +1,59 @@
+model:
+ name: "img-linear"
+
+train_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/image/train/real.pt"
+ fake_pt: "encodings/image/train/fake.pt"
+
+val_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/image/validation/real.pt"
+ fake_pt: "encodings/image/validation/fake.pt"
+
+training:
+ epochs: 100
+ batch_size: 64
+ learning_rate: 0.001
+ save_path: "checkpoints/image/img-linear.pt"
+
+test1_dataset:
+ name: "img-or-text"
+ test_name: "test1_nyt_mj"
+ params:
+ real_pt: "encodings/image/test1_nyt_mj/real.pt"
+ fake_pt: "encodings/image/test1_nyt_mj/fake.pt"
+
+test2_dataset:
+ name: "img-or-text"
+ test_name: "test2_bbc_dalle"
+ params:
+ real_pt: "encodings/image/test2_bbc_dalle/real.pt"
+ fake_pt: "encodings/image/test2_bbc_dalle/fake.pt"
+
+test3_dataset:
+ name: "img-or-text"
+ test_name: "test3_cnn_dalle"
+ params:
+ real_pt: "encodings/image/test3_cnn_dalle/real.pt"
+ fake_pt: "encodings/image/test3_cnn_dalle/fake.pt"
+
+test4_dataset:
+ name: "img-or-text"
+ test_name: "test4_bbc_sdxl"
+ params:
+ real_pt: "encodings/image/test4_bbc_sdxl/real.pt"
+ fake_pt: "encodings/image/test4_bbc_sdxl/fake.pt"
+
+test5_dataset:
+ name: "img-or-text"
+ test_name: "test5_cnn_sdxl"
+ params:
+ real_pt: "encodings/image/test5_cnn_sdxl/real.pt"
+ fake_pt: "encodings/image/test5_cnn_sdxl/fake.pt"
+
+testing:
+ batch_size: 64
+ save_path: "results/image/img-linear.jsonl"
diff --git a/configs/image/mirage.yaml b/configs/image/mirage.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c8f59ff54a43a3445c2c61c1045fbbc3edcacf65
--- /dev/null
+++ b/configs/image/mirage.yaml
@@ -0,0 +1,59 @@
+model:
+ name: "mirage-img"
+
+training:
+ epochs: 100
+ batch_size: 64
+ learning_rate: 0.001
+ save_path: "checkpoints/image/mirage-img.pt"
+
+train_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/predictions/image/merged/train/real.pt"
+ fake_pt: "encodings/predictions/image/merged/train/fake.pt"
+
+val_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/predictions/image/merged/validation/real.pt"
+ fake_pt: "encodings/predictions/image/merged/validation/fake.pt"
+
+test1_dataset:
+ name: "img-or-text"
+ test_name: "test1_nyt_mj"
+ params:
+ real_pt: "encodings/predictions/image/merged/test1_nyt_mj/real.pt"
+ fake_pt: "encodings/predictions/image/merged/test1_nyt_mj/fake.pt"
+
+test2_dataset:
+ name: "img-or-text"
+ test_name: "test2_bbc_dalle"
+ params:
+ real_pt: "encodings/predictions/image/merged/test2_bbc_dalle/real.pt"
+ fake_pt: "encodings/predictions/image/merged/test2_bbc_dalle/fake.pt"
+
+test3_dataset:
+ name: "img-or-text"
+ test_name: "test3_cnn_dalle"
+ params:
+ real_pt: "encodings/predictions/image/merged/test3_cnn_dalle/real.pt"
+ fake_pt: "encodings/predictions/image/merged/test3_cnn_dalle/fake.pt"
+
+test4_dataset:
+ name: "img-or-text"
+ test_name: "test4_bbc_sdxl"
+ params:
+ real_pt: "encodings/predictions/image/merged/test4_bbc_sdxl/real.pt"
+ fake_pt: "encodings/predictions/image/merged/test4_bbc_sdxl/fake.pt"
+
+test5_dataset:
+ name: "img-or-text"
+ test_name: "test5_cnn_sdxl"
+ params:
+ real_pt: "encodings/predictions/image/merged/test5_cnn_sdxl/real.pt"
+ fake_pt: "encodings/predictions/image/merged/test5_cnn_sdxl/fake.pt"
+
+testing:
+ batch_size: 64
+ save_path: "results/image/mirage-img.jsonl"
diff --git a/configs/multimodal/mirage.yaml b/configs/multimodal/mirage.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d0a92529d3bb8f53f5fdc89930f3ef36993e9fc9
--- /dev/null
+++ b/configs/multimodal/mirage.yaml
@@ -0,0 +1,74 @@
+model:
+ name: "mirage"
+
+training:
+ epochs: 100
+ batch_size: 64
+ learning_rate: 0.001
+ image_model_path: "checkpoints/image/best-mirage-img.pt"
+ text_model_path: "checkpoints/text/best-mirage-txt.pt"
+
+train_dataset:
+ name: "multimodal"
+ params:
+ real_img_pt: "encodings/predictions/image/merged/train/real.pt"
+ fake_img_pt: "encodings/predictions/image/merged/train/fake.pt"
+ real_text_pt: "encodings/predictions/text/merged/train/real.pt"
+ fake_text_pt: "encodings/predictions/text/merged/train/fake.pt"
+
+val_dataset:
+ name: "multimodal"
+ params:
+ real_img_pt: "encodings/predictions/image/merged/validation/real.pt"
+ fake_img_pt: "encodings/predictions/image/merged/validation/fake.pt"
+ real_text_pt: "encodings/predictions/text/merged/validation/real.pt"
+ fake_text_pt: "encodings/predictions/text/merged/validation/fake.pt"
+
+test1_dataset:
+ name: "multimodal"
+ test_name: "test1_nyt_mj"
+ params:
+ real_img_pt: "encodings/predictions/image/merged/test1_nyt_mj/real.pt"
+ fake_img_pt: "encodings/predictions/image/merged/test1_nyt_mj/fake.pt"
+ real_text_pt: "encodings/predictions/text/merged/test1_nyt_mj/real.pt"
+ fake_text_pt: "encodings/predictions/text/merged/test1_nyt_mj/fake.pt"
+
+test2_dataset:
+ name: "multimodal"
+ test_name: "test2_bbc_dalle"
+ params:
+ real_img_pt: "encodings/predictions/image/merged/test2_bbc_dalle/real.pt"
+ fake_img_pt: "encodings/predictions/image/merged/test2_bbc_dalle/fake.pt"
+ real_text_pt: "encodings/predictions/text/merged/test2_bbc_dalle/real.pt"
+ fake_text_pt: "encodings/predictions/text/merged/test2_bbc_dalle/fake.pt"
+
+test3_dataset:
+ name: "multimodal"
+ test_name: "test3_cnn_dalle"
+ params:
+ real_img_pt: "encodings/predictions/image/merged/test3_cnn_dalle/real.pt"
+ fake_img_pt: "encodings/predictions/image/merged/test3_cnn_dalle/fake.pt"
+ real_text_pt: "encodings/predictions/text/merged/test3_cnn_dalle/real.pt"
+ fake_text_pt: "encodings/predictions/text/merged/test3_cnn_dalle/fake.pt"
+
+test4_dataset:
+ name: "multimodal"
+ test_name: "test4_bbc_sdxl"
+ params:
+ real_img_pt: "encodings/predictions/image/merged/test4_bbc_sdxl/real.pt"
+ fake_img_pt: "encodings/predictions/image/merged/test4_bbc_sdxl/fake.pt"
+ real_text_pt: "encodings/predictions/text/merged/test4_bbc_sdxl/real.pt"
+ fake_text_pt: "encodings/predictions/text/merged/test4_bbc_sdxl/fake.pt"
+
+test5_dataset:
+ name: "multimodal"
+ test_name: "test5_cnn_sdxl"
+ params:
+ real_img_pt: "encodings/predictions/image/merged/test5_cnn_sdxl/real.pt"
+ fake_img_pt: "encodings/predictions/image/merged/test5_cnn_sdxl/fake.pt"
+ real_text_pt: "encodings/predictions/text/merged/test5_cnn_sdxl/real.pt"
+ fake_text_pt: "encodings/predictions/text/merged/test5_cnn_sdxl/fake.pt"
+
+testing:
+ batch_size: 64
+ save_path: "results/multimodal/mirage.jsonl"
\ No newline at end of file
diff --git a/configs/text/linear.yaml b/configs/text/linear.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..368ab3f136c6ddae14fc33e3d2771b5b9acaa87a
--- /dev/null
+++ b/configs/text/linear.yaml
@@ -0,0 +1,59 @@
+model:
+ name: "txt-linear"
+
+train_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/text/train/real.pt"
+ fake_pt: "encodings/text/train/fake.pt"
+
+val_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/text/validation/real.pt"
+ fake_pt: "encodings/text/validation/fake.pt"
+
+training:
+ epochs: 100
+ batch_size: 64
+ learning_rate: 0.001
+ save_path: "checkpoints/text/txt-linear.pt"
+
+test1_dataset:
+ name: "img-or-text"
+ test_name: "test1_nyt_mj"
+ params:
+ real_pt: "encodings/text/test1_nyt_mj/real.pt"
+ fake_pt: "encodings/text/test1_nyt_mj/fake.pt"
+
+test2_dataset:
+ name: "img-or-text"
+ test_name: "test2_bbc_dalle"
+ params:
+ real_pt: "encodings/text/test2_bbc_dalle/real.pt"
+ fake_pt: "encodings/text/test2_bbc_dalle/fake.pt"
+
+test3_dataset:
+ name: "img-or-text"
+ test_name: "test3_cnn_dalle"
+ params:
+ real_pt: "encodings/text/test3_cnn_dalle/real.pt"
+ fake_pt: "encodings/text/test3_cnn_dalle/fake.pt"
+
+test4_dataset:
+ name: "img-or-text"
+ test_name: "test4_bbc_sdxl"
+ params:
+ real_pt: "encodings/text/test4_bbc_sdxl/real.pt"
+ fake_pt: "encodings/text/test4_bbc_sdxl/fake.pt"
+
+test5_dataset:
+ name: "img-or-text"
+ test_name: "test5_cnn_sdxl"
+ params:
+ real_pt: "encodings/text/test5_cnn_sdxl/real.pt"
+ fake_pt: "encodings/text/test5_cnn_sdxl/fake.pt"
+
+testing:
+ batch_size: 64
+ save_path: "results/text/txt-linear.jsonl"
diff --git a/configs/text/mirage.yaml b/configs/text/mirage.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..791963dd2591b81d62186a8c2dbfb91b2ff4fe5b
--- /dev/null
+++ b/configs/text/mirage.yaml
@@ -0,0 +1,59 @@
+model:
+ name: "mirage-txt"
+
+training:
+ epochs: 100
+ batch_size: 64
+ learning_rate: 0.001
+ save_path: "checkpoints/text/mirage-txt.pt"
+
+train_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/predictions/text/merged/train/real.pt"
+ fake_pt: "encodings/predictions/text/merged/train/fake.pt"
+
+val_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/predictions/text/merged/validation/real.pt"
+ fake_pt: "encodings/predictions/text/merged/validation/fake.pt"
+
+test1_dataset:
+ name: "img-or-text"
+ test_name: "test1_nyt_mj"
+ params:
+ real_pt: "encodings/predictions/text/merged/test1_nyt_mj/real.pt"
+ fake_pt: "encodings/predictions/text/merged/test1_nyt_mj/fake.pt"
+
+test2_dataset:
+ name: "img-or-text"
+ test_name: "test2_bbc_dalle"
+ params:
+ real_pt: "encodings/predictions/text/merged/test2_bbc_dalle/real.pt"
+ fake_pt: "encodings/predictions/text/merged/test2_bbc_dalle/fake.pt"
+
+test3_dataset:
+ name: "img-or-text"
+ test_name: "test3_cnn_dalle"
+ params:
+ real_pt: "encodings/predictions/text/merged/test3_cnn_dalle/real.pt"
+ fake_pt: "encodings/predictions/text/merged/test3_cnn_dalle/fake.pt"
+
+test4_dataset:
+ name: "img-or-text"
+ test_name: "test4_bbc_sdxl"
+ params:
+ real_pt: "encodings/predictions/text/merged/test4_bbc_sdxl/real.pt"
+ fake_pt: "encodings/predictions/text/merged/test4_bbc_sdxl/fake.pt"
+
+test5_dataset:
+ name: "img-or-text"
+ test_name: "test5_cnn_sdxl"
+ params:
+ real_pt: "encodings/predictions/text/merged/test5_cnn_sdxl/real.pt"
+ fake_pt: "encodings/predictions/text/merged/test5_cnn_sdxl/fake.pt"
+
+testing:
+ batch_size: 64
+ save_path: "results/text/mirage-txt.jsonl"
\ No newline at end of file
diff --git a/configs/text/tbm-predictor.yaml b/configs/text/tbm-predictor.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9e1bdc878d5438074a7e924d69798e0e784c8b69
--- /dev/null
+++ b/configs/text/tbm-predictor.yaml
@@ -0,0 +1,59 @@
+model:
+ name: "tbm-predictor"
+
+training:
+ epochs: 100
+ batch_size: 64
+ learning_rate: 0.001
+ save_path: "checkpoints/text/tbm-predictor.pt"
+
+train_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/predictions/text/tbm-encoder/train/real.pt"
+ fake_pt: "encodings/predictions/text/tbm-encoder/train/fake.pt"
+
+val_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/predictions/text/tbm-encoder/validation/real.pt"
+ fake_pt: "encodings/predictions/text/tbm-encoder/validation/fake.pt"
+
+test1_dataset:
+ name: "img-or-text"
+ test_name: "test1_nyt_mj"
+ params:
+ real_pt: "encodings/predictions/text/tbm-encoder/test1_nyt_mj/real.pt"
+ fake_pt: "encodings/predictions/text/tbm-encoder/test1_nyt_mj/fake.pt"
+
+test2_dataset:
+ name: "img-or-text"
+ test_name: "test2_bbc_dalle"
+ params:
+ real_pt: "encodings/predictions/text/tbm-encoder/test2_bbc_dalle/real.pt"
+ fake_pt: "encodings/predictions/text/tbm-encoder/test2_bbc_dalle/fake.pt"
+
+test3_dataset:
+ name: "img-or-text"
+ test_name: "test3_cnn_dalle"
+ params:
+ real_pt: "encodings/predictions/text/tbm-encoder/test3_cnn_dalle/real.pt"
+ fake_pt: "encodings/predictions/text/tbm-encoder/test3_cnn_dalle/fake.pt"
+
+test4_dataset:
+ name: "img-or-text"
+ test_name: "test4_bbc_sdxl"
+ params:
+ real_pt: "encodings/predictions/text/tbm-encoder/test4_bbc_sdxl/real.pt"
+ fake_pt: "encodings/predictions/text/tbm-encoder/test4_bbc_sdxl/fake.pt"
+
+test5_dataset:
+ name: "img-or-text"
+ test_name: "test5_cnn_sdxl"
+ params:
+ real_pt: "encodings/predictions/text/tbm-encoder/test5_cnn_sdxl/real.pt"
+ fake_pt: "encodings/predictions/text/tbm-encoder/test5_cnn_sdxl/fake.pt"
+
+testing:
+ batch_size: 64
+ save_path: "results/text/tbm-predictor.jsonl"
\ No newline at end of file
diff --git a/main.py b/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..1200e8c626b0d59ef333ab944d03c34ae706553a
--- /dev/null
+++ b/main.py
@@ -0,0 +1,30 @@
+import asyncio
+from miragenews import run_multimodal_to_json
+
+async def main():
+ image_paths = [
+ "data/1.png",
+ "data/2.png"
+ ]
+
+ text = "This content looks AI-generated"
+
+ output = await run_multimodal_to_json(
+ image_paths=image_paths,
+ text=text,
+ output_json_path="result.json"
+ )
+
+ print("Saved result to:", output)
+
+asyncio.run(main())
+
+from AIGVDet import run_video_to_json
+
+if __name__ == "__main__":
+ result = run_video_to_json(
+ video_path="data/test_video.mp4",
+ output_json_path="results/video_result.json"
+ )
+
+ print(result)
\ No newline at end of file
diff --git "a/miragenews/.gradio/certificate.pem\357\200\272Zone.Identifier" "b/miragenews/.gradio/certificate.pem\357\200\272Zone.Identifier"
new file mode 100644
index 0000000000000000000000000000000000000000..d6c1ec682968c796b9f5e9e080cc6f674b57c766
Binary files /dev/null and "b/miragenews/.gradio/certificate.pem\357\200\272Zone.Identifier" differ
diff --git a/miragenews/README.md b/miragenews/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..4a5fea17cb42a906b4d23fdd9bf10ee825ee61a9
--- /dev/null
+++ b/miragenews/README.md
@@ -0,0 +1,119 @@
+
+
+# MiRAGeNews: Multimodal Realistic AI-Generated News Detection
+
+
+
+
+
+
+
+
+
+
+
+
+## Abstract
+The proliferation of inflammatory or misleading "fake" news content has become increasingly common in recent years. Simultaneously, it has become easier than ever to use AI tools to generate photorealistic images depicting any scene imaginable. Combining these two -- AI-generated fake news content -- is particularly potent and dangerous. To combat the spread of AI-generated fake news, we propose the MiRAGeNews Dataset, a dataset of 12,500 high-quality real and AI-generated image-caption pairs from state-of-the-art generators. We find that our dataset poses a significant challenge to humans (60% F-1) and state-of-the-art multi-modal LLMs (< 24% F-1). Using our dataset we train a multi-modal detector (MiRAGe) that improves by +5.1% F-1 over state-of-the-art baselines on image-caption pairs from out-of-domain image generators and news publishers. We release our code and data to aid future work on detecting AI-generated content.
+
+## ***MiRAGeNews*** Dataset
+***MiRAGeNews*** dataset contains a total of 15,000 pieces of real or AI-generated multimodal news (image-caption pairs) -- a training set of 10,000 pairs, a validation set of 2,500 pairs, and five test sets of 500 pairs each. Four of the test sets are out-of-domain data from unseen news publishers and image generators to evaluate detector's generalization ability.
+
+
+
+Download ***MiRAGeNews*** from [HuggingFace](https://huggingface.co/datasets/anson-huang/mirage-news):
+```py
+from datasets import load_dataset
+
+dataset = load_dataset("anson-huang/mirage-news")
+```
+
+To train ***MiRAGe*** detectors on the ***MiRAGeNews*** dataset, we need to first encode both images and text from the dataset:
+
+```
+$ python data/encode_image.py
+$ python data/encode_crops.py
+$ python data/encode_text.py
+```
+You can use the ```--custom``` and ```--read_dirs``` flags if you want to encode other datasets.
+
+## MiRAGe Detectors
+There are three detectors: **MiRAGe-Img** for Image-only Detection, **MiRAGe-Txt** for Text-only Detection, and **MiRAGe** for Multimodal Detection. The single-modal detectors are trained on predictions from a linear model and a concept bottleneck model(CBM). The multimodal detector directly inferences on the predictions from **MiRAGe-Img** and **MiRAGe-Txt** without further training. All of the pretrained checkpoints are in ```\checkpoints```
+
+
+### ***MiRAGe-Img***
+
+#### Training
+To train ***MiRAGe-Img***, we need to first train a linear model and a concept bottlenecks model(CBM) to get their predictions:
+```
+$ python train.py --mode image --model_class linear
+$ python train.py --mode image --model_class cbm-encoder
+$ python train.py --mode image --model_class cbm-predictor
+```
+Note that the CBM encoder encodes each image to a concept vector(D=300) based on its object-class crops and the CBM predictor outputs real/fake from the concept vectors.
+
+Then, we need to encode the predictions from the linear model and merge them with the concept vectors to obtain the D=301 vector before training ***MiRAGe-Img***.
+```
+$ python data/encode_predictions --mode image --model_class linear
+$ python data/encode_predictions --mode image --model_class merged
+$ python train.py --mode image --model_class mirage
+```
+
+#### Testing
+To test ***MiRAGe-Img***, run
+```
+$ python test.py --mode image --model_class mirage
+```
+Modify the ```--model_class``` for testing any subcomponents. The results of all five test sets would be saved in corresponding jsonl file in ```\results\image```
+
+
+
+### ***MiRAGe-Txt***
+
+#### Training
+To train ***MiRAGe-Txt***, we need to first train a linear model:
+```
+$ python train.py --mode text --model_class linear
+```
+We provided the concept vectors(D=18) from text bottleneck model(TBM) in ```encodings/predictions/text/tbm-encoder``` since it requires access to OpenAI API. See details of TBM here.
+
+Optionally, you can train a TBM predictor only using the TBM concepts:
+```
+$ python train.py --mode text --model_class tbm-predictor
+```
+
+Then, we need to encode the predictions from the linear model and merge them with the concept vectors to obtain the D=19 vector before training ***MiRAGe-Txt***.
+```
+$ python data/encode_predictions --mode text --model_class linear
+$ python data/encode_predictions --mode text --model_class merged
+$ python train.py --mode text --model_class mirage
+```
+#### Testing
+To test ***MiRAGe-Txt***, run
+```
+$ python test.py --mode text --model_class mirage
+```
+Modify the ```--model_class``` for testing any subcomponents. The results of all five test sets would be saved in corresponding jsonl file in ```\results\text```
+
+
+
+### ***MiRAGe***
+***MiRAGe*** detector uses trained ***MiRAGe-Img*** and ***MiRAGe-Txt*** and doesn't need further training. To test ***MiRAGe***, run:
+```
+$ python test.py --mode multimodal --model_class mirage
+```
+The results of all five test sets would be saved in corresponding jsonl file in ```\results\multimodal```
+
+
+Our detectors are more robust on out-of-domain (OOD) data from unseen news publishers and image generators than SOTA MLLMs and detectors.
+
+
+
+
+
+
+
+
+
+## Acknowledgement
+This research is supported in part by the Office of the Director of National Intelligence (ODNI), Intelligence Advanced Research Projects Activity (IARPA), via the HIATUS Program contract #2022-22072200005. The views and conclusions contained herein are those of the authors and should not be interpreted as necessarily representing the official policies, either expressed or implied, of ODNI, IARPA, or the U.S. Government. The U.S. Government is authorized to reproduce and distribute reprints for governmental purposes notwithstanding any copyright annotation therein.
diff --git a/miragenews/__init__.py b/miragenews/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3800de4105d6ddcb74890007a208becc142efd80
--- /dev/null
+++ b/miragenews/__init__.py
@@ -0,0 +1,3 @@
+from .main_exe import run_multimodal_to_json
+
+__all__ = ["run_multimodal_to_json"]
diff --git a/miragenews/analyzer.py b/miragenews/analyzer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/miragenews/app_text.py b/miragenews/app_text.py
new file mode 100644
index 0000000000000000000000000000000000000000..7546bb45362f8c2aad7c30060a84dc60a761e24a
--- /dev/null
+++ b/miragenews/app_text.py
@@ -0,0 +1,143 @@
+import gradio as gr
+import asyncio
+from dotenv import load_dotenv
+
+from img.resources import resources
+from img.file_utils import handle_image_list_change
+from img.core import analyze_saved_images
+
+from text_module.pipeline import verify_text_logic
+from text_module.TextAnalysisResult import TextAnalysisResult
+
+from merge_img_text import verify_multimodal_logic
+
+# --- INITIALIZATION ---
+load_dotenv()
+print("⏳ Loading Image Models...")
+resources.load_all()
+print("✅ Ready.")
+
+def text_tab_wrapper(text_input):
+ if not text_input:
+ return "Waiting...", "⚠️ Please enter text.", ""
+
+ res_obj = TextAnalysisResult()
+
+ verify_text_logic(text_input, res_obj)
+
+ auth = res_obj.get_authenticity_assessment()
+
+ report = f"""
+### 📄 Text Analysis Report
+**Assessment:** {auth}
+**Tools:** {res_obj.get_verification_tools_methods()}
+**Type:** {res_obj.get_synthetic_type()}
+
+#### Details:
+{res_obj.get_other_artifacts()}
+ """
+
+ log = "Analysis Finished."
+
+ return log, auth, report
+
+custom_css = """
+#progress_img, #progress_txt, #final_report {
+ opacity: 1 !important;
+ filter: none !important;
+ transition: none !important;
+}
+.generating {
+ border: none !important;
+}
+"""
+
+# --- GRADIO INTERFACE ---
+with gr.Blocks(theme=gr.themes.Soft(), title="Multi-Modal Forensic Tool", css=custom_css) as iface:
+
+ gr.Markdown("# 🕵️ AI Forensic & Fact-Checking Tool")
+
+ with gr.Tabs():
+
+ with gr.TabItem("🖼️ Image Forensics"):
+ saved_state_img = gr.State([])
+ hidden_status = gr.Textbox(visible=False)
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ img_input = gr.File(file_count="multiple", label="Upload Images", file_types=["image"])
+ analyze_btn = gr.Button("🔍 Analyze Image", variant="primary")
+
+ with gr.Column(scale=2):
+ output_json = gr.JSON(label="Analysis Results (JSON)")
+ output_md = gr.Markdown(label="Formatted Report")
+
+ # Events Tab 1
+ img_input.change(
+ fn=handle_image_list_change,
+ inputs=[img_input, saved_state_img],
+ outputs=[saved_state_img, hidden_status],
+ queue=False
+ )
+ analyze_btn.click(
+ fn=analyze_saved_images,
+ inputs=[saved_state_img],
+ outputs=[output_json, output_md]
+ )
+
+ with gr.TabItem("📄 Text Verification"):
+ with gr.Row():
+ with gr.Column(scale=1):
+ txt_input = gr.Textbox(lines=6, label="Input Text", placeholder="Enter text to verify...")
+ txt_btn = gr.Button("🕵️ Verify Text", variant="primary")
+
+ with gr.Column(scale=2):
+ txt_log = gr.Textbox(label="Logs", lines=2, interactive=False)
+ txt_status = gr.Label(label="Status")
+ txt_report = gr.Markdown(label="Detailed Report")
+
+ # Events Tab 2
+ txt_btn.click(
+ fn=text_tab_wrapper,
+ inputs=[txt_input],
+ outputs=[txt_log, txt_status, txt_report]
+ )
+
+
+ with gr.TabItem("⚖️ Multi-modal Verification"):
+ gr.Markdown("### Cross-verify Image & Text Context")
+
+ saved_state_multi = gr.State([])
+ hidden_status_multi = gr.Textbox(visible=False)
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ mm_img_input = gr.File(label="Upload Related Images", file_count="multiple", file_types=["image"])
+ mm_txt_input = gr.Textbox(label="Context Text", lines=5, placeholder="Paste text here...")
+ mm_btn = gr.Button("🚀 Start Verification", variant="primary")
+
+ with gr.Column(scale=2):
+ gr.Markdown("#### Analysis Progress")
+
+ mm_prog_img = gr.HTML(elem_id="progress_img", label="Image Status", value='💤 Waiting...
')
+ mm_prog_txt = gr.HTML(elem_id="progress_txt", label="Text Status", value='💤 Waiting...
')
+
+ gr.Markdown("---")
+ mm_report = gr.Markdown(elem_id="final_report", label="Final Report")
+
+ mm_img_input.change(
+ fn=handle_image_list_change,
+ inputs=[mm_img_input, saved_state_multi],
+ outputs=[saved_state_multi, hidden_status_multi],
+ queue=False
+ )
+
+ mm_btn.click(
+ fn=verify_multimodal_logic,
+ inputs=[saved_state_multi, mm_txt_input],
+ outputs=[mm_prog_img, mm_prog_txt, mm_report],
+ show_progress="hidden"
+ )
+
+if __name__ == "__main__":
+ iface.queue().launch(share=True)
\ No newline at end of file
diff --git a/miragenews/checkpoints/image/best-mirage-img.pt b/miragenews/checkpoints/image/best-mirage-img.pt
new file mode 100644
index 0000000000000000000000000000000000000000..5c89ed0b73bb6831ec1b6912ab83b8b91887989e
--- /dev/null
+++ b/miragenews/checkpoints/image/best-mirage-img.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:755e26bfe97830e03e4c475596ade7168e1df18c57d89162d685b1148ae9f5a8
+size 2375
diff --git a/miragenews/checkpoints/image/cbm-encoder.pt b/miragenews/checkpoints/image/cbm-encoder.pt
new file mode 100644
index 0000000000000000000000000000000000000000..4e70876f43511c3ff3767a7653e595990ff84d0b
--- /dev/null
+++ b/miragenews/checkpoints/image/cbm-encoder.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2a90c02450ae4a105414c6701bf571d1a498571083a725d916119d792cbdcf2f
+size 1917947
diff --git a/miragenews/checkpoints/image/cbm-predictor.pt b/miragenews/checkpoints/image/cbm-predictor.pt
new file mode 100644
index 0000000000000000000000000000000000000000..82d2e0cd974701eb4d60e8c2c72a3f03ecb461ae
--- /dev/null
+++ b/miragenews/checkpoints/image/cbm-predictor.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2eba7ff0cd8db1c69961478551b2a40e47687f281d7e853f95a6650908e8c3df
+size 2387
diff --git a/miragenews/checkpoints/image/img-linear.pt b/miragenews/checkpoints/image/img-linear.pt
new file mode 100644
index 0000000000000000000000000000000000000000..ffc5e1ac651e2c85711dde93dec50ef53690b290
--- /dev/null
+++ b/miragenews/checkpoints/image/img-linear.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:457af200c5b62936c0dbcfc9a0e930e0c05ea18bea16cc9d56dd869e445c7833
+size 6839
diff --git a/miragenews/checkpoints/text/mirage-txt.pt b/miragenews/checkpoints/text/mirage-txt.pt
new file mode 100644
index 0000000000000000000000000000000000000000..a987a21403c6f14b8ded97e127af8f2b5c0143d7
--- /dev/null
+++ b/miragenews/checkpoints/text/mirage-txt.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4b5adbd2992bec2fabb6770debc54c1aab26102fb1ec2f6480da69c5e11797b3
+size 1267
diff --git a/miragenews/checkpoints/text/tbm-predictor.pt b/miragenews/checkpoints/text/tbm-predictor.pt
new file mode 100644
index 0000000000000000000000000000000000000000..f0e4b41bd9a58a90512c3fa5657949e107a8b726
--- /dev/null
+++ b/miragenews/checkpoints/text/tbm-predictor.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:699d961ef46e6c9f997de6e3a7d3003e9ee32c7ebcc0a9cd77d2cde8e36071b9
+size 1279
diff --git a/miragenews/checkpoints/text/txt-linear.pt b/miragenews/checkpoints/text/txt-linear.pt
new file mode 100644
index 0000000000000000000000000000000000000000..6aca03bf3336306ef63e6cca34d93777ebd089c5
--- /dev/null
+++ b/miragenews/checkpoints/text/txt-linear.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d2f11dab8f689ea8052f86489d9b5e821c043d8018b7bac0431e40af79685bcb
+size 4275
diff --git a/miragenews/configs/image/cbm-encoder.yaml b/miragenews/configs/image/cbm-encoder.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0c566ffa97961004b179a6eb97e1c5dd265a6fd7
--- /dev/null
+++ b/miragenews/configs/image/cbm-encoder.yaml
@@ -0,0 +1,8 @@
+model:
+ name: "cbm-encoder"
+
+training:
+ epochs: 100
+ batch_size: 64
+ learning_rate: 0.001
+ save_path: "checkpoints/image/cbm-encoder.pt"
\ No newline at end of file
diff --git a/miragenews/configs/image/cbm-predictor.yaml b/miragenews/configs/image/cbm-predictor.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0ec79b79ae5e17e101630cc4a1eaeb80e6b30ab5
--- /dev/null
+++ b/miragenews/configs/image/cbm-predictor.yaml
@@ -0,0 +1,59 @@
+model:
+ name: "cbm-predictor"
+
+training:
+ epochs: 100
+ batch_size: 64
+ learning_rate: 0.001
+ save_path: "checkpoints/image/cbm-predictor.pt"
+
+train_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/predictions/image/cbm-encoder/train/real.pt"
+ fake_pt: "encodings/predictions/image/cbm-encoder/train/fake.pt"
+
+val_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/predictions/image/cbm-encoder/validation/real.pt"
+ fake_pt: "encodings/predictions/image/cbm-encoder/validation/fake.pt"
+
+test1_dataset:
+ name: "img-or-text"
+ test_name: "test1_nyt_mj"
+ params:
+ real_pt: "encodings/predictions/image/cbm-encoder/test1_nyt_mj/real.pt"
+ fake_pt: "encodings/predictions/image/cbm-encoder/test1_nyt_mj/fake.pt"
+
+test2_dataset:
+ name: "img-or-text"
+ test_name: "test2_bbc_dalle"
+ params:
+ real_pt: "encodings/predictions/image/cbm-encoder/test2_bbc_dalle/real.pt"
+ fake_pt: "encodings/predictions/image/cbm-encoder/test2_bbc_dalle/fake.pt"
+
+test3_dataset:
+ name: "img-or-text"
+ test_name: "test3_cnn_dalle"
+ params:
+ real_pt: "encodings/predictions/image/cbm-encoder/test3_cnn_dalle/real.pt"
+ fake_pt: "encodings/predictions/image/cbm-encoder/test3_cnn_dalle/fake.pt"
+
+test4_dataset:
+ name: "img-or-text"
+ test_name: "test4_bbc_sdxl"
+ params:
+ real_pt: "encodings/predictions/image/cbm-encoder/test4_bbc_sdxl/real.pt"
+ fake_pt: "encodings/predictions/image/cbm-encoder/test4_bbc_sdxl/fake.pt"
+
+test5_dataset:
+ name: "img-or-text"
+ test_name: "test5_cnn_sdxl"
+ params:
+ real_pt: "encodings/predictions/image/cbm-encoder/test5_cnn_sdxl/real.pt"
+ fake_pt: "encodings/predictions/image/cbm-encoder/test5_cnn_sdxl/fake.pt"
+
+testing:
+ batch_size: 64
+ save_path: "results/image/cbm-predictor.jsonl"
\ No newline at end of file
diff --git a/miragenews/configs/image/linear.yaml b/miragenews/configs/image/linear.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..140016aae0e48db33c8e03f2746635d18648c96b
--- /dev/null
+++ b/miragenews/configs/image/linear.yaml
@@ -0,0 +1,59 @@
+model:
+ name: "img-linear"
+
+train_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/image/train/real.pt"
+ fake_pt: "encodings/image/train/fake.pt"
+
+val_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/image/validation/real.pt"
+ fake_pt: "encodings/image/validation/fake.pt"
+
+training:
+ epochs: 100
+ batch_size: 64
+ learning_rate: 0.001
+ save_path: "checkpoints/image/img-linear.pt"
+
+test1_dataset:
+ name: "img-or-text"
+ test_name: "test1_nyt_mj"
+ params:
+ real_pt: "encodings/image/test1_nyt_mj/real.pt"
+ fake_pt: "encodings/image/test1_nyt_mj/fake.pt"
+
+test2_dataset:
+ name: "img-or-text"
+ test_name: "test2_bbc_dalle"
+ params:
+ real_pt: "encodings/image/test2_bbc_dalle/real.pt"
+ fake_pt: "encodings/image/test2_bbc_dalle/fake.pt"
+
+test3_dataset:
+ name: "img-or-text"
+ test_name: "test3_cnn_dalle"
+ params:
+ real_pt: "encodings/image/test3_cnn_dalle/real.pt"
+ fake_pt: "encodings/image/test3_cnn_dalle/fake.pt"
+
+test4_dataset:
+ name: "img-or-text"
+ test_name: "test4_bbc_sdxl"
+ params:
+ real_pt: "encodings/image/test4_bbc_sdxl/real.pt"
+ fake_pt: "encodings/image/test4_bbc_sdxl/fake.pt"
+
+test5_dataset:
+ name: "img-or-text"
+ test_name: "test5_cnn_sdxl"
+ params:
+ real_pt: "encodings/image/test5_cnn_sdxl/real.pt"
+ fake_pt: "encodings/image/test5_cnn_sdxl/fake.pt"
+
+testing:
+ batch_size: 64
+ save_path: "results/image/img-linear.jsonl"
diff --git a/miragenews/configs/image/mirage.yaml b/miragenews/configs/image/mirage.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c8f59ff54a43a3445c2c61c1045fbbc3edcacf65
--- /dev/null
+++ b/miragenews/configs/image/mirage.yaml
@@ -0,0 +1,59 @@
+model:
+ name: "mirage-img"
+
+training:
+ epochs: 100
+ batch_size: 64
+ learning_rate: 0.001
+ save_path: "checkpoints/image/mirage-img.pt"
+
+train_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/predictions/image/merged/train/real.pt"
+ fake_pt: "encodings/predictions/image/merged/train/fake.pt"
+
+val_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/predictions/image/merged/validation/real.pt"
+ fake_pt: "encodings/predictions/image/merged/validation/fake.pt"
+
+test1_dataset:
+ name: "img-or-text"
+ test_name: "test1_nyt_mj"
+ params:
+ real_pt: "encodings/predictions/image/merged/test1_nyt_mj/real.pt"
+ fake_pt: "encodings/predictions/image/merged/test1_nyt_mj/fake.pt"
+
+test2_dataset:
+ name: "img-or-text"
+ test_name: "test2_bbc_dalle"
+ params:
+ real_pt: "encodings/predictions/image/merged/test2_bbc_dalle/real.pt"
+ fake_pt: "encodings/predictions/image/merged/test2_bbc_dalle/fake.pt"
+
+test3_dataset:
+ name: "img-or-text"
+ test_name: "test3_cnn_dalle"
+ params:
+ real_pt: "encodings/predictions/image/merged/test3_cnn_dalle/real.pt"
+ fake_pt: "encodings/predictions/image/merged/test3_cnn_dalle/fake.pt"
+
+test4_dataset:
+ name: "img-or-text"
+ test_name: "test4_bbc_sdxl"
+ params:
+ real_pt: "encodings/predictions/image/merged/test4_bbc_sdxl/real.pt"
+ fake_pt: "encodings/predictions/image/merged/test4_bbc_sdxl/fake.pt"
+
+test5_dataset:
+ name: "img-or-text"
+ test_name: "test5_cnn_sdxl"
+ params:
+ real_pt: "encodings/predictions/image/merged/test5_cnn_sdxl/real.pt"
+ fake_pt: "encodings/predictions/image/merged/test5_cnn_sdxl/fake.pt"
+
+testing:
+ batch_size: 64
+ save_path: "results/image/mirage-img.jsonl"
diff --git a/miragenews/configs/multimodal/mirage.yaml b/miragenews/configs/multimodal/mirage.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d0a92529d3bb8f53f5fdc89930f3ef36993e9fc9
--- /dev/null
+++ b/miragenews/configs/multimodal/mirage.yaml
@@ -0,0 +1,74 @@
+model:
+ name: "mirage"
+
+training:
+ epochs: 100
+ batch_size: 64
+ learning_rate: 0.001
+ image_model_path: "checkpoints/image/best-mirage-img.pt"
+ text_model_path: "checkpoints/text/best-mirage-txt.pt"
+
+train_dataset:
+ name: "multimodal"
+ params:
+ real_img_pt: "encodings/predictions/image/merged/train/real.pt"
+ fake_img_pt: "encodings/predictions/image/merged/train/fake.pt"
+ real_text_pt: "encodings/predictions/text/merged/train/real.pt"
+ fake_text_pt: "encodings/predictions/text/merged/train/fake.pt"
+
+val_dataset:
+ name: "multimodal"
+ params:
+ real_img_pt: "encodings/predictions/image/merged/validation/real.pt"
+ fake_img_pt: "encodings/predictions/image/merged/validation/fake.pt"
+ real_text_pt: "encodings/predictions/text/merged/validation/real.pt"
+ fake_text_pt: "encodings/predictions/text/merged/validation/fake.pt"
+
+test1_dataset:
+ name: "multimodal"
+ test_name: "test1_nyt_mj"
+ params:
+ real_img_pt: "encodings/predictions/image/merged/test1_nyt_mj/real.pt"
+ fake_img_pt: "encodings/predictions/image/merged/test1_nyt_mj/fake.pt"
+ real_text_pt: "encodings/predictions/text/merged/test1_nyt_mj/real.pt"
+ fake_text_pt: "encodings/predictions/text/merged/test1_nyt_mj/fake.pt"
+
+test2_dataset:
+ name: "multimodal"
+ test_name: "test2_bbc_dalle"
+ params:
+ real_img_pt: "encodings/predictions/image/merged/test2_bbc_dalle/real.pt"
+ fake_img_pt: "encodings/predictions/image/merged/test2_bbc_dalle/fake.pt"
+ real_text_pt: "encodings/predictions/text/merged/test2_bbc_dalle/real.pt"
+ fake_text_pt: "encodings/predictions/text/merged/test2_bbc_dalle/fake.pt"
+
+test3_dataset:
+ name: "multimodal"
+ test_name: "test3_cnn_dalle"
+ params:
+ real_img_pt: "encodings/predictions/image/merged/test3_cnn_dalle/real.pt"
+ fake_img_pt: "encodings/predictions/image/merged/test3_cnn_dalle/fake.pt"
+ real_text_pt: "encodings/predictions/text/merged/test3_cnn_dalle/real.pt"
+ fake_text_pt: "encodings/predictions/text/merged/test3_cnn_dalle/fake.pt"
+
+test4_dataset:
+ name: "multimodal"
+ test_name: "test4_bbc_sdxl"
+ params:
+ real_img_pt: "encodings/predictions/image/merged/test4_bbc_sdxl/real.pt"
+ fake_img_pt: "encodings/predictions/image/merged/test4_bbc_sdxl/fake.pt"
+ real_text_pt: "encodings/predictions/text/merged/test4_bbc_sdxl/real.pt"
+ fake_text_pt: "encodings/predictions/text/merged/test4_bbc_sdxl/fake.pt"
+
+test5_dataset:
+ name: "multimodal"
+ test_name: "test5_cnn_sdxl"
+ params:
+ real_img_pt: "encodings/predictions/image/merged/test5_cnn_sdxl/real.pt"
+ fake_img_pt: "encodings/predictions/image/merged/test5_cnn_sdxl/fake.pt"
+ real_text_pt: "encodings/predictions/text/merged/test5_cnn_sdxl/real.pt"
+ fake_text_pt: "encodings/predictions/text/merged/test5_cnn_sdxl/fake.pt"
+
+testing:
+ batch_size: 64
+ save_path: "results/multimodal/mirage.jsonl"
\ No newline at end of file
diff --git a/miragenews/configs/text/linear.yaml b/miragenews/configs/text/linear.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..368ab3f136c6ddae14fc33e3d2771b5b9acaa87a
--- /dev/null
+++ b/miragenews/configs/text/linear.yaml
@@ -0,0 +1,59 @@
+model:
+ name: "txt-linear"
+
+train_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/text/train/real.pt"
+ fake_pt: "encodings/text/train/fake.pt"
+
+val_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/text/validation/real.pt"
+ fake_pt: "encodings/text/validation/fake.pt"
+
+training:
+ epochs: 100
+ batch_size: 64
+ learning_rate: 0.001
+ save_path: "checkpoints/text/txt-linear.pt"
+
+test1_dataset:
+ name: "img-or-text"
+ test_name: "test1_nyt_mj"
+ params:
+ real_pt: "encodings/text/test1_nyt_mj/real.pt"
+ fake_pt: "encodings/text/test1_nyt_mj/fake.pt"
+
+test2_dataset:
+ name: "img-or-text"
+ test_name: "test2_bbc_dalle"
+ params:
+ real_pt: "encodings/text/test2_bbc_dalle/real.pt"
+ fake_pt: "encodings/text/test2_bbc_dalle/fake.pt"
+
+test3_dataset:
+ name: "img-or-text"
+ test_name: "test3_cnn_dalle"
+ params:
+ real_pt: "encodings/text/test3_cnn_dalle/real.pt"
+ fake_pt: "encodings/text/test3_cnn_dalle/fake.pt"
+
+test4_dataset:
+ name: "img-or-text"
+ test_name: "test4_bbc_sdxl"
+ params:
+ real_pt: "encodings/text/test4_bbc_sdxl/real.pt"
+ fake_pt: "encodings/text/test4_bbc_sdxl/fake.pt"
+
+test5_dataset:
+ name: "img-or-text"
+ test_name: "test5_cnn_sdxl"
+ params:
+ real_pt: "encodings/text/test5_cnn_sdxl/real.pt"
+ fake_pt: "encodings/text/test5_cnn_sdxl/fake.pt"
+
+testing:
+ batch_size: 64
+ save_path: "results/text/txt-linear.jsonl"
diff --git a/miragenews/configs/text/mirage.yaml b/miragenews/configs/text/mirage.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..791963dd2591b81d62186a8c2dbfb91b2ff4fe5b
--- /dev/null
+++ b/miragenews/configs/text/mirage.yaml
@@ -0,0 +1,59 @@
+model:
+ name: "mirage-txt"
+
+training:
+ epochs: 100
+ batch_size: 64
+ learning_rate: 0.001
+ save_path: "checkpoints/text/mirage-txt.pt"
+
+train_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/predictions/text/merged/train/real.pt"
+ fake_pt: "encodings/predictions/text/merged/train/fake.pt"
+
+val_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/predictions/text/merged/validation/real.pt"
+ fake_pt: "encodings/predictions/text/merged/validation/fake.pt"
+
+test1_dataset:
+ name: "img-or-text"
+ test_name: "test1_nyt_mj"
+ params:
+ real_pt: "encodings/predictions/text/merged/test1_nyt_mj/real.pt"
+ fake_pt: "encodings/predictions/text/merged/test1_nyt_mj/fake.pt"
+
+test2_dataset:
+ name: "img-or-text"
+ test_name: "test2_bbc_dalle"
+ params:
+ real_pt: "encodings/predictions/text/merged/test2_bbc_dalle/real.pt"
+ fake_pt: "encodings/predictions/text/merged/test2_bbc_dalle/fake.pt"
+
+test3_dataset:
+ name: "img-or-text"
+ test_name: "test3_cnn_dalle"
+ params:
+ real_pt: "encodings/predictions/text/merged/test3_cnn_dalle/real.pt"
+ fake_pt: "encodings/predictions/text/merged/test3_cnn_dalle/fake.pt"
+
+test4_dataset:
+ name: "img-or-text"
+ test_name: "test4_bbc_sdxl"
+ params:
+ real_pt: "encodings/predictions/text/merged/test4_bbc_sdxl/real.pt"
+ fake_pt: "encodings/predictions/text/merged/test4_bbc_sdxl/fake.pt"
+
+test5_dataset:
+ name: "img-or-text"
+ test_name: "test5_cnn_sdxl"
+ params:
+ real_pt: "encodings/predictions/text/merged/test5_cnn_sdxl/real.pt"
+ fake_pt: "encodings/predictions/text/merged/test5_cnn_sdxl/fake.pt"
+
+testing:
+ batch_size: 64
+ save_path: "results/text/mirage-txt.jsonl"
\ No newline at end of file
diff --git a/miragenews/configs/text/tbm-predictor.yaml b/miragenews/configs/text/tbm-predictor.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9e1bdc878d5438074a7e924d69798e0e784c8b69
--- /dev/null
+++ b/miragenews/configs/text/tbm-predictor.yaml
@@ -0,0 +1,59 @@
+model:
+ name: "tbm-predictor"
+
+training:
+ epochs: 100
+ batch_size: 64
+ learning_rate: 0.001
+ save_path: "checkpoints/text/tbm-predictor.pt"
+
+train_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/predictions/text/tbm-encoder/train/real.pt"
+ fake_pt: "encodings/predictions/text/tbm-encoder/train/fake.pt"
+
+val_dataset:
+ name: "img-or-text"
+ params:
+ real_pt: "encodings/predictions/text/tbm-encoder/validation/real.pt"
+ fake_pt: "encodings/predictions/text/tbm-encoder/validation/fake.pt"
+
+test1_dataset:
+ name: "img-or-text"
+ test_name: "test1_nyt_mj"
+ params:
+ real_pt: "encodings/predictions/text/tbm-encoder/test1_nyt_mj/real.pt"
+ fake_pt: "encodings/predictions/text/tbm-encoder/test1_nyt_mj/fake.pt"
+
+test2_dataset:
+ name: "img-or-text"
+ test_name: "test2_bbc_dalle"
+ params:
+ real_pt: "encodings/predictions/text/tbm-encoder/test2_bbc_dalle/real.pt"
+ fake_pt: "encodings/predictions/text/tbm-encoder/test2_bbc_dalle/fake.pt"
+
+test3_dataset:
+ name: "img-or-text"
+ test_name: "test3_cnn_dalle"
+ params:
+ real_pt: "encodings/predictions/text/tbm-encoder/test3_cnn_dalle/real.pt"
+ fake_pt: "encodings/predictions/text/tbm-encoder/test3_cnn_dalle/fake.pt"
+
+test4_dataset:
+ name: "img-or-text"
+ test_name: "test4_bbc_sdxl"
+ params:
+ real_pt: "encodings/predictions/text/tbm-encoder/test4_bbc_sdxl/real.pt"
+ fake_pt: "encodings/predictions/text/tbm-encoder/test4_bbc_sdxl/fake.pt"
+
+test5_dataset:
+ name: "img-or-text"
+ test_name: "test5_cnn_sdxl"
+ params:
+ real_pt: "encodings/predictions/text/tbm-encoder/test5_cnn_sdxl/real.pt"
+ fake_pt: "encodings/predictions/text/tbm-encoder/test5_cnn_sdxl/fake.pt"
+
+testing:
+ batch_size: 64
+ save_path: "results/text/tbm-predictor.jsonl"
\ No newline at end of file
diff --git a/miragenews/data/__init__.py b/miragenews/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..28a0879bc8c16929c16c504572ce41736562fe7a
--- /dev/null
+++ b/miragenews/data/__init__.py
@@ -0,0 +1,114 @@
+from .dataset import *
+import yaml
+from PIL import Image
+import numpy as np
+from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
+
+def load_config(config_path):
+ with open(config_path, "r") as file:
+ config = yaml.safe_load(file)
+ return config
+
+def get_dataset(config, is_eval=False, test_set=None):
+ """
+ Retrieves the dataset class specified in the config and initializes it with provided parameters.
+
+ Args:
+ config (dict): Configuration dictionary loaded from a YAML file.
+ is_eval (bool): Flag indicating whether to load an evaluation (validation) dataset.
+ test_set (str, optional): Specifies the test dataset to load. Use 'test1', 'test2', etc.,
+ for individual test sets, 'all' for all test sets, or None for train/validation.
+
+ Returns:
+ Tuple[Dataset, str] or List[Tuple[Dataset, str]]: Returns (dataset instance, test_name) for a single dataset,
+ or a list of (dataset instance, test_name) tuples for all test sets.
+ """
+ # Mapping dataset names to their classes
+ dataset_classes = {
+ "img-or-text": MiRAGeImageOrTextDataset,
+ "multimodal": MiRAGeNewsDataset,
+ # Add other dataset classes here if needed
+ }
+
+ # Handle loading of all test datasets if specified
+ if test_set == "all":
+ return [
+ (dataset_classes[config[key]['name']](**config[key]['params']), config[key]['test_name'])
+ for key in config if key.startswith("test") and key != "testing" and config[key]['name'] in dataset_classes
+ ]
+
+ # Determine the appropriate dataset section in config
+ dataset_key = test_set if test_set else ('val_dataset' if is_eval else 'train_dataset')
+
+ # Retrieve dataset name, parameters, and test name (if applicable)
+ dataset_name = config[dataset_key]['name']
+ dataset_params = config[dataset_key]['params']
+ test_name = config[dataset_key].get('test_name', '')
+
+ # Initialize and return the dataset and test name
+ if dataset_name in dataset_classes:
+ dataset_class = dataset_classes[dataset_name]
+ return dataset_class(**dataset_params), test_name
+ else:
+ raise ValueError(f"Dataset '{dataset_name}' not recognized. Please check the config file.")
+
+
+def get_object_class(class_file_path='data/class_names.txt', limit=300):
+ """
+ Loads object class names from a text file.
+
+ Args:
+ class_file_path (str): Path to the text file containing class names, with each class name on a separate line.
+ limit (int, optional): The maximum number of classes to load. If None, loads all classes.
+
+ Returns:
+ List[str]: A list of class names.
+ """
+ classes = []
+ with open(class_file_path, 'r', encoding='utf-8') as file:
+ classes = [line.strip() for line in file][:limit]
+ return classes
+
+def get_object_class_caption(class_file_path='data/class_names.txt', limit=300):
+ """
+ Loads object class captions from a text file.
+
+ Args:
+ class_file_path (str): Path to the text file containing class names, with each class name on a separate line.
+ limit (int, optional): The maximum number of classes to load. If None, loads all classes.
+
+ Returns:
+ List[str]: A list of class captions.
+ """
+ classes = []
+ with open(class_file_path, 'r', encoding='utf-8') as file:
+ classes = [f'a photo of {line.strip().lower()}' for line in file][:limit]
+ return classes
+
+
+def get_object_class_config(class_name):
+ config = {
+ 'train_dataset': {
+ 'name': 'img-or-text',
+ 'params': {
+ 'real_pt': f"encodings/crops/{class_name}/train/real.pt",
+ 'fake_pt': f"encodings/crops/{class_name}/train/fake.pt",
+ }
+ },
+ 'val_dataset': {
+ 'name': 'img-or-text',
+ 'params': {
+ 'real_pt': f"encodings/crops/{class_name}/validation/real.pt",
+ 'fake_pt': f"encodings/crops/{class_name}/validation/fake.pt",
+ }
+ }
+ }
+ return config
+
+def get_preprocessed_image(pixel_values):
+ pixel_values = pixel_values.squeeze().numpy()
+ unnormalized_image = (pixel_values * np.array(OPENAI_CLIP_STD)[:, None, None]) + np.array(OPENAI_CLIP_MEAN)[:, None, None]
+ unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
+ unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)
+ unnormalized_image = Image.fromarray(unnormalized_image)
+ return unnormalized_image
\ No newline at end of file
diff --git a/miragenews/data/dataset.py b/miragenews/data/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..80ff2773b23bbe2e64902334ff3ac6c5c5bfd8fb
--- /dev/null
+++ b/miragenews/data/dataset.py
@@ -0,0 +1,42 @@
+import torch
+
+class MiRAGeImageOrTextDataset(torch.utils.data.Dataset):
+ def __init__(self, real_pt, fake_pt):
+ # Load real and fake tensors
+ real_pt = torch.load(real_pt)
+ fake_pt = torch.load(fake_pt)
+
+ # Concatenate real and fake images
+ self.all_data = torch.cat((real_pt, fake_pt), dim=0)
+ self.labels = torch.cat((torch.zeros(len(real_pt)), torch.ones(len(fake_pt))))
+
+ def __getitem__(self, index):
+ features = self.all_data[index]
+ label = self.labels[index].item()
+ return features, label
+
+ def __len__(self):
+ return len(self.all_data)
+
+
+class MiRAGeNewsDataset(torch.utils.data.Dataset):
+ def __init__(self, real_img_pt, fake_img_pt, real_text_pt, fake_text_pt):
+ # Load real and fake tensors
+ real_img_pt = torch.load(real_img_pt)
+ fake_img_pt = torch.load(fake_img_pt)
+ real_text_pt = torch.load(real_text_pt)
+ fake_text_pt = torch.load(fake_text_pt)
+
+ # Concatenate real and fake news
+ self.all_imgs = torch.cat((real_img_pt, fake_img_pt), dim=0)
+ self.all_texts = torch.cat((real_text_pt, fake_text_pt), dim=0)
+ self.labels = torch.cat((torch.zeros(len(real_img_pt)), torch.ones(len(fake_img_pt))))
+
+ def __getitem__(self, index):
+ image_features = self.all_imgs[index]
+ text_features = self.all_texts[index]
+ label = self.labels[index].item()
+ return image_features, text_features, label
+
+ def __len__(self):
+ return len(self.all_imgs)
diff --git a/miragenews/data/encode_crops.py b/miragenews/data/encode_crops.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ce9670bb0a34430b569f59e49764e13d390e219
--- /dev/null
+++ b/miragenews/data/encode_crops.py
@@ -0,0 +1,170 @@
+import argparse
+import torch
+import os
+from PIL import Image, ImageDraw, ImageFont
+from tqdm import tqdm
+from transformers import Owlv2Processor, Owlv2ForObjectDetection, Blip2Processor, Blip2ForConditionalGeneration
+from datasets import load_dataset
+from data import get_object_class_caption, get_preprocessed_image
+
+def save_crops(image, boxes, phrases, image_name, split, image_type, output_dir):
+ """Save cropped objects from the image into a structured directory."""
+ image_width, image_height = image.size # Get image dimensions
+
+ for i, box in enumerate(boxes):
+ # Convert box to integer and ensure coordinates are within the image dimensions
+ x_min, y_min, x_max, y_max = map(int, box.tolist())
+ x_min = max(0, x_min)
+ y_min = max(0, y_min)
+ x_max = min(image_width, x_max)
+ y_max = min(image_height, y_max)
+
+ # Skip very small boxes
+ if (x_max - x_min) < 64 or (y_max - y_min) < 64:
+ continue
+
+ # Crop the image based on the adjusted box
+ crop_img = image.crop((x_min, y_min, x_max, y_max))
+
+ # Create directory structure for saving
+ class_dir = os.path.join(output_dir, phrases[i], split, image_type)
+ os.makedirs(class_dir, exist_ok=True)
+
+ # Save the crop
+ crop_path = os.path.join(class_dir, f"{image_name}_{image_type}_{str(i).zfill(2)}.jpg")
+ crop_img.save(crop_path)
+
+def annotate_image(image, boxes, phrases, output_path="annotated_image.jpg"):
+ """Draw bounding boxes and labels on the image and save it."""
+ draw = ImageDraw.Draw(image)
+
+ # Define font for the labels
+ try:
+ font = ImageFont.truetype("arial.ttf", 15)
+ except IOError:
+ font = ImageFont.load_default()
+
+ # Draw each bounding box with label
+ for i, box in enumerate(boxes):
+ x_min, y_min, x_max, y_max = map(int, box.tolist())
+ label = phrases[i]
+
+ # Draw bounding box
+ draw.rectangle([(x_min, y_min), (x_max, y_max)], outline="red", width=2)
+
+ # Draw label
+ text_size = draw.textsize(label, font=font)
+ draw.rectangle([(x_min, y_min - text_size[1]), (x_min + text_size[0], y_min)], fill="red")
+ draw.text((x_min, y_min - text_size[1]), label, fill="white", font=font)
+
+ image.save(output_path)
+ print(f"Annotated image saved to {output_path}")
+
+def encode_and_save_crops(model, processor, device, output_dir, encodings_dir, batch_size=32):
+ """Encode all saved crops in output_dir and save the encodings to encodings_dir with the same structure."""
+ for root, _, files in os.walk(output_dir):
+ if files:
+ encodings = []
+ batch = []
+ for crop_filename in tqdm(sorted(files), desc=f"Encoding crops in {root}"):
+ crop_path = os.path.join(root, crop_filename)
+ image = processor(Image.open(crop_path).convert("RGB")).data['pixel_values'][0]
+ batch.append(torch.Tensor(image))
+
+ if len(batch) == batch_size:
+ encodings.append(process_batch(batch, model, device))
+ batch = []
+
+ if batch:
+ encodings.append(process_batch(batch, model, device))
+
+ # Determine the output path for encoding, mirroring the structure of output_dir within encodings_dir
+ relative_path = os.path.relpath(root, output_dir)
+ encoding_output_path = os.path.join(encodings_dir, f"{relative_path}.pt")
+ os.makedirs(os.path.dirname(encoding_output_path), exist_ok=True)
+ torch.save(torch.cat(encodings), encoding_output_path)
+ print(f"Saved encodings to {encoding_output_path}")
+
+def process_batch(images, model, device):
+ """Process a batch of images through the model and extract embeddings."""
+ images_tensor = torch.stack(images).to(device)
+ with torch.no_grad():
+ return model(images_tensor).pooler_output
+
+# Main processing function
+
+def process_dataset_or_directory(custom=False, read_dirs=None, output_dir="data/crops", encodings_dir="encodings/crops", batch_size=128):
+ device = "cuda"
+
+ # Load models and processors
+ object_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
+ object_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device)
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
+ model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").vision_model.to(device)
+ objects = get_object_class_caption()
+
+ # Process custom directories
+ if custom and read_dirs:
+ for read_dir in read_dirs:
+ for image_type in ["real", "fake"]:
+ dir_path = os.path.join(read_dir, image_type)
+ if os.path.exists(dir_path):
+ for idx, image_name in enumerate(tqdm(sorted(os.listdir(dir_path)), desc=f"Processing {image_type} in {read_dir}")):
+ image = Image.open(os.path.join(dir_path, image_name)).convert("RGB")
+
+ # Object detection
+ inputs = object_processor(text=objects, images=image, return_tensors="pt").to(device)
+ unnormalized_image = get_preprocessed_image(inputs.pixel_values.cpu())
+ outputs = object_model(**inputs)
+ target_sizes = torch.Tensor([unnormalized_image.size[::-1]])
+ results = object_processor.post_process_object_detection(outputs=outputs, threshold=0.2, target_sizes=target_sizes)[0]
+
+ # Save crops
+ boxes, labels = results["boxes"].cpu(), results["labels"]
+ phrases = [objects[i].split("of ", 1)[1].replace(' ', '_') for i in labels]
+ save_crops(unnormalized_image, boxes, phrases, image_name, read_dir, image_type, output_dir)
+
+
+ # Process HF dataset
+ else:
+ dataset_name = "anson-huang/mirage-news"
+ available_splits = load_dataset(dataset_name).keys()
+ for split in available_splits:
+ dataset = load_dataset(dataset_name, split=split)
+ for idx, item in enumerate(tqdm(dataset, desc=f"Processing split {split}")):
+ image = item["image"].convert("RGB")
+ label = item["label"]
+ image_type = "real" if label == 0 else "fake"
+
+ # Object detection
+ inputs = object_processor(text=objects, images=image, return_tensors="pt").to(device)
+ unnormalized_image = get_preprocessed_image(inputs.pixel_values.cpu())
+ outputs = object_model(**inputs)
+ target_sizes = torch.Tensor([unnormalized_image.size[::-1]])
+ results = object_processor.post_process_object_detection(outputs=outputs, threshold=0.2, target_sizes=target_sizes)[0]
+
+ # Save crops
+ boxes, labels = results["boxes"].cpu(), results["labels"]
+ phrases = [objects[i].split("of ", 1)[1].replace(' ', '_') for i in labels]
+ save_crops(unnormalized_image, boxes, phrases, f"{split}_{idx}", split, image_type, output_dir)
+
+ # Encode all saved crops after saving is complete
+ encode_and_save_crops(model, processor, device, output_dir, encodings_dir, batch_size)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Process HF dataset or custom image directory for object detection and encoding")
+ parser.add_argument("--custom", action="store_true", help="Use local directories instead of Hugging Face dataset")
+ parser.add_argument("--read_dirs", nargs="+", help="List of directories to read images from (if --custom is set)")
+ parser.add_argument("--output_dir", type=str, default="data/crops", help="Directory to save the cropped images")
+ parser.add_argument("--encodings_dir", type=str, default="encodings/crops", help="Directory to save the image encodings")
+ parser.add_argument("--batch_size", type=int, default=32, help="Batch size for processing images")
+
+ args = parser.parse_args()
+ process_dataset_or_directory(
+ custom=args.custom,
+ read_dirs=args.read_dirs,
+ output_dir=args.output_dir,
+ encodings_dir=args.encodings_dir,
+ batch_size=args.batch_size
+ )
diff --git a/miragenews/data/encode_image.py b/miragenews/data/encode_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..f903c18f3529a8f6a67d0023fbcb0012e50ca3ff
--- /dev/null
+++ b/miragenews/data/encode_image.py
@@ -0,0 +1,113 @@
+import argparse
+import torch
+import os
+from PIL import Image
+from tqdm import tqdm
+from transformers import Blip2Processor, Blip2ForConditionalGeneration
+from datasets import load_dataset
+
+def preprocess_image(image_path, processor):
+ image = processor(Image.open(image_path).convert("RGB")).data['pixel_values'][0]
+ return torch.Tensor(image)
+
+def preprocess_hf_image(hf_image, processor):
+ # Process an image from the HF dataset (assuming PIL format)
+ image = processor(hf_image.convert("RGB")).data['pixel_values'][0]
+ return torch.Tensor(image)
+
+def process_batch(images, model, device):
+ images_tensor = torch.stack(images).to(device)
+ with torch.no_grad():
+ return model(images_tensor).pooler_output
+
+def process_directory(directory, model, processor, device, batch_size=32):
+ encodings = []
+ batch = []
+ for filename in tqdm(sorted(os.listdir(directory))):
+ image_path = os.path.join(directory, filename)
+ batch.append(preprocess_image(image_path, processor))
+ if len(batch) == batch_size:
+ encodings.append(process_batch(batch, model, device))
+ batch = []
+
+ if batch:
+ encodings.append(process_batch(batch, model, device))
+
+ return torch.cat(encodings)
+
+def process_hf_dataset_by_label(dataset, model, processor, device, batch_size=32):
+ encodings_real, encodings_fake = [], []
+ batch_real, batch_fake = [], []
+ for item in tqdm(dataset):
+ image_tensor = preprocess_hf_image(item["image"], processor)
+ label = item["label"]
+
+ if label == 0:
+ batch_real.append(image_tensor)
+ if len(batch_real) == batch_size:
+ encodings_real.append(process_batch(batch_real, model, device))
+ batch_real = []
+ elif label == 1:
+ batch_fake.append(image_tensor)
+ if len(batch_fake) == batch_size:
+ encodings_fake.append(process_batch(batch_fake, model, device))
+ batch_fake = []
+
+ if batch_real:
+ encodings_real.append(process_batch(batch_real, model, device))
+ if batch_fake:
+ encodings_fake.append(process_batch(batch_fake, model, device))
+
+ return torch.cat(encodings_real), torch.cat(encodings_fake)
+
+def save_encodings(encodings, filename):
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ torch.save(encodings, filename)
+
+def main(custom=False, read_dirs=None, batch_size=64):
+ device = "cuda"
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
+ model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").vision_model.to(device)
+
+ if custom and read_dirs:
+ for read_dir in read_dirs:
+ real_dir = os.path.join("my_dataset/image", read_dir, "real")
+ fake_dir = os.path.join("my_dataset/image", read_dir, "fake")
+
+ if os.path.exists(real_dir):
+ output_file_real = f"encodings/image/{read_dir}/real.pt"
+ image_features_real = process_directory(real_dir, model, processor, device, batch_size)
+ save_encodings(image_features_real, output_file_real)
+ print(f"Encoded features for real images in {read_dir} saved to {output_file_real}")
+
+ if os.path.exists(fake_dir):
+ output_file_fake = f"encodings/image/{read_dir}/fake.pt"
+ image_features_fake = process_directory(fake_dir, model, processor, device, batch_size)
+ save_encodings(image_features_fake, output_file_fake)
+ print(f"Encoded features for fake images in {read_dir} saved to {output_file_fake}")
+
+ else:
+ dataset_name = "anson-huang/mirage-news"
+ available_splits = load_dataset(dataset_name).keys() # Get available splits from the dataset
+ for split in available_splits:
+ dataset = load_dataset(dataset_name, split=split)
+ output_file_real = f"encodings/image/{split}/real.pt"
+ output_file_fake = f"encodings/image/{split}/fake.pt"
+
+ image_features_real, image_features_fake = process_hf_dataset_by_label(dataset, model, processor, device, batch_size)
+ save_encodings(image_features_real, output_file_real)
+ save_encodings(image_features_fake, output_file_fake)
+
+ print(f"Encoded features for real images in {split} saved to {output_file_real}")
+ print(f"Encoded features for fake images in {split} saved to {output_file_fake}")
+
+ print("Feature vectors saved successfully.")
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Encode images from local directory or HF dataset")
+ parser.add_argument("--custom", action="store_true", help="Use local directories instead of Hugging Face dataset")
+ parser.add_argument("--read_dirs", nargs="+", help="List of directories to read images from (if --custom is set)")
+ parser.add_argument("--batch_size", type=int, default=64, help="Batch size for processing images (default: 64)")
+
+ args = parser.parse_args()
+ main(custom=args.custom, read_dirs=args.read_dirs, batch_size=args.batch_size)
diff --git a/miragenews/data/encode_predictions.py b/miragenews/data/encode_predictions.py
new file mode 100644
index 0000000000000000000000000000000000000000..e35a7418ea96526f0ee06a20898b5ff21a51d568
--- /dev/null
+++ b/miragenews/data/encode_predictions.py
@@ -0,0 +1,283 @@
+import argparse
+import torch
+from PIL import Image
+from tqdm import tqdm
+from transformers import Owlv2Processor, Owlv2ForObjectDetection, Blip2Processor, Blip2ForConditionalGeneration
+import clip
+from datasets import load_dataset
+import sys
+import os
+
+# Add the parent directory to the Python path
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
+
+from models import get_model
+from data import load_config, get_object_class_caption, get_preprocessed_image
+from utils import load_model_checkpoint
+
+def get_logits(y, eps=1e-5):
+ y = torch.clamp(y, eps, 1 - eps)
+ y = torch.log(y / (1 - y))
+ return y
+
+def process_img_linear(model, batch, image_encoder, device):
+ """Process images with the img-linear model."""
+ batch_tensor = torch.stack(batch).to(device)
+ with torch.no_grad():
+ images_encoding = image_encoder(batch_tensor).pooler_output
+ outputs = model(images_encoding)
+ return get_logits(outputs)
+
+def process_cbm_encoder(model, image, objects, object_processor, object_detector, image_processor, image_encoder, device):
+ """Process each image with cbm-encoder model for object crops."""
+ object_scores = torch.full((300,), float('-inf')).to(device) # Initialize scores with -inf
+ filled_indices = torch.zeros((300,), dtype=torch.bool).to(device) # Track updated indices
+
+ inputs = object_processor(text=objects, images=image, return_tensors="pt").to(device)
+ unnormalized_image = get_preprocessed_image(inputs.pixel_values.cpu())
+ outputs = object_detector(**inputs)
+ target_sizes = torch.Tensor([unnormalized_image.size[::-1]])
+ detected_objects = object_processor.post_process_object_detection(outputs=outputs, threshold=0.3, target_sizes=target_sizes)[0]
+
+ for box, label in zip(detected_objects["boxes"], detected_objects["labels"]):
+ obj_class_idx = label.item() # Object class index
+ crop_img = unnormalized_image.crop(box.tolist())
+ crop_tensor = torch.Tensor(image_processor(crop_img.convert("RGB")).data['pixel_values'][0]).unsqueeze(0).to(device)
+ crop_encoding = image_encoder(crop_tensor).pooler_output[0]
+ with torch.no_grad():
+ crop_score = model.classifiers[obj_class_idx](crop_encoding)
+ object_scores[obj_class_idx] = torch.maximum(object_scores[obj_class_idx], crop_score)
+ filled_indices[obj_class_idx] = True
+
+ # Fill unfilled indices with 0.5
+ object_scores[~filled_indices] = 0.5
+ return get_logits(object_scores).unsqueeze(0)
+
+def preprocess_texts(texts, model, device):
+ """Tokenizes and encodes a batch of text using CLIP's model."""
+ tokenized_texts = clip.tokenize(texts, truncate=True).to(device)
+ with torch.no_grad():
+ return model.encode_text(tokenized_texts)
+
+def process_txt_linear(model, text_encoding, device):
+ """Process images with the txt-linear model."""
+ with torch.no_grad():
+ outputs = model(text_encoding.float().to(device))
+ return get_logits(outputs)
+
+def save_predictions(predictions, output_dir, mode, model_class, split, label):
+ output_path = os.path.join(output_dir, mode, model_class, split, f"{label}.pt")
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ torch.save(predictions, output_path)
+
+def main(mode, model_class, custom=False, img_dirs=None, text_dirs=None, batch_size=64, test_only=False):
+ device = "cpu" if torch.cuda.is_available() else "cpu"
+ output_dir = "encodings/predictions"
+
+ # Merge predictions
+ if model_class == "merged":
+ if mode == "image":
+ linear_pred_dir = f'{output_dir}/{mode}/linear'
+ cbm_pred_dir = f'{output_dir}/{mode}/cbm-encoder'
+
+ for split in sorted(os.listdir(linear_pred_dir)):
+ linear_real = torch.load(f'{linear_pred_dir}/{split}/real.pt').to(device)
+ linear_fake = torch.load(f'{linear_pred_dir}/{split}/fake.pt').to(device)
+ cbm_real = torch.load(f'{cbm_pred_dir}/{split}/real.pt').to(device)
+ cbm_fake = torch.load(f'{cbm_pred_dir}/{split}/fake.pt').to(device)
+ merged_dir = f'{output_dir}/{mode}/merged/{split}'
+ os.makedirs(merged_dir, exist_ok=True)
+ torch.save(torch.concat((cbm_real, linear_real), dim=1), f'{merged_dir}/real.pt')
+ torch.save(torch.concat((cbm_fake, linear_fake), dim=1), f'{merged_dir}/fake.pt')
+ print(f"Image predictions merged.")
+ return
+ elif mode == "text":
+ linear_pred_dir = f'{output_dir}/{mode}/linear'
+ tbm_pred_dir = f'{output_dir}/{mode}/tbm-encoder'
+
+ for split in sorted(os.listdir(linear_pred_dir)):
+ linear_real = torch.load(f'{linear_pred_dir}/{split}/real.pt').to(device)
+ linear_fake = torch.load(f'{linear_pred_dir}/{split}/fake.pt').to(device)
+ tbm_real = torch.load(f'{tbm_pred_dir}/{split}/real.pt').to(device)
+ tbm_fake = torch.load(f'{tbm_pred_dir}/{split}/fake.pt').to(device)
+ print(linear_real.shape)
+ print(tbm_real.shape)
+ merged_dir = f'{output_dir}/{mode}/merged/{split}'
+ os.makedirs(merged_dir, exist_ok=True)
+ torch.save(torch.concat((tbm_real, linear_real), dim=1), f'{merged_dir}/real.pt')
+ torch.save(torch.concat((tbm_fake, linear_fake), dim=1), f'{merged_dir}/fake.pt')
+ print(f"Text predictions merged.")
+ return
+
+ if mode == "image":
+ # Load shared processors and models
+ # Sửa lỗi OOM cho cbm-encoder
+ object_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
+ object_detector = Owlv2ForObjectDetection.from_pretrained(
+ "google/owlv2-base-patch16-ensemble",
+ torch_dtype=torch.float16 # Thêm float16
+ ).to(device)
+ image_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
+ image_encoder = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").vision_model.to(device)
+
+ elif mode == "text":
+ clip_model, _ = clip.load("ViT-L/14@336px", device=device)
+ # Initialize models
+ config = load_config(f"configs/{mode}/{model_class}.yaml")
+ model, _ = load_model_checkpoint(get_model(config).to(device), config['training']['save_path'])
+ objects = get_object_class_caption() if model_class == "cbm-encoder" else None
+
+ if custom:
+ if img_dirs:
+ # Process local directories
+ for read_dir in img_dirs:
+ real_dir = os.path.join("my_dataset/image", read_dir, "real")
+ fake_dir = os.path.join("my_dataset/image", read_dir, "fake")
+
+ batch = []
+ predictions = []
+
+ if os.path.exists(real_dir):
+ for image_name in tqdm(sorted(os.listdir(real_dir)), desc=f"Processing {real_dir} with {model_class}"):
+ image_path = os.path.join(real_dir, image_name)
+ image = Image.open(image_path).convert("RGB")
+ if model_class == "cbm-encoder":
+ predictions.append(process_cbm_encoder(model, image, objects, object_processor, object_detector, image_processor, image_encoder, device))
+ else:
+ image_tensor = torch.Tensor(image_processor(image).data['pixel_values'][0])
+ batch.append(image_tensor)
+ if len(batch) == batch_size:
+ predictions.append(process_img_linear(model, batch, image_encoder, device))
+ batch = []
+ if batch:
+ predictions.append(process_img_linear(model, batch, image_encoder, device))
+ if predictions:
+ save_predictions(torch.cat(predictions), output_dir, mode, model_class, read_dir, "real")
+ print(f"Predictions for real images in {read_dir} saved.")
+
+ if os.path.exists(fake_dir):
+ batch = [] # Reset batch and predictions for fake images
+ predictions = []
+ for image_name in tqdm(sorted(os.listdir(fake_dir)), desc=f"Processing {fake_dir} with {model_class}"):
+ image_path = os.path.join(fake_dir, image_name)
+ image = Image.open(image_path).convert("RGB")
+ if model_class == "cbm-encoder":
+ predictions.append(process_cbm_encoder(model, image, objects, object_processor, object_detector, image_processor, image_encoder, device))
+ else:
+ image_tensor = torch.Tensor(image_processor(image).data['pixel_values'][0])
+ batch.append(image_tensor)
+ if len(batch) == batch_size:
+ predictions.append(process_img_linear(model, batch, image_encoder, device))
+ batch = []
+ if batch:
+ predictions.append(process_img_linear(model, batch, image_encoder, device))
+ if predictions:
+ save_predictions(torch.cat(predictions), output_dir, mode, model_class, read_dir, "fake")
+ print(f"Predictions for fake images in {read_dir} saved.")
+
+ # ==================================================================
+ # === BẮT ĐẦU KHỐI CODE MỚI ĐỂ XỬ LÝ TEXT TÙY CHỈNH ===
+ # ==================================================================
+ elif text_dirs:
+ for read_dir in text_dirs:
+ # Logic này chỉ chạy cho 'linear' vì TBM (18-dim) không khả thi cho custom text
+ if model_class != "linear":
+ print(f"Warning: Only 'linear' model class is supported for custom text dirs. Skipping '{model_class}'.")
+ continue # Bỏ qua nếu model class không phải 'linear'
+
+ for label in ["real", "fake"]: # Xử lý cả 'real' và 'fake' nếu thư mục tồn tại
+ text_dir = os.path.join("my_dataset/text", read_dir, label)
+ batch = []
+ predictions = []
+
+ if os.path.exists(text_dir):
+ print(f"Processing directory: {text_dir}")
+ for text_name in tqdm(sorted(os.listdir(text_dir)), desc=f"Processing {text_dir} with {model_class}"):
+ text_path = os.path.join(text_dir, text_name)
+ # Đọc nội dung file text
+ try:
+ with open(text_path, 'r', encoding='utf-8') as f:
+ text = f.read()
+ except Exception as e:
+ print(f"\nWarning: Could not read {text_path}. Skipping. Error: {e}")
+ continue
+
+ batch.append(text)
+ # Xử lý batch khi đầy
+ if len(batch) == batch_size:
+ text_encoding = preprocess_texts(batch, clip_model, device)
+ predictions.append(process_txt_linear(model, text_encoding, device))
+ batch = []
+
+ # Xử lý batch còn sót lại
+ if batch:
+ text_encoding = preprocess_texts(batch, clip_model, device)
+ predictions.append(process_txt_linear(model, text_encoding, device))
+
+ # Lưu file .pt nếu có dự đoán
+ if predictions:
+ save_predictions(torch.cat(predictions), output_dir, mode, model_class, read_dir, label)
+ print(f"Predictions for {label} texts in {read_dir} saved.")
+ else:
+ print(f"Directory not found, skipping: {text_dir}")
+ # ==================================================================
+ # === KẾT THÚC KHỐI CODE MỚI ===
+ # ==================================================================
+
+ else:
+ # Process Hugging Face dataset
+ dataset_name = "anson-huang/mirage-news"
+ available_splits = list(load_dataset(dataset_name).keys())
+ if test_only:
+ # Lấy 5 split test đầu tiên
+ available_splits = [s for s in available_splits if s.startswith('test')]
+
+ for split in available_splits:
+ if split not in ['train', 'validation'] and not test_only:
+ continue # Bỏ qua các split test nếu không có cờ test_only
+
+ dataset = load_dataset(dataset_name, split=split)
+ for label in ["real", "fake"]:
+ batch = []
+ predictions = []
+ filtered_dataset = [item for item in dataset if item["label"] == (0 if label == "real" else 1)]
+
+ if mode == 'image':
+ for item in tqdm(filtered_dataset, desc=f"Processing {split}/{label} with {model_class}"):
+ image = item["image"].convert("RGB")
+ if model_class == "cbm-encoder":
+ predictions.append(process_cbm_encoder(model, image, objects, object_processor, object_detector, image_processor, image_encoder, device))
+ else:
+ image_tensor = torch.Tensor(image_processor(image).data['pixel_values'][0])
+ batch.append(image_tensor)
+ if len(batch) == batch_size:
+ predictions.append(process_img_linear(model, batch, image_encoder, device))
+ batch = []
+ if batch:
+ predictions.append(process_img_linear(model, batch, image_encoder, device))
+ elif mode == 'text':
+ for item in tqdm(filtered_dataset, desc=f"Processing {split}/{label} with {model_class}"):
+ batch.append(item["text"])
+ if len(batch) == batch_size:
+ text_encoding = preprocess_texts(batch, clip_model, device)
+ predictions.append(process_txt_linear(model, text_encoding, device))
+ batch = []
+ if batch:
+ text_encoding = preprocess_texts(batch, clip_model, device)
+ predictions.append(process_txt_linear(model, text_encoding, device))
+
+ if predictions: # Đảm bảo predictions không rỗng
+ save_predictions(torch.cat(predictions), output_dir, mode, model_class, split, label)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Process local directories or Hugging Face datasets for encoding predictions")
+ parser.add_argument("--mode", required=True, choices=["image", "text"], help="Specify 'image' or 'text'")
+ parser.add_argument("--model_class", choices=["linear", "cbm-encoder", "merged"], required=True, help="Specify model class")
+ parser.add_argument("--custom", action="store_true", help="Use local directories instead of Hugging Face dataset")
+ parser.add_argument("--img_dirs", nargs="+", help="List of directories to read images from (if --custom is set)")
+ parser.add_argument("--text_dirs", nargs="+", help="List of directories to read captions from (if --custom is set)")
+ parser.add_argument("--batch_size", type=int, default=64, help="Batch size for processing images")
+ parser.add_argument("--test_only", action="store_true", help="Encode only the test sets from the Hugging Face dataset")
+
+ args = parser.parse_args()
+ main(mode=args.mode, model_class=args.model_class, custom=args.custom, img_dirs=args.img_dirs, text_dirs=args.text_dirs, batch_size=args.batch_size, test_only=args.test_only)
diff --git a/miragenews/data/encode_text.py b/miragenews/data/encode_text.py
new file mode 100644
index 0000000000000000000000000000000000000000..3916af9c6f7681ee1f0adf8de6a0d8c14a090b8a
--- /dev/null
+++ b/miragenews/data/encode_text.py
@@ -0,0 +1,109 @@
+import argparse
+import torch
+import os
+import clip
+from tqdm import tqdm
+from datasets import load_dataset
+
+def preprocess_texts(texts, model, device):
+ """Tokenizes and encodes a batch of text using CLIP's model."""
+ tokenized_texts = clip.tokenize(texts, truncate=True).to(device)
+ with torch.no_grad():
+ return model.encode_text(tokenized_texts)
+
+def process_custom_text_file(text_file, model, device, batch_size=64):
+ """Reads texts from a custom text file and encodes them in batches."""
+ encodings = []
+ batch = []
+
+ with open(text_file, 'r') as file:
+ for line in tqdm(file, desc="Processing custom text file"):
+ batch.append(line.strip())
+ if len(batch) == batch_size:
+ encodings.append(preprocess_texts(batch, model, device))
+ batch = []
+
+ if batch:
+ encodings.append(preprocess_texts(batch, model, device))
+
+ return torch.cat(encodings)
+
+def process_hf_dataset_by_label(dataset, model, device, batch_size=64):
+ """Encodes texts from a Hugging Face dataset based on labels in batches."""
+ encodings_real, encodings_fake = [], []
+ batch_real, batch_fake = [], []
+
+ for item in tqdm(dataset, desc="Processing Hugging Face dataset"):
+ text = item['text']
+ label = item["label"]
+
+ if label == 0:
+ batch_real.append(text)
+ if len(batch_real) == batch_size:
+ encodings_real.append(preprocess_texts(batch_real, model, device))
+ batch_real = []
+ elif label == 1:
+ batch_fake.append(text)
+ if len(batch_fake) == batch_size:
+ encodings_fake.append(preprocess_texts(batch_fake, model, device))
+ batch_fake = []
+
+ if batch_real:
+ encodings_real.append(preprocess_texts(batch_real, model, device))
+ if batch_fake:
+ encodings_fake.append(preprocess_texts(batch_fake, model, device))
+
+ return torch.cat(encodings_real), torch.cat(encodings_fake)
+
+def save_encodings(encodings, filename):
+ """Saves the encoded tensors to a file."""
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ torch.save(encodings, filename)
+
+def main(custom=False, text_files=None, batch_size=64):
+ device = "cuda"
+ model, _ = clip.load("ViT-L/14@336px", device=device)
+
+ if custom and text_files:
+ for text_file in text_files:
+ real_file = os.path.join("my_dataset/text", text_file, "real.txt")
+ fake_file = os.path.join("my_dataset/text", text_file, "fake.txt")
+
+ if os.path.exists(real_file):
+ output_file_real = f"encodings/text/{text_file}/real.pt"
+ text_features_real = process_custom_text_file(real_file, model, device, batch_size)
+ save_encodings(text_features_real, output_file_real)
+ print(f"Encoded features for real texts in {text_file} saved to {output_file_real}")
+
+ if os.path.exists(fake_file):
+ output_file_fake = f"encodings/text/{text_file}/fake.pt"
+ text_features_fake = process_custom_text_file(fake_file, model, device, batch_size)
+ save_encodings(text_features_fake, output_file_fake)
+ print(f"Encoded features for fake texts in {text_file} saved to {output_file_fake}")
+
+ else:
+ dataset_name = "anson-huang/mirage-news"
+ available_splits = load_dataset(dataset_name).keys()
+
+ for split in available_splits:
+ dataset = load_dataset(dataset_name, split=split)
+ output_file_real = f"encodings/text/{split}/real.pt"
+ output_file_fake = f"encodings/text/{split}/fake.pt"
+
+ text_features_real, text_features_fake = process_hf_dataset_by_label(dataset, model, device, batch_size)
+ save_encodings(text_features_real, output_file_real)
+ save_encodings(text_features_fake, output_file_fake)
+
+ print(f"Encoded features for real texts in {split} saved to {output_file_real}")
+ print(f"Encoded features for fake texts in {split} saved to {output_file_fake}")
+
+ print("Text feature vectors saved successfully.")
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Encode text from local text files or HF dataset")
+ parser.add_argument("--custom", action="store_true", help="Use local text files instead of Hugging Face dataset")
+ parser.add_argument("--text_files", nargs="+", help="List of text files to read (if --custom is set)")
+ parser.add_argument("--batch_size", type=int, default=64, help="Batch size for processing texts (default: 64)")
+
+ args = parser.parse_args()
+ main(custom=args.custom, text_files=args.text_files, batch_size=args.batch_size)
diff --git a/miragenews/data/infer_single_image.py b/miragenews/data/infer_single_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b7714d39b1ced088cff93d8515b7f37a1bb1bac
--- /dev/null
+++ b/miragenews/data/infer_single_image.py
@@ -0,0 +1,249 @@
+import torch
+from PIL import Image
+from transformers import Owlv2Processor, Owlv2ForObjectDetection, Blip2Processor, Blip2ForConditionalGeneration
+import clip # Cần import clip nếu dùng text mode
+import os
+import sys
+import argparse # Thêm thư viện để nhận tham số dòng lệnh
+
+# --- Đảm bảo các file project khác có thể import ---
+# Thêm thư mục cha vào sys.path nếu file này nằm trong thư mục con (ví dụ: data/)
+script_dir = os.path.dirname(os.path.abspath(__file__))
+project_root = os.path.abspath(os.path.join(script_dir, "..")) # Lùi lại một cấp
+if project_root not in sys.path:
+ sys.path.append(project_root)
+
+# Import các hàm/lớp cần thiết từ project của bạn
+try:
+ from models import get_model
+ from data import load_config, get_object_class_caption, get_preprocessed_image
+ from utils import load_model_checkpoint
+except ImportError as e:
+ print(f"Lỗi import: {e}")
+ print("Hãy đảm bảo bạn đang chạy script này từ thư mục gốc của project HOẶC cấu trúc thư mục đúng.")
+ sys.exit(1)
+
+# --- Copy các hàm xử lý từ encode_predictions.py ---
+# (Giữ nguyên các hàm này)
+def get_logits(y, eps=1e-5):
+ y = torch.clamp(y, eps, 1 - eps)
+ y = torch.log(y / (1 - y))
+ return y
+
+def process_img_linear(model, batch, image_encoder, device):
+ # Đảm bảo batch là list các tensor trước khi stack
+ if isinstance(batch, list) and all(isinstance(t, torch.Tensor) for t in batch):
+ batch_tensor = torch.stack(batch).to(device)
+ elif isinstance(batch, torch.Tensor): # Nếu đã là tensor (ví dụ: batch size 1)
+ batch_tensor = batch.to(device)
+ if batch_tensor.dim() == 3: # Thêm chiều batch nếu thiếu
+ batch_tensor = batch_tensor.unsqueeze(0)
+ else:
+ raise TypeError("Đầu vào 'batch' cho process_img_linear phải là list các tensor hoặc một tensor.")
+
+ with torch.no_grad():
+ images_encoding = image_encoder(batch_tensor).pooler_output
+ outputs = model(images_encoding)
+ return get_logits(outputs)
+
+def process_cbm_encoder(model, image, objects, object_processor, object_detector, image_processor, image_encoder, device):
+ # Khởi tạo điểm số và cờ theo dõi
+ num_classes = 300 # Hoặc lấy từ config/model nếu có thể
+ object_scores = torch.full((num_classes,), -float('inf')).to(device) # Dùng -inf để torch.maximum hoạt động đúng
+ filled_indices = torch.zeros((num_classes,), dtype=torch.bool).to(device)
+
+ # Xử lý ảnh bằng OWL-v2 processor
+ inputs = object_processor(text=objects, images=image, return_tensors="pt").to(device)
+
+ # Lấy ảnh PIL gốc (không chuẩn hóa) - giả sử image đầu vào là PIL
+ unnormalized_image = image
+
+ # Chạy OWL-v2 detector
+ outputs = object_detector(**inputs)
+
+ # Xử lý kết quả object detection
+ target_sizes = torch.Tensor([unnormalized_image.size[::-1]]).to(device) # Kích thước ảnh gốc
+ detected_objects = object_processor.post_process_object_detection(outputs=outputs, threshold=0.1, target_sizes=target_sizes)[0] # Giảm threshold nếu cần
+
+ # Lấy BLIP-2 processor để chuẩn bị ảnh crop
+ blip_image_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
+
+ print(f"Phát hiện {len(detected_objects['labels'])} đối tượng với threshold > 0.1")
+
+ # Duyệt qua các đối tượng phát hiện được
+ for box, label, score in zip(detected_objects["boxes"], detected_objects["labels"], detected_objects["scores"]):
+ obj_class_idx = label.item()
+
+ # Kiểm tra xem index có hợp lệ không (quan trọng!)
+ if obj_class_idx >= num_classes:
+ # print(f"Cảnh báo: Index đối tượng {obj_class_idx} vượt quá số lớp ({num_classes}). Bỏ qua.")
+ continue
+ # Bỏ qua nếu index không có classifier tương ứng trong model CBM
+ if obj_class_idx >= len(model.classifiers):
+ # print(f"Cảnh báo: Không có classifier cho index {obj_class_idx} trong model CBM. Bỏ qua.")
+ continue
+
+ # Cắt ảnh đối tượng
+ crop_img = unnormalized_image.crop(box.tolist())
+
+ # Chuẩn bị ảnh crop bằng BLIP-2 processor
+ crop_inputs = blip_image_processor(images=crop_img.convert("RGB"), return_tensors="pt")
+ crop_tensor = crop_inputs['pixel_values'].to(device) # Tensor đã chuẩn hóa
+
+ # Mã hóa ảnh crop bằng BLIP-2 encoder
+ with torch.no_grad():
+ # Đảm bảo crop_tensor có batch dimension
+ if crop_tensor.dim() == 3:
+ crop_tensor = crop_tensor.unsqueeze(0)
+ crop_encoding = image_encoder(crop_tensor).pooler_output[0] # Lấy encoding của ảnh crop
+
+ # Lấy điểm số từ classifier tương ứng của CBM
+ classifier = model.classifiers[obj_class_idx]
+ crop_score = classifier(crop_encoding)
+
+ # Cập nhật điểm số cao nhất cho lớp đối tượng này
+ object_scores[obj_class_idx] = torch.maximum(object_scores[obj_class_idx], crop_score.squeeze()) # Squeeze để bỏ chiều thừa
+ filled_indices[obj_class_idx] = True
+
+ # Gán giá trị trung tính (0.5 trước khi qua logit) cho các lớp không được phát hiện
+ # Giá trị logit của 0.5 là 0.0
+ # object_scores[~filled_indices] = 0.5 # Gán xác suất 0.5
+ # Cần gán logit(0.5) = 0.0
+ object_scores[~filled_indices] = 0.0
+
+ # Nếu không có đối tượng nào được điền, trả về vector 0
+ if not torch.any(filled_indices):
+ print("Cảnh báo: Không có đối tượng nào được phát hiện và điền điểm số. Trả về vector 0.")
+ return torch.zeros(1, num_classes).to(device)
+
+
+ # Chuyển điểm số sigmoid thành logits và thêm chiều batch
+ # return get_logits(object_scores).unsqueeze(0) # Nếu object_scores là xác suất sigmoid
+ # Nếu object_scores đã là logit (hoặc giá trị chưa qua sigmoid), chỉ cần unsqueeze
+ return object_scores.unsqueeze(0)
+
+
+# --- Hàm main để chạy inference ---
+def main_single_image(image_path, model_class, config_path):
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ print(f"Sử dụng thiết bị: {device}")
+
+ # --- Load Config và Checkpoint Path ---
+ try:
+ config = load_config(config_path)
+ checkpoint_path = config['training']['save_path']
+ print(f"Sử dụng checkpoint: {checkpoint_path}")
+ except Exception as e:
+ print(f"Lỗi khi load config '{config_path}': {e}")
+ return
+
+ # --- Load Model Dự đoán ---
+ try:
+ model = get_model(config).to(device)
+ model, _ = load_model_checkpoint(model, checkpoint_path)
+ model.eval()
+ print(f"Đã load model dự đoán '{model_class}' từ {checkpoint_path}")
+ except Exception as e:
+ print(f"Lỗi khi load model dự đoán: {e}")
+ return
+
+ # --- Load các Model/Processor dùng chung ---
+ print("Đang load các model/processor dùng chung...")
+ try:
+ # BLIP-2 (luôn cần cho image mode)
+ blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
+ blip_image_encoder = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").vision_model.to(device)
+ blip_image_encoder.eval()
+
+ # OWL-v2 (chỉ cần cho cbm-encoder)
+ if model_class == "cbm-encoder":
+ owl_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
+ owl_detector = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device)
+ owl_detector.eval()
+ objects = get_object_class_caption() # Lấy tên các đối tượng
+ else:
+ owl_processor, owl_detector, objects = None, None, None
+ print("Load model/processor dùng chung thành công.")
+ except Exception as e:
+ print(f"Lỗi khi load model/processor dùng chung: {e}")
+ return
+
+ # --- Load và Chuẩn bị Ảnh Đầu vào ---
+ print(f"Đang load ảnh: {image_path}")
+ try:
+ input_image_pil = Image.open(image_path).convert("RGB")
+ except FileNotFoundError:
+ print(f"Lỗi: Không tìm thấy file ảnh tại {image_path}")
+ return
+ except Exception as e:
+ print(f"Lỗi khi mở ảnh: {e}")
+ return
+
+ # --- Chạy Inference ---
+ print(f"Đang chạy inference với model {model_class}...")
+ predictions_logits = None # Khởi tạo
+ try:
+ if model_class == "linear":
+ # Chuẩn bị ảnh cho BLIP-2
+ blip_inputs = blip_processor(images=input_image_pil, return_tensors="pt")
+ image_tensor_for_linear = blip_inputs['pixel_values'].to(device) # Shape [1, C, H, W]
+ # Chạy model linear (cần batch là list tensor hoặc tensor)
+ predictions_logits = process_img_linear(model, image_tensor_for_linear, blip_image_encoder, device)
+
+ elif model_class == "cbm-encoder":
+ # Chạy model cbm-encoder (truyền ảnh PIL gốc)
+ predictions_logits = process_cbm_encoder(model, input_image_pil, objects, owl_processor, owl_detector, blip_processor, blip_image_encoder, device)
+
+ else:
+ print(f"Lỗi: model_class '{model_class}' không được hỗ trợ cho image mode.")
+ return
+
+ print("Inference hoàn tất.")
+
+ except Exception as e:
+ print(f"Lỗi trong quá trình inference: {e}")
+ # In thêm traceback nếu cần debug
+ import traceback
+ traceback.print_exc()
+ return
+
+ # --- Hiển thị Kết quả ---
+ if predictions_logits is not None:
+ predictions_probs = torch.sigmoid(predictions_logits)
+
+ print("\n--- Kết quả Dự đoán ---")
+ # print("Logits:")
+ # print(predictions_logits.detach().cpu().numpy())
+ print("Xác suất (Probabilities):")
+ print(predictions_probs.detach().cpu().numpy())
+
+ # Ví dụ: Lấy top 5 concepts
+ top_k_probs, top_k_indices = torch.topk(predictions_probs.squeeze(0), 5)
+ print("\nTop 5 concepts (Index và Xác suất):")
+ for i in range(5):
+ print(f" - Index: {top_k_indices[i].item()}, Prob: {top_k_probs[i].item():.4f}")
+ # Bạn có thể cần một map từ index sang tên concept để dễ hiểu hơn
+ else:
+ print("Không có kết quả dự đoán để hiển thị.")
+
+
+# --- Xử lý Tham số Dòng lệnh ---
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Chạy inference trên một ảnh duy nhất.")
+ parser.add_argument("--image_path", required=True, help="Đường dẫn đến file ảnh cần xử lý.")
+ parser.add_argument("--model_class", required=True, choices=["linear", "cbm-encoder"], help="Loại model ('linear' hoặc 'cbm-encoder').")
+ # Tự động tìm config dựa trên model_class
+ # parser.add_argument("--config_path", required=True, help="Đường dẫn đến file config .yaml của model.")
+
+ args = parser.parse_args()
+
+ # Tự động xác định đường dẫn config
+ config_path = f"configs/image/{args.model_class}.yaml"
+ if not os.path.exists(config_path):
+ print(f"Lỗi: Không tìm thấy file config tại '{config_path}'.")
+ print("Hãy đảm bảo file config tồn tại hoặc cung cấp đường dẫn đầy đủ qua tham số.")
+ # Có thể thêm tham số --config_path nếu muốn ghi đè
+ sys.exit(1)
+
+
+ main_single_image(args.image_path, args.model_class, config_path)
diff --git a/miragenews/data/manual_merge.py b/miragenews/data/manual_merge.py
new file mode 100644
index 0000000000000000000000000000000000000000..90ea2284b036d52a9a533353e7bf842e02dbdff4
--- /dev/null
+++ b/miragenews/data/manual_merge.py
@@ -0,0 +1,33 @@
+import torch
+import os
+import sys
+
+print("Starting manual merge for REAL files...")
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+split_name = "my_single_image_dir"
+
+cbm_path = f"encodings/predictions/image/cbm-encoder/{split_name}/real.pt"
+linear_path = f"encodings/predictions/image/linear/{split_name}/real.pt"
+output_dir = f"encodings/predictions/image/merged/{split_name}"
+output_path = os.path.join(output_dir, "real.pt")
+
+os.makedirs(output_dir, exist_ok=True)
+
+try:
+ cbm_real = torch.load(cbm_path).to(device) # Shape [1, 300]
+ linear_real = torch.load(linear_path).to(device) # Shape [1, 1]
+
+ print(f"Loaded CBM tensor (shape): {cbm_real.shape}")
+ print(f"Loaded Linear tensor (shape): {linear_real.shape}")
+
+ merged_real = torch.cat((cbm_real, linear_real), dim=1)
+
+ torch.save(merged_real, output_path)
+ print(f"\nSuccess! Merged tensor (shape {merged_real.shape}) saved to: {output_path}")
+
+except FileNotFoundError as e:
+ print(f"\nError: A required file was not found: {e}")
+ print("Please make sure you have SUCCESSFULLY run BOTH 'cbm-encoder' and 'linear' scripts first.")
+except Exception as e:
+ print(f"\nAn error occurred: {e}")
diff --git a/miragenews/data/manual_merge_text.py b/miragenews/data/manual_merge_text.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b8a706079c6ad6ecdb8fb8dab83dd457cdde930
--- /dev/null
+++ b/miragenews/data/manual_merge_text.py
@@ -0,0 +1,89 @@
+import torch
+import clip
+from models import get_model
+from data import load_config
+from utils import load_model_checkpoint
+
+# --- Lấy các hàm quan trọng từ script của bạn ---
+def get_logits(y, eps=1e-5):
+ y = torch.clamp(y, eps, 1 - eps)
+ y = torch.log(y / (1 - y))
+ return y
+
+def preprocess_texts(texts, model, device):
+ """Tokenizes and encodes a batch of text using CLIP's model."""
+ tokenized_texts = clip.tokenize(texts, truncate=True).to(device)
+ with torch.no_grad():
+ return model.encode_text(tokenized_texts)
+
+def process_txt_linear(model, text_encoding, device):
+ """Lấy dự đoán (dạng logit) từ mô hình."""
+ with torch.no_grad():
+ # Giả sử mô hình trả về xác suất (0-1)
+ outputs = model(text_encoding.float().to(device))
+ return get_logits(outputs)
+
+# --- Hàm mới để tải tất cả mô hình ---
+def load_all_models(device):
+ """Tải tất cả 3 mô hình cần thiết."""
+ print("Đang tải mô hình CLIP...")
+ clip_model, _ = clip.load("ViT-L/14@336px", device=device)
+
+ print("Đang tải mô hình Linear (1D)...")
+ config_linear = load_config("configs/text/linear.yaml")
+ model_linear, _ = load_model_checkpoint(get_model(config_linear).to(device), config_linear['training']['save_path'])
+
+ print("Đang tải mô hình TBM (18D)...")
+ config_tbm = load_config("configs/text/tbm-encoder.yaml")
+ model_tbm, _ = load_model_checkpoint(get_model(config_tbm).to(device), config_tbm['training']['save_path'])
+
+ # Đặt các mô hình ở chế độ eval()
+ clip_model.eval()
+ model_linear.eval()
+ model_tbm.eval()
+
+ return clip_model, model_linear, model_tbm
+
+# --- Hàm chính bạn cần ---
+def get_19d_vector_from_text(text_input, clip_model, model_linear, model_tbm, device):
+ """
+ Chuyển một chuỗi văn bản đầu vào thành vector logit 19 chiều.
+ """
+
+ # 1. Mã hóa văn bản bằng CLIP
+ # Phải đưa vào dưới dạng danh sách [text_input]
+ text_encoding = preprocess_texts([text_input], clip_model, device)
+
+ # 2. Lấy dự đoán 1D từ mô hình Linear
+ # Hàm này trả về logit
+ pred_1d_logits = process_txt_linear(model_linear, text_encoding, device)
+
+ # 3. Lấy dự đoán 18D từ mô hình TBM
+ # Hàm này cũng trả về logit
+ pred_18d_logits = process_txt_linear(model_tbm, text_encoding, device)
+
+ # 4. Ghép nối (Concatenate)
+ # Giống hệt logic trong script của bạn: (18 chiều, 1 chiều)
+ vector_19d = torch.cat((pred_18d_logits, pred_1d_logits), dim=1)
+
+ return vector_19d
+
+# --- CÁCH SỬ DỤNG ---
+if __name__ == "__main__":
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ # Tải mô hình một lần
+ clip_model, model_linear, model_tbm = load_all_models(device)
+
+ print("\n--- Tải mô hình thành công ---")
+
+ # Văn bản bạn muốn xử lý
+ my_text = "This is a sample text to be encoded."
+
+ # Chạy mã hóa
+ vector_19d = get_19d_vector_from_text(my_text, clip_model, model_linear, model_tbm, device)
+
+ print(f"\nĐầu vào: '{my_text}'")
+ print(f"Shape đầu ra: {vector_19d.shape}")
+ print("Vector 19 chiều (dạng logit):")
+ print(vector_19d)
diff --git a/miragenews/encodings/image/my_single_image_dir/real.pt b/miragenews/encodings/image/my_single_image_dir/real.pt
new file mode 100644
index 0000000000000000000000000000000000000000..6d9cfa62bd6d49b81f369451968e419a4625cbce
--- /dev/null
+++ b/miragenews/encodings/image/my_single_image_dir/real.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f9a74f50047e43d5cbcc8ad9c1038d6847855cf95c368fc06a205587618487fb
+size 7124
diff --git a/miragenews/encodings/predictions/image/cbm-encoder/my_single_image_dir/real.pt b/miragenews/encodings/predictions/image/cbm-encoder/my_single_image_dir/real.pt
new file mode 100644
index 0000000000000000000000000000000000000000..d3e6da9d1c5b720986b831fd11124aae2e4772a4
--- /dev/null
+++ b/miragenews/encodings/predictions/image/cbm-encoder/my_single_image_dir/real.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0af65ba6c3878b72db7e5f96f6041bf87ebef98b7c64f5dce5fe76c5b23b320b
+size 2708
diff --git a/miragenews/encodings/predictions/image/linear/my_single_image_dir/real.pt b/miragenews/encodings/predictions/image/linear/my_single_image_dir/real.pt
new file mode 100644
index 0000000000000000000000000000000000000000..e2a8508edf97807f1ca89d079918b2f21c4b6e0b
--- /dev/null
+++ b/miragenews/encodings/predictions/image/linear/my_single_image_dir/real.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f7232770804cbb4314dc0863362dc5d49307bb18e4f611b934d5d74cc456b058
+size 1556
diff --git a/miragenews/encodings/predictions/image/merged/my_single_image_dir/real.pt b/miragenews/encodings/predictions/image/merged/my_single_image_dir/real.pt
new file mode 100644
index 0000000000000000000000000000000000000000..9cc83fbaf12fa6c2bf627edf4390d31293ad2af0
--- /dev/null
+++ b/miragenews/encodings/predictions/image/merged/my_single_image_dir/real.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a78584ae854570f8effd5fe89337bae06afd097d8128d6a06450abbb5f58a918
+size 2708
diff --git a/miragenews/encodings/predictions/text/linear/my_single_text_dir/real.pt b/miragenews/encodings/predictions/text/linear/my_single_text_dir/real.pt
new file mode 100644
index 0000000000000000000000000000000000000000..53fbd310599f3b8081521276d7249da7d059c218
--- /dev/null
+++ b/miragenews/encodings/predictions/text/linear/my_single_text_dir/real.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1bebf824e3b6f17c7ed7d2071a18ccd8a3e956dc5d2079ac0339c01370f9cb5a
+size 1556
diff --git a/miragenews/encodings/text/my_single_text_dir/real.pt b/miragenews/encodings/text/my_single_text_dir/real.pt
new file mode 100644
index 0000000000000000000000000000000000000000..3f33849e418627da6c9c661687734ca4c9e939f3
--- /dev/null
+++ b/miragenews/encodings/text/my_single_text_dir/real.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9c810edd1b30cd91344388890ad6371e16079e49283dfef31cdbed10669c1ce2
+size 3028
diff --git a/miragenews/img/config.py b/miragenews/img/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f951e8c40317f1b1b267da6b349a0f652352c25
--- /dev/null
+++ b/miragenews/img/config.py
@@ -0,0 +1,21 @@
+import os
+from dotenv import load_dotenv
+
+load_dotenv()
+
+# API Keys
+SCRAPINGDOG_API_KEY = os.getenv("SCRAPINGDOG_API_KEY")
+GOOGLE_SAFE_BROWSING_API_KEY = os.getenv("GOOGLE_SAFE_BROWSING_API_KEY")
+
+# Paths
+BASE_DIR = os.getcwd()
+IMAGE_SAVE_DIR = os.path.join(BASE_DIR, "my_dataset/image/my_single_image_dir/real")
+MERGED_PT_DIR = os.path.join(BASE_DIR, "encodings/predictions/image/merged/my_single_image_dir")
+REAL_PT_FILENAME = "real.pt"
+CUSTOM_IMG_DIR_NAME = "my_single_image_dir"
+TEMP_IMAGE_DIR = os.path.join(BASE_DIR, "temp_images")
+MIRAGE_CONFIG_IMG = "configs/image/mirage.yaml"
+MIRAGE_CONFIG_MULTI = "configs/multimodal/mirage.yaml"
+
+# Ensure directories exist
+os.makedirs(TEMP_IMAGE_DIR, exist_ok=True)
\ No newline at end of file
diff --git a/miragenews/img/constants.py b/miragenews/img/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..14febf555230c76aa92fc7391d96080aa9061ec6
--- /dev/null
+++ b/miragenews/img/constants.py
@@ -0,0 +1,395 @@
+# constants.py
+
+import re
+import os
+from dotenv import load_dotenv
+
+# Tải các biến môi trường từ tệp .env
+load_dotenv()
+
+# --- Tham số Ngưỡng ---
+human_thres = 0.55
+machine_thres = 0.30
+
+
+# Chúng ta đã XÓA danh sách REGEX (chứa 'gan', 'clip', 'sd')
+AI_KEYWORDS_HIGH_CERTAINTY = [
+ # --- Tiếng Anh: Thuật ngữ chung & Khái niệm ---
+ 'ai generated', 'ai-generated', 'generated by ai', 'ai creation', 'ai artwork',
+ 'ai image', 'ai photo', 'ai generated image', 'ai generated picture',
+ 'ai generated art', 'ai drawing', 'ai illustration', 'ai render', 'ai rendering',
+ 'ai art', 'ai portrait', 'ai painting',
+ 'synthetic image', 'synthetic media', 'synthetic photo', 'synthetic photography',
+ 'generative ai', 'generative art', 'ai-assisted', 'ai enhanced',
+ 'made with ai', 'created with ai', 'powered by ai', 'rendered by ai',
+ 'text-to-image', 'text2image', 'image-to-image', 'img2img', 'txt2img',
+ 'prompted by', 'from a prompt', 'ai prompt', 'image prompt',
+ 'deepfake', 'deep fake', 'ai face', 'ai avatar', 'ai headshot', 'virtual human', 'virtual model', 'ai influencer',
+ 'ai background', 'generated background',
+ 'ai photo manipulation', 'ai edited', 'edited with ai', 'ai filter',
+ 'not a real person', 'this person is not real',
+ 'computer generated imagery',
+ 'procedural generation', 'procedurally generated',
+
+ # --- Tiếng Anh: Tên Mô hình & Công nghệ ---
+ 'stable diffusion', 'stablediffusion', 'stable diffusion xl',
+ 'midjourney', 'midjourney ai', 'midjourney v5', 'midjourney v6',
+ 'dall-e', 'dall-e 2', 'dall-e 3', 'dalle', 'dalle2', 'dalle3',
+ 'stylegan', 'stylegan2', 'stylegan3', 'generative adversarial network',
+ 'diffusion model', 'latent diffusion',
+ 'adobe firefly', 'photoshop ai', 'generative fill',
+ 'google imagen',
+ 'openai sora',
+ 'microsoft vasa',
+ 'kandinsky',
+ 'controlnet',
+ 'textual inversion', 'dreambooth',
+ 'wuerstchen',
+ 'deepfloyd',
+ 'variational autoencoder', # 'vae' đã bị xóa
+
+ # --- Tiếng Anh: Tên Nền tảng & Công cụ ---
+ 'leonardo ai', 'leonardo.ai',
+ 'civitai',
+ 'hugging face', 'huggingface',
+ 'runway', 'runwayml', 'runway gen-2',
+ 'playground ai', 'playground.ai',
+ 'artbreeder',
+ 'nightcafe', 'nightcafe.studio',
+ 'craiyon',
+ 'dreamstudio',
+ 'getimg.ai',
+ 'mage.space',
+ 'deepdream',
+ 'artstation',
+ 'deviantart',
+ 'ideogram', 'ideogram.ai',
+ 'bluewillow',
+ 'krea.ai',
+ 'pika', 'pika labs', 'pika.art',
+ 'luma ai', 'luma labs',
+ 'stability ai',
+ 'tensor.art',
+ 'pixlr ai',
+ 'canva ai',
+ 'picsart ai',
+ 'starryai',
+ 'kaiber',
+
+ # --- Tiếng Việt ---
+ 'tạo bởi ai', 'được tạo bởi ai', 'tạo ra bởi ai',
+ 'ảnh ai', 'anh ai', 'hình ai',
+ 'ảnh do ai tạo', 'anh do ai tao',
+ 'ảnh trí tuệ nhân tạo', 'anh tri tue nhan tao',
+ 'trí tuệ nhân tạo tạo ra', 'tạo bởi trí tuệ nhân tạo',
+ 'tạo sinh', 'ảnh tạo sinh', 'nghệ thuật tạo sinh',
+ 'ảnh deepfake', 'video deepfake', 'giả mạo sâu',
+ 'bằng trí tuệ nhân tạo',
+ 'sử dụng ai', 'từ ai', 'nhờ ai', 'vẽ bằng ai',
+ 'mô hình ai', 'mô hình trí tuệ nhân tạo',
+ 'được vẽ bởi ai', 'vẽ bởi ai',
+ 'mô hình khuếch tán', # Diffusion model
+ 'mạng đối nghịch', 'mạng đối nghịch tạo sinh', # GAN
+ 'người mẫu ảo', 'ảnh ảo', 'người ảo'
+]
+
+# --- DANH SÁCH DOMAIN AI (Giữ nguyên) ---
+ai_domains = [
+ 'artlist.io', 'artlist.ai',
+ 'artstation.com',
+ 'deviantart.com',
+ 'civitai.com',
+ 'lexica.art',
+ 'huggingface.co',
+ 'runwayml.com',
+ 'playground.ai', 'playgroundai.com',
+ 'nightcafe.studio',
+ 'craiyon.com',
+ 'leonardo.ai',
+ 'mage.space',
+ 'dreamstudio.ai',
+ 'thispersondoesnotexist.com',
+ 'deepai.org',
+ 'getimg.ai',
+ 'fotor.com',
+ 'deepdreamgenerator.com',
+ 'starryai.com',
+ 'stockimg.ai',
+ 'firefly.adobe.com',
+ 'pixai.art',
+ 'krea.ai',
+ 'ideogram.ai',
+ 'seaart.ai',
+ 'artbreeder.com',
+ 'artguru.ai',
+ 'pollinations.ai',
+ 'instantart.io',
+ 'kaiber.ai',
+ 'pika.art',
+ 'luma.ai',
+ 'midjourney.com',
+ 'stability.ai',
+ 'tensor.art',
+ 'picsart.com',
+ 'canva.com',
+ 'synthesia.io'
+]
+
+# --- DOMAIN UY TÍN (Giữ nguyên) ---
+reputable_domains_patterns = [
+ # Quốc tế
+ r'\.gov$', r'\.edu$', r'\.mil$',
+ r'bbc\.com', r'bbc\.co\.uk',
+ r'reuters\.com',
+ r'apnews\.com',
+ r'nytimes\.com',
+ r'washingtonpost\.com',
+ r'theguardian\.com',
+ r'cnn\.com',
+ r'wikipedia\.org',
+ r'wikimedia\.org',
+ r'wsj\.com',
+ r'economist\.com',
+ r'forbes\.com',
+ r'bloomberg\.com',
+
+ # Việt Nam (Chính thống)
+ r'chinhphu\.vn',
+ r'vnexpress\.net',
+ r'dantri\.com\.vn',
+ r'vietnamnet\.vn',
+ r'tuoitre\.vn',
+ r'thanhnien\.vn',
+ 'laodong.vn',
+ r'vtv\.vn',
+ 'vov.vn',
+ r'baotintuc\.vn',
+ 'nhandan.vn',
+ 'qdnd.vn'
+]
+
+
+# === DANH SÁCH LOẠI TRỪ (MỞ RỘNG) ===
+# (Mạng xã hội, UGC, Video, Diễn đàn, E-commerce, Cache, Stock miễn phí)
+excluded_domains = [
+ # Mạng xã hội (Social Media)
+ "facebook.com",
+ "instagram.com",
+ "twitter.com",
+ "x.com",
+ "pinterest.com",
+ "linkedin.com",
+ "reddit.com",
+ "tumblr.com",
+ "vk.com", # Mạng xã hội của Nga
+ "gab.com",
+ "parler.com",
+ "truthsocial.com",
+ "t.me", # Telegram
+ "discord.com",
+ "weibo.com", # Mạng xã hội Trung Quốc
+ "ok.ru", # Mạng xã hội Nga (Odnoklassniki)
+ "line.me", # Mạng xã hội/nhắn tin (Nhật/Đông Nam Á)
+
+ # Nền tảng Video / Streaming
+ "youtube.com",
+ "youtu.be",
+ "tiktok.com",
+ "vimeo.com",
+ "dailymotion.com",
+ "twitch.tv",
+ "bitchute.com",
+ "rumble.com",
+
+ # Host ảnh / GIF / UGC
+ "imgur.com",
+ "flickr.com",
+ "giphy.com",
+ "tenor.com",
+ "deviantart.com",
+ "artstation.com", # Có thể là nguồn, nhưng là UGC
+ "photobucket.com",
+ "500px.com",
+ "gfycat.com",
+
+ # Stock ảnh miễn phí (Thường dùng để minh họa, không phải ảnh sự kiện gốc)
+ "unsplash.com",
+ "pexels.com",
+ "pixabay.com",
+ "freepik.com",
+
+ # Nền tảng Blog / Forum / Diễn đàn
+ "blogger.com",
+ "blogspot.com",
+ "wordpress.com", # Trang miễn phí, không phải .org tự host
+ "quora.com",
+ "medium.com",
+ "substack.com",
+ "wix.com",
+ "weebly.com",
+ "livejournal.com",
+ "9gag.com",
+ "4chan.org",
+ "boards.4channel.org",
+ "8kun.top",
+ "ifunny.co",
+
+ # E-commerce / Chợ
+ "amazon.com",
+ "ebay.com",
+ "etsy.com",
+ "aliexpress.com",
+ "alibaba.com",
+ "walmart.com",
+ "shopify.com",
+
+ # Cache / CDN (Thường không phải nguồn)
+ "googleusercontent.com",
+ "pinimg.com", # CDN của Pinterest
+ "fbcdn.net", # CDN của Facebook
+ "cdn.discordapp.com",
+ "twimg.com", # CDN của Twitter
+ "cloudfront.net", # Amazon CDN
+ "akamaihd.net",
+ "fastly.net",
+
+ # Âm nhạc / Khác
+ "spotify.com",
+ "soundcloud.com",
+
+ # Lưu trữ / Tiện ích
+ "archive.org", # Wayback Machine
+ "archive.is",
+ "gravatar.com", # Avatars
+ "tineye.com", # Bản thân Tineye là công cụ tìm kiếm
+ "bit.ly", # Rút gọn link
+ "imgflip.com", # Chế meme
+]
+# === DANH SÁCH NGUỒN UY TÍN (MỞ RỘNG) ===
+# (Tổ chức Fact-check, Hãng thông tấn, Báo chí lớn, NGO/Khoa học)
+fact_checker_domains = [
+ # Tổ chức Fact-Check chuyên nghiệp
+ "snopes.com",
+ "factcheck.org",
+ "politifact.com",
+ "leadstories.com",
+ "fullfact.org", # UK
+ "boomlive.in", # Ấn Độ
+ "poynter.org", # Tổ chức mẹ của PolitiFact
+ "africacheck.org", # Châu Phi
+ "chequeado.com", # Argentina
+ "aosfatos.org", # Brazil
+ "maldita.es", # Tây Ban Nha
+ "correctiv.org", # Đức
+ "factly.in", # Ấn Độ
+ "vishvasnews.com", # Ấn Độ
+ "taiwantfactcheck.org", # Đài Loan
+ "rmit.org.au", # Úc (RMIT FactLab)
+ "facta.news", # Ý
+ "verafiles.org", # Philippines
+ "factcrescendo.com", # Ấn Độ / Đông Nam Á
+ "stopfake.org", # Ukraine
+ "logically.ai", # Anh / Mỹ
+
+ # Hãng thông tấn quốc tế (Rất quan trọng cho nguồn ảnh)
+ "reuters.com",
+ "apnews.com", # Associated Press
+ "afp.com", # Agence France-Presse
+ "gettyimages.com", # Nguồn ảnh stock/tin tức cực lớn
+ "shutterstock.com", # Tương tự Getty
+ "alamy.com", # Tương tự Getty
+ "dpa.com", # Hãng thông tấn Đức
+ "efe.com", # Hãng thông tấn Tây Ban Nha
+ "ansa.it", # Hãng thông tấn Ý
+ "kyodonews.net", # Hãng thông tấn Nhật Bản
+ "pa.media", # Press Association (UK)
+ "tass.com", # Hãng thông tấn Nga
+ "xinhuanet.com", # Tân Hoa Xã (Trung Quốc)
+ "upi.com", # United Press International
+ "canadianpress.com", # The Canadian Press
+
+ # Báo chí / Truyền thông uy tín (Toàn cầu)
+ "bbc.com",
+ "nytimes.com",
+ "washingtonpost.com",
+ "theguardian.com",
+ "wsj.com", # Wall Street Journal
+ "npr.org",
+ "aljazeera.com",
+ "cbsnews.com",
+ "abcnews.go.com",
+ "nbcnews.com",
+ "cnn.com",
+ "ft.com", # Financial Times
+ "economist.com",
+ "time.com",
+ "pbs.org",
+ "dw.com", # Deutsche Welle (Đức)
+ "france24.com",
+ "euronews.com",
+
+ # Báo chí / Truyền thông uy tín (Theo quốc gia)
+ "lemonde.fr", # Pháp
+ "lefigaro.fr", # Pháp
+ "spiegel.de", # Đức
+ "zeit.de", # Đức
+ "faz.net", # Đức
+ "sueddeutsche.de", # Đức
+ "elpais.com", # Tây Ban Nha
+ "elmundo.es", # Tây Ban Nha
+ "corriere.it", # Ý
+ "repubblica.it", # Ý
+ "thehindu.com", # Ấn Độ
+ "timesofindia.indiatimes.com", # Ấn Độ
+ "asahi.com", # Nhật Bản (Asahi Shimbun)
+ "yomiuri.co.jp", # Nhật Bản (Yomiuri Shimbun)
+ "mainichi.jp", # Nhật Bản (Mainichi Shimbun)
+ "koreaherald.com", # Hàn Quốc
+ "koreajoongangdaily.joins.com", # Hàn Quốc
+ "scmp.com", # South China Morning Post (Hong Kong)
+ "straitstimes.com", # Singapore
+ "cbc.ca", # Canada
+ "ctvnews.ca", # Canada
+ "theglobeandmail.com", # Canada
+ "nationalpost.com", # Canada
+ "abc.net.au", # Úc
+ "smh.com.au", # Úc (Sydney Morning Herald)
+ "theage.com.au", # Úc
+ "theaustralian.com.au", # Úc
+ "telegraph.co.uk", # UK
+ "independent.co.uk", # UK
+ "news.sky.com", # UK
+ "thetimes.co.uk", # UK
+ "latimes.com", # US (Los Angeles Times)
+ "usatoday.com", # US
+ "chicagotribune.com", # US
+ "bostonglobe.com", # US
+
+ # Tổ chức phi chính phủ / Khoa học (Uy tín cho các chủ đề cụ thể)
+ "who.int", # Tổ chức Y tế Thế giới
+ "cdc.gov", # Trung tâm kiểm soát dịch bệnh Hoa Kỳ
+ "nasa.gov", # NASA
+ "un.org", # Liên Hợp Quốc
+ "unicef.org", # Quỹ Nhi đồng Liên Hợp Quốc
+ "unesco.org",
+ "wfp.org", # Chương trình Lương thực Thế giới
+ "icrc.org", # Ủy ban Chữ thập đỏ quốc tế
+ "redcross.org", # Hội chữ thập đỏ (Quốc gia)
+ "doctorswithoutborders.org", # Bác sĩ không biên giới (MSF)
+ "amnesty.org", # Ân xá Quốc tế
+ "hrw.org", # Human Rights Watch
+ "nature.com",
+ "sciencemag.org",
+ "thelancet.com",
+ "nejm.org", # New England Journal of Medicine
+ "bmj.com", # British Medical Journal
+]
+
+# --- API KEYS (AN TOÀN) ---
+SCRAPINGDOG_API_KEY = os.getenv("SCRAPINGDOG_API_KEY")
+GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
+
+if not SCRAPINGDOG_API_KEY:
+ print("Cảnh báo: SCRAPINGDOG_API_KEY không tìm thấy trong tệp .env")
+if not GEMINI_API_KEY:
+ print("Cảnh báo: GEMINI_API_KEY không tìm thấy trong tệp .env")
diff --git a/miragenews/img/core.py b/miragenews/img/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..0819c5ff095b2233506d8009f41de7d5705956ff
--- /dev/null
+++ b/miragenews/img/core.py
@@ -0,0 +1,522 @@
+import os
+import io
+import re
+import asyncio
+import subprocess
+from typing import List, Dict, Any
+from urllib.parse import urlparse
+
+import gradio as gr
+from PIL import Image
+from google.cloud import vision
+import httpx
+
+# 1. Import Config
+from .config import (
+ GOOGLE_SAFE_BROWSING_API_KEY,
+ SCRAPINGDOG_API_KEY,
+ MERGED_PT_DIR,
+ REAL_PT_FILENAME,
+ BASE_DIR,
+ CUSTOM_IMG_DIR_NAME
+)
+
+# 2. Import Resources (Singleton Instance)
+# Chúng ta import đối tượng 'resources' đã được khởi tạo bên resources.py
+from .resources import resources
+
+# 3. Import Utils
+from .llm_analyzer import rank_reputable_urls_llm, get_llm_analysis
+from .web_utils import (
+ check_google_safe_browsing,
+ get_image_embedding,
+ find_best_url_fast_scan_bs4,
+ get_html_context_block_bs4
+)
+from .pixel_utils import run_mirage_subprocess, predict_authenticity_from_pt
+
+# --- HÀM LOGIC CHÍNH ---
+
+async def analyze_saved_images(saved_files_state: List[Dict[str, str]], progress=gr.Progress()):
+
+ mirage_img_global = resources.mirage_model
+ google_vision_client = resources.google_vision_client
+ clip_model = resources.clip_model
+ device_mirage = resources.device
+ best_threshold = resources.best_threshold
+
+ # Hàm tiện ích nội bộ
+ def get_friendly_judgment(judgment: str) -> str:
+ if judgment == "AI_RELATED": return "AI Detected"
+ if judgment == "NOT_AI_RELATED": return "Looks Normal"
+ if judgment == "NO_URLS": return "No Web Info"
+ if judgment == "ERROR": return "Error"
+ if judgment == "SKIPPED": return "Skipped"
+ return "Unclear"
+
+ def create_status_output(json_status: Dict):
+ status_json = {"status_update": json_status}
+ return status_json, None
+
+ # 1. Kiểm tra Model đã load chưa
+ if mirage_img_global is None or google_vision_client is None or clip_model is None:
+ error_msg = "Models not loaded. "
+ if mirage_img_global is None: error_msg += "(MirageNews failed) "
+ if google_vision_client is None: error_msg += "(Google Vision failed) "
+ if clip_model is None: error_msg += "(CLIP model failed) "
+ print(f"LỖI: {error_msg}")
+ yield {"error": error_msg}, f"## ❌ Error\n{error_msg}"
+ return
+
+ # 2. Kiểm tra input
+ if not saved_files_state:
+ yield {"info": "No images were loaded."}, "## ℹ️ Info\nNo images were loaded. Please upload an image."
+ return
+
+ first_file_info = saved_files_state[0]
+ image_name = first_file_info["original_filename"]
+ image_path = first_file_info["saved_path"]
+
+ print(f"\n{'='*50}\nBắt đầu phân tích ảnh: {image_name} (tại {image_path})\n{'='*50}")
+
+ # Khởi tạo biến kết quả
+ final_judgment = "UNCLEAR"
+ final_reason = "Bắt đầu phân tích..."
+ gv_sim_score_display = 0.0
+ all_urls_for_display = []
+ found_ai_source = False
+ ai_source_urls_found = []
+ all_pages = []
+ web_judgment = "UNCLEAR"
+ pixel_judgment = "UNCLEAR"
+ llm_only_judgment = "UNCLEAR"
+ md_progress_history = ["**Analysis Log:**"]
+ best_url_from_fast_scan = None
+
+ progress(0.05, desc="Initializing analysis...")
+ yield create_status_output({"step": "Initializing"})
+
+ # --- BẮT ĐẦU LOGIC PHÂN TÍCH ---
+ try:
+ print(f"Đang đọc file ảnh từ: {image_path}")
+ with io.open(image_path, 'rb') as image_file:
+ content = image_file.read()
+
+ input_image_pil = Image.open(io.BytesIO(content)).convert("RGB")
+
+ # --- GOOGLE VISION CHECK ---
+ image_vision = vision.Image(content=content)
+ feature = vision.Feature(type_=vision.Feature.Type.WEB_DETECTION)
+ gv_request = vision.AnnotateImageRequest(image=image_vision, features=[feature])
+
+ gv_response = google_vision_client.annotate_image(request=gv_request)
+ ann = gv_response.web_detection
+ gv_sim_score_display = min(ann.web_entities[0].score, 1.0) if ann.web_entities else 0.0
+
+ all_found_pages_raw = ann.pages_with_matching_images[:30]
+
+ if not all_found_pages_raw:
+ print("Google Vision không tìm thấy trang nào.")
+ all_pages = []
+ else:
+ print(f"Google Vision tìm thấy {len(all_found_pages_raw)} trang. Bắt đầu lọc an toàn...")
+ md_progress_history.append(f"- ✅ Web search complete. Found {len(all_found_pages_raw)} pages. Checking for safety...")
+ progress(0.2, desc="Verifying link safety (Google)...")
+ yield create_status_output({"step": "Safety Check"})
+
+ urls_to_check = [page.url for page in all_found_pages_raw if page.url]
+ safety_results = check_google_safe_browsing(urls_to_check, GOOGLE_SAFE_BROWSING_API_KEY)
+ unsafe_urls = set()
+ if safety_results and 'matches' in safety_results:
+ for match in safety_results['matches']:
+ unsafe_url = match['threat']['url']
+ threat_type = match['threatType']
+ print(f"🚫 [LỌC AN TOÀN] Phát hiện link nguy hiểm: {unsafe_url} ({threat_type})")
+ unsafe_urls.add(unsafe_url)
+
+ safe_pages = []
+ for page in all_found_pages_raw:
+ if page.url not in unsafe_urls:
+ safe_pages.append(page)
+
+ all_pages = [p.url for p in safe_pages[:10] if p.url]
+
+ print(f"✅ Lọc an toàn hoàn tất. Còn lại {len(all_pages)} trang (Top 10 an toàn) để phân tích.")
+ md_progress_history.append(f"- 🛡️ Safety check complete. Filtered {len(all_found_pages_raw) - len(safe_pages)} unsafe links. Proceeding with {len(all_pages)} pages.")
+
+ all_urls_for_display = all_pages
+ progress(0.3, desc=f"✅ Web search complete. Found {len(all_pages)} safe pages.")
+ yield create_status_output({"step": "Google Vision", "urls_found": len(all_pages)})
+
+ except Exception as e:
+ print(f"❌ Lỗi nghiêm trọng khi gọi Google Vision: {e}")
+ md_progress_history.append(f"- ❌ Error during Google Vision search: {e}")
+ yield create_status_output({"step": "Google Vision Error", "error": str(e)})
+ all_pages = []
+ all_urls_for_display = []
+ gv_sim_score_display = 0.0
+
+ # --- WEB CONTENT ANALYSIS ---
+ if all_pages:
+ progress(0.35, desc="Phase 0/3: Filtering reputable sources (LLM)...")
+ ranked_urls = await rank_reputable_urls_llm(all_pages)
+
+ if not ranked_urls:
+ print("⚠️ LLM Reputation Filter đã loại bỏ tất cả. Bỏ qua web check.")
+ md_progress_history.append("- ⚠️ LLM Reputation Filter removed all URLs. Skipping web checks.")
+ else:
+ md_progress_history.append(f"- ✅ LLM Reputation Filter complete. Kept {len(ranked_urls)} URLs (sorted).")
+
+ if ranked_urls:
+ print("\n--- [GIAI ĐOẠN 1: QUÉT NHANH (BS4)] ---")
+ md_progress_history.append(f"- 🧠 Phase 1: Fast scan ({len(ranked_urls)} URLs) to find best match...")
+ progress(0.4, desc=f"Phase 1/3: Fast scan ({len(ranked_urls)} URLs)...")
+
+ input_embedding = get_image_embedding(clip_model, input_image_pil)
+ best_url_from_fast_scan = None
+ scraped_context_html = ""
+ max_sim = 0.0
+
+ async with httpx.AsyncClient(timeout=10.0, follow_redirects=True) as client:
+ best_url_from_fast_scan, max_sim = await find_best_url_fast_scan_bs4(
+ input_embedding, ranked_urls, client, clip_model, SCRAPINGDOG_API_KEY, progress
+ )
+
+ if not best_url_from_fast_scan:
+ print("⚠️ Quét Nhanh (Phase 1) không tìm thấy URL nào. Bỏ qua.")
+ web_judgment = "UNCLEAR"
+ final_reason = "Web analysis unclear. Fast scan (Phase 1) did not find a matching URL."
+ md_progress_history.append("- ⚠️ Fast scan (Phase 1) did not select a URL. Skipping deep scan.")
+ else:
+ print(f"✅ Quét Nhanh (Phase 1) tìm thấy URL tốt nhất: {best_url_from_fast_scan}")
+
+ if max_sim < 0.75:
+ print(f"⚠️ Similarity ({max_sim:.4f}) < 0.75. Bỏ qua phân tích context (Web).")
+ md_progress_history.append(f"- ⚠️ Image match similarity ({max_sim:.4f}) is below 0.75. Skipping web text analysis.")
+ web_judgment = "UNCLEAR"
+ final_reason = "Web context analysis skipped due to low image similarity (< 0.75)."
+ else:
+ print("\n--- [GIAI ĐOẠN 2: LẤY KHỐI HTML (BS4)] ---")
+ md_progress_history.append(f"- 🧠 Phase 2: Fetching HTML block from 1 URL: {best_url_from_fast_scan}")
+
+ async with httpx.AsyncClient(timeout=10.0, follow_redirects=True) as client:
+ scraped_context_html = await get_html_context_block_bs4(
+ input_embedding, best_url_from_fast_scan, client, clip_model, SCRAPINGDOG_API_KEY, progress
+ )
+
+ print("\n--- [BẮT ĐẦU QUY TẮC LLM (HTML Check)] ---")
+ if not scraped_context_html:
+ print("⚠️ Quét Sâu (Phase 2) không tìm thấy khối HTML. Bỏ qua LLM (Web).")
+ web_judgment = "UNCLEAR"
+ final_reason = "Deep scan (Phase 2) found a URL, but no HTML block for the image."
+ md_progress_history.append("- ⚠️ No HTML block found (Phase 2). Skipping web text analysis.")
+ else:
+ print(f"✅ Quét Sâu (Phase 2) tìm thấy khối HTML (dài: {len(scraped_context_html)} chars)")
+ md_progress_history.append(f"- ✅ Found HTML block (Length: {len(scraped_context_html)}).")
+ md_progress_history.append(f"- ℹ️ Similarity ({max_sim:.4f}) >= 0.75. Proceeding with web text analysis.")
+
+ domain_name = "unknown.com"
+ try:
+ domain_name = urlparse(best_url_from_fast_scan).netloc.replace("www.", "")
+ except: pass
+
+ scraped_data_dict = {f"HTML snippet from {best_url_from_fast_scan}": scraped_context_html}
+
+ judgment_llm_web, reason_llm_web = await get_llm_analysis(
+ image_path, scraped_data_dict, domain_name, (max_sim > 0.9)
+ )
+
+ web_judgment = judgment_llm_web
+ final_reason = reason_llm_web
+
+ if judgment_llm_web == "AI_RELATED":
+ print(f"✅ [KẾT QUẢ LLM (Web)] LLM TÌM THẤY BẰNG CHỨNG AI/FAKE TRONG CONTEXT")
+ found_ai_source = True
+ ai_source_urls_found.append(best_url_from_fast_scan)
+ elif judgment_llm_web == "NOT_AI_RELATED" and "Visually verified as authentic" in reason_llm_web:
+ print(f"✅ [KẾT QUẢ LLM (Web)] 'Trust Rule' ĐÃ ĐẠT. Xác minh là thật.")
+ llm_only_judgment = "SKIPPED"
+ pixel_judgment = "SKIPPED"
+ md_progress_history.append("- ✅ 'Trust Rule' PASSED (Sim > 0.9 & Clean Text). Skipping Image-Only and Pixel checks.")
+ else:
+ print(f"✅ [KẾT QUẢ LLM (Web)] LLM không tìm thấy bằng chứng AI/FAKE trong context.")
+ pixel_judgment = "SKIPPED"
+ md_progress_history.append("- ℹ️ Web text (HTML context) analysis complete. No AI/Fake evidence found.")
+
+ md_progress_history.append(f"- ➡️ Web Analysis Result: **{get_friendly_judgment(web_judgment)}**")
+ yield create_status_output({"step": "Rule 1 Complete", "web_judgment": web_judgment})
+
+ # --- LLM IMAGE ONLY CHECK (WITH WEB CONTEXT) ---
+ if web_judgment != "AI_RELATED" and llm_only_judgment != "SKIPPED":
+ print(f"ℹ️ Web check không phải AI ('{web_judgment}'). Chạy LLM (Image-Only) check...")
+ md_progress_history.append("- ℹ️ Starting LLM (Gemini) image-only analysis...")
+ progress(0.7, desc="Phase 3/3: LLM (image-only) analysis...")
+ yield create_status_output({"step": "LLM (Image-Only) Check"})
+
+ try:
+ llm_judgment_no_url, llm_reason_no_url = await get_llm_analysis(image_path, {}, "N/A", False)
+ llm_only_judgment = llm_judgment_no_url
+ print(f"✅ [KẾT QUẢ LLM (Image-Only)] Judgment: {llm_only_judgment}, Reason: {llm_reason_no_url}")
+
+ if llm_only_judgment == "AI_RELATED":
+ final_reason = llm_reason_no_url
+
+ except Exception as e:
+ print(f"❌ Lỗi khi gọi get_llm_analysis (image-only): {e}")
+ llm_only_judgment = "ERROR"
+ final_reason = f"Error during LLM (image-only) analysis: {e}"
+
+ md_progress_history.append(f"- ➡️ LLM (Image-Only) Analysis Result: **{get_friendly_judgment(llm_only_judgment)}**")
+ yield create_status_output({"step": "LLM (Image-Only) Complete", "llm_judgment": llm_only_judgment})
+ else:
+ if llm_only_judgment == "SKIPPED":
+ print(f"ℹ️ Bỏ qua LLM (Image-Only) vì 'Trust Rule' (Sim > 0.9) đã thành công.")
+ else:
+ print(f"ℹ️ Bỏ qua LLM (Image-Only) vì Web check đã là 'AI_RELATED'.")
+ md_progress_history.append("- ℹ️ Skipping LLM (Image-Only) check (AI/Fake found on web context).")
+
+ else:
+ # Case: URLs found but filtered out by reputation
+ print(f"ℹ️ Không có URL uy tín nào. Sẽ chạy LLM (Gemini) VÀ Pixel (MirageNews).")
+ md_progress_history.append("- ℹ️ No reputable URLs found. Starting LLM (Gemini) image-only analysis...")
+ progress(0.4, desc="No reputable pages. Starting LLM analysis...")
+ yield create_status_output({"step": "LLM (No-URL) Check"})
+
+ try:
+ llm_judgment_no_url, llm_reason_no_url = await get_llm_analysis(image_path, {}, "N/A", False)
+ llm_only_judgment = llm_judgment_no_url
+ print(f"✅ [KẾT QUẢ LLM (Không URL)] Judgment: {llm_only_judgment}, Reason: {llm_reason_no_url}")
+
+ if llm_only_judgment == "AI_RELATED":
+ final_reason = llm_reason_no_url
+ else:
+ final_reason = f"LLM (Gemini) analysis (image-only) result: {llm_judgment_no_url}. (Reason: {llm_reason_no_url})"
+ except Exception as e:
+ print(f"❌ Lỗi khi gọi get_llm_analysis (không URL): {e}")
+ llm_only_judgment = "ERROR"
+ final_reason = f"Error during LLM (image-only) analysis: {e}"
+
+ md_progress_history.append(f"- ➡️ LLM (No-URL) Analysis Result: **{get_friendly_judgment(llm_only_judgment)}**")
+ progress(0.6, desc=f"LLM (No-URL) Result: {get_friendly_judgment(llm_only_judgment)}")
+ yield create_status_output({"step": "LLM (No-URL) Complete", "llm_judgment": llm_only_judgment})
+ web_judgment = "NO_URLS"
+ md_progress_history.append("- ℹ️ Proceeding to pixel analysis (MirageNews).\n- ➡️ Web Analysis Result: **{get_friendly_judgment(web_judgment)}**")
+
+ else:
+ # Case: No URLs found by Google Vision
+ print(f"ℹ️ Google Vision không tìm thấy URL nào. Sẽ chạy LLM (Gemini) VÀ Pixel (MirageNews).")
+ md_progress_history.append("- ℹ️ No web pages found. Starting LLM (Gemini) image-only analysis...")
+ progress(0.4, desc="No web pages found. Starting LLM analysis...")
+ yield create_status_output({"step": "LLM (No-URL) Check"})
+
+ try:
+ llm_judgment_no_url, llm_reason_no_url = await get_llm_analysis(image_path, {}, "N/A", False)
+ llm_only_judgment = llm_judgment_no_url
+ print(f"✅ [KẾT QUẢ LLM (Không URL)] Judgment: {llm_only_judgment}, Reason: {llm_reason_no_url}")
+
+ if llm_only_judgment == "AI_RELATED":
+ final_reason = llm_reason_no_url
+ else:
+ final_reason = f"LLM (Gemini) analysis (image-only) result: {llm_judgment_no_url}. (Reason: {llm_reason_no_url})"
+ except Exception as e:
+ print(f"❌ Lỗi khi gọi get_llm_analysis (không URL): {e}")
+ llm_only_judgment = "ERROR"
+ final_reason = f"Error during LLM (image-only) analysis: {e}"
+
+ md_progress_history.append(f"- ➡️ LLM (No-URL) Analysis Result: **{get_friendly_judgment(llm_only_judgment)}**")
+ progress(0.6, desc=f"LLM (No-URL) Result: {get_friendly_judgment(llm_only_judgment)}")
+ yield create_status_output({"step": "LLM (No-URL) Complete", "llm_judgment": llm_only_judgment})
+ web_judgment = "NO_URLS"
+ md_progress_history.append("- ℹ️ Proceeding to pixel analysis (MirageNews).\n- ➡️ Web Analysis Result: **{get_friendly_judgment(web_judgment)}**")
+
+ # --- MIRAGENEWS PIXEL ANALYSIS ---
+ if (web_judgment != "AI_RELATED" and
+ llm_only_judgment != "AI_RELATED" and
+ pixel_judgment != "SKIPPED" and
+ mirage_img_global is not None):
+
+ print(f"\n--- [BẮT ĐẦU PHÂN TÍCH PIXEL (MIRAGENEWS)] ---")
+ md_progress_history.append("- ✅ Starting pixel-level analysis... (This may take a moment)")
+ progress(0.8, desc="Phase 4/4: Pixel analysis (MirageNews)...")
+ yield create_status_output({"step": "MirageNews (Start)"})
+
+ try:
+ print("Chạy encode_predictions.py...")
+ cmd_encode = f"python miragenews/data/encode_predictions.py --mode image --model_class linear --custom --img_dirs {CUSTOM_IMG_DIR_NAME} --batch_size 16"
+ process_encode = await asyncio.create_subprocess_shell(
+ cmd_encode, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, cwd=BASE_DIR
+ )
+ stdout_e, stderr_e = await process_encode.communicate()
+ if process_encode.returncode != 0:
+ print(f"❌ Lỗi khi chạy encode_predictions.py: {stderr_e.decode()}")
+ raise subprocess.CalledProcessError(process_encode.returncode, cmd_encode, stderr=stderr_e)
+
+ print("Chạy manual_merge.py...")
+ cmd_merge = "python data/manual_merge.py"
+ process_merge = await asyncio.create_subprocess_shell(
+ cmd_merge, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, cwd=BASE_DIR
+ )
+ stdout_m, stderr_m = await process_merge.communicate()
+ if process_merge.returncode != 0:
+ print(f"❌ Lỗi khi chạy manual_merge.py: {stderr_m.decode()}")
+ raise subprocess.CalledProcessError(process_merge.returncode, cmd_merge, stderr=stderr_m)
+
+ print("Running pixel detection model...")
+ real_pt_path = os.path.join(MERGED_PT_DIR, REAL_PT_FILENAME)
+ prob_fake, label = predict_authenticity_from_pt(
+ real_pt_path, mirage_img_global, device_mirage, best_threshold) # Truyền best_threshold vào
+
+ if label == "fake": pixel_judgment = "AI_RELATED"
+ elif label == "real": pixel_judgment = "NOT_AI_RELATED"
+ else: pixel_judgment = "UNCLEAR"
+
+ except (subprocess.CalledProcessError, Exception) as e:
+ error_message = f"Lỗi MirageNews: {e}"
+ if hasattr(e, 'stderr'):
+ error_message = f"Lỗi MirageNews subprocess: {e.stderr.decode()}"
+ print(f"❌ {error_message}")
+ pixel_judgment = "ERROR"
+ final_reason = error_message
+
+ md_progress_history.append(f"- ➡️ Pixel Analysis Result: **{get_friendly_judgment(pixel_judgment)}**")
+ progress(0.9, desc=f"✅ Pixel Analysis Result: {get_friendly_judgment(pixel_judgment)}")
+ yield create_status_output({"step": "MirageNews (Complete)", "pixel_judgment": pixel_judgment})
+
+ else:
+ print("\n--- [BỎ QUA MIRAGENEWS] ---")
+ skip_reason = ""
+ if pixel_judgment == "SKIPPED":
+ skip_reason = "'Trust Rule' (Sim > 0.9) đã thành công."
+ md_progress_history.append("- ℹ️ Skipping pixel analysis ('Trust Rule' passed).")
+ elif web_judgment == "AI_RELATED":
+ skip_reason = "Lọc Web (Context) đã tìm thấy bằng chứng AI/Fake."
+ md_progress_history.append("- ℹ️ Skipping pixel analysis (AI/Fake found on web context).")
+ elif llm_only_judgment == "AI_RELATED":
+ skip_reason = "Lọc LLM (Image-Only) đã tìm thấy bằng chứng AI/Fake."
+ md_progress_history.append("- ℹ️ Skipping pixel analysis (AI/Fake found by LLM image-only check).")
+ elif mirage_img_global is None:
+ skip_reason = "Mô hình MirageNews không được tải."
+ md_progress_history.append("- ⚠️ Skipping pixel analysis (MirageNews model not loaded).")
+
+ print(f"(Lý do: {skip_reason})")
+ if pixel_judgment != "SKIPPED":
+ pixel_judgment = "SKIPPED"
+
+ progress(0.9, desc="Skipping pixel analysis (Not needed)")
+ yield create_status_output({"step": "MirageNews (Skipped)"})
+
+ # --- KẾT HỢP KẾT QUẢ CUỐI CÙNG ---
+ print("\n--- [KẾT HỢP KẾT QUẢ CUỐI CÙNG] ---")
+ print(f"Web Judgment: {web_judgment}, LLM-Only: {llm_only_judgment}, Pixel: {pixel_judgment}")
+ md_progress_history.append(f"- ✅ Combining results (Web: {get_friendly_judgment(web_judgment)}, LLM-Only: {get_friendly_judgment(llm_only_judgment)}, Pixel: {get_friendly_judgment(pixel_judgment)})...")
+ progress(0.95, desc="Combining final results...")
+ yield create_status_output({"step": "Combining", "web_judgment": web_judgment, "pixel_judgment": pixel_judgment, "llm_only_judgment": llm_only_judgment})
+
+ await asyncio.sleep(0.5)
+
+ # Logic quyết định cuối cùng (Final Decision Logic)
+ if web_judgment == "NOT_AI_RELATED" and "Visually verified as authentic" in final_reason:
+ final_judgment = "NOT_AI_RELATED"
+ elif web_judgment == "AI_RELATED" or llm_only_judgment == "AI_RELATED":
+ final_judgment = "AI_RELATED"
+ elif pixel_judgment == "AI_RELATED":
+ final_judgment = "AI_RELATED"
+ final_reason = "Pixel analysis (MirageNews) detected image as AI/Fake."
+ if web_judgment == "NOT_AI_RELATED":
+ final_reason += f" (Web context/reputability check passed. LLM-only check: {get_friendly_judgment(llm_only_judgment)})"
+ elif web_judgment == "UNCLEAR":
+ final_reason += f" (Web text context was unclear. LLM-only check: {get_friendly_judgment(llm_only_judgment)})"
+ elif web_judgment == "NO_URLS":
+ final_reason += f" (Web search found no URLs. LLM-only check: {get_friendly_judgment(llm_only_judgment)})"
+ elif (web_judgment == "NOT_AI_RELATED" or web_judgment == "UNCLEAR") and llm_only_judgment == "NOT_AI_RELATED" and pixel_judgment == "NOT_AI_RELATED":
+ final_judgment = "NOT_AI_RELATED"
+ final_reason = "All checks passed: Web text (context), LLM (image-only), and Pixel (MirageNews) found no evidence of being Not-Real."
+ elif web_judgment == "NO_URLS" and llm_only_judgment == "NOT_AI_RELATED" and pixel_judgment == "NOT_AI_RELATED":
+ final_judgment = "NOT_AI_RELATED"
+ final_reason = "All checks passed: No URLs found, LLM (image-only) found no AI/Fake, and Pixel (MirageNews) found no AI/Fake."
+ else:
+ final_judgment = "UNCLEAR"
+ final_reason = f"Analysis is inconclusive. Web (Context): '{get_friendly_judgment(web_judgment)}', LLM (Image-Only): '{get_friendly_judgment(llm_only_judgment)}', Pixel: '{get_friendly_judgment(pixel_judgment)}'."
+
+ print(f"Final Judgment: {final_judgment}")
+
+ final_results = {
+ "judgment": final_judgment,
+ "reason": final_reason,
+ "detected_ai_sources": ai_source_urls_found,
+ "all_detected_urls": all_urls_for_display,
+ "gv_similarity_score": gv_sim_score_display,
+ "debug_web_judgment": web_judgment,
+ "debug_pixel_judgment": pixel_judgment,
+ "debug_llm_only_judgment": llm_only_judgment
+ }
+
+ # --- FORMAT MARKDOWN OUTPUT ---
+ auth_assessment = "❓ **UNCLEAR**"
+ synth_type = "N/A"
+ artifacts = final_reason
+ source_url_to_display = best_url_from_fast_scan if (web_judgment != "NO_URLS" and best_url_from_fast_scan) else None
+
+ if final_judgment == "AI_RELATED":
+ auth_assessment = "🤖 **NOT REAL** (Fake, Manipulated, or AI)"
+ tool_match = re.search(r"\*\*Tool:\*\*(.*)", final_reason, re.IGNORECASE | re.DOTALL)
+ reason_match = re.search(r"\*\*Reason:\*\*(.*)", final_reason, re.IGNORECASE | re.DOTALL)
+
+ if reason_match:
+ full_reason_text = reason_match.group(1).strip()
+ if tool_match:
+ artifacts = full_reason_text.split("**Tool:**")[0].strip()
+ synth_type = tool_match.group(1).strip()
+ else:
+ artifacts = full_reason_text
+ synth_type = "Unknown"
+ elif tool_match:
+ synth_type = tool_match.group(1).strip()
+ artifacts = "Reason not specified, see tool."
+ else:
+ if final_reason.startswith("AI_RELATED"):
+ artifacts = "AI/Fake evidence found, but detailed reason was not extracted."
+ else:
+ artifacts = final_reason
+ synth_type = "Unknown"
+
+ if source_url_to_display:
+ artifacts += f"
**Source:** [{source_url_to_display}]({source_url_to_display})"
+
+ elif final_judgment == "NOT_AI_RELATED":
+ auth_assessment = "🧑 **REAL PHOTO**"
+ synth_type = "N/A"
+ reason_match = re.search(r"\*\*Reason:\*\*(.*)", final_reason, re.IGNORECASE | re.DOTALL)
+ if reason_match:
+ artifacts = reason_match.group(1).strip()
+ else:
+ artifacts = final_reason
+ if "Visually verified as authentic" in artifacts and source_url_to_display:
+ artifacts += f"
**Source:** [{source_url_to_display}]({source_url_to_display})"
+
+ tools_list = ["Google Vision (Web Search)"]
+ if web_judgment == "NO_URLS":
+ tools_list.append("LLM (Image-Only Analysis)")
+ else:
+ tools_list.append("LLM Reputation Filter")
+ tools_list.append("Fast Scan (BS4)")
+ tools_list.append("Deep HTML Scan (BS4)")
+ tools_list.append("LLM (HTML Context & Reputability Analysis)")
+ if llm_only_judgment != "SKIPPED":
+ tools_list.append("LLM (Image-Only Analysis)")
+ if pixel_judgment != "SKIPPED":
+ tools_list.append("Pixel Analysis (MirageNews)")
+
+ tools_methods_str = ", ".join(tools_list)
+
+ md_links = f"## 🏁 Forensic Analysis\n\n"
+ md_links += f"- **Authenticity Assessment:** {auth_assessment}\n"
+ md_links += f"- **Verification Tools & Methods:** {tools_methods_str}\n"
+ md_links += f"- **Synthetic Type (if applicable):** {synth_type.replace('\n', '
')}\n"
+ md_links += f"- **Other Artifacts:** {artifacts.replace('\n', '
')}\n"
+
+ progress(1.0, desc="🏁 Analysis Complete!")
+ await asyncio.sleep(0.5)
+ progress(None)
+
+ yield final_results, md_links
\ No newline at end of file
diff --git a/miragenews/img/file_utils.py b/miragenews/img/file_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f9403512939dcb8c52dd93393445fa98801d1c7
--- /dev/null
+++ b/miragenews/img/file_utils.py
@@ -0,0 +1,25 @@
+import os
+import shutil
+import uuid
+from typing import List, Dict, Any
+from .config import IMAGE_SAVE_DIR
+
+def handle_image_list_change(image_files: List[Any], current_state: List[Dict[str, str]]):
+ clear_json = None
+ if not image_files:
+ if os.path.exists(IMAGE_SAVE_DIR): shutil.rmtree(IMAGE_SAVE_DIR)
+ os.makedirs(IMAGE_SAVE_DIR, exist_ok=True)
+ return [], clear_json
+
+ if os.path.exists(IMAGE_SAVE_DIR): shutil.rmtree(IMAGE_SAVE_DIR)
+ os.makedirs(IMAGE_SAVE_DIR, exist_ok=True)
+
+ new_state = []
+ for img_obj in image_files:
+ orig_name = getattr(img_obj, 'orig_name', os.path.basename(img_obj.name))
+ ext = os.path.splitext(orig_name)[1] if '.' in orig_name else '.tmp'
+ new_path = os.path.join(IMAGE_SAVE_DIR, f"{uuid.uuid4()}{ext}")
+ shutil.copyfile(img_obj.name, new_path)
+ new_state.append({"original_filename": orig_name, "saved_path": new_path})
+
+ return new_state, clear_json
\ No newline at end of file
diff --git a/miragenews/img/llm_analyzer.py b/miragenews/img/llm_analyzer.py
new file mode 100644
index 0000000000000000000000000000000000000000..96567f5a1ea87232743f737319795786346a1bab
--- /dev/null
+++ b/miragenews/img/llm_analyzer.py
@@ -0,0 +1,248 @@
+import re
+import asyncio
+from typing import Dict, Tuple, List
+import google.generativeai as genai
+from PIL import Image
+
+from .constants import GEMINI_API_KEY
+
+model = None
+try:
+ if not GEMINI_API_KEY or "DÁN_KEY_CỦA_BẠN" in GEMINI_API_KEY:
+ print("⚠️ Warning: GEMINI_API_KEY is not set in constants.py. LLM analysis will be skipped.")
+ else:
+ genai.configure(api_key=GEMINI_API_KEY)
+
+ safety_settings = [
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
+ ]
+
+ model = genai.GenerativeModel(
+ 'gemini-2.5-pro',
+ safety_settings=safety_settings
+ )
+ print("✅ Gemini 2.5 Pro (Multimodal) model initialized successfully.")
+
+except Exception as e:
+ print(f"❌ Error initializing Gemini: {e}")
+
+def _sync_generation_call(prompt_parts: List) -> str:
+ """Hàm đồng bộ (sync) để gọi LLM (tránh blocking event loop)"""
+ if model is None:
+ print("Lỗi: model Gemini chưa được khởi tạo.")
+ return ""
+ try:
+ generation_config = genai.types.GenerationConfig(
+ temperature=0.0,
+ max_output_tokens=4000
+ )
+ response = model.generate_content(
+ prompt_parts,
+ generation_config=generation_config
+ )
+
+ if response.parts:
+ return response.text
+ else:
+ print(f"[LLM Call Error] API block. Finish Reason: {response.candidates[0].finish_reason}")
+ print(f" Block Message: {response.candidates[0].safety_ratings}")
+ return ""
+
+ except Exception as e:
+ print(f"❌ Error calling Gemini: {e}")
+ return ""
+
+async def rank_reputable_urls_llm(urls: List[str]) -> List[str]:
+ """
+ Gửi 10 URL cho LLM, yêu cầu nó lọc và sắp xếp theo độ uy tín
+ và trả về CHỈ SỐ (INDEX) của chúng.
+ CHỈ trả về các URL uy tín.
+ """
+ if model is None:
+ print("LLM không được khởi tạo, trả về danh sách URL gốc.")
+ return urls
+ if not urls:
+ return []
+
+ url_list_str = "\n".join([f"{i+1}. {url}" for i, url in enumerate(urls)])
+
+ prompt = f"""You are an expert news source analyst. I will provide a list of URLs found by a reverse image search.
+Your task is to:
+1. **Filter:** Identify only the URLs that are **primary, reputable sources**.
+ * **KEEP:** Reputable news organizations. This includes **major global agencies** (e.g., Reuters, AP, BBC, NYT, CNN) AND **established national or regional news sites** (e.g., telegrafi.com, abc.net.au, vnexpress.net, cbc.ca). Also keep government sites (e.g., nasa.gov) and scientific journals.
+ * **EXCLUDE:** Social media (Facebook, X, Reddit, Instagram), forums, blogs, e-commerce sites (Amazon), and general image aggregators (Pinterest, Imgur).
+2. **Sort:** Sort the **KEPT** URLs by their trustworthiness (most reputable first).
+3. **Output:** Return *only* the **NUMBERS** of the sorted, reputable URLs, separated by commas.
+ * Example: 4, 1, 7
+ * If no URLs are kept, return the single word: None
+
+Here is the list:
+{url_list_str}
+
+Sorted, Reputable URL Indexes:
+"""
+ print(f"Đang gửi {len(urls)} URL cho LLM để lọc và sắp xếp (Lấy chỉ số)...")
+
+ response_text = await asyncio.to_thread(_sync_generation_call, [prompt])
+
+ print(f"[LLM Rank Response]: {response_text}")
+
+ ranked_reputable_urls = []
+
+ if response_text.lower() == "none" or not response_text:
+ print("LLM không giữ lại URL uy tín nào. Trả về danh sách rỗng.")
+ return []
+
+ try:
+ indices = [int(i.strip()) for i in re.findall(r'\d+', response_text)]
+ for idx in indices:
+ if 1 <= idx <= len(urls):
+ url = urls[idx - 1]
+ ranked_reputable_urls.append(url)
+
+ if not ranked_reputable_urls:
+ print("LLM không giữ lại URL uy tín nào (sau khi parse). Trả về danh sách rỗng.")
+ return []
+
+ print(f"LLM đã lọc và CHỈ giữ lại {len(ranked_reputable_urls)} URL uy tín (đã sắp xếp).")
+ return ranked_reputable_urls
+
+ except Exception as e:
+ print(f"Lỗi khi phân tích chỉ số (index) từ LLM: {e}. Trả về danh sách rỗng.")
+ return []
+
+
+def check_domain_with_llm(domain: str) -> Tuple[str, str]:
+ if model is None:
+ return "UNCLEAR", "Gemini model not initialized."
+ return "UNCLEAR", "Function is deprecated"
+
+async def get_llm_analysis(
+ image_path: str,
+ scraped_data: Dict[str, str],
+ source_domain: str,
+ sim_gt_90: bool
+) -> Tuple[str, str]:
+ """
+ Gửi 4 mẩu bằng chứng (Ảnh, VĂN BẢN SẠCH, Domain, Sim) cho Gemini để phân tích.
+ """
+ if model is None:
+ return "UNCLEAR", "Gemini model not initialized (missing API key)."
+
+ try:
+ img = Image.open(image_path)
+ except Exception as e:
+ print(f"❌ Error: Could not open image file: {e}")
+ return "UNCLEAR", f"Could not read image file at {image_path}"
+
+ text_part_1 = """You are an expert image analyst and fact-checker.
+Your task is to determine if the provided image is **DECEPTIVE (a fake)** OR **NON-DECEPTIVE (real or illustrative)**.
+
+You will receive 4 types of evidence:
+1. **IMAGE:** The image file to analyze.
+2. **SOURCE_DOMAIN:** The domain the image was found on (e.g., "bbc.com"). (This source has already been pre-filtered as reputable).
+3. **SIMILARITY:** A flag indicating if the image match was > 0.9 (High) or < 0.9 (Low).
+4. **TEXT_SNIPPETS:** Clean text snippets (alt text, caption, nearby paragraphs) from the webpage.
+
+**--- YOUR ANALYSIS PROCESS ---**
+
+**STEP 1: HIGH-TRUST CONTEXT (If Sim > 0.9 AND Text is "Clean")**
+First, check if both of these conditions are met:
+1. **Is SIMILARITY 'High'?** (The flag is True)
+2. **Are the TEXT_SNIPPETS "clean"?** (Read the text and confirm it does *not* contain any *deceptive* keywords like "fake", "AI-generated fake", "hoax", "disinformation").
+ *(Benign words like 'illustration', 'map', 'composite', 'artwork', 'recreation' are OK and considered "clean").*
+
+* **IF BOTH ARE TRUE:** The text context is considered "clean" and trustworthy.
+ * Your task is to perform a **final verification**:
+ * **Visually inspect the IMAGE** to ensure it matches the trustworthy text.
+ * Ask: "Does this image *look* real, or does it have obvious visual artifacts (AI hands, waxy skin, manipulated elements) that contradict the clean text?"
+* **IF FALSE:** (e.g., Similarity is Low, or text contains "fake"), proceed to Step 2.
+
+**STEP 2: STANDARD ANALYSIS (If "High-Trust" conditions are not met)**
+* **TEXTUAL ANALYSIS:** "Read" the `TEXT_SNIPPETS`.
+ * Look for a reason why the image is DECEPTIVE (e.g., "fake picture", "AI hoax").
+ * Look for the tool used (e.g., "Midjourney", "Photoshop").
+* **VISUAL ANALYSIS:** "Look" at the IMAGE for signs it is DECEPTIVE (e.g., AI artifacts trying to look real, malicious Photoshop).
+
+---
+**EVIDENCE PACKET:**
+"""
+
+ sim_text = "High (Match > 0.9)" if sim_gt_90 else "Low (Match < 0.9)"
+
+ clean_text = "\n".join(scraped_data.values())
+ if not clean_text:
+ clean_text = "(No text snippets found. Rely on visual analysis only.)"
+
+ part_2_header = f"""
+**SOURCE_DOMAIN:** {source_domain}
+**SIMILARITY:** {sim_text}
+**TEXT_SNIPPETS:**
+{clean_text}
+"""
+ text_part_2 = part_2_header
+
+ text_part_3 = f"""
+---
+**FINAL ANALYSIS:**
+
+* If the image is **NON-DECEPTIVE** (an authentic photo, a map, an illustration, a chart, or a clearly non-malicious composite), respond ONLY with the single word: `NOT_AI_RELATED`
+ * (If you are in the "High-Trust" path, add a brief reason: **Reason:** Visually verified as authentic, matching clean text from {source_domain}.)
+
+* If the image is **DECEPTIVE** (an AI-generated fake trying to look real, a malicious Photoshop, or a hoax), respond in the following format:
+ `AI_RELATED`
+ **Reason:** [Explain *why* it's deceptive.
+ (If in "High-Trust" path, state: "Text was clean, but visual analysis detected {{artifacts}}."
+ (If in "Standard" path, PRIORITIZE text (e.g., "The web page caption states 'a fake picture'"). IF from visual, describe artifact.)]
+ **Tool:** [Identify the tool. IF from text, quote the text (e.g., "The web page mentions 'Midjourney'"). IF predicted, state (e.g., "Predicted: Photoshop"). If unknown, state "Unknown".]
+
+* If you are completely uncertain, respond ONLY with the single word: `UNCLEAR`
+"""
+
+ print(f"Sending 4-EVIDENCE PACKET (Img, Domain, Sim, Text) to Gemini...")
+
+ llm_result_text = "UNCLEAR"
+ reasoning_text = "Error: Unknown."
+
+ try:
+ prompt_parts = [
+ text_part_1,
+ text_part_2,
+ text_part_3,
+ img,
+ ]
+
+ full_response_text = await asyncio.to_thread(_sync_generation_call, prompt_parts)
+
+ print(f"[LLM Multimodal Response]:\n{full_response_text}")
+
+ reasoning_text = full_response_text
+ first_line_upper = full_response_text.split('\n')[0].upper()
+
+ if "NOT_AI_RELATED" in first_line_upper:
+ llm_result_text = "NOT_AI_RELATED"
+ elif "AI_RELATED" in first_line_upper:
+ llm_result_text = "AI_RELATED"
+ else:
+ llm_result_text = "UNCLEAR"
+
+ except Exception as e:
+ print(f"❌ Error calling Gemini API (multimodal): {e}")
+ reasoning_text = f"Gemini Error: {e}"
+
+ if "NOT_AI_RELATED" in llm_result_text:
+ if "Visually verified as authentic" in reasoning_text:
+ return "NOT_AI_RELATED", reasoning_text
+ elif "**Reason:**" in reasoning_text:
+ return "NOT_AI_RELATED", reasoning_text
+ else:
+ return "NOT_AI_RELATED", "Gemini analysis confirmed: NON-DECEPTIVE."
+
+ elif "AI_RELATED" in llm_result_text:
+ return "AI_RELATED", reasoning_text
+
+ else:
+ return "UNCLEAR", f"Gemini was unable to determine. (Response: {reasoning_text[:100]}...)"
\ No newline at end of file
diff --git a/miragenews/img/pixel_utils.py b/miragenews/img/pixel_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0d6712a1cddcd94afc80cc245f0f684bcf754b0
--- /dev/null
+++ b/miragenews/img/pixel_utils.py
@@ -0,0 +1,38 @@
+import torch
+import os
+import asyncio
+import subprocess
+from typing import Tuple, Optional
+from .config import BASE_DIR, MERGED_PT_DIR, REAL_PT_FILENAME, CUSTOM_IMG_DIR_NAME
+
+def predict_authenticity_from_pt(pt_file_path, model, device, best_threshold) -> Tuple[Optional[float], Optional[str]]:
+ if not os.path.exists(pt_file_path): return None, None
+ try:
+ image_encodings = torch.load(pt_file_path).to(device)
+ with torch.no_grad():
+ logits = model(image_encodings)
+ prob_fake = torch.sigmoid(logits).squeeze().item()
+ label = "fake" if prob_fake >= best_threshold else "real"
+ return prob_fake, label
+ except Exception as e:
+ print(f"Error processing PT: {e}")
+ return None, None
+
+async def run_mirage_subprocess():
+ # Chạy encode_predictions.py
+ cmd_encode = f"python miragenews/data/encode_predictions.py --mode image --model_class linear --custom --img_dirs {CUSTOM_IMG_DIR_NAME} --batch_size 16"
+ process_encode = await asyncio.create_subprocess_shell(
+ cmd_encode, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, cwd=BASE_DIR
+ )
+ await process_encode.communicate()
+ if process_encode.returncode != 0: raise subprocess.CalledProcessError(process_encode.returncode, cmd_encode)
+
+ # Chạy manual_merge.py
+ cmd_merge = "python data/manual_merge.py"
+ process_merge = await asyncio.create_subprocess_shell(
+ cmd_merge, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, cwd=BASE_DIR
+ )
+ await process_merge.communicate()
+ if process_merge.returncode != 0: raise subprocess.CalledProcessError(process_merge.returncode, cmd_merge)
+
+ return True
\ No newline at end of file
diff --git a/miragenews/img/resources.py b/miragenews/img/resources.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ca7cffc2b082d6796389bc4407ddf69cf1882ba
--- /dev/null
+++ b/miragenews/img/resources.py
@@ -0,0 +1,78 @@
+import sys
+import torch
+from google.cloud import vision
+from sentence_transformers import SentenceTransformer
+import os
+
+# Import local modules (giữ nguyên logic cũ của bạn)
+
+from .semantic_filter import SemanticFilter
+from miragenews.models import get_model
+from miragenews.data import load_config
+from miragenews.utils import load_model_checkpoint
+
+
+class ResourceManager:
+ def __init__(self):
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.mirage_model = None
+ self.google_vision_client = None
+ self.semantic_filter = None
+ self.clip_model = None
+ self.best_threshold = 0.639
+
+ def load_all(self):
+ print("--- Loading Resources ---")
+ self._load_mirage()
+ self._load_google_vision()
+ self._load_semantic_filter()
+ self._load_clip()
+ print("--- Resources Loaded ---\n")
+
+ def _load_mirage(self):
+ from .config import MIRAGE_CONFIG_IMG, MIRAGE_CONFIG_MULTI
+ try:
+ print("Loading MirageNews model...")
+ config_img = load_config(MIRAGE_CONFIG_IMG)
+ config_multi = load_config(MIRAGE_CONFIG_MULTI)
+ model_instance = get_model(config_img).to(self.device)
+ ckpt_path = config_multi['training']['image_model_path']
+
+ loaded_model, loaded_threshold = load_model_checkpoint(model_instance, ckpt_path)
+ if loaded_threshold is not None:
+ self.best_threshold = loaded_threshold
+
+ loaded_model.eval()
+ self.mirage_model = loaded_model
+ print(f"✅ MirageNews loaded. Threshold: {self.best_threshold}")
+ except Exception as e:
+ print(f"❌ Error loading MirageNews: {e}")
+
+ def _load_google_vision(self):
+ try:
+ if 'GOOGLE_APPLICATION_CREDENTIALS' not in os.environ:
+ raise FileNotFoundError("GOOGLE_APPLICATION_CREDENTIALS missing")
+ self.google_vision_client = vision.ImageAnnotatorClient()
+ print("✅ Google Vision client initialized.")
+ except Exception as e:
+ print(f"❌ Google Vision Error: {e}")
+
+ def _load_semantic_filter(self):
+ try:
+ print("Initializing SemanticFilter...")
+ self.semantic_filter = SemanticFilter()
+ print("✅ SemanticFilter initialized.")
+ except Exception as e:
+ print(f"❌ Error SemanticFilter: {e}")
+ sys.exit(1)
+
+ def _load_clip(self):
+ try:
+ print("Loading CLIP...")
+ self.clip_model = SentenceTransformer('clip-ViT-B-32')
+ print("✅ CLIP loaded.")
+ except Exception as e:
+ print(f"❌ Error loading CLIP: {e}")
+
+# Global instance
+resources = ResourceManager()
\ No newline at end of file
diff --git a/miragenews/img/scraper.py b/miragenews/img/scraper.py
new file mode 100644
index 0000000000000000000000000000000000000000..fae54e876fb077fe8bd3e29eb025f37a1891a160
--- /dev/null
+++ b/miragenews/img/scraper.py
@@ -0,0 +1,120 @@
+# scraper.py
+
+import httpx
+import re
+from bs4 import BeautifulSoup
+from typing import Optional
+import urllib.parse
+
+# Import API Key từ constants.py
+try:
+ from .constants import SCRAPINGDOG_API_KEY
+except ImportError:
+ print("Lỗi: Không tìm thấy tệp constants.py")
+ SCRAPINGDOG_API_KEY = ""
+
+async def async_scrape_url_text(url: str, client: httpx.AsyncClient, max_chars: int = 100000) -> Optional[str]:
+ """
+ Truy cập một URL bằng cách sử dụng API của ScrapingDog,
+ trích xuất các văn bản NGỮ CẢNH (meta, title, alt, article, caption)
+ và trả về một chuỗi văn bản đã làm sạch.
+ """
+
+ # 1. Kiểm tra xem API key đã được đặt chưa
+ # [SỬA LỖI] Sửa lại logic kiểm tra key cho đúng
+ if not SCRAPINGDOG_API_KEY or "DÁN_KEY_CỦA_BẠN" in SCRAPINGDOG_API_KEY:
+ print("[Scraper] Lỗi: SCRAPINGDOG_API_KEY chưa được đặt trong tệp constants.py.")
+ print("[Scraper] Hãy đăng ký tại scrapingdog.com và thêm key của bạn.")
+ return None
+
+ # 2. Xây dựng URL gọi API (Giữ nguyên)
+ encoded_url = urllib.parse.quote(url)
+ api_endpoint = f"https://api.scrapingdog.com/scrape?api_key={SCRAPINGDOG_API_KEY}&url={encoded_url}"
+
+ try:
+ # 3. Gửi yêu cầu đến ScrapingDog (Giữ nguyên)
+ print(f"[Scraper] Đang gọi API ScrapingDog cho: {url}")
+ response = await client.get(api_endpoint, timeout=30.0)
+ response.raise_for_status()
+
+ # 4. Lấy HTML (Giữ nguyên)
+ html_content = response.text
+
+ # 5. [THAY ĐỔI] Phân tích HTML để lấy các phần cụ thể
+ soup = BeautifulSoup(html_content, 'html.parser')
+
+ # Tạo một danh sách để chứa các đoạn văn bản tìm thấy
+ scraped_texts = []
+
+ # 5a. Lấy Title của trang
+ page_title = soup.find('title')
+ if page_title and page_title.string:
+ scraped_texts.append(f"Page Title: {page_title.string.strip()}")
+
+ # 5b. Lấy Meta Description
+ meta_desc = soup.find('meta', attrs={'name': 'description'})
+ if meta_desc and meta_desc.get('content'):
+ scraped_texts.append(f"Meta Description: {meta_desc['content'].strip()}")
+
+ # 5c. Lấy nội dung từ thẻ
+ # Nếu không có thẻ , sẽ thử tìm thẻ
+ content_source = soup.find_all('article')
+ content_prefix = "Article Content"
+
+ if not content_source:
+ main_tag = soup.find('main')
+ if main_tag:
+ content_source = [main_tag] # Đặt vào list để xử lý chung
+ content_prefix = "Main Content"
+
+ if content_source:
+ for content_tag in content_source:
+ # Dọn dẹp script/style bên trong article/main
+ for script_or_style in content_tag(["script", "style", "nav", "footer", "aside", "header"]):
+ script_or_style.decompose()
+
+ content_text = content_tag.get_text(separator=' ', strip=True)
+ content_text = re.sub(r'\s+', ' ', content_text)
+ if content_text:
+ scraped_texts.append(f"{content_prefix}: {content_text}")
+
+ # 5d. Lấy tất cả văn bản 'alt' của hình ảnh
+ all_alts = []
+ for img in soup.find_all('img'):
+ alt = img.get('alt')
+ # Bỏ qua nếu alt rỗng hoặc chỉ có khoảng trắng
+ if alt and not alt.isspace():
+ all_alts.append(alt.strip())
+
+ if all_alts:
+ # Xóa trùng lặp nhưng giữ nguyên thứ tự
+ unique_alts = list(dict.fromkeys(all_alts))
+ scraped_texts.append(f"Image Alt Texts: {', '.join(unique_alts)}")
+
+ # 5e. Lấy tất cả nội dung 'figcaption' (chú thích ảnh)
+ all_captions = []
+ for caption in soup.find_all('figcaption'):
+ caption_text = caption.get_text(separator=' ', strip=True)
+ if caption_text and not caption_text.isspace():
+ all_captions.append(re.sub(r'\s+', ' ', caption_text))
+
+ if all_captions:
+ unique_captions = list(dict.fromkeys(all_captions))
+ scraped_texts.append(f"Captions: {', '.join(unique_captions)}")
+
+ # 6. Kết hợp tất cả văn bản tìm thấy thành một chuỗi duy nhất
+ final_text = "\n\n".join(scraped_texts)
+
+ if not final_text:
+ print("[Scraper] API thành công, nhưng không tìm thấy nội dung mục tiêu (title, meta, article, alt, caption).")
+ return None
+
+ print("[Scraper] API ScrapingDog cào và trích xuất thành công.")
+ return final_text[:max_chars] # Trả về text đã trích xuất
+
+ except httpx.HTTPStatusError as e:
+ print(f"[Scraper] Lỗi API ScrapingDog (HTTP {e.response.status_code}): {e.response.text}")
+ return None
+ except Exception as e:
+ print(f"[Scraper] Lỗi chi tiết khi gọi API ScrapingDog {url}: {e}")
+ return None
diff --git a/miragenews/img/semantic_filter.py b/miragenews/img/semantic_filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..738d4a47e02c25ce82100d77b020d5f521d6f241
--- /dev/null
+++ b/miragenews/img/semantic_filter.py
@@ -0,0 +1,96 @@
+import re
+from sentence_transformers import SentenceTransformer, util
+
+class SemanticFilter:
+ def __init__(self):
+ model_name = 'all-MiniLM-L6-v2'
+ print(f"Đang tải mô hình lọc ngữ nghĩa ({model_name})...")
+ try:
+ self.model = SentenceTransformer(model_name)
+ print("✅ Mô hình lọc đã tải thành công.")
+ except Exception as e:
+ print(f"❌ Lỗi khi tải mô hình SentenceTransformer: {e}")
+ print("Hãy thử chạy lại 'pip install -U sentence-transformers'")
+ self.model = None
+
+ # 1. [ĐÃ MỞ RỘNG] Thêm các khái niệm cho các từ nhập nhằng
+ self.ai_concepts = [
+ # Tiếng Việt
+ 'ảnh do AI tạo',
+ 'sản phẩm của trí tuệ nhân tạo',
+ 'dùng mô hình AI để vẽ',
+ 'bức ảnh này không có thật',
+ 'đây là ảnh giả mạo',
+ 'mạng đối nghịch tạo sinh', # (GAN)
+
+ # Tiếng Anh (Chung)
+ 'AI generated image',
+ 'created with an artificial intelligence model',
+ 'synthetic media',
+ 'generative art',
+ 'this is not a real photo',
+ 'this person does not exist',
+ 'AI image generator tool',
+ 'this picture is fake',
+
+ # Tiếng Anh (Kỹ thuật & Tên riêng)
+ 'diffusion model', # (diffusion)
+ 'generative adversarial network', # (GAN)
+ 'made with Midjourney',
+ 'made with Stable Diffusion', # (SD)
+ 'made with DALL-E',
+
+ # [THÊM MỚI] Khái niệm cho các từ nhập nhằng
+ 'CLIP model', # (clip)
+ 'variational autoencoder', # (vae)
+ 'LoRA adapter', # (lora)
+ 'Sora video model', # (sora)
+ 'CGI rendering' # (cgi)
+ ]
+
+ # 2. Mã hóa các khái niệm
+ if self.model:
+ self.concept_embeddings = self.model.encode(
+ self.ai_concepts,
+ convert_to_tensor=True
+ )
+
+ def find_suspicious_sentences(self, content: str, threshold: float = 0.5) -> list:
+ if self.model is None:
+ print("Lỗi: Mô hình lọc ngữ nghĩa chưa được tải.")
+ return []
+
+ # Tách văn bản thành các câu
+ sentences = re.split(r'[.!?\n]+', content)
+
+ # Lọc ra các câu rỗng hoặc quá ngắn
+ valid_sentences = [s.strip() for s in sentences if len(s.strip()) > 10]
+
+ if not valid_sentences:
+ return []
+
+ # 3. Mã hóa tất cả các câu cào được
+ print(f"Đang mã hóa {len(valid_sentences)} câu để lọc ngữ nghĩa...")
+ sentence_embeddings = self.model.encode(
+ valid_sentences,
+ convert_to_tensor=True,
+ show_progress_bar=False
+ )
+
+ # 4. Thực hiện tìm kiếm ngữ nghĩa
+ hits = util.semantic_search(
+ self.concept_embeddings,
+ sentence_embeddings,
+ top_k=5
+ )
+
+ suspicious_sentences = set()
+
+ # 5. Lọc ra những câu có điểm số cao (khả nghi)
+ for i in range(len(self.ai_concepts)):
+ for hit in hits[i]:
+ if hit['score'] >= threshold:
+ sentence = valid_sentences[hit['corpus_id']]
+ suspicious_sentences.add(sentence)
+
+ return list(suspicious_sentences)
diff --git a/miragenews/img/web_utils.py b/miragenews/img/web_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e689614327f5ce77085fa6029b3b9f0d9b4bc891
--- /dev/null
+++ b/miragenews/img/web_utils.py
@@ -0,0 +1,225 @@
+import httpx
+import requests
+import io
+import re
+import numpy as np
+from typing import List, Tuple, Optional
+from bs4 import BeautifulSoup
+from urllib.parse import urlparse, urljoin
+from PIL import Image
+import gradio as gr
+from sentence_transformers import SentenceTransformer
+
+from .constants import ai_domains, AI_KEYWORDS_HIGH_CERTAINTY
+from .scraper import async_scrape_url_text
+from .llm_analyzer import check_domain_with_llm, get_llm_analysis
+
+def check_google_safe_browsing(urls: list, api_key: str) -> dict:
+ if not api_key: return {}
+ api_endpoint = f"https://safebrowsing.googleapis.com/v4/threatMatches:find?key={api_key}"
+ payload = {
+ "client": {"clientId": "miragenews-checker", "clientVersion": "1.0.0"},
+ "threatInfo": {
+ "threatTypes": ["MALWARE", "SOCIAL_ENGINEERING", "UNWANTED_SOFTWARE", "POTENTIALLY_HARMFUL_APPLICATION"],
+ "platformTypes": ["ANY_PLATFORM"],
+ "threatEntryTypes": ["URL"],
+ "threatEntries": [{"url": url} for url in urls if url]
+ }
+ }
+ if not payload["threatInfo"]["threatEntries"]: return {}
+ try:
+ response = requests.post(api_endpoint, json=payload, timeout=10)
+ return response.json() if response.status_code == 200 else {}
+ except Exception:
+ return {}
+
+def get_image_embedding(model, image_pil: Image.Image) -> np.ndarray:
+ return model.encode(image_pil)
+
+def calculate_similarity(emb1: np.ndarray, emb2: np.ndarray) -> float:
+ emb1 = emb1 / np.linalg.norm(emb1)
+ emb2 = emb2 / np.linalg.norm(emb2)
+ return np.dot(emb1, emb2)
+
+async def scrape_html_with_fallback(url: str, client: httpx.AsyncClient, api_key: Optional[str]) -> Optional[str]:
+ try:
+ headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'}
+ print(f"[Scrape HTML] Trying httpx for: {url}")
+ response = await client.get(url, headers=headers, follow_redirects=True, timeout=10)
+ response.raise_for_status()
+ return response.text
+ except Exception as e:
+ print(f"⚠️ Scrape (httpx) thất bại cho {url}: {e}. Thử ScrappingDog...")
+ if not api_key:
+ print("❌ Không có SCRAPINGDOG_API_KEY. Bỏ qua fallback.")
+ return None
+ try:
+ sd_url = f"https://api.scrapingdog.com/scrape?api_key={api_key}&url={url}&dynamic=true&country=us" # Đã sửa dynamic=true
+ print(f"[Scrape HTML] Trying ScrappingDog for: {url}")
+ response = await client.get(sd_url, timeout=30)
+ response.raise_for_status()
+ return response.text
+ except Exception as e2:
+ print(f"❌ Scrape (ScrappingDog) cũng thất bại cho {url}: {e2}")
+ return None
+
+
+
+async def find_best_url_fast_scan_bs4(
+ input_embedding: np.ndarray,
+ urls: List[str],
+ client: httpx.AsyncClient,
+ model: SentenceTransformer,
+ api_key: Optional[str],
+ progress: gr.Progress
+) -> Tuple[Optional[str], float]:
+
+ max_sim = -1.0
+ best_url = None
+
+ if not urls:
+ return None, 0.0
+
+ for i, url in enumerate(urls):
+ progress_step = (0.6 - 0.4) / len(urls)
+ progress(0.4 + (i * progress_step), desc=f"Phase 1 (Fast Scan): {i+1}/{len(urls)}")
+
+ print(f"--- [Fast Scan] Đang quét trang {i+1}/{len(urls)}: {url} ---")
+
+ html = await scrape_html_with_fallback(url, client, api_key)
+ if not html:
+ print("[Fast Scan] Scrape HTML thất bại, bỏ qua.")
+ continue
+
+ try:
+ soup = BeautifulSoup(html, 'html.parser')
+ img_tags = soup.find_all('img')
+ except Exception as e:
+ print(f"[Fast Scan] Lỗi phân tích HTML: {e}")
+ continue
+
+ if not img_tags:
+ print("[Fast Scan] Không tìm thấy thẻ
.")
+ continue
+
+ for img_tag in img_tags:
+ img_src = img_tag.get('src')
+ if not img_src or img_src.startswith('data:'):
+ continue
+
+ try:
+ img_url = urljoin(url, img_src)
+ img_response = await client.get(img_url, timeout=5)
+ img_response.raise_for_status()
+
+ scraped_image_pil = Image.open(io.BytesIO(img_response.content)).convert("RGB")
+
+ if scraped_image_pil.width < 50 or scraped_image_pil.height < 50:
+ continue
+
+ scraped_embedding = get_image_embedding(model, scraped_image_pil)
+ sim = calculate_similarity(input_embedding, scraped_embedding)
+
+ if sim > max_sim:
+ max_sim = sim
+ best_url = url
+
+ if sim > 0.9:
+ print(f"✅ [Fast Scan] TÌM THẤY KHỚP > 0.9 (Sim: {sim:.4f}) tại: {url}")
+ return url, sim
+
+ except Exception as e:
+ pass
+
+ if best_url:
+ print(f"ℹ️ [Fast Scan] Không tìm thấy > 0.9. Chọn URL khớp nhất: {best_url} (Sim: {max_sim:.4f})")
+ return best_url, max_sim
+
+ if not best_url and urls:
+ print(f"ℹ️ [Fast Scan] Không tìm thấy ảnh nào. Chọn URL đầu tiên làm dự phòng.")
+ return urls[0], 0.0
+
+ return None, 0.0
+
+
+
+async def get_html_context_block_bs4(
+ input_embedding: np.ndarray,
+ url: str,
+ client: httpx.AsyncClient,
+ model: SentenceTransformer,
+ api_key: Optional[str],
+ progress: gr.Progress
+) -> str:
+ print(f"--- [Deep Scan] Lấy khối HTML từ: {url} ---")
+ progress(0.6, desc="Phase 2/2: Deep scan (Fetching HTML block)...")
+
+ html = await scrape_html_with_fallback(url, client, api_key)
+ if not html:
+ print("[Deep Scan] Scrape HTML thất bại.")
+ return ""
+
+ try:
+ soup = BeautifulSoup(html, 'html.parser')
+ img_tags = soup.find_all('img')
+
+ best_tag = None
+ max_sim = -1.0
+
+ for img_tag in img_tags:
+ img_src = img_tag.get('src')
+ if not img_src or img_src.startswith('data:'):
+ continue
+
+ try:
+ img_url = urljoin(url, img_src)
+ img_response = await client.get(img_url, timeout=5)
+ img_response.raise_for_status()
+
+ scraped_image_pil = Image.open(io.BytesIO(img_response.content)).convert("RGB")
+
+ if scraped_image_pil.width < 50 or scraped_image_pil.height < 50:
+ continue
+
+ scraped_embedding = get_image_embedding(model, scraped_image_pil)
+ sim = calculate_similarity(input_embedding, scraped_embedding)
+
+ if sim > max_sim:
+ max_sim = sim
+ best_tag = img_tag
+
+ except Exception:
+ pass
+
+ if best_tag:
+ print(f"[Deep Scan] Tìm thấy ảnh khớp nhất (Sim: {max_sim:.4f}). Đang tìm khối cha...")
+
+ current = best_tag
+ for _ in range(5):
+ parent = current.parent
+ if parent is None or parent.name == 'body':
+ break
+ parent_name = parent.name.lower()
+ if parent_name in ['article', 'section', 'li', 'main']:
+ print(f"[Deep Scan] Tìm thấy khối semantic: <{parent_name}>")
+ return str(parent)
+ if parent_name == 'div':
+ class_list = parent.get('class', [])
+ if any(cls in ['content', 'post', 'article', 'story-body', 'caption'] for cls in class_list):
+ print(f"[Deep Scan] Tìm thấy khối div quan trọng: {class_list}")
+ return str(parent)
+ current = parent
+
+ print("[Deep Scan] Không tìm thấy khối semantic. Trả về 3 cấp cha.")
+ parent_block = best_tag.parent.parent
+ if parent_block:
+ return str(parent_block)
+ else:
+ return str(best_tag.parent)
+ else:
+ print("[Deep Scan] Không tìm thấy ảnh khớp nào.")
+ return ""
+ except Exception as e:
+ print(f"❌ [Deep Scan] Lỗi khi phân tích HTML: {e}")
+ return ""
+
diff --git a/miragenews/main_exe.py b/miragenews/main_exe.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cd0096aa896b4eb28ee7d13a8daf4f45345a65a
--- /dev/null
+++ b/miragenews/main_exe.py
@@ -0,0 +1,48 @@
+import json
+import os
+from typing import List, Optional, Dict
+
+from .merge import run_multimodal_analysis
+from .img.resources import resources
+
+
+print("⏳ Loading Models...")
+try:
+ resources.load_all()
+ print("✅ Models Ready.")
+except Exception as e:
+ print(f"⚠️ Warning loading models: {e}")
+
+
+async def run_multimodal_to_json(
+ image_paths: Optional[List[str]] = None,
+ text: Optional[str] = None,
+ output_json_path: Optional[str] = "result.json"
+) -> List[Dict]:
+ """
+ Xử lý ảnh + text và ghi kết quả ra file JSON.
+
+ Args:
+ image_paths: list path ảnh
+ text: text input
+ output_json_path: nơi lưu file json
+
+ Returns:
+ path của file json
+ """
+ safe_text = text or ""
+ image_paths = image_paths or []
+
+ # Run analysis
+ result_dict = await run_multimodal_analysis(
+ image_paths,
+ safe_text
+ )
+
+ # Write JSON or return dict
+ if output_json_path:
+ with open(output_json_path, "w", encoding="utf-8") as f:
+ json.dump(result_dict, f, ensure_ascii=False, indent=2)
+ return output_json_path
+ else:
+ return result_dict
diff --git a/miragenews/merge.py b/miragenews/merge.py
new file mode 100644
index 0000000000000000000000000000000000000000..affcfd94619407c7f07f85c3394215416438338e
--- /dev/null
+++ b/miragenews/merge.py
@@ -0,0 +1,195 @@
+import asyncio
+import inspect
+import re
+import os
+from .img.core import analyze_saved_images
+from .text_module.pipeline import verify_text_logic
+from .text_module.TextAnalysisResult import TextAnalysisResult
+
+class MockGradioFile(dict):
+ def __init__(self, path):
+ filename = os.path.basename(path)
+ data = {
+ "name": path,
+ "path": path,
+ "saved_path": path,
+ "orig_name": filename,
+ "original_filename": filename,
+ "size": os.path.getsize(path) if os.path.exists(path) else 0
+ }
+ super().__init__(data)
+ for k, v in data.items():
+ setattr(self, k, v)
+
+def parse_child_report(report_text):
+ data = {"auth": "N/A", "tools": "Unknown", "synth": "N/A", "artifacts": ""}
+ if not report_text: return data
+
+ auth_match = re.search(r"Authenticity Assessment:\s*(.+)", report_text)
+ if auth_match: data["auth"] = auth_match.group(1).strip()
+
+ tools_match = re.search(r"Verification Tools & Methods:\s*(.+)", report_text)
+ if tools_match: data["tools"] = tools_match.group(1).strip()
+
+ synth_match = re.search(r"Synthetic Type \(if applicable\):\s*(.+)", report_text)
+ if synth_match: data["synth"] = synth_match.group(1).strip()
+
+ art_match = re.search(r"Other Artifacts:\s*(.*)", report_text, re.DOTALL)
+ if art_match: data["artifacts"] = art_match.group(1).strip()
+
+ return data
+
+def is_verdict_fake(assessment_string):
+ if not assessment_string: return False
+ s = assessment_string.lower().strip()
+ fake_keywords = ["not real", "fake", "manipulated", "generated", "artificial", "synthetic"]
+ return any(kw in s for kw in fake_keywords)
+
+async def consume_async_generator(gen):
+ last_result = None
+ if inspect.isasyncgen(gen):
+ async for item in gen: last_result = item
+ else:
+ for item in gen: last_result = item
+ return last_result
+
+async def run_multimodal_analysis(image_paths: list, text_input: str) -> dict:
+
+ async def task_image():
+ if not image_paths: return []
+ all_image_reports = []
+ for img_path in image_paths:
+ try:
+ gradio_inputs = [MockGradioFile(img_path)]
+ gen = analyze_saved_images(gradio_inputs)
+ result_tuple = await consume_async_generator(gen)
+ if result_tuple:
+ _, report_md = result_tuple
+ all_image_reports.append(report_md)
+ else:
+ all_image_reports.append(None)
+ except Exception as e:
+ import traceback
+ traceback.print_exc()
+ all_image_reports.append(f"Error Image: {str(e)}")
+ return all_image_reports
+
+ async def task_text():
+ if not text_input: return None
+ txt_res_obj = TextAnalysisResult()
+ try:
+ await asyncio.to_thread(verify_text_logic, text_input, txt_res_obj)
+ return txt_res_obj
+ except Exception as e:
+ txt_res_obj.set_authenticity_assessment("Error")
+ txt_res_obj.set_other_artifacts(str(e))
+ return txt_res_obj
+
+ img_report_raw, txt_res_obj = await asyncio.gather(task_image(), task_text())
+
+ all_final_results = []
+
+ txt_auth = "No text"
+ txt_tools = ""
+ txt_synth = "N/A"
+ txt_art = ""
+
+ if txt_res_obj:
+ txt_auth = txt_res_obj.get_authenticity_assessment()
+ txt_tools = txt_res_obj.get_verification_tools_methods()
+ txt_synth = txt_res_obj.get_synthetic_type()
+ txt_art = txt_res_obj.get_other_artifacts()
+
+ txt_is_fake = is_verdict_fake(txt_auth)
+
+ if not img_report_raw and not text_input:
+ all_final_results.append({
+ "authenticity_assessment": "⚠️ No Input Provided",
+ "verification_tools_methods": "",
+ "synthetic_type": "N/A",
+ "other_artifacts": "No image or text was provided for analysis."
+ })
+ return all_final_results
+
+ if not img_report_raw: # Only text analysis
+ img_parsed = {"auth": "No images", "tools": "", "synth": "N/A", "artifacts": ""}
+ final_auth = "🤖 NOT REAL (Fake, Manipulated, or AI)" if txt_is_fake else "REAL (Authentic)"
+ final_tools = f"Verified by our model using algorithms SearchLLM."
+ final_synth = f"Text: {txt_synth}" if txt_is_fake and txt_synth != "N/A" else "N/A"
+ final_artifacts_str = f"**For Text:** {txt_art}" if txt_art != "N/A" else "No additional artifacts for text."
+
+ all_final_results.append({
+ "filename": "N/A",
+ "text_used": text_input,
+ "result": {
+ "authenticity_assessment": final_auth,
+ "verification_tools_methods": final_tools,
+ "synthetic_type": final_synth,
+ "other_artifacts": final_artifacts_str
+ }
+ })
+ return all_final_results
+
+ for idx, img_report in enumerate(img_report_raw):
+ img_filename = os.path.basename(image_paths[idx]) if idx < len(image_paths) else "Unknown Image"
+ img_parsed = {"auth": "No images", "tools": "", "synth": "N/A", "artifacts": ""}
+ if img_report:
+ if "Error Image" in img_report:
+ img_parsed["artifacts"] = img_report
+ img_parsed["auth"] = "Error"
+ else:
+ img_parsed = parse_child_report(img_report)
+
+ img_is_fake = is_verdict_fake(img_parsed["auth"])
+
+ if img_is_fake or txt_is_fake:
+ final_auth = "🤖 NOT REAL (Fake, Manipulated, or AI)"
+ else:
+ final_auth = "REAL (Authentic)"
+
+ final_tools = f"Verified by our model using algorithms SearchLLM and ImageForensics."
+
+ synth_list = []
+ if img_is_fake:
+ s_type = img_parsed["synth"] if img_parsed["synth"] != "N/A" else "Manipulated Image"
+ synth_list.append(f"Image: {s_type}")
+
+ if txt_is_fake:
+ s_type = txt_synth
+ if not s_type or s_type == "N/A": s_type = "Generated Content"
+ synth_list.append(f"Text: {s_type}")
+
+ final_synth = " | ".join(synth_list) if synth_list else "N/A"
+
+ final_artifacts_str = ""
+
+ if img_is_fake and txt_is_fake:
+ final_artifacts_str = f"**[Image Evidence for {img_filename}]**\n{img_parsed['artifacts']}\n\n**[Text Evidence]**\n{txt_art}"
+ elif img_is_fake:
+ final_artifacts_str = f"**[Image Evidence for {img_filename}]**\n{img_parsed['artifacts']}"
+ elif txt_is_fake:
+ final_artifacts_str = f"**[Text Evidence]**\n{txt_art}"
+ else:
+ img_src = img_parsed.get('artifacts', '').strip()
+ if img_src and img_src != "N/A" and "No details" not in img_src and "Error" not in img_src:
+ final_artifacts_str += f"**For Image {img_filename}:** {img_src}"
+
+ if txt_art and txt_art != "N/A":
+ if final_artifacts_str: final_artifacts_str += "\n\n"
+ final_artifacts_str += f"**For Text:** {txt_art}"
+
+ if not final_artifacts_str:
+ final_artifacts_str = "Both image and text are verified as authentic."
+
+ all_final_results.append({
+ "filename": img_filename,
+ "text_used": text_input,
+ "result": {
+ "authenticity_assessment": final_auth,
+ "verification_tools_methods": final_tools,
+ "synthetic_type": final_synth,
+ "other_artifacts": final_artifacts_str
+ }
+ })
+
+ return all_final_results
\ No newline at end of file
diff --git a/miragenews/merge_img_text.py b/miragenews/merge_img_text.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc2b49b14cb0e206ad025ba6ce46a8765fab7012
--- /dev/null
+++ b/miragenews/merge_img_text.py
@@ -0,0 +1,249 @@
+import asyncio
+import inspect
+import re
+from img.core import analyze_saved_images
+from text_module.pipeline import verify_text_logic
+from text_module.TextAnalysisResult import TextAnalysisResult
+
+# --- HELPER: BÓC TÁCH REPORT ---
+def parse_child_report(report_text):
+ """
+ Dùng Regex lấy giá trị từng dòng cụ thể.
+ """
+ data = {
+ "auth": "N/A", "tools": "Unknown", "synth": "N/A", "artifacts": ""
+ }
+ if not report_text: return data
+
+ # 1. Lấy Authenticity Assessment (Quan trọng nhất)
+ # Regex này chỉ lấy nội dung trên cùng 1 dòng sau dấu hai chấm
+ auth_match = re.search(r"Authenticity Assessment:\s*(.+)", report_text)
+ if auth_match:
+ data["auth"] = auth_match.group(1).strip()
+
+ # 2. Lấy Tools
+ tools_match = re.search(r"Verification Tools & Methods:\s*(.+)", report_text)
+ if tools_match:
+ data["tools"] = tools_match.group(1).strip()
+
+ # 3. Lấy Synthetic Type
+ synth_match = re.search(r"Synthetic Type \(if applicable\):\s*(.+)", report_text)
+ if synth_match:
+ data["synth"] = synth_match.group(1).strip()
+
+ # 4. Lấy Artifacts (Lấy từ dòng đó xuống hết)
+ art_match = re.search(r"Other Artifacts:\s*(.*)", report_text, re.DOTALL)
+ if art_match:
+ data["artifacts"] = art_match.group(1).strip()
+
+ return data
+
+# --- HELPER: CHECK FAKE CHỈ TRÊN DÒNG ASSESSMENT ---
+def is_verdict_fake(assessment_string):
+ if not assessment_string: return False
+ s = assessment_string.lower().strip()
+
+ # Các từ khóa khẳng định là FAKE
+ fake_keywords = ["not real", "fake", "manipulated", "generated", "artificial", "synthetic"]
+
+ for kw in fake_keywords:
+ if kw in s:
+ return True
+ return False
+
+# --- HTML STATUS BAR (GIỮ NGUYÊN) ---
+def create_status_html(label, status, message):
+ color = "#9ca3af"; percent = 5; icon = "⏳"; bg_pulse = ""; text_color = "#374151"
+ if status == 'processing':
+ color = "#2563eb"; percent = 60; icon = "⚙️"; bg_pulse = "animation: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;"
+ elif status == 'done':
+ color = "#16a34a"; percent = 100; icon = "✅"
+ elif status == 'error':
+ color = "#dc2626"; percent = 100; icon = "❌"
+
+ html = f"""
+
+
+
+ {icon} {label}
+ {message}
+
+
+
+ """
+ return html
+
+# --- TASK 1: XỬ LÝ ẢNH ---
+async def run_image_task(shared_state, image_input):
+ shared_state['img_status'] = 'processing'
+ shared_state['img_msg'] = "Scanning artifacts..."
+ img_result_obj = TextAnalysisResult()
+ try:
+ final_json = {}
+ final_report_md = ""
+ gen = analyze_saved_images(image_input)
+ if inspect.isasyncgen(gen):
+ async for res in gen: final_json, final_report_md = res
+ else:
+ for res in gen: final_json, final_report_md = res
+
+ # Lưu toàn bộ chuỗi report vào artifact
+ img_result_obj.set_other_artifacts(final_report_md)
+
+ # Parse lấy đúng dòng Auth để set status cho object (để dùng cho short-circuit nếu cần)
+ parsed = parse_child_report(final_report_md)
+ img_result_obj.set_authenticity_assessment(parsed["auth"])
+
+ shared_state['img_status'] = 'done'
+ shared_state['img_msg'] = "Done"
+ except Exception as e:
+ shared_state['img_status'] = 'error'
+ shared_state['img_msg'] = "Error"
+ img_result_obj.set_authenticity_assessment("Error")
+ return img_result_obj
+
+# --- TASK 2: XỬ LÝ TEXT ---
+async def run_text_task(shared_state, text_input):
+ shared_state['txt_status'] = 'processing'
+ shared_state['txt_msg'] = "Verifying logic..."
+ txt_result_obj = TextAnalysisResult()
+ try:
+ await asyncio.to_thread(verify_text_logic, text_input, txt_result_obj)
+ shared_state['txt_status'] = 'done'
+ shared_state['txt_msg'] = "Done"
+ except Exception as e:
+ shared_state['txt_status'] = 'error'
+ shared_state['txt_msg'] = str(e)
+ txt_result_obj.set_authenticity_assessment("Error")
+ return txt_result_obj
+
+# --- MAIN ORCHESTRATOR ---
+async def verify_multimodal_logic(image_state, text_input):
+
+ shared_state = {'img_status': 'waiting', 'img_msg': 'Ready...', 'txt_status': 'waiting', 'txt_msg': 'Ready...'}
+ def get_ui(): return create_status_html("Image Analysis", shared_state['img_status'], shared_state['img_msg']), create_status_html("Text Analysis", shared_state['txt_status'], shared_state['txt_msg']), "..."
+ yield get_ui()
+
+ task_img = asyncio.create_task(run_image_task(shared_state, image_state))
+ task_txt = asyncio.create_task(run_text_task(shared_state, text_input))
+
+ img_res, txt_res = None, None
+
+ # Loop short-circuit
+ while not (task_img.done() and task_txt.done()):
+ yield get_ui()
+ if task_img.done() and img_res is None:
+ try:
+ img_res = task_img.result()
+ # Check Fake chỉ dựa trên Assessment (ngắn gọn)
+ if is_verdict_fake(img_res.get_authenticity_assessment()):
+ if not task_txt.done(): task_txt.cancel(); shared_state['txt_msg'] = "Stopped (Image is Fake)"
+ break
+ except: pass
+
+ if task_txt.done() and txt_res is None:
+ try:
+ txt_res = task_txt.result()
+ # Check Fake chỉ dựa trên Assessment
+ if is_verdict_fake(txt_res.get_authenticity_assessment()):
+ if not task_img.done(): task_img.cancel(); shared_state['img_msg'] = "Stopped (Text is Fake)"
+ break
+ except: pass
+ await asyncio.sleep(0.1)
+
+ if img_res is None and task_img.done(): img_res = task_img.result()
+ if txt_res is None and task_txt.done(): txt_res = task_txt.result()
+ if not img_res: img_res = TextAnalysisResult(authenticity_assessment="Skipped")
+ if not txt_res: txt_res = TextAnalysisResult(authenticity_assessment="Skipped")
+
+ # =========================================================================
+ # LOGIC MERGE: CHỈ DỰA VÀO DÒNG ASSESSMENT
+ # =========================================================================
+
+ # 1. Parse Image Report để lấy dòng "Authenticity Assessment" sạch
+ img_data_parsed = parse_child_report(img_res.get_other_artifacts())
+ img_auth_line = img_data_parsed["auth"] # VD: "🧑 REAL PHOTO"
+
+ # 2. Lấy dòng Assessment của Text
+ txt_auth_line = txt_res.get_authenticity_assessment() # VD: "REAL (Authentic)"
+
+ # 3. KIỂM TRA FAKE/REAL (Dựa trên 2 dòng trên)
+ img_is_fake = is_verdict_fake(img_auth_line)
+ txt_is_fake = is_verdict_fake(txt_auth_line)
+
+ # --- FIELD 1: Authenticity Assessment ---
+ if img_is_fake or txt_is_fake:
+ final_auth = "🤖 NOT REAL (Fake, Manipulated, or AI)"
+ color_hex = "#dc2626"
+ else:
+ final_auth = " REAL (Authentic)"
+ color_hex = "#16a34a"
+
+ # --- FIELD 2: Verification Tools ---
+ final_tools = f"Verified by our model using algorithms SearchLLM and ImageForensics."
+
+ # --- FIELD 3: Synthetic Type ---
+ final_synth_list = []
+
+ # Chỉ lấy Synthetic Type từ module Ảnh nếu Ảnh bị kết luận là Fake
+ if img_is_fake:
+ s_type = img_data_parsed["synth"] if img_data_parsed["synth"] != "N/A" else "Manipulated Image"
+ final_synth_list.append(f"**Image:** {s_type}")
+
+ # Chỉ lấy Synthetic Type từ module Text nếu Text bị kết luận là Fake
+ if txt_is_fake:
+ s_type = txt_res.get_synthetic_type()
+ if not s_type or s_type == "N/A": s_type = "Generated Content"
+ final_synth_list.append(f"**Text:** {s_type}")
+
+ final_synth_str = "\n".join(final_synth_list) if final_synth_list else "N/A"
+
+ # --- FIELD 4: Other Artifacts (Logic hiển thị Source/Artifacts) ---
+ final_artifacts_str = ""
+
+ # Case: Cả 2 Fake -> Show cả 2
+ if img_is_fake and txt_is_fake:
+ final_artifacts_str = f"**[Image Evidence]**\n{img_data_parsed['artifacts']}\n\n**[Text Evidence]**\n{txt_res.get_other_artifacts()}"
+
+ # Case: Chỉ Ảnh Fake -> Show ảnh
+ elif img_is_fake:
+ final_artifacts_str = f"{img_data_parsed['artifacts']}"
+
+ # Case: Chỉ Text Fake -> Show text
+ elif txt_is_fake:
+ final_artifacts_str = f"{txt_res.get_other_artifacts()}"
+
+ # Case: Cả 2 đều REAL -> Show source (nếu có)
+ else:
+ final_artifacts_str = "Both image and text are verified as authentic by our multi-modal pipeline."
+
+ # Check source ảnh (Khác N/A, khác rỗng)
+ img_src = img_data_parsed.get('artifacts', '').strip()
+ if img_src and img_src != "N/A" and "No details" not in img_src:
+ final_artifacts_str += f"\n\n**For Image:** {img_src}"
+
+ # Check source text
+ txt_src = txt_res.get_other_artifacts().strip()
+ if txt_src and txt_src != "N/A":
+ final_artifacts_str += f"\n\n**For Text:** {txt_src}"
+
+ # TẠO FINAL MARKDOWN
+ final_report_md = f"""
+### 📋 Final Verification Report
+
+**1. Authenticity Assessment:**
+{final_auth}
+
+**2. Verification Tools & Methods:**
+{final_tools}
+
+**3. Synthetic Type (if applicable):**
+{final_synth_str}
+
+**4. Other Artifacts:**
+{final_artifacts_str}
+ """
+
+ yield get_ui()[0], get_ui()[1], final_report_md
\ No newline at end of file
diff --git a/miragenews/models/__init__.py b/miragenews/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..86d11c235e709999d121d7540ac4df7ea6cc0170
--- /dev/null
+++ b/miragenews/models/__init__.py
@@ -0,0 +1,27 @@
+from .mirage_img import *
+from .mirage_txt import *
+
+def get_model(config):
+ """
+ Retrieves the model class specified in the config and initializes it with provided parameters.
+ """
+ model_name = config['model']['name']
+ model_params = config['model'].get('params', {})
+
+ # Mapping model names to their classes
+ model_classes = {
+ "img-linear": ImageLinearModel,
+ "cbm-encoder": ObjectClassCBMEncoder,
+ "cbm-predictor": ObjectClassCBMPredictor,
+ "mirage-img": MiRAGeImg,
+ "txt-linear": TextLinearModel,
+ "tbm-predictor": TBMPredictor,
+ "mirage-txt": MiRAGeTxt
+ # Add other models here as needed
+ }
+
+ if model_name in model_classes:
+ model_class = model_classes[model_name]
+ return model_class(**model_params) # Instantiate model with parameters
+ else:
+ raise ValueError(f"Model {model_name} not recognized. Please check config.")
diff --git a/miragenews/models/mirage_img.py b/miragenews/models/mirage_img.py
new file mode 100644
index 0000000000000000000000000000000000000000..d83ed56cf9a1c91b3d24cc727777b0436a15ee6d
--- /dev/null
+++ b/miragenews/models/mirage_img.py
@@ -0,0 +1,46 @@
+import torch
+import torch.nn as nn
+
+class ImageLinearModel(nn.Module):
+ def __init__(self):
+ super(ImageLinearModel, self).__init__()
+ self.linear = nn.Linear(1408, 1)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, image_features):
+ x = self.linear(image_features)
+ x = self.sigmoid(x)
+ return x
+
+class ObjectClassCBMEncoder(nn.Module):
+ def __init__(self):
+ super(ObjectClassCBMEncoder, self).__init__()
+ self.classifiers = nn.ModuleList([nn.Sequential(
+ nn.Linear(1408, 1),
+ nn.Sigmoid()
+ ) for _ in range(300)])
+
+ def forward(self, image_features, classifier_index):
+ return self.classifiers[classifier_index](image_features)
+
+class ObjectClassCBMPredictor(nn.Module):
+ def __init__(self):
+ super(ObjectClassCBMPredictor, self).__init__()
+ self.linear = nn.Linear(300, 1)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, logits_per_image):
+ x = self.linear(logits_per_image)
+ x = self.sigmoid(x)
+ return x
+
+class MiRAGeImg(nn.Module):
+ def __init__(self):
+ super(MiRAGeImg, self).__init__()
+ self.linear = nn.Linear(301, 1)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, logits_per_image):
+ x = self.linear(logits_per_image)
+ x = self.sigmoid(x)
+ return x
\ No newline at end of file
diff --git a/miragenews/models/mirage_txt.py b/miragenews/models/mirage_txt.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ec707b06edb753a97f252923fca9eda0297f4d1
--- /dev/null
+++ b/miragenews/models/mirage_txt.py
@@ -0,0 +1,46 @@
+import torch
+import torch.nn as nn
+
+class TextLinearModel(nn.Module):
+ def __init__(self):
+ super(TextLinearModel, self).__init__()
+ self.linear = nn.Linear(768, 1)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, text_features):
+ x = self.linear(text_features)
+ x = self.sigmoid(x)
+ return x
+
+# class TBMEncoder(nn.Module):
+# def __init__(self):
+# super(TBMEncoder, self).__init__()
+# self.classifiers = nn.ModuleList([nn.Sequential(
+# nn.Linear(1408, 1),
+# nn.Sigmoid()
+# ) for _ in range(300)])
+
+# def forward(self, image_features, classifier_index):
+# return self.classifiers[classifier_index](image_features)
+
+class TBMPredictor(nn.Module):
+ def __init__(self):
+ super(TBMPredictor, self).__init__()
+ self.linear = nn.Linear(18, 1)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, logits_per_text):
+ x = self.linear(logits_per_text)
+ x = self.sigmoid(x)
+ return x
+
+class MiRAGeTxt(nn.Module):
+ def __init__(self):
+ super(MiRAGeTxt, self).__init__()
+ self.linear = nn.Linear(19, 1)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, logits_per_text):
+ x = self.linear(logits_per_text)
+ x = self.sigmoid(x)
+ return x
\ No newline at end of file
diff --git a/miragenews/my_dataset/text/my_single_text_dir/real/real.txt b/miragenews/my_dataset/text/my_single_text_dir/real/real.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6804be6d154e92fa47e3710e7d290575086450e4
--- /dev/null
+++ b/miragenews/my_dataset/text/my_single_text_dir/real/real.txt
@@ -0,0 +1 @@
+An immigrant viciously attacks medical personnel after being wounded in clashes with residents of Rosarno, in the southern region of Calabria.
diff --git a/miragenews/requirements.txt b/miragenews/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..78dc1d737a92e3986caa42f6b09d3c16a1339e82
--- /dev/null
+++ b/miragenews/requirements.txt
@@ -0,0 +1,19 @@
+beautifulsoup4==4.14.3
+clip==1.0
+datasets==4.3.0
+gradio==6.1.0
+httpx==0.28.1
+nltk==3.9.2
+numpy < 2
+Pillow==12.0.0
+protobuf==6.33.2
+python-dotenv==1.2.1
+PyYAML==6.0.3
+Requests==2.32.5
+scikit_learn==1.8.0
+sentence_transformers==5.1.2
+tqdm==4.67.1
+trafilatura==2.0.0
+transformers==4.57.1
+
+
diff --git a/miragenews/scrappingdog.py b/miragenews/scrappingdog.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cef89dd75dc3eca5f2ab4a5eaabf60657e6320a
--- /dev/null
+++ b/miragenews/scrappingdog.py
@@ -0,0 +1,189 @@
+import requests
+from bs4 import BeautifulSoup
+import google.generativeai as genai
+import re
+
+headers = {
+ "User-Agent": "Mozilla/5.0",
+ "Referer": "https://www.duixs.com/"
+}
+domain = "https://www.duixs.com"
+
+def translate_with_gemini(text, target_lang="vi", api_key="AIzaSyBnOoyhsmBjxLB5eHC99soihh7dC3WvQSk",gemini_version="gemini-2.5-flash"):
+ genai.configure(api_key=api_key)
+
+ prompt = f"""
+ Hãy đóng vai một biên dịch viên chuyên nghiệp về tiểu thuyết Tiên Hiệp/Kiếm Hiệp Trung Quốc.
+ Nhiệm vụ của bạn là dịch văn bản sau từ tiếng Trung sang tiếng Việt.
+
+ Yêu cầu quan trọng:
+ 1. Văn phong: Hùng hồn, cổ trang, giữ được cái "chất" của truyện tu tiên.
+ 2. Thuật ngữ: Giữ nguyên các từ Hán Việt đặc thù của thể loại này (ví dụ: Tu vi, Động phủ, Nguyên anh, Pháp bảo...).
+ 3. Đại từ nhân xưng: Dựa vào ngữ cảnh để chọn đại từ phù hợp (Ta/Ngươi, Huynh/Muội, Lão phu/Tiểu tử, Tại hạ/Các hạ...). Tuyệt đối không dùng "tôi/bạn".
+ 4. Tên riêng và địa danh: Phiên âm Hán Việt chuẩn.
+ 5. Định dạng: Giữ nguyên tuyệt đối cấu trúc xuống dòng và hội thoại của bản gốc.
+ 6. Kiểm tra kỹ lỗi chính tả và ngữ pháp, đảm bảo văn bản cuối cùng không có lỗi.
+ 7. Chỉ trả về văn bản đã được dịch, KHÔNG thêm bất kỳ lời giải thích hay chú thích nào khác.
+
+ Văn bản cần dịch:
+ {text}
+ """
+
+ model = genai.GenerativeModel(gemini_version)
+
+ response = model.generate_content(prompt)
+
+ return response.text.strip()
+
+
+# ---------------------------
+# Lấy nội dung 1 chương
+# ---------------------------
+def get_chapter_content(url):
+ # 1. Tải trang đầu tiên
+ print(f"Đang xử lý: {url}")
+ try:
+ html = requests.get(url, headers=headers).text
+ soup = BeautifulSoup(html, "html.parser")
+
+ # Lấy nội dung trang 1
+ full_text = []
+ book_div = soup.select_one("#booktxt")
+ if book_div:
+ full_text.append(book_div.get_text(separator="\n", strip=True))
+
+
+ title_tag = soup.select_one(".bookname")
+
+ total_pages = 1
+ if title_tag:
+ title_text = title_tag.get_text()
+ match = re.search(r'\((\d+)/(\d+)\)', title_text) or re.search(r'(\d+)/(\d+)', title_text)
+
+ if match:
+ total_pages = int(match.group(2))
+ print(f" -> Phát hiện chương có {total_pages} trang con.")
+
+ # 3. Nếu có nhiều hơn 1 trang -> Tự sinh URL
+ if total_pages > 1:
+ # Xử lý cắt đuôi .html để chèn số
+ if url.endswith(".html"):
+ base_url = url[:-5] # Cắt bỏ .html
+ extension = ".html"
+ else:
+ base_url = url
+ extension = ""
+
+ # Vòng lặp từ trang 2 đến trang cuối
+ for i in range(2, total_pages + 1):
+ # Quy luật: url_goc_2.html, url_goc_3.html
+ sub_url = f"{base_url}_{i}{extension}"
+ print(f" -> Đang cào phần {i}/{total_pages}: {sub_url}")
+
+ try:
+ sub_html = requests.get(sub_url, headers=headers).text
+ sub_soup = BeautifulSoup(sub_html, "html.parser")
+ sub_div = sub_soup.select_one("#booktxt")
+ if sub_div:
+ full_text.append(sub_div.get_text(separator="\n", strip=True))
+ except Exception as e:
+ print(f"Lỗi cào trang {sub_url}: {e}")
+
+ return "\n\n".join(full_text)
+
+ except Exception as e:
+ return f"Lỗi xử lý chương: {e}"
+
+# ---------------------------
+# Lấy thông tin cơ bản của truyện
+# ---------------------------
+def get_book_info(url):
+ html = requests.get(url, headers=headers).text
+ soup = BeautifulSoup(html, "html.parser")
+
+ name = soup.select_one('div#info h1').get_text(strip=True)
+ authors = soup.select_one('div#info h1 + p').get_text(strip=True)
+
+ # Xử lý intro
+ intro_div = soup.select_one('div#intro')
+ bad = intro_div.select_one('#downapi')
+ if bad:
+ bad.decompose()
+ intro = intro_div.get_text(strip=True)
+
+ return name, authors, intro, soup
+
+
+# ---------------------------
+# Lấy danh sách URL index page
+# ---------------------------
+def get_index_pages(first_chapter_href):
+ # "txt_123456/xxxx.html" → lấy phần txt_123456
+ part = first_chapter_href.split('/')[2]
+
+ index_url = f"{domain}/read/{part}.html"
+
+ html = requests.get(index_url, headers=headers).text
+ soup = BeautifulSoup(html, "html.parser")
+
+ select_tag = soup.select_one('#indexselect')
+ options = select_tag.select('option')
+
+ index_pages = [domain + opt['value'] for opt in options]
+ return index_pages
+
+
+# ---------------------------
+# Lấy danh sách tất cả chapter
+# ---------------------------
+def get_all_chapters(index_pages):
+ chapter_list = []
+
+ for page in index_pages:
+ html = requests.get(page, headers=headers).text
+ soup = BeautifulSoup(html, "html.parser")
+
+ links = soup.select('div#list dl a[rel="chapter"]')
+ for link in links:
+ chapter_title = link.get_text(strip=True)
+ chapter_url = domain + link['href']
+ chapter_list.append((chapter_title, chapter_url))
+
+ return chapter_list
+
+
+# ---------------------------
+# MAIN
+# ---------------------------
+def main():
+ URL = "https://www.duixs.com/txt_zsg.html"
+
+ # Lấy thông tin truyện
+ name, authors, intro, soup = get_book_info(URL)
+
+ print("Name:", name)
+ print("Authors:", authors)
+ print("Intro:", intro)
+
+ # Lấy link chapter đầu tiên
+ first_chapter = soup.select_one('div#list dl a[rel="chapter"]')
+ first_chapter_url = domain + first_chapter['href']
+ print("First chapter:", first_chapter_url)
+
+ # Lấy danh sách các index pages
+ index_pages = get_index_pages(first_chapter['href'])
+ print("\nTotal index pages:", len(index_pages))
+
+ # Lấy toàn bộ chapter
+ chapters = get_all_chapters(index_pages)
+ print("Total chapters found:", len(chapters))
+
+
+ title, url = chapters[0]
+ print("\n>>> Chapter:", translate_with_gemini(title))
+ print(translate_with_gemini(get_chapter_content(url), gemini_version="gemini-3-pro-preview"))
+
+
+# Chạy
+if __name__ == "__main__":
+ main()
diff --git a/miragenews/test_single_pair.py b/miragenews/test_single_pair.py
new file mode 100644
index 0000000000000000000000000000000000000000..383b6696c8ae063c95d590dc10c7f2c761b04c15
--- /dev/null
+++ b/miragenews/test_single_pair.py
@@ -0,0 +1,131 @@
+import torch
+from models import get_model
+from data import load_config
+from utils import load_model_checkpoint
+import os
+
+# --- 1. TẢI CONFIGS VÀ KHỞI TẠO MODEL (Làm 1 lần) ---
+print("Loading configs and initializing models ONCE...")
+device = "cuda" if torch.cuda.is_available() else "cpu"
+print(f"Using device: {device}")
+
+try:
+ config_img = load_config("configs/image/mirage.yaml")
+ config_multi = load_config("configs/multimodal/mirage.yaml")
+
+ mirage_img = get_model(config_img).to(device)
+
+ # --- 2. Tải Checkpoint Model (Làm 1 lần) ---
+ print("Loading model checkpoint ONCE...")
+ checkpoint_path_img = config_multi['training']['image_model_path']
+ if not os.path.exists(checkpoint_path_img):
+ print(f"FATAL ERROR: Model checkpoint not found at {checkpoint_path_img}")
+ exit()
+ mirage_img, _ = load_model_checkpoint(mirage_img, checkpoint_path_img)
+ print(f"Loaded image model from {checkpoint_path_img}")
+
+ # --- 3. Bật chế độ Evaluation (Làm 1 lần) ---
+ mirage_img.eval()
+ print("Model set to evaluation mode.")
+
+except Exception as e:
+ print(f"FATAL ERROR during model initialization or checkpoint loading: {e}")
+ exit()
+
+# --- 4. HÀM DỰ ĐOÁN TỪ FILE .PT ---
+def predict_authenticity_from_pt(pt_file_path, model, device):
+ """
+ Tải dữ liệu encoding từ file .pt và dự đoán tính xác thực (real/fake).
+
+ Args:
+ pt_file_path (str): Đường dẫn đến file .pt chứa tensor [1, 301].
+ model (torch.nn.Module): Model ảnh đã được tải checkpoint và đặt ở eval mode.
+ device (str): Thiết bị 'cuda' hoặc 'cpu'.
+
+ Returns:
+ tuple: (probability_fake, prediction_label)
+ - probability_fake (float): Xác suất ảnh là GIẢ (0.0 đến 1.0).
+ - prediction_label (str): Nhãn dự đoán ("real" hoặc "fake").
+ Trả về (None, None) nếu có lỗi.
+ """
+ print(f"\n--- Processing: {pt_file_path} ---")
+
+ # 4.1. Kiểm tra file tồn tại
+ if not os.path.exists(pt_file_path):
+ print(f"ERROR: Input file not found at {pt_file_path}")
+ return None, None
+
+ # 4.2. Tải dữ liệu encoding
+ try:
+ image_encodings = torch.load(pt_file_path).to(device)
+ # Kiểm tra shape (phải là [1, 301])
+ if image_encodings.shape != (1, 301):
+ print(f"ERROR: Expected tensor shape [1, 301], but got {image_encodings.shape} from {pt_file_path}")
+ return None, None
+ print(f"Loaded image data (301-dim) with shape: {image_encodings.shape}")
+
+ except Exception as e:
+ print(f"ERROR loading or checking data from {pt_file_path}: {e}")
+ return None, None
+
+ # 4.3. Chạy Inference (Dự đoán)
+ try:
+ with torch.no_grad():
+ output_logits_img = model(image_encodings) # Model ảnh nhận [1, 301]
+ except Exception as e:
+ print(f"ERROR during model inference for {pt_file_path}: {e}")
+ return None, None
+
+ # 4.4. Xử lý kết quả (Logic: 1 = fake)
+ probs_img = torch.sigmoid(output_logits_img)
+ probability_fake = probs_img.squeeze().item() # Xác suất ảnh là GIẢ
+
+ # Quyết định nhãn dựa trên ngưỡng 0.5 (Đảo ngược logic)
+ # Nếu xác suất "fake" >= 0.5 -> fake
+ # Nếu xác suất "fake" < 0.5 -> real
+ prediction_label = "fake" if probability_fake >= 0.5 else "real"
+
+ print(f"Raw output (logits interpreted as 'fake' logit): {output_logits_img.squeeze().item():.4f}")
+ # Cập nhật chuỗi giải thích
+ print(f"Probability (0=real, 1=fake): {probability_fake:.4f}")
+ print(f"Predicted Label: {prediction_label}")
+
+ # Trả về xác suất fake và nhãn
+ return probability_fake, prediction_label
+
+# --- 5. CÁCH SỬ DỤNG HÀM ---
+if __name__ == "__main__":
+ # --- XỬ LÝ MỘT FILE ---
+ # Thay thế bằng đường dẫn đến file .pt bạn muốn kiểm tra
+ input_pt_path_single = "encodings/predictions/image/merged/my_single_image_dir/real.pt"
+
+ print("\n--- Processing a single file ---")
+ prob_fake_single, label_single = predict_authenticity_from_pt(input_pt_path_single, mirage_img, device)
+
+ if prob_fake_single is not None and label_single is not None:
+ print(f"\nFinal result for {input_pt_path_single}: Probability Fake={prob_fake_single:.4f}, Label='{label_single}'")
+ else:
+ print(f"\nFailed to process {input_pt_path_single}.")
+
+ print("\n" + "="*50 + "\n") # Thêm dòng phân cách
+
+ # --- VÍ DỤ XỬ LÝ NHIỀU FILE ---
+ pt_files_to_check = [
+ "encodings/predictions/image/merged/my_single_image_dir/real.pt", # Thay bằng đường dẫn file thật
+ # "encodings/predictions/image/merged/another_dir/fake_image.pt", # ĐÃ XÓA DÒNG NÀY
+ "path/to/nonexistent.pt" # Ví dụ file không tồn tại
+ ]
+ print("\n--- Processing multiple files ---")
+ results = {}
+ for file_path in pt_files_to_check:
+ prob_fake, label = predict_authenticity_from_pt(file_path, mirage_img, device)
+ results[file_path] = (prob_fake, label) # Lưu kết quả vào dictionary
+
+ print("\n--- Summary ---")
+ for file, (prob_fake, label) in results.items():
+ if prob_fake is not None:
+ print(f"{file}: Prob Fake={prob_fake:.4f}, Label='{label}'")
+ else:
+ print(f"{file}: Processing FAILED")
+
+ print("\nScript finished.")
diff --git a/miragenews/text_module/TextAnalysisResult.py b/miragenews/text_module/TextAnalysisResult.py
new file mode 100644
index 0000000000000000000000000000000000000000..11a12e4027a1e23f0eb0edbdada10d5834844aed
--- /dev/null
+++ b/miragenews/text_module/TextAnalysisResult.py
@@ -0,0 +1,70 @@
+import json
+import os
+
+class TextAnalysisResult:
+ def __init__(self,
+ authenticity_assessment: str = "",
+ verification_tools_methods: str = "",
+ synthetic_type: str = "",
+ other_artifacts: str = ""):
+
+ self._authenticity_assessment = authenticity_assessment
+ self._verification_tools_methods = verification_tools_methods
+ self._synthetic_type = synthetic_type
+ self._other_artifacts = other_artifacts
+
+ def get_authenticity_assessment(self) -> str:
+ return self._authenticity_assessment
+
+ def get_verification_tools_methods(self) -> str:
+ return self._verification_tools_methods
+
+ def get_synthetic_type(self) -> str:
+ return self._synthetic_type
+
+ def get_other_artifacts(self) -> str:
+ return self._other_artifacts
+
+ def set_authenticity_assessment(self, value: str):
+ self._authenticity_assessment = value
+
+ def set_verification_tools_methods(self, value: str):
+ self._verification_tools_methods = value
+
+ def set_synthetic_type(self, value: str):
+ self._synthetic_type = value
+
+ def set_other_artifacts(self, value: str):
+ self._other_artifacts = value
+
+ def save_json(self, file_path: str):
+ """
+ Lưu đối tượng vào file JSON theo kiến trúc yêu cầu.
+ """
+ data = {
+ "Authenticity Assessment": self._authenticity_assessment,
+ "Verification Tools & Methods": self._verification_tools_methods,
+ "Synthetic Type (if applicable)": self._synthetic_type,
+ "Other Artifacts": self._other_artifacts
+ }
+
+ try:
+ # Đảm bảo thư mục tồn tại
+ directory = os.path.dirname(file_path)
+ if directory and not os.path.exists(directory):
+ os.makedirs(directory)
+
+ with open(file_path, 'w', encoding='utf-8') as f:
+ # ensure_ascii=False để hiển thị tiếng Việt hoặc ký tự đặc biệt đúng
+ json.dump(data, f, indent=4, ensure_ascii=False)
+
+ print(f"✅ Đã lưu kết quả thành công tại: {file_path}")
+ return True
+ except Exception as e:
+ print(f"❌ Lỗi khi lưu file JSON: {e}")
+ return False
+
+ def __str__(self):
+ """Hiển thị đối tượng dưới dạng string để debug nhanh"""
+ return (f"Authenticity: {self._authenticity_assessment} | "
+ f"Method: {self._verification_tools_methods}")
\ No newline at end of file
diff --git a/miragenews/text_module/__init__.py b/miragenews/text_module/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/miragenews/text_module/config.py b/miragenews/text_module/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3c03938351053622b5af09334c16ff033f96153
--- /dev/null
+++ b/miragenews/text_module/config.py
@@ -0,0 +1,15 @@
+# config.py
+import os
+from dotenv import load_dotenv
+
+load_dotenv()
+
+# API Keys
+GOOGLE_API_KEY = os.getenv("GOOGLE_CSE_CX") # Dùng cho Gemini (theo code cũ của bạn)
+GOOGLE_SAFE_BROWSING_API_KEY = os.getenv("GOOGLE_SAFE_BROWSING_API_KEY")
+GOOGLE_CX = os.getenv("GOOGLE_CSE_CX") # Dùng cho Search
+SCRAPINGDOG_API_KEY = os.getenv("SCRAPINGDOG_API_KEY")
+GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
+
+# Constants
+USER_AGENT = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
\ No newline at end of file
diff --git a/miragenews/text_module/llm_utils.py b/miragenews/text_module/llm_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9693da7d2c4c4df2378acdcec8edafb3722b0130
--- /dev/null
+++ b/miragenews/text_module/llm_utils.py
@@ -0,0 +1,77 @@
+import google.generativeai as genai
+from .models import text_llm_model
+
+MAX_TOKENS = 4096
+
+def ask_llm_about_text(text_content):
+ """
+ Hàm này dùng model mặc định (ví dụ gemini-pro) load từ models.py để phân tích sâu.
+ """
+ if not text_llm_model:
+ return "ERROR", "Text LLM Model not initialized."
+
+ prompt = f"""
+ You are an expert AI text forensic analyst.
+
+ TASK: Analyze the text below and determine if it is "Human-written" or "AI-generated".
+
+ TEXT:
+ "{text_content}"
+
+ CRITERIA:
+ - AI: Overly structured, repetitive, perfect grammar but lack of depth, neutral tone, hallucinations.
+ - Human: Contextual, emotional, slang, imperfections, complex sentence structures.
+
+ OUTPUT FORMAT:
+ Strictly follow this format:
+ Status: [LIKELY AI GENERATED / LIKELY HUMAN WRITTEN]
+ Reason: [Short explanation, max 2 sentences]
+ """
+
+ try:
+ generation_config = genai.types.GenerationConfig(
+ temperature=0.0,
+ max_output_tokens=MAX_TOKENS
+ )
+
+ response = text_llm_model.generate_content(
+ prompt,
+ generation_config=generation_config
+ )
+ return "SUCCESS", response.text.strip()
+ except Exception as e:
+ return "ERROR", str(e)
+
+def ask_llm_to_rewrite(text_content):
+
+ if not text_llm_model:
+ return None
+
+ full_prompt = f"""
+ You are a helpful assistant.
+ Paraphrasing for the text: ```{text_content}``` Only output the paraphrased text without explanation.
+ Rewritten Version:
+ """
+
+ try:
+ flash_model = genai.GenerativeModel('gemini-2.0-flash')
+
+ config = genai.types.GenerationConfig(
+ temperature=0.0,
+ max_output_tokens=MAX_TOKENS
+ )
+
+ # 4. Gọi hàm generate_content
+ response = flash_model.generate_content(
+ full_prompt,
+ generation_config=config
+ )
+
+ if response and response.text:
+ return response.text.strip()
+
+ return None
+
+ except Exception as e:
+ print(f"❌ Gemini Flash Rewrite Error: {e}")
+ return None
\ No newline at end of file
diff --git a/miragenews/text_module/models.py b/miragenews/text_module/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5d37a0ed8d31e8995209204f9388e85ca879c8c
--- /dev/null
+++ b/miragenews/text_module/models.py
@@ -0,0 +1,21 @@
+# text_module/models.py
+import google.generativeai as genai
+from sentence_transformers import SentenceTransformer
+from dotenv import load_dotenv
+load_dotenv()
+from .config import GEMINI_API_KEY
+
+print("⏳ Loading Text Embedding Model...")
+text_model = SentenceTransformer('intfloat/multilingual-e5-large')
+
+# --- CONFIG GEMINI (Detection) ---
+text_llm_model = None
+try:
+ if GEMINI_API_KEY:
+ genai.configure(api_key=GEMINI_API_KEY)
+ text_llm_model = genai.GenerativeModel('gemini-2.5-pro')
+ print("✅ Google Gemini (Text LLM) configured.")
+ else:
+ print("⚠️ Missing GEMINI_API_KEY.")
+except Exception as e:
+ print(f"⚠️ Error configuring Gemini: {e}")
diff --git a/miragenews/text_module/pipeline.py b/miragenews/text_module/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..3449ae8ae5751b317f469532ad9178f1f0c47b61
--- /dev/null
+++ b/miragenews/text_module/pipeline.py
@@ -0,0 +1,305 @@
+import time
+import re
+import torch
+import numpy as np
+
+# --- GIẢ ĐỊNH CÁC MODULE CỦA BẠN ---
+# Bạn cần đảm bảo các file này tồn tại trong project structure của bạn
+try:
+ from .models import text_model # Model SentenceTransformer
+ from .search import search_google_official, search_scrapingdog_serp
+ from .scraper import scrape_text_content
+ from .llm_utils import ask_llm_about_text, ask_llm_to_rewrite
+except ImportError:
+ print("⚠️ Warning: Import failed. Running in standalone mode logic checking.")
+ pass
+
+# --- CONFIGURATION ---
+ALPHA = 0.96 # Ngưỡng xác định là REAL (Tuyệt đối)
+DELTA = 0.00923 # Ngưỡng chênh lệch AI Hypothesis
+MIN_CANDIDATE_SIM = 0.50 # Ngưỡng tối thiểu để chạy kiểm tra AI (Step 2)
+TIMEOUT_SECONDS = 300 # ⏳ GIỚI HẠN THỜI GIAN: 5 phút (300 giây)
+
+def split_sentences(text):
+ """Tách văn bản thành các câu."""
+ if not text: return []
+ # Tách dựa trên dấu câu kết thúc hoặc xuống dòng
+ sentences = re.split(r'(?<=[.!?])\s+|[\n\r]+', text)
+ return [s.strip() for s in sentences if s.strip()]
+
+def get_similarity(text1, text2):
+ """Tính độ tương đồng Cosine giữa 2 đoạn text dùng model embedding."""
+ if not text1 or not text2: return 0.0
+ try:
+ emb1 = text_model.encode([text1])
+ emb2 = text_model.encode([text2])
+ return text_model.similarity(emb1, emb2)[0][0].item()
+ except Exception as e:
+ print(f"Embedding Error: {e}")
+ return 0.0
+
+def run_text_alignment_algorithm(input_text, url_content_list):
+ """
+ Algorithm 2: EXTRACT_CANDIDATE & COMPUTE SIMILARITY
+ Tìm đoạn văn bản khớp nhất trong nội dung web so với input của người dùng.
+ """
+ St = split_sentences(input_text)
+
+ # Xử lý input đầu vào cho url_content
+ if isinstance(url_content_list, str):
+ Su = split_sentences(url_content_list)
+ else:
+ Su = url_content_list
+
+ if not St or not Su: return 0.0, "", ""
+
+ # Phase 1: Tìm điểm bắt đầu (Anchor alignment)
+ best_b = -1
+ max_sim_start = -1.0
+ target_start = St[0]
+
+ for i, sent_u in enumerate(Su):
+ sim = get_similarity(target_start, sent_u)
+ if best_b == -1 or sim > max_sim_start:
+ max_sim_start = sim
+ best_b = i
+
+ # Cắt bỏ phần đầu không liên quan của web content
+ if best_b != -1 and max_sim_start > 0.4:
+ Su = Su[best_b:]
+
+ # Phase 2: Sliding Window Matching
+ P = []
+ total_score = 0.0
+ count_matches = 0
+ match_details = []
+ candidate_parts = []
+ MAX_WINDOW = 3
+
+ # Copy list để không ảnh hưởng dữ liệu gốc
+ curr_St = list(St)
+ curr_Su = list(Su)
+
+ while curr_Su and curr_St:
+ best_p_score = -1.0
+ best_pair = None
+ remove_from_St = 0
+ remove_from_Su = 0
+
+ # Cố định câu đầu của Input, quét window bên Web
+ target_st_1 = curr_St[0]
+ limit_k = min(len(curr_Su), MAX_WINDOW)
+ for k in range(1, limit_k + 1):
+ Ck = " ".join(curr_Su[:k])
+ sim = get_similarity(target_st_1, Ck)
+ if best_pair is None or sim > best_p_score:
+ best_p_score = sim
+ best_pair = (target_st_1, Ck)
+ remove_from_St = 1
+ remove_from_Su = k
+
+ # Cố định câu đầu của Web, quét window bên Input
+ target_su_1 = curr_Su[0]
+ limit_l = min(len(curr_St), MAX_WINDOW)
+ for l in range(1, limit_l + 1):
+ Cl = " ".join(curr_St[:l])
+ sim = get_similarity(Cl, target_su_1)
+ if sim > best_p_score:
+ best_p_score = sim
+ best_pair = (Cl, target_su_1)
+ remove_from_St = l
+ remove_from_Su = 1
+
+ # Cập nhật kết quả matching
+ if best_pair:
+ P.append(best_pair)
+ total_score += best_p_score
+ count_matches += 1
+
+ # Lưu lại đoạn text khớp từ web để làm evidence
+ if remove_from_Su > 1:
+ web_part = " ".join(curr_Su[:remove_from_Su])
+ else:
+ web_part = curr_Su[0]
+ candidate_parts.append(web_part)
+
+ if best_p_score > 0.6:
+ match_details.append(f"• Sim: {best_p_score:.2f} | Web: \"{best_pair[1][:50]}...\"")
+
+ curr_St = curr_St[remove_from_St:]
+ curr_Su = curr_Su[remove_from_Su:]
+ else:
+ # Nếu không khớp tí nào, bỏ qua câu hiện tại
+ curr_St.pop(0)
+ curr_Su.pop(0)
+
+ final_avg_score = (total_score / count_matches) if count_matches > 0 else 0.0
+ candidate_content_str = " ".join(candidate_parts)
+
+ return final_avg_score, "\n".join(match_details), candidate_content_str
+
+def verify_text_logic(user_input, result_object):
+ """
+ Hàm chính để xác thực văn bản.
+ Args:
+ user_input (str): Văn bản cần kiểm tra.
+ result_object: Object chứa kết quả (có các method set_...).
+ """
+
+ # [TIMER START] Bắt đầu tính giờ
+ start_time = time.time()
+ timed_out = False
+
+ # 0. Validate input cơ bản
+ if not user_input or len(user_input.strip()) < 5:
+ result_object.set_authenticity_assessment("⚠️ INSUFFICIENT DATA")
+ result_object.set_verification_tools_methods("Input Validation")
+ result_object.set_synthetic_type("N/A")
+ result_object.set_other_artifacts("Input text is too short or empty.")
+ return
+
+ # Clear GPU cache nếu có
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Biến lưu trữ kết quả tốt nhất tìm được từ Search
+ best_candidate = {
+ "url": "",
+ "score_c": 0.0,
+ "content_tc": "",
+ "evidence": ""
+ }
+
+ # Danh sách chiến lược tìm kiếm
+ search_strategies = [
+ ("Google CSE", search_google_official),
+ ("ScrapingDog", search_scrapingdog_serp)
+ ]
+
+ tools_used_list = ["SearchLLM Algorithm", "Sentence Transformer"]
+ found_any_url = False
+
+ # --- BƯỚC 1: TÌM KIẾM & PHÂN TÍCH CHUYÊN SÂU ---
+ print(f"🚀 Starting Text Analysis for: {user_input[:30]}...")
+
+ for strategy_name, search_func in search_strategies:
+ # [CHECK TIMEOUT] Kiểm tra trước khi bắt đầu strategy mới
+ if time.time() - start_time > TIMEOUT_SECONDS:
+ timed_out = True
+ break
+
+ try:
+ print(f" Using {strategy_name}...")
+ urls = search_func(user_input, num_results=5)
+ except Exception as e:
+ print(f" Error in search {strategy_name}: {e}")
+ continue
+
+ if not urls: continue
+ found_any_url = True
+ tools_used_list.append(strategy_name)
+
+ # Duyệt qua từng URL tìm được
+ for i, url in enumerate(urls):
+ if time.time() - start_time > TIMEOUT_SECONDS:
+ print(f"⏳ TIMEOUT reached ({TIMEOUT_SECONDS}s). Aborting Search/Scrape.")
+ timed_out = True
+ break
+
+ # Filter cơ bản
+ if any(blocked in url for blocked in []):
+ continue
+
+ # 1.1 Scrape Content
+ try:
+ url_sentences = scrape_text_content(url)
+ if not url_sentences: continue
+
+ # 1.2 Run Matching Algorithm (Tốn thời gian nhất)
+ sigma_c, evidence_str, tc_content = run_text_alignment_algorithm(user_input, url_sentences)
+
+ # Cập nhật kết quả tốt nhất
+ if sigma_c > best_candidate["score_c"]:
+ best_candidate["score_c"] = sigma_c
+ best_candidate["url"] = url
+ best_candidate["content_tc"] = tc_content
+ best_candidate["evidence"] = evidence_str
+
+ # CASE 1: REAL - Nếu khớp gần như tuyệt đối -> Return ngay
+ if sigma_c >= ALPHA:
+ print(f"✅ Found exact match (Sim: {sigma_c:.4f}) at {url}")
+ result_object.set_authenticity_assessment("✅ REAL (Authentic)")
+ result_object.set_verification_tools_methods(", ".join(tools_used_list))
+ result_object.set_synthetic_type("Human/Existing Source")
+
+ artifacts = (f"An exact match has been found.\n"
+ f"Source: {url}\n"
+ f"Similarity Score: {sigma_c:.4f}")
+ result_object.set_other_artifacts(artifacts)
+ return # KẾT THÚC HÀM
+
+ except Exception as e:
+ print(f"⚠️ Error processing URL {url}: {e}")
+ continue
+
+ # Break vòng lặp chiến lược nếu đã timeout trong vòng lặp URL
+ if timed_out:
+ break
+
+ if not timed_out and best_candidate["score_c"] >= MIN_CANDIDATE_SIM:
+
+ # Check thời gian lần cuối trước khi gọi LLM Rewrite
+ if time.time() - start_time <= TIMEOUT_SECONDS:
+ print("🤖 Running AI Regeneration Hypothesis check...")
+ tools_used_list.append("LLM Regeneration Analysis")
+
+ tr_content = ask_llm_to_rewrite(best_candidate["content_tc"])
+
+ if tr_content:
+ sigma_r = get_similarity(user_input, tr_content)
+ diff = sigma_r - best_candidate["score_c"]
+
+ print(f" Delta Diff: {diff:.4f} (Threshold: {DELTA})")
+
+ # CASE 2: AI DETECTED (Nếu bản rewrite giống input hơn bản gốc trên web)
+ if diff >= DELTA:
+ result_object.set_authenticity_assessment("🤖 NOT REAL (Fake, Manipulated, or AI)")
+ result_object.set_verification_tools_methods(", ".join(tools_used_list))
+ result_object.set_synthetic_type("AI Generated (Derived from Source)")
+
+ artifacts = (f"Detected AI generation pattern by our system.\n"
+ f"Potential Source: {best_candidate['url']}\n"
+ f"Rewrite Similarity Gain: +{diff:.4f}")
+ result_object.set_other_artifacts(artifacts)
+ return # KẾT THÚC HÀM
+ else:
+ timed_out = True # Đánh dấu timeout nếu sát giờ quá
+
+ print("🔻 Switching to Direct LLM Analysis (Fallback)...")
+ tools_used_list.append("Direct LLM Analysis")
+
+ # Gọi LLM trực tiếp để đánh giá
+ status_llm, response_llm = ask_llm_about_text(user_input)
+
+ # CASE 3: FALLBACK RESULT
+ if "AI GENERATED" in str(response_llm).upper():
+ auth_status = "🤖 NOT REAL (Fake, Manipulated, or AI)"
+ syn_type = "AI Generated Text"
+ verdict_color = "#dc2626" # Red
+ else:
+ auth_status = "✅ REAL (Authentic)"
+ syn_type = "Human Written"
+ verdict_color = "#16a34a" # Green
+
+ # Set kết quả cuối cùng
+ result_object.set_authenticity_assessment(auth_status)
+ result_object.set_verification_tools_methods(", ".join(tools_used_list))
+ result_object.set_synthetic_type(syn_type)
+
+ # Artifacts bao gồm thông tin fallback
+ final_artifacts = (f"We identified matching text from the provided URL: {best_candidate['url'] if best_candidate['url'] else 'None'}\n\n"
+ f"-------------------\n"
+ f"LLM Reasoning: {response_llm}")
+
+ result_object.set_other_artifacts(final_artifacts)
+ print("🏁 Text Verification Completed.")
\ No newline at end of file
diff --git a/miragenews/text_module/scraper.py b/miragenews/text_module/scraper.py
new file mode 100644
index 0000000000000000000000000000000000000000..b881fe07327b32d3761bb92c20e31bf8b85b95d7
--- /dev/null
+++ b/miragenews/text_module/scraper.py
@@ -0,0 +1,72 @@
+# text_module/scraper.py
+import requests
+import nltk
+import shutil
+import os
+import zipfile
+from bs4 import BeautifulSoup
+import trafilatura
+
+try:
+ from config import SCRAPINGDOG_API_KEY, USER_AGENT
+except ImportError:
+ from .config import SCRAPINGDOG_API_KEY, USER_AGENT
+
+# --- NLTK SETUP ---
+def setup_nltk():
+ try:
+ nltk.data.find('tokenizers/punkt')
+ nltk.data.find('tokenizers/punkt_tab')
+ except (LookupError, zipfile.BadZipFile, OSError):
+ try:
+ for path in nltk.data.path:
+ punkt_path = os.path.join(path, 'tokenizers', 'punkt')
+ if os.path.exists(punkt_path): shutil.rmtree(punkt_path, ignore_errors=True)
+ except: pass
+ nltk.download('punkt')
+ nltk.download('punkt_tab')
+
+setup_nltk()
+
+def parse_text_html(html_content):
+ if not html_content: return []
+
+ text = trafilatura.extract(html_content, include_comments=False, include_tables=False)
+
+ if not text or len(text) < 50: return []
+
+ text = text.replace('\n', ' ')
+ return nltk.sent_tokenize(text)
+
+# --- HÀM CHÍNH ---
+def scrape_text_content(url):
+ """
+ Cào dữ liệu: Thử Requests thường -> Lỗi -> Thử ScrapingDog API
+ """
+ try:
+ headers = {'User-Agent': USER_AGENT}
+ response = requests.get(url, headers=headers, timeout=5)
+ if response.status_code == 200:
+ content = parse_text_html(response.text)
+ if content:
+ return content
+ except Exception:
+ pass
+
+ if SCRAPINGDOG_API_KEY:
+ try:
+ # dynamic=true để xử lý trang web dùng JS (React/Vue...)
+ sd_url = "https://api.scrapingdog.com/scrape"
+ params = {
+ 'api_key': SCRAPINGDOG_API_KEY,
+ 'url': url,
+ 'dynamic': 'true'
+ }
+ r = requests.get(sd_url, params=params, timeout=20)
+ if r.status_code == 200:
+ return parse_text_html(r.text)
+ except Exception as e:
+ print(f"ScrapingDog Error: {e}")
+
+
+ return []
\ No newline at end of file
diff --git a/miragenews/text_module/search.py b/miragenews/text_module/search.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0919c79f9c9d695124119df7a653c5c041a4731
--- /dev/null
+++ b/miragenews/text_module/search.py
@@ -0,0 +1,54 @@
+# text_module/search.py
+import requests
+from .config import GOOGLE_SAFE_BROWSING_API_KEY, GOOGLE_CX, SCRAPINGDOG_API_KEY
+
+def search_google_official(query, num_results=10):
+ """Vòng 1: Dùng Google Custom Search API (Official)"""
+ if not GOOGLE_SAFE_BROWSING_API_KEY or not GOOGLE_CX:
+ print("⚠️ Thiếu Google API Key/CX")
+ return []
+
+ url = "https://www.googleapis.com/customsearch/v1"
+ params = {
+ 'key': GOOGLE_SAFE_BROWSING_API_KEY,
+ 'cx': GOOGLE_CX,
+ 'q': query,
+ 'num': num_results
+ }
+ try:
+ response = requests.get(url, params=params, timeout=10)
+ if response.status_code == 200:
+ data = response.json()
+ return [item['link'] for item in data.get('items', [])]
+ except Exception as e:
+ print(f"❌ Google Official Search Error: {e}")
+ return []
+
+def search_scrapingdog_serp(query, num_results=10):
+ """Vòng 2: Dùng ScrapingDog SERP API"""
+ if not SCRAPINGDOG_API_KEY:
+ print("⚠️ Thiếu SCRAPINGDOG_API_KEY")
+ return []
+
+ url = "https://api.scrapingdog.com/google_search"
+ params = {
+ 'api_key': SCRAPINGDOG_API_KEY,
+ 'query': query,
+ 'results': num_results,
+ 'country': 'vn'
+ }
+
+ try:
+ print(f"🔎 Calling ScrapingDog SERP for: {query}...")
+ response = requests.get(url, params=params, timeout=20)
+ if response.status_code == 200:
+ data = response.json()
+ links = []
+ if 'organic_data' in data:
+ for item in data['organic_data']:
+ if 'link' in item:
+ links.append(item['link'])
+ return links
+ except Exception as e:
+ print(f"❌ ScrapingDog Search Error: {e}")
+ return []
\ No newline at end of file
diff --git a/miragenews/utils/__init__.py b/miragenews/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fdede4fb3e78fb1a30483f19b83bb449f6e3285
--- /dev/null
+++ b/miragenews/utils/__init__.py
@@ -0,0 +1,224 @@
+import torch
+import numpy as np
+import json
+from tqdm import tqdm
+from .metrics import calculate_metrics, find_best_threshold
+from miragenews.models import *
+
+# === SHARED FUNCTIONS ===
+
+def save_model_checkpoint(model, save_path, threshold):
+ """
+ Save the model state and threshold as a checkpoint.
+
+ Args:
+ model (torch.nn.Module): The model to save.
+ save_path (str): Path to save the checkpoint.
+ threshold (float): The best threshold to save.
+ """
+ checkpoint = {
+ "model_state_dict": model.state_dict(),
+ "best_threshold": threshold,
+ }
+ torch.save(checkpoint, save_path)
+
+
+def load_model_checkpoint(model, save_path):
+ """
+ Load the model state and best threshold from a checkpoint.
+
+ Args:
+ model (torch.nn.Module): The model to load the state into.
+ save_path (str): Path to the checkpoint.
+
+ Returns:
+ model (torch.nn.Module): Model loaded with state_dict.
+ best_threshold (float): Best threshold saved in the checkpoint.
+ """
+ checkpoint = torch.load(save_path)
+
+ model.load_state_dict(checkpoint["model_state_dict"])
+ best_threshold = checkpoint["best_threshold"] if "best_threshold" in checkpoint else None
+ return model, best_threshold
+
+
+def evaluate_model(model, data_loader, criterion, device="cuda", threshold=0.5, cbm_encoder=None, concept_num=300):
+ """
+ Evaluate the model on a given dataset, calculating metrics using a specified threshold.
+
+ Args:
+ model (torch.nn.Module): The model to evaluate.
+ data_loader (DataLoader): DataLoader for the dataset.
+ criterion: Loss function.
+ device (str): Device for computation.
+ threshold (float): Threshold for binary classification.
+ cbm_encoder (torch.nn.Module, optional): CBM encoder model, if using CBM Predictor.
+ concept_num (int): Number of concepts in CBM Predictor.
+
+ Returns:
+ avg_loss (float): Average loss over the dataset.
+ metrics (dict): Calculated metrics using the specified threshold.
+ """
+ model.eval()
+ y_true = []
+ y_probs = []
+ total_loss = 0.0
+
+ with torch.no_grad():
+ for inputs, labels in data_loader:
+ inputs, labels = inputs.to(device), labels.to(device)
+
+ # Forward pass logic based on model type
+ if isinstance(model, ObjectClassCBMEncoder):
+ classifier_idx = 0
+ outputs = model(inputs.float(), classifier_idx)
+ # elif isinstance(model, ObjectClassCBMPredictor) and cbm_encoder:
+ # concept_features = [cbm_encoder(inputs.float(), i) for i in range(concept_num)]
+ # pred_scores = torch.cat(concept_features, dim=1)
+ # outputs = model(pred_scores.to(device))
+ else:
+ outputs = model(inputs.float())
+
+ # Calculate loss and store probabilities
+
+ loss = criterion(outputs.squeeze(-1), labels.float())
+ total_loss += loss.item()
+ y_true.extend(labels.cpu().numpy())
+ y_probs.extend(outputs.squeeze(-1).cpu().numpy())
+
+ avg_loss = total_loss / len(data_loader)
+ y_true = np.array(y_true)
+ y_probs = np.array(y_probs)
+ metrics = calculate_metrics(y_true, y_probs, threshold)
+
+ model.train()
+ return avg_loss, metrics, y_true, y_probs # Return y_true and y_probs for threshold finding
+
+def save_metrics(metrics, save_path):
+ """
+ Save metrics as a JSON file.
+
+ Args:
+ metrics (dict): Dictionary of metrics to save.
+ save_path (str): Path to the file where metrics will be saved.
+ """
+ with open(save_path, "a") as f:
+ f.write(json.dumps(metrics) + "\n")
+
+# === TRAINING FUNCTION ===
+
+def train_model(model, train_loader, eval_loader, config, classifier_idx=None, cbm_encoder=None, device="cuda"):
+ """
+ Train the model with early stopping and save the best model checkpoint with the threshold.
+ """
+ criterion = torch.nn.BCELoss()
+ optimizer = torch.optim.Adam(model.parameters(), lr=config['training']['learning_rate'])
+ best_val_loss = float('inf')
+ patience_counter = 0
+ best_threshold = 0.5
+
+ for epoch in tqdm(range(config['training']['epochs']), desc="Training"):
+ model.train()
+ running_loss = 0.0
+
+ for batch_idx, (inputs, labels) in enumerate(train_loader):
+ inputs, labels = inputs.to(device), labels.to(device)
+ optimizer.zero_grad()
+ if isinstance(model, ObjectClassCBMEncoder):
+ outputs = model(inputs.float(), classifier_idx)
+ # elif isinstance(model, ObjectClassCBMPredictor) and cbm_encoder:
+ # concept_features = [cbm_encoder(inputs.float(), i) for i in range(config.get('concept_num', 300))]
+ # pred_scores = torch.cat(concept_features, dim=1)
+ # outputs = model(pred_scores.to(device))
+ else:
+ outputs = model(inputs.float())
+
+ loss = criterion(outputs.squeeze(-1), labels.float())
+ loss.backward()
+ optimizer.step()
+ running_loss += loss.item()
+
+ train_loss = running_loss / len(train_loader)
+ # print(f"Epoch {epoch + 1}/{config['training']['epochs']}, Training Loss: {train_loss:.4f}")
+
+ # Evaluate on validation set
+ eval_loss, val_metrics, y_true, y_probs = evaluate_model(model, eval_loader, criterion, device, cbm_encoder=cbm_encoder)
+ # print(f"Validation Loss: {eval_loss:.4f}, Metrics: {val_metrics}")
+
+ # Find and save the best threshold based on validation
+ best_threshold = find_best_threshold(y_true, y_probs)
+ # print(f"Best Threshold Found: {best_threshold}")
+
+ if eval_loss < best_val_loss - 0.001:
+ best_val_loss = eval_loss
+ patience_counter = 0
+ save_model_checkpoint(model, config['training']['save_path'], best_threshold)
+ else:
+ patience_counter += 1
+
+ if patience_counter > 10:
+ print("Early stopping triggered. Training completed.")
+ break
+
+
+# === TESTING FUNCTION ===
+
+def test_model(model, test_loader, checkpoint_path, device="cuda", model_2=None):
+ """
+ Load the best model and threshold, then evaluate on the test set and save metrics.
+
+ Args:
+ model (torch.nn.Module): The trained model.
+ test_loader (DataLoader): DataLoader for the test set.
+ checkpoint_path (str): Path to the saved model checkpoint.
+ device (str): Device for computation (e.g., 'cuda' or 'cpu').
+
+ Returns:
+ dict: Test metrics.
+ """
+ criterion = torch.nn.BCELoss()
+ # Load model and best threshold
+ model, best_threshold = load_model_checkpoint(model, checkpoint_path)
+ # Evaluate on the test set with the loaded threshold
+ test_loss, test_metrics, _, _ = evaluate_model(model, test_loader, criterion, device, threshold=best_threshold)
+ # print("Test Metrics:", test_metrics)
+
+ return test_metrics
+
+def test_multimodal_model(image_model, text_model, test_loader, threshold=0.5, device='cuda'):
+ image_model.eval()
+ text_model.eval()
+ criterion = torch.nn.BCELoss()
+ y_true = []
+ y_probs = []
+ total_loss = 0.0
+
+ with torch.no_grad():
+ for image_inputs, text_inputs, labels in test_loader:
+ image_inputs, text_inputs, labels = image_inputs.to(device), text_inputs.to(device), labels.to(device)
+
+ # Forward pass logic based on model type
+ if isinstance(image_model, ObjectClassCBMEncoder):
+ classifier_idx = 0
+ image_outputs = image_model(image_inputs.float(), classifier_idx)
+ # elif isinstance(model, ObjectClassCBMPredictor) and cbm_encoder:
+ # concept_features = [cbm_encoder(inputs.float(), i) for i in range(concept_num)]
+ # pred_scores = torch.cat(concept_features, dim=1)
+ # outputs = model(pred_scores.to(device))
+ else:
+ image_outputs = image_model(image_inputs.float())
+
+ text_outputs = text_model(text_inputs.float())
+ # Calculate loss and store probabilities
+ outputs = (image_outputs + text_outputs) / 2
+ loss = criterion(outputs.squeeze(-1), labels.float())
+ total_loss += loss.item()
+ y_true.extend(labels.cpu().numpy())
+ y_probs.extend(outputs.squeeze(-1).cpu().numpy())
+
+ avg_loss = total_loss / len(test_loader)
+ y_true = np.array(y_true)
+ y_probs = np.array(y_probs)
+ metrics = calculate_metrics(y_true, y_probs, threshold)
+
+ return metrics
\ No newline at end of file
diff --git a/miragenews/utils/metrics.py b/miragenews/utils/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..3087794e6cebef8778da3e40197689d4af73259e
--- /dev/null
+++ b/miragenews/utils/metrics.py
@@ -0,0 +1,56 @@
+from sklearn.metrics import accuracy_score, f1_score, average_precision_score
+import numpy as np
+
+def calculate_metrics(y_true, y_probs, threshold=0.5):
+ """
+ Calculate evaluation metrics for binary classification using a specified threshold.
+
+ Args:
+ y_true (np.array): True binary labels.
+ y_probs (np.array): Predicted probabilities.
+ threshold (float): Threshold to convert probabilities to binary predictions.
+
+ Returns:
+ dict: Dictionary containing accuracy, F1 score, average precision, real accuracy,
+ fake accuracy, and the threshold used.
+ """
+ # Convert probabilities to binary predictions using the specified threshold
+ y_pred = (y_probs > threshold).astype(int)
+
+ # Calculate metrics
+ metrics = {
+ "accuracy": round(accuracy_score(y_true, y_pred), 4),
+ "real_accuracy": round(accuracy_score(y_true[y_true == 0], y_pred[y_true == 0]), 4) if (y_true == 0).any() else 0.0,
+ "fake_accuracy": round(accuracy_score(y_true[y_true == 1], y_pred[y_true == 1]), 4) if (y_true == 1).any() else 0.0,
+ "f1": round(f1_score(y_true, y_pred), 4),
+ "average_precision": round(average_precision_score(y_true, y_probs), 4),
+ "threshold": round(threshold, 4)
+ }
+
+ # Convert all values to native Python types
+ return {k: float(v) if isinstance(v, (np.float32, np.float64)) else v for k, v in metrics.items()}
+
+
+def find_best_threshold(y_true, y_probs):
+ """
+ Find the best threshold for binary classification to maximize accuracy.
+
+ Args:
+ y_true (np.array): True binary labels.
+ y_probs (np.array): Predicted probabilities.
+
+ Returns:
+ float: The best threshold that maximizes accuracy.
+ """
+ thresholds = np.unique(y_probs)
+ best_accuracy = 0.0
+ best_threshold = 0.5
+
+ for thresh in thresholds:
+ y_pred = (y_probs > thresh).astype(int)
+ accuracy = accuracy_score(y_true, y_pred)
+ if accuracy > best_accuracy:
+ best_accuracy = accuracy
+ best_threshold = thresh
+
+ return round(float(best_threshold), 4)
diff --git a/raft_model/raft-things.pth b/raft_model/raft-things.pth
new file mode 100644
index 0000000000000000000000000000000000000000..1e206ac8a2f660bc7620b0806a9278ddb3fc594d
--- /dev/null
+++ b/raft_model/raft-things.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fcfa4125d6418f4de95d84aec20a3c5f4e205101715a79f193243c186ac9a7e1
+size 21108000
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b590b2b234143d5b95fd7b1994b93b51a1d6c027
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,34 @@
+--extra-index-url https://download.pytorch.org/whl/cu121
+torch==2.3.1
+torchvision==0.18.1
+torchaudio==2.3.1
+numpy<2
+pandas
+scikit-learn
+matplotlib
+scipy
+tqdm
+einops
+imageio
+opencv-python
+Pillow==12.0.0
+transformers==4.57.1
+sentence-transformers==5.1.2
+nltk==3.9.2
+tensorboard
+tensorboardX
+blobfile>=1.0.5
+natsort
+trafilatura==2.0.0
+fastapi==0.116.1
+pydantic==2.11.7
+uvicorn[standard]
+python-multipart
+python-dotenv==1.2.1
+httpx==0.28.1
+Requests==2.32.5
+beautifulsoup4==4.14.3
+git+https://github.com/openai/CLIP.git
+google-cloud-vision
+gradio
+google.generativeai
\ No newline at end of file
diff --git a/uploads/video_result.json b/uploads/video_result.json
new file mode 100644
index 0000000000000000000000000000000000000000..2b698277f6eb54a67495248bd11a7dc69c2083d2
--- /dev/null
+++ b/uploads/video_result.json
@@ -0,0 +1,6 @@
+{
+ "Original_Index_0": {
+ "video_name": "Original_Index_0.png",
+ "error": "division by zero"
+ }
+}
\ No newline at end of file