kyrilloswahid commited on
Commit
53cc87a
Β·
verified Β·
1 Parent(s): 7b57e57

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +40 -16
streamlit_app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import streamlit as st
2
  import numpy as np
3
  import tensorflow as tf
@@ -6,14 +7,18 @@ from tensorflow.keras.models import load_model
6
  from tensorflow.keras.applications.xception import preprocess_input as xcp_pre
7
  from tensorflow.keras.applications.efficientnet import preprocess_input as eff_pre
8
  from huggingface_hub import hf_hub_download
 
 
 
 
 
9
 
10
- # Set Streamlit page configuration
11
  st.set_page_config(page_title="Deepfake Image Verifier", layout="centered")
12
-
13
  st.title("πŸ” Deepfake Image Verifier")
14
  st.markdown("Upload an image to classify it as **Real** or **Fake** using an ensemble of Xception and EfficientNet models.")
15
 
16
- # Load models only once and cache them
17
  @st.cache_resource
18
  def load_models():
19
  xcp_path = hf_hub_download(repo_id="Zeyadd-Mostaffa/deepfake-image-detector_final", filename="xception_model.h5")
@@ -24,8 +29,9 @@ def load_models():
24
 
25
  xcp_model, eff_model = load_models()
26
 
27
- # Prediction function
28
- def predict(image_np):
 
29
  xcp_img = cv2.resize(image_np, (299, 299))
30
  eff_img = cv2.resize(image_np, (224, 224))
31
 
@@ -36,23 +42,41 @@ def predict(image_np):
36
  eff_pred = eff_model.predict(eff_tensor, verbose=0).flatten()[0]
37
 
38
  avg_pred = (xcp_pred + eff_pred) / 2
39
- label = "🟒 Real" if avg_pred > 0.5 else "πŸ”΄ Fake"
40
  return label
41
 
42
- # Upload image
43
  uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
44
-
45
  if uploaded_file:
46
  file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
47
  image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
 
 
 
 
 
48
 
49
- if image is None:
50
- st.error("Failed to decode the image. Please try another file.")
51
- else:
52
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
53
- st.image(image_rgb, caption="Uploaded Image", use_column_width=True)
 
 
 
 
54
 
55
- with st.spinner("Analyzing..."):
56
- label = predict(image_rgb)
 
 
 
 
 
 
 
 
 
57
 
58
- st.success(f"Prediction: **{label}**")
 
 
1
+ # streamlit_app.py
2
  import streamlit as st
3
  import numpy as np
4
  import tensorflow as tf
 
7
  from tensorflow.keras.applications.xception import preprocess_input as xcp_pre
8
  from tensorflow.keras.applications.efficientnet import preprocess_input as eff_pre
9
  from huggingface_hub import hf_hub_download
10
+ from fastapi import FastAPI, File, UploadFile
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from starlette.responses import PlainTextResponse
13
+ import uvicorn
14
+ from io import BytesIO
15
 
16
+ # Set up Streamlit UI
17
  st.set_page_config(page_title="Deepfake Image Verifier", layout="centered")
 
18
  st.title("πŸ” Deepfake Image Verifier")
19
  st.markdown("Upload an image to classify it as **Real** or **Fake** using an ensemble of Xception and EfficientNet models.")
20
 
21
+ # Load models from HF Hub once
22
  @st.cache_resource
23
  def load_models():
24
  xcp_path = hf_hub_download(repo_id="Zeyadd-Mostaffa/deepfake-image-detector_final", filename="xception_model.h5")
 
29
 
30
  xcp_model, eff_model = load_models()
31
 
32
+ # Prediction logic
33
+
34
+ def run_model_prediction(image_np):
35
  xcp_img = cv2.resize(image_np, (299, 299))
36
  eff_img = cv2.resize(image_np, (224, 224))
37
 
 
42
  eff_pred = eff_model.predict(eff_tensor, verbose=0).flatten()[0]
43
 
44
  avg_pred = (xcp_pred + eff_pred) / 2
45
+ label = "Real" if avg_pred > 0.5 else "Fake"
46
  return label
47
 
48
+ # Streamlit UI
49
  uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
 
50
  if uploaded_file:
51
  file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
52
  image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
53
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
54
+ st.image(image_rgb, caption="Uploaded Image", use_column_width=True)
55
+ with st.spinner("Analyzing..."):
56
+ label = run_model_prediction(image_rgb)
57
+ st.success(f"Prediction: **{label}**")
58
 
59
+ # FastAPI for backend use (Flask calls etc.)
60
+ app = FastAPI()
61
+ app.add_middleware(
62
+ CORSMiddleware,
63
+ allow_origins=["*"],
64
+ allow_credentials=True,
65
+ allow_methods=["*"],
66
+ allow_headers=["*"],
67
+ )
68
 
69
+ @app.post("/predict")
70
+ async def predict_api(file: UploadFile = File(...)):
71
+ try:
72
+ contents = await file.read()
73
+ file_bytes = np.asarray(bytearray(contents), dtype=np.uint8)
74
+ image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
75
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
76
+ label = run_model_prediction(image_rgb)
77
+ return PlainTextResponse(label, status_code=200)
78
+ except Exception as e:
79
+ return PlainTextResponse(f"Error: {str(e)}", status_code=500)
80
 
81
+ if __name__ == "__main__":
82
+ uvicorn.run(app, host="0.0.0.0", port=7860)