kimnamjoon0007 commited on
Commit
166d169
·
verified ·
1 Parent(s): 6517dff

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +83 -142
app.py CHANGED
@@ -19,7 +19,7 @@ TARGET_SR = 16000
19
  MAX_DURATION = 10.0
20
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
- # Model architecture (must match training)
23
  class W2VBertDeepfakeDetector(nn.Module):
24
  def __init__(self, backbone, num_labels=2):
25
  super().__init__()
@@ -42,7 +42,6 @@ print("Loading model...")
42
  backbone = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53")
43
  model = W2VBertDeepfakeDetector(backbone, num_labels=2)
44
 
45
- # Try to load from HF Hub
46
  try:
47
  from huggingface_hub import hf_hub_download
48
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename="best_model.pt")
@@ -51,7 +50,6 @@ try:
51
  print(f"✓ Loaded model from {MODEL_REPO}")
52
  except Exception as e:
53
  print(f"Warning: Could not load from HF Hub: {e}")
54
- # Fallback to local file
55
  if os.path.exists("best_model.pt"):
56
  model.load_state_dict(torch.load("best_model.pt", map_location="cpu"))
57
  print("✓ Loaded model from local file")
@@ -63,48 +61,42 @@ print(f"Model ready on {DEVICE}")
63
 
64
  def load_audio(audio_path):
65
  """Load and preprocess audio file."""
66
- try:
67
- audio_segment = AudioSegment.from_file(audio_path)
68
- samples = np.array(audio_segment.get_array_of_samples()).astype(np.float32)
69
-
70
- if audio_segment.channels > 1:
71
- samples = samples.reshape(-1, audio_segment.channels).mean(axis=1)
72
-
73
- samples /= 32767.0
74
- sr = audio_segment.frame_rate
75
-
76
- if sr != TARGET_SR:
77
- samples = librosa.resample(samples, orig_sr=sr, target_sr=TARGET_SR)
78
-
79
- # Truncate to max duration
80
- max_len = int(MAX_DURATION * TARGET_SR)
81
- if len(samples) > max_len:
82
- samples = samples[:max_len]
83
-
84
- return torch.from_numpy(samples).float()
85
- except Exception as e:
86
- raise gr.Error(f"Error loading audio: {e}")
87
 
88
 
89
  def classify_audio(audio_input):
90
- """Main classification function for Gradio."""
91
  if audio_input is None:
92
- return "Please upload or record an audio file.", None
93
-
94
- # Handle both file upload and microphone input
95
- if isinstance(audio_input, tuple):
96
- # Microphone input: (sample_rate, numpy_array)
97
- sr, audio_data = audio_input
98
- temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
99
- import scipy.io.wavfile as wav
100
- wav.write(temp_file.name, sr, audio_data)
101
- audio_path = temp_file.name
102
- else:
103
- # File upload
104
- audio_path = audio_input
105
 
106
  try:
107
- # Load and preprocess
 
 
 
 
 
 
 
 
 
 
108
  waveform = load_audio(audio_path)
109
  input_values = waveform.unsqueeze(0).to(DEVICE)
110
 
@@ -115,117 +107,66 @@ def classify_audio(audio_input):
115
  pred_class = torch.argmax(probs, dim=-1).item()
116
  confidence = probs[0, pred_class].item()
117
 
118
- # Result
119
- label = "🤖 AI-GENERATED" if pred_class == 1 else "👤 HUMAN"
120
 
121
- # Create detailed result
122
- result_text = f"""
123
- ## Classification Result
124
-
125
- **Verdict:** {label}
 
 
 
 
126
 
127
- **Confidence:** {confidence:.1%}
128
 
129
  ---
130
 
131
- ### Probability Breakdown
132
- - Human: {probs[0, 0].item():.1%}
133
- - AI-Generated: {probs[0, 1].item():.1%}
 
 
 
 
134
  """
135
-
136
- # Create confidence bar data
137
- confidence_data = {
138
- "Human": float(probs[0, 0].item()),
139
- "AI-Generated": float(probs[0, 1].item())
140
- }
141
-
142
- return result_text, confidence_data
143
 
144
  except Exception as e:
145
- return f"Error: {str(e)}", None
146
 
147
  finally:
148
- # Cleanup temp file if created
149
- if isinstance(audio_input, tuple) and os.path.exists(audio_path):
150
- os.remove(audio_path)
151
-
152
-
153
- # Gradio Interface
154
- with gr.Blocks(
155
- title="AI Voice Detection",
156
- theme=gr.themes.Soft(primary_hue="blue"),
157
- css="""
158
- .gradio-container { max-width: 800px; margin: auto; }
159
- .result-box { font-size: 1.2em; }
160
- """
161
- ) as demo:
162
-
163
- gr.Markdown("""
164
- # 🎤 AI Voice Detection
165
-
166
- Detect whether an audio clip is **AI-generated** or spoken by a **human**.
167
-
168
- ### Supported Languages
169
- Tamil • English • Hindi • Malayalam • Telugu
170
-
171
- ---
172
- """)
173
-
174
- with gr.Row():
175
- with gr.Column(scale=1):
176
- audio_input = gr.Audio(
177
- label="Upload or Record Audio",
178
- type="filepath",
179
- sources=["upload", "microphone"]
180
- )
181
-
182
- submit_btn = gr.Button("🔍 Analyze", variant="primary", size="lg")
183
-
184
- gr.Markdown("""
185
- **Tips:**
186
- - Upload MP3, WAV, or other audio formats
187
- - Or use microphone to record directly
188
- - Audio will be analyzed up to 10 seconds
189
- """)
190
-
191
- with gr.Column(scale=1):
192
- result_output = gr.Markdown(
193
- label="Result",
194
- elem_classes=["result-box"]
195
- )
196
-
197
- confidence_chart = gr.Label(
198
- label="Confidence Scores",
199
- num_top_classes=2
200
- )
201
-
202
- # Event handlers
203
- submit_btn.click(
204
- fn=classify_audio,
205
- inputs=[audio_input],
206
- outputs=[result_output, confidence_chart]
207
- )
208
-
209
- audio_input.change(
210
- fn=classify_audio,
211
- inputs=[audio_input],
212
- outputs=[result_output, confidence_chart]
213
- )
214
-
215
- gr.Markdown("""
216
- ---
217
-
218
- ### About
219
-
220
- This model uses **Wav2Vec2-large-xlsr-53** as the backbone, fine-tuned for AI voice detection.
221
-
222
- - **Accuracy:** 99.69%
223
- - **AUROC:** 1.0
224
- - **EER:** 0.25%
225
-
226
- [View Model on Hugging Face](https://huggingface.co/kimnamjoon0007/lkht-v440)
227
- """)
228
-
229
- # Launch
230
  if __name__ == "__main__":
231
- demo.launch()
 
19
  MAX_DURATION = 10.0
20
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
+
23
  class W2VBertDeepfakeDetector(nn.Module):
24
  def __init__(self, backbone, num_labels=2):
25
  super().__init__()
 
42
  backbone = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53")
43
  model = W2VBertDeepfakeDetector(backbone, num_labels=2)
44
 
 
45
  try:
46
  from huggingface_hub import hf_hub_download
47
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename="best_model.pt")
 
50
  print(f"✓ Loaded model from {MODEL_REPO}")
51
  except Exception as e:
52
  print(f"Warning: Could not load from HF Hub: {e}")
 
53
  if os.path.exists("best_model.pt"):
54
  model.load_state_dict(torch.load("best_model.pt", map_location="cpu"))
55
  print("✓ Loaded model from local file")
 
61
 
62
  def load_audio(audio_path):
63
  """Load and preprocess audio file."""
64
+ audio_segment = AudioSegment.from_file(audio_path)
65
+ samples = np.array(audio_segment.get_array_of_samples()).astype(np.float32)
66
+
67
+ if audio_segment.channels > 1:
68
+ samples = samples.reshape(-1, audio_segment.channels).mean(axis=1)
69
+
70
+ samples /= 32767.0
71
+ sr = audio_segment.frame_rate
72
+
73
+ if sr != TARGET_SR:
74
+ samples = librosa.resample(samples, orig_sr=sr, target_sr=TARGET_SR)
75
+
76
+ max_len = int(MAX_DURATION * TARGET_SR)
77
+ if len(samples) > max_len:
78
+ samples = samples[:max_len]
79
+
80
+ return torch.from_numpy(samples).float()
 
 
 
 
81
 
82
 
83
  def classify_audio(audio_input):
84
+ """Main classification function."""
85
  if audio_input is None:
86
+ return "⚠️ Please upload or record an audio file."
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  try:
89
+ # Handle tuple input from microphone (sample_rate, audio_array)
90
+ if isinstance(audio_input, tuple):
91
+ import scipy.io.wavfile as wav
92
+ sr, audio_data = audio_input
93
+ temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
94
+ wav.write(temp_file.name, sr, audio_data)
95
+ audio_path = temp_file.name
96
+ else:
97
+ audio_path = audio_input
98
+
99
+ # Load and process
100
  waveform = load_audio(audio_path)
101
  input_values = waveform.unsqueeze(0).to(DEVICE)
102
 
 
107
  pred_class = torch.argmax(probs, dim=-1).item()
108
  confidence = probs[0, pred_class].item()
109
 
110
+ human_prob = probs[0, 0].item() * 100
111
+ ai_prob = probs[0, 1].item() * 100
112
 
113
+ if pred_class == 1:
114
+ verdict = "🤖 AI-GENERATED"
115
+ color = "red"
116
+ else:
117
+ verdict = "👤 HUMAN"
118
+ color = "green"
119
+
120
+ result = f"""
121
+ ## Result: {verdict}
122
 
123
+ **Confidence: {confidence:.1%}**
124
 
125
  ---
126
 
127
+ | Category | Probability |
128
+ |----------|-------------|
129
+ | 👤 Human | {human_prob:.1f}% |
130
+ | 🤖 AI-Generated | {ai_prob:.1f}% |
131
+
132
+ ---
133
+ *Model: Wav2Vec2-large-xlsr-53 fine-tuned for voice detection*
134
  """
135
+ return result
 
 
 
 
 
 
 
136
 
137
  except Exception as e:
138
+ return f"Error processing audio: {str(e)}"
139
 
140
  finally:
141
+ if isinstance(audio_input, tuple) and 'audio_path' in locals():
142
+ try:
143
+ os.remove(audio_path)
144
+ except:
145
+ pass
146
+
147
+
148
+ # Simple Gradio Interface
149
+ demo = gr.Interface(
150
+ fn=classify_audio,
151
+ inputs=gr.Audio(
152
+ label="Upload or Record Audio",
153
+ type="filepath",
154
+ sources=["upload", "microphone"]
155
+ ),
156
+ outputs=gr.Markdown(label="Result"),
157
+ title="🎤 AI Voice Detection",
158
+ description="""
159
+ **Detect if audio is AI-generated or Human speech**
160
+
161
+ Supported languages: Tamil, English, Hindi, Malayalam, Telugu
162
+
163
+ Upload an audio file (MP3, WAV, etc.) or record directly using your microphone.
164
+ """,
165
+ examples=[],
166
+ theme=gr.themes.Soft(),
167
+ allow_flagging="never"
168
+ )
169
+
170
+ # Launch for HuggingFace Spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  if __name__ == "__main__":
172
+ demo.launch(server_name="0.0.0.0", server_port=7860)