Maddy21 commited on
Commit
e3ac0d6
·
1 Parent(s): cd03405
Files changed (2) hide show
  1. .gitattributes +2 -0
  2. app.py +41 -19
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.pth filter=lfs diff=lfs merge=lfs -text
37
+
app.py CHANGED
@@ -1,18 +1,41 @@
1
  import torch
 
 
 
 
2
  import timm
3
- import uvicorn
4
  from fastapi import FastAPI, File, UploadFile
5
  import shutil
6
  import os
7
- import cv2
8
- from PIL import Image
9
- from torchvision import transforms
10
 
 
11
  app = FastAPI()
12
 
13
  # Device configuration
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  # Define image transformations
17
  transform = transforms.Compose([
18
  transforms.Resize((224, 224)),
@@ -20,17 +43,12 @@ transform = transforms.Compose([
20
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
21
  ])
22
 
23
- # Load model from Hugging Face Hub
24
- MODEL_URL = "https://huggingface.co/Maddy21/deepfake-detection-api/blob/main/best_vit_model.pth"
25
-
26
- model = timm.create_model('vit_large_patch16_224', pretrained=False, num_classes=2)
27
- model.load_state_dict(torch.hub.load_state_dict_from_url(MODEL_URL, map_location=device))
28
- model.to(device)
29
- model.eval()
30
-
31
  def predict_video(video_path):
32
  cap = cv2.VideoCapture(video_path)
33
- frame_count, real_count, manipulated_count = 0, 0, 0
 
 
34
 
35
  while cap.isOpened():
36
  ret, frame = cap.read()
@@ -51,22 +69,26 @@ def predict_video(video_path):
51
  manipulated_count += 1
52
 
53
  cap.release()
54
- return {"frames": frame_count, "real": real_count, "manipulated": manipulated_count,
55
- "result": "Real" if real_count > manipulated_count else "Fake"}
56
 
 
57
  @app.get("/")
58
- def home():
59
  return {"message": "Deepfake Detection API is running!"}
60
 
 
61
  @app.post("/predict/")
62
  async def predict(file: UploadFile = File(...)):
63
  file_path = f"temp_{file.filename}"
 
 
64
  with open(file_path, "wb") as buffer:
65
  shutil.copyfileobj(file.file, buffer)
66
 
 
67
  result = predict_video(file_path)
 
 
68
  os.remove(file_path)
69
  return result
70
-
71
- if __name__ == "__main__":
72
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import cv2
6
  import timm
 
7
  from fastapi import FastAPI, File, UploadFile
8
  import shutil
9
  import os
 
 
 
10
 
11
+ # FastAPI app instance
12
  app = FastAPI()
13
 
14
  # Device configuration
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
+ # Create model storage directory
18
+ MODEL_DIR = "./models"
19
+ os.makedirs(MODEL_DIR, exist_ok=True)
20
+
21
+ # Model URL from Hugging Face
22
+ MODEL_URL = "https://huggingface.co/Maddy21/deepfake-detection-api/resolve/main/best_vit_model.pth"
23
+
24
+ # Define model path
25
+ model_path = os.path.join(MODEL_DIR, "best_vit_model.pth")
26
+
27
+ # Download model if not already present
28
+ if not os.path.exists(model_path):
29
+ print("Downloading model...")
30
+ torch.hub.download_url_to_file(MODEL_URL, model_path)
31
+ print("Model downloaded successfully.")
32
+
33
+ # Load the trained ViT model
34
+ model = timm.create_model('vit_large_patch16_224', pretrained=False, num_classes=2)
35
+ model.load_state_dict(torch.load(model_path, map_location=device))
36
+ model.to(device)
37
+ model.eval()
38
+
39
  # Define image transformations
40
  transform = transforms.Compose([
41
  transforms.Resize((224, 224)),
 
43
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
44
  ])
45
 
46
+ # Function to process video and classify frames
 
 
 
 
 
 
 
47
  def predict_video(video_path):
48
  cap = cv2.VideoCapture(video_path)
49
+ frame_count = 0
50
+ real_count = 0
51
+ manipulated_count = 0
52
 
53
  while cap.isOpened():
54
  ret, frame = cap.read()
 
69
  manipulated_count += 1
70
 
71
  cap.release()
72
+ result = "Real" if real_count > manipulated_count else "Manipulated"
73
+ return {"total_frames": frame_count, "real_frames": real_count, "manipulated_frames": manipulated_count, "result": result}
74
 
75
+ # API Endpoint to check API status
76
  @app.get("/")
77
+ def read_root():
78
  return {"message": "Deepfake Detection API is running!"}
79
 
80
+ # API Endpoint to receive and process video
81
  @app.post("/predict/")
82
  async def predict(file: UploadFile = File(...)):
83
  file_path = f"temp_{file.filename}"
84
+
85
+ # Save uploaded video
86
  with open(file_path, "wb") as buffer:
87
  shutil.copyfileobj(file.file, buffer)
88
 
89
+ # Run prediction
90
  result = predict_video(file_path)
91
+
92
+ # Delete temp file after processing
93
  os.remove(file_path)
94
  return result