{ "cells": [ { "cell_type": "markdown", "id": "1e0e7b4a", "metadata": {}, "source": [ "# DeepShield: FaceForensics++ ViT Training \n", "Run this entirely in Google Colab.\n", "**Before running**:\n", "1. Go to `Runtime` -> `Change runtime type` -> select **T4 GPU**.\n", "2. Run the cells below sequentially.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4fe293e7", "metadata": {}, "outputs": [], "source": [ "!pip install timm transformers datasets accelerate evaluate opencv-python\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c9387c0f", "metadata": {}, "outputs": [], "source": [ "# We create the download script inside the Colab environment\n", "download_script = '''#!/usr/bin/env python\n", "import argparse\n", "import os\n", "import urllib.request\n", "import tempfile\n", "import time\n", "import sys\n", "import json\n", "from tqdm import tqdm\n", "from os.path import join\n", "\n", "FILELIST_URL = 'misc/filelist.json'\n", "DEEPFEAKES_DETECTION_URL = 'misc/deepfake_detection_filenames.json'\n", "DEEPFAKES_MODEL_NAMES = ['decoder_A.h5', 'decoder_B.h5', 'encoder.h5',]\n", "DATASETS = {\n", " 'original': 'original_sequences/youtube',\n", " 'Deepfakes': 'manipulated_sequences/Deepfakes',\n", " 'Face2Face': 'manipulated_sequences/Face2Face',\n", " 'FaceShifter': 'manipulated_sequences/FaceShifter',\n", " 'FaceSwap': 'manipulated_sequences/FaceSwap',\n", " 'NeuralTextures': 'manipulated_sequences/NeuralTextures'\n", "}\n", "ALL_DATASETS = ['original', 'Deepfakes', 'Face2Face', 'FaceShifter', 'FaceSwap', 'NeuralTextures']\n", "COMPRESSION = ['raw', 'c23', 'c40']\n", "TYPE = ['videos']\n", "\n", "def download_file(url, out_file):\n", " os.makedirs(os.path.dirname(out_file), exist_ok=True)\n", " if not os.path.isfile(out_file):\n", " urllib.request.urlretrieve(url, out_file)\n", "\n", "def main():\n", " parser = argparse.ArgumentParser()\n", " parser.add_argument('output_path', type=str)\n", " parser.add_argument('-d', '--dataset', type=str, default='all')\n", " parser.add_argument('-c', '--compression', type=str, default='c40')\n", " parser.add_argument('-t', '--type', type=str, default='videos')\n", " parser.add_argument('-n', '--num_videos', type=int, default=50) # Small amount for tutorial\n", " args = parser.parse_args()\n", " \n", " base_url = 'http://kaldir.vc.in.tum.de/faceforensics/v3/'\n", " \n", " datasets = [args.dataset] if args.dataset != 'all' else ALL_DATASETS\n", " for dataset in datasets:\n", " dataset_path = DATASETS[dataset]\n", " print(f'Downloading {args.compression} of {dataset}')\n", " \n", " file_pairs = json.loads(urllib.request.urlopen(base_url + FILELIST_URL).read().decode(\"utf-8\"))\n", " filelist = []\n", " if 'original' in dataset_path:\n", " for pair in file_pairs:\n", " filelist += pair\n", " else:\n", " for pair in file_pairs:\n", " filelist.append('_'.join(pair))\n", " filelist.append('_'.join(pair[::-1]))\n", " \n", " filelist = filelist[:args.num_videos]\n", " dataset_videos_url = base_url + f'{dataset_path}/{args.compression}/{args.type}/'\n", " dataset_output_path = join(args.output_path, dataset_path, args.compression, args.type)\n", " \n", " for filename in tqdm(filelist):\n", " download_file(dataset_videos_url + filename + \".mp4\", join(dataset_output_path, filename + \".mp4\"))\n", "\n", "if __name__ == \"__main__\":\n", " main()\n", "'''\n", "\n", "with open(\"download_ffpp.py\", \"w\") as f:\n", " f.write(download_script)\n", "\n", "!python download_ffpp.py ./data -d all -c c40 -t videos -n 50\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f33716f6", "metadata": {}, "outputs": [], "source": [ "import cv2\n", "import os\n", "import glob\n", "from tqdm import tqdm\n", "\n", "def extract_frames(video_folder, output_folder, label, max_frames=4):\n", " os.makedirs(output_folder, exist_ok=True)\n", " videos = glob.glob(os.path.join(video_folder, \"*.mp4\"))\n", " \n", " for vid_path in tqdm(videos, desc=f\"Extracting {label}\"):\n", " vid_name = os.path.basename(vid_path).replace('.mp4','')\n", " cap = cv2.VideoCapture(vid_path)\n", " count = 0\n", " while cap.isOpened() and count < max_frames:\n", " ret, frame = cap.read()\n", " if not ret: break\n", " frame = cv2.resize(frame, (224, 224))\n", " out_path = os.path.join(output_folder, f\"{vid_name}_f{count}.jpg\")\n", " cv2.imwrite(out_path, frame)\n", " count += 1\n", " cap.release()\n", "\n", "# Extract Real\n", "extract_frames('./data/original_sequences/youtube/c40/videos', './dataset/real', 'real')\n", "\n", "# Extract Fakes\n", "fakes = ['Deepfakes', 'Face2Face', 'FaceSwap', 'NeuralTextures']\n", "for f in fakes:\n", " extract_frames(f'./data/manipulated_sequences/{f}/c40/videos', './dataset/fake', 'fake')\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b79cdd85", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from datasets import load_dataset\n", "from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer\n", "import torch\n", "\n", "# 1. Load Dataset\n", "dataset = load_dataset('imagefolder', data_dir='./dataset')\n", "# Split into train/validation\n", "dataset = dataset['train'].train_test_split(test_size=0.1)\n", "\n", "# 2. Preprocessor\n", "model_name_or_path = 'google/vit-base-patch16-224-in21k'\n", "processor = ViTImageProcessor.from_pretrained(model_name_or_path)\n", "\n", "def transform(example_batch):\n", " # Take a list of PIL images and turn them to pixel values\n", " inputs = processor([x.convert(\"RGB\") for x in example_batch['image']], return_tensors='pt')\n", " inputs['labels'] = example_batch['label']\n", " return inputs\n", "\n", "prepared_ds = dataset.with_transform(transform)\n", "\n", "def collate_fn(batch):\n", " return {\n", " 'pixel_values': torch.stack([x['pixel_values'] for x in batch]),\n", " 'labels': torch.tensor([x['labels'] for x in batch])\n", " }\n", "\n", "# 3. Load Model\n", "labels = dataset['train'].features['label'].names\n", "model = ViTForImageClassification.from_pretrained(\n", " model_name_or_path,\n", " num_labels=len(labels),\n", " id2label={str(i): c for i, c in enumerate(labels)},\n", " label2id={c: str(i) for i, c in enumerate(labels)}\n", ")\n", "\n", "training_args = TrainingArguments(\n", " output_dir=\"./vit-deepshield\",\n", " per_device_train_batch_size=16,\n", " eval_strategy=\"steps\",\n", " num_train_epochs=3,\n", " fp16=True, # Mixed precision for speed\n", " save_steps=100,\n", " eval_steps=100,\n", " logging_steps=10,\n", " learning_rate=2e-4,\n", " save_total_limit=2,\n", " remove_unused_columns=False,\n", " push_to_hub=False,\n", " load_best_model_at_end=True,\n", ")\n", "\n", "import evaluate\n", "metric = evaluate.load(\"accuracy\")\n", "def compute_metrics(p):\n", " return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " data_collator=collate_fn,\n", " compute_metrics=compute_metrics,\n", " train_dataset=prepared_ds[\"train\"],\n", " eval_dataset=prepared_ds[\"test\"],\n", ")\n", "\n", "# 4. Train\n", "train_results = trainer.train()\n", "trainer.save_model(\"deepshield_vit_model\")\n", "processor.save_pretrained(\"deepshield_vit_model\")\n", "trainer.log_metrics(\"train\", train_results.metrics)\n", "trainer.save_metrics(\"train\", train_results.metrics)\n", "trainer.save_state()\n", "print(\"Training Complete! The model is saved to ./deepshield_vit_model\")\n" ] } ], "metadata": {}, "nbformat": 4, "nbformat_minor": 5 }