File size: 4,368 Bytes
1ba1e08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45ab4d5
1ba1e08
 
 
 
 
 
eade2be
1ba1e08
 
 
eade2be
 
1ba1e08
 
 
 
 
 
 
eade2be
1ba1e08
 
 
eade2be
1ba1e08
 
 
 
eade2be
1ba1e08
 
 
 
 
 
 
 
eade2be
1ba1e08
 
 
 
 
 
 
eade2be
1ba1e08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eade2be
1ba1e08
 
eade2be
 
1ba1e08
eade2be
 
1ba1e08
 
eade2be
1ba1e08
eade2be
1ba1e08
 
 
 
 
 
 
06a525d
eade2be
1ba1e08
 
 
 
eade2be
1ba1e08
 
 
 
eade2be
1ba1e08
 
 
 
 
eade2be
1ba1e08
eade2be
1ba1e08
eade2be
 
 
1ba1e08
eade2be
 
1ba1e08
eade2be
1ba1e08
 
eade2be
1ba1e08
 
 
 
eade2be
1ba1e08
 
 
 
40896a7
28bd1e0
1ba1e08
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import time
import requests
import streamlit as st
import torch
from PIL import Image
from torchvision import transforms
from torchvision.transforms import InterpolationMode

# ============================================================
# Configuration
# ============================================================

MODEL_URL = (
    "https://huggingface.co/neuralninja10/deepFakeWithCBAM/"
    "resolve/main/updatedDeepFakeModel.pt"
)
MODEL_PATH = "deepFakeWithCBAM.pt"
THRESHOLD = 0.68
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ============================================================
# Page Configuration
# ============================================================

st.set_page_config(
    page_title="DeepFake Detection",
    page_icon="🛡️",
    layout="centered",
)

# ============================================================
# Secure Model Loader
# ============================================================

@st.cache_resource
def load_model():
    token = os.environ.get("HF_TOKEN")
    if token is None:
        raise RuntimeError("HF_TOKEN not found in Space secrets")

    headers = {"Authorization": f"Bearer {token}"}

    if not os.path.exists(MODEL_PATH):
        with st.spinner("Initializing system..."):
            response = requests.get(
                MODEL_URL,
                headers=headers,
                stream=True,
                timeout=60,
            )
            response.raise_for_status()
            with open(MODEL_PATH, "wb") as f:
                for chunk in response.iter_content(8192):
                    f.write(chunk)

    model = torch.jit.load(MODEL_PATH, map_location=DEVICE)
    model.eval()
    return model

# ============================================================
# Image Processing
# ============================================================

_transform = transforms.Compose([
    transforms.Resize((256, 256), interpolation=InterpolationMode.BILINEAR),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

def preprocess_image(image: Image.Image) -> torch.Tensor:
    return _transform(image).unsqueeze(0)

# ============================================================
# Inference
# ============================================================

def run_inference(model, image: Image.Image):
    tensor = preprocess_image(image).to(DEVICE)

    start_time = time.time()
    with torch.no_grad():
        logits = model(tensor)
        probability = torch.sigmoid(logits).item()
    latency = (time.time() - start_time) * 1000

    is_real = probability > THRESHOLD
    confidence = probability if is_real else (1 - probability)

    return {
        "label": "Real" if is_real else "Fake",
        "confidence": confidence,
        "latency": latency,
    }

# ============================================================
# UI
# ============================================================

def main():
    st.title("DeepFake Detection for eKYC (Facial Images)")
    st.caption("Upload an image to verify authenticity.")

    try:
        model = load_model()
    except Exception as e:
        st.error("System initialization failed.")
        st.exception(e)
        return

    uploaded_file = st.file_uploader(
        "Upload Image",
        type=["jpg", "jpeg", "png"],
    )

    if uploaded_file:
        image = Image.open(uploaded_file).convert("RGB")
        st.image(image, caption="Uploaded Image")

        st.divider()

        if st.button("Analyze Image"):
            with st.spinner("Analyzing..."):
                result = run_inference(model, image)

            if result["label"] == "Real":
                st.success("✔ Image appears to be authentic")
            else:
                st.error("✖ Image is likely manipulated")

            st.metric(
                label="Confidence",
                value=f"{result['confidence']:.2%}",
            )

            st.caption(
                f"Processing time: {result['latency']:.0f} ms"
            )

    st.divider()
    st.caption(
        "This demo caters all the available generators including Style GAN and Diffusion model variants. "
        "For further inquiries please feel free to contact uzairmughal30@gmail.com"
    )

if __name__ == "__main__":
    main()