bhoumik12 commited on
Commit
a7c6634
·
verified ·
1 Parent(s): ed5d264

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +153 -0
  2. audio_inference.py +57 -0
  3. requirements.txt +0 -0
  4. web_backend.py +101 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ # ---- IMPORT BACKENDS ----
4
+ from web_backend import predict_image_pil
5
+ from audio_inference import predict_audio
6
+
7
+
8
+ # =========================
9
+ # IMAGE LOGIC (UNCHANGED)
10
+ # =========================
11
+ def analyze_image(image):
12
+ label, confidence, heatmap = predict_image_pil(image)
13
+
14
+ if label == "Fake":
15
+ if confidence >= 90:
16
+ risk = "🚨 High likelihood of Deepfake"
17
+ elif confidence >= 60:
18
+ risk = "⚠️ Possibly Deepfake"
19
+ else:
20
+ risk = "⚠️ Uncertain Deepfake"
21
+ else:
22
+ if confidence >= 90:
23
+ risk = "✅ Likely Real"
24
+ elif confidence >= 60:
25
+ risk = "⚠️ Possibly Real"
26
+ else:
27
+ risk = "⚠️ Uncertain – Needs Review"
28
+
29
+ return label, f"{confidence} %", risk, heatmap
30
+
31
+
32
+ # =========================
33
+ # AUDIO LOGIC (UNCHANGED)
34
+ # =========================
35
+ def analyze_audio(audio_path):
36
+ label, confidence = predict_audio(audio_path)
37
+
38
+ if label == "fake":
39
+ if confidence >= 90:
40
+ risk = "🚨 High likelihood of Deepfake"
41
+ elif confidence >= 60:
42
+ risk = "⚠️ Possibly Deepfake"
43
+ else:
44
+ risk = "⚠️ Uncertain – Needs Review"
45
+ else:
46
+ if confidence >= 90:
47
+ risk = "✅ Likely Real"
48
+ elif confidence >= 60:
49
+ risk = "⚠️ Possibly Real"
50
+ else:
51
+ risk = "⚠️ Uncertain – Needs Review"
52
+
53
+ return label.capitalize(), f"{confidence} %", risk
54
+
55
+
56
+ # =========================
57
+ # UI
58
+ # =========================
59
+ with gr.Blocks() as demo:
60
+ gr.Markdown("# 🧠 Unified Deepfake Detection System")
61
+
62
+ with gr.Tabs():
63
+
64
+ # =====================
65
+ # HOME TAB
66
+ # =====================
67
+ with gr.Tab("🏠 Home"):
68
+ gr.Markdown(
69
+ """
70
+ ## Welcome 👋
71
+ Select the type of media you want to analyze:
72
+ """
73
+ )
74
+
75
+ gr.Markdown("### 🔍 Choose Detection Mode")
76
+ gr.Markdown("- 🖼 **Image Deepfake Detection**\n- 🎧 **Audio Deepfake Detection**")
77
+
78
+ gr.Markdown(
79
+ """
80
+ 👉 Use the tabs above to switch between Image and Audio detection.
81
+ """
82
+ )
83
+
84
+ # =====================
85
+ # IMAGE TAB
86
+ # =====================
87
+ with gr.Tab("🖼 Image Deepfake"):
88
+ gr.Markdown("# 🖼 Deepfake Image Detection System")
89
+
90
+ with gr.Row():
91
+ with gr.Column(scale=1):
92
+ image_input = gr.Image(
93
+ label="Upload Image",
94
+ type="pil",
95
+ height=280
96
+ )
97
+ img_submit = gr.Button("Submit")
98
+ img_clear = gr.Button("Clear")
99
+
100
+ with gr.Column(scale=2):
101
+ img_pred = gr.Text(label="Prediction")
102
+ img_conf = gr.Text(label="Confidence")
103
+ img_risk = gr.Text(label="Risk Assessment")
104
+ img_heatmap = gr.Image(
105
+ label="Explainability Heatmap",
106
+ height=280
107
+ )
108
+
109
+ img_submit.click(
110
+ fn=analyze_image,
111
+ inputs=image_input,
112
+ outputs=[img_pred, img_conf, img_risk, img_heatmap]
113
+ )
114
+
115
+ img_clear.click(
116
+ fn=lambda: (None, "", "", None),
117
+ inputs=None,
118
+ outputs=[image_input, img_pred, img_conf, img_risk]
119
+ )
120
+
121
+ # =====================
122
+ # AUDIO TAB
123
+ # =====================
124
+ with gr.Tab("🎧 Audio Deepfake"):
125
+ gr.Markdown("# 🎧 Deepfake Audio Detection System")
126
+
127
+ with gr.Row():
128
+ with gr.Column(scale=1):
129
+ audio_input = gr.Audio(
130
+ label="Upload Audio (.wav)",
131
+ type="filepath"
132
+ )
133
+ aud_submit = gr.Button("Submit")
134
+ aud_clear = gr.Button("Clear")
135
+
136
+ with gr.Column(scale=2):
137
+ aud_pred = gr.Text(label="Prediction")
138
+ aud_conf = gr.Text(label="Confidence")
139
+ aud_risk = gr.Text(label="Risk Assessment")
140
+
141
+ aud_submit.click(
142
+ fn=analyze_audio,
143
+ inputs=audio_input,
144
+ outputs=[aud_pred, aud_conf, aud_risk]
145
+ )
146
+
147
+ aud_clear.click(
148
+ fn=lambda: (None, "", ""),
149
+ inputs=None,
150
+ outputs=[audio_input, aud_pred, aud_conf]
151
+ )
152
+
153
+ demo.launch()
audio_inference.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import librosa
3
+ import numpy as np
4
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
5
+
6
+ # =====================
7
+ # CONFIG
8
+ # =====================
9
+ MODEL_DIR = "exported_audio_model"
10
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
+ SR = 16000
12
+ MAX_SAMPLES = 8 * SR # 8 seconds
13
+
14
+ # =====================
15
+ # LOAD MODEL + PROCESSOR (ONCE)
16
+ # =====================
17
+ processor = Wav2Vec2Processor.from_pretrained(MODEL_DIR)
18
+ model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_DIR)
19
+ model.to(DEVICE)
20
+ model.eval()
21
+
22
+ # =====================
23
+ # PREDICT FUNCTION
24
+ # =====================
25
+ def predict_audio(wav_path):
26
+ # Load audio
27
+ audio, sr = librosa.load(wav_path, sr=SR, mono=True)
28
+
29
+ # Truncate if needed
30
+ if len(audio) > MAX_SAMPLES:
31
+ audio = audio[:MAX_SAMPLES]
32
+
33
+ # Processor handles padding
34
+ inputs = processor(
35
+ audio,
36
+ sampling_rate=SR,
37
+ return_tensors="pt",
38
+ padding=True,
39
+ return_attention_mask=True
40
+ )
41
+
42
+ input_values = inputs.input_values.to(DEVICE)
43
+ attention_mask = inputs.attention_mask.to(DEVICE)
44
+
45
+ with torch.no_grad():
46
+ outputs = model(
47
+ input_values=input_values,
48
+ attention_mask=attention_mask
49
+ )
50
+
51
+ probs = torch.softmax(outputs.logits, dim=1)[0]
52
+ pred_id = torch.argmax(probs).item()
53
+
54
+ label = model.config.id2label[pred_id]
55
+ confidence = probs[pred_id].item() * 100
56
+
57
+ return label, round(confidence, 2)
requirements.txt ADDED
File without changes
web_backend.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from transformers import ViTForImageClassification, ViTConfig
4
+ from PIL import Image
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ import io
8
+ import os
9
+
10
+ # -----------------------------
11
+ # Device
12
+ # -----------------------------
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+ # -----------------------------
16
+ # Model Setup (SAME AS CMD)
17
+ # -----------------------------
18
+ config = ViTConfig.from_pretrained(
19
+ "google/vit-base-patch16-224",
20
+ num_labels=2,
21
+ output_attentions=True
22
+ )
23
+
24
+ model = ViTForImageClassification.from_pretrained(
25
+ "google/vit-base-patch16-224",
26
+ config=config,
27
+ ignore_mismatched_sizes=True
28
+ )
29
+
30
+ if os.path.exists("model/vit_real_fake_best.pth"):
31
+ model.load_state_dict(
32
+ torch.load("model/vit_real_fake_best.pth", map_location=device)
33
+ )
34
+
35
+ model.to(device)
36
+ model.eval()
37
+
38
+ # -----------------------------
39
+ # Image Preprocessing (IDENTICAL)
40
+ # -----------------------------
41
+ transform = transforms.Compose([
42
+ transforms.Resize((224, 224)),
43
+ transforms.ToTensor(),
44
+ transforms.Normalize(
45
+ [0.485, 0.456, 0.406],
46
+ [0.229, 0.224, 0.225]
47
+ )
48
+ ])
49
+
50
+ # -----------------------------
51
+ # Attention Heatmap (IDENTICAL)
52
+ # -----------------------------
53
+ def get_attention_map(model, img_tensor):
54
+ with torch.no_grad():
55
+ outputs = model(img_tensor, output_attentions=True)
56
+ attn = outputs.attentions[-1].mean(dim=1)[0]
57
+ cls_attn = attn[0, 1:]
58
+
59
+ grid_size = int(cls_attn.size(0) ** 0.5)
60
+ cls_attn = cls_attn.reshape(grid_size, grid_size).cpu().numpy()
61
+ cls_attn = (cls_attn - cls_attn.min()) / (cls_attn.max() - cls_attn.min())
62
+
63
+ return cls_attn
64
+
65
+ def overlay_heatmap_on_image(image, heatmap):
66
+ heatmap = np.uint8(255 * heatmap)
67
+ heatmap = Image.fromarray(heatmap).resize(image.size)
68
+ heatmap_np = np.array(heatmap)
69
+
70
+ fig, ax = plt.subplots(figsize=(4, 4))
71
+ ax.imshow(image)
72
+ ax.imshow(heatmap_np, cmap="jet", alpha=0.5)
73
+ ax.axis("off")
74
+
75
+ buf = io.BytesIO()
76
+ plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
77
+ plt.close(fig)
78
+ buf.seek(0)
79
+
80
+ return Image.open(buf)
81
+
82
+ # -----------------------------
83
+ # Prediction Function (SOURCE OF TRUTH)
84
+ # -----------------------------
85
+ def predict_image_pil(image):
86
+ image = image.convert("RGB")
87
+ input_tensor = transform(image).unsqueeze(0).to(device)
88
+
89
+ with torch.no_grad():
90
+ outputs = model(input_tensor)
91
+ logits = outputs.logits
92
+ pred = torch.argmax(logits, dim=1).item()
93
+
94
+ label = "Fake" if pred == 0 else "Real"
95
+
96
+ attn_map = get_attention_map(model, input_tensor)
97
+ heatmap_img = overlay_heatmap_on_image(image, attn_map)
98
+
99
+ confidence = torch.softmax(logits, dim=1)[0][pred].item() * 100
100
+
101
+ return label, round(confidence, 2), heatmap_img