Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| import cv2 | |
| import timm | |
| from fastapi import FastAPI, File, UploadFile | |
| import shutil | |
| import os | |
| # FastAPI app instance | |
| app = FastAPI() | |
| # Device configuration | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # β Use /tmp/models instead of ./models | |
| MODEL_DIR = "/tmp/models" | |
| os.makedirs(MODEL_DIR, exist_ok=True) # β Fix PermissionError | |
| # Model URL from Hugging Face | |
| MODEL_URL = "https://huggingface.co/Maddy21/deepfake-detection-api/resolve/main/best_vit_model.pth" | |
| # Define model path | |
| model_path = os.path.join(MODEL_DIR, "best_vit_model.pth") | |
| # Download model if not already present | |
| if not os.path.exists(model_path): | |
| print("Downloading model...") | |
| torch.hub.download_url_to_file(MODEL_URL, model_path) | |
| print("Model downloaded successfully.") | |
| # Load the trained ViT model | |
| model = timm.create_model('vit_large_patch16_224', pretrained=False, num_classes=2) | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| # Define image transformations | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # Function to process video and classify frames | |
| def predict_video(video_path): | |
| cap = cv2.VideoCapture(video_path) | |
| frame_count = 0 | |
| real_count = 0 | |
| manipulated_count = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_count += 1 | |
| print(frame_count) | |
| image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| image = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = model(image) | |
| _, predicted = torch.max(outputs, 1) | |
| print(predicted) | |
| if predicted.item() == 0: | |
| real_count += 1 | |
| else: | |
| manipulated_count += 1 | |
| cap.release() | |
| result = "Real" if real_count > manipulated_count else "Manipulated" | |
| return {"total_frames": frame_count, "real_frames": real_count, "manipulated_frames": manipulated_count, "result": result} | |
| # API Endpoint to check API status | |
| def read_root(): | |
| return {"message": "Deepfake Detection API is running!"} | |
| # API Endpoint to receive and process video | |
| async def predict(file: UploadFile = File(...)): | |
| file_path = f"/tmp/{file.filename}" # β Use /tmp/ instead of current directory | |
| # Save uploaded video | |
| with open(file_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| # Run prediction | |
| result = predict_video(file_path) | |
| # Delete temp file after processing | |
| os.remove(file_path) | |
| return result | |