vaniv commited on
Commit
ad7b882
·
verified ·
1 Parent(s): 34aff01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +229 -109
app.py CHANGED
@@ -1,113 +1,233 @@
1
- import os
2
- import typing as t
3
-
4
- import gradio as gr
5
  import numpy as np
6
- import tensorflow as tf
7
- from tensorflow.keras.layers import (Input, Conv2D, BatchNormalization,
8
- MaxPooling2D, Flatten, Dropout, Dense, LeakyReLU)
9
- from tensorflow.keras.models import Model
10
- from PIL import Image
11
-
12
- # Paths
13
- CUSTOM_MODEL_PATH = "model.h5" # optional: full Keras model
14
- MESO_WEIGHTS_PATH = "weights/Meso4_DF.weights.h5" # your weights-only file
15
- LABELS = ["real", "fake"] # index 0..1 (we'll compute both scores)
16
-
17
- # Globals
18
- MODEL: t.Optional[tf.keras.Model] = None
19
- IS_MESO = False
20
- TARGET_SIZE = (256, 256) # your notebook used 256×256
21
- THRESHOLD = 0.5 # sigmoid > 0.5 => fake
22
-
23
- def build_meso4() -> tf.keras.Model:
24
- x = Input(shape=(TARGET_SIZE[0], TARGET_SIZE[1], 3))
25
- x1 = Conv2D(8, (3, 3), padding='same', activation='relu')(x)
26
- x1 = BatchNormalization()(x1)
27
- x1 = MaxPooling2D(pool_size=(2, 2), padding='same')(x1)
28
-
29
- x2 = Conv2D(8, (5, 5), padding='same', activation='relu')(x1)
30
- x2 = BatchNormalization()(x2)
31
- x2 = MaxPooling2D(pool_size=(2, 2), padding='same')(x2)
32
-
33
- x3 = Conv2D(16, (5, 5), padding='same', activation='relu')(x2)
34
- x3 = BatchNormalization()(x3)
35
- x3 = MaxPooling2D(pool_size=(2, 2), padding='same')(x3)
36
-
37
- x4 = Conv2D(16, (5, 5), padding='same', activation='relu')(x3)
38
- x4 = BatchNormalization()(x4)
39
- x4 = MaxPooling2D(pool_size=(4, 4), padding='same')(x4)
40
-
41
- y = Flatten()(x4)
42
- y = Dropout(0.5)(y)
43
- y = Dense(16)(y)
44
- y = LeakyReLU(alpha=0.1)(y)
45
- y = Dropout(0.5)(y)
46
- y = Dense(1, activation='sigmoid')(y)
47
-
48
- model = Model(inputs=x, outputs=y)
49
- model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
50
- return model
51
-
52
- def _load_model():
53
- """Load a full Keras model if present; otherwise build Meso4 and load weights."""
54
- global MODEL, IS_MESO
55
- # 1) Full model (optional)
56
- if os.path.exists(CUSTOM_MODEL_PATH):
57
- try:
58
- MODEL = tf.keras.models.load_model(CUSTOM_MODEL_PATH, compile=False)
59
- IS_MESO = False
60
- print("Loaded custom model from model.h5")
61
- return
62
- except Exception as e:
63
- print("Failed to load model.h5:", e)
64
-
65
- # 2) Meso4 + weights (your case)
66
- if os.path.exists(MESO_WEIGHTS_PATH):
67
- MODEL = build_meso4()
68
- MODEL.load_weights(MESO_WEIGHTS_PATH)
69
- IS_MESO = True
70
- print("Loaded Meso4 with weights:", MESO_WEIGHTS_PATH)
71
- return
72
-
73
- # 3) Hard fail (don’t silently switch to ImageNet; this is a deepfake app)
74
- raise RuntimeError(
75
- "No model found. Upload either model.h5 or weights/Meso4_DF to the Space."
76
- )
77
-
78
- def _preprocess(img: Image.Image) -> np.ndarray:
79
- img = img.convert("RGB").resize(TARGET_SIZE)
80
- arr = np.array(img).astype("float32") / 255.0
81
- return np.expand_dims(arr, axis=0)
82
-
83
- def predict(image: Image.Image):
84
- if image is None:
85
- return {"real": 0.0, "fake": 0.0}, None, "Upload an image."
86
- x = _preprocess(image)
87
- prob_fake = float(MODEL.predict(x, verbose=0)[0][0])
88
- prob_real = 1.0 - prob_fake
89
- label = "fake" if prob_fake >= THRESHOLD else "real"
90
- msg = f"Prediction: {label.upper()} | fake={prob_fake:.2f}, real={prob_real:.2f}"
91
- # Return both scores for the Label component
92
- return {"real": prob_real, "fake": prob_fake}, image, msg
93
-
94
- # Init
95
- _load_model()
96
-
97
- with gr.Blocks(title="Deepfake Detector (Meso4)") as demo:
98
- gr.Markdown("# Deepfake Detector (Meso4)\n"
99
- "Upload a face image (or a frame from a video). The model outputs real vs fake.")
100
-
101
- with gr.Row():
102
- with gr.Column(scale=1):
103
- inp = gr.Image(type="pil", label="Upload image")
104
- btn = gr.Button("Predict")
105
- with gr.Column(scale=1):
106
- out_label = gr.Label(num_top_classes=2, label="Scores")
107
- out_img = gr.Image(type="pil", label="Preview")
108
- out_text = gr.Markdown()
109
-
110
- btn.click(fn=predict, inputs=inp, outputs=[out_label, out_img, out_text])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  if __name__ == "__main__":
113
  demo.launch()
 
1
+ import io, os, tempfile, math
 
 
 
2
  import numpy as np
3
+ import gradio as gr
4
+ from PIL import Image, ImageChops, ImageFilter
5
+ import cv2
6
+ from skimage import exposure
7
+
8
+ # ---------- Forensic primitives ----------
9
+
10
+ def error_level_analysis(pil_img: Image.Image, quality: int = 90):
11
+ """
12
+ ELA: save as JPEG (quality q), diff with original, enhance to visualize anomalies.
13
+ Returns: ELA image (PIL), mean ELA intensity (float)
14
+ """
15
+ img = pil_img.convert("RGB")
16
+ with io.BytesIO() as buffer:
17
+ img.save(buffer, "JPEG", quality=quality)
18
+ buffer.seek(0)
19
+ comp = Image.open(buffer).convert("RGB")
20
+
21
+ diff = ImageChops.difference(img, comp)
22
+ # amplify differences to be human-visible
23
+ extrema = diff.getextrema()
24
+ max_diff = max([m for (_, m) in extrema])
25
+ scale = 255.0 / max(1, max_diff)
26
+ ela = ImageEnhance(diff, scale)
27
+ ela_np = np.array(ela)
28
+ mean_intensity = float(ela_np.mean() / 255.0)
29
+ return ela, mean_intensity
30
+
31
+ def ImageEnhance(pil_img: Image.Image, scale: float):
32
+ arr = np.array(pil_img).astype("float32") * scale
33
+ arr = np.clip(arr, 0, 255).astype("uint8")
34
+ return Image.fromarray(arr)
35
+
36
+ def fft_high_freq_ratio(pil_img: Image.Image):
37
+ """
38
+ Compute high-frequency energy ratio from grayscale FFT.
39
+ Returns: spectrum image (PIL), hf_ratio (float in [0,1] approx)
40
+ """
41
+ gray = np.array(pil_img.convert("L"), dtype=np.float32) / 255.0
42
+ # windowing to reduce edge artifacts
43
+ h, w = gray.shape
44
+ win_y = np.hanning(h)[:, None]
45
+ win_x = np.hanning(w)[None, :]
46
+ grayw = gray * (win_y * win_x)
47
+
48
+ F = np.fft.fftshift(np.fft.fft2(grayw))
49
+ mag = np.log1p(np.abs(F))
50
+ # visualize spectrum normalized
51
+ spec = (mag / mag.max() * 255).astype("uint8")
52
+ spec_img = Image.fromarray(spec)
53
+
54
+ # high vs low freq using radius threshold
55
+ cy, cx = h // 2, w // 2
56
+ yy, xx = np.ogrid[:h, :w]
57
+ dist = np.sqrt((yy - cy) ** 2 + (xx - cx) ** 2)
58
+ r_low = min(h, w) * 0.08 # low radius
59
+ mask_low = dist <= r_low
60
+ low_energy = mag[mask_low].sum()
61
+ high_energy = mag[~mask_low].sum()
62
+ hf_ratio = float(high_energy / (high_energy + low_energy + 1e-9))
63
+ return spec_img, hf_ratio
64
+
65
+ def noise_map_score(pil_img: Image.Image):
66
+ """
67
+ Laplacian variance map as a proxy for local sharpness / noise consistency.
68
+ Returns: heatmap (PIL), inconsistency score (float)
69
+ """
70
+ img = np.array(pil_img.convert("L"))
71
+ lap = cv2.Laplacian(img, cv2.CV_32F, ksize=3)
72
+ # normalize heatmap for display
73
+ lap_abs = np.abs(lap)
74
+ heat = (lap_abs / (lap_abs.max() + 1e-9) * 255).astype("uint8")
75
+ heat_eq = exposure.equalize_adapthist(heat, clip_limit=0.01)
76
+ heat_disp = Image.fromarray((heat_eq * 255).astype("uint8"))
77
+
78
+ # inconsistency: std dev of local variance over tiles
79
+ tile = 32
80
+ H, W = img.shape
81
+ vars_ = []
82
+ for y in range(0, H, tile):
83
+ for x in range(0, W, tile):
84
+ patch = lap_abs[y:min(y+tile, H), x:min(x+tile, W)]
85
+ if patch.size > 0:
86
+ vars_.append(patch.var())
87
+ vars_ = np.array(vars_, dtype=np.float32)
88
+ score = float((vars_.std() / (vars_.mean() + 1e-9))) # higher = more inconsistent
89
+ # squash to approx [0,1]
90
+ score_norm = float(np.tanh(score / 5.0))
91
+ return heat_disp, score_norm
92
+
93
+ # ---------- Simple decision rule ----------
94
+
95
+ def combine_scores(ela_mean, hf_ratio, noise_incons):
96
+ """
97
+ Combine three signals into a simple confidence of manipulation.
98
+ Tuned conservatively to avoid false alarms on clean photos.
99
+ """
100
+ # weights (can tweak)
101
+ w1, w2, w3 = 0.4, 0.35, 0.25
102
+ # normalize features roughly to [0,1]
103
+ s_ela = np.clip(ela_mean * 2.5, 0, 1) # more ELA intensity -> more suspect
104
+ s_hf = np.clip((hf_ratio - 0.65) / 0.25, 0, 1) # lots of HF energy -> suspect
105
+ s_noi = np.clip(noise_incons, 0, 1)
106
+
107
+ suspect = float(w1 * s_ela + w2 * s_hf + w3 * s_noi)
108
+ label = "Likely Manipulated" if suspect >= 0.55 else "Likely Authentic"
109
+ return label, suspect
110
+
111
+ # ---------- Gradio handlers ----------
112
+
113
+ def analyze_image(pil_img: Image.Image):
114
+ if pil_img is None:
115
+ return {}, None, None, None, "Upload an image"
116
+
117
+ # Standardize size for stable scores (keeps aspect, pads)
118
+ pil_img = pil_img.convert("RGB")
119
+ pil_img = pil_img.resize((512, 512))
120
+
121
+ ela_img, ela_mean = error_level_analysis(pil_img, quality=90)
122
+ spec_img, hf_ratio = fft_high_freq_ratio(pil_img)
123
+ noise_img, noise_incons = noise_map_score(pil_img)
124
+
125
+ label, conf = combine_scores(ela_mean, hf_ratio, noise_incons)
126
+ scores = {
127
+ "Confidence manipulated": round(conf, 3),
128
+ "ELA mean": round(ela_mean, 3),
129
+ "HF ratio": round(hf_ratio, 3),
130
+ "Noise inconsistency": round(noise_incons, 3)
131
+ }
132
+ msg = f"Result: **{label}** — confidence: {conf:.2f}\n\n" \
133
+ f"*ELA={ela_mean:.3f}, HF={hf_ratio:.3f}, Noise={noise_incons:.3f}*"
134
+
135
+ return scores, pil_img, ela_img, spec_img, noise_img, msg
136
+
137
+ def analyze_video(video_file):
138
+ if video_file is None:
139
+ return {}, None, None, None, None, "Upload a short video (<= 10–15s)"
140
+ # write to temp, sample frames
141
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
142
+ tmp.write(video_file.read()); tmp.flush(); tmp.close()
143
+
144
+ cap = cv2.VideoCapture(tmp.name)
145
+ frames = []
146
+ idx = 0
147
+ while True:
148
+ ret, frame = cap.read()
149
+ if not ret: break
150
+ if idx % 15 == 0: # sample every 15th frame
151
+ frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
152
+ if len(frames) >= 8: break
153
+ idx += 1
154
+ cap.release()
155
+ os.unlink(tmp.name)
156
+
157
+ if not frames:
158
+ return {}, None, None, None, None, "Couldn’t read frames; try a different/shorter video."
159
+
160
+ # analyze first frame for visuals, average scores across all
161
+ scores_list = []
162
+ vis_sample = frames[0].resize((512, 512)).convert("RGB")
163
+
164
+ ela_img, ela_mean = error_level_analysis(vis_sample)
165
+ spec_img, hf_ratio = fft_high_freq_ratio(vis_sample)
166
+ noise_img, noise_incons = noise_map_score(vis_sample)
167
+
168
+ # avg over all frames
169
+ elas, hfs, noises = [ela_mean], [hf_ratio], [noise_incons]
170
+ for f in frames[1:]:
171
+ f = f.resize((512, 512)).convert("RGB")
172
+ _, em = error_level_analysis(f)
173
+ _, hr = fft_high_freq_ratio(f)
174
+ _, ns = noise_map_score(f)
175
+ elas.append(em); hfs.append(hr); noises.append(ns)
176
+
177
+ ela_m = float(np.mean(elas))
178
+ hf_m = float(np.mean(hfs))
179
+ noi_m = float(np.mean(noises))
180
+ label, conf = combine_scores(ela_m, hf_m, noi_m)
181
+
182
+ scores = {
183
+ "Confidence manipulated": round(conf, 3),
184
+ "ELA mean (avg)": round(ela_m, 3),
185
+ "HF ratio (avg)": round(hf_m, 3),
186
+ "Noise inconsistency (avg)": round(noi_m, 3)
187
+ }
188
+ msg = f"Result: **{label}** — confidence: {conf:.2f}\n\n" \
189
+ f"*ELA={ela_m:.3f}, HF={hf_m:.3f}, Noise={noi_m:.3f}*\n" \
190
+ f"_Note: rule-based (no ML), indicative only._"
191
+
192
+ return scores, vis_sample, ela_img, spec_img, noise_img, msg
193
+
194
+ # ---------- UI ----------
195
+
196
+ with gr.Blocks(title="Deepfake Forensics (No-ML)") as demo:
197
+ gr.Markdown("## Deepfake Forensics (No-ML)\n"
198
+ "Upload an **image** or a short **video**. We run three classical forensic checks:\n"
199
+ "- **ELA** (Error Level Analysis)\n- **Frequency Spectrum** (high-freq energy)\n- **Noise Consistency** (Laplacian map)\n"
200
+ "Outputs a **Likely Authentic / Likely Manipulated** decision with visual evidence.")
201
+
202
+ with gr.Tab("Image"):
203
+ with gr.Row():
204
+ with gr.Column(scale=1):
205
+ img_in = gr.Image(type="pil", label="Upload image")
206
+ btn = gr.Button("Analyze")
207
+ with gr.Column(scale=2):
208
+ scores = gr.Label(label="Scores")
209
+ img_std = gr.Image(label="Normalized Input")
210
+ img_ela = gr.Image(label="ELA Heatmap")
211
+ img_fft = gr.Image(label="Frequency Spectrum")
212
+ img_noise = gr.Image(label="Noise/Sharpness Map")
213
+ msg = gr.Markdown()
214
+ btn.click(analyze_image, inputs=img_in,
215
+ outputs=[scores, img_std, img_ela, img_fft, img_noise, msg])
216
+
217
+ with gr.Tab("Video (optional)"):
218
+ with gr.Row():
219
+ with gr.Column(scale=1):
220
+ vid_in = gr.Video(label="Upload short MP4 (<=10–15s)")
221
+ btnv = gr.Button("Analyze Video")
222
+ with gr.Column(scale=2):
223
+ vscores = gr.Label(label="Scores (avg over frames)")
224
+ vimg_std = gr.Image(label="Frame Preview")
225
+ vimg_ela = gr.Image(label="ELA Heatmap (frame)")
226
+ vimg_fft = gr.Image(label="Frequency Spectrum (frame)")
227
+ vimg_noise = gr.Image(label="Noise/Sharpness Map (frame)")
228
+ vmsg = gr.Markdown()
229
+ btnv.click(analyze_video, inputs=vid_in,
230
+ outputs=[vscores, vimg_std, vimg_ela, vimg_fft, vimg_noise, vmsg])
231
 
232
  if __name__ == "__main__":
233
  demo.launch()