SanskarModi commited on
Commit
6445597
·
verified ·
1 Parent(s): caddf8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -61
app.py CHANGED
@@ -1,79 +1,50 @@
 
1
  import tempfile
2
- from io import BytesIO
3
- from typing import Optional
4
-
5
  import cv2
6
- import numpy as np
7
- import uvicorn
8
- from fastapi import FastAPI, File, Form, Query, UploadFile
9
- from fastapi.responses import JSONResponse, StreamingResponse
10
- from starlette.middleware.cors import CORSMiddleware
11
-
12
  from prediction import Prediction
13
 
14
- app = FastAPI(
15
- title="Deepfake Detection API",
16
- description="Upload a video to check if it's real or a manipulated deepfake (Face2Face, FaceShifter, FaceSwap, or NeuralTextures).",
17
- )
18
-
19
- # CORS (optional if using frontend)
20
- app.add_middleware(
21
- CORSMiddleware,
22
- allow_origins=["*"],
23
- allow_credentials=True,
24
- allow_methods=["*"],
25
- allow_headers=["*"],
26
- )
27
-
28
- # Initialize model
29
  predictor = Prediction()
30
 
31
-
32
- @app.post("/predict/")
33
- async def predict_deepfake(
34
- video: UploadFile = File(...),
35
- sequence_length: Optional[int] = Query(
36
- None, description="Number of frames to use for prediction"
37
- ),
38
- ):
39
  try:
40
- # Save video to a temporary file
41
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video:
42
- temp_video.write(await video.read())
43
  temp_video_path = temp_video.name
44
 
45
- # Get prediction and explanation image
46
  prediction_str, explanation_image, details = predictor.predict(
47
  temp_video_path, sequence_length
48
  )
49
 
50
- response = {"prediction": prediction_str, "details": details}
51
-
52
- # Convert explanation image (np array) to JPEG bytes if available
53
  if explanation_image is not None:
54
- _, img_encoded = cv2.imencode(".jpg", explanation_image)
55
- img_bytes = BytesIO(img_encoded.tobytes())
56
- return StreamingResponse(
57
- content=img_bytes,
58
- media_type="image/jpeg",
59
- headers={"X-Prediction-Result": prediction_str},
60
- )
61
- else:
62
- return JSONResponse(content=response)
63
-
64
- except Exception as e:
65
- import traceback
66
 
67
- error_detail = traceback.format_exc()
68
- return JSONResponse(
69
- status_code=500, content={"error": str(e), "detail": error_detail}
70
- )
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- @app.get("/")
74
- def root():
75
- return {
76
- "message": "Deepfake Detection API is running!",
77
- "usage": "POST to /predict/ with a video file and optional sequence_length parameter",
78
- }
79
-
 
1
+ import gradio as gr
2
  import tempfile
 
 
 
3
  import cv2
 
 
 
 
 
 
4
  from prediction import Prediction
5
 
6
+ # Initialize the predictor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  predictor = Prediction()
8
 
9
+ # Define your inference function
10
+ def detect_deepfake(video_file, sequence_length=None):
 
 
 
 
 
 
11
  try:
12
+ # Save the uploaded video to a temporary file
13
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video:
14
+ temp_video.write(video_file.read())
15
  temp_video_path = temp_video.name
16
 
17
+ # Run prediction
18
  prediction_str, explanation_image, details = predictor.predict(
19
  temp_video_path, sequence_length
20
  )
21
 
22
+ # Return prediction and image
23
+ explanation_img = None
 
24
  if explanation_image is not None:
25
+ explanation_img = cv2.cvtColor(explanation_image, cv2.COLOR_BGR2RGB)
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ return prediction_str, explanation_img, str(details)
 
 
 
28
 
29
+ except Exception as e:
30
+ return f"Error: {str(e)}", None, ""
31
+
32
+ # Gradio UI
33
+ demo = gr.Interface(
34
+ fn=detect_deepfake,
35
+ inputs=[
36
+ gr.File(label="Upload Video (.mp4)", file_types=[".mp4"]),
37
+ gr.Number(label="Sequence Length (Optional)", value=None),
38
+ ],
39
+ outputs=[
40
+ gr.Textbox(label="Prediction"),
41
+ gr.Image(label="Explanation Image"),
42
+ gr.Textbox(label="Details"),
43
+ ],
44
+ title="Deepdetect",
45
+ description="Upload a video to detect deepfakes using Face2Face, FaceSwap, FaceShifter, and NeuralTextures models.",
46
+ allow_flagging="never",
47
+ )
48
 
49
+ if __name__ == "__main__":
50
+ demo.launch()