Maddy21's picture
final
96ebbe7
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import cv2
import timm
from fastapi import FastAPI, File, UploadFile
import shutil
import os
# FastAPI app instance
app = FastAPI()
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# βœ… Use /tmp/models instead of ./models
MODEL_DIR = "/tmp/models"
os.makedirs(MODEL_DIR, exist_ok=True) # βœ… Fix PermissionError
# Model URL from Hugging Face
MODEL_URL = "https://huggingface.co/Maddy21/deepfake-detection-api/resolve/main/best_vit_model.pth"
# Define model path
model_path = os.path.join(MODEL_DIR, "best_vit_model.pth")
# Download model if not already present
if not os.path.exists(model_path):
print("Downloading model...")
torch.hub.download_url_to_file(MODEL_URL, model_path)
print("Model downloaded successfully.")
# Load the trained ViT model
model = timm.create_model('vit_large_patch16_224', pretrained=False, num_classes=2)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
# Define image transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Function to process video and classify frames
def predict_video(video_path):
cap = cv2.VideoCapture(video_path)
frame_count = 0
real_count = 0
manipulated_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame_count += 1
print(frame_count)
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
image = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(image)
_, predicted = torch.max(outputs, 1)
print(predicted)
if predicted.item() == 0:
real_count += 1
else:
manipulated_count += 1
cap.release()
result = "Real" if real_count > manipulated_count else "Manipulated"
return {"total_frames": frame_count, "real_frames": real_count, "manipulated_frames": manipulated_count, "result": result}
# API Endpoint to check API status
@app.get("/")
def read_root():
return {"message": "Deepfake Detection API is running!"}
# API Endpoint to receive and process video
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
file_path = f"/tmp/{file.filename}" # βœ… Use /tmp/ instead of current directory
# Save uploaded video
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Run prediction
result = predict_video(file_path)
# Delete temp file after processing
os.remove(file_path)
return result