vaniv commited on
Commit
b0d371d
·
verified ·
1 Parent(s): 8f07fef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -134
app.py CHANGED
@@ -1,13 +1,71 @@
1
- import io
2
- import numpy as np
3
- import gradio as gr
4
  from PIL import Image, ImageChops, ImageDraw
5
  import cv2
6
  from skimage import exposure
7
  import mediapipe as mp
8
 
9
- # ====================== ELA (compression residual) ======================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def _enhance_for_display(pil_img, scale: float):
12
  arr = np.array(pil_img).astype("float32") * scale
13
  arr = np.clip(arr, 0, 255).astype("uint8")
@@ -16,12 +74,10 @@ def _enhance_for_display(pil_img, scale: float):
16
  def error_level_analysis(pil_img: Image.Image, quality: int = 90):
17
  img = pil_img.convert("RGB")
18
  with io.BytesIO() as buf:
19
- img.save(buf, "JPEG", quality=quality)
20
- buf.seek(0)
21
  comp = Image.open(buf).convert("RGB")
22
  diff = ImageChops.difference(img, comp)
23
- extrema = diff.getextrema()
24
- max_diff = max([m for (_, m) in extrema])
25
  scale = 255.0 / max(1, max_diff)
26
  ela_vis = _enhance_for_display(diff, scale)
27
  ela_np = np.array(ela_vis, dtype=np.float32)
@@ -31,117 +87,51 @@ def error_level_analysis(pil_img: Image.Image, quality: int = 90):
31
  def ela_sweep_mean(pil_img, qualities=(95, 90, 85)):
32
  vals = []
33
  for q in qualities:
34
- _, m = error_level_analysis(pil_img, quality=q)
35
- vals.append(m)
36
  return float(max(vals)), float(np.mean(vals))
37
 
38
- # ====================== Frequency & Noise (support face masks) ======================
39
-
40
- def fft_high_freq_ratio(pil_img: Image.Image, mask=None):
41
  y = pil_img.convert("YCbCr").split()[0]
42
- gray = np.array(y, dtype=np.float32) / 255.0
43
- if mask is not None:
44
- gray = gray * mask
45
-
46
  h, w = gray.shape
47
  wy, wx = np.hanning(h)[:, None], np.hanning(w)[None, :]
48
  F = np.fft.fftshift(np.fft.fft2(gray * (wy * wx)))
49
  mag = np.log1p(np.abs(F))
50
-
51
- cy, cx = h // 2, w // 2
52
- yy, xx = np.ogrid[:h, :w]
53
- dist = np.sqrt((yy - cy) ** 2 + (xx - cx) ** 2)
54
  r_low = min(h, w) * 0.08
 
 
55
 
56
- low_energy = float(mag[dist <= r_low].sum())
57
- high_energy = float(mag[dist > r_low].sum())
58
- hf_ratio = high_energy / (high_energy + low_energy + 1e-9)
59
- return None, float(hf_ratio)
60
-
61
- def noise_inconsistency(pil_img: Image.Image, mask=None):
62
  y = pil_img.convert("YCbCr").split()[0]
63
  img = np.array(y, dtype=np.float32)
64
- if mask is not None:
65
- img = img * mask
66
-
67
- lap = cv2.Laplacian(img, cv2.CV_32F, ksize=3)
68
- lap_abs = np.abs(lap)
69
-
70
- _ = exposure.equalize_adapthist(
71
- (lap_abs / (lap_abs.max() + 1e-9)).astype("float32"), clip_limit=0.01
72
- )
73
-
74
- tile = 32
75
- H, W = lap_abs.shape
76
- vals = []
77
  for yy in range(0, H, tile):
78
  for xx in range(0, W, tile):
79
  patch = lap_abs[yy:min(yy+tile, H), xx:min(xx+tile, W)]
80
- if patch.size:
81
- vals.append(patch.var())
82
- if not vals:
83
- return None, 0.0
84
  vals = np.array(vals, dtype=np.float32)
85
  score = float(vals.std() / (vals.mean() + 1e-9))
86
  return None, float(np.tanh(score / 5.0))
87
 
88
- # ====================== Face crop + oval mask ======================
89
-
90
- _mp_face = mp.solutions.face_detection.FaceDetection(
91
- model_selection=0, min_detection_confidence=0.4
92
- )
93
-
94
- def crop_face(pil_img, pad=0.25):
95
- img = np.array(pil_img.convert("RGB"))
96
- h, w = img.shape[:2]
97
- res = _mp_face.process(cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
98
- if not res.detections:
99
- return pil_img
100
- det = max(res.detections, key=lambda d: d.location_data.relative_bounding_box.width)
101
- b = det.location_data.relative_bounding_box
102
- x, y, bw, bh = b.xmin, b.ymin, b.width, b.height
103
- x1 = int(max(0, (x - pad*bw) * w)); y1 = int(max(0, (y - pad*bh) * h))
104
- x2 = int(min(w, (x + bw + pad*bw) * w)); y2 = int(min(h, (y + bh + pad*bh) * h))
105
- face = Image.fromarray(img[y1:y2, x1:x2])
106
- return face if face.size[0] > 20 and face.size[1] > 20 else pil_img
107
-
108
- def face_oval_mask(img_pil, shrink=0.80):
109
- w, h = img_pil.size
110
- mask = Image.new("L", (w, h), 0)
111
- draw = ImageDraw.Draw(mask)
112
- dx, dy = int((1 - shrink) * w / 2), int((1 - shrink) * h / 2)
113
- draw.ellipse((dx, dy, w - dx, h - dy), fill=255)
114
- return np.array(mask, dtype=np.float32) / 255.0
115
-
116
- # ====================== Natural texture correction ======================
117
-
118
- def natural_texture_correction(pil_img: Image.Image):
119
- gray = np.array(pil_img.convert("L"), dtype=np.float32) / 255.0
120
- grad_x = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3)
121
- grad_y = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3)
122
- edge_strength = np.mean(np.sqrt(grad_x**2 + grad_y**2))
123
- flatness = np.std(gray)
124
- ratio = edge_strength / (flatness + 1e-6) # small -> smooth/realistic
125
- corr = 1.0 - np.clip((0.15 - ratio) * 2.5, 0, 0.3)
126
- return float(np.clip(corr, 0.7, 1.0))
127
-
128
- # ====================== Decision layer ======================
129
-
130
- def combine_scores(ela_mean, hf_ratio, noise_incons_score, texture_corr=1.0):
131
  w1, w2, w3 = 0.30, 0.40, 0.30
132
  s_ela = np.clip(ela_mean * 3.0, 0, 1)
133
  s_hf = np.clip((hf_ratio - 0.65) / 0.25, 0, 1)
134
- s_noi = np.clip(noise_inconsistency, 0, 1) if False else np.clip(noise_incons_score, 0, 1) # keep identical behavior
135
- suspect = float((w1*s_ela + w2*s_hf + w3*s_noi) * texture_corr)
136
- label = "Likely Manipulated" if suspect >= 0.65 else "Likely Authentic"
137
- return label, suspect
138
-
139
- # ====================== Gradio handler ======================
140
 
141
- def _result_card(label: str, conf: float) -> str:
 
142
  pct = max(0.0, min(1.0, conf)) * 100.0
143
  color = "#d84a4a" if label.startswith("Likely Manipulated") else "#2e7d32"
144
  bar_bg = "#e9ecef"
 
145
  return f"""
146
  <div style="max-width:860px;margin:0 auto;">
147
  <div style="border:1px solid #e5e7eb;border-radius:14px;padding:18px 20px;background:#fff;
@@ -154,63 +144,51 @@ def _result_card(label: str, conf: float) -> str:
154
  <div style="height:100%;width:{pct:.4f}%;background:{color};"></div>
155
  </div>
156
  </div>
 
157
  </div>
158
  """
159
 
160
- def analyze_simple(pil_img: Image.Image):
 
161
  if pil_img is None:
162
  return _result_card("Likely Authentic", 0.0)
163
-
164
- pil_img = crop_face(pil_img)
165
- pil_img = pil_img.convert("RGB").resize((512, 512))
166
-
167
- oval = face_oval_mask(pil_img, shrink=0.80)
168
-
169
- ela_peak, ela_avg = ela_sweep_mean(pil_img)
170
- ela_mean = ela_peak * (0.85 if ela_avg < 0.06 else 1.0)
171
-
172
- _, hf_ratio = fft_high_freq_ratio(pil_img, mask=oval)
173
- _, noi_score = noise_inconsistency(pil_img, mask=oval)
174
-
175
- texture_corr = natural_texture_correction(pil_img)
176
- label, conf = combine_scores(ela_mean, hf_ratio, noi_score, texture_corr)
177
-
178
- return _result_card(label, conf)
179
-
180
- # ====================== UI ======================
181
-
182
  CUSTOM_CSS = """
183
  .gradio-container {max-width: 980px !important;}
184
- /* Card-like uploader */
185
  .sleek-card {
186
  border: 1px solid #e5e7eb; border-radius: 16px; background: #fff;
187
  box-shadow: 0 2px 10px rgba(16,24,40,.04); padding: 18px;
188
  }
189
  """
190
-
191
- with gr.Blocks(title="Deepfake Detector", css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
192
- gr.Markdown(
193
- "<h2 style='text-align:center;margin-bottom:6px;'>Deepfake Detector</h2>"
194
- "<p style='text-align:center;color:#6b7280;'>Upload an image and get a single, clean likelihood estimate.</p>"
195
- )
196
-
197
  with gr.Row():
198
  with gr.Column(scale=6, elem_classes=["sleek-card"]):
199
- inp = gr.Image(
200
- type="pil",
201
- label="Upload / Paste Image",
202
- sources=["upload", "webcam", "clipboard"], # <-- fixed; 'url' not supported in your build
203
- height=420,
204
- show_label=True,
205
- interactive=True,
206
- )
207
  btn = gr.Button("Analyze", variant="primary", size="lg")
208
-
209
  with gr.Column(scale=6):
210
  out = gr.HTML()
211
-
212
- btn.click(analyze_simple, inputs=inp, outputs=out)
213
- inp.change(analyze_simple, inputs=inp, outputs=out)
214
 
215
  if __name__ == "__main__":
216
  demo.launch()
 
1
+ import io, os, numpy as np, gradio as gr
 
 
2
  from PIL import Image, ImageChops, ImageDraw
3
  import cv2
4
  from skimage import exposure
5
  import mediapipe as mp
6
 
7
+ # ====== HF model choice (pick one) ======
8
+ HF_MODEL_ID = os.getenv("HF_MODEL_ID", "prithivMLmods/Deep-Fake-Detector-v2-Model") # ViT 224
9
+ HF_IMAGE_SIZE = int(os.getenv("HF_IMAGE_SIZE", "224")) # 224 for v2 ViT, 512 for v1 SigLIP
10
+
11
+ # ====== HF imports (lazy so app can start even if transformers missing) ======
12
+ _hf_loaded = False
13
+ _hf_processor = None
14
+ _hf_model = None
15
+ def _try_load_hf():
16
+ global _hf_loaded, _hf_processor, _hf_model
17
+ if _hf_loaded:
18
+ return True
19
+ try:
20
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
21
+ _hf_processor = AutoImageProcessor.from_pretrained(HF_MODEL_ID)
22
+ _hf_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_ID)
23
+ _hf_model.eval()
24
+ _hf_loaded = True
25
+ return True
26
+ except Exception as e:
27
+ print("HF load failed:", e)
28
+ _hf_loaded = False
29
+ return False
30
+
31
+ def _hf_predict_proba(pil_rgb_face):
32
+ """Returns probability that image is deepfake, in [0,1]."""
33
+ import torch
34
+ with torch.no_grad():
35
+ inputs = _hf_processor(images=pil_rgb_face.resize((HF_IMAGE_SIZE, HF_IMAGE_SIZE)), return_tensors="pt")
36
+ outputs = _hf_model(**inputs)
37
+ logits = outputs.logits[0]
38
+ probs = torch.softmax(logits, dim=-1).cpu().numpy()
39
+ # Map label -> index; models commonly use ["Deepfake","Realism"] or ["fake","real"]
40
+ id2label = _hf_model.config.id2label
41
+ lab2idx = {v.lower(): k for k, v in _hf_model.config.label2id.items()}
42
+ # Try a few common names
43
+ deep_idx = lab2idx.get("deepfake", None)
44
+ if deep_idx is None:
45
+ deep_idx = lab2idx.get("fake", None)
46
+ if deep_idx is None:
47
+ # Heuristic: choose the class whose label name contains 'fake'
48
+ deep_idx = next((i for i, name in id2label.items() if "fake" in name.lower()), 0)
49
+ return float(probs[int(deep_idx)])
50
+
51
+ # ====== Face detect / crop (your pipeline) ======
52
+ _mp_face = mp.solutions.face_detection.FaceDetection(model_selection=0, min_detection_confidence=0.4)
53
 
54
+ def crop_face(pil_img, pad=0.25):
55
+ img = np.array(pil_img.convert("RGB"))
56
+ h, w = img.shape[:2]
57
+ res = _mp_face.process(cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
58
+ if not res.detections:
59
+ return pil_img
60
+ det = max(res.detections, key=lambda d: d.location_data.relative_bounding_box.width)
61
+ b = det.location_data.relative_bounding_box
62
+ x, y, bw, bh = b.xmin, b.ymin, b.width, b.height
63
+ x1 = int(max(0, (x - pad*bw) * w)); y1 = int(max(0, (y - pad*bh) * h))
64
+ x2 = int(min(w, (x + bw + pad*bw) * w)); y2 = int(min(h, (y + bh + pad*bh) * h))
65
+ face = Image.fromarray(img[y1:y2, x1:x2])
66
+ return face if face.size[0] > 20 and face.size[1] > 20 else pil_img
67
+
68
+ # ====== Heuristic fallback (unchanged core) ======
69
  def _enhance_for_display(pil_img, scale: float):
70
  arr = np.array(pil_img).astype("float32") * scale
71
  arr = np.clip(arr, 0, 255).astype("uint8")
 
74
  def error_level_analysis(pil_img: Image.Image, quality: int = 90):
75
  img = pil_img.convert("RGB")
76
  with io.BytesIO() as buf:
77
+ img.save(buf, "JPEG", quality=quality); buf.seek(0)
 
78
  comp = Image.open(buf).convert("RGB")
79
  diff = ImageChops.difference(img, comp)
80
+ extrema = diff.getextrema(); max_diff = max([m for (_, m) in extrema])
 
81
  scale = 255.0 / max(1, max_diff)
82
  ela_vis = _enhance_for_display(diff, scale)
83
  ela_np = np.array(ela_vis, dtype=np.float32)
 
87
  def ela_sweep_mean(pil_img, qualities=(95, 90, 85)):
88
  vals = []
89
  for q in qualities:
90
+ _, m = error_level_analysis(pil_img, quality=q); vals.append(m)
 
91
  return float(max(vals)), float(np.mean(vals))
92
 
93
+ def fft_high_freq_ratio(pil_img: Image.Image):
 
 
94
  y = pil_img.convert("YCbCr").split()[0]
95
+ gray = np.array(y, dtype=np.float32)/255.0
 
 
 
96
  h, w = gray.shape
97
  wy, wx = np.hanning(h)[:, None], np.hanning(w)[None, :]
98
  F = np.fft.fftshift(np.fft.fft2(gray * (wy * wx)))
99
  mag = np.log1p(np.abs(F))
100
+ cy, cx = h//2, w//2
101
+ yy, xx = np.ogrid[:h, :w]; dist = np.sqrt((yy - cy)**2 + (xx - cx)**2)
 
 
102
  r_low = min(h, w) * 0.08
103
+ low = float(mag[dist <= r_low].sum()); high = float(mag[dist > r_low].sum())
104
+ return None, float(high / (high + low + 1e-9))
105
 
106
+ def noise_inconsistency(pil_img: Image.Image):
 
 
 
 
 
107
  y = pil_img.convert("YCbCr").split()[0]
108
  img = np.array(y, dtype=np.float32)
109
+ lap = cv2.Laplacian(img, cv2.CV_32F, ksize=3); lap_abs = np.abs(lap)
110
+ tile = 32; H, W = lap_abs.shape; vals = []
 
 
 
 
 
 
 
 
 
 
 
111
  for yy in range(0, H, tile):
112
  for xx in range(0, W, tile):
113
  patch = lap_abs[yy:min(yy+tile, H), xx:min(xx+tile, W)]
114
+ if patch.size: vals.append(patch.var())
115
+ if not vals: return None, 0.0
 
 
116
  vals = np.array(vals, dtype=np.float32)
117
  score = float(vals.std() / (vals.mean() + 1e-9))
118
  return None, float(np.tanh(score / 5.0))
119
 
120
+ def combine_scores(ela_mean, hf_ratio, noise_incons_score):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  w1, w2, w3 = 0.30, 0.40, 0.30
122
  s_ela = np.clip(ela_mean * 3.0, 0, 1)
123
  s_hf = np.clip((hf_ratio - 0.65) / 0.25, 0, 1)
124
+ s_noi = np.clip(noise_incons_score, 0, 1)
125
+ conf = float(w1*s_ela + w2*s_hf + w3*s_noi)
126
+ label = "Likely Manipulated" if conf >= 0.65 else "Likely Authentic"
127
+ return label, conf
 
 
128
 
129
+ # ====== Result card ======
130
+ def _result_card(label: str, conf: float, note: str | None = None) -> str:
131
  pct = max(0.0, min(1.0, conf)) * 100.0
132
  color = "#d84a4a" if label.startswith("Likely Manipulated") else "#2e7d32"
133
  bar_bg = "#e9ecef"
134
+ extra = f"<div style='color:#6b7280;font-size:12px;margin-top:10px;text-align:center;'>{note}</div>" if note else ""
135
  return f"""
136
  <div style="max-width:860px;margin:0 auto;">
137
  <div style="border:1px solid #e5e7eb;border-radius:14px;padding:18px 20px;background:#fff;
 
144
  <div style="height:100%;width:{pct:.4f}%;background:{color};"></div>
145
  </div>
146
  </div>
147
+ {extra}
148
  </div>
149
  """
150
 
151
+ # ====== Inference ======
152
+ def analyze(pil_img: Image.Image):
153
  if pil_img is None:
154
  return _result_card("Likely Authentic", 0.0)
155
+ face = crop_face(pil_img).convert("RGB")
156
+
157
+ if _try_load_hf():
158
+ prob_fake = _hf_predict_proba(face)
159
+ label = "Likely Manipulated" if prob_fake >= 0.5 else "Likely Authentic"
160
+ note = f"HF model: {HF_MODEL_ID}"
161
+ return _result_card(label, prob_fake, note=note)
162
+
163
+ # Fallback heuristic (if HF model failed)
164
+ face = face.resize((512, 512))
165
+ _, ela_mean = error_level_analysis(face, quality=90)
166
+ _, hf_ratio = fft_high_freq_ratio(face)
167
+ _, noi_score = noise_inconsistency(face)
168
+ label, conf = combine_scores(ela_mean, hf_ratio, noi_score)
169
+ return _result_card(label, conf, note="Heuristic fallback")
170
+
171
+ # ====== UI ======
 
 
172
  CUSTOM_CSS = """
173
  .gradio-container {max-width: 980px !important;}
 
174
  .sleek-card {
175
  border: 1px solid #e5e7eb; border-radius: 16px; background: #fff;
176
  box-shadow: 0 2px 10px rgba(16,24,40,.04); padding: 18px;
177
  }
178
  """
179
+ with gr.Blocks(title="Deepfake Detector (Pretrained HF Model)", css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
180
+ gr.Markdown("<h2 style='text-align:center;margin-bottom:6px;'>Deepfake Detector</h2>"
181
+ "<p style='text-align:center;color:#6b7280;'>Face-crop → pretrained classifier → single likelihood.</p>")
 
 
 
 
182
  with gr.Row():
183
  with gr.Column(scale=6, elem_classes=["sleek-card"]):
184
+ inp = gr.Image(type="pil", label="Upload / Paste Image",
185
+ sources=["upload", "webcam", "clipboard"],
186
+ height=420, show_label=True, interactive=True)
 
 
 
 
 
187
  btn = gr.Button("Analyze", variant="primary", size="lg")
 
188
  with gr.Column(scale=6):
189
  out = gr.HTML()
190
+ btn.click(analyze, inputs=inp, outputs=out)
191
+ inp.change(analyze, inputs=inp, outputs=out)
 
192
 
193
  if __name__ == "__main__":
194
  demo.launch()