PiyushGPT's picture
Update app.py
63741f3 verified
import os
import tempfile
import torch
import cv2
import numpy as np
from PIL import Image
from torchvision import transforms
import gradio as gr
from huggingface_hub import hf_hub_download # To download the model from Hugging Face
# Define Model Class
class Model(torch.nn.Module):
def __init__(self, num_classes, latent_dim=512):
super(Model, self).__init__()
resnet = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)
self.model = torch.nn.Sequential(*list(resnet.children())[:-1]) # Remove final classification layer
self.linear = torch.nn.Linear(latent_dim, num_classes) # ResNet18 latent dim is 512
self.dropout = torch.nn.Dropout(0.5)
def forward(self, x):
batch_size, seq_length, c, h, w = x.shape
x = x.view(batch_size * seq_length, c, h, w)
x = self.model(x)
x = x.view(batch_size, seq_length, -1) # Flatten spatial dimensions
x = x.mean(dim=1) # Average over frames
x = self.dropout(x)
return self.linear(x)
# Load Pre-trained Model from Hugging Face Hub
def load_model():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model(num_classes=2)
try:
# Download the model file from Hugging Face
print("Downloading model from Hugging Face...")
model_path = hf_hub_download(
repo_id="PiyushGPT/deepfake-detection-model",
filename="model.pt",
use_auth_token=os.getenv("HF_AUTH_TOKEN") # For private repositories
)
print(f"Model downloaded successfully: {model_path}")
# Load the model weights
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
model.to(device)
return model
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}")
# Transforms
im_size = 64
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
transform = transforms.Compose([
transforms.Resize((im_size, im_size)),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
# Function to Extract Frames from Video
def extract_frames(video_path, max_frames=10):
cap = cv2.VideoCapture(video_path)
frames = []
while len(frames) < max_frames:
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame) # Convert to PIL Image
frame = transform(frame) # Apply transformations
frames.append(frame)
cap.release()
# Pad frames if necessary
while len(frames) < max_frames:
frames.append(torch.zeros_like(frames[0])) # Add blank frames
# Truncate frames if necessary
frames = frames[:max_frames]
return torch.stack(frames).unsqueeze(0) # Add batch dimension
# Prediction Function
def predict(video_path):
try:
# Load the model
model = load_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Extract frames and preprocess
frames = extract_frames(video_path).to(device)
# Make prediction
with torch.no_grad():
outputs = model(frames)
probabilities = torch.softmax(outputs, dim=1)
confidence, prediction = torch.max(probabilities, 1)
label = "REAL" if prediction.item() == 1 else "FAKE"
confidence = confidence.item() * 100
# Return result
return f"Prediction: {label}, Confidence: {confidence:.2f}%"
except Exception as e:
return f"Error during prediction: {str(e)}"
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# Deepfake Detection App")
gr.Markdown("Upload a video file to detect whether it is REAL or FAKE.")
with gr.Row():
video_input = gr.File(label="Upload Video", file_types=["video"])
output_text = gr.Textbox(label="Prediction Result")
predict_button = gr.Button("Predict")
predict_button.click(predict, inputs=video_input, outputs=output_text)
# Launch the app
demo.launch()