import streamlit as st import cv2 import os import tempfile import numpy as np import torch import torch.nn.functional as F from torchvision import transforms import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import timm from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from facenet_pytorch import MTCNN from deep_sort_realtime.deepsort_tracker import DeepSort from frequency_analyzer import analyze_fft # ── ViT Grad-CAM reshape helper ────────────────────────────────────────────── # # ViT внутри represents an image as a flat sequence of patches: # [Batch, 1 + 14×14, Embedding_dim] (the leading 1 is the [CLS] token) # Standard GradCAM expects [Batch, Channels, Height, Width], so we: # 1. Drop the [CLS] token (index 0) # 2. Reshape the 196 patch tokens → 14×14 spatial grid # 3. Permute axes to [B, C, H, W] # 14 = 224 (input size) / 16 (ViT patch size) def reshape_transform(tensor, height=14, width=14): result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2)) result = result.transpose(2, 3).transpose(1, 2) # [B, H, W, C] → [B, C, H, W] return result # ─────────────────────────────── PAGE CONFIG ──────────────────────────────── # st.set_page_config( page_title="Digital Trust Shield", page_icon="🛡️", layout="wide" ) # ─────────────────────────────── HEADER UI ────────────────────────────────── # st.title("🛡️ Digital Trust Shield: Deepfake Detection Engine") st.markdown( "Upload a video to run **MTCNN + DeepSORT multi-identity tracking**, " "**ViT-Small-Patch16-224 inference**, **Grad-CAM explainability**, " "and **FFT frequency analysis** — all in one pipeline." ) st.divider() # ─────────────────────────────── CACHED RESOURCES ─────────────────────────── # @st.cache_resource def load_model(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = timm.create_model('vit_small_patch16_224', pretrained=False, num_classes=2) model.load_state_dict( torch.load("./models/vit_deepfake.pth", map_location=device, weights_only=True) ) model.to(device) model.eval() return model, device @st.cache_resource def load_mtcnn(device): # keep_all=True so we detect ALL faces per frame for DeepSORT return MTCNN(keep_all=True, device=device) model, device = load_model() mtcnn = load_mtcnn(device) # ViT-Small-Patch16-224 requires exactly 224×224 input with ImageNet normalisation. # This must match the val_transform used during training (no augmentation at inference time). transform = transforms.Compose([ transforms.ToPILImage(), # cv2 ndarray → PIL transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) PROB_THRESHOLD = 0.95 # Strict face-confidence gate — no shoulders/blurs # ─────────────────────────────── SIDEBAR ──────────────────────────────────── # st.sidebar.header("⚙️ Controls") uploaded_file = st.sidebar.file_uploader("Upload Video", type=['mp4', 'avi', 'mov']) st.sidebar.markdown("---") st.sidebar.markdown("**Detection Settings**") st.sidebar.markdown(f"- Face prob threshold: `{PROB_THRESHOLD}`") st.sidebar.markdown("- Tracker: `DeepSORT (max_age=30, n_init=3)`") st.sidebar.markdown("- Detector: `MTCNN`") st.sidebar.markdown("- Model: `ViT-Small-Patch16-224`") # ─────────────────────────────── MAIN FLOW ────────────────────────────────── # if uploaded_file is not None: st.video(uploaded_file) if st.button("🔍 Analyze Video", type="primary"): with st.spinner("🚀 Initializing DeepSORT + MTCNN Analysis Engine..."): # ── Save to temp file so OpenCV can read it ──────────────────── # tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") tfile.write(uploaded_file.read()) video_path = tfile.name tfile.close() # ── Open video ───────────────────────────────────────────────── # cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) if not fps or fps == 0: fps = 30.0 frame_interval = max(1, int(round(fps))) # sample 1 fps total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) frames_to_sample = max(1, total_frames // frame_interval) # ── DeepSORT tracker (fresh per video) ───────────────────────── # tracker = DeepSort(max_age=30, n_init=3) # Key: track_id → list of (frame_number, confidence, rgb_img, input_tensor) track_history = {} progress_bar = st.progress(0, text="Detecting faces…") frame_count = 0 sampled_count = 0 while True: ret, frame = cap.read() if not ret: break if frame_count % frame_interval == 0: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) h_img, w_img = frame.shape[:2] # ── MTCNN Detection ───────────────────────────────────── # boxes, probs = mtcnn.detect(frame_rgb) formatted_detections = [] if boxes is not None and probs is not None: for box, prob in zip(boxes, probs): if box is None or prob is None: continue # ── STRICT FACE FILTER: prob < 0.95 → skip ───── # if prob < PROB_THRESHOLD: continue left = float(box[0]) top = float(box[1]) width = float(box[2] - box[0]) height = float(box[3] - box[1]) formatted_detections.append( ([left, top, width, height], float(prob), 'face') ) # ── DeepSORT Update ───────────────────────────────────── # tracks = tracker.update_tracks(formatted_detections, frame=frame) for track in tracks: if not track.is_confirmed() or track.time_since_update > 0: continue ltrb = track.to_ltrb() x = max(0, int(ltrb[0])) y = max(0, int(ltrb[1])) w = min(w_img - x, int(ltrb[2] - ltrb[0])) h = min(h_img - y, int(ltrb[3] - ltrb[1])) if w <= 0 or h <= 0: continue # ── Strict Elliptical Blackout Mask ──────────────── # center = (x + w // 2, y + h // 2) axes = (max(1, w // 2), max(1, h // 2)) mask = np.zeros_like(frame) cv2.ellipse(mask, center, axes, 0, 0, 360, (255, 255, 255), -1) masked_frame = cv2.bitwise_and(frame, mask) face_crop = masked_frame[y:y+h, x:x+w] if face_crop.size == 0: continue # ── ViT-Small Inference ───────────────────────────── # # transform already handles Resize(224,224) + Normalize. # We still build rgb_img for display / Grad-CAM overlay. face_resized = cv2.resize(face_crop, (224, 224)) rgb_img = cv2.cvtColor(face_resized, cv2.COLOR_BGR2RGB) input_tensor = transform(rgb_img).unsqueeze(0).to(device) with torch.no_grad(): out = model(input_tensor) probs_t = F.softmax(out, dim=1)[0] fake_conf = probs_t[0].item() * 100 # class 0 = fake tid = track.track_id if tid not in track_history: track_history[tid] = [] track_history[tid].append( (frame_count, fake_conf, rgb_img, input_tensor) ) sampled_count += 1 prog = min(sampled_count / frames_to_sample, 1.0) progress_bar.progress(prog, text=f"Analyzed frame {frame_count} / {total_frames}") frame_count += 1 cap.release() progress_bar.empty() # ── Guard: no tracks ─────────────────────────────────────────────── # if not track_history: st.error( "⚠️ No confirmed faces detected above the 95% confidence threshold. " "Try a clearer video with a visible, well-lit face." ) st.stop() # ─────────────────────── TEMPORAL VARIANCE CHART ─────────────────── # st.subheader("📈 Temporal Variance Analysis — Multi-Identity Tracking") multi_plot_path = './data/heatmaps/multi_temporal_variance.png' os.makedirs('./data/heatmaps', exist_ok=True) fig, ax = plt.subplots(figsize=(12, 5)) ax.axhline(y=50, color='red', linestyle='--', linewidth=1.5, label="50% Threshold", zorder=2) for tid, history in sorted(track_history.items()): frames = [item[0] for item in history] scores = [item[1] for item in history] ax.plot(frames, scores, marker='o', markersize=3, linewidth=1.8, label=f"ID: {tid}", zorder=3) ax.set_ylim(0, 105) ax.set_xlabel("Frame Number", fontsize=11) ax.set_ylabel("Fake Confidence (%)", fontsize=11) ax.set_title("DeepSORT Multi-Identity Deepfake Confidence Over Time", fontsize=13, fontweight='bold') ax.legend(loc="upper right") ax.grid(True, linestyle='--', alpha=0.5) fig.tight_layout() fig.savefig(multi_plot_path, dpi=150) plt.close(fig) st.image(multi_plot_path, use_container_width=True) # ── Per-ID summary table ─────────────────────────────────────────── # st.markdown("**Per-Identity Confidence Summary**") id_summaries = {} for tid, history in track_history.items(): scores = [item[1] for item in history] id_summaries[tid] = { "Frames Tracked": len(scores), "Avg Confidence (%)": f"{sum(scores)/len(scores):.1f}", "Max Confidence (%)": f"{max(scores):.1f}", } cols = st.columns(min(len(id_summaries), 4)) for i, (tid, summary) in enumerate(sorted(id_summaries.items())): with cols[i % len(cols)]: avg = float(summary["Avg Confidence (%)"]) badge = "🔴 FAKE" if avg > 50 else "🟢 REAL" st.metric( label=f"Track ID: {tid} {badge}", value=f"{avg:.1f}%", delta=f"{float(summary['Max Confidence (%)']):.1f}% peak", delta_color="inverse" ) st.divider() # ─────────────────────── SMART WORST OFFENDER ────────────────────── # # Flatten all frames across ALL track IDs, find absolute max confidence all_frames = [item for history in track_history.values() for item in history] all_frames.sort(key=lambda x: x[1], reverse=True) worst_frame_num, worst_conf, worst_img, worst_tensor = all_frames[0] st.subheader(f"🔬 Explainable AI — Highest Anomaly Frame (Frame #{worst_frame_num}, Conf: {worst_conf:.1f}%)") # ── Grad-CAM ─────────────────────────────────────────────────────── # # For ViT, we hook into the norm2 of the last transformer block. # norm2 sits just before each block's MLP — it carries rich, spatially # distributed semantic activations that Grad-CAM can localise. try: target_layers = [model.blocks[-1].norm2] except (AttributeError, IndexError): target_layers = [] targets = [ClassifierOutputTarget(0)] # class 0 = fake with GradCAM( model=model, target_layers=target_layers, reshape_transform=reshape_transform, # ← ViT patch-sequence → 2D spatial map ) as cam: grayscale_cam = cam(input_tensor=worst_tensor, targets=targets)[0, :] rgb_float = np.float32(worst_img) / 255.0 visualization = show_cam_on_image(rgb_float, grayscale_cam, use_rgb=True) # ── FFT Spectrum ──────────────────────────────────────────────────── # fft_path = analyze_fft(worst_img, save_path='./data/heatmaps/fft_spectrum.png') col1, col2, col3 = st.columns(3) with col1: st.image(worst_img, caption=f"Elliptical Masked Face | Conf: {worst_conf:.1f}%") with col2: st.image(visualization, caption="Grad-CAM Heatmap | Manipulation Hotspots") with col3: st.image(fft_path, caption="FFT Magnitude Spectrum | Frequency Artifacts") st.caption( "📡 **Frequency Domain (FFT):** Grid-like patterns or abnormal bright spots " "in outer frequencies indicate unnatural AI generation artifacts. " "🔥 **Grad-CAM:** Red/yellow zones reveal the exact facial regions driving the model's decision." ) # ─────────────────────── FINAL VERDICT LOGIC ─────────────────────── # # RULE: If ANY tracked ID has avg confidence > 50% → DEEPFAKE DETECTED st.divider() fake_ids = [tid for tid, h in track_history.items() if (sum(i[1] for i in h) / len(h)) > 50.0] real_ids = [tid for tid in track_history if tid not in fake_ids] global_avg = sum(i[1] for i in all_frames) / len(all_frames) if fake_ids: st.error( f"### ⚠️ VERDICT: DEEPFAKE DETECTED\n" f"- **Compromised IDs:** {', '.join(str(t) for t in fake_ids)}\n" f"- **Clean IDs:** {', '.join(str(t) for t in real_ids) if real_ids else 'None'}\n" f"- **Global Average Confidence:** {global_avg:.1f}%" ) else: st.success( f"### ✅ VERDICT: AUTHENTIC VIDEO\n" f"- **All Tracked IDs passed:** {', '.join(str(t) for t in track_history.keys())}\n" f"- **Global Average Confidence:** {global_avg:.1f}%" )