IvanLayer7 commited on
Commit
97c892c
·
verified ·
1 Parent(s): 7e8036f

Upload 5 files

Browse files
Files changed (4) hide show
  1. .gitattributes +34 -35
  2. app.py +57 -17
  3. requirements.txt +1 -3
  4. whisper_classifier.py +15 -50
.gitattributes CHANGED
@@ -1,35 +1,34 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
 
app.py CHANGED
@@ -12,7 +12,7 @@ import warnings
12
 
13
  # Import our custom modules
14
  from audio_processor import AudioProcessor
15
- from whisper_classifier import HybridKeywordSpotter
16
 
17
  warnings.filterwarnings("ignore")
18
 
@@ -20,16 +20,27 @@ warnings.filterwarnings("ignore")
20
  class KeywordSpottingApp:
21
  """Main application class for the keyword spotting interface."""
22
 
23
- def __init__(self):
24
  """Initialize the application components."""
25
  print("Initializing Keyword Spotting App for Hugging Face...")
26
 
27
  # Initialize components
28
  self.audio_processor = AudioProcessor(target_sample_rate=48000, max_duration=30.0)
29
- self.classifier = HybridKeywordSpotter()
30
 
31
  print("App initialized successfully!")
32
 
 
 
 
 
 
 
 
 
 
 
 
33
  def process_audio_and_classify(
34
  self,
35
  audio_input: Optional[Tuple[int, np.ndarray]],
@@ -144,14 +155,21 @@ class KeywordSpottingApp:
144
  def create_gradio_interface():
145
  """Create and configure the Gradio interface for Hugging Face."""
146
 
147
- # Initialize the app
148
- app = KeywordSpottingApp()
149
 
150
- def classify_audio(audio_input, audio_file, keywords):
151
  """Wrapper function for Gradio interface."""
 
 
 
152
  results, status = app.process_audio_and_classify(audio_input, audio_file, keywords)
153
  formatted_results = app.format_results_for_display(results)
154
- return formatted_results, status
 
 
 
 
155
 
156
  # Create the interface
157
  with gr.Blocks(
@@ -173,13 +191,14 @@ def create_gradio_interface():
173
  gr.Markdown("""
174
  # 🎯 Zero-Shot Audio Keyword Spotting
175
 
176
- Detect keywords in Spanish audio using AI **without prior training**.
177
- Uses Whisper + CLAP models for accurate keyword detection.
178
 
179
  ## 📋 Instructions:
180
- 1. **Enter keywords** you want to detect (comma-separated)
181
- 2. **Record audio** using microphone OR **upload audio file**
182
- 3. **Click "Analyze Audio"** to get probability results
 
183
 
184
  ### 💡 Example Keywords:
185
  `hola, gracias, adiós, sí, no, por favor`
@@ -187,6 +206,14 @@ def create_gradio_interface():
187
 
188
  with gr.Row():
189
  with gr.Column(scale=1):
 
 
 
 
 
 
 
 
190
  gr.Markdown("### 🔤 Keywords")
191
  gr.Markdown("*Example: hola, gracias, adiós*")
192
  keywords_input = gr.Textbox(
@@ -233,12 +260,19 @@ def create_gradio_interface():
233
  interactive=False,
234
  elem_classes=["status-box"]
235
  )
 
 
 
 
 
 
 
236
 
237
  # Event handlers
238
  analyze_btn.click(
239
  fn=classify_audio,
240
- inputs=[audio_input, audio_file, keywords_input],
241
- outputs=[results_output, status_output]
242
  )
243
 
244
  # Examples section
@@ -259,10 +293,16 @@ def create_gradio_interface():
259
  - Works best with common Spanish words
260
 
261
  ## 🔧 Technical Details:
262
- - **Models**: Whisper (transcription) + CLAP (audio-text similarity)
263
- - **Languages**: Optimized for Spanish, works with others
264
  - **Processing**: Up to 30 seconds, 48kHz sampling rate
265
- - **Approach**: Hybrid zero-shot classification
 
 
 
 
 
 
266
  """)
267
 
268
  return interface
 
12
 
13
  # Import our custom modules
14
  from audio_processor import AudioProcessor
15
+ from whisper_classifier import WhisperKeywordSpotter
16
 
17
  warnings.filterwarnings("ignore")
18
 
 
20
  class KeywordSpottingApp:
21
  """Main application class for the keyword spotting interface."""
22
 
23
+ def __init__(self, model_size: str = "base"):
24
  """Initialize the application components."""
25
  print("Initializing Keyword Spotting App for Hugging Face...")
26
 
27
  # Initialize components
28
  self.audio_processor = AudioProcessor(target_sample_rate=48000, max_duration=30.0)
29
+ self.classifier = WhisperKeywordSpotter(model_size=model_size)
30
 
31
  print("App initialized successfully!")
32
 
33
+ def change_model(self, new_model_size: str) -> str:
34
+ """Change the Whisper model size."""
35
+ try:
36
+ success = self.classifier.change_model(new_model_size)
37
+ if success:
38
+ return f"✅ Successfully changed to {new_model_size} model"
39
+ else:
40
+ return f"❌ Failed to change to {new_model_size} model"
41
+ except Exception as e:
42
+ return f"❌ Error changing model: {str(e)}"
43
+
44
  def process_audio_and_classify(
45
  self,
46
  audio_input: Optional[Tuple[int, np.ndarray]],
 
155
  def create_gradio_interface():
156
  """Create and configure the Gradio interface for Hugging Face."""
157
 
158
+ # Initialize the app with default model
159
+ app = KeywordSpottingApp(model_size="base")
160
 
161
+ def classify_audio(audio_input, audio_file, keywords, model_size):
162
  """Wrapper function for Gradio interface."""
163
+ # Change model if needed
164
+ model_change_msg = app.change_model(model_size)
165
+
166
  results, status = app.process_audio_and_classify(audio_input, audio_file, keywords)
167
  formatted_results = app.format_results_for_display(results)
168
+
169
+ # Add model info to status
170
+ status_with_model = f"{status} | Model: {model_size}"
171
+
172
+ return formatted_results, status_with_model, model_change_msg
173
 
174
  # Create the interface
175
  with gr.Blocks(
 
191
  gr.Markdown("""
192
  # 🎯 Zero-Shot Audio Keyword Spotting
193
 
194
+ Detect keywords in Spanish audio using **Whisper AI** without prior training.
195
+ Transcribes audio and matches keywords with high accuracy.
196
 
197
  ## 📋 Instructions:
198
+ 1. **Select Whisper model** (tiny=fastest, medium=most accurate)
199
+ 2. **Enter keywords** you want to detect (comma-separated)
200
+ 3. **Record audio** using microphone OR **upload audio file**
201
+ 4. **Click "Analyze Audio"** to get results
202
 
203
  ### 💡 Example Keywords:
204
  `hola, gracias, adiós, sí, no, por favor`
 
206
 
207
  with gr.Row():
208
  with gr.Column(scale=1):
209
+ gr.Markdown("### 🤖 Model Selection")
210
+ model_selector = gr.Dropdown(
211
+ choices=["tiny", "base", "small", "medium"],
212
+ value="base",
213
+ label="Whisper Model",
214
+ info="tiny=fastest, base=balanced, small=better accuracy, medium=best accuracy"
215
+ )
216
+
217
  gr.Markdown("### 🔤 Keywords")
218
  gr.Markdown("*Example: hola, gracias, adiós*")
219
  keywords_input = gr.Textbox(
 
260
  interactive=False,
261
  elem_classes=["status-box"]
262
  )
263
+
264
+ model_status_output = gr.Textbox(
265
+ label="Model Status",
266
+ value="Current model: base",
267
+ interactive=False,
268
+ elem_classes=["status-box"]
269
+ )
270
 
271
  # Event handlers
272
  analyze_btn.click(
273
  fn=classify_audio,
274
+ inputs=[audio_input, audio_file, keywords_input, model_selector],
275
+ outputs=[results_output, status_output, model_status_output]
276
  )
277
 
278
  # Examples section
 
293
  - Works best with common Spanish words
294
 
295
  ## 🔧 Technical Details:
296
+ - **Model**: OpenAI Whisper (speech transcription)
297
+ - **Languages**: Optimized for Spanish, works with others
298
  - **Processing**: Up to 30 seconds, 48kHz sampling rate
299
+ - **Approach**: Transcription + text matching
300
+
301
+ ## 🤖 Model Comparison:
302
+ - **tiny**: Fastest, basic accuracy (72MB)
303
+ - **base**: Balanced speed/accuracy (139MB)
304
+ - **small**: Better accuracy, slower (461MB)
305
+ - **medium**: Best accuracy, slowest (1.46GB)
306
  """)
307
 
308
  return interface
requirements.txt CHANGED
@@ -1,9 +1,7 @@
1
- # Optimized requirements for Hugging Face Spaces
2
  gradio==4.44.0
3
  torch>=2.0.0
4
- transformers>=4.30.0
5
  librosa>=0.10.0
6
  numpy>=1.21.0
7
  soundfile>=0.12.0
8
  openai-whisper>=20231117
9
- scipy>=1.7.0
 
1
+ # Optimized requirements for Hugging Face Spaces - Whisper only
2
  gradio==4.44.0
3
  torch>=2.0.0
 
4
  librosa>=0.10.0
5
  numpy>=1.21.0
6
  soundfile>=0.12.0
7
  openai-whisper>=20231117
 
whisper_classifier.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
- Alternative keyword spotter using Whisper for transcription + text matching.
3
- This approach transcribes the audio first, then matches keywords in the text.
4
  """
5
 
6
  import torch
@@ -61,7 +61,7 @@ class WhisperKeywordSpotter:
61
  Transcribe audio using Whisper.
62
 
63
  Args:
64
- audio_tensor: Audio tensor (should be 16kHz for Whisper)
65
 
66
  Returns:
67
  Transcribed text
@@ -174,57 +174,22 @@ class WhisperKeywordSpotter:
174
  error_msg = f"Classification error: {str(e)}"
175
  print(error_msg)
176
  return {"error": error_msg}
177
-
178
-
179
- class HybridKeywordSpotter:
180
- """Hybrid approach combining multiple methods."""
181
-
182
- def __init__(self):
183
- """Initialize hybrid classifier."""
184
- self.whisper_spotter = None
185
- self.clap_spotter = None
186
-
187
- # Try to initialize Whisper
188
- try:
189
- if WHISPER_AVAILABLE:
190
- self.whisper_spotter = WhisperKeywordSpotter("base")
191
- except Exception as e:
192
- print(f"⚠️ Could not initialize Whisper: {e}")
193
-
194
- # Try to initialize CLAP as fallback
195
- try:
196
- from improved_classifier import ImprovedZeroShotKeywordSpotter
197
- self.clap_spotter = ImprovedZeroShotKeywordSpotter()
198
- except Exception as e:
199
- print(f"⚠️ Could not initialize CLAP: {e}")
200
 
201
- def classify_keywords(self, audio_tensor: torch.Tensor, keywords: str) -> Dict[str, float]:
202
  """
203
- Classify using the best available method.
204
 
205
  Args:
206
- audio_tensor: Preprocessed audio tensor
207
- keywords: Comma-separated keywords string
208
-
209
- Returns:
210
- Dictionary mapping keywords to probability scores
211
  """
212
- # Try Whisper first (usually more accurate for speech)
213
- if self.whisper_spotter:
 
214
  try:
215
- results = self.whisper_spotter.classify_keywords(audio_tensor, keywords)
216
- if "error" not in results:
217
- return results
218
  except Exception as e:
219
- print(f"Whisper failed: {e}")
220
-
221
- # Fallback to CLAP
222
- if self.clap_spotter:
223
- try:
224
- return self.clap_spotter.classify_keywords_simple(audio_tensor, keywords)
225
- except Exception as e:
226
- print(f"CLAP failed: {e}")
227
-
228
- # If all else fails
229
- keyword_list = keywords.split(",")
230
- return {kw.strip(): 0.0 for kw in keyword_list if kw.strip()}
 
1
  """
2
+ Whisper-only keyword spotter for zero-shot audio keyword detection.
3
+ Uses Whisper transcription + text matching without CLAP dependencies.
4
  """
5
 
6
  import torch
 
61
  Transcribe audio using Whisper.
62
 
63
  Args:
64
+ audio_tensor: Audio tensor (will be resampled for Whisper)
65
 
66
  Returns:
67
  Transcribed text
 
174
  error_msg = f"Classification error: {str(e)}"
175
  print(error_msg)
176
  return {"error": error_msg}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
+ def change_model(self, new_model_size: str):
179
  """
180
+ Change the Whisper model size.
181
 
182
  Args:
183
+ new_model_size: New model size to load
 
 
 
 
184
  """
185
+ if new_model_size != self.model_size:
186
+ print(f"Changing model from {self.model_size} to {new_model_size}")
187
+ self.model_size = new_model_size
188
  try:
189
+ self.model = whisper.load_model(new_model_size, device=self.device)
190
+ print(f"Successfully loaded {new_model_size} model!")
191
+ return True
192
  except Exception as e:
193
+ print(f"Error loading {new_model_size} model: {e}")
194
+ return False
195
+ return True