kimnamjoon0007 commited on
Commit
3c18ee6
·
verified ·
1 Parent(s): dcd9e51

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +231 -0
app.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AI Voice Detection - Hugging Face Spaces Demo
3
+ Detects AI-generated vs Human voices in multilingual audio
4
+ """
5
+
6
+ import os
7
+ import tempfile
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import gradio as gr
12
+ from transformers import Wav2Vec2Model
13
+ from pydub import AudioSegment
14
+ import librosa
15
+
16
+ # Configuration
17
+ MODEL_REPO = "kimnamjoon0007/lkht-v440"
18
+ TARGET_SR = 16000
19
+ MAX_DURATION = 10.0
20
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+
22
+ # Model architecture (must match training)
23
+ class W2VBertDeepfakeDetector(nn.Module):
24
+ def __init__(self, backbone, num_labels=2):
25
+ super().__init__()
26
+ self.backbone = backbone
27
+ hidden_size = backbone.config.hidden_size
28
+ self.dropout = nn.Dropout(0.1)
29
+ self.classifier = nn.Linear(hidden_size, num_labels)
30
+
31
+ def forward(self, input_values, attention_mask=None):
32
+ outputs = self.backbone(input_values=input_values, attention_mask=attention_mask)
33
+ hidden_states = outputs.last_hidden_state
34
+ pooled = hidden_states.mean(dim=1)
35
+ pooled = self.dropout(pooled)
36
+ logits = self.classifier(pooled)
37
+ return logits
38
+
39
+
40
+ # Load model
41
+ print("Loading model...")
42
+ backbone = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53")
43
+ model = W2VBertDeepfakeDetector(backbone, num_labels=2)
44
+
45
+ # Try to load from HF Hub
46
+ try:
47
+ from huggingface_hub import hf_hub_download
48
+ model_path = hf_hub_download(repo_id=MODEL_REPO, filename="best_model.pt")
49
+ state_dict = torch.load(model_path, map_location="cpu")
50
+ model.load_state_dict(state_dict)
51
+ print(f"✓ Loaded model from {MODEL_REPO}")
52
+ except Exception as e:
53
+ print(f"Warning: Could not load from HF Hub: {e}")
54
+ # Fallback to local file
55
+ if os.path.exists("best_model.pt"):
56
+ model.load_state_dict(torch.load("best_model.pt", map_location="cpu"))
57
+ print("✓ Loaded model from local file")
58
+
59
+ model.to(DEVICE)
60
+ model.eval()
61
+ print(f"Model ready on {DEVICE}")
62
+
63
+
64
+ def load_audio(audio_path):
65
+ """Load and preprocess audio file."""
66
+ try:
67
+ audio_segment = AudioSegment.from_file(audio_path)
68
+ samples = np.array(audio_segment.get_array_of_samples()).astype(np.float32)
69
+
70
+ if audio_segment.channels > 1:
71
+ samples = samples.reshape(-1, audio_segment.channels).mean(axis=1)
72
+
73
+ samples /= 32767.0
74
+ sr = audio_segment.frame_rate
75
+
76
+ if sr != TARGET_SR:
77
+ samples = librosa.resample(samples, orig_sr=sr, target_sr=TARGET_SR)
78
+
79
+ # Truncate to max duration
80
+ max_len = int(MAX_DURATION * TARGET_SR)
81
+ if len(samples) > max_len:
82
+ samples = samples[:max_len]
83
+
84
+ return torch.from_numpy(samples).float()
85
+ except Exception as e:
86
+ raise gr.Error(f"Error loading audio: {e}")
87
+
88
+
89
+ def classify_audio(audio_input):
90
+ """Main classification function for Gradio."""
91
+ if audio_input is None:
92
+ return "Please upload or record an audio file.", None
93
+
94
+ # Handle both file upload and microphone input
95
+ if isinstance(audio_input, tuple):
96
+ # Microphone input: (sample_rate, numpy_array)
97
+ sr, audio_data = audio_input
98
+ temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
99
+ import scipy.io.wavfile as wav
100
+ wav.write(temp_file.name, sr, audio_data)
101
+ audio_path = temp_file.name
102
+ else:
103
+ # File upload
104
+ audio_path = audio_input
105
+
106
+ try:
107
+ # Load and preprocess
108
+ waveform = load_audio(audio_path)
109
+ input_values = waveform.unsqueeze(0).to(DEVICE)
110
+
111
+ # Inference
112
+ with torch.no_grad():
113
+ logits = model(input_values)
114
+ probs = torch.softmax(logits, dim=-1)
115
+ pred_class = torch.argmax(probs, dim=-1).item()
116
+ confidence = probs[0, pred_class].item()
117
+
118
+ # Result
119
+ label = "🤖 AI-GENERATED" if pred_class == 1 else "👤 HUMAN"
120
+
121
+ # Create detailed result
122
+ result_text = f"""
123
+ ## Classification Result
124
+
125
+ **Verdict:** {label}
126
+
127
+ **Confidence:** {confidence:.1%}
128
+
129
+ ---
130
+
131
+ ### Probability Breakdown
132
+ - Human: {probs[0, 0].item():.1%}
133
+ - AI-Generated: {probs[0, 1].item():.1%}
134
+ """
135
+
136
+ # Create confidence bar data
137
+ confidence_data = {
138
+ "Human": float(probs[0, 0].item()),
139
+ "AI-Generated": float(probs[0, 1].item())
140
+ }
141
+
142
+ return result_text, confidence_data
143
+
144
+ except Exception as e:
145
+ return f"Error: {str(e)}", None
146
+
147
+ finally:
148
+ # Cleanup temp file if created
149
+ if isinstance(audio_input, tuple) and os.path.exists(audio_path):
150
+ os.remove(audio_path)
151
+
152
+
153
+ # Gradio Interface
154
+ with gr.Blocks(
155
+ title="AI Voice Detection",
156
+ theme=gr.themes.Soft(primary_hue="blue"),
157
+ css="""
158
+ .gradio-container { max-width: 800px; margin: auto; }
159
+ .result-box { font-size: 1.2em; }
160
+ """
161
+ ) as demo:
162
+
163
+ gr.Markdown("""
164
+ # 🎤 AI Voice Detection
165
+
166
+ Detect whether an audio clip is **AI-generated** or spoken by a **human**.
167
+
168
+ ### Supported Languages
169
+ Tamil • English • Hindi • Malayalam • Telugu
170
+
171
+ ---
172
+ """)
173
+
174
+ with gr.Row():
175
+ with gr.Column(scale=1):
176
+ audio_input = gr.Audio(
177
+ label="Upload or Record Audio",
178
+ type="filepath",
179
+ sources=["upload", "microphone"]
180
+ )
181
+
182
+ submit_btn = gr.Button("🔍 Analyze", variant="primary", size="lg")
183
+
184
+ gr.Markdown("""
185
+ **Tips:**
186
+ - Upload MP3, WAV, or other audio formats
187
+ - Or use microphone to record directly
188
+ - Audio will be analyzed up to 10 seconds
189
+ """)
190
+
191
+ with gr.Column(scale=1):
192
+ result_output = gr.Markdown(
193
+ label="Result",
194
+ elem_classes=["result-box"]
195
+ )
196
+
197
+ confidence_chart = gr.Label(
198
+ label="Confidence Scores",
199
+ num_top_classes=2
200
+ )
201
+
202
+ # Event handlers
203
+ submit_btn.click(
204
+ fn=classify_audio,
205
+ inputs=[audio_input],
206
+ outputs=[result_output, confidence_chart]
207
+ )
208
+
209
+ audio_input.change(
210
+ fn=classify_audio,
211
+ inputs=[audio_input],
212
+ outputs=[result_output, confidence_chart]
213
+ )
214
+
215
+ gr.Markdown("""
216
+ ---
217
+
218
+ ### About
219
+
220
+ This model uses **Wav2Vec2-large-xlsr-53** as the backbone, fine-tuned for AI voice detection.
221
+
222
+ - **Accuracy:** 99.69%
223
+ - **AUROC:** 1.0
224
+ - **EER:** 0.25%
225
+
226
+ [View Model on Hugging Face](https://huggingface.co/kimnamjoon0007/lkht-v440)
227
+ """)
228
+
229
+ # Launch
230
+ if __name__ == "__main__":
231
+ demo.launch()