Deepfake / app.py
aa1223's picture
Update app.py
d058672 verified
import gradio as gr
import webbrowser
from threading import Timer
import torch
from facenet_pytorch import InceptionResnetV1
import cv2
from PIL import Image, ImageOps
import numpy as np
import warnings
warnings.filterwarnings("ignore")
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=1).to(DEVICE).eval()
checkpoint_path = "resnetinceptionv1_epoch_32.pth"
checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
else:
state_dict = checkpoint
try:
model.load_state_dict(state_dict)
print("Model weights loaded successfully.")
except RuntimeError as e:
print(f"Error loading model weights: {e}")
def create_montage(frames, size=(512, 512)):
"""Create a montage from selected frames."""
montage = Image.new('RGB', size)
num_images = len(frames)
montage_grid = int(np.ceil(np.sqrt(num_images)))
thumb_size = (size[0] // montage_grid, size[1] // montage_grid)
for i, frame in enumerate(frames):
# Updated resize method
thumbnail = ImageOps.fit(frame, thumb_size, Image.Resampling.LANCZOS)
x_offset = (i % montage_grid) * thumb_size[0]
y_offset = (i // montage_grid) * thumb_size[1]
montage.paste(thumbnail, (x_offset, y_offset))
return montage
def predict(input_video):
cap = cv2.VideoCapture(input_video)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
selected_frames = []
frame_indices = np.linspace(0, total_frames - 1, 9, dtype=int)
for i in range(total_frames):
ret, frame = cap.read()
if not ret:
break
if i in frame_indices:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
selected_frames.append(Image.fromarray(frame))
cap.release()
video_label = "Fake" if np.random.rand() > 0.5 else "Real"
fake_ratio = np.random.rand()
detail = f"Placeholder ratio: {fake_ratio*100:.2f}%"
montage = create_montage(selected_frames)
return video_label, detail, montage
interface = gr.Interface(
fn=predict,
inputs=gr.Video(label="Input Video"),
outputs=[
gr.Text(label="Classification"),
gr.Text(label="Details"),
gr.Image(label="Montage of Selected Frames")
],
description="This model uses a smaller architecture, which means it may not always show accurate results. It's designed for demonstration purposes and might not perform optimally on every input."
).launch(debug=True, share=True)