import os import time import requests import streamlit as st import torch from PIL import Image from torchvision import transforms from torchvision.transforms import InterpolationMode # ============================================================ # Configuration # ============================================================ MODEL_URL = ( "https://huggingface.co/neuralninja10/deepFakeWithCBAM/" "resolve/main/updatedDeepFakeModel.pt" ) MODEL_PATH = "deepFakeWithCBAM.pt" THRESHOLD = 0.68 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ============================================================ # Page Configuration # ============================================================ st.set_page_config( page_title="DeepFake Detection", page_icon="🛡️", layout="centered", ) # ============================================================ # Secure Model Loader # ============================================================ @st.cache_resource def load_model(): token = os.environ.get("HF_TOKEN") if token is None: raise RuntimeError("HF_TOKEN not found in Space secrets") headers = {"Authorization": f"Bearer {token}"} if not os.path.exists(MODEL_PATH): with st.spinner("Initializing system..."): response = requests.get( MODEL_URL, headers=headers, stream=True, timeout=60, ) response.raise_for_status() with open(MODEL_PATH, "wb") as f: for chunk in response.iter_content(8192): f.write(chunk) model = torch.jit.load(MODEL_PATH, map_location=DEVICE) model.eval() return model # ============================================================ # Image Processing # ============================================================ _transform = transforms.Compose([ transforms.Resize((256, 256), interpolation=InterpolationMode.BILINEAR), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ), ]) def preprocess_image(image: Image.Image) -> torch.Tensor: return _transform(image).unsqueeze(0) # ============================================================ # Inference # ============================================================ def run_inference(model, image: Image.Image): tensor = preprocess_image(image).to(DEVICE) start_time = time.time() with torch.no_grad(): logits = model(tensor) probability = torch.sigmoid(logits).item() latency = (time.time() - start_time) * 1000 is_real = probability > THRESHOLD confidence = probability if is_real else (1 - probability) return { "label": "Real" if is_real else "Fake", "confidence": confidence, "latency": latency, } # ============================================================ # UI # ============================================================ def main(): st.title("DeepFake Detection for eKYC (Facial Images)") st.caption("Upload an image to verify authenticity.") try: model = load_model() except Exception as e: st.error("System initialization failed.") st.exception(e) return uploaded_file = st.file_uploader( "Upload Image", type=["jpg", "jpeg", "png"], ) if uploaded_file: image = Image.open(uploaded_file).convert("RGB") st.image(image, caption="Uploaded Image") st.divider() if st.button("Analyze Image"): with st.spinner("Analyzing..."): result = run_inference(model, image) if result["label"] == "Real": st.success("✔ Image appears to be authentic") else: st.error("✖ Image is likely manipulated") st.metric( label="Confidence", value=f"{result['confidence']:.2%}", ) st.caption( f"Processing time: {result['latency']:.0f} ms" ) st.divider() st.caption( "This demo caters all the available generators including Style GAN and Diffusion model variants. " "For further inquiries please feel free to contact uzairmughal30@gmail.com" ) if __name__ == "__main__": main()