y512's picture
Create src/app.py
a0eab6c verified
import streamlit as st
import cv2
import os
import tempfile
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import timm
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from facenet_pytorch import MTCNN
from deep_sort_realtime.deepsort_tracker import DeepSort
from frequency_analyzer import analyze_fft
# ── ViT Grad-CAM reshape helper ────────────────────────────────────────────── #
# ViT Π²Π½ΡƒΡ‚Ρ€ΠΈ represents an image as a flat sequence of patches:
# [Batch, 1 + 14Γ—14, Embedding_dim] (the leading 1 is the [CLS] token)
# Standard GradCAM expects [Batch, Channels, Height, Width], so we:
# 1. Drop the [CLS] token (index 0)
# 2. Reshape the 196 patch tokens β†’ 14Γ—14 spatial grid
# 3. Permute axes to [B, C, H, W]
# 14 = 224 (input size) / 16 (ViT patch size)
def reshape_transform(tensor, height=14, width=14):
result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2))
result = result.transpose(2, 3).transpose(1, 2) # [B, H, W, C] β†’ [B, C, H, W]
return result
# ─────────────────────────────── PAGE CONFIG ──────────────────────────────── #
st.set_page_config(
page_title="Digital Trust Shield",
page_icon="πŸ›‘οΈ",
layout="wide"
)
# ─────────────────────────────── HEADER UI ────────────────────────────────── #
st.title("πŸ›‘οΈ Digital Trust Shield: Deepfake Detection Engine")
st.markdown(
"Upload a video to run **MTCNN + DeepSORT multi-identity tracking**, "
"**ViT-Small-Patch16-224 inference**, **Grad-CAM explainability**, "
"and **FFT frequency analysis** β€” all in one pipeline."
)
st.divider()
# ─────────────────────────────── CACHED RESOURCES ─────────────────────────── #
@st.cache_resource
def load_model():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = timm.create_model('vit_small_patch16_224', pretrained=False, num_classes=2)
model.load_state_dict(
torch.load("./models/vit_deepfake.pth", map_location=device, weights_only=True)
)
model.to(device)
model.eval()
return model, device
@st.cache_resource
def load_mtcnn(device):
# keep_all=True so we detect ALL faces per frame for DeepSORT
return MTCNN(keep_all=True, device=device)
model, device = load_model()
mtcnn = load_mtcnn(device)
# ViT-Small-Patch16-224 requires exactly 224Γ—224 input with ImageNet normalisation.
# This must match the val_transform used during training (no augmentation at inference time).
transform = transforms.Compose([
transforms.ToPILImage(), # cv2 ndarray β†’ PIL
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
PROB_THRESHOLD = 0.95 # Strict face-confidence gate β€” no shoulders/blurs
# ─────────────────────────────── SIDEBAR ──────────────────────────────────── #
st.sidebar.header("βš™οΈ Controls")
uploaded_file = st.sidebar.file_uploader("Upload Video", type=['mp4', 'avi', 'mov'])
st.sidebar.markdown("---")
st.sidebar.markdown("**Detection Settings**")
st.sidebar.markdown(f"- Face prob threshold: `{PROB_THRESHOLD}`")
st.sidebar.markdown("- Tracker: `DeepSORT (max_age=30, n_init=3)`")
st.sidebar.markdown("- Detector: `MTCNN`")
st.sidebar.markdown("- Model: `ViT-Small-Patch16-224`")
# ─────────────────────────────── MAIN FLOW ────────────────────────────────── #
if uploaded_file is not None:
st.video(uploaded_file)
if st.button("πŸ” Analyze Video", type="primary"):
with st.spinner("πŸš€ Initializing DeepSORT + MTCNN Analysis Engine..."):
# ── Save to temp file so OpenCV can read it ──────────────────── #
tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
tfile.write(uploaded_file.read())
video_path = tfile.name
tfile.close()
# ── Open video ───────────────────────────────────────────────── #
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
if not fps or fps == 0:
fps = 30.0
frame_interval = max(1, int(round(fps))) # sample 1 fps
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frames_to_sample = max(1, total_frames // frame_interval)
# ── DeepSORT tracker (fresh per video) ───────────────────────── #
tracker = DeepSort(max_age=30, n_init=3)
# Key: track_id β†’ list of (frame_number, confidence, rgb_img, input_tensor)
track_history = {}
progress_bar = st.progress(0, text="Detecting faces…")
frame_count = 0
sampled_count = 0
while True:
ret, frame = cap.read()
if not ret:
break
if frame_count % frame_interval == 0:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
h_img, w_img = frame.shape[:2]
# ── MTCNN Detection ───────────────────────────────────── #
boxes, probs = mtcnn.detect(frame_rgb)
formatted_detections = []
if boxes is not None and probs is not None:
for box, prob in zip(boxes, probs):
if box is None or prob is None:
continue
# ── STRICT FACE FILTER: prob < 0.95 β†’ skip ───── #
if prob < PROB_THRESHOLD:
continue
left = float(box[0])
top = float(box[1])
width = float(box[2] - box[0])
height = float(box[3] - box[1])
formatted_detections.append(
([left, top, width, height], float(prob), 'face')
)
# ── DeepSORT Update ───────────────────────────────────── #
tracks = tracker.update_tracks(formatted_detections, frame=frame)
for track in tracks:
if not track.is_confirmed() or track.time_since_update > 0:
continue
ltrb = track.to_ltrb()
x = max(0, int(ltrb[0]))
y = max(0, int(ltrb[1]))
w = min(w_img - x, int(ltrb[2] - ltrb[0]))
h = min(h_img - y, int(ltrb[3] - ltrb[1]))
if w <= 0 or h <= 0:
continue
# ── Strict Elliptical Blackout Mask ──────────────── #
center = (x + w // 2, y + h // 2)
axes = (max(1, w // 2), max(1, h // 2))
mask = np.zeros_like(frame)
cv2.ellipse(mask, center, axes, 0, 0, 360, (255, 255, 255), -1)
masked_frame = cv2.bitwise_and(frame, mask)
face_crop = masked_frame[y:y+h, x:x+w]
if face_crop.size == 0:
continue
# ── ViT-Small Inference ───────────────────────────── #
# transform already handles Resize(224,224) + Normalize.
# We still build rgb_img for display / Grad-CAM overlay.
face_resized = cv2.resize(face_crop, (224, 224))
rgb_img = cv2.cvtColor(face_resized, cv2.COLOR_BGR2RGB)
input_tensor = transform(rgb_img).unsqueeze(0).to(device)
with torch.no_grad():
out = model(input_tensor)
probs_t = F.softmax(out, dim=1)[0]
fake_conf = probs_t[0].item() * 100 # class 0 = fake
tid = track.track_id
if tid not in track_history:
track_history[tid] = []
track_history[tid].append(
(frame_count, fake_conf, rgb_img, input_tensor)
)
sampled_count += 1
prog = min(sampled_count / frames_to_sample, 1.0)
progress_bar.progress(prog, text=f"Analyzed frame {frame_count} / {total_frames}")
frame_count += 1
cap.release()
progress_bar.empty()
# ── Guard: no tracks ─────────────────────────────────────────────── #
if not track_history:
st.error(
"⚠️ No confirmed faces detected above the 95% confidence threshold. "
"Try a clearer video with a visible, well-lit face."
)
st.stop()
# ─────────────────────── TEMPORAL VARIANCE CHART ─────────────────── #
st.subheader("πŸ“ˆ Temporal Variance Analysis β€” Multi-Identity Tracking")
multi_plot_path = './data/heatmaps/multi_temporal_variance.png'
os.makedirs('./data/heatmaps', exist_ok=True)
fig, ax = plt.subplots(figsize=(12, 5))
ax.axhline(y=50, color='red', linestyle='--', linewidth=1.5, label="50% Threshold", zorder=2)
for tid, history in sorted(track_history.items()):
frames = [item[0] for item in history]
scores = [item[1] for item in history]
ax.plot(frames, scores, marker='o', markersize=3,
linewidth=1.8, label=f"ID: {tid}", zorder=3)
ax.set_ylim(0, 105)
ax.set_xlabel("Frame Number", fontsize=11)
ax.set_ylabel("Fake Confidence (%)", fontsize=11)
ax.set_title("DeepSORT Multi-Identity Deepfake Confidence Over Time", fontsize=13, fontweight='bold')
ax.legend(loc="upper right")
ax.grid(True, linestyle='--', alpha=0.5)
fig.tight_layout()
fig.savefig(multi_plot_path, dpi=150)
plt.close(fig)
st.image(multi_plot_path, use_container_width=True)
# ── Per-ID summary table ─────────────────────────────────────────── #
st.markdown("**Per-Identity Confidence Summary**")
id_summaries = {}
for tid, history in track_history.items():
scores = [item[1] for item in history]
id_summaries[tid] = {
"Frames Tracked": len(scores),
"Avg Confidence (%)": f"{sum(scores)/len(scores):.1f}",
"Max Confidence (%)": f"{max(scores):.1f}",
}
cols = st.columns(min(len(id_summaries), 4))
for i, (tid, summary) in enumerate(sorted(id_summaries.items())):
with cols[i % len(cols)]:
avg = float(summary["Avg Confidence (%)"])
badge = "πŸ”΄ FAKE" if avg > 50 else "🟒 REAL"
st.metric(
label=f"Track ID: {tid} {badge}",
value=f"{avg:.1f}%",
delta=f"{float(summary['Max Confidence (%)']):.1f}% peak",
delta_color="inverse"
)
st.divider()
# ─────────────────────── SMART WORST OFFENDER ────────────────────── #
# Flatten all frames across ALL track IDs, find absolute max confidence
all_frames = [item for history in track_history.values() for item in history]
all_frames.sort(key=lambda x: x[1], reverse=True)
worst_frame_num, worst_conf, worst_img, worst_tensor = all_frames[0]
st.subheader(f"πŸ”¬ Explainable AI β€” Highest Anomaly Frame (Frame #{worst_frame_num}, Conf: {worst_conf:.1f}%)")
# ── Grad-CAM ─────────────────────────────────────────────────────── #
# For ViT, we hook into the norm2 of the last transformer block.
# norm2 sits just before each block's MLP β€” it carries rich, spatially
# distributed semantic activations that Grad-CAM can localise.
try:
target_layers = [model.blocks[-1].norm2]
except (AttributeError, IndexError):
target_layers = []
targets = [ClassifierOutputTarget(0)] # class 0 = fake
with GradCAM(
model=model,
target_layers=target_layers,
reshape_transform=reshape_transform, # ← ViT patch-sequence β†’ 2D spatial map
) as cam:
grayscale_cam = cam(input_tensor=worst_tensor, targets=targets)[0, :]
rgb_float = np.float32(worst_img) / 255.0
visualization = show_cam_on_image(rgb_float, grayscale_cam, use_rgb=True)
# ── FFT Spectrum ──────────────────────────────────────────────────── #
fft_path = analyze_fft(worst_img, save_path='./data/heatmaps/fft_spectrum.png')
col1, col2, col3 = st.columns(3)
with col1:
st.image(worst_img, caption=f"Elliptical Masked Face | Conf: {worst_conf:.1f}%")
with col2:
st.image(visualization, caption="Grad-CAM Heatmap | Manipulation Hotspots")
with col3:
st.image(fft_path, caption="FFT Magnitude Spectrum | Frequency Artifacts")
st.caption(
"πŸ“‘ **Frequency Domain (FFT):** Grid-like patterns or abnormal bright spots "
"in outer frequencies indicate unnatural AI generation artifacts. "
"πŸ”₯ **Grad-CAM:** Red/yellow zones reveal the exact facial regions driving the model's decision."
)
# ─────────────────────── FINAL VERDICT LOGIC ─────────────────────── #
# RULE: If ANY tracked ID has avg confidence > 50% β†’ DEEPFAKE DETECTED
st.divider()
fake_ids = [tid for tid, h in track_history.items()
if (sum(i[1] for i in h) / len(h)) > 50.0]
real_ids = [tid for tid in track_history if tid not in fake_ids]
global_avg = sum(i[1] for i in all_frames) / len(all_frames)
if fake_ids:
st.error(
f"### ⚠️ VERDICT: DEEPFAKE DETECTED\n"
f"- **Compromised IDs:** {', '.join(str(t) for t in fake_ids)}\n"
f"- **Clean IDs:** {', '.join(str(t) for t in real_ids) if real_ids else 'None'}\n"
f"- **Global Average Confidence:** {global_avg:.1f}%"
)
else:
st.success(
f"### βœ… VERDICT: AUTHENTIC VIDEO\n"
f"- **All Tracked IDs passed:** {', '.join(str(t) for t in track_history.keys())}\n"
f"- **Global Average Confidence:** {global_avg:.1f}%"
)