neuralninja10's picture
Update app.py
45ab4d5 verified
import os
import time
import requests
import streamlit as st
import torch
from PIL import Image
from torchvision import transforms
from torchvision.transforms import InterpolationMode
# ============================================================
# Configuration
# ============================================================
MODEL_URL = (
"https://huggingface.co/neuralninja10/deepFakeWithCBAM/"
"resolve/main/updatedDeepFakeModel.pt"
)
MODEL_PATH = "deepFakeWithCBAM.pt"
THRESHOLD = 0.68
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ============================================================
# Page Configuration
# ============================================================
st.set_page_config(
page_title="DeepFake Detection",
page_icon="🛡️",
layout="centered",
)
# ============================================================
# Secure Model Loader
# ============================================================
@st.cache_resource
def load_model():
token = os.environ.get("HF_TOKEN")
if token is None:
raise RuntimeError("HF_TOKEN not found in Space secrets")
headers = {"Authorization": f"Bearer {token}"}
if not os.path.exists(MODEL_PATH):
with st.spinner("Initializing system..."):
response = requests.get(
MODEL_URL,
headers=headers,
stream=True,
timeout=60,
)
response.raise_for_status()
with open(MODEL_PATH, "wb") as f:
for chunk in response.iter_content(8192):
f.write(chunk)
model = torch.jit.load(MODEL_PATH, map_location=DEVICE)
model.eval()
return model
# ============================================================
# Image Processing
# ============================================================
_transform = transforms.Compose([
transforms.Resize((256, 256), interpolation=InterpolationMode.BILINEAR),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
])
def preprocess_image(image: Image.Image) -> torch.Tensor:
return _transform(image).unsqueeze(0)
# ============================================================
# Inference
# ============================================================
def run_inference(model, image: Image.Image):
tensor = preprocess_image(image).to(DEVICE)
start_time = time.time()
with torch.no_grad():
logits = model(tensor)
probability = torch.sigmoid(logits).item()
latency = (time.time() - start_time) * 1000
is_real = probability > THRESHOLD
confidence = probability if is_real else (1 - probability)
return {
"label": "Real" if is_real else "Fake",
"confidence": confidence,
"latency": latency,
}
# ============================================================
# UI
# ============================================================
def main():
st.title("DeepFake Detection for eKYC (Facial Images)")
st.caption("Upload an image to verify authenticity.")
try:
model = load_model()
except Exception as e:
st.error("System initialization failed.")
st.exception(e)
return
uploaded_file = st.file_uploader(
"Upload Image",
type=["jpg", "jpeg", "png"],
)
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image")
st.divider()
if st.button("Analyze Image"):
with st.spinner("Analyzing..."):
result = run_inference(model, image)
if result["label"] == "Real":
st.success("✔ Image appears to be authentic")
else:
st.error("✖ Image is likely manipulated")
st.metric(
label="Confidence",
value=f"{result['confidence']:.2%}",
)
st.caption(
f"Processing time: {result['latency']:.0f} ms"
)
st.divider()
st.caption(
"This demo caters all the available generators including Style GAN and Diffusion model variants. "
"For further inquiries please feel free to contact uzairmughal30@gmail.com"
)
if __name__ == "__main__":
main()