import torch import torch.nn as nn from torchvision import transforms from PIL import Image import cv2 import timm from fastapi import FastAPI, File, UploadFile import shutil import os # FastAPI app instance app = FastAPI() # Device configuration device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ✅ Use /tmp/models instead of ./models MODEL_DIR = "/tmp/models" os.makedirs(MODEL_DIR, exist_ok=True) # ✅ Fix PermissionError # Model URL from Hugging Face MODEL_URL = "https://huggingface.co/Maddy21/deepfake-detection-api/resolve/main/best_vit_model.pth" # Define model path model_path = os.path.join(MODEL_DIR, "best_vit_model.pth") # Download model if not already present if not os.path.exists(model_path): print("Downloading model...") torch.hub.download_url_to_file(MODEL_URL, model_path) print("Model downloaded successfully.") # Load the trained ViT model model = timm.create_model('vit_large_patch16_224', pretrained=False, num_classes=2) model.load_state_dict(torch.load(model_path, map_location=device)) model.to(device) model.eval() # Define image transformations transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Function to process video and classify frames def predict_video(video_path): cap = cv2.VideoCapture(video_path) frame_count = 0 real_count = 0 manipulated_count = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break frame_count += 1 print(frame_count) image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) image = transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(image) _, predicted = torch.max(outputs, 1) print(predicted) if predicted.item() == 0: real_count += 1 else: manipulated_count += 1 cap.release() result = "Real" if real_count > manipulated_count else "Manipulated" return {"total_frames": frame_count, "real_frames": real_count, "manipulated_frames": manipulated_count, "result": result} # API Endpoint to check API status @app.get("/") def read_root(): return {"message": "Deepfake Detection API is running!"} # API Endpoint to receive and process video @app.post("/predict/") async def predict(file: UploadFile = File(...)): file_path = f"/tmp/{file.filename}" # ✅ Use /tmp/ instead of current directory # Save uploaded video with open(file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) # Run prediction result = predict_video(file_path) # Delete temp file after processing os.remove(file_path) return result