Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from torchvision import transforms | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download # To download the model from Hugging Face | |
| # Define Model Class | |
| class Model(torch.nn.Module): | |
| def __init__(self, num_classes, latent_dim=512): | |
| super(Model, self).__init__() | |
| resnet = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False) | |
| self.model = torch.nn.Sequential(*list(resnet.children())[:-1]) # Remove final classification layer | |
| self.linear = torch.nn.Linear(latent_dim, num_classes) # ResNet18 latent dim is 512 | |
| self.dropout = torch.nn.Dropout(0.5) | |
| def forward(self, x): | |
| batch_size, seq_length, c, h, w = x.shape | |
| x = x.view(batch_size * seq_length, c, h, w) | |
| x = self.model(x) | |
| x = x.view(batch_size, seq_length, -1) # Flatten spatial dimensions | |
| x = x.mean(dim=1) # Average over frames | |
| x = self.dropout(x) | |
| return self.linear(x) | |
| # Load Pre-trained Model from Hugging Face Hub | |
| def load_model(): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = Model(num_classes=2) | |
| try: | |
| # Download the model file from Hugging Face | |
| print("Downloading model from Hugging Face...") | |
| model_path = hf_hub_download( | |
| repo_id="PiyushGPT/deepfake-detection-model", | |
| filename="model.pt", | |
| use_auth_token=os.getenv("HF_AUTH_TOKEN") # For private repositories | |
| ) | |
| print(f"Model downloaded successfully: {model_path}") | |
| # Load the model weights | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.eval() | |
| model.to(device) | |
| return model | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load model: {str(e)}") | |
| # Transforms | |
| im_size = 64 | |
| mean = [0.485, 0.456, 0.406] | |
| std = [0.229, 0.224, 0.225] | |
| transform = transforms.Compose([ | |
| transforms.Resize((im_size, im_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean, std) | |
| ]) | |
| # Function to Extract Frames from Video | |
| def extract_frames(video_path, max_frames=10): | |
| cap = cv2.VideoCapture(video_path) | |
| frames = [] | |
| while len(frames) < max_frames: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frame = Image.fromarray(frame) # Convert to PIL Image | |
| frame = transform(frame) # Apply transformations | |
| frames.append(frame) | |
| cap.release() | |
| # Pad frames if necessary | |
| while len(frames) < max_frames: | |
| frames.append(torch.zeros_like(frames[0])) # Add blank frames | |
| # Truncate frames if necessary | |
| frames = frames[:max_frames] | |
| return torch.stack(frames).unsqueeze(0) # Add batch dimension | |
| # Prediction Function | |
| def predict(video_path): | |
| try: | |
| # Load the model | |
| model = load_model() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Extract frames and preprocess | |
| frames = extract_frames(video_path).to(device) | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = model(frames) | |
| probabilities = torch.softmax(outputs, dim=1) | |
| confidence, prediction = torch.max(probabilities, 1) | |
| label = "REAL" if prediction.item() == 1 else "FAKE" | |
| confidence = confidence.item() * 100 | |
| # Return result | |
| return f"Prediction: {label}, Confidence: {confidence:.2f}%" | |
| except Exception as e: | |
| return f"Error during prediction: {str(e)}" | |
| # Gradio Interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Deepfake Detection App") | |
| gr.Markdown("Upload a video file to detect whether it is REAL or FAKE.") | |
| with gr.Row(): | |
| video_input = gr.File(label="Upload Video", file_types=["video"]) | |
| output_text = gr.Textbox(label="Prediction Result") | |
| predict_button = gr.Button("Predict") | |
| predict_button.click(predict, inputs=video_input, outputs=output_text) | |
| # Launch the app | |
| demo.launch() |