Spaces:
No application file
No application file
| import streamlit as st | |
| import cv2 | |
| import os | |
| import tempfile | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import timm | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| from facenet_pytorch import MTCNN | |
| from deep_sort_realtime.deepsort_tracker import DeepSort | |
| from frequency_analyzer import analyze_fft | |
| # ββ ViT Grad-CAM reshape helper ββββββββββββββββββββββββββββββββββββββββββββββ # | |
| # ViT Π²Π½ΡΡΡΠΈ represents an image as a flat sequence of patches: | |
| # [Batch, 1 + 14Γ14, Embedding_dim] (the leading 1 is the [CLS] token) | |
| # Standard GradCAM expects [Batch, Channels, Height, Width], so we: | |
| # 1. Drop the [CLS] token (index 0) | |
| # 2. Reshape the 196 patch tokens β 14Γ14 spatial grid | |
| # 3. Permute axes to [B, C, H, W] | |
| # 14 = 224 (input size) / 16 (ViT patch size) | |
| def reshape_transform(tensor, height=14, width=14): | |
| result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2)) | |
| result = result.transpose(2, 3).transpose(1, 2) # [B, H, W, C] β [B, C, H, W] | |
| return result | |
| # βββββββββββββββββββββββββββββββ PAGE CONFIG ββββββββββββββββββββββββββββββββ # | |
| st.set_page_config( | |
| page_title="Digital Trust Shield", | |
| page_icon="π‘οΈ", | |
| layout="wide" | |
| ) | |
| # βββββββββββββββββββββββββββββββ HEADER UI ββββββββββββββββββββββββββββββββββ # | |
| st.title("π‘οΈ Digital Trust Shield: Deepfake Detection Engine") | |
| st.markdown( | |
| "Upload a video to run **MTCNN + DeepSORT multi-identity tracking**, " | |
| "**ViT-Small-Patch16-224 inference**, **Grad-CAM explainability**, " | |
| "and **FFT frequency analysis** β all in one pipeline." | |
| ) | |
| st.divider() | |
| # βββββββββββββββββββββββββββββββ CACHED RESOURCES βββββββββββββββββββββββββββ # | |
| def load_model(): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = timm.create_model('vit_small_patch16_224', pretrained=False, num_classes=2) | |
| model.load_state_dict( | |
| torch.load("./models/vit_deepfake.pth", map_location=device, weights_only=True) | |
| ) | |
| model.to(device) | |
| model.eval() | |
| return model, device | |
| def load_mtcnn(device): | |
| # keep_all=True so we detect ALL faces per frame for DeepSORT | |
| return MTCNN(keep_all=True, device=device) | |
| model, device = load_model() | |
| mtcnn = load_mtcnn(device) | |
| # ViT-Small-Patch16-224 requires exactly 224Γ224 input with ImageNet normalisation. | |
| # This must match the val_transform used during training (no augmentation at inference time). | |
| transform = transforms.Compose([ | |
| transforms.ToPILImage(), # cv2 ndarray β PIL | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| PROB_THRESHOLD = 0.95 # Strict face-confidence gate β no shoulders/blurs | |
| # βββββββββββββββββββββββββββββββ SIDEBAR ββββββββββββββββββββββββββββββββββββ # | |
| st.sidebar.header("βοΈ Controls") | |
| uploaded_file = st.sidebar.file_uploader("Upload Video", type=['mp4', 'avi', 'mov']) | |
| st.sidebar.markdown("---") | |
| st.sidebar.markdown("**Detection Settings**") | |
| st.sidebar.markdown(f"- Face prob threshold: `{PROB_THRESHOLD}`") | |
| st.sidebar.markdown("- Tracker: `DeepSORT (max_age=30, n_init=3)`") | |
| st.sidebar.markdown("- Detector: `MTCNN`") | |
| st.sidebar.markdown("- Model: `ViT-Small-Patch16-224`") | |
| # βββββββββββββββββββββββββββββββ MAIN FLOW ββββββββββββββββββββββββββββββββββ # | |
| if uploaded_file is not None: | |
| st.video(uploaded_file) | |
| if st.button("π Analyze Video", type="primary"): | |
| with st.spinner("π Initializing DeepSORT + MTCNN Analysis Engine..."): | |
| # ββ Save to temp file so OpenCV can read it ββββββββββββββββββββ # | |
| tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
| tfile.write(uploaded_file.read()) | |
| video_path = tfile.name | |
| tfile.close() | |
| # ββ Open video βββββββββββββββββββββββββββββββββββββββββββββββββ # | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| if not fps or fps == 0: | |
| fps = 30.0 | |
| frame_interval = max(1, int(round(fps))) # sample 1 fps | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| frames_to_sample = max(1, total_frames // frame_interval) | |
| # ββ DeepSORT tracker (fresh per video) βββββββββββββββββββββββββ # | |
| tracker = DeepSort(max_age=30, n_init=3) | |
| # Key: track_id β list of (frame_number, confidence, rgb_img, input_tensor) | |
| track_history = {} | |
| progress_bar = st.progress(0, text="Detecting facesβ¦") | |
| frame_count = 0 | |
| sampled_count = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if frame_count % frame_interval == 0: | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| h_img, w_img = frame.shape[:2] | |
| # ββ MTCNN Detection βββββββββββββββββββββββββββββββββββββ # | |
| boxes, probs = mtcnn.detect(frame_rgb) | |
| formatted_detections = [] | |
| if boxes is not None and probs is not None: | |
| for box, prob in zip(boxes, probs): | |
| if box is None or prob is None: | |
| continue | |
| # ββ STRICT FACE FILTER: prob < 0.95 β skip βββββ # | |
| if prob < PROB_THRESHOLD: | |
| continue | |
| left = float(box[0]) | |
| top = float(box[1]) | |
| width = float(box[2] - box[0]) | |
| height = float(box[3] - box[1]) | |
| formatted_detections.append( | |
| ([left, top, width, height], float(prob), 'face') | |
| ) | |
| # ββ DeepSORT Update βββββββββββββββββββββββββββββββββββββ # | |
| tracks = tracker.update_tracks(formatted_detections, frame=frame) | |
| for track in tracks: | |
| if not track.is_confirmed() or track.time_since_update > 0: | |
| continue | |
| ltrb = track.to_ltrb() | |
| x = max(0, int(ltrb[0])) | |
| y = max(0, int(ltrb[1])) | |
| w = min(w_img - x, int(ltrb[2] - ltrb[0])) | |
| h = min(h_img - y, int(ltrb[3] - ltrb[1])) | |
| if w <= 0 or h <= 0: | |
| continue | |
| # ββ Strict Elliptical Blackout Mask ββββββββββββββββ # | |
| center = (x + w // 2, y + h // 2) | |
| axes = (max(1, w // 2), max(1, h // 2)) | |
| mask = np.zeros_like(frame) | |
| cv2.ellipse(mask, center, axes, 0, 0, 360, (255, 255, 255), -1) | |
| masked_frame = cv2.bitwise_and(frame, mask) | |
| face_crop = masked_frame[y:y+h, x:x+w] | |
| if face_crop.size == 0: | |
| continue | |
| # ββ ViT-Small Inference βββββββββββββββββββββββββββββ # | |
| # transform already handles Resize(224,224) + Normalize. | |
| # We still build rgb_img for display / Grad-CAM overlay. | |
| face_resized = cv2.resize(face_crop, (224, 224)) | |
| rgb_img = cv2.cvtColor(face_resized, cv2.COLOR_BGR2RGB) | |
| input_tensor = transform(rgb_img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| out = model(input_tensor) | |
| probs_t = F.softmax(out, dim=1)[0] | |
| fake_conf = probs_t[0].item() * 100 # class 0 = fake | |
| tid = track.track_id | |
| if tid not in track_history: | |
| track_history[tid] = [] | |
| track_history[tid].append( | |
| (frame_count, fake_conf, rgb_img, input_tensor) | |
| ) | |
| sampled_count += 1 | |
| prog = min(sampled_count / frames_to_sample, 1.0) | |
| progress_bar.progress(prog, text=f"Analyzed frame {frame_count} / {total_frames}") | |
| frame_count += 1 | |
| cap.release() | |
| progress_bar.empty() | |
| # ββ Guard: no tracks βββββββββββββββββββββββββββββββββββββββββββββββ # | |
| if not track_history: | |
| st.error( | |
| "β οΈ No confirmed faces detected above the 95% confidence threshold. " | |
| "Try a clearer video with a visible, well-lit face." | |
| ) | |
| st.stop() | |
| # βββββββββββββββββββββββ TEMPORAL VARIANCE CHART βββββββββββββββββββ # | |
| st.subheader("π Temporal Variance Analysis β Multi-Identity Tracking") | |
| multi_plot_path = './data/heatmaps/multi_temporal_variance.png' | |
| os.makedirs('./data/heatmaps', exist_ok=True) | |
| fig, ax = plt.subplots(figsize=(12, 5)) | |
| ax.axhline(y=50, color='red', linestyle='--', linewidth=1.5, label="50% Threshold", zorder=2) | |
| for tid, history in sorted(track_history.items()): | |
| frames = [item[0] for item in history] | |
| scores = [item[1] for item in history] | |
| ax.plot(frames, scores, marker='o', markersize=3, | |
| linewidth=1.8, label=f"ID: {tid}", zorder=3) | |
| ax.set_ylim(0, 105) | |
| ax.set_xlabel("Frame Number", fontsize=11) | |
| ax.set_ylabel("Fake Confidence (%)", fontsize=11) | |
| ax.set_title("DeepSORT Multi-Identity Deepfake Confidence Over Time", fontsize=13, fontweight='bold') | |
| ax.legend(loc="upper right") | |
| ax.grid(True, linestyle='--', alpha=0.5) | |
| fig.tight_layout() | |
| fig.savefig(multi_plot_path, dpi=150) | |
| plt.close(fig) | |
| st.image(multi_plot_path, use_container_width=True) | |
| # ββ Per-ID summary table βββββββββββββββββββββββββββββββββββββββββββ # | |
| st.markdown("**Per-Identity Confidence Summary**") | |
| id_summaries = {} | |
| for tid, history in track_history.items(): | |
| scores = [item[1] for item in history] | |
| id_summaries[tid] = { | |
| "Frames Tracked": len(scores), | |
| "Avg Confidence (%)": f"{sum(scores)/len(scores):.1f}", | |
| "Max Confidence (%)": f"{max(scores):.1f}", | |
| } | |
| cols = st.columns(min(len(id_summaries), 4)) | |
| for i, (tid, summary) in enumerate(sorted(id_summaries.items())): | |
| with cols[i % len(cols)]: | |
| avg = float(summary["Avg Confidence (%)"]) | |
| badge = "π΄ FAKE" if avg > 50 else "π’ REAL" | |
| st.metric( | |
| label=f"Track ID: {tid} {badge}", | |
| value=f"{avg:.1f}%", | |
| delta=f"{float(summary['Max Confidence (%)']):.1f}% peak", | |
| delta_color="inverse" | |
| ) | |
| st.divider() | |
| # βββββββββββββββββββββββ SMART WORST OFFENDER ββββββββββββββββββββββ # | |
| # Flatten all frames across ALL track IDs, find absolute max confidence | |
| all_frames = [item for history in track_history.values() for item in history] | |
| all_frames.sort(key=lambda x: x[1], reverse=True) | |
| worst_frame_num, worst_conf, worst_img, worst_tensor = all_frames[0] | |
| st.subheader(f"π¬ Explainable AI β Highest Anomaly Frame (Frame #{worst_frame_num}, Conf: {worst_conf:.1f}%)") | |
| # ββ Grad-CAM βββββββββββββββββββββββββββββββββββββββββββββββββββββββ # | |
| # For ViT, we hook into the norm2 of the last transformer block. | |
| # norm2 sits just before each block's MLP β it carries rich, spatially | |
| # distributed semantic activations that Grad-CAM can localise. | |
| try: | |
| target_layers = [model.blocks[-1].norm2] | |
| except (AttributeError, IndexError): | |
| target_layers = [] | |
| targets = [ClassifierOutputTarget(0)] # class 0 = fake | |
| with GradCAM( | |
| model=model, | |
| target_layers=target_layers, | |
| reshape_transform=reshape_transform, # β ViT patch-sequence β 2D spatial map | |
| ) as cam: | |
| grayscale_cam = cam(input_tensor=worst_tensor, targets=targets)[0, :] | |
| rgb_float = np.float32(worst_img) / 255.0 | |
| visualization = show_cam_on_image(rgb_float, grayscale_cam, use_rgb=True) | |
| # ββ FFT Spectrum ββββββββββββββββββββββββββββββββββββββββββββββββββββ # | |
| fft_path = analyze_fft(worst_img, save_path='./data/heatmaps/fft_spectrum.png') | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.image(worst_img, caption=f"Elliptical Masked Face | Conf: {worst_conf:.1f}%") | |
| with col2: | |
| st.image(visualization, caption="Grad-CAM Heatmap | Manipulation Hotspots") | |
| with col3: | |
| st.image(fft_path, caption="FFT Magnitude Spectrum | Frequency Artifacts") | |
| st.caption( | |
| "π‘ **Frequency Domain (FFT):** Grid-like patterns or abnormal bright spots " | |
| "in outer frequencies indicate unnatural AI generation artifacts. " | |
| "π₯ **Grad-CAM:** Red/yellow zones reveal the exact facial regions driving the model's decision." | |
| ) | |
| # βββββββββββββββββββββββ FINAL VERDICT LOGIC βββββββββββββββββββββββ # | |
| # RULE: If ANY tracked ID has avg confidence > 50% β DEEPFAKE DETECTED | |
| st.divider() | |
| fake_ids = [tid for tid, h in track_history.items() | |
| if (sum(i[1] for i in h) / len(h)) > 50.0] | |
| real_ids = [tid for tid in track_history if tid not in fake_ids] | |
| global_avg = sum(i[1] for i in all_frames) / len(all_frames) | |
| if fake_ids: | |
| st.error( | |
| f"### β οΈ VERDICT: DEEPFAKE DETECTED\n" | |
| f"- **Compromised IDs:** {', '.join(str(t) for t in fake_ids)}\n" | |
| f"- **Clean IDs:** {', '.join(str(t) for t in real_ids) if real_ids else 'None'}\n" | |
| f"- **Global Average Confidence:** {global_avg:.1f}%" | |
| ) | |
| else: | |
| st.success( | |
| f"### β VERDICT: AUTHENTIC VIDEO\n" | |
| f"- **All Tracked IDs passed:** {', '.join(str(t) for t in track_history.keys())}\n" | |
| f"- **Global Average Confidence:** {global_avg:.1f}%" | |
| ) | |