diff --git a/.env b/.env new file mode 100644 index 0000000000000000000000000000000000000000..14c42fd0e598482e9653f00b9d959afb58093bf4 --- /dev/null +++ b/.env @@ -0,0 +1,39 @@ +# ═══════════════════════════════════════ +# DeepShield — Environment Configuration +# ═══════════════════════════════════════ +# Copy this file to backend/.env and customize + +# Server +APP_HOST=0.0.0.0 +APP_PORT=8000 +DEBUG=true +CORS_ORIGINS=["http://localhost:5173"] + +# Database +# For local dev: sqlite:///./deepshield.db +# For production (Neon/Supabase): postgresql://username:password@ep-cool... +DATABASE_URL=postgresql://neondb_owner:npg_YUdXqlrDP3H2@ep-divine-tooth-ame27uf3-pooler.c-5.us-east-1.aws.neon.tech/neondb?sslmode=require&channel_binding=require + +# File Upload +MAX_UPLOAD_SIZE_MB=100 +UPLOAD_DIR=./temp_uploads +FILE_RETENTION_SECONDS=300 + +# AI Models +IMAGE_MODEL_ID=prithivMLmods/Deep-Fake-Detector-v2-Model +TEXT_MODEL_ID=jy46604790/Fake-News-Bert-Detect +DEVICE=cpu +PRELOAD_MODELS=true + +# News API (optional — sign up at https://newsdata.io) +NEWS_API_KEY=pub_83c8fca805124a4fb074256825decd4c +NEWS_API_BASE_URL=https://newsdata.io/api/1/news + +# PDF Reports +REPORT_DIR=./temp_reports +REPORT_TTL_SECONDS=3600 + +# Auth — CHANGE JWT_SECRET_KEY IN PRODUCTION! +JWT_SECRET_KEY=change-me-in-production +JWT_ALGORITHM=HS256 +JWT_EXPIRATION_MINUTES=1440 diff --git a/Colab_ViT_Training.ipynb b/Colab_ViT_Training.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..cb7dff448bd445cb9a01a51ca01984864890df14 --- /dev/null +++ b/Colab_ViT_Training.ipynb @@ -0,0 +1,233 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1e0e7b4a", + "metadata": {}, + "source": [ + "# DeepShield: FaceForensics++ ViT Training \n", + "Run this entirely in Google Colab.\n", + "**Before running**:\n", + "1. Go to `Runtime` -> `Change runtime type` -> select **T4 GPU**.\n", + "2. Run the cells below sequentially.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4fe293e7", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install timm transformers datasets accelerate evaluate opencv-python\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c9387c0f", + "metadata": {}, + "outputs": [], + "source": [ + "# We create the download script inside the Colab environment\n", + "download_script = '''#!/usr/bin/env python\n", + "import argparse\n", + "import os\n", + "import urllib.request\n", + "import tempfile\n", + "import time\n", + "import sys\n", + "import json\n", + "from tqdm import tqdm\n", + "from os.path import join\n", + "\n", + "FILELIST_URL = 'misc/filelist.json'\n", + "DEEPFEAKES_DETECTION_URL = 'misc/deepfake_detection_filenames.json'\n", + "DEEPFAKES_MODEL_NAMES = ['decoder_A.h5', 'decoder_B.h5', 'encoder.h5',]\n", + "DATASETS = {\n", + " 'original': 'original_sequences/youtube',\n", + " 'Deepfakes': 'manipulated_sequences/Deepfakes',\n", + " 'Face2Face': 'manipulated_sequences/Face2Face',\n", + " 'FaceShifter': 'manipulated_sequences/FaceShifter',\n", + " 'FaceSwap': 'manipulated_sequences/FaceSwap',\n", + " 'NeuralTextures': 'manipulated_sequences/NeuralTextures'\n", + "}\n", + "ALL_DATASETS = ['original', 'Deepfakes', 'Face2Face', 'FaceShifter', 'FaceSwap', 'NeuralTextures']\n", + "COMPRESSION = ['raw', 'c23', 'c40']\n", + "TYPE = ['videos']\n", + "\n", + "def download_file(url, out_file):\n", + " os.makedirs(os.path.dirname(out_file), exist_ok=True)\n", + " if not os.path.isfile(out_file):\n", + " urllib.request.urlretrieve(url, out_file)\n", + "\n", + "def main():\n", + " parser = argparse.ArgumentParser()\n", + " parser.add_argument('output_path', type=str)\n", + " parser.add_argument('-d', '--dataset', type=str, default='all')\n", + " parser.add_argument('-c', '--compression', type=str, default='c40')\n", + " parser.add_argument('-t', '--type', type=str, default='videos')\n", + " parser.add_argument('-n', '--num_videos', type=int, default=50) # Small amount for tutorial\n", + " args = parser.parse_args()\n", + " \n", + " base_url = 'http://kaldir.vc.in.tum.de/faceforensics/v3/'\n", + " \n", + " datasets = [args.dataset] if args.dataset != 'all' else ALL_DATASETS\n", + " for dataset in datasets:\n", + " dataset_path = DATASETS[dataset]\n", + " print(f'Downloading {args.compression} of {dataset}')\n", + " \n", + " file_pairs = json.loads(urllib.request.urlopen(base_url + FILELIST_URL).read().decode(\"utf-8\"))\n", + " filelist = []\n", + " if 'original' in dataset_path:\n", + " for pair in file_pairs:\n", + " filelist += pair\n", + " else:\n", + " for pair in file_pairs:\n", + " filelist.append('_'.join(pair))\n", + " filelist.append('_'.join(pair[::-1]))\n", + " \n", + " filelist = filelist[:args.num_videos]\n", + " dataset_videos_url = base_url + f'{dataset_path}/{args.compression}/{args.type}/'\n", + " dataset_output_path = join(args.output_path, dataset_path, args.compression, args.type)\n", + " \n", + " for filename in tqdm(filelist):\n", + " download_file(dataset_videos_url + filename + \".mp4\", join(dataset_output_path, filename + \".mp4\"))\n", + "\n", + "if __name__ == \"__main__\":\n", + " main()\n", + "'''\n", + "\n", + "with open(\"download_ffpp.py\", \"w\") as f:\n", + " f.write(download_script)\n", + "\n", + "!python download_ffpp.py ./data -d all -c c40 -t videos -n 50\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f33716f6", + "metadata": {}, + "outputs": [], + "source": [ + "import cv2\n", + "import os\n", + "import glob\n", + "from tqdm import tqdm\n", + "\n", + "def extract_frames(video_folder, output_folder, label, max_frames=4):\n", + " os.makedirs(output_folder, exist_ok=True)\n", + " videos = glob.glob(os.path.join(video_folder, \"*.mp4\"))\n", + " \n", + " for vid_path in tqdm(videos, desc=f\"Extracting {label}\"):\n", + " vid_name = os.path.basename(vid_path).replace('.mp4','')\n", + " cap = cv2.VideoCapture(vid_path)\n", + " count = 0\n", + " while cap.isOpened() and count < max_frames:\n", + " ret, frame = cap.read()\n", + " if not ret: break\n", + " frame = cv2.resize(frame, (224, 224))\n", + " out_path = os.path.join(output_folder, f\"{vid_name}_f{count}.jpg\")\n", + " cv2.imwrite(out_path, frame)\n", + " count += 1\n", + " cap.release()\n", + "\n", + "# Extract Real\n", + "extract_frames('./data/original_sequences/youtube/c40/videos', './dataset/real', 'real')\n", + "\n", + "# Extract Fakes\n", + "fakes = ['Deepfakes', 'Face2Face', 'FaceSwap', 'NeuralTextures']\n", + "for f in fakes:\n", + " extract_frames(f'./data/manipulated_sequences/{f}/c40/videos', './dataset/fake', 'fake')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b79cdd85", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from datasets import load_dataset\n", + "from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer\n", + "import torch\n", + "\n", + "# 1. Load Dataset\n", + "dataset = load_dataset('imagefolder', data_dir='./dataset')\n", + "# Split into train/validation\n", + "dataset = dataset['train'].train_test_split(test_size=0.1)\n", + "\n", + "# 2. Preprocessor\n", + "model_name_or_path = 'google/vit-base-patch16-224-in21k'\n", + "processor = ViTImageProcessor.from_pretrained(model_name_or_path)\n", + "\n", + "def transform(example_batch):\n", + " # Take a list of PIL images and turn them to pixel values\n", + " inputs = processor([x.convert(\"RGB\") for x in example_batch['image']], return_tensors='pt')\n", + " inputs['labels'] = example_batch['label']\n", + " return inputs\n", + "\n", + "prepared_ds = dataset.with_transform(transform)\n", + "\n", + "def collate_fn(batch):\n", + " return {\n", + " 'pixel_values': torch.stack([x['pixel_values'] for x in batch]),\n", + " 'labels': torch.tensor([x['labels'] for x in batch])\n", + " }\n", + "\n", + "# 3. Load Model\n", + "labels = dataset['train'].features['label'].names\n", + "model = ViTForImageClassification.from_pretrained(\n", + " model_name_or_path,\n", + " num_labels=len(labels),\n", + " id2label={str(i): c for i, c in enumerate(labels)},\n", + " label2id={c: str(i) for i, c in enumerate(labels)}\n", + ")\n", + "\n", + "training_args = TrainingArguments(\n", + " output_dir=\"./vit-deepshield\",\n", + " per_device_train_batch_size=16,\n", + " eval_strategy=\"steps\",\n", + " num_train_epochs=3,\n", + " fp16=True, # Mixed precision for speed\n", + " save_steps=100,\n", + " eval_steps=100,\n", + " logging_steps=10,\n", + " learning_rate=2e-4,\n", + " save_total_limit=2,\n", + " remove_unused_columns=False,\n", + " push_to_hub=False,\n", + " load_best_model_at_end=True,\n", + ")\n", + "\n", + "import evaluate\n", + "metric = evaluate.load(\"accuracy\")\n", + "def compute_metrics(p):\n", + " return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=training_args,\n", + " data_collator=collate_fn,\n", + " compute_metrics=compute_metrics,\n", + " train_dataset=prepared_ds[\"train\"],\n", + " eval_dataset=prepared_ds[\"test\"],\n", + ")\n", + "\n", + "# 4. Train\n", + "train_results = trainer.train()\n", + "trainer.save_model(\"deepshield_vit_model\")\n", + "processor.save_pretrained(\"deepshield_vit_model\")\n", + "trainer.log_metrics(\"train\", train_results.metrics)\n", + "trainer.save_metrics(\"train\", train_results.metrics)\n", + "trainer.save_state()\n", + "print(\"Training Complete! The model is saved to ./deepshield_vit_model\")\n" + ] + } + ], + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..17817d784a560dfa9a3d508f7921101df43364f4 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,31 @@ +# Base image with Python 3.10 +FROM python:3.10-slim + +# Set the working directory +WORKDIR /app + +# Install system dependencies required for OpenCV, PyTorch, etc. +RUN apt-get update && apt-get install -y \ + libgl1-mesa-glx \ + libglib2.0-0 \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +# Copy the requirements file into the container +COPY requirements.txt . + +# Install Python dependencies +# Using --no-cache-dir keeps the Docker image smaller +RUN pip install --no-cache-dir -r requirements.txt + +# Copy the rest of the backend code +COPY . . + +# Create directories for models and temporary uploads if they don't exist +RUN mkdir -p /app/temp_uploads /app/models + +# Expose port 7860 (This is the default port required by Hugging Face Spaces) +EXPOSE 7860 + +# Run the FastAPI server on port 7860 +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"] diff --git a/README.md b/README.md index 35641652196891acafe39d7bd0b8eb6c8f833b9b..f126cc5c0276bf02ba9cccb1478426d7aa9f9ce0 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,12 @@ ---- -title: Deepshield -emoji: 🏆 -colorFrom: yellow -colorTo: purple -sdk: docker -pinned: false -license: mit ---- +# backend/training -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +Training pipeline for the DeepShield image detector (BUILD_PLAN2 Phase 11). + +| Phase | Module | +|---|---| +| 11.1 Dataset procurement | [`datasets/`](./datasets/) — see [../../docs/datasets.md](../../docs/datasets.md) | +| 11.2 Training | `dataset.py`, `train_convnext.py` (pending) | +| 11.2 Calibration | `calibrate.py` (pending) | +| 11.2 Evaluation | `eval.py` (pending) | + +Run `bash datasets/procure_all.sh` to build `./data/manifest.csv`. diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/analyze.py b/analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..252939f3c6a78c228a52137684a7511b0692426c --- /dev/null +++ b/analyze.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +from typing import List + +from pydantic import BaseModel + +from schemas.common import ( + ArtifactIndicator, + ContradictingEvidence, + ExifSummary, + LLMExplainabilitySummary, + ProcessingSummary, + TrustedSource, + TruthOverride, + Verdict, + VLMBreakdown, +) + + +class SensationalismBreakdown(BaseModel): + score: int = 0 + level: str = "Low" + exclamation_count: int = 0 + caps_word_count: int = 0 + clickbait_matches: int = 0 + emotional_word_count: int = 0 + superlative_count: int = 0 + + +class ManipulationIndicatorOut(BaseModel): + pattern_type: str + matched_text: str + start_pos: int + end_pos: int + severity: str + description: str + + +class TextExplainability(BaseModel): + fake_probability: float + top_label: str + all_scores: dict = {} + keywords: List[str] = [] + sensationalism: SensationalismBreakdown = SensationalismBreakdown() + manipulation_indicators: List[ManipulationIndicatorOut] = [] + detected_language: str = "en" # ISO 639-1 code, e.g. "en", "hi" + truth_override: TruthOverride | None = None + + +class TextAnalysisResponse(BaseModel): + analysis_id: str + record_id: int = 0 + media_type: str = "text" + timestamp: str + verdict: Verdict + explainability: TextExplainability + llm_summary: LLMExplainabilitySummary | None = None + trusted_sources: List[TrustedSource] = [] + contradicting_evidence: List[ContradictingEvidence] = [] + processing_summary: ProcessingSummary + responsible_ai_notice: str = ( + "AI-based analysis may not be 100% accurate. Cross-check with trusted sources before sharing." + ) + + +class OCRBoxOut(BaseModel): + text: str + bbox: List[List[int]] + confidence: float + + +class SuspiciousPhraseOut(BaseModel): + text: str + bbox: List[List[int]] + pattern_type: str + severity: str + description: str + + +class LayoutAnomalyOut(BaseModel): + type: str + severity: str + description: str + confidence: float + + +class ScreenshotExplainability(BaseModel): + extracted_text: str = "" + ocr_boxes: List[OCRBoxOut] = [] + fake_probability: float = 0.0 + sensationalism: SensationalismBreakdown = SensationalismBreakdown() + suspicious_phrases: List[SuspiciousPhraseOut] = [] + layout_anomalies: List[LayoutAnomalyOut] = [] + keywords: List[str] = [] + detected_language: str = "en" + truth_override: TruthOverride | None = None + + +class ScreenshotAnalysisResponse(BaseModel): + analysis_id: str + record_id: int = 0 + media_type: str = "screenshot" + timestamp: str + verdict: Verdict + explainability: ScreenshotExplainability + llm_summary: LLMExplainabilitySummary | None = None + trusted_sources: List[TrustedSource] = [] + contradicting_evidence: List[ContradictingEvidence] = [] + processing_summary: ProcessingSummary + responsible_ai_notice: str = ( + "AI-based analysis may not be 100% accurate. Cross-check with trusted sources before sharing." + ) + + +class ImageExplainability(BaseModel): + heatmap_base64: str = "" + ela_base64: str = "" + boxes_base64: str = "" + heatmap_status: str = "success" # success | failed | degraded + artifact_indicators: List[ArtifactIndicator] = [] + exif: ExifSummary | None = None + llm_summary: LLMExplainabilitySummary | None = None + vlm_breakdown: VLMBreakdown | None = None + + +class FrameAnalysisOut(BaseModel): + index: int + timestamp_s: float + label: str + confidence: float + suspicious_prob: float + is_suspicious: bool + has_face: bool = False + scored: bool = False + + +class VideoExplainability(BaseModel): + num_frames_sampled: int + num_face_frames: int = 0 + num_suspicious_frames: int + mean_suspicious_prob: float + max_suspicious_prob: float + suspicious_ratio: float + insufficient_faces: bool = False + suspicious_timestamps: List[float] = [] + frames: List[FrameAnalysisOut] = [] + + +class VideoAnalysisResponse(BaseModel): + analysis_id: str + record_id: int = 0 + media_type: str = "video" + timestamp: str + verdict: Verdict + explainability: VideoExplainability + llm_summary: LLMExplainabilitySummary | None = None + trusted_sources: List[TrustedSource] = [] + contradicting_evidence: List[ContradictingEvidence] = [] + processing_summary: ProcessingSummary + responsible_ai_notice: str = ( + "AI-based analysis may not be 100% accurate. Cross-check with trusted sources before sharing." + ) + + +class ImageAnalysisResponse(BaseModel): + analysis_id: str + record_id: int = 0 + media_type: str = "image" + timestamp: str + verdict: Verdict + explainability: ImageExplainability + trusted_sources: List[TrustedSource] = [] + contradicting_evidence: List[ContradictingEvidence] = [] + processing_summary: ProcessingSummary + responsible_ai_notice: str = ( + "AI-based analysis may not be 100% accurate. Cross-check with trusted sources before sharing." + ) diff --git a/artifact_detector.py b/artifact_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..bb05435afa5d80def094f961af0390679659358e --- /dev/null +++ b/artifact_detector.py @@ -0,0 +1,229 @@ +from __future__ import annotations + +import io +from typing import List + +import numpy as np +from loguru import logger +from PIL import Image + +from schemas.common import ArtifactIndicator + + +def _severity_from_score(score: float) -> str: + if score >= 0.7: + return "high" + if score >= 0.4: + return "medium" + return "low" + + +# ---------- 1. GAN high-frequency signature (FFT) ---------- +def detect_gan_hf_artifact(pil_img: Image.Image) -> ArtifactIndicator | None: + """Compute high-frequency energy ratio on the luminance channel. + Real photos typically follow a ~1/f spectrum; many GAN outputs show + elevated HF energy or spectral peaks. + """ + try: + gray = np.asarray(pil_img.convert("L"), dtype=np.float32) + # downsample for speed + if max(gray.shape) > 512: + import cv2 + + scale = 512 / max(gray.shape) + gray = cv2.resize(gray, (int(gray.shape[1] * scale), int(gray.shape[0] * scale))) + + fft = np.fft.fftshift(np.fft.fft2(gray)) + mag = np.abs(fft) + h, w = mag.shape + cy, cx = h // 2, w // 2 + y, x = np.ogrid[:h, :w] + r = np.sqrt((x - cx) ** 2 + (y - cy) ** 2) + r_max = np.sqrt(cx * cx + cy * cy) + hf_mask = r > (0.5 * r_max) + + total = float(mag.sum() + 1e-9) + hf = float(mag[hf_mask].sum()) + ratio = hf / total # typically 0.05–0.20 for natural photos + + # normalize to [0,1] suspiciousness + score = max(0.0, min(1.0, (ratio - 0.10) / 0.20)) + sev = _severity_from_score(score) + return ArtifactIndicator( + type="gan_artifact", + severity=sev, + description=( + f"High-frequency energy ratio {ratio:.3f} — " + + ("elevated HF energy consistent with GAN/diffusion outputs" if score > 0.4 + else "natural frequency falloff") + ), + confidence=float(score), + ) + except Exception as e: # noqa: BLE001 + logger.warning(f"GAN HF detection failed: {e}") + return None + + +# ---------- 2. JPEG quantization table anomaly ---------- +_STANDARD_Q_SUMS = { # rough heuristic: camera JPEGs fall in these ranges + 50: (1500, 4500), + 75: (600, 2500), + 90: (200, 1000), + 95: (100, 600), +} + + +def detect_compression_anomaly(raw_bytes: bytes) -> ArtifactIndicator | None: + """Inspect JPEG quantization tables. Missing tables, non-standard layouts, + or re-saved tables often indicate manipulation or re-encoding. + """ + try: + img = Image.open(io.BytesIO(raw_bytes)) + if img.format != "JPEG": + return ArtifactIndicator( + type="compression", + severity="low", + description=f"Non-JPEG format ({img.format}); compression signature not available", + confidence=0.1, + ) + + q = getattr(img, "quantization", None) + if not q: + return ArtifactIndicator( + type="compression", + severity="low", + description="No JPEG quantization tables readable", + confidence=0.2, + ) + + tables = list(q.values()) + sums = [int(sum(t)) for t in tables] + num_tables = len(tables) + + # Heuristics: very low sum → very high quality (possibly re-saved); + # non-standard number of tables; extreme values. + suspicious = 0.0 + reasons: list[str] = [] + if num_tables not in (1, 2): + suspicious += 0.4 + reasons.append(f"unusual table count ({num_tables})") + if any(s < 60 for s in sums): + suspicious += 0.3 + reasons.append("very low quantization sums (possible re-encoding)") + if any(s > 8000 for s in sums): + suspicious += 0.2 + reasons.append("very high quantization sums") + + score = max(0.0, min(1.0, suspicious)) + sev = _severity_from_score(score) + desc = ( + f"JPEG Q-table sums {sums}" + + (f"; {', '.join(reasons)}" if reasons else "; within typical camera range") + ) + return ArtifactIndicator( + type="compression", + severity=sev, + description=desc, + confidence=float(score), + ) + except Exception as e: # noqa: BLE001 + logger.warning(f"Compression anomaly detection failed: {e}") + return None + + +# ---------- 3. Facial boundary + 4. Lighting (MediaPipe) ---------- +def detect_face_based_artifacts(pil_img: Image.Image) -> List[ArtifactIndicator]: + """If a face is detected, analyze jaw boundary variance and per-quadrant + luminance balance. Returns 0, 1, or 2 indicators. + """ + results: List[ArtifactIndicator] = [] + try: + import mediapipe as mp # type: ignore + + from models.model_loader import get_model_loader + + detector = get_model_loader().load_face_detector() + rgb = np.asarray(pil_img.convert("RGB")) + h, w = rgb.shape[:2] + mp_result = detector.process(rgb) + + if not mp_result.multi_face_landmarks: + return results + + landmarks = mp_result.multi_face_landmarks[0].landmark + + # ----- Jaw boundary jitter ----- + # FaceMesh jaw/oval landmark indices (approximate face contour) + JAW_IDX = [ + 10, 338, 297, 332, 284, 251, 389, 356, 454, 323, 361, + 288, 397, 365, 379, 378, 400, 377, 152, 148, 176, 149, + 150, 136, 172, 58, 132, 93, 234, 127, 162, 21, 54, 103, 67, 109, + ] + pts = np.array([(landmarks[i].x * w, landmarks[i].y * h) for i in JAW_IDX]) + # Second-difference magnitude = local curvature jitter + diffs = np.diff(pts, axis=0) + seconds = np.diff(diffs, axis=0) + jitter = float(np.linalg.norm(seconds, axis=1).mean()) / max(w, h) + jitter_score = max(0.0, min(1.0, (jitter - 0.003) / 0.010)) + results.append( + ArtifactIndicator( + type="facial_boundary", + severity=_severity_from_score(jitter_score), + description=( + f"Jaw-contour jitter {jitter:.4f} (normalized) — " + + ("inconsistent boundary blending detected" if jitter_score > 0.4 + else "face boundary appears smooth") + ), + confidence=float(jitter_score), + ) + ) + + # ----- Lighting inconsistency (per-quadrant luminance) ----- + xs = np.array([lm.x * w for lm in landmarks]) + ys = np.array([lm.y * h for lm in landmarks]) + x0, x1 = int(max(0, xs.min())), int(min(w, xs.max())) + y0, y1 = int(max(0, ys.min())), int(min(h, ys.max())) + if x1 > x0 + 4 and y1 > y0 + 4: + face_crop = rgb[y0:y1, x0:x1] + gray = 0.299 * face_crop[..., 0] + 0.587 * face_crop[..., 1] + 0.114 * face_crop[..., 2] + hh, ww = gray.shape + quads = [ + gray[: hh // 2, : ww // 2], + gray[: hh // 2, ww // 2 :], + gray[hh // 2 :, : ww // 2], + gray[hh // 2 :, ww // 2 :], + ] + means = np.array([q.mean() for q in quads if q.size > 0]) + if means.size == 4 and means.mean() > 1e-3: + imbalance = float(means.std() / means.mean()) + lighting_score = max(0.0, min(1.0, (imbalance - 0.08) / 0.20)) + results.append( + ArtifactIndicator( + type="lighting", + severity=_severity_from_score(lighting_score), + description=( + f"Luminance imbalance across face quadrants {imbalance:.3f} — " + + ("inconsistent lighting direction" if lighting_score > 0.4 + else "lighting appears uniform") + ), + confidence=float(lighting_score), + ) + ) + except Exception as e: # noqa: BLE001 + logger.warning(f"Face-based artifact detection failed: {e}") + + return results + + +# ---------- Orchestrator ---------- +def scan_artifacts(pil_img: Image.Image, raw_bytes: bytes) -> List[ArtifactIndicator]: + indicators: List[ArtifactIndicator] = [] + for fn in ( + lambda: detect_gan_hf_artifact(pil_img), + lambda: detect_compression_anomaly(raw_bytes), + ): + ind = fn() + if ind is not None: + indicators.append(ind) + indicators.extend(detect_face_based_artifacts(pil_img)) + return indicators diff --git a/auth.py b/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..ed5ddb27c8400f6bcf3556bc85e662ff28295fc1 --- /dev/null +++ b/auth.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from datetime import datetime + +from pydantic import BaseModel, EmailStr, Field + + +class RegisterBody(BaseModel): + email: EmailStr + password: str = Field(min_length=6, max_length=128) + name: str | None = Field(default=None, max_length=255) + + +class LoginBody(BaseModel): + email: EmailStr + password: str + + +class UserOut(BaseModel): + id: int + email: str + name: str | None = None + created_at: datetime + + +class TokenResponse(BaseModel): + access_token: str + token_type: str = "bearer" + expires_in_minutes: int + user: UserOut diff --git a/auth_service.py b/auth_service.py new file mode 100644 index 0000000000000000000000000000000000000000..2c63225dc14fa4b8f1c06f155b0e9386d34f6a91 --- /dev/null +++ b/auth_service.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from typing import Any + +import bcrypt +from jose import JWTError, jwt +from sqlalchemy.orm import Session + +from config import settings +from db.models import User + + +def _encode_pw(plain: str) -> bytes: + # bcrypt truncates to 72 bytes silently in some builds and hard-errors in others. + # Truncate explicitly so behavior is deterministic across versions. + return plain.encode("utf-8")[:72] + + +def hash_password(plain: str) -> str: + return bcrypt.hashpw(_encode_pw(plain), bcrypt.gensalt()).decode("utf-8") + + +def verify_password(plain: str, hashed: str) -> bool: + try: + return bcrypt.checkpw(_encode_pw(plain), hashed.encode("utf-8")) + except Exception: + return False + + +def create_access_token(user_id: int, email: str) -> str: + now = datetime.now(timezone.utc) + payload = { + "sub": str(user_id), + "email": email, + "iat": int(now.timestamp()), + "exp": int((now + timedelta(minutes=settings.JWT_EXPIRATION_MINUTES)).timestamp()), + } + return jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM) + + +def decode_token(token: str) -> dict[str, Any] | None: + try: + return jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]) + except JWTError: + return None + + +def register_user(db: Session, email: str, password: str, name: str | None) -> User: + email = email.strip().lower() + user = User(email=email, password_hash=hash_password(password), name=(name or None)) + db.add(user) + db.commit() + db.refresh(user) + return user + + +def authenticate(db: Session, email: str, password: str) -> User | None: + email = email.strip().lower() + user = db.query(User).filter(User.email == email).first() + if not user or not verify_password(password, user.password_hash): + return None + return user + + +def get_user(db: Session, user_id: int) -> User | None: + return db.query(User).filter(User.id == user_id).first() diff --git a/common.py b/common.py new file mode 100644 index 0000000000000000000000000000000000000000..612d6f1f148744d26562707e935b98a287cb5448 --- /dev/null +++ b/common.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from typing import List, Optional + +from pydantic import BaseModel, ConfigDict, Field + + +class Verdict(BaseModel): + model_config = ConfigDict(protected_namespaces=()) + + label: str + severity: str + authenticity_score: int = Field(ge=0, le=100) + model_confidence: float = Field(ge=0.0, le=1.0) + model_label: str + + +class ArtifactIndicator(BaseModel): + type: str + severity: str # low | medium | high + description: str + confidence: float = Field(ge=0.0, le=1.0) + + +class TrustedSource(BaseModel): + source_name: str + title: str + url: str + published_at: Optional[str] = None + relevance_score: float = Field(ge=0.0, le=1.0) + + +class ContradictingEvidence(BaseModel): + source_name: str + title: str + url: str + type: str = "fact_check" + + +class TruthOverride(BaseModel): + applied: bool = False + source_url: str = "" + source_name: str = "" + similarity: float = 0.0 + fake_prob_before: float = 0.0 + fake_prob_after: float = 0.0 + + +class ExifSummary(BaseModel): + make: Optional[str] = None + model: Optional[str] = None + datetime_original: Optional[str] = None + gps_info: Optional[str] = None + software: Optional[str] = None + lens_model: Optional[str] = None + trust_adjustment: int = 0 # negative = more real, positive = more fake + trust_reason: str = "" + + +class LLMExplainabilitySummary(BaseModel): + paragraph: str = "" + bullets: List[str] = [] + model_used: str = "" + cached: bool = False + + +class VLMComponentScore(BaseModel): + score: int = Field(ge=0, le=100, default=75) + notes: str = "" + + +class VLMBreakdown(BaseModel): + facial_symmetry: VLMComponentScore = VLMComponentScore() + skin_texture: VLMComponentScore = VLMComponentScore() + lighting_consistency: VLMComponentScore = VLMComponentScore() + background_coherence: VLMComponentScore = VLMComponentScore() + anatomy_hands_eyes: VLMComponentScore = VLMComponentScore() + context_objects: VLMComponentScore = VLMComponentScore() + model_used: str = "" + cached: bool = False + + +class ProcessingSummary(BaseModel): + model_config = ConfigDict(protected_namespaces=()) + + stages_completed: List[str] + total_duration_ms: int + model_used: str diff --git a/config.py b/config.py new file mode 100644 index 0000000000000000000000000000000000000000..dfa221971db898376edfcf0179c57d5a64e02ce6 --- /dev/null +++ b/config.py @@ -0,0 +1,53 @@ +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + # Server + APP_HOST: str = "0.0.0.0" + APP_PORT: int = 8000 + DEBUG: bool = False + CORS_ORIGINS: list[str] = ["http://localhost:5173"] + + # Database + DATABASE_URL: str = "sqlite:///./deepshield.db" + + # File Upload + MAX_UPLOAD_SIZE_MB: int = 100 + UPLOAD_DIR: str = "./temp_uploads" + ALLOWED_IMAGE_TYPES: list[str] = ["image/jpeg", "image/png", "image/webp"] + ALLOWED_VIDEO_TYPES: list[str] = ["video/mp4", "video/avi", "video/mov", "video/webm"] + FILE_RETENTION_SECONDS: int = 300 + + # AI Models + IMAGE_MODEL_ID: str = "prithivMLmods/Deep-Fake-Detector-v2-Model" + TEXT_MODEL_ID: str = "jy46604790/Fake-News-Bert-Detect" + # Multilingual text model for non-English (Hindi etc.). Leave empty to fall back to TEXT_MODEL_ID. + TEXT_MULTILANG_MODEL_ID: str = "" + DEVICE: str = "cpu" + PRELOAD_MODELS: bool = True # preload models at startup + + # Phase 13: OCR language list (comma-separated ISO codes, e.g. "en,hi") + OCR_LANGS: str = "en,hi" + + # News API + NEWS_API_KEY: str = "" + NEWS_API_BASE_URL: str = "https://newsdata.io/api/1/news" + + # Reports + REPORT_DIR: str = "./temp_reports" + REPORT_TTL_SECONDS: int = 3600 # 1h expiry + + # LLM Explainability (Phase 12) + LLM_PROVIDER: str = "gemini" # "gemini" | "openai" + LLM_API_KEY: str = "" + LLM_MODEL: str = "gemini-1.5-flash" # or "gpt-4o-mini" + + # Auth + JWT_SECRET_KEY: str = "change-me-in-production" + JWT_ALGORITHM: str = "HS256" + JWT_EXPIRATION_MINUTES: int = 1440 + + model_config = SettingsConfigDict(env_file=".env", extra="ignore") + + +settings = Settings() diff --git a/database.py b/database.py new file mode 100644 index 0000000000000000000000000000000000000000..eb8a87fa8ae3a141d8b3416a61c5f731811d1785 --- /dev/null +++ b/database.py @@ -0,0 +1,28 @@ +from sqlalchemy import create_engine +from sqlalchemy.orm import DeclarativeBase, sessionmaker + +from config import settings + +engine = create_engine( + settings.DATABASE_URL, + connect_args={"check_same_thread": False} if settings.DATABASE_URL.startswith("sqlite") else {}, +) + +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +class Base(DeclarativeBase): + pass + + +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() + + +def init_db(): + from db import models # noqa: F401 + Base.metadata.create_all(bind=engine) diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/datasets/build_manifest.py b/datasets/build_manifest.py new file mode 100644 index 0000000000000000000000000000000000000000..00885270483c63cde5aafaffff604f70039084ae --- /dev/null +++ b/datasets/build_manifest.py @@ -0,0 +1,93 @@ +"""Build a unified train/val/test manifest (70/15/15) across all dataset buckets. + +Expected input layout (produced by the other scripts in this package): + + data_root/ + real/ + ffpp_youtube/*.jpg # frames from FFPP original_sequences + ffhq/*.jpg # FFHQ thumbnails + + fake/ + ffpp_deepfakes/*.jpg + ffpp_face2face/*.jpg + ffpp_faceswap/*.jpg + ffpp_neuraltextures/*.jpg + ffpp_faceshifter/*.jpg + dfdc/*.jpg + +The manifest is stratified by (label, source) so FFHQ stays represented +in val/test. + +Usage: + python -m backend.training.datasets.build_manifest \ + --data ./data --out ./data/manifest.csv --seed 42 +""" +from __future__ import annotations + +import argparse +import csv +import random +from collections import defaultdict +from pathlib import Path + +IMG_EXTS = {".jpg", ".jpeg", ".png"} + + +def collect(data_root: Path) -> list[tuple[str, str, str]]: + rows: list[tuple[str, str, str]] = [] + for label in ("real", "fake"): + label_root = data_root / label + if not label_root.exists(): + continue + for source_dir in sorted(p for p in label_root.iterdir() if p.is_dir()): + for img in source_dir.rglob("*"): + if img.suffix.lower() in IMG_EXTS and img.is_file(): + rows.append((str(img.resolve()), label, source_dir.name)) + return rows + + +def split(rows: list[tuple[str, str, str]], seed: int) -> dict[str, list[tuple[str, str, str]]]: + buckets: dict[tuple[str, str], list[tuple[str, str, str]]] = defaultdict(list) + for r in rows: + buckets[(r[1], r[2])].append(r) + + rng = random.Random(seed) + out = {"train": [], "val": [], "test": []} + for key, items in buckets.items(): + rng.shuffle(items) + n = len(items) + n_train = int(0.70 * n) + n_val = int(0.15 * n) + out["train"].extend(items[:n_train]) + out["val"].extend(items[n_train : n_train + n_val]) + out["test"].extend(items[n_train + n_val :]) + return out + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--data", required=True, type=Path) + ap.add_argument("--out", required=True, type=Path) + ap.add_argument("--seed", type=int, default=42) + args = ap.parse_args() + + rows = collect(args.data) + if not rows: + raise SystemExit(f"No images found under {args.data}") + + splits = split(rows, args.seed) + args.out.parent.mkdir(parents=True, exist_ok=True) + with args.out.open("w", newline="", encoding="utf-8") as f: + w = csv.writer(f) + w.writerow(["path", "label", "source", "split"]) + for name, items in splits.items(): + for path, label, source in items: + w.writerow([path, label, source, name]) + + summary = {k: len(v) for k, v in splits.items()} + print(f"Manifest: {args.out}") + print(f"Totals: {summary} (overall {sum(summary.values())})") + + +if __name__ == "__main__": + main() diff --git a/datasets/download_dfdc_sample.py b/datasets/download_dfdc_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..290f639f16430ea03bec20cad71c34b4f4a3a898 --- /dev/null +++ b/datasets/download_dfdc_sample.py @@ -0,0 +1,44 @@ +"""Download a sample of the DFDC (Deepfake Detection Challenge) Preview dataset. + +The full DFDC is ~470GB; the *preview* release (~5GB, Kaggle) is enough for +diversity augmentation alongside FFPP. + +Requires the Kaggle CLI (`pip install kaggle`) and ~/.kaggle/kaggle.json. + +Usage: + python -m backend.training.datasets.download_dfdc_sample --output ./data/dfdc_preview +""" +from __future__ import annotations + +import argparse +import shutil +import subprocess +import sys +from pathlib import Path + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--output", required=True, type=Path) + ap.add_argument( + "--competition", + default="deepfake-detection-challenge", + help="Kaggle competition slug (default: deepfake-detection-challenge preview).", + ) + args = ap.parse_args() + + kaggle = shutil.which("kaggle") + if kaggle is None: + print("Kaggle CLI not found. Install with: pip install kaggle", file=sys.stderr) + print("Then place kaggle.json in ~/.kaggle/ (chmod 600).", file=sys.stderr) + sys.exit(2) + + args.output.mkdir(parents=True, exist_ok=True) + cmd = [kaggle, "competitions", "download", "-c", args.competition, "-p", str(args.output)] + print("Running:", " ".join(cmd)) + subprocess.run(cmd, check=True) + print(f"Downloaded to {args.output}. Unzip with: unzip *.zip") + + +if __name__ == "__main__": + main() diff --git a/datasets/download_ffhq.py b/datasets/download_ffhq.py new file mode 100644 index 0000000000000000000000000000000000000000..9aad01da57b77488e4d3113295cd2769c3376826 --- /dev/null +++ b/datasets/download_ffhq.py @@ -0,0 +1,49 @@ +"""Download the FFHQ 128x128 thumbnail subset from the official Google Drive mirror. + +Pulls up to N images (default 10k) into the `real` bucket of the training set. +Falls back to the NVlabs 'ffhq-dataset' helper if available; otherwise expects +user to run the manual download once. + +Usage: + python -m backend.training.datasets.download_ffhq --output ./data/real/ffhq -n 10000 +""" +from __future__ import annotations + +import argparse +import shutil +import subprocess +import sys +from pathlib import Path + + +def try_nvlabs_helper(output: Path, num: int) -> bool: + """Prefer the official ffhq-dataset downloader if installed.""" + helper = shutil.which("ffhq-dataset") + if helper is None: + return False + cmd = [helper, "--json", "ffhq-dataset-v2.json", "--thumbs", "--num_threads", "4"] + print("Running:", " ".join(cmd)) + subprocess.run(cmd, cwd=output, check=False) + return True + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--output", required=True, type=Path) + ap.add_argument("-n", "--num", type=int, default=10000) + args = ap.parse_args() + args.output.mkdir(parents=True, exist_ok=True) + + if try_nvlabs_helper(args.output, args.num): + return + + print("[!] `ffhq-dataset` helper not installed.") + print(" Install via: pip install ffhq-dataset (requires gdown)") + print(" Or download thumbnails128x128.zip manually from:") + print(" https://github.com/NVlabs/ffhq-dataset") + print(f" Extract into: {args.output}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/datasets/extract_frames.py b/datasets/extract_frames.py new file mode 100644 index 0000000000000000000000000000000000000000..28bebebb62c2676744a72e1f68ffdf4f5d74c2f4 --- /dev/null +++ b/datasets/extract_frames.py @@ -0,0 +1,90 @@ +"""Convert FFPP / DFDC videos -> 16 sampled frames at 224x224 RGB. + +Usage: + python -m backend.training.datasets.extract_frames \ + --input ./ffpp_data/original_sequences/youtube/raw/videos \ + --output ./ffpp_data/frames/real \ + --label real --frames 16 --size 224 +""" +from __future__ import annotations + +import argparse +import csv +from pathlib import Path + +import cv2 +import numpy as np +from tqdm import tqdm + + +def sample_frame_indices(total: int, n: int) -> list[int]: + if total <= 0: + return [] + if total <= n: + return list(range(total)) + step = total / float(n) + return [min(total - 1, int(step * i + step / 2)) for i in range(n)] + + +def extract_from_video(path: Path, out_dir: Path, n: int, size: int) -> int: + cap = cv2.VideoCapture(str(path)) + if not cap.isOpened(): + return 0 + total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + indices = set(sample_frame_indices(total, n)) + out_dir.mkdir(parents=True, exist_ok=True) + + saved = 0 + i = 0 + while True: + ok, frame = cap.read() + if not ok: + break + if i in indices: + frame = cv2.resize(frame, (size, size), interpolation=cv2.INTER_AREA) + cv2.imwrite(str(out_dir / f"{path.stem}_f{i:06d}.jpg"), frame, [cv2.IMWRITE_JPEG_QUALITY, 95]) + saved += 1 + i += 1 + cap.release() + return saved + + +def main() -> None: + ap = argparse.ArgumentParser(description="Sample N frames per video and resize.") + ap.add_argument("--input", required=True, type=Path, help="Directory of .mp4 videos (recursive).") + ap.add_argument("--output", required=True, type=Path, help="Directory to write .jpg frames.") + ap.add_argument("--label", required=True, choices=["real", "fake"], help="Label tag for manifest.") + ap.add_argument("--frames", type=int, default=16) + ap.add_argument("--size", type=int, default=224) + ap.add_argument("--manifest", type=Path, default=None, help="Optional CSV manifest append path.") + args = ap.parse_args() + + videos = [p for p in args.input.rglob("*.mp4")] + if not videos: + print(f"No .mp4 found under {args.input}") + return + + rows: list[tuple[str, str, str]] = [] + total_frames = 0 + for vid in tqdm(videos, desc=f"extract[{args.label}]"): + rel_out = args.output / vid.stem + saved = extract_from_video(vid, rel_out, args.frames, args.size) + total_frames += saved + if args.manifest is not None: + for jpg in rel_out.glob("*.jpg"): + rows.append((str(jpg), args.label, vid.stem)) + + if args.manifest is not None and rows: + args.manifest.parent.mkdir(parents=True, exist_ok=True) + new_file = not args.manifest.exists() + with args.manifest.open("a", newline="", encoding="utf-8") as f: + w = csv.writer(f) + if new_file: + w.writerow(["path", "label", "source_video"]) + w.writerows(rows) + + print(f"Done. Videos: {len(videos)}, frames written: {total_frames}") + + +if __name__ == "__main__": + main() diff --git a/datasets/procure_all.ps1 b/datasets/procure_all.ps1 new file mode 100644 index 0000000000000000000000000000000000000000..4edafaffd1c99062de0244d2c7a3db1cf6813cc6 --- /dev/null +++ b/datasets/procure_all.ps1 @@ -0,0 +1,40 @@ +# Phase 11.1 orchestrator for Windows (PowerShell) +$ErrorActionPreference = "Stop" + +$ROOT = if ($env:ROOT) { $env:ROOT } else { ".\data" } +$FFPP = if ($env:FFPP) { $env:FFPP } else { ".\ffpp_data" } + +New-Item -ItemType Directory -Force -Path "$ROOT\real" | Out-Null +New-Item -ItemType Directory -Force -Path "$ROOT\fake" | Out-Null +New-Item -ItemType Directory -Force -Path $FFPP | Out-Null + +Write-Host "1. FaceForensics++ (highly compressed c40, 10 videos only) -- requires TOS keypress" +python backend\scripts\download_ffpp.py $FFPP -d all -c c40 -t videos -n 10 + +Write-Host "2. Frame extraction: real (original youtube)" +python -m backend.training.datasets.extract_frames ` + --input "$FFPP\original_sequences\youtube\c40\videos" ` + --output "$ROOT\real\ffpp_youtube" --label real --frames 4 --size 224 + +Write-Host "3. Frame extraction: fakes (each manipulation family)" +$Families = @("Deepfakes", "Face2Face", "FaceSwap", "NeuralTextures", "FaceShifter") +foreach ($fam in $Families) { + $famLower = $fam.ToLower() + python -m backend.training.datasets.extract_frames ` + --input "$FFPP\manipulated_sequences\$fam\c40\videos" ` + --output "$ROOT\fake\ffpp_$famLower" --label fake --frames 4 --size 224 +} + +Write-Host "4. FFHQ thumbnails (real - limited to 100 items)" +python -m backend.training.datasets.download_ffhq --output "$ROOT\real\ffhq" -n 100 + + +Write-Host "6. DFDC preview sample (fake+real)" +python -m backend.training.datasets.download_dfdc_sample --output "$ROOT\_dfdc_raw" +Write-Host "NOTE: You will need to manually unzip + sort DFDC into $ROOT\fake\dfdc and $ROOT\real\dfdc" + +Write-Host "7. Build manifest" +python -m backend.training.datasets.build_manifest ` + --data $ROOT --out "$ROOT\manifest.csv" --seed 42 + +Write-Host "Phase 11.1 complete. See $ROOT\manifest.csv" diff --git a/datasets/procure_all.sh b/datasets/procure_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..fa2f94ecf2959a92df56d31572c4c26c9256b0d1 --- /dev/null +++ b/datasets/procure_all.sh @@ -0,0 +1,37 @@ +#!/usr/bin/env bash +# Phase 11.1 orchestrator: download + frame-extract + manifest. +# Total disk target: ~120k labeled images. Expect 60-80GB intermediate, ~30GB frames. + +set -euo pipefail + +ROOT="${ROOT:-./data}" +FFPP="${FFPP:-./ffpp_data}" +mkdir -p "$ROOT/real" "$ROOT/fake" "$FFPP" + +# 1. FaceForensics++ (raw, videos) -- requires TOS keypress +python backend/scripts/download_ffpp.py "$FFPP" -d all -c raw -t videos + +# 2. Frame extraction: real (original youtube) +python -m backend.training.datasets.extract_frames \ + --input "$FFPP/original_sequences/youtube/raw/videos" \ + --output "$ROOT/real/ffpp_youtube" --label real --frames 16 --size 224 + +# 3. Frame extraction: fakes (each manipulation family) +for fam in Deepfakes Face2Face FaceSwap NeuralTextures FaceShifter; do + python -m backend.training.datasets.extract_frames \ + --input "$FFPP/manipulated_sequences/$fam/raw/videos" \ + --output "$ROOT/fake/ffpp_${fam,,}" --label fake --frames 16 --size 224 +done + +# 4. FFHQ thumbnails (real) +python -m backend.training.datasets.download_ffhq --output "$ROOT/real/ffhq" -n 10000 + +# 6. DFDC preview sample (fake+real) -- needs Kaggle creds +python -m backend.training.datasets.download_dfdc_sample --output "$ROOT/_dfdc_raw" +# NOTE: unzip + sort into $ROOT/fake/dfdc and $ROOT/real/dfdc per DFDC metadata.json + +# 7. Build manifest +python -m backend.training.datasets.build_manifest \ + --data "$ROOT" --out "$ROOT/manifest.csv" --seed 42 + +echo "Phase 11.1 complete. See $ROOT/manifest.csv" diff --git a/deepshield_13_5bcf1328.pdf b/deepshield_13_5bcf1328.pdf new file mode 100644 index 0000000000000000000000000000000000000000..dee4c670b4e4e2064488fe79f0eec4ac48bf39ea --- /dev/null +++ b/deepshield_13_5bcf1328.pdf @@ -0,0 +1,148 @@ +%PDF-1.4 +% ReportLab Generated PDF document (opensource) +1 0 obj +<< +/F1 2 0 R /F2 3 0 R /F3 5 0 R +>> +endobj +2 0 obj +<< +/BaseFont /Helvetica /Encoding /WinAnsiEncoding /Name /F1 /Subtype /Type1 /Type /Font +>> +endobj +3 0 obj +<< +/BaseFont /Helvetica-Bold /Encoding /WinAnsiEncoding /Name /F2 /Subtype /Type1 /Type /Font +>> +endobj +4 0 obj +<< +/Contents 18 0 R /MediaBox [ 0 0 595.2756 841.8898 ] /Parent 17 0 R /Resources << +/Font 1 0 R /ProcSet [ /PDF /Text /ImageB /ImageC /ImageI ] +>> /Rotate 0 /Trans << + +>> + /Type /Page +>> +endobj +5 0 obj +<< +/BaseFont /Symbol /Name /F3 /Subtype /Type1 /Type /Font +>> +endobj +6 0 obj +<< +/Contents 19 0 R /MediaBox [ 0 0 595.2756 841.8898 ] /Parent 17 0 R /Resources << +/Font 1 0 R /ProcSet [ /PDF /Text /ImageB /ImageC /ImageI ] +>> /Rotate 0 /Trans << + +>> + /Type /Page +>> +endobj +7 0 obj +<< +/Outlines 9 0 R /PageMode /UseNone /Pages 17 0 R /Type /Catalog +>> +endobj +8 0 obj +<< +/Author () /CreationDate (D:20260415181653+05'00') /Creator (\(unspecified\)) /Keywords () /ModDate (D:20260415181653+05'00') /Producer (xhtml2pdf ) + /Subject () /Title (DeepShield Analysis Report \204 7771f496-45b1-4c97-8a1a-d9d2492ca67d) /Trapped /False +>> +endobj +9 0 obj +<< +/Count 3 /First 10 0 R /Last 10 0 R /Type /Outlines +>> +endobj +10 0 obj +<< +/Count -4 /Dest [ 4 0 R /Fit ] /First 11 0 R /Last 16 0 R /Parent 9 0 R /Title (DeepShield Analysis Report) +>> +endobj +11 0 obj +<< +/Dest [ 4 0 R /Fit ] /Next 12 0 R /Parent 10 0 R /Title (Verdict) +>> +endobj +12 0 obj +<< +/Count -2 /Dest [ 4 0 R /Fit ] /First 13 0 R /Last 14 0 R /Next 15 0 R /Parent 10 0 R + /Prev 11 0 R /Title (Text Classification) +>> +endobj +13 0 obj +<< +/Dest [ 4 0 R /Fit ] /Next 14 0 R /Parent 12 0 R /Title (Sensationalism Signals) +>> +endobj +14 0 obj +<< +/Dest [ 4 0 R /Fit ] /Parent 12 0 R /Prev 13 0 R /Title (Extracted Keywords) +>> +endobj +15 0 obj +<< +/Dest [ 4 0 R /Fit ] /Next 16 0 R /Parent 10 0 R /Prev 12 0 R /Title (Trusted Source Cross-Reference \(1\)) +>> +endobj +16 0 obj +<< +/Dest [ 6 0 R /Fit ] /Parent 10 0 R /Prev 15 0 R /Title (Processing Summary) +>> +endobj +17 0 obj +<< +/Count 2 /Kids [ 4 0 R 6 0 R ] /Type /Pages +>> +endobj +18 0 obj +<< +/Filter [ /ASCII85Decode /FlateDecode ] /Length 1750 +>> +stream +Gb"/(9lo&I&A@sBlm4G[Acr2Y4p^$ca2t\gAsuiHo\c,I9gURE8lSA3M>qu?,XkR;()9nE&%0G$"Ts\%gUFdJ0E[3iXSb#I!k]Slq-+&^_fu5V&-:f'>`[5155TjpXI_!]U"iQd1qrcX0jNK021sk.K_S`f[kfkaR[pr2$LLU)UX&`3>7R17rJ3t':B_<4Kk*Grr8\a:5/Z<<[I]mbfHq28c@Y+3O)t)0k@mu0K^fiq^N*(u.%T.'jlS?L^o+>>SgBV8H:sX>5A0-l`)&\h4Lk6L5I=)ArV#_bh%^>M_c,"jSErfH[2A&CfKtLn_&K3h)!u;:i'6.H*(apE@/QWkIgF*OaTZ"ZT=me'_?iN-hL[(uHeb"'/B!\/7d068ieW>Y3P8NcsU#;"%eOe_!^-"Xsc?9a'H,u4"nMEm$3F[>c1S8J!`Sh;Ye8pG>de>ac3KpI*&j-(`*[@OB&i#OgJSl=(I-';4Vs.^rc%L+kt99^Gd]mfUsWoLD02jLH*WUl.Pb(oF^j?7RUN!m&Us22M!@Ald**8+J._-f-FEVm$t<`HO6GNqd_[bhJ&8qK0d-ZKt;EB60ud0-2Z:*Z]IT(dG)'7QU\#u^ecY/FgdnO#RWf_=Js*t;iiO?'fQ:g&@nC/Xhu.;&o1b+?_6-Z%i4;1H5GAUag0*4LfL'2;Sl`["O/H6p>jU\SO4%Ffq^-']muUp/PKbuj>J71&Mh5t,WF_k&]O@P+do^;.WV"r6Kkb#5`,aF$-adPdc+'072](pse[q;.^?I#Q#kci1Qr9Z_U:Q_lQ53n!nIBHrchNfMeP-HF*=<22XdSrZ8j>sP4CR1SEP\Ge.aCh(VEW.)F'<]`"gVnaq<<]K,.uCIMlUqSgV3UTe;V8("S^2/7e`3>4E]],alEY#@T-dG.(=/^7(s[bh3%omN/'WKl<"q_K`T7$VrMt.GfckX6]1EfAB]1F6o6g>\:2Etf)rD.XNrRc2pgl"Hr<(1MCd%~>endstream +endobj +19 0 obj +<< +/Filter [ /ASCII85Decode /FlateDecode ] /Length 1251 +>> +stream +Gau`R;01GN&:Vs/fU'm&SZsB\Z>@pd[^l$Ne'"!6Hco+&(^1n.H%;Q95P;kU[/"Vgs.N%@'=M6kAJN1afF&?E_+rA+1KE+S:4],1QpOr^qg01e<#d,;@\e=!\1-*,1T[41J&^DSg86dC5.#&+tMiZhie$%p]f=sWJ!9ni#^ZR?Gp5lVJY,MW+]dGA*V*5[2WS\gs>9t%t32b/^W)[_+r7&3kOLD>8WTI508QU_ZkVRb*l"j_,ie@Wk/$,J'=rjAsRr^aIAp,g4N\@rcW@_7fV)G7.f:C\2aDCnK2"(-Yh-fNKV4ogPJ_Bbno/AG^W)=l`02mHESBSd,2MW2Q,8S^O,7f_^Pj+'$c\[n!'TZ'8A[[6$M/6Vlo9egXU318J0Zl;rXSYgM=-\-3TecfRc]m]FKNI.=E4amT3\PSaWQi;TtrPVN"#t`E;bkNM&M.:/OC)MK2$$?Jp$`SY/%t"jbj6*+.%6.71qjEsp)j@\0#RIF/1!&^q"O7Ou;8DL^2(?$>18.AWa`H$Fi,Ak&SQPl+Y^;rG>nArp/_q%9B[r]_;\_^p'[__7OH7)iuf]c[rld?RB/MrP3T8Xk7VY%=qG1""FA,mioCp,lF3^-AZtKRg/NFX>&kA^rZpnFAendstream +endobj +xref +0 20 +0000000000 65535 f +0000000061 00000 n +0000000112 00000 n +0000000219 00000 n +0000000331 00000 n +0000000536 00000 n +0000000613 00000 n +0000000818 00000 n +0000000903 00000 n +0000001223 00000 n +0000001296 00000 n +0000001426 00000 n +0000001514 00000 n +0000001667 00000 n +0000001770 00000 n +0000001869 00000 n +0000001999 00000 n +0000002098 00000 n +0000002164 00000 n +0000004006 00000 n +trailer +<< +/ID +[<8e273c2672d813e3cd44109eb1edd604><8e273c2672d813e3cd44109eb1edd604>] +% ReportLab generated PDF document -- digest (opensource) + +/Info 8 0 R +/Root 7 0 R +/Size 20 +>> +startxref +5349 +%%EOF diff --git a/deps.py b/deps.py new file mode 100644 index 0000000000000000000000000000000000000000..776c7ceb8b1184c3e0e03627b7757a6b34e2bc7b --- /dev/null +++ b/deps.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from fastapi import Depends, Header, HTTPException, status +from sqlalchemy.orm import Session + +from db.database import get_db +from db.models import User +from services.auth_service import decode_token, get_user + + +def _extract_bearer(authorization: str | None) -> str | None: + if not authorization: + return None + parts = authorization.split() + if len(parts) != 2 or parts[0].lower() != "bearer": + return None + return parts[1] + + +def get_current_user( + authorization: str | None = Header(default=None), + db: Session = Depends(get_db), +) -> User: + token = _extract_bearer(authorization) + if not token: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing bearer token") + payload = decode_token(token) + if not payload or "sub" not in payload: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired token") + user = get_user(db, int(payload["sub"])) + if not user: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found") + return user + + +def optional_current_user( + authorization: str | None = Header(default=None), + db: Session = Depends(get_db), +) -> User | None: + token = _extract_bearer(authorization) + if not token: + return None + payload = decode_token(token) + if not payload or "sub" not in payload: + return None + return get_user(db, int(payload["sub"])) diff --git a/download_ffpp.py b/download_ffpp.py new file mode 100644 index 0000000000000000000000000000000000000000..4b86da46b8afaa685f312a5e6617e8dd75add856 --- /dev/null +++ b/download_ffpp.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python +""" Downloads FaceForensics++ and Deep Fake Detection public data release +Example usage: + see -h or https://github.com/ondyari/FaceForensics +""" +# -*- coding: utf-8 -*- +import argparse +import os +import urllib +import urllib.request +import tempfile +import time +import sys +import json +import random +from tqdm import tqdm +from os.path import join + + +# URLs and filenames +FILELIST_URL = 'misc/filelist.json' +DEEPFEAKES_DETECTION_URL = 'misc/deepfake_detection_filenames.json' +DEEPFAKES_MODEL_NAMES = ['decoder_A.h5', 'decoder_B.h5', 'encoder.h5',] + +# Parameters +DATASETS = { + 'original_youtube_videos': 'misc/downloaded_youtube_videos.zip', + 'original_youtube_videos_info': 'misc/downloaded_youtube_videos_info.zip', + 'original': 'original_sequences/youtube', + 'DeepFakeDetection_original': 'original_sequences/actors', + 'Deepfakes': 'manipulated_sequences/Deepfakes', + 'DeepFakeDetection': 'manipulated_sequences/DeepFakeDetection', + 'Face2Face': 'manipulated_sequences/Face2Face', + 'FaceShifter': 'manipulated_sequences/FaceShifter', + 'FaceSwap': 'manipulated_sequences/FaceSwap', + 'NeuralTextures': 'manipulated_sequences/NeuralTextures' + } +ALL_DATASETS = ['original', 'DeepFakeDetection_original', 'Deepfakes', + 'DeepFakeDetection', 'Face2Face', 'FaceShifter', 'FaceSwap', + 'NeuralTextures'] +COMPRESSION = ['raw', 'c23', 'c40'] +TYPE = ['videos', 'masks', 'models'] +SERVERS = ['EU', 'EU2', 'CA'] + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Downloads FaceForensics v2 public data release.', + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument('output_path', type=str, help='Output directory.') + parser.add_argument('-d', '--dataset', type=str, default='all', + help='Which dataset to download, either pristine or ' + 'manipulated data or the downloaded youtube ' + 'videos.', + choices=list(DATASETS.keys()) + ['all'] + ) + parser.add_argument('-c', '--compression', type=str, default='raw', + help='Which compression degree. All videos ' + 'have been generated with h264 with a varying ' + 'codec. Raw (c0) videos are lossless compressed.', + choices=COMPRESSION + ) + parser.add_argument('-t', '--type', type=str, default='videos', + help='Which file type, i.e. videos, masks, for our ' + 'manipulation methods, models, for Deepfakes.', + choices=TYPE + ) + parser.add_argument('-n', '--num_videos', type=int, default=None, + help='Select a number of videos number to ' + "download if you don't want to download the full" + ' dataset.') + parser.add_argument('--server', type=str, default='EU', + help='Server to download the data from. If you ' + 'encounter a slow download speed, consider ' + 'changing the server.', + choices=SERVERS + ) + args = parser.parse_args() + + # URLs + server = args.server + if server == 'EU': + server_url = 'http://canis.vc.in.tum.de:8100/' + elif server == 'EU2': + server_url = 'http://kaldir.vc.in.tum.de/faceforensics/' + elif server == 'CA': + server_url = 'http://falas.cmpt.sfu.ca:8100/' + else: + raise Exception('Wrong server name. Choices: {}'.format(str(SERVERS))) + args.tos_url = server_url + 'webpage/FaceForensics_TOS.pdf' + args.base_url = server_url + 'v3/' + args.deepfakes_model_url = server_url + 'v3/manipulated_sequences/' + \ + 'Deepfakes/models/' + + return args + + +def download_files(filenames, base_url, output_path, report_progress=True): + os.makedirs(output_path, exist_ok=True) + if report_progress: + filenames = tqdm(filenames) + for filename in filenames: + download_file(base_url + filename, join(output_path, filename)) + + +def reporthook(count, block_size, total_size): + global start_time + if count == 0: + start_time = time.time() + return + duration = time.time() - start_time + progress_size = int(count * block_size) + speed = int(progress_size / (1024 * duration)) + percent = int(count * block_size * 100 / total_size) + sys.stdout.write("\rProgress: %d%%, %d MB, %d KB/s, %d seconds passed" % + (percent, progress_size / (1024 * 1024), speed, duration)) + sys.stdout.flush() + + +def download_file(url, out_file, report_progress=False): + out_dir = os.path.dirname(out_file) + if not os.path.isfile(out_file): + fh, out_file_tmp = tempfile.mkstemp(dir=out_dir) + f = os.fdopen(fh, 'w') + f.close() + if report_progress: + urllib.request.urlretrieve(url, out_file_tmp, + reporthook=reporthook) + else: + urllib.request.urlretrieve(url, out_file_tmp) + os.rename(out_file_tmp, out_file) + else: + tqdm.write('WARNING: skipping download of existing file ' + out_file) + + +def main(args): + # TOS + print('By pressing any key to continue you confirm that you have agreed '\ + 'to the FaceForensics terms of use as described at:') + print(args.tos_url) + print('***') + print('Press any key to continue, or CTRL-C to exit.') + _ = input('') + + # Extract arguments + c_datasets = [args.dataset] if args.dataset != 'all' else ALL_DATASETS + c_type = args.type + c_compression = args.compression + num_videos = args.num_videos + output_path = args.output_path + os.makedirs(output_path, exist_ok=True) + + # Check for special dataset cases + for dataset in c_datasets: + dataset_path = DATASETS[dataset] + # Special cases + if 'original_youtube_videos' in dataset: + # Here we download the original youtube videos zip file + print('Downloading original youtube videos.') + if not 'info' in dataset_path: + print('Please be patient, this may take a while (~40gb)') + suffix = '' + else: + suffix = 'info' + download_file(args.base_url + '/' + dataset_path, + out_file=join(output_path, + 'downloaded_videos{}.zip'.format( + suffix)), + report_progress=True) + return + + # Else: regular datasets + print('Downloading {} of dataset "{}"'.format( + c_type, dataset_path + )) + + # Get filelists and video lenghts list from server + if 'DeepFakeDetection' in dataset_path or 'actors' in dataset_path: + filepaths = json.loads(urllib.request.urlopen(args.base_url + '/' + + DEEPFEAKES_DETECTION_URL).read().decode("utf-8")) + if 'actors' in dataset_path: + filelist = filepaths['actors'] + else: + filelist = filepaths['DeepFakesDetection'] + elif 'original' in dataset_path: + # Load filelist from server + file_pairs = json.loads(urllib.request.urlopen(args.base_url + '/' + + FILELIST_URL).read().decode("utf-8")) + filelist = [] + for pair in file_pairs: + filelist += pair + else: + # Load filelist from server + file_pairs = json.loads(urllib.request.urlopen(args.base_url + '/' + + FILELIST_URL).read().decode("utf-8")) + # Get filelist + filelist = [] + for pair in file_pairs: + filelist.append('_'.join(pair)) + if c_type != 'models': + filelist.append('_'.join(pair[::-1])) + # Maybe limit number of videos for download + if num_videos is not None and num_videos > 0: + print('Downloading the first {} videos'.format(num_videos)) + filelist = filelist[:num_videos] + + # Server and local paths + dataset_videos_url = args.base_url + '{}/{}/{}/'.format( + dataset_path, c_compression, c_type) + dataset_mask_url = args.base_url + '{}/{}/videos/'.format( + dataset_path, 'masks', c_type) + + if c_type == 'videos': + dataset_output_path = join(output_path, dataset_path, c_compression, + c_type) + print('Output path: {}'.format(dataset_output_path)) + filelist = [filename + '.mp4' for filename in filelist] + download_files(filelist, dataset_videos_url, dataset_output_path) + elif c_type == 'masks': + dataset_output_path = join(output_path, dataset_path, c_type, + 'videos') + print('Output path: {}'.format(dataset_output_path)) + if 'original' in dataset: + if args.dataset != 'all': + print('Only videos available for original data. Aborting.') + return + else: + print('Only videos available for original data. ' + 'Skipping original.\n') + continue + if 'FaceShifter' in dataset: + print('Masks not available for FaceShifter. Aborting.') + return + filelist = [filename + '.mp4' for filename in filelist] + download_files(filelist, dataset_mask_url, dataset_output_path) + + # Else: models for deepfakes + else: + if dataset != 'Deepfakes' and c_type == 'models': + print('Models only available for Deepfakes. Aborting') + return + dataset_output_path = join(output_path, dataset_path, c_type) + print('Output path: {}'.format(dataset_output_path)) + + # Get Deepfakes models + for folder in tqdm(filelist): + folder_filelist = DEEPFAKES_MODEL_NAMES + + # Folder paths + folder_base_url = args.deepfakes_model_url + folder + '/' + folder_dataset_output_path = join(dataset_output_path, + folder) + download_files(folder_filelist, folder_base_url, + folder_dataset_output_path, + report_progress=False) # already done + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/ela_service.py b/ela_service.py new file mode 100644 index 0000000000000000000000000000000000000000..e937502d11a21c347c224611a155047d8c88bbfc --- /dev/null +++ b/ela_service.py @@ -0,0 +1,88 @@ +"""Error Level Analysis (ELA) — Phase 12.1 + +Re-saves an image at a fixed JPEG quality and diffs against the original to reveal +per-pixel manipulation artifacts. Regions that were recently edited will show +higher error levels than untouched areas. +""" + +from __future__ import annotations + +import base64 +import io + +import cv2 +import numpy as np +from loguru import logger +from PIL import Image + + +def _compute_ela(pil_img: Image.Image, quality: int = 90, scale: float = 15.0) -> np.ndarray: + """Return an ELA difference map as a uint8 (H,W,3) RGB array. + + Args: + pil_img: Input image (any format — converted to RGB internally). + quality: JPEG re-save quality level (lower = more aggressive compression). + scale: Amplification factor for the difference (higher = more contrast). + + Returns: + Difference image as uint8 (H,W,3) array. + """ + rgb = pil_img.convert("RGB") + + # Re-save at specified JPEG quality into an in-memory buffer + buf = io.BytesIO() + rgb.save(buf, format="JPEG", quality=quality) + buf.seek(0) + resaved = Image.open(buf).convert("RGB") + + original_arr = np.array(rgb, dtype=np.float32) + resaved_arr = np.array(resaved, dtype=np.float32) + + # Per-pixel absolute difference, amplified + diff = np.abs(original_arr - resaved_arr) * scale + diff = np.clip(diff, 0, 255).astype(np.uint8) + + return diff + + +def generate_ela_base64(pil_img: Image.Image, quality: int = 90, scale: float = 15.0) -> str: + """Produce a base64 data-URL PNG of the ELA difference map. + + Regions with higher error levels (brighter in the output) are more likely + to have been digitally manipulated. + """ + diff = _compute_ela(pil_img, quality=quality, scale=scale) + + buf = io.BytesIO() + Image.fromarray(diff).save(buf, format="PNG") + b64 = base64.b64encode(buf.getvalue()).decode("ascii") + + logger.info(f"ELA map generated ({diff.shape[1]}x{diff.shape[0]})") + return f"data:image/png;base64,{b64}" + + +def generate_blended_ela_base64( + pil_img: Image.Image, + gradcam_weight: float = 0.6, + ela_weight: float = 0.4, + quality: int = 90, + scale: float = 15.0, +) -> str: + """Blend Grad-CAM heatmap overlay with ELA at specified weights. + + This is a utility for the 'blended' mode — it composites the ELA + difference map on top of the original image for visual clarity. + """ + rgb = pil_img.convert("RGB") + original_arr = np.array(rgb, dtype=np.float32) + ela_arr = _compute_ela(pil_img, quality=quality, scale=scale).astype(np.float32) + + # Blend: overlay ELA on the original for visual context + blended = np.clip(original_arr * 0.5 + ela_arr * 0.5, 0, 255).astype(np.uint8) + + buf = io.BytesIO() + Image.fromarray(blended).save(buf, format="PNG") + b64 = base64.b64encode(buf.getvalue()).decode("ascii") + + logger.info(f"Blended ELA generated ({blended.shape[1]}x{blended.shape[0]})") + return f"data:image/png;base64,{b64}" diff --git a/exif_service.py b/exif_service.py new file mode 100644 index 0000000000000000000000000000000000000000..61c69ef39ff455182b5b1f0ab151f0d5aac8bbff --- /dev/null +++ b/exif_service.py @@ -0,0 +1,129 @@ +"""EXIF Metadata Extraction — Phase 12.2 + +Extracts camera metadata from uploaded images and computes a trust adjustment +score: presence of authentic camera metadata lowers fake probability, while +evidence of editing software raises it. +""" + +from __future__ import annotations + +from typing import Optional + +from loguru import logger +from PIL import Image +from PIL.ExifTags import TAGS, GPSTAGS + +from schemas.common import ExifSummary + + +# Software strings that suggest post-processing / generation +_SUSPICIOUS_SOFTWARE = { + "adobe photoshop", "photoshop", "gimp", "affinity photo", + "stable diffusion", "midjourney", "dall-e", "comfyui", + "automatic1111", "invokeai", +} + +# Software strings that are normal camera firmware +_CAMERA_SOFTWARE = { + "ver.", "firmware", "camera", "dji", "gopro", +} + + +def _decode_gps(gps_info: dict) -> Optional[str]: + """Decode EXIF GPSInfo dict into a human-readable lat/lon string.""" + try: + def _to_decimal(values, ref): + d, m, s = [float(v) for v in values] + decimal = d + m / 60.0 + s / 3600.0 + if ref in ("S", "W"): + decimal = -decimal + return decimal + + lat = _to_decimal(gps_info.get(2, (0, 0, 0)), gps_info.get(1, "N")) + lon = _to_decimal(gps_info.get(4, (0, 0, 0)), gps_info.get(3, "E")) + return f"{lat:.6f}, {lon:.6f}" + except Exception: + return None + + +def extract_exif(pil_img: Image.Image, raw_bytes: bytes) -> ExifSummary: + """Extract EXIF metadata and compute a trust adjustment score. + + Trust adjustment logic: + - Valid Make + Model + DateTimeOriginal → -15 (more likely real camera photo) + - GPS info present → -5 additional (real photos often have GPS) + - Suspicious editing software detected → +10 (more likely manipulated) + - No EXIF at all → 0 (inconclusive — many platforms strip EXIF) + """ + summary = ExifSummary() + + try: + exif_data = pil_img._getexif() + except Exception: + exif_data = None + + if not exif_data: + # Try exifread as fallback for formats Pillow doesn't handle well + try: + import exifread + from io import BytesIO + tags = exifread.process_file(BytesIO(raw_bytes), details=False) + if tags: + summary.make = str(tags.get("Image Make", "")).strip() or None + summary.model = str(tags.get("Image Model", "")).strip() or None + summary.datetime_original = str(tags.get("EXIF DateTimeOriginal", "")).strip() or None + summary.software = str(tags.get("Image Software", "")).strip() or None + summary.lens_model = str(tags.get("EXIF LensModel", "")).strip() or None + except ImportError: + logger.debug("exifread not installed, skipping fallback EXIF extraction") + except Exception as e: + logger.debug(f"exifread fallback failed: {e}") + else: + # Decode Pillow EXIF + decoded = {} + for tag_id, value in exif_data.items(): + tag_name = TAGS.get(tag_id, tag_id) + decoded[tag_name] = value + + summary.make = str(decoded.get("Make", "")).strip() or None + summary.model = str(decoded.get("Model", "")).strip() or None + summary.datetime_original = str(decoded.get("DateTimeOriginal", "")).strip() or None + summary.software = str(decoded.get("Software", "")).strip() or None + summary.lens_model = str(decoded.get("LensModel", "")).strip() or None + + # GPS + gps_raw = decoded.get("GPSInfo") + if gps_raw and isinstance(gps_raw, dict): + gps_decoded = {} + for k, v in gps_raw.items(): + gps_decoded[GPSTAGS.get(k, k)] = v + summary.gps_info = _decode_gps(gps_decoded) + + # ── Trust adjustment scoring ── + adjustment = 0 + reasons = [] + + has_camera_meta = summary.make and summary.model and summary.datetime_original + if has_camera_meta: + adjustment -= 15 + reasons.append("valid camera metadata (Make/Model/DateTime)") + + if summary.gps_info: + adjustment -= 5 + reasons.append("GPS coordinates present") + + if summary.software: + sw_lower = summary.software.lower() + if any(s in sw_lower for s in _SUSPICIOUS_SOFTWARE): + adjustment += 10 + reasons.append(f"editing software detected: {summary.software}") + elif any(s in sw_lower for s in _CAMERA_SOFTWARE): + adjustment -= 2 + reasons.append("camera firmware in Software field") + + summary.trust_adjustment = adjustment + summary.trust_reason = "; ".join(reasons) if reasons else "no EXIF metadata found" + + logger.info(f"EXIF extracted: make={summary.make}, model={summary.model}, " + f"adjustment={adjustment} ({summary.trust_reason})") + return summary diff --git a/file_handler.py b/file_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..dc88cbed4ecce474c2eb7b5d543bdd8ae9124717 --- /dev/null +++ b/file_handler.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import io +import os +import tempfile +from typing import Iterable + +from fastapi import HTTPException, UploadFile, status + +from config import settings + +IMAGE_MAGIC_BYTES: dict[bytes, str] = { + b"\xff\xd8\xff": "image/jpeg", + b"\x89PNG\r\n\x1a\n": "image/png", + b"RIFF": "image/webp", # partial; WEBP has 'RIFF....WEBP' +} + + +def _detect_mime_by_magic(head: bytes) -> str | None: + for sig, mime in IMAGE_MAGIC_BYTES.items(): + if head.startswith(sig): + if mime == "image/webp" and b"WEBP" not in head[:16]: + continue + return mime + return None + + +async def read_upload_bytes( + file: UploadFile, + allowed_mimes: Iterable[str], + max_size_mb: int, +) -> tuple[bytes, str]: + """Read an UploadFile into memory after validating type and size. + Returns (raw_bytes, detected_mime). Raises HTTPException on failure. + """ + data = await file.read() + size_mb = len(data) / (1024 * 1024) + if size_mb > max_size_mb: + raise HTTPException( + status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, + detail=f"File too large ({size_mb:.1f} MB > {max_size_mb} MB)", + ) + + mime = _detect_mime_by_magic(data[:16]) or (file.content_type or "") + if mime not in allowed_mimes: + raise HTTPException( + status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, + detail=f"Unsupported type '{mime}'. Allowed: {list(allowed_mimes)}", + ) + return data, mime + + +def bytes_to_buffer(data: bytes) -> io.BytesIO: + return io.BytesIO(data) + + +async def save_upload_to_tempfile( + file: UploadFile, + allowed_mimes: Iterable[str], + max_size_mb: int, + suffix: str = ".mp4", +) -> tuple[str, str]: + """Stream an UploadFile to a temp file on disk. Returns (path, mime). + MIME is taken from the client's content_type (no magic-byte check for videos). + Caller is responsible for deleting the temp file. + """ + mime = (file.content_type or "").lower() + if mime not in allowed_mimes: + raise HTTPException( + status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, + detail=f"Unsupported type '{mime}'. Allowed: {list(allowed_mimes)}", + ) + + max_bytes = max_size_mb * 1024 * 1024 + fd, path = tempfile.mkstemp(suffix=suffix, prefix="ds_vid_") + written = 0 + try: + with os.fdopen(fd, "wb") as out: + while True: + chunk = await file.read(1024 * 1024) + if not chunk: + break + written += len(chunk) + if written > max_bytes: + raise HTTPException( + status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, + detail=f"File too large (> {max_size_mb} MB)", + ) + out.write(chunk) + except Exception: + try: + os.unlink(path) + except OSError: + pass + raise + return path, mime diff --git a/generate_colab_nb.py b/generate_colab_nb.py new file mode 100644 index 0000000000000000000000000000000000000000..9e1828737351e68461499d7786945060fc25fea9 --- /dev/null +++ b/generate_colab_nb.py @@ -0,0 +1,213 @@ +import nbformat as nbf +import os + +nb = nbf.v4.new_notebook() + +text = """\ +# DeepShield: FaceForensics++ ViT Training +Run this entirely in Google Colab. +**Before running**: +1. Go to `Runtime` -> `Change runtime type` -> select **T4 GPU**. +2. Run the cells below sequentially. +""" + +code_install = """\ +!pip install timm transformers datasets accelerate evaluate opencv-python +""" + +code_ffpp = """\ +# We create the download script inside the Colab environment +download_script = '''#!/usr/bin/env python +import argparse +import os +import urllib.request +import tempfile +import time +import sys +import json +from tqdm import tqdm +from os.path import join + +FILELIST_URL = 'misc/filelist.json' +DEEPFEAKES_DETECTION_URL = 'misc/deepfake_detection_filenames.json' +DEEPFAKES_MODEL_NAMES = ['decoder_A.h5', 'decoder_B.h5', 'encoder.h5',] +DATASETS = { + 'original': 'original_sequences/youtube', + 'Deepfakes': 'manipulated_sequences/Deepfakes', + 'Face2Face': 'manipulated_sequences/Face2Face', + 'FaceShifter': 'manipulated_sequences/FaceShifter', + 'FaceSwap': 'manipulated_sequences/FaceSwap', + 'NeuralTextures': 'manipulated_sequences/NeuralTextures' +} +ALL_DATASETS = ['original', 'Deepfakes', 'Face2Face', 'FaceShifter', 'FaceSwap', 'NeuralTextures'] +COMPRESSION = ['raw', 'c23', 'c40'] +TYPE = ['videos'] + +def download_file(url, out_file): + os.makedirs(os.path.dirname(out_file), exist_ok=True) + if not os.path.isfile(out_file): + urllib.request.urlretrieve(url, out_file) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('output_path', type=str) + parser.add_argument('-d', '--dataset', type=str, default='all') + parser.add_argument('-c', '--compression', type=str, default='c40') + parser.add_argument('-t', '--type', type=str, default='videos') + parser.add_argument('-n', '--num_videos', type=int, default=50) # Small amount for tutorial + args = parser.parse_args() + + base_url = 'http://kaldir.vc.in.tum.de/faceforensics/v3/' + + datasets = [args.dataset] if args.dataset != 'all' else ALL_DATASETS + for dataset in datasets: + dataset_path = DATASETS[dataset] + print(f'Downloading {args.compression} of {dataset}') + + file_pairs = json.loads(urllib.request.urlopen(base_url + FILELIST_URL).read().decode("utf-8")) + filelist = [] + if 'original' in dataset_path: + for pair in file_pairs: + filelist += pair + else: + for pair in file_pairs: + filelist.append('_'.join(pair)) + filelist.append('_'.join(pair[::-1])) + + filelist = filelist[:args.num_videos] + dataset_videos_url = base_url + f'{dataset_path}/{args.compression}/{args.type}/' + dataset_output_path = join(args.output_path, dataset_path, args.compression, args.type) + + for filename in tqdm(filelist): + download_file(dataset_videos_url + filename + ".mp4", join(dataset_output_path, filename + ".mp4")) + +if __name__ == "__main__": + main() +''' + +with open("download_ffpp.py", "w") as f: + f.write(download_script) + +!python download_ffpp.py ./data -d all -c c40 -t videos -n 50 +""" + +code_extract = """\ +import cv2 +import os +import glob +from tqdm import tqdm + +def extract_frames(video_folder, output_folder, label, max_frames=4): + os.makedirs(output_folder, exist_ok=True) + videos = glob.glob(os.path.join(video_folder, "*.mp4")) + + for vid_path in tqdm(videos, desc=f"Extracting {label}"): + vid_name = os.path.basename(vid_path).replace('.mp4','') + cap = cv2.VideoCapture(vid_path) + count = 0 + while cap.isOpened() and count < max_frames: + ret, frame = cap.read() + if not ret: break + frame = cv2.resize(frame, (224, 224)) + out_path = os.path.join(output_folder, f"{vid_name}_f{count}.jpg") + cv2.imwrite(out_path, frame) + count += 1 + cap.release() + +# Extract Real +extract_frames('./data/original_sequences/youtube/c40/videos', './dataset/real', 'real') + +# Extract Fakes +fakes = ['Deepfakes', 'Face2Face', 'FaceSwap', 'NeuralTextures'] +for f in fakes: + extract_frames(f'./data/manipulated_sequences/{f}/c40/videos', './dataset/fake', 'fake') +""" + +code_train = """\ +import numpy as np +from datasets import load_dataset +from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer +import torch + +# 1. Load Dataset +dataset = load_dataset('imagefolder', data_dir='./dataset') +# Split into train/validation +dataset = dataset['train'].train_test_split(test_size=0.1) + +# 2. Preprocessor +model_name_or_path = 'google/vit-base-patch16-224-in21k' +processor = ViTImageProcessor.from_pretrained(model_name_or_path) + +def transform(example_batch): + # Take a list of PIL images and turn them to pixel values + inputs = processor([x.convert("RGB") for x in example_batch['image']], return_tensors='pt') + inputs['labels'] = example_batch['label'] + return inputs + +prepared_ds = dataset.with_transform(transform) + +def collate_fn(batch): + return { + 'pixel_values': torch.stack([x['pixel_values'] for x in batch]), + 'labels': torch.tensor([x['labels'] for x in batch]) + } + +# 3. Load Model +labels = dataset['train'].features['label'].names +model = ViTForImageClassification.from_pretrained( + model_name_or_path, + num_labels=len(labels), + id2label={str(i): c for i, c in enumerate(labels)}, + label2id={c: str(i) for i, c in enumerate(labels)} +) + +training_args = TrainingArguments( + output_dir="./vit-deepshield", + per_device_train_batch_size=16, + eval_strategy="steps", + num_train_epochs=3, + fp16=True, # Mixed precision for speed + save_steps=100, + eval_steps=100, + logging_steps=10, + learning_rate=2e-4, + save_total_limit=2, + remove_unused_columns=False, + push_to_hub=False, + load_best_model_at_end=True, +) + +import evaluate +metric = evaluate.load("accuracy") +def compute_metrics(p): + return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids) + +trainer = Trainer( + model=model, + args=training_args, + data_collator=collate_fn, + compute_metrics=compute_metrics, + train_dataset=prepared_ds["train"], + eval_dataset=prepared_ds["test"], +) + +# 4. Train +train_results = trainer.train() +trainer.save_model("deepshield_vit_model") +processor.save_pretrained("deepshield_vit_model") +trainer.log_metrics("train", train_results.metrics) +trainer.save_metrics("train", train_results.metrics) +trainer.save_state() +print("Training Complete! The model is saved to ./deepshield_vit_model") +""" + +nb['cells'] = [ + nbf.v4.new_markdown_cell(text), + nbf.v4.new_code_cell(code_install), + nbf.v4.new_code_cell(code_ffpp), + nbf.v4.new_code_cell(code_extract), + nbf.v4.new_code_cell(code_train) +] + +with open(r'c:\Users\athar\Desktop\minor2\backend\training\Colab_ViT_Training.ipynb', 'w', encoding='utf-8') as f: + nbf.write(nb, f) diff --git a/heatmap_generator.py b/heatmap_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..7f62e4d2b740d5ea02600380cd5ea4425950c577 --- /dev/null +++ b/heatmap_generator.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +import base64 +import io +from typing import Optional + +import cv2 +import numpy as np +import torch +from loguru import logger +from PIL import Image +from pytorch_grad_cam import GradCAMPlusPlus +from pytorch_grad_cam.utils.image import show_cam_on_image +from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget + +from config import settings +from models.model_loader import get_model_loader + + +class _HFLogitsWrapper(torch.nn.Module): + """Wrap a HuggingFace image classification model so forward() returns logits + as a plain tensor (pytorch_grad_cam expects tensor outputs, not dicts/dataclasses). + """ + + def __init__(self, model: torch.nn.Module) -> None: + super().__init__() + self.model = model + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: # type: ignore[override] + return self.model(pixel_values=pixel_values).logits + + +def _vit_reshape_transform(tensor: torch.Tensor, height: int = 14, width: int = 14) -> torch.Tensor: + """Grad-CAM expects (B, C, H, W); ViT hidden states are (B, 1+H*W, C). + Drop the CLS token and reshape tokens into a spatial grid. + """ + result = tensor[:, 1:, :] + b, n, c = result.shape + result = result.reshape(b, height, width, c) + result = result.permute(0, 3, 1, 2) # (B, C, H, W) + return result + + +def _preprocess_for_cam(pil_img: Image.Image, processor) -> tuple[torch.Tensor, np.ndarray]: + """Return (input_tensor, rgb_float_224) where rgb_float_224 is a (H,W,3) float + array in [0,1] matching the model input geometry — needed for overlaying. + """ + inputs = processor(images=pil_img, return_tensors="pt") + input_tensor = inputs["pixel_values"].to(settings.DEVICE) + + size = getattr(processor, "size", {"height": 224, "width": 224}) + h = size.get("height", 224) if isinstance(size, dict) else 224 + w = size.get("width", 224) if isinstance(size, dict) else 224 + + resized = pil_img.resize((w, h), Image.BILINEAR) + rgb = np.array(resized).astype(np.float32) / 255.0 # (H,W,3) in [0,1] + return input_tensor, rgb + + +def _encode_overlay_to_base64(overlay: np.ndarray) -> str: + """Encode a uint8 (H,W,3) RGB overlay to a base64 data-URL PNG.""" + buf = io.BytesIO() + Image.fromarray(overlay).save(buf, format="PNG") + b64 = base64.b64encode(buf.getvalue()).decode("ascii") + return f"data:image/png;base64,{b64}" + + +def _compute_gradcam_pp( + pil_img: Image.Image, + target_class_idx: Optional[int] = None, +) -> tuple[np.ndarray, np.ndarray]: + """Compute Grad-CAM++ averaged across the last 3 ViT encoder layers. + Returns (grayscale_cam, rgb_float) where grayscale_cam is (H,W) in [0,1]. + """ + loader = get_model_loader() + model, processor = loader.load_image_model() + + model.eval() + for p in model.parameters(): + p.requires_grad_(True) + + input_tensor, rgb_float = _preprocess_for_cam(pil_img, processor) + + grid = int(model.config.image_size / model.config.patch_size) + + # Average across last 3 ViT encoder layers for smoother heatmaps + num_layers = len(model.vit.encoder.layer) + last_n = min(3, num_layers) + target_layers = [ + model.vit.encoder.layer[-(i + 1)].layernorm_before + for i in range(last_n) + ] + + wrapped = _HFLogitsWrapper(model) + + targets = None + if target_class_idx is not None: + targets = [ClassifierOutputTarget(int(target_class_idx))] + + with GradCAMPlusPlus( + model=wrapped, + target_layers=target_layers, + reshape_transform=lambda t: _vit_reshape_transform(t, grid, grid), + ) as cam: + grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0] # (H,W) in [0,1] + + return grayscale_cam, rgb_float + + +def generate_heatmap_base64( + pil_img: Image.Image, + target_class_idx: Optional[int] = None, +) -> str: + """Produce a base64 data-URL PNG of the Grad-CAM++ overlay for the given image.""" + grayscale_cam, rgb_float = _compute_gradcam_pp(pil_img, target_class_idx) + overlay = show_cam_on_image(rgb_float, grayscale_cam, use_rgb=True) + logger.info(f"Heatmap generated ({overlay.shape[0]}x{overlay.shape[1]})") + return _encode_overlay_to_base64(overlay) + + +def generate_boxes_base64( + pil_img: Image.Image, + target_class_idx: Optional[int] = None, + top_k: int = 5, + threshold: float = 0.4, +) -> str: + """Produce bounding boxes around top-K connected components from Grad-CAM++ activation. + Renders colored boxes (red/yellow/orange by intensity) on the original image. + """ + grayscale_cam, rgb_float = _compute_gradcam_pp(pil_img, target_class_idx) + + h, w = rgb_float.shape[:2] + base_img = (rgb_float * 255).astype(np.uint8).copy() + + # Threshold the heatmap to find activated regions + binary = (grayscale_cam >= threshold).astype(np.uint8) * 255 + contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + if not contours: + logger.info("No significant activation regions found for bounding boxes") + return _encode_overlay_to_base64(base_img) + + # Sort by area descending, take top_k + contours = sorted(contours, key=cv2.contourArea, reverse=True)[:top_k] + + # Color by mean activation intensity within each box + for cnt in contours: + x, y, bw, bh = cv2.boundingRect(cnt) + region_activation = grayscale_cam[y:y + bh, x:x + bw].mean() + + if region_activation >= 0.7: + color = (220, 40, 40) # red — high suspicion + elif region_activation >= 0.5: + color = (240, 140, 20) # orange — medium + else: + color = (230, 200, 40) # yellow — lower + + cv2.rectangle(base_img, (x, y), (x + bw, y + bh), color, 2) + label = f"{region_activation * 100:.0f}%" + cv2.putText(base_img, label, (x, max(y - 6, 12)), + cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1, cv2.LINE_AA) + + logger.info(f"Bounding boxes generated: {len(contours)} regions") + return _encode_overlay_to_base64(base_img) diff --git a/image_service.py b/image_service.py new file mode 100644 index 0000000000000000000000000000000000000000..799857aad686978ca8dc5289218a39ed4ddc328c --- /dev/null +++ b/image_service.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import io +from dataclasses import dataclass +from typing import Tuple + +import torch +from loguru import logger +from PIL import Image + +from config import settings +from models.model_loader import get_model_loader + + +@dataclass +class ImageClassification: + label: str + confidence: float + all_scores: dict[str, float] + + +def load_image_from_bytes(data: bytes) -> Image.Image: + img = Image.open(io.BytesIO(data)) + if img.mode != "RGB": + img = img.convert("RGB") + return img + + +def classify_image(pil_img: Image.Image) -> ImageClassification: + """Run the ViT deepfake classifier on a PIL image.""" + loader = get_model_loader() + model, processor = loader.load_image_model() + + inputs = processor(images=pil_img, return_tensors="pt") + inputs = {k: v.to(settings.DEVICE) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = model(**inputs) + logits = outputs.logits # (1, num_labels) + probs = torch.softmax(logits, dim=-1)[0] + + id2label: dict[int, str] = getattr(model.config, "id2label", {}) + all_scores = {id2label.get(i, str(i)): float(p.item()) for i, p in enumerate(probs)} + top_idx = int(torch.argmax(probs).item()) + top_label = id2label.get(top_idx, str(top_idx)) + top_conf = float(probs[top_idx].item()) + + logger.info(f"Image classify → {top_label} @ {top_conf:.3f}") + return ImageClassification(label=top_label, confidence=top_conf, all_scores=all_scores) + + +def preprocess_and_classify(raw_bytes: bytes) -> Tuple[Image.Image, ImageClassification]: + """Convenience: decode bytes → PIL → classify. Returns the PIL image too so + downstream steps (heatmap, artifact scan) can reuse it. + """ + pil = load_image_from_bytes(raw_bytes) + result = classify_image(pil) + return pil, result diff --git a/llm_explainer.py b/llm_explainer.py new file mode 100644 index 0000000000000000000000000000000000000000..81e9fe26157ba7ec76d76ae3d528903a2adba72a --- /dev/null +++ b/llm_explainer.py @@ -0,0 +1,182 @@ +"""LLM Explainability Card — Phase 12.3 + +Generates a plain-English summary paragraph + 3 key-signal bullets from the +full analysis payload. Supports Gemini (default) and OpenAI providers. +Results are cached per record_id to avoid re-spending tokens. +""" + +from __future__ import annotations + +import json +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import Any + +from loguru import logger + +from config import settings +from schemas.common import LLMExplainabilitySummary + +# ── In-memory cache keyed by record_id ── +_cache: dict[str, LLMExplainabilitySummary] = {} + + +_PROMPT_TEMPLATE = """\ +You are DeepShield's explainability engine. Given the JSON analysis payload below, +write a concise, accessible summary for a non-technical user. + +**Output format (strict JSON only — no markdown fences):** +{{ + "paragraph": "<2-3 sentence plain-English summary of the verdict and key signals>", + "bullets": [ + "", + "", + "" + ] +}} + +Rules: +- Be factual. State what the analysis found, not what you speculate. +- Reference specific indicators (e.g. "GAN artifact score", "EXIF metadata", "sensationalism level"). +- If the verdict is "Likely Authentic", reassure the user and explain why. +- If the verdict is "Likely Manipulated" or "Suspicious", highlight the strongest evidence. +- Keep the paragraph under 60 words. Each bullet under 20 words. + +**Analysis payload:** +{payload_json} +""" + + +class _LLMProvider(ABC): + @abstractmethod + def generate(self, prompt: str) -> str: + """Send prompt to LLM and return raw text response.""" + + +class _GeminiProvider(_LLMProvider): + def __init__(self) -> None: + import google.generativeai as genai + genai.configure(api_key=settings.LLM_API_KEY) + self._model = genai.GenerativeModel(settings.LLM_MODEL) + + def generate(self, prompt: str) -> str: + response = self._model.generate_content(prompt) + return response.text + + +class _OpenAIProvider(_LLMProvider): + def __init__(self) -> None: + from openai import OpenAI + self._client = OpenAI(api_key=settings.LLM_API_KEY) + + def generate(self, prompt: str) -> str: + response = self._client.chat.completions.create( + model=settings.LLM_MODEL, + messages=[{"role": "user", "content": prompt}], + temperature=0.3, + max_tokens=300, + ) + return response.choices[0].message.content + + +@lru_cache(maxsize=1) +def _get_provider() -> _LLMProvider: + """Lazy-init the configured LLM provider (singleton).""" + provider_name = settings.LLM_PROVIDER.lower() + if provider_name == "openai": + return _OpenAIProvider() + return _GeminiProvider() + + +def _parse_llm_response(raw: str) -> tuple[str, list[str]]: + """Parse the LLM's JSON response into (paragraph, bullets). + Handles cases where the LLM wraps output in markdown fences. + """ + text = raw.strip() + # Strip markdown code fences if present + if text.startswith("```"): + lines = text.split("\n") + # Remove first and last fence lines + lines = [l for l in lines if not l.strip().startswith("```")] + text = "\n".join(lines).strip() + + parsed = json.loads(text) + paragraph = parsed.get("paragraph", "") + bullets = parsed.get("bullets", []) + if not isinstance(bullets, list): + bullets = [str(bullets)] + return paragraph, bullets[:3] + + +def generate_llm_summary( + payload: dict[str, Any], + record_id: str | None = None, +) -> LLMExplainabilitySummary: + """Generate an LLM-powered plain-English explanation for an analysis result. + + Args: + payload: The full analysis response dict (verdict, scores, indicators, etc.). + record_id: Optional cache key. If provided and cached, returns cached result. + + Returns: + LLMExplainabilitySummary with paragraph, bullets, and model info. + """ + # Check cache + if record_id and record_id in _cache: + logger.debug(f"LLM summary cache hit for record_id={record_id}") + cached = _cache[record_id] + cached.cached = True + return cached + + # Guard: no API key configured + if not settings.LLM_API_KEY: + logger.warning("LLM_API_KEY not set — skipping LLM explainability card") + return LLMExplainabilitySummary( + paragraph="LLM explanation unavailable (no API key configured).", + bullets=[], + model_used="none", + ) + + # Strip heavy base64 fields to reduce token usage + slim_payload = {k: v for k, v in payload.items() + if k not in ("explainability",)} + # Include explainability but strip base64 images + if "explainability" in payload and isinstance(payload["explainability"], dict): + expl = {k: v for k, v in payload["explainability"].items() + if not k.endswith("_base64")} + slim_payload["explainability"] = expl + + prompt = _PROMPT_TEMPLATE.format(payload_json=json.dumps(slim_payload, indent=2, default=str)) + + try: + provider = _get_provider() + raw_response = provider.generate(prompt) + paragraph, bullets = _parse_llm_response(raw_response) + + summary = LLMExplainabilitySummary( + paragraph=paragraph, + bullets=bullets, + model_used=f"{settings.LLM_PROVIDER}/{settings.LLM_MODEL}", + ) + + # Cache result + if record_id: + _cache[record_id] = summary + + logger.info(f"LLM summary generated via {settings.LLM_PROVIDER}/{settings.LLM_MODEL}") + return summary + + except json.JSONDecodeError as e: + logger.error(f"LLM returned unparseable JSON: {e}") + return LLMExplainabilitySummary( + paragraph="Analysis complete. See the detailed indicators below for specifics.", + bullets=["LLM explanation could not be parsed"], + model_used=f"{settings.LLM_PROVIDER}/{settings.LLM_MODEL}", + ) + except Exception as e: + logger.error(f"LLM explainer failed: {e}") + return LLMExplainabilitySummary( + paragraph="Analysis complete. See the detailed indicators below for specifics.", + bullets=["LLM explanation temporarily unavailable"], + model_used="error", + ) diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..2c144c8523c543ab6882943f9f1412ce24d57e75 --- /dev/null +++ b/main.py @@ -0,0 +1,59 @@ +import asyncio +from contextlib import asynccontextmanager + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from loguru import logger + +from api.router import api_router +from config import settings +from db.database import init_db +from models.model_loader import get_model_loader +from services.report_service import cleanup_expired + + +async def _report_cleanup_loop(): + while True: + try: + cleanup_expired() + except Exception as e: # noqa: BLE001 + logger.warning(f"Report cleanup error: {e}") + await asyncio.sleep(600) # every 10 min + + +@asynccontextmanager +async def lifespan(app: FastAPI): + logger.info("Starting DeepShield backend") + init_db() + logger.info("Database initialized") + if settings.PRELOAD_MODELS: + get_model_loader().preload_phase1() + else: + logger.info("PRELOAD_MODELS=false — models will load on first use") + task = asyncio.create_task(_report_cleanup_loop()) + yield + task.cancel() + logger.info("Shutting down DeepShield backend") + + +app = FastAPI( + title="DeepShield API", + description="Explainable AI-based multimodal misinformation detection", + version="0.1.0", + lifespan=lifespan, +) + +app.add_middleware( + CORSMiddleware, + allow_origins=settings.CORS_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +app.include_router(api_router) + + +@app.get("/") +def root(): + return {"service": "DeepShield", "docs": "/docs", "health": "/api/v1/health"} diff --git a/model_loader.py b/model_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..d71e9f3f59bb7a51d81cbd8d82ef940d521118e0 --- /dev/null +++ b/model_loader.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +from threading import Lock +from typing import Optional, Tuple + +from loguru import logger + +from config import settings + + +class ModelLoader: + """Singleton holder for preloaded AI models. Thread-safe lazy init.""" + + _instance: Optional["ModelLoader"] = None + _lock: Lock = Lock() + + def __new__(cls) -> "ModelLoader": + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._image_model = None + cls._instance._image_processor = None + cls._instance._text_pipeline = None + cls._instance._multilang_text_pipeline = None + cls._instance._ocr_reader = None + cls._instance._face_detector = None + cls._instance._spacy_nlp = None + cls._instance._sentence_transformer = None + return cls._instance + + @classmethod + def get_instance(cls) -> "ModelLoader": + return cls() + + # ---------- Image (ViT deepfake classifier) ---------- + def load_image_model(self) -> Tuple[object, object]: + if self._image_model is None: + logger.info(f"Loading image model: {settings.IMAGE_MODEL_ID}") + from transformers import AutoImageProcessor, AutoModelForImageClassification + + self._image_processor = AutoImageProcessor.from_pretrained(settings.IMAGE_MODEL_ID) + model = AutoModelForImageClassification.from_pretrained(settings.IMAGE_MODEL_ID) + model.to(settings.DEVICE) + model.eval() + self._image_model = model + logger.info("Image model loaded") + return self._image_model, self._image_processor + + # ---------- Text (BERT fake-news classifier — English) ---------- + def load_text_model(self): + if self._text_pipeline is None: + logger.info(f"Loading text model: {settings.TEXT_MODEL_ID}") + from transformers import pipeline + + self._text_pipeline = pipeline( + "text-classification", + model=settings.TEXT_MODEL_ID, + device=0 if settings.DEVICE == "cuda" else -1, + ) + logger.info("Text model loaded") + return self._text_pipeline + + # ---------- Multilingual text model (Phase 13) ---------- + def load_multilang_text_model(self): + """Load multilingual fake-news classifier. Falls back to English model if not configured.""" + model_id = settings.TEXT_MULTILANG_MODEL_ID + if not model_id: + logger.debug("TEXT_MULTILANG_MODEL_ID not set — falling back to English text model") + return self.load_text_model() + + if self._multilang_text_pipeline is None: + logger.info(f"Loading multilingual text model: {model_id}") + from transformers import pipeline + + self._multilang_text_pipeline = pipeline( + "text-classification", + model=model_id, + device=0 if settings.DEVICE == "cuda" else -1, + ) + logger.info("Multilingual text model loaded") + return self._multilang_text_pipeline + + # ---------- spaCy NLP (Phase 13 NER) ---------- + def load_spacy_nlp(self): + """Lazy-load spaCy English NLP model. Returns None if spaCy is not installed.""" + if self._spacy_nlp is None: + try: + import spacy # type: ignore + try: + self._spacy_nlp = spacy.load("en_core_web_sm") + logger.info("spaCy en_core_web_sm loaded") + except OSError: + logger.warning( + "spaCy model 'en_core_web_sm' not found. " + "Run: python -m spacy download en_core_web_sm" + ) + return None + except ImportError: + logger.warning("spaCy not installed — NER keyword extraction disabled") + return None + return self._spacy_nlp + + # ---------- Sentence-Transformer (Phase 13 truth-override) ---------- + def load_sentence_transformer(self): + """Lazy-load sentence-transformers/all-MiniLM-L6-v2. Returns None if not installed.""" + if self._sentence_transformer is None: + try: + from sentence_transformers import SentenceTransformer # type: ignore + self._sentence_transformer = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") + logger.info("Sentence-transformer (all-MiniLM-L6-v2) loaded") + except ImportError: + logger.warning("sentence-transformers not installed — truth-override disabled") + return None + except Exception as e: + logger.warning(f"Sentence-transformer load failed: {e}") + return None + return self._sentence_transformer + + # ---------- OCR (EasyOCR) — Phase 13: use OCR_LANGS from config ---------- + def load_ocr_engine(self): + if self._ocr_reader is None: + langs = [l.strip() for l in settings.OCR_LANGS.split(",") if l.strip()] + if not langs: + langs = ["en"] + logger.info(f"Loading EasyOCR reader (langs: {langs})") + import easyocr # type: ignore + + self._ocr_reader = easyocr.Reader( + langs, gpu=(settings.DEVICE == "cuda"), verbose=False, download_enabled=True, + ) + logger.info("EasyOCR loaded") + return self._ocr_reader + + # ---------- Face detector (MediaPipe) ---------- + def load_face_detector(self): + if self._face_detector is None: + logger.info("Loading MediaPipe FaceMesh") + import mediapipe as mp # type: ignore + + self._face_detector = mp.solutions.face_mesh.FaceMesh( + static_image_mode=True, + max_num_faces=5, + min_detection_confidence=0.5, + ) + logger.info("MediaPipe FaceMesh loaded") + return self._face_detector + + # ---------- Preload ---------- + def preload_phase1(self) -> None: + """Preload only what Phase 1 needs (image model).""" + self.load_image_model() + + +def get_model_loader() -> ModelLoader: + return ModelLoader.get_instance() diff --git a/models.py b/models.py new file mode 100644 index 0000000000000000000000000000000000000000..af3b2f8f14b6485f08ed933ec490c11d10802e4a --- /dev/null +++ b/models.py @@ -0,0 +1,45 @@ +from datetime import datetime + +from sqlalchemy import DateTime, ForeignKey, Integer, String, Text +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from db.database import Base + + +class User(Base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) + email: Mapped[str] = mapped_column(String(255), unique=True, index=True, nullable=False) + password_hash: Mapped[str] = mapped_column(String(255), nullable=False) + name: Mapped[str | None] = mapped_column(String(255), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + + analyses: Mapped[list["AnalysisRecord"]] = relationship(back_populates="user") + + +class AnalysisRecord(Base): + __tablename__ = "analyses" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) + user_id: Mapped[int | None] = mapped_column(ForeignKey("users.id"), nullable=True) + media_type: Mapped[str] = mapped_column(String(32), nullable=False) # image|video|text|screenshot + verdict: Mapped[str] = mapped_column(String(32), nullable=False) + authenticity_score: Mapped[float] = mapped_column(nullable=False) + result_json: Mapped[str] = mapped_column(Text, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + + user: Mapped["User | None"] = relationship(back_populates="analyses") + report: Mapped["Report | None"] = relationship(back_populates="analysis", uselist=False) + + +class Report(Base): + __tablename__ = "reports" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) + analysis_id: Mapped[int] = mapped_column(ForeignKey("analyses.id"), nullable=False) + file_path: Mapped[str] = mapped_column(String(512), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + expires_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + + analysis: Mapped["AnalysisRecord"] = relationship(back_populates="report") diff --git a/news_lookup.py b/news_lookup.py new file mode 100644 index 0000000000000000000000000000000000000000..8831afb27b3e5d852cf6c2838c8bd96ceca8420d --- /dev/null +++ b/news_lookup.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Optional, Tuple +from urllib.parse import urlparse + +import httpx +from loguru import logger + +from config import settings +from schemas.common import ContradictingEvidence, TrustedSource, TruthOverride + +# Trusted news domains — higher relevance boost +TRUSTED_DOMAINS = { + "reuters.com": 1.0, "apnews.com": 1.0, "bbc.com": 1.0, "bbc.co.uk": 1.0, + "theguardian.com": 0.95, "nytimes.com": 0.95, "washingtonpost.com": 0.95, + "cnn.com": 0.9, "npr.org": 0.95, "aljazeera.com": 0.9, + "thehindu.com": 0.9, "indianexpress.com": 0.9, "ndtv.com": 0.85, + "hindustantimes.com": 0.85, "pti.news": 0.95, +} + +# Fact-check / contradiction sources +FACTCHECK_DOMAINS = { + "factcheck.org", "snopes.com", "politifact.com", "fullfact.org", + "reuters.com/fact-check", "apnews.com/hub/ap-fact-check", + "factly.in", "altnews.in", "boomlive.in", "vishvasnews.com", +} + +# Domains eligible for truth-override (weight >= 0.9 per BUILD_PLAN spec) +_HIGH_TRUST_DOMAINS = {d for d, w in TRUSTED_DOMAINS.items() if w >= 0.9} + +# Thresholds per BUILD_PLAN §13.2 +_OVERRIDE_SIMILARITY_THRESHOLD = 0.6 +_OVERRIDE_FAKE_PROB_CAP = 0.15 +_OVERRIDE_FAKE_PROB_MULTIPLIER = 0.3 + + +@dataclass +class NewsLookupResult: + trusted_sources: List[TrustedSource] + contradicting_evidence: List[ContradictingEvidence] + total_articles: int + truth_override: Optional[TruthOverride] = None + + +def _domain_of(url: str) -> str: + try: + return urlparse(url).netloc.lower().replace("www.", "") + except Exception: + return "" + + +def _is_factcheck(url: str, title: str) -> bool: + dom = _domain_of(url) + if any(fc in dom for fc in FACTCHECK_DOMAINS): + return True + tl = (title or "").lower() + return any(kw in tl for kw in ("fact check", "fact-check", "debunked", "false claim", "misleading", "hoax")) + + +def _relevance(url: str) -> float: + dom = _domain_of(url) + for td, score in TRUSTED_DOMAINS.items(): + if td in dom: + return score + return 0.5 + + +def _is_high_trust(url: str) -> bool: + dom = _domain_of(url) + return any(ht in dom for ht in _HIGH_TRUST_DOMAINS) + + +def _compute_truth_override( + input_text: str, + trusted_sources: List[TrustedSource], + current_fake_prob: float, +) -> Optional[TruthOverride]: + """Check if any high-trust source corroborates the input text at >= 0.6 cosine similarity. + + Per BUILD_PLAN §13.2: + - Compute cosine similarity between input_text and each trusted-source headline+description + - If ≥ 1 high-trust source (weight ≥ 0.9) has similarity ≥ 0.6 → apply fake_prob *= 0.3, cap at 0.15 + """ + if not input_text or not trusted_sources: + return None + + # Filter to high-trust sources only + high_trust = [s for s in trusted_sources if _is_high_trust(s.url)] + if not high_trust: + return None + + # Lazy-load sentence-transformer + from models.model_loader import get_model_loader + st_model = get_model_loader().load_sentence_transformer() + if st_model is None: + return None + + try: + import numpy as np + + # Encode input text and all high-trust headlines + source_texts = [ + f"{s.title}" for s in high_trust + ] + all_texts = [input_text[:512]] + source_texts + + embeddings = st_model.encode(all_texts, convert_to_numpy=True, normalize_embeddings=True) + query_vec = embeddings[0] # (D,) + source_vecs = embeddings[1:] # (N, D) + + # Cosine similarity — already normalized, so dot product = cosine similarity + similarities = np.dot(source_vecs, query_vec) + + best_idx = int(np.argmax(similarities)) + best_sim = float(similarities[best_idx]) + best_source = high_trust[best_idx] + + logger.info( + f"Truth-override: best similarity={best_sim:.3f} " + f"source={best_source.source_name} url={best_source.url}" + ) + + if best_sim >= _OVERRIDE_SIMILARITY_THRESHOLD: + new_fake_prob = min( + current_fake_prob * _OVERRIDE_FAKE_PROB_MULTIPLIER, + _OVERRIDE_FAKE_PROB_CAP, + ) + logger.info( + f"Truth-override APPLIED: fake_prob {current_fake_prob:.3f} → {new_fake_prob:.3f}" + ) + return TruthOverride( + applied=True, + source_url=best_source.url, + source_name=best_source.source_name, + similarity=round(best_sim, 4), + fake_prob_before=round(current_fake_prob, 4), + fake_prob_after=round(new_fake_prob, 4), + ) + + return TruthOverride( + applied=False, + source_url=best_source.url, + source_name=best_source.source_name, + similarity=round(best_sim, 4), + fake_prob_before=round(current_fake_prob, 4), + fake_prob_after=round(current_fake_prob, 4), + ) + + except Exception as e: + logger.warning(f"Truth-override computation failed: {e}") + return None + + +async def _fetch(q: str, country: Optional[str]) -> list[dict]: + target_country = country or "in" + params = {"apikey": settings.NEWS_API_KEY, "q": q, "language": "en", "size": 10, "country": "in"} + + try: + async with httpx.AsyncClient(timeout=8.0) as c: + r = await c.get(settings.NEWS_API_BASE_URL, params=params) + r.raise_for_status() + return (r.json() or {}).get("results") or [] + except Exception as e: + logger.warning(f"News lookup failed: {e}") + return [] + + +async def search_news( + keywords: List[str], + limit: int = 6, + country: Optional[str] = None, +) -> List[TrustedSource]: + """Back-compat simple form — returns trusted sources only.""" + result = await search_news_full(keywords, limit=limit, country=country) + return result.trusted_sources + + +async def search_news_full( + keywords: List[str], + limit: int = 6, + country: Optional[str] = None, + original_text: Optional[str] = None, + current_fake_prob: float = 0.5, +) -> NewsLookupResult: + """Full news lookup with truth-override support. + + Args: + keywords: NER-extracted or frequency-extracted keywords to search. + limit: Max sources to return. + country: Country code for newsdata.io. + original_text: Input text to compare against headlines for truth-override. + current_fake_prob: Current fake probability — may be adjusted by truth-override. + """ + if not settings.NEWS_API_KEY or not keywords: + return NewsLookupResult([], [], 0) + + q = " ".join(keywords[:4]) + articles = await _fetch(q, country) + + seen: set[str] = set() + trusted: List[TrustedSource] = [] + contradictions: List[ContradictingEvidence] = [] + + for art in articles: + url = art.get("link") or "" + if not url or url in seen: + continue + seen.add(url) + + title = art.get("title") or "" + dom = _domain_of(url) + src_name = art.get("source_id") or dom or "news" + + if _is_factcheck(url, title): + contradictions.append(ContradictingEvidence( + source_name=src_name, title=title, url=url, type="fact_check", + )) + continue + + trusted.append(TrustedSource( + source_name=src_name, + title=title, + url=url, + published_at=art.get("pubDate"), + relevance_score=_relevance(url), + )) + + trusted.sort(key=lambda s: -s.relevance_score) + trusted = trusted[:limit] + + # ── Phase 13.2: Truth-override ── + truth_override = None + if original_text and trusted: + truth_override = _compute_truth_override(original_text, trusted, current_fake_prob) + + return NewsLookupResult( + trusted_sources=trusted, + contradicting_evidence=contradictions[:limit], + total_articles=len(articles), + truth_override=truth_override, + ) diff --git a/report.html b/report.html new file mode 100644 index 0000000000000000000000000000000000000000..17189b7194f885ccb8c690ae7accfe93578780a9 --- /dev/null +++ b/report.html @@ -0,0 +1,367 @@ + + + + + DeepShield Analysis Report — {{ analysis_id }} + + + + + {# ── Header ── #} + + + + + +
DeepShield + Analysis Report  ·  ID: {{ analysis_id }}
+ Media: {{ media_type | upper }}  ·  Generated: {{ generated_at }} +
+ + {# ── Verdict ── #} +

Verdict

+ + + + + {% if donut_b64 %} + + {% endif %} + +
+
{{ verdict.authenticity_score }}
+
/ 100
+
+
{{ verdict.label }}
+
Severity: {{ verdict.severity }}
+
Model: {{ verdict.model_label }}  ({{ '%.1f' | format(verdict.model_confidence * 100) }}% confidence)
+
+ score donut +
+ + {# ── LLM Explanation ── #} + {% if llm_summary and llm_summary.paragraph %} +

AI Explanation

+
+

{{ llm_summary.paragraph }}

+ {% if llm_summary.bullets %} +
    + {% for b in llm_summary.bullets %}
  • {{ b }}
  • {% endfor %} +
+ {% endif %} + {% if llm_summary.model_used %} +
via {{ llm_summary.model_used }}
+ {% endif %} +
+ {% endif %} + + {# ══════════ IMAGE ══════════ #} + {% if media_type == 'image' %} + + {# EXIF #} + {% if explainability.exif %} +

EXIF Metadata

+ + + {% if explainability.exif.make %} + + {% endif %} + {% if explainability.exif.model %} + + {% endif %} + {% if explainability.exif.datetime_original %} + + {% endif %} + {% if explainability.exif.software %} + + + {% endif %} + {% if explainability.exif.lens_model %} + + {% endif %} + {% if explainability.exif.gps_info %} + + {% endif %} + + + + +
FieldValueTrust Signal
Camera Make{{ explainability.exif.make }}+real
Camera Model{{ explainability.exif.model }}
Date Taken{{ explainability.exif.datetime_original }}+real
Software{{ explainability.exif.software }}{% if 'photoshop' in explainability.exif.software | lower %}+fake{% endif %}
Lens Model{{ explainability.exif.lens_model }}
GPS{{ explainability.exif.gps_info }}
Trust adjustment + {% if explainability.exif.trust_adjustment > 0 %} + +{{ explainability.exif.trust_adjustment }} (fake signal) + {% elif explainability.exif.trust_adjustment < 0 %} + {{ explainability.exif.trust_adjustment }} (real signal) + {% else %} + neutral + {% endif %} +
+ {% endif %} + + {# Artifact indicators #} + {% if explainability.artifact_indicators %} +

Artifact Indicators

+ + + {% for ind in explainability.artifact_indicators %} + + + + + + + {% endfor %} +
TypeSeverityConfidenceDescription
{{ ind.type }}{{ ind.severity }}{{ '%.0f' | format(ind.confidence * 100) }}%{{ ind.description }}
+ {% else %} +

Artifact Indicators

+
No artifacts detected.
+ {% endif %} + + {# VLM Detailed Breakdown #} + {% if explainability.vlm_breakdown %} +

Detailed Breakdown

+ {% if explainability.vlm_breakdown.model_used %} +
Scored by {{ explainability.vlm_breakdown.model_used }}
+ {% endif %} + + + {% set bd = explainability.vlm_breakdown %} + {% for comp_key, comp_label in [ + ('facial_symmetry', 'Facial Symmetry'), + ('skin_texture', 'Skin Texture'), + ('lighting_consistency', 'Lighting Consistency'), + ('background_coherence', 'Background Coherence'), + ('anatomy_hands_eyes', 'Anatomy / Hands & Eyes'), + ('context_objects', 'Context & Objects') + ] %} + {% set comp = bd[comp_key] %} + {% set sc2 = comp.score if comp else 75 %} + {% set bar_cls = 'vlm-real' if sc2 >= 70 else ('vlm-warn' if sc2 >= 40 else 'vlm-fake') %} + + + + + + + {% endfor %} +
ComponentScoreBarNotes
{{ comp_label }}{{ sc2 }}/100 + + + + {{ comp.notes if comp else '' }}
+ {% endif %} + + {% endif %}{# end image #} + + {# ══════════ VIDEO ══════════ #} + {% if media_type == 'video' %} +

Frame-Level Analysis

+ + + + + + + + +
MetricValue
Frames sampled{{ explainability.num_frames_sampled }}
Frames with face{{ explainability.num_face_frames }}
Suspicious frames{{ explainability.num_suspicious_frames }}
Mean suspicious prob{{ '%.1f' | format(explainability.mean_suspicious_prob * 100) }}%
Max suspicious prob{{ '%.1f' | format(explainability.max_suspicious_prob * 100) }}%
Insufficient faces{{ explainability.insufficient_faces }}
+ {% endif %} + + {# ══════════ TEXT ══════════ #} + {% if media_type == 'text' %} + + {# Language + truth-override #} + {% if explainability.detected_language and explainability.detected_language != 'en' %} +

Language

+
Detected: {{ explainability.detected_language | upper }} — analysed via multilingual model
+ {% endif %} + {% if explainability.truth_override and explainability.truth_override.applied %} +
+ Truth-override applied. + Corroborated by {{ explainability.truth_override.source_name }} + ({{ '%.0f' | format(explainability.truth_override.similarity * 100) }}% similarity). + Fake probability reduced from {{ '%.1f' | format(explainability.truth_override.fake_prob_before * 100) }}% + to {{ '%.1f' | format(explainability.truth_override.fake_prob_after * 100) }}%. +
+ {% endif %} + +

Text Classification

+ + + + + + + + + +
MetricValue
Fake probability{{ '%.1f' | format(explainability.fake_probability * 100) }}%
Top label{{ explainability.top_label }}
Sensationalism score{{ explainability.sensationalism.score }}/100 ({{ explainability.sensationalism.level }})
Exclamations{{ explainability.sensationalism.exclamation_count }}
ALL CAPS words{{ explainability.sensationalism.caps_word_count }}
Clickbait matches{{ explainability.sensationalism.clickbait_matches }}
Emotional words{{ explainability.sensationalism.emotional_word_count }}
+ + {% if explainability.manipulation_indicators %} +

Manipulation Indicators ({{ explainability.manipulation_indicators | length }})

+ + + {% for m in explainability.manipulation_indicators %} + + + + + + {% endfor %} +
PatternSeverityMatched text
{{ m.pattern_type }}{{ m.severity }}{{ m.matched_text }}
+ {% endif %} + + {% if explainability.keywords %} +

Extracted Keywords

+
{% for kw in explainability.keywords %}{{ kw }}{% endfor %}
+ {% endif %} + + {% endif %}{# end text #} + + {# ══════════ SCREENSHOT ══════════ #} + {% if media_type == 'screenshot' %} + + {% if explainability.detected_language and explainability.detected_language != 'en' %} +
Detected language: {{ explainability.detected_language | upper }}
+ {% endif %} + {% if explainability.truth_override and explainability.truth_override.applied %} +
+ Truth-override applied. {{ explainability.truth_override.source_name }} + ({{ '%.0f' | format(explainability.truth_override.similarity * 100) }}% similarity) +
+ {% endif %} + +

Extracted Text

+
{{ explainability.ocr_boxes | length }} OCR regions detected
+ + +
{{ explainability.extracted_text }}
+ +

Analysis Summary

+ + + + + + +
MetricValue
Fake probability{{ '%.1f' | format(explainability.fake_probability * 100) }}%
Sensationalism{{ explainability.sensationalism.score }}/100 ({{ explainability.sensationalism.level }})
Suspicious phrases{{ explainability.suspicious_phrases | length }}
Layout anomalies{{ explainability.layout_anomalies | length }}
+ + {% if explainability.suspicious_phrases %} +

Suspicious Phrases

+ + + {% for p in explainability.suspicious_phrases %} + + + + + + {% endfor %} +
TextPatternSeverity
{{ p.text }}{{ p.pattern_type }}{{ p.severity }}
+ {% endif %} + + {% endif %}{# end screenshot #} + + {# ══════════ SOURCES (all types) ══════════ #} + {% if trusted_sources %} +

Trusted Source Cross-Reference ({{ trusted_sources | length }})

+ + + {% for s in trusted_sources %} + + + + + + {% endfor %} +
SourceTitleRelevance
{{ s.source_name }}{{ s.title }}{{ '%.0f' | format(s.relevance_score * 100) }}%
+ {% endif %} + + {% if contradicting_evidence %} +

Contradicting Evidence ({{ contradicting_evidence | length }})

+ + + {% for c in contradicting_evidence %} + + {% endfor %} +
SourceTitleType
{{ c.source_name }}{{ c.title }}{{ c.type }}
+ {% endif %} + + {# ══════════ PROCESSING ══════════ #} +

Processing Summary

+
Model: {{ processing_summary.model_used }}  ·  Duration: {{ processing_summary.total_duration_ms }} ms
+
{{ processing_summary.stages_completed | join(' → ') }}
+ + {# ══════════ FOOTER ══════════ #} + + + + diff --git a/report_service.py b/report_service.py new file mode 100644 index 0000000000000000000000000000000000000000..154503b6179adbcd19ba924682e3f3649d8b0cc6 --- /dev/null +++ b/report_service.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +import base64 +import json +import os +import time +import uuid +from datetime import datetime, timedelta, timezone +from io import BytesIO +from pathlib import Path +from typing import Any, Optional + +from jinja2 import Environment, FileSystemLoader, select_autoescape +from loguru import logger +from xhtml2pdf import pisa # type: ignore + +from config import settings +from db.models import AnalysisRecord, Report + +TEMPLATES_DIR = Path(__file__).resolve().parent.parent / "templates" + +_env = Environment( + loader=FileSystemLoader(str(TEMPLATES_DIR)), + autoescape=select_autoescape(["html", "xml"]), +) + + +def _score_class(score: int) -> str: + if score >= 70: + return "real" + if score >= 40: + return "warn" + return "fake" + + +def _ensure_dir() -> Path: + p = Path(settings.REPORT_DIR) + p.mkdir(parents=True, exist_ok=True) + return p + + +def _make_donut_chart(score: int, score_cls: str) -> str: + """Render authenticity score as a donut chart PNG; return base64 or '' on failure.""" + try: + import matplotlib # type: ignore + matplotlib.use("Agg") + import matplotlib.pyplot as plt # type: ignore + + color_map = {"real": "#43A047", "warn": "#FB8C00", "fake": "#E53935"} + color = color_map.get(score_cls, "#6B7280") + + fig, ax = plt.subplots(figsize=(2.2, 2.2), dpi=96) + sizes = [score, 100 - score] + wedge_colors = [color, "#F3F4F6"] + ax.pie(sizes, colors=wedge_colors, startangle=90, + wedgeprops=dict(width=0.42, edgecolor="white", linewidth=1)) + ax.text(0, 0, str(score), ha="center", va="center", + fontsize=20, fontweight="bold", color=color) + ax.set_aspect("equal") + plt.tight_layout(pad=0.05) + + buf = BytesIO() + fig.savefig(buf, format="png", bbox_inches="tight", transparent=True) + plt.close(fig) + buf.seek(0) + return base64.b64encode(buf.read()).decode() + except Exception as e: + logger.debug(f"Donut chart skipped: {e}") + return "" + + +def _extract_llm_summary(analysis_json: dict) -> dict | None: + """Extract llm_summary from either top-level or inside explainability (images).""" + top = analysis_json.get("llm_summary") + if top: + return top + return (analysis_json.get("explainability") or {}).get("llm_summary") + + +def render_html(analysis_json: dict) -> str: + score = analysis_json.get("verdict", {}).get("authenticity_score", 50) + sc = _score_class(score) + donut_b64 = _make_donut_chart(score, sc) + llm_summary = _extract_llm_summary(analysis_json) + expl: dict[str, Any] = analysis_json.get("explainability") or {} + + tmpl = _env.get_template("report.html") + return tmpl.render( + analysis_id=analysis_json.get("analysis_id", ""), + media_type=analysis_json.get("media_type", "unknown"), + verdict=analysis_json.get("verdict", {}), + explainability=expl, + trusted_sources=analysis_json.get("trusted_sources", []), + contradicting_evidence=analysis_json.get("contradicting_evidence", []), + processing_summary=analysis_json.get("processing_summary", {}), + responsible_ai_notice=analysis_json.get( + "responsible_ai_notice", + "AI-based analysis may not be 100% accurate.", + ), + score_class=sc, + generated_at=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M UTC"), + donut_b64=donut_b64, + llm_summary=llm_summary, + ) + + +def html_to_pdf(html: str, out_path: Path) -> None: + with open(out_path, "wb") as f: + result = pisa.CreatePDF(html, dest=f) + if result.err: + raise RuntimeError(f"xhtml2pdf failed with {result.err} errors") + + +def generate_report(record: AnalysisRecord) -> Path: + out_dir = _ensure_dir() + filename = f"deepshield_{record.id}_{uuid.uuid4().hex[:8]}.pdf" + out_path = out_dir / filename + + data = json.loads(record.result_json) + html = render_html(data) + html_to_pdf(html, out_path) + logger.info(f"Report generated id={record.id} path={out_path} size={out_path.stat().st_size}B") + return out_path + + +def create_report_row(analysis_id: int, path: Path) -> Report: + return Report( + analysis_id=analysis_id, + file_path=str(path), + expires_at=datetime.utcnow() + timedelta(seconds=settings.REPORT_TTL_SECONDS), + ) + + +def cleanup_expired(now: Optional[datetime] = None) -> int: + """Delete expired PDFs from disk. Returns count deleted.""" + now = now or datetime.utcnow() + d = Path(settings.REPORT_DIR) + if not d.exists(): + return 0 + deleted = 0 + ttl = timedelta(seconds=settings.REPORT_TTL_SECONDS) + for f in d.glob("*.pdf"): + try: + mtime = datetime.utcfromtimestamp(f.stat().st_mtime) + if now - mtime > ttl: + f.unlink() + deleted += 1 + except OSError as e: + logger.warning(f"Cleanup failed for {f}: {e}") + if deleted: + logger.info(f"Cleaned up {deleted} expired reports") + return deleted diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..6111cdc40a3920a19fa432a7c9539965abc77d32 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,50 @@ +fastapi==0.115.0 +uvicorn[standard]==0.32.0 +pydantic==2.9.2 +pydantic-settings==2.6.0 +python-multipart==0.0.12 +python-dotenv==1.0.1 +loguru==0.7.2 +SQLAlchemy==2.0.35 +psycopg2-binary==2.9.9 +alembic==1.13.3 +python-jose[cryptography]==3.3.0 +bcrypt==4.2.0 + +# === Phase 1: Image Detection === +# Install torch separately with CPU index first (see README): pip install torch==2.4.1 --index-url https://download.pytorch.org/whl/cpu +torch==2.4.1 +torchvision==0.19.1 +transformers==4.44.2 +Pillow>=10.4.0 +numpy>=1.26,<3 +opencv-python==4.10.0.84 +grad-cam==1.5.4 +mediapipe==0.10.14 + +# === Phase 12: Explainability v2 === +exifread==3.0.0 +google-generativeai>=0.3.0 # Gemini provider for LLM explainability +openai>=1.0.0 # OpenAI provider (alternative to Gemini) + +# === Phase 14: PDF v2 donut chart === +matplotlib>=3.9.0 + +# === Phase 13: Text Pipeline Hardening === +# After installing, run: python -m spacy download en_core_web_sm +spacy>=3.7.0,<4.0.0 +sentence-transformers>=2.7.0 # for truth-override cosine similarity (all-MiniLM-L6-v2) +langdetect==1.0.9 # lightweight language detection + +# === Phase 3: Text / News === +httpx==0.27.2 + +# === Phase 4: Screenshot / OCR === +easyocr==1.7.2 + +# === Phase 7: PDF Reports === +Jinja2==3.1.4 +xhtml2pdf==0.2.16 + +# === Phase 8: Auth === +email-validator==2.2.0 diff --git a/router.py b/router.py new file mode 100644 index 0000000000000000000000000000000000000000..478bb749be5a898755712262b76c49db8bd1257f --- /dev/null +++ b/router.py @@ -0,0 +1,10 @@ +from fastapi import APIRouter + +from api.v1 import analyze, auth, health, history, report + +api_router = APIRouter(prefix="/api/v1") +api_router.include_router(health.router) +api_router.include_router(analyze.router) +api_router.include_router(report.router) +api_router.include_router(auth.router) +api_router.include_router(history.router) diff --git a/scoring.py b/scoring.py new file mode 100644 index 0000000000000000000000000000000000000000..eec7009e2d63204fd7952bc3e0d30afb561dacc1 --- /dev/null +++ b/scoring.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import Tuple + +TRUST_SCALE = [ + (0, 20, "Very Likely Fake", "critical"), + (21, 40, "Likely Fake", "danger"), + (41, 60, "Possibly Manipulated", "warning"), + (61, 80, "Likely Real", "positive"), + (81, 100, "Very Likely Real", "safe"), +] + + +def compute_authenticity_score(model_confidence: float, label: str) -> int: + """Map (confidence, label) to 0-100 authenticity score. + Real-ish labels give high score; fake-ish labels give low score. + """ + label_l = label.lower() + fake_tokens = ("fake", "deepfake", "manipulated", "ai", "generated", "synthetic") + if any(tok in label_l for tok in fake_tokens): + score = (1.0 - float(model_confidence)) * 100.0 + else: + score = float(model_confidence) * 100.0 + return int(round(max(0.0, min(100.0, score)))) + + +def get_verdict_label(score: int) -> Tuple[str, str]: + for lo, hi, label, severity in TRUST_SCALE: + if lo <= score <= hi: + return label, severity + return "Unknown", "warning" + + +def get_score_color(score: int) -> str: + """Linear interpolate Red (#E53935) → Amber (#FFA726) → Green (#43A047).""" + def lerp(a: int, b: int, t: float) -> int: + return int(round(a + (b - a) * t)) + + score = max(0, min(100, score)) + if score <= 50: + t = score / 50.0 + r, g, b = lerp(0xE5, 0xFF, t), lerp(0x39, 0xA7, t), lerp(0x35, 0x26, t) + else: + t = (score - 50) / 50.0 + r, g, b = lerp(0xFF, 0x43, t), lerp(0xA7, 0xA0, t), lerp(0x26, 0x47, t) + return f"#{r:02X}{g:02X}{b:02X}" diff --git a/screenshot_service.py b/screenshot_service.py new file mode 100644 index 0000000000000000000000000000000000000000..ae5aa3eed6986c0f9a940965bcdd5dfaedfa0dfb --- /dev/null +++ b/screenshot_service.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import List, Tuple + +import numpy as np +from loguru import logger +from PIL import Image + +from models.model_loader import get_model_loader + + +@dataclass +class OCRBox: + text: str + bbox: List[List[int]] # 4 points [[x,y],...] + confidence: float + + +@dataclass +class SuspiciousPhrase: + text: str + bbox: List[List[int]] + pattern_type: str + severity: str + description: str + + +@dataclass +class LayoutAnomaly: + type: str # misalignment / font_mismatch / uneven_spacing + severity: str + description: str + confidence: float + + +def run_ocr(pil_img: Image.Image) -> List[OCRBox]: + reader = get_model_loader().load_ocr_engine() + arr = np.array(pil_img.convert("RGB")) + results = reader.readtext(arr, detail=1, paragraph=False) + out: List[OCRBox] = [] + for bbox, text, conf in results: + out.append(OCRBox( + text=str(text), + bbox=[[int(p[0]), int(p[1])] for p in bbox], + confidence=float(conf), + )) + logger.info(f"OCR extracted {len(out)} text regions") + return out + + +def extract_full_text(boxes: List[OCRBox]) -> str: + return " ".join(b.text for b in boxes if b.text.strip()) + + +def map_phrases_to_boxes(boxes: List[OCRBox], manipulation_indicators) -> List[SuspiciousPhrase]: + """Map each manipulation indicator to the OCR box whose text contains it.""" + out: List[SuspiciousPhrase] = [] + for mi in manipulation_indicators: + needle = mi.matched_text.lower() + for b in boxes: + if needle in b.text.lower(): + out.append(SuspiciousPhrase( + text=mi.matched_text, + bbox=b.bbox, + pattern_type=mi.pattern_type, + severity=mi.severity, + description=mi.description, + )) + break + return out + + +def detect_layout_anomalies(boxes: List[OCRBox]) -> List[LayoutAnomaly]: + """Heuristic layout checks on OCR bboxes.""" + out: List[LayoutAnomaly] = [] + if len(boxes) < 3: + return out + + heights = [] + x_lefts = [] + for b in boxes: + pts = b.bbox + ys = [p[1] for p in pts] + xs = [p[0] for p in pts] + heights.append(max(ys) - min(ys)) + x_lefts.append(min(xs)) + + h_arr = np.array(heights, dtype=float) + if h_arr.mean() > 0: + cv_h = float(h_arr.std() / h_arr.mean()) + if cv_h > 0.7: + out.append(LayoutAnomaly( + type="font_mismatch", + severity="medium" if cv_h < 1.2 else "high", + description=f"High variance in text heights (cv={cv_h:.2f}) — mixed fonts/sizes possible", + confidence=min(cv_h / 1.5, 1.0), + )) + + x_arr = np.array(x_lefts, dtype=float) + if x_arr.std() > 0 and len(x_arr) > 4: + clustered = sum(1 for x in x_arr if abs(x - np.median(x_arr)) < 20) + align_ratio = clustered / len(x_arr) + if align_ratio < 0.4: + out.append(LayoutAnomaly( + type="misalignment", + severity="low", + description=f"Only {align_ratio*100:.0f}% of text blocks share left-alignment — unusual layout", + confidence=1.0 - align_ratio, + )) + + if len(boxes) >= 4: + tops = sorted([min(p[1] for p in b.bbox) for b in boxes]) + gaps = np.diff(tops) + gaps = gaps[gaps > 0] + if len(gaps) >= 3 and gaps.mean() > 0: + cv_g = float(gaps.std() / gaps.mean()) + if cv_g > 1.5: + out.append(LayoutAnomaly( + type="uneven_spacing", + severity="low", + description=f"Irregular vertical spacing between text blocks (cv={cv_g:.2f})", + confidence=min(cv_g / 2.5, 1.0), + )) + + return out diff --git a/test_image_classify.py b/test_image_classify.py new file mode 100644 index 0000000000000000000000000000000000000000..d38a0b667ed32057de0a08738f461535328648f5 --- /dev/null +++ b/test_image_classify.py @@ -0,0 +1,58 @@ +"""Phase 1.2 smoke test: download a sample image and run the ViT classifier. + +Run from backend/: + .venv/Scripts/python.exe scripts/test_image_classify.py +""" +from __future__ import annotations + +import sys +import urllib.request +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +import base64 + +from models.heatmap_generator import generate_heatmap_base64 +from services.artifact_detector import scan_artifacts +from services.image_service import preprocess_and_classify +from utils.scoring import compute_authenticity_score, get_verdict_label + +SAMPLE_URL = "https://picsum.photos/seed/deepshield/512/512" + + +def main() -> int: + print(f"Fetching sample image: {SAMPLE_URL}") + req = urllib.request.Request(SAMPLE_URL, headers={"User-Agent": "DeepShield/0.1"}) + with urllib.request.urlopen(req, timeout=30) as r: + data = r.read() + print(f" got {len(data)} bytes") + + print("Running classifier (first run will download model ~350MB)…") + pil, result = preprocess_and_classify(data) + print(f" image size: {pil.size}") + print(f" label: {result.label}") + print(f" confidence: {result.confidence:.4f}") + print(f" all scores: {result.all_scores}") + + score = compute_authenticity_score(result.confidence, result.label) + verdict_label, severity = get_verdict_label(score) + print(f"\n authenticity_score: {score}") + print(f" verdict: {verdict_label} ({severity})") + + print("\nScanning artifact indicators\u2026") + for ind in scan_artifacts(pil, data): + print(f" [{ind.severity.upper():6s}] {ind.type}: {ind.description} (conf {ind.confidence:.2f})") + + print("\nGenerating Grad-CAM heatmap\u2026") + heatmap_url = generate_heatmap_base64(pil) + header, b64 = heatmap_url.split(",", 1) + out_path = Path(__file__).resolve().parent.parent / "heatmap_smoketest.png" + out_path.write_bytes(base64.b64decode(b64)) + print(f" saved: {out_path}") + print(f" data URL length: {len(heatmap_url)} chars") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/test_news_api.py b/test_news_api.py new file mode 100644 index 0000000000000000000000000000000000000000..7efa689a0e80a1d4e4bae7b5e3cfdb85dd7fb030 --- /dev/null +++ b/test_news_api.py @@ -0,0 +1,43 @@ +"""Test script for the NewsData API integration.""" +import asyncio +import sys +import os + +# Add backend directory to sys.path so we can import modules +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from config import settings +from services.news_lookup import search_news_full + +async def test_news(): + print(f"Testing News API Integration with key: {settings.NEWS_API_KEY[:6]}... (masked)") + + if not settings.NEWS_API_KEY: + print("ERROR: NEWS_API_KEY is empty in .env") + return + + keywords = ["modi", "election", "bjp", "congress"] + print(f"Searching for keywords: {keywords}") + + try: + result = await search_news_full(keywords, limit=5) + + print("\n=== RAW RESULT ===") + print(f"Total articles found: {result.total_articles}") + + print("\n=== TRUSTED SOURCES ===") + for i, source in enumerate(result.trusted_sources, 1): + date_str = str(source.published_at)[:10] if source.published_at else "Unknown date" + print(f"{i}. [{source.relevance_score}] {source.source_name}: {source.title[:60]}... ({date_str})") + + print("\n=== CONTRADICTING EVIDENCE / FACT CHECKS ===") + if not result.contradicting_evidence: + print("No fact-check articles found for these keywords.") + for i, ev in enumerate(result.contradicting_evidence, 1): + print(f"{i}. {ev.source_name}: {ev.title[:60]}...") + + except Exception as e: + print(f"\nERROR running test: {e}") + +if __name__ == "__main__": + asyncio.run(test_news()) diff --git a/test_phase5.py b/test_phase5.py new file mode 100644 index 0000000000000000000000000000000000000000..5feebb08a02a305c00c093b43e15e083073cf67e --- /dev/null +++ b/test_phase5.py @@ -0,0 +1,70 @@ +"""Phase 5 smoke: unit-test news_lookup classification + endpoint wiring.""" +from __future__ import annotations + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from services.news_lookup import ( + _domain_of, _is_factcheck, _relevance, search_news_full, +) + + +def test_domain(): + assert _domain_of("https://www.reuters.com/article/x") == "reuters.com" + assert _domain_of("https://snopes.com/fact-check/abc") == "snopes.com" + print("[OK] _domain_of") + + +def test_factcheck_detection(): + assert _is_factcheck("https://snopes.com/x", "Claim about moon") + assert _is_factcheck("https://factly.in/x", "") + assert _is_factcheck("https://example.com/x", "FACT CHECK: viral video debunked") + assert not _is_factcheck("https://bbc.com/news/world-123", "Election results") + print("[OK] _is_factcheck") + + +def test_relevance(): + assert _relevance("https://reuters.com/x") == 1.0 + assert _relevance("https://ndtv.com/x") == 0.85 + assert _relevance("https://random-blog.xyz/x") == 0.5 + print("[OK] _relevance weights") + + +async def test_empty_key_returns_empty(): + res = await search_news_full(["modi", "election"]) + assert res.trusted_sources == [] + assert res.contradicting_evidence == [] + assert res.total_articles == 0 + print(f"[OK] empty-key path -> {res}") + + +async def test_endpoint_wiring(): + import httpx + body = {"text": "BREAKING!!! You won't BELIEVE this SHOCKING miracle cure doctors don't want you to know!!! Click now!"} + async with httpx.AsyncClient(timeout=180.0) as c: + r = await c.post("http://127.0.0.1:8000/api/v1/analyze/text", json=body) + r.raise_for_status() + j = r.json() + assert j["media_type"] == "text" + assert "trusted_sources" in j + assert "contradicting_evidence" in j + assert "news_lookup" in j["processing_summary"]["stages_completed"] + print(f"[OK] /analyze/text -> verdict={j['verdict']['label']} " + f"score={j['verdict']['authenticity_score']} " + f"trusted={len(j['trusted_sources'])} contradictions={len(j['contradicting_evidence'])}") + + +async def main(): + test_domain() + test_factcheck_detection() + test_relevance() + await test_empty_key_returns_empty() + await test_endpoint_wiring() + print("\n=== Phase 5 smoke PASS ===") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/test_text_analysis.py b/test_text_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..3bedd8f87f8f7be25039e23d5abff27031c5a22c --- /dev/null +++ b/test_text_analysis.py @@ -0,0 +1,34 @@ +"""Quick smoke test for sensationalism + manipulation detection.""" +import sys +sys.path.insert(0, ".") + +from services.text_service import score_sensationalism, detect_manipulation_indicators + +# --- Sensationalism --- +text1 = "BREAKING: You wont believe this SHOCKING truth! Experts confirm the most DEVASTATING scandal exposed!!!" +s = score_sensationalism(text1) +print(f"Sensationalism: score={s.score} level={s.level}") +print(f" excl={s.exclamation_count} caps={s.caps_word_count} clickbait={s.clickbait_matches} emotional={s.emotional_word_count} superlative={s.superlative_count}") +assert s.score > 50, f"Expected high sensationalism, got {s.score}" +assert s.level in ("Medium", "High"), f"Expected Medium/High, got {s.level}" +print(" PASS") + +# --- Manipulation --- +text2 = "Sources say that experts confirm the shocking truth. Allegedly, everyone knows this is a proven fact." +m = detect_manipulation_indicators(text2) +print(f"\nManipulation indicators: {len(m)} found") +for ind in m: + print(f" [{ind.severity}] {ind.pattern_type}: \"{ind.matched_text}\"") +assert len(m) >= 3, f"Expected >=3 indicators, got {len(m)}" +print(" PASS") + +# --- Clean text --- +text3 = "The weather today is sunny with clear skies in New Delhi." +s2 = score_sensationalism(text3) +m2 = detect_manipulation_indicators(text3) +print(f"\nClean text: sensationalism={s2.score} ({s2.level}), manipulation={len(m2)}") +assert s2.score < 20, f"Expected low sensationalism for clean text, got {s2.score}" +assert len(m2) == 0, f"Expected 0 manipulation indicators for clean text, got {len(m2)}" +print(" PASS") + +print("\nAll tests passed!") diff --git a/text_service.py b/text_service.py new file mode 100644 index 0000000000000000000000000000000000000000..556ac48a8f0195fe5b41d3a789179979e778fae3 --- /dev/null +++ b/text_service.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from typing import List, Optional + +from loguru import logger + +from models.model_loader import get_model_loader + +FAKE_TOKENS = ("fake", "false", "unreliable", "misinformation") + +# --- Sensationalism patterns --- +CLICKBAIT_PATTERNS = [ + (r"\byou won'?t believe\b", "clickbait"), + (r"\bbreaking\s*:", "clickbait"), + (r"\bshocking\s*:", "clickbait"), + (r"\bexclusive\s*:", "clickbait"), + (r"\bjust\s+in\s*:", "clickbait"), + (r"\burgent\s*:", "clickbait"), + (r"\bwhat\s+happens\s+next\b", "clickbait"), + (r"\bthis\s+will\s+change\b", "clickbait"), + (r"\b(?:everyone|nobody)\s+(?:is|was)\s+talking\b", "clickbait"), +] +EMOTIONAL_WORDS = { + "outrage", "shocking", "horrifying", "disgusting", "amazing", "incredible", + "unbelievable", "devastating", "terrifying", "explosive", "bombshell", + "jaw-dropping", "heartbreaking", "furious", "scandal", "crisis", + "chaos", "destroyed", "slammed", "blasted", "exposed", "revealed", +} +SUPERLATIVES = { + "best", "worst", "greatest", "biggest", "most", "least", + "fastest", "deadliest", "largest", "smallest", "ultimate", +} + +# --- Manipulation indicator patterns --- +MANIPULATION_PATTERNS = [ + # Unverified claims + (r"\bsources?\s+(?:say|said|claim|report)\b", "unverified_claim", "medium", + "Unverified source attribution without specific citation"), + (r"\ballegedly\b", "unverified_claim", "low", + "Hedging language suggests unverified information"), + (r"\breports?\s+suggest\b", "unverified_claim", "medium", + "Vague report attribution"), + (r"\baccording\s+to\s+(?:some|many|several)\b", "unverified_claim", "medium", + "Non-specific source attribution"), + (r"\brunconfirmed\b", "unverified_claim", "medium", + "Explicitly unconfirmed information"), + # Emotional manipulation + (r"\boutrage\b", "emotional_manipulation", "medium", + "Emotional trigger word designed to provoke reaction"), + (r"\bshocking\s+truth\b", "emotional_manipulation", "high", + "Sensationalist phrase designed to manipulate reader emotion"), + (r"\bwake\s+up\b", "emotional_manipulation", "medium", + "Call-to-action implying hidden knowledge"), + (r"\bthey\s+don'?t\s+want\s+you\s+to\s+know\b", "emotional_manipulation", "high", + "Conspiracy framing language"), + (r"\bopen\s+your\s+eyes\b", "emotional_manipulation", "medium", + "Implies audience ignorance"), + # False authority + (r"\bexperts?\s+(?:confirm|say|agree|warn)\b", "false_authority", "medium", + "Unnamed expert citation without specific attribution"), + (r"\bscientists?\s+(?:confirm|prove|say)\b", "false_authority", "medium", + "Unnamed scientist citation"), + (r"\bstudies?\s+(?:show|prove|confirm)\b", "false_authority", "low", + "Vague study reference without citation"), + (r"\beveryone\s+knows\b", "false_authority", "medium", + "Appeal to common knowledge fallacy"), + (r"\bit'?s\s+(?:a\s+)?(?:well-?known|proven)\s+fact\b", "false_authority", "medium", + "Assertion of fact without evidence"), +] + +# NER entity labels to prefer for keyword extraction +_NER_PREFERRED = {"PERSON", "ORG", "GPE", "EVENT", "PRODUCT", "NORP"} + + +@dataclass +class TextClassification: + label: str + confidence: float + fake_prob: float + all_scores: dict[str, float] + + +@dataclass +class SensationalismResult: + score: int # 0-100 + level: str # Low / Medium / High + exclamation_count: int + caps_word_count: int + clickbait_matches: int + emotional_word_count: int + superlative_count: int + + +@dataclass +class ManipulationIndicator: + pattern_type: str # unverified_claim / emotional_manipulation / false_authority + matched_text: str + start_pos: int + end_pos: int + severity: str # low / medium / high + description: str + + +def detect_language(text: str) -> str: + """Detect the primary language of text using langdetect. + Returns ISO 639-1 code (e.g. 'en', 'hi'). Falls back to 'en' on failure. + """ + if not text or len(text.strip()) < 10: + return "en" + try: + from langdetect import detect # type: ignore + lang = detect(text.strip()) + logger.info(f"Language detected: {lang}") + return lang + except ImportError: + logger.debug("langdetect not installed — defaulting to 'en'") + return "en" + except Exception as e: + logger.debug(f"Language detection failed: {e} — defaulting to 'en'") + return "en" + + +def _scores_to_classification(items) -> TextClassification: + """Convert pipeline output to TextClassification.""" + scores = {i["label"]: float(i["score"]) for i in items} + top_label, top_conf = max(scores.items(), key=lambda kv: kv[1]) + # Extract fake probability + fake_prob = 0.0 + if "LABEL_0" in scores: + fake_prob = scores["LABEL_0"] + else: + fake_prob = max( + (p for lbl, p in scores.items() if any(t in lbl.lower() for t in FAKE_TOKENS)), + default=0.0, + ) + return TextClassification(top_label, top_conf, fake_prob, scores) + + +def classify_text(text: str, language: Optional[str] = None) -> TextClassification: + """Classify text as fake/real. + Routes to multilingual model when language is non-English and the model is configured. + """ + text = (text or "").strip() + if not text: + return TextClassification("unknown", 0.0, 0.0, {}) + + loader = get_model_loader() + + if language and language != "en": + pipe = loader.load_multilang_text_model() + else: + pipe = loader.load_text_model() + + out = pipe(text[:2000], truncation=True, top_k=None) + items = out[0] if isinstance(out[0], list) else out + clf = _scores_to_classification(items) + logger.info( + f"Text classify [{language or 'en'}] → {clf.label} @ {clf.confidence:.3f} " + f"fake_p={clf.fake_prob:.3f}" + ) + return clf + + +def score_sensationalism(text: str) -> SensationalismResult: + """Compute a 0-100 sensationalism score from structural/linguistic signals.""" + if not text: + return SensationalismResult(0, "Low", 0, 0, 0, 0, 0) + + words = text.split() + total_words = max(len(words), 1) + + excl = text.count("!") + caps = sum(1 for w in words if w.isupper() and len(w) > 2) + clickbait = sum( + 1 for pat, _ in CLICKBAIT_PATTERNS + if re.search(pat, text, re.IGNORECASE) + ) + emotional = sum(1 for w in words if w.lower().strip(".,!?;:") in EMOTIONAL_WORDS) + superlative = sum(1 for w in words if w.lower().strip(".,!?;:") in SUPERLATIVES) + + raw = ( + min(excl * 8, 25) + + min(caps / total_words * 200, 25) + + min(clickbait * 12, 25) + + min(emotional * 6, 15) + + min(superlative * 5, 10) + ) + score = int(min(100, max(0, raw))) + level = "Low" if score < 30 else ("Medium" if score < 60 else "High") + + logger.info(f"Sensationalism → {score} ({level}) excl={excl} caps={caps} cb={clickbait} emo={emotional}") + return SensationalismResult(score, level, excl, caps, clickbait, emotional, superlative) + + +def detect_manipulation_indicators(text: str) -> List[ManipulationIndicator]: + """Scan text for manipulation linguistic patterns with positions.""" + if not text: + return [] + indicators: List[ManipulationIndicator] = [] + for pattern, ptype, severity, description in MANIPULATION_PATTERNS: + for m in re.finditer(pattern, text, re.IGNORECASE): + indicators.append(ManipulationIndicator( + pattern_type=ptype, + matched_text=m.group(), + start_pos=m.start(), + end_pos=m.end(), + severity=severity, + description=description, + )) + indicators.sort(key=lambda i: i.start_pos) + logger.info(f"Manipulation indicators → {len(indicators)} found") + return indicators + + +def extract_entities(text: str, max_k: int = 6) -> List[str]: + """Extract keywords via spaCy NER (PERSON, ORG, GPE, EVENT preferred). + Falls back to frequency-based extraction when spaCy is unavailable or text is too short. + """ + if not text or len(text.strip()) < 20: + return _extract_keywords_freq(text, max_k) + + loader = get_model_loader() + nlp = loader.load_spacy_nlp() + + if nlp is None: + # spaCy not available — use frequency fallback + return _extract_keywords_freq(text, max_k) + + try: + doc = nlp(text[:5000]) # cap for performance + + # Collect named entities, preferring high-value types + preferred: List[str] = [] + other: List[str] = [] + seen: set[str] = set() + + for ent in doc.ents: + norm = ent.text.strip() + norm_lower = norm.lower() + if not norm or norm_lower in seen or len(norm) < 2: + continue + seen.add(norm_lower) + if ent.label_ in _NER_PREFERRED: + preferred.append(norm) + else: + other.append(norm) + + entities = preferred + other + + if len(entities) >= 2: + logger.info(f"NER extracted {len(entities)} entities: {entities[:max_k]}") + return entities[:max_k] + + # Not enough entities — supplement with frequency keywords + freq_kws = _extract_keywords_freq(text, max_k) + combined = entities + [k for k in freq_kws if k.lower() not in seen] + return combined[:max_k] + + except Exception as e: + logger.warning(f"spaCy NER failed: {e} — falling back to frequency extraction") + return _extract_keywords_freq(text, max_k) + + +def _extract_keywords_freq(text: str, max_k: int = 6) -> List[str]: + """Frequency-based keyword extraction (original implementation, kept as fallback).""" + stop = { + "the","a","an","is","are","was","were","be","been","being","to","of","and","or","but", + "in","on","at","for","with","by","from","as","that","this","it","its","has","have","had", + "will","would","can","could","should","may","might","do","does","did","not","no","so", + "than","then","there","their","they","them","we","our","you","your","he","she","his","her", + } + words = re.findall(r"[A-Za-z][A-Za-z\-']{2,}", text or "") + freq: dict[str, int] = {} + for w in words: + wl = w.lower() + if wl in stop: + continue + freq[wl] = freq.get(wl, 0) + 1 + return [w for w, _ in sorted(freq.items(), key=lambda kv: (-kv[1], kv[0]))[:max_k]] + + +# Back-compat alias: routes that still call extract_keywords get NER-first behaviour +extract_keywords = extract_entities diff --git a/v1/__init__.py b/v1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/v1/__pycache__/__init__.cpython-311.pyc b/v1/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73f05625ed549a176fc40b85d92dae6a5cb03ff6 Binary files /dev/null and b/v1/__pycache__/__init__.cpython-311.pyc differ diff --git a/v1/__pycache__/analyze.cpython-311.pyc b/v1/__pycache__/analyze.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0872fa4fc6033ab1c83d83443d1b7f83821ab961 Binary files /dev/null and b/v1/__pycache__/analyze.cpython-311.pyc differ diff --git a/v1/__pycache__/auth.cpython-311.pyc b/v1/__pycache__/auth.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bb52bcb50707cbd7326e8dcb6aa06a44cdd323e Binary files /dev/null and b/v1/__pycache__/auth.cpython-311.pyc differ diff --git a/v1/__pycache__/health.cpython-311.pyc b/v1/__pycache__/health.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1546b02f22634cce20396352c2729418b411e89 Binary files /dev/null and b/v1/__pycache__/health.cpython-311.pyc differ diff --git a/v1/__pycache__/history.cpython-311.pyc b/v1/__pycache__/history.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff4b5cf59e3ab6fa3e0fa3c35ef4beba4f377c0d Binary files /dev/null and b/v1/__pycache__/history.cpython-311.pyc differ diff --git a/v1/__pycache__/report.cpython-311.pyc b/v1/__pycache__/report.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b306d1542a167f78b0d17bd576448755943635e Binary files /dev/null and b/v1/__pycache__/report.cpython-311.pyc differ diff --git a/v1/analyze.py b/v1/analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..9663da62c33140dc37ca3fc24830458d4ad7d3fa --- /dev/null +++ b/v1/analyze.py @@ -0,0 +1,581 @@ +from __future__ import annotations + +import json +import os +import time +import uuid +from datetime import datetime, timezone + +from fastapi import APIRouter, Body, Depends, File, UploadFile +from pydantic import BaseModel +from loguru import logger +from sqlalchemy.orm import Session + +from api.deps import optional_current_user +from config import settings +from db.database import get_db +from db.models import AnalysisRecord, User +from models.heatmap_generator import generate_heatmap_base64, generate_boxes_base64 +from schemas.analyze import ( + FrameAnalysisOut, + ImageAnalysisResponse, + ImageExplainability, + LayoutAnomalyOut, + ManipulationIndicatorOut, + OCRBoxOut, + ScreenshotAnalysisResponse, + ScreenshotExplainability, + SensationalismBreakdown, + SuspiciousPhraseOut, + TextAnalysisResponse, + TextExplainability, + VideoAnalysisResponse, + VideoExplainability, +) +from services.screenshot_service import ( + detect_layout_anomalies, + extract_full_text, + map_phrases_to_boxes, + run_ocr, +) +from services.ela_service import generate_ela_base64 +from services.exif_service import extract_exif +from services.image_service import load_image_from_bytes +from services.llm_explainer import generate_llm_summary +from schemas.common import ProcessingSummary, Verdict +from services.artifact_detector import scan_artifacts +from services.image_service import preprocess_and_classify +from services.news_lookup import search_news_full +from services.vlm_breakdown import generate_vlm_breakdown +from services.text_service import ( + classify_text, + detect_language, + detect_manipulation_indicators, + extract_entities, + score_sensationalism, +) +from services.video_service import analyze_video +from utils.file_handler import read_upload_bytes, save_upload_to_tempfile +from utils.scoring import compute_authenticity_score, get_verdict_label + +router = APIRouter(prefix="/analyze", tags=["analyze"]) + +IMAGE_MAX_MB = 20 +VIDEO_MAX_MB = 100 +VIDEO_NUM_FRAMES = 16 + + +@router.post("/image", response_model=ImageAnalysisResponse) +async def analyze_image( + file: UploadFile = File(...), + db: Session = Depends(get_db), + user: User | None = Depends(optional_current_user), +) -> ImageAnalysisResponse: + start = time.perf_counter() + stages: list[str] = [] + + raw, mime = await read_upload_bytes( + file, settings.ALLOWED_IMAGE_TYPES, max_size_mb=IMAGE_MAX_MB + ) + stages.append("validation") + + pil, clf = preprocess_and_classify(raw) + stages.append("classification") + + indicators = scan_artifacts(pil, raw) + stages.append("artifact_scanning") + + # ── Phase 12: Grad-CAM++ heatmap ── + heatmap_status = "success" + heatmap = "" + try: + heatmap = generate_heatmap_base64(pil) + stages.append("heatmap_generation") + except Exception as e: # noqa: BLE001 + logger.warning(f"Heatmap generation failed, continuing: {e}") + heatmap_status = "failed" + + # ── Phase 12: ELA (Error Level Analysis) ── + ela_b64 = "" + try: + ela_b64 = generate_ela_base64(pil) + stages.append("ela_generation") + except Exception as e: # noqa: BLE001 + logger.warning(f"ELA generation failed, continuing: {e}") + + # ── Phase 12: Bounding box mode ── + boxes_b64 = "" + try: + boxes_b64 = generate_boxes_base64(pil) + stages.append("boxes_generation") + except Exception as e: # noqa: BLE001 + logger.warning(f"Bounding box generation failed, continuing: {e}") + + # ── Phase 12: EXIF extraction + trust adjustment ── + exif_summary = None + try: + exif_summary = extract_exif(pil, raw) + stages.append("exif_extraction") + except Exception as e: # noqa: BLE001 + logger.warning(f"EXIF extraction failed, continuing: {e}") + + score = compute_authenticity_score(clf.confidence, clf.label) + + # Apply EXIF trust adjustment to the score + if exif_summary and exif_summary.trust_adjustment != 0: + score = int(round(max(0, min(100, score + exif_summary.trust_adjustment)))) + + label, severity = get_verdict_label(score) + duration_ms = int((time.perf_counter() - start) * 1000) + + analysis_id = str(uuid.uuid4()) + + response = ImageAnalysisResponse( + analysis_id=analysis_id, + media_type="image", + timestamp=datetime.now(timezone.utc).isoformat(), + verdict=Verdict( + label=label, + severity=severity, + authenticity_score=score, + model_confidence=clf.confidence, + model_label=clf.label, + ), + explainability=ImageExplainability( + heatmap_base64=heatmap, + ela_base64=ela_b64, + boxes_base64=boxes_b64, + heatmap_status=heatmap_status, + artifact_indicators=indicators, + exif=exif_summary, + ), + trusted_sources=[], + contradicting_evidence=[], + processing_summary=ProcessingSummary( + stages_completed=stages, + total_duration_ms=duration_ms, + model_used=settings.IMAGE_MODEL_ID, + ), + ) + + record = AnalysisRecord( + user_id=user.id if user else None, + media_type="image", + verdict=label, + authenticity_score=float(score), + result_json=json.dumps(response.model_dump( + exclude={"explainability": {"heatmap_base64", "ela_base64", "boxes_base64"}} + )), + ) + db.add(record) + db.commit() + db.refresh(record) + response.record_id = record.id + logger.info(f"Saved AnalysisRecord id={record.id} score={score} verdict={label}") + + # ── Phase 12: LLM explainability card (runs after DB save so we have record_id) ── + try: + llm_summary = generate_llm_summary( + payload=response.model_dump( + exclude={"explainability": {"heatmap_base64", "ela_base64", "boxes_base64"}} + ), + record_id=str(record.id), + ) + response.explainability.llm_summary = llm_summary + stages.append("llm_explanation") + except Exception as e: # noqa: BLE001 + logger.warning(f"LLM explainer failed, continuing: {e}") + + # ── Phase 14: VLM detailed breakdown (vision LLM scores 6 perceptual components) ── + try: + vlm_bd = generate_vlm_breakdown(pil, record_id=str(record.id)) + if vlm_bd: + response.explainability.vlm_breakdown = vlm_bd + stages.append("vlm_breakdown") + except Exception as e: # noqa: BLE001 + logger.warning(f"VLM breakdown failed, continuing: {e}") + + return response + + +@router.post("/video", response_model=VideoAnalysisResponse) +async def analyze_video_endpoint( + file: UploadFile = File(...), + db: Session = Depends(get_db), + user: User | None = Depends(optional_current_user), +) -> VideoAnalysisResponse: + start = time.perf_counter() + stages: list[str] = [] + + suffix = os.path.splitext(file.filename or "")[1].lower() or ".mp4" + path, mime = await save_upload_to_tempfile( + file, settings.ALLOWED_VIDEO_TYPES, max_size_mb=VIDEO_MAX_MB, suffix=suffix + ) + stages.append("validation") + + try: + agg = analyze_video(path, num_frames=VIDEO_NUM_FRAMES) + stages.append("frame_extraction") + stages.append("frame_classification") + stages.append("aggregation") + finally: + try: + os.unlink(path) + except OSError: + pass + + if agg.insufficient_faces: + score = 50 + label = "Insufficient face content" + severity = "warning" + else: + score = int(round(max(0.0, min(100.0, (1.0 - agg.mean_suspicious_prob) * 100.0)))) + label, severity = get_verdict_label(score) + duration_ms = int((time.perf_counter() - start) * 1000) + + response = VideoAnalysisResponse( + analysis_id=str(uuid.uuid4()), + media_type="video", + timestamp=datetime.now(timezone.utc).isoformat(), + verdict=Verdict( + label=label, + severity=severity, + authenticity_score=score, + model_confidence=float(agg.mean_suspicious_prob), + model_label="suspicious_mean" if not agg.insufficient_faces else "no_faces", + ), + explainability=VideoExplainability( + num_frames_sampled=agg.num_frames_sampled, + num_face_frames=agg.num_face_frames, + num_suspicious_frames=agg.num_suspicious_frames, + mean_suspicious_prob=agg.mean_suspicious_prob, + max_suspicious_prob=agg.max_suspicious_prob, + suspicious_ratio=agg.suspicious_ratio, + insufficient_faces=agg.insufficient_faces, + suspicious_timestamps=agg.suspicious_timestamps, + frames=[ + FrameAnalysisOut( + index=f.index, + timestamp_s=f.timestamp_s, + label=f.label, + confidence=f.confidence, + suspicious_prob=f.suspicious_prob, + is_suspicious=f.is_suspicious, + has_face=f.has_face, + scored=f.scored, + ) + for f in agg.frames + ], + ), + processing_summary=ProcessingSummary( + stages_completed=stages, + total_duration_ms=duration_ms, + model_used=settings.IMAGE_MODEL_ID, + ), + ) + + record = AnalysisRecord( + user_id=user.id if user else None, + media_type="video", + verdict=label, + authenticity_score=float(score), + result_json=json.dumps(response.model_dump()), + ) + db.add(record) + db.commit() + db.refresh(record) + response.record_id = record.id + logger.info( + f"Saved AnalysisRecord id={record.id} video score={score} verdict={label} " + f"frames={agg.num_frames_sampled} susp={agg.num_suspicious_frames}" + ) + + # Phase 12: LLM explainability card + try: + response.llm_summary = generate_llm_summary( + payload=response.model_dump(), record_id=str(record.id), + ) + except Exception as e: # noqa: BLE001 + logger.warning(f"LLM explainer failed for video: {e}") + + return response + + +class TextAnalyzeBody(BaseModel): + text: str + + +@router.post("/text", response_model=TextAnalysisResponse) +async def analyze_text_endpoint( + body: TextAnalyzeBody = Body(...), + db: Session = Depends(get_db), + user: User | None = Depends(optional_current_user), +) -> TextAnalysisResponse: + start = time.perf_counter() + stages: list[str] = [] + + # Phase 13: language detection — routes to multilang model when non-English + lang = detect_language(body.text) + stages.append("language_detection") + + clf = classify_text(body.text, language=lang) + stages.append("classification") + + sens = score_sensationalism(body.text) + stages.append("sensationalism_analysis") + + manip = detect_manipulation_indicators(body.text) + stages.append("manipulation_detection") + + # Phase 13.1: NER-based keyword extraction (spaCy entities first, frequency fallback) + keywords = extract_entities(body.text) + stages.append("ner_keyword_extraction") + + # Phase 13.2: pass original text + current fake_prob for truth-override computation + news = await search_news_full( + keywords, + original_text=body.text, + current_fake_prob=clf.fake_prob, + ) + stages.append("news_lookup") + + # Apply truth-override to fake_prob before scoring + effective_fake_prob = clf.fake_prob + if news.truth_override and news.truth_override.applied: + effective_fake_prob = news.truth_override.fake_prob_after + stages.append("truth_override_applied") + + # Weighted score: 70% classifier + 20% inverse sensationalism + 10% manipulation penalty + manip_penalty = min(len(manip) * 5, 30) + raw_score = (1.0 - effective_fake_prob) * 100.0 + weighted = raw_score * 0.70 + max(0, 100 - sens.score) * 0.20 + max(0, 100 - manip_penalty) * 0.10 + score = int(round(max(0.0, min(100.0, weighted)))) + label, severity = get_verdict_label(score) + duration_ms = int((time.perf_counter() - start) * 1000) + + model_used = ( + settings.TEXT_MULTILANG_MODEL_ID if (lang != "en" and settings.TEXT_MULTILANG_MODEL_ID) + else settings.TEXT_MODEL_ID + ) + + response = TextAnalysisResponse( + analysis_id=str(uuid.uuid4()), + media_type="text", + timestamp=datetime.now(timezone.utc).isoformat(), + verdict=Verdict( + label=label, + severity=severity, + authenticity_score=score, + model_confidence=float(clf.confidence), + model_label=clf.label, + ), + explainability=TextExplainability( + fake_probability=effective_fake_prob, + top_label=clf.label, + all_scores=clf.all_scores, + keywords=keywords, + sensationalism=SensationalismBreakdown( + score=sens.score, + level=sens.level, + exclamation_count=sens.exclamation_count, + caps_word_count=sens.caps_word_count, + clickbait_matches=sens.clickbait_matches, + emotional_word_count=sens.emotional_word_count, + superlative_count=sens.superlative_count, + ), + manipulation_indicators=[ + ManipulationIndicatorOut( + pattern_type=m.pattern_type, + matched_text=m.matched_text, + start_pos=m.start_pos, + end_pos=m.end_pos, + severity=m.severity, + description=m.description, + ) + for m in manip + ], + detected_language=lang, + truth_override=news.truth_override, + ), + trusted_sources=news.trusted_sources, + contradicting_evidence=news.contradicting_evidence, + processing_summary=ProcessingSummary( + stages_completed=stages, + total_duration_ms=duration_ms, + model_used=model_used, + ), + ) + + record = AnalysisRecord( + user_id=user.id if user else None, + media_type="text", + verdict=label, + authenticity_score=float(score), + result_json=json.dumps(response.model_dump()), + ) + db.add(record) + db.commit() + db.refresh(record) + response.record_id = record.id + logger.info(f"Saved AnalysisRecord id={record.id} text score={score} verdict={label}") + + # Phase 12: LLM explainability card + try: + response.llm_summary = generate_llm_summary( + payload=response.model_dump(), record_id=str(record.id), + ) + except Exception as e: # noqa: BLE001 + logger.warning(f"LLM explainer failed for text: {e}") + + return response + + +@router.post("/screenshot", response_model=ScreenshotAnalysisResponse) +async def analyze_screenshot_endpoint( + file: UploadFile = File(...), + db: Session = Depends(get_db), + user: User | None = Depends(optional_current_user), +) -> ScreenshotAnalysisResponse: + start = time.perf_counter() + stages: list[str] = [] + + raw, mime = await read_upload_bytes( + file, settings.ALLOWED_IMAGE_TYPES, max_size_mb=IMAGE_MAX_MB + ) + stages.append("validation") + + pil = load_image_from_bytes(raw) + ocr_boxes = run_ocr(pil) + stages.append("ocr") + + full_text = extract_full_text(ocr_boxes) + + # Phase 13: language detection on extracted OCR text + lang = detect_language(full_text) if full_text else "en" + stages.append("language_detection") + + clf = classify_text(full_text, language=lang) if full_text else None + stages.append("classification") + + sens = score_sensationalism(full_text) + stages.append("sensationalism_analysis") + + manip = detect_manipulation_indicators(full_text) + stages.append("manipulation_detection") + + phrases = map_phrases_to_boxes(ocr_boxes, manip) + stages.append("phrase_overlay_mapping") + + layout = detect_layout_anomalies(ocr_boxes) + stages.append("layout_anomaly_detection") + + # Phase 13.1: NER-based keyword extraction + keywords = extract_entities(full_text) + stages.append("ner_keyword_extraction") + + fake_prob = clf.fake_prob if clf else 0.0 + model_conf = clf.confidence if clf else 0.0 + model_lbl = clf.label if clf else "no_text" + + # Phase 13.2: truth-override via cosine similarity + news = await search_news_full( + keywords, + original_text=full_text, + current_fake_prob=fake_prob, + ) + stages.append("news_lookup") + + effective_fake_prob = fake_prob + if news.truth_override and news.truth_override.applied: + effective_fake_prob = news.truth_override.fake_prob_after + stages.append("truth_override_applied") + + manip_penalty = min(len(manip) * 5, 30) + layout_penalty = min(len(layout) * 5, 15) + raw_score = (1.0 - effective_fake_prob) * 100.0 + weighted = ( + raw_score * 0.65 + + max(0, 100 - sens.score) * 0.20 + + max(0, 100 - manip_penalty) * 0.10 + + max(0, 100 - layout_penalty) * 0.05 + ) + if not full_text.strip(): + weighted = 50 + score = int(round(max(0.0, min(100.0, weighted)))) + label, severity = get_verdict_label(score) + duration_ms = int((time.perf_counter() - start) * 1000) + + model_used_str = ( + f"{settings.TEXT_MULTILANG_MODEL_ID} + EasyOCR" + if (lang != "en" and settings.TEXT_MULTILANG_MODEL_ID) + else f"{settings.TEXT_MODEL_ID} + EasyOCR" + ) + + response = ScreenshotAnalysisResponse( + analysis_id=str(uuid.uuid4()), + media_type="screenshot", + timestamp=datetime.now(timezone.utc).isoformat(), + verdict=Verdict( + label=label, + severity=severity, + authenticity_score=score, + model_confidence=float(model_conf), + model_label=model_lbl, + ), + explainability=ScreenshotExplainability( + extracted_text=full_text, + ocr_boxes=[OCRBoxOut(text=b.text, bbox=b.bbox, confidence=b.confidence) for b in ocr_boxes], + fake_probability=effective_fake_prob, + sensationalism=SensationalismBreakdown( + score=sens.score, level=sens.level, + exclamation_count=sens.exclamation_count, caps_word_count=sens.caps_word_count, + clickbait_matches=sens.clickbait_matches, emotional_word_count=sens.emotional_word_count, + superlative_count=sens.superlative_count, + ), + suspicious_phrases=[ + SuspiciousPhraseOut( + text=p.text, bbox=p.bbox, pattern_type=p.pattern_type, + severity=p.severity, description=p.description, + ) for p in phrases + ], + layout_anomalies=[ + LayoutAnomalyOut( + type=la.type, severity=la.severity, + description=la.description, confidence=la.confidence, + ) for la in layout + ], + keywords=keywords, + detected_language=lang, + truth_override=news.truth_override, + ), + trusted_sources=news.trusted_sources, + contradicting_evidence=news.contradicting_evidence, + processing_summary=ProcessingSummary( + stages_completed=stages, + total_duration_ms=duration_ms, + model_used=model_used_str, + ), + ) + + record = AnalysisRecord( + user_id=user.id if user else None, + media_type="screenshot", + verdict=label, + authenticity_score=float(score), + result_json=json.dumps(response.model_dump()), + ) + db.add(record) + db.commit() + db.refresh(record) + response.record_id = record.id + logger.info(f"Saved AnalysisRecord id={record.id} screenshot score={score} verdict={label}") + + # Phase 12: LLM explainability card + try: + response.llm_summary = generate_llm_summary( + payload=response.model_dump(), record_id=str(record.id), + ) + except Exception as e: # noqa: BLE001 + logger.warning(f"LLM explainer failed for screenshot: {e}") + + return response diff --git a/v1/auth.py b/v1/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..c8e61ca0cbc138c8e1e2b7420996e4892e46468e --- /dev/null +++ b/v1/auth.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from fastapi import APIRouter, Depends, HTTPException, status +from loguru import logger +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from api.deps import get_current_user +from config import settings +from db.database import get_db +from db.models import User +from schemas.auth import LoginBody, RegisterBody, TokenResponse, UserOut +from services.auth_service import authenticate, create_access_token, register_user + +router = APIRouter(prefix="/auth", tags=["auth"]) + + +def _token_response(user: User) -> TokenResponse: + return TokenResponse( + access_token=create_access_token(user.id, user.email), + expires_in_minutes=settings.JWT_EXPIRATION_MINUTES, + user=UserOut(id=user.id, email=user.email, name=user.name, created_at=user.created_at), + ) + + +@router.post("/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED) +def register(body: RegisterBody, db: Session = Depends(get_db)) -> TokenResponse: + try: + user = register_user(db, body.email, body.password, body.name) + except IntegrityError: + db.rollback() + raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered") + logger.info(f"Registered user id={user.id} email={user.email}") + return _token_response(user) + + +@router.post("/login", response_model=TokenResponse) +def login(body: LoginBody, db: Session = Depends(get_db)) -> TokenResponse: + user = authenticate(db, body.email, body.password) + if not user: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid email or password") + logger.info(f"Login user id={user.id} email={user.email}") + return _token_response(user) + + +@router.get("/me", response_model=UserOut) +def me(user: User = Depends(get_current_user)) -> UserOut: + return UserOut(id=user.id, email=user.email, name=user.name, created_at=user.created_at) diff --git a/v1/health.py b/v1/health.py new file mode 100644 index 0000000000000000000000000000000000000000..b02fd9845b16997bda02ffe00e91915c4c043533 --- /dev/null +++ b/v1/health.py @@ -0,0 +1,8 @@ +from fastapi import APIRouter + +router = APIRouter(tags=["health"]) + + +@router.get("/health") +def health(): + return {"status": "ok", "service": "deepshield-backend"} diff --git a/v1/history.py b/v1/history.py new file mode 100644 index 0000000000000000000000000000000000000000..db70c77e068a5e4f8070caddc011504868912493 --- /dev/null +++ b/v1/history.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import json +from datetime import datetime + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from api.deps import get_current_user +from db.database import get_db +from db.models import AnalysisRecord, User + +router = APIRouter(prefix="/history", tags=["history"]) + + +class HistoryItem(BaseModel): + id: int + media_type: str + verdict: str + authenticity_score: float + created_at: datetime + + +class HistoryListResponse(BaseModel): + items: list[HistoryItem] + total: int + + +@router.get("", response_model=HistoryListResponse) +def list_history( + limit: int = Query(default=50, ge=1, le=200), + offset: int = Query(default=0, ge=0), + user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> HistoryListResponse: + q = db.query(AnalysisRecord).filter(AnalysisRecord.user_id == user.id) + total = q.count() + rows = q.order_by(AnalysisRecord.created_at.desc()).offset(offset).limit(limit).all() + items = [ + HistoryItem( + id=r.id, + media_type=r.media_type, + verdict=r.verdict, + authenticity_score=r.authenticity_score, + created_at=r.created_at, + ) + for r in rows + ] + return HistoryListResponse(items=items, total=total) + + +@router.get("/{record_id}") +def get_history_detail( + record_id: int, + user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + r = db.query(AnalysisRecord).filter(AnalysisRecord.id == record_id).first() + if not r or r.user_id != user.id: + raise HTTPException(status.HTTP_404_NOT_FOUND, "Analysis not found") + try: + return json.loads(r.result_json) + except Exception: + raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Corrupt result payload") + + +@router.delete("/{record_id}", status_code=status.HTTP_204_NO_CONTENT) +def delete_history( + record_id: int, + user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + r = db.query(AnalysisRecord).filter(AnalysisRecord.id == record_id).first() + if not r or r.user_id != user.id: + raise HTTPException(status.HTTP_404_NOT_FOUND, "Analysis not found") + db.delete(r) + db.commit() + return None diff --git a/v1/report.py b/v1/report.py new file mode 100644 index 0000000000000000000000000000000000000000..72a34c8165dbd78f8e474afdc6d9df77d6e54494 --- /dev/null +++ b/v1/report.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from pathlib import Path + +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import FileResponse +from loguru import logger +from sqlalchemy.orm import Session + +from db.database import get_db +from db.models import AnalysisRecord, Report +from services.report_service import cleanup_expired, create_report_row, generate_report + +router = APIRouter(prefix="/report", tags=["report"]) + + +@router.post("/{analysis_id}") +def generate(analysis_id: int, db: Session = Depends(get_db)): + record = db.query(AnalysisRecord).filter(AnalysisRecord.id == analysis_id).first() + if not record: + raise HTTPException(status_code=404, detail="analysis not found") + + existing = db.query(Report).filter(Report.analysis_id == analysis_id).first() + if existing and Path(existing.file_path).exists(): + return {"report_id": existing.id, "analysis_id": analysis_id, "ready": True} + + try: + path = generate_report(record) + except Exception as e: # noqa: BLE001 + logger.exception(f"Report generation failed: {e}") + raise HTTPException(status_code=500, detail=f"report generation failed: {e}") + + if existing: + existing.file_path = str(path) + db.commit() + db.refresh(existing) + return {"report_id": existing.id, "analysis_id": analysis_id, "ready": True} + + row = create_report_row(analysis_id, path) + db.add(row) + db.commit() + db.refresh(row) + return {"report_id": row.id, "analysis_id": analysis_id, "ready": True} + + +@router.get("/{analysis_id}/download") +def download(analysis_id: int, db: Session = Depends(get_db)): + row = db.query(Report).filter(Report.analysis_id == analysis_id).first() + if not row: + raise HTTPException(status_code=404, detail="report not found — generate first") + p = Path(row.file_path) + if not p.exists(): + raise HTTPException(status_code=410, detail="report expired or missing") + return FileResponse( + path=str(p), + media_type="application/pdf", + filename=f"deepshield_report_{analysis_id}.pdf", + ) + + +@router.post("/cleanup") +def cleanup(): + n = cleanup_expired() + return {"deleted": n} diff --git a/video_service.py b/video_service.py new file mode 100644 index 0000000000000000000000000000000000000000..b1334fe682462f59c7eb3486c907ff033142b99e --- /dev/null +++ b/video_service.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import List, Tuple + +import cv2 +import numpy as np +from loguru import logger +from PIL import Image + +from models.model_loader import get_model_loader +from services.image_service import classify_image + + +@dataclass +class FrameAnalysis: + index: int + timestamp_s: float + label: str + confidence: float + suspicious_prob: float # prob of the fake/manipulated class + is_suspicious: bool + has_face: bool = False + scored: bool = False # contributed to aggregate (face frames only) + + +@dataclass +class VideoAggregation: + num_frames_sampled: int + num_face_frames: int + num_suspicious_frames: int + mean_suspicious_prob: float + max_suspicious_prob: float + suspicious_ratio: float + insufficient_faces: bool + suspicious_timestamps: List[float] = field(default_factory=list) + frames: List[FrameAnalysis] = field(default_factory=list) + + +FAKE_TOKENS = ("fake", "deepfake", "manipulated", "ai", "generated", "synthetic") + + +def _is_fake_label(label: str) -> bool: + l = label.lower() + return any(tok in l for tok in FAKE_TOKENS) + + +def extract_frames(video_path: str, num_frames: int = 16) -> List[Tuple[int, float, Image.Image]]: + """Uniformly sample num_frames frames from the video. Returns list of + (frame_index, timestamp_seconds, PIL.Image). + """ + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise RuntimeError(f"Failed to open video: {video_path}") + + total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) + fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0) + if total <= 0: + cap.release() + raise RuntimeError("Video appears to have 0 frames") + + n = min(num_frames, total) + indices = np.linspace(0, max(0, total - 1), num=n, dtype=int).tolist() + + out: List[Tuple[int, float, Image.Image]] = [] + for idx in indices: + cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx)) + ok, frame_bgr = cap.read() + if not ok or frame_bgr is None: + continue + frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) + pil = Image.fromarray(frame_rgb) + ts = (idx / fps) if fps > 0 else 0.0 + out.append((int(idx), float(ts), pil)) + + cap.release() + logger.info(f"Extracted {len(out)}/{n} frames from video (total={total}, fps={fps:.2f})") + return out + + +MIN_FACE_FRAMES = 3 # below this we refuse to issue a deepfake verdict + + +def _has_face(pil: Image.Image) -> bool: + detector = get_model_loader().load_face_detector() + arr = np.array(pil) + res = detector.process(arr) + return bool(getattr(res, "multi_face_landmarks", None)) + + +def classify_frames(frames: List[Tuple[int, float, Image.Image]]) -> List[FrameAnalysis]: + results: List[FrameAnalysis] = [] + for idx, ts, pil in frames: + face = _has_face(pil) + clf = classify_image(pil) + fake_prob = 0.0 + for lbl, p in clf.all_scores.items(): + if _is_fake_label(lbl): + fake_prob = max(fake_prob, float(p)) + results.append( + FrameAnalysis( + index=idx, + timestamp_s=ts, + label=clf.label, + confidence=clf.confidence, + suspicious_prob=fake_prob, + is_suspicious=(fake_prob >= 0.5) and face, + has_face=face, + scored=face, + ) + ) + return results + + +def aggregate(frames: List[FrameAnalysis]) -> VideoAggregation: + if not frames: + return VideoAggregation(0, 0, 0, 0.0, 0.0, 0.0, True) + + scored = [f for f in frames if f.scored] + num_face = len(scored) + insufficient = num_face < MIN_FACE_FRAMES + + if insufficient: + mean_p = 0.0 + max_p = 0.0 + susp_ratio = 0.0 + susp: List[FrameAnalysis] = [] + else: + probs = [f.suspicious_prob for f in scored] + susp = [f for f in scored if f.is_suspicious] + mean_p = float(np.mean(probs)) + max_p = float(np.max(probs)) + susp_ratio = len(susp) / len(scored) + + return VideoAggregation( + num_frames_sampled=len(frames), + num_face_frames=num_face, + num_suspicious_frames=len(susp), + mean_suspicious_prob=mean_p, + max_suspicious_prob=max_p, + suspicious_ratio=susp_ratio, + insufficient_faces=insufficient, + suspicious_timestamps=[round(f.timestamp_s, 2) for f in susp], + frames=frames, + ) + + +def analyze_video(video_path: str, num_frames: int = 16) -> VideoAggregation: + frames = extract_frames(video_path, num_frames=num_frames) + classified = classify_frames(frames) + return aggregate(classified) diff --git a/vlm_breakdown.py b/vlm_breakdown.py new file mode 100644 index 0000000000000000000000000000000000000000..50ab212b81579d5eeeb475d0191d603fa52809ca --- /dev/null +++ b/vlm_breakdown.py @@ -0,0 +1,138 @@ +"""VLM Detailed Breakdown — Phase 14.1 + +Calls a vision-capable LLM (Gemini or OpenAI) to score 6 perceptual +components of an image for deepfake forensics. Cached per record_id. +""" +from __future__ import annotations + +import json +from io import BytesIO +from typing import Any + +from loguru import logger +from PIL import Image + +from config import settings +from schemas.common import VLMBreakdown, VLMComponentScore + +_cache: dict[str, VLMBreakdown] = {} + +_PROMPT = """\ +You are DeepShield's deepfake forensics engine. Analyze this image and score \ +each component for visual authenticity. + +Output ONLY valid JSON (no markdown fences, no extra text): +{ + "facial_symmetry": {"score": <0-100>, "notes": ""}, + "skin_texture": {"score": <0-100>, "notes": ""}, + "lighting_consistency": {"score": <0-100>, "notes": ""}, + "background_coherence": {"score": <0-100>, "notes": ""}, + "anatomy_hands_eyes": {"score": <0-100>, "notes": ""}, + "context_objects": {"score": <0-100>, "notes": ""} +} + +Scoring rules: +- 100 = perfectly natural/authentic for that component +- 0 = clear manipulation artifact for that component +- Score each independently based only on visual evidence in this image +- If a component is not visible (e.g. no hands present), score 75 and note "not visible in image" +""" + + +def _parse_response(raw: str) -> dict[str, Any]: + text = raw.strip() + if text.startswith("```"): + lines = [ln for ln in text.split("\n") if not ln.strip().startswith("```")] + text = "\n".join(lines).strip() + return json.loads(text) + + +def _to_component(d: Any) -> VLMComponentScore: + if isinstance(d, dict): + return VLMComponentScore( + score=max(0, min(100, int(d.get("score", 75)))), + notes=str(d.get("notes", ""))[:200], + ) + return VLMComponentScore() + + +def _build_breakdown(data: dict[str, Any]) -> VLMBreakdown: + return VLMBreakdown( + facial_symmetry=_to_component(data.get("facial_symmetry")), + skin_texture=_to_component(data.get("skin_texture")), + lighting_consistency=_to_component(data.get("lighting_consistency")), + background_coherence=_to_component(data.get("background_coherence")), + anatomy_hands_eyes=_to_component(data.get("anatomy_hands_eyes")), + context_objects=_to_component(data.get("context_objects")), + ) + + +def generate_vlm_breakdown( + image: Image.Image, + record_id: str | None = None, +) -> VLMBreakdown | None: + """Score 6 perceptual components via vision LLM. Returns None when unconfigured.""" + if record_id and record_id in _cache: + cached = _cache[record_id] + cached.cached = True + return cached + + if not settings.LLM_API_KEY: + logger.debug("LLM_API_KEY not set — skipping VLM breakdown") + return None + + provider = settings.LLM_PROVIDER.lower() + model_id = settings.LLM_MODEL + + try: + if provider == "openai": + breakdown = _call_openai(image, model_id) + else: + breakdown = _call_gemini(image, model_id) + + breakdown.model_used = f"{provider}/{model_id}" + if record_id: + _cache[record_id] = breakdown + + logger.info(f"VLM breakdown generated via {provider}/{model_id}") + return breakdown + + except json.JSONDecodeError as e: + logger.error(f"VLM breakdown: unparseable JSON from LLM: {e}") + return None + except Exception as e: + logger.error(f"VLM breakdown failed: {e}") + return None + + +def _call_gemini(image: Image.Image, model_id: str) -> VLMBreakdown: + import google.generativeai as genai # type: ignore + genai.configure(api_key=settings.LLM_API_KEY) + model = genai.GenerativeModel(model_id) + response = model.generate_content([_PROMPT, image]) + return _build_breakdown(_parse_response(response.text)) + + +def _call_openai(image: Image.Image, model_id: str) -> VLMBreakdown: + import base64 + from openai import OpenAI # type: ignore + + buf = BytesIO() + img = image.convert("RGB") + img.save(buf, format="JPEG", quality=85) + b64 = base64.b64encode(buf.getvalue()).decode() + + client = OpenAI(api_key=settings.LLM_API_KEY) + response = client.chat.completions.create( + model=model_id, + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": _PROMPT}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64}"}}, + ], + }], + temperature=0.2, + max_tokens=400, + ) + return _build_breakdown(_parse_response(response.choices[0].message.content))