Maddy21 commited on
Commit
cd03405
·
1 Parent(s): 6415ec4
Files changed (1) hide show
  1. app.py +13 -37
app.py CHANGED
@@ -1,20 +1,17 @@
1
  import torch
2
- import torch.nn as nn
3
- from torchvision import transforms
4
- from PIL import Image
5
- import cv2
6
  import timm
 
7
  from fastapi import FastAPI, File, UploadFile
8
  import shutil
9
  import os
10
- from huggingface_hub import hf_hub_download
 
 
11
 
12
- # FastAPI app instance
13
  app = FastAPI()
14
 
15
  # Device configuration
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
- print("Initializing model...")
18
 
19
  # Define image transformations
20
  transform = transforms.Compose([
@@ -23,27 +20,17 @@ transform = transforms.Compose([
23
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
24
  ])
25
 
26
- # Hugging Face model repository details
27
- HF_MODEL_REPO = "Maddy21/vit-deepfake-model" # Replace with your repo
28
- HF_MODEL_FILE = "best_vit_model.pth"
29
-
30
- # Download the model from Hugging Face
31
- model_path = hf_hub_download(repo_id=HF_MODEL_REPO, filename=HF_MODEL_FILE)
32
 
33
- # Load the trained ViT model
34
  model = timm.create_model('vit_large_patch16_224', pretrained=False, num_classes=2)
35
- model.load_state_dict(torch.load(model_path, map_location=device))
36
  model.to(device)
37
  model.eval()
38
- print("Model loaded successfully.")
39
 
40
- # Function to process video and classify frames
41
  def predict_video(video_path):
42
- print("Processing video for deepfake detection...")
43
  cap = cv2.VideoCapture(video_path)
44
- frame_count = 0
45
- real_count = 0
46
- manipulated_count = 0
47
 
48
  while cap.isOpened():
49
  ret, frame = cap.read()
@@ -64,33 +51,22 @@ def predict_video(video_path):
64
  manipulated_count += 1
65
 
66
  cap.release()
 
 
67
 
68
- result = "Real" if real_count > manipulated_count else "Manipulated"
69
- return {
70
- "total_frames": frame_count,
71
- "real_frames": real_count,
72
- "manipulated_frames": manipulated_count,
73
- "result": result
74
- }
75
-
76
- # API Endpoint to check if API is running
77
  @app.get("/")
78
- def read_root():
79
  return {"message": "Deepfake Detection API is running!"}
80
 
81
- # API Endpoint to upload a video and get predictions
82
  @app.post("/predict/")
83
  async def predict(file: UploadFile = File(...)):
84
  file_path = f"temp_{file.filename}"
85
-
86
- # Save uploaded video
87
  with open(file_path, "wb") as buffer:
88
  shutil.copyfileobj(file.file, buffer)
89
 
90
- # Run prediction
91
  result = predict_video(file_path)
92
-
93
- # Delete temp file after processing
94
  os.remove(file_path)
95
  return result
96
 
 
 
 
1
  import torch
 
 
 
 
2
  import timm
3
+ import uvicorn
4
  from fastapi import FastAPI, File, UploadFile
5
  import shutil
6
  import os
7
+ import cv2
8
+ from PIL import Image
9
+ from torchvision import transforms
10
 
 
11
  app = FastAPI()
12
 
13
  # Device configuration
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
15
 
16
  # Define image transformations
17
  transform = transforms.Compose([
 
20
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
21
  ])
22
 
23
+ # Load model from Hugging Face Hub
24
+ MODEL_URL = "https://huggingface.co/Maddy21/deepfake-detection-api/blob/main/best_vit_model.pth"
 
 
 
 
25
 
 
26
  model = timm.create_model('vit_large_patch16_224', pretrained=False, num_classes=2)
27
+ model.load_state_dict(torch.hub.load_state_dict_from_url(MODEL_URL, map_location=device))
28
  model.to(device)
29
  model.eval()
 
30
 
 
31
  def predict_video(video_path):
 
32
  cap = cv2.VideoCapture(video_path)
33
+ frame_count, real_count, manipulated_count = 0, 0, 0
 
 
34
 
35
  while cap.isOpened():
36
  ret, frame = cap.read()
 
51
  manipulated_count += 1
52
 
53
  cap.release()
54
+ return {"frames": frame_count, "real": real_count, "manipulated": manipulated_count,
55
+ "result": "Real" if real_count > manipulated_count else "Fake"}
56
 
 
 
 
 
 
 
 
 
 
57
  @app.get("/")
58
+ def home():
59
  return {"message": "Deepfake Detection API is running!"}
60
 
 
61
  @app.post("/predict/")
62
  async def predict(file: UploadFile = File(...)):
63
  file_path = f"temp_{file.filename}"
 
 
64
  with open(file_path, "wb") as buffer:
65
  shutil.copyfileobj(file.file, buffer)
66
 
 
67
  result = predict_video(file_path)
 
 
68
  os.remove(file_path)
69
  return result
70
 
71
+ if __name__ == "__main__":
72
+ uvicorn.run(app, host="0.0.0.0", port=7860)