IvanLayer7 commited on
Commit
e685c03
·
verified ·
1 Parent(s): 4414822

Upload 5 files

Browse files
Files changed (5) hide show
  1. app_hf.py +284 -0
  2. audio_processor.py +137 -0
  3. config.py +98 -0
  4. requirements_hf.txt +9 -0
  5. whisper_classifier.py +230 -0
app_hf.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face Spaces version of the Keyword Spotting App.
3
+ Simplified for deployment without local authentication.
4
+ """
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ import os
10
+ from typing import Dict, Any, Tuple, Optional
11
+ import warnings
12
+
13
+ # Import our custom modules
14
+ from audio_processor import AudioProcessor
15
+ from whisper_classifier import HybridKeywordSpotter
16
+
17
+ warnings.filterwarnings("ignore")
18
+
19
+
20
+ class KeywordSpottingApp:
21
+ """Main application class for the keyword spotting interface."""
22
+
23
+ def __init__(self):
24
+ """Initialize the application components."""
25
+ print("Initializing Keyword Spotting App for Hugging Face...")
26
+
27
+ # Initialize components
28
+ self.audio_processor = AudioProcessor(target_sample_rate=48000, max_duration=30.0)
29
+ self.classifier = HybridKeywordSpotter()
30
+
31
+ print("App initialized successfully!")
32
+
33
+ def process_audio_and_classify(
34
+ self,
35
+ audio_input: Optional[Tuple[int, np.ndarray]],
36
+ audio_file: Optional[str],
37
+ keywords: str
38
+ ) -> Tuple[Dict[str, float], str]:
39
+ """
40
+ Process audio input and perform keyword classification.
41
+
42
+ Args:
43
+ audio_input: Tuple of (sample_rate, audio_array) from microphone
44
+ audio_file: Path to uploaded audio file
45
+ keywords: Comma-separated keywords string
46
+
47
+ Returns:
48
+ Tuple of (classification_results, status_message)
49
+ """
50
+ try:
51
+ # Validate keywords
52
+ if not keywords or not keywords.strip():
53
+ return {}, "❌ Por favor, ingrese al menos una palabra clave."
54
+
55
+ # Determine audio source and process
56
+ audio_tensor = None
57
+ source_info = ""
58
+
59
+ if audio_file is not None:
60
+ # Process uploaded file
61
+ try:
62
+ audio_tensor = self.audio_processor.process_audio_file(audio_file)
63
+ source_info = f"📁 Archivo: {os.path.basename(audio_file)}"
64
+ except Exception as e:
65
+ return {}, f"❌ Error procesando archivo: {str(e)}"
66
+
67
+ elif audio_input is not None:
68
+ # Process microphone input
69
+ try:
70
+ sample_rate, audio_array = audio_input
71
+ # Convert to float32 if needed
72
+ if audio_array.dtype == np.int16:
73
+ audio_array = audio_array.astype(np.float32) / 32768.0
74
+ elif audio_array.dtype == np.int32:
75
+ audio_array = audio_array.astype(np.float32) / 2147483648.0
76
+
77
+ audio_tensor = self.audio_processor.process_audio_array(audio_array, sample_rate)
78
+ source_info = "🎤 Micrófono"
79
+ except Exception as e:
80
+ return {}, f"❌ Error procesando audio del micrófono: {str(e)}"
81
+ else:
82
+ return {}, "❌ Por favor, grabe audio o suba un archivo de audio."
83
+
84
+ # Perform classification
85
+ results = self.classifier.classify_keywords(audio_tensor, keywords)
86
+
87
+ if "error" in results:
88
+ return {}, f"❌ Error en clasificación: {results['error']}"
89
+
90
+ # Create status message
91
+ num_keywords = len([k for k in keywords.split(",") if k.strip()])
92
+ status_msg = f"✅ Clasificación completada | {source_info} | {num_keywords} palabra(s) clave"
93
+
94
+ return results, status_msg
95
+
96
+ except Exception as e:
97
+ error_msg = f"❌ Error inesperado: {str(e)}"
98
+ print(error_msg)
99
+ return {}, error_msg
100
+
101
+ def format_results_for_display(self, results: Dict[str, float]) -> str:
102
+ """
103
+ Format classification results for display.
104
+
105
+ Args:
106
+ results: Classification results dictionary
107
+
108
+ Returns:
109
+ Formatted string for display
110
+ """
111
+ if not results:
112
+ return "No hay resultados para mostrar."
113
+
114
+ if "error" in results:
115
+ return f"Error: {results['error']}"
116
+
117
+ # Sort results by probability (descending)
118
+ sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True)
119
+
120
+ output_lines = ["📊 **Resultados de Clasificación:**\n"]
121
+
122
+ for keyword, probability in sorted_results:
123
+ # Create visual probability bar
124
+ bar_length = 20
125
+ filled_length = int(bar_length * probability)
126
+ bar = "█" * filled_length + "░" * (bar_length - filled_length)
127
+
128
+ # Color coding based on probability
129
+ if probability >= 0.7:
130
+ emoji = "🟢" # High confidence
131
+ elif probability >= 0.4:
132
+ emoji = "🟡" # Medium confidence
133
+ else:
134
+ emoji = "🔴" # Low confidence
135
+
136
+ percentage = probability * 100
137
+ output_lines.append(
138
+ f"{emoji} **{keyword.upper()}**: {percentage:.1f}% [{bar}]"
139
+ )
140
+
141
+ return "\n".join(output_lines)
142
+
143
+
144
+ def create_gradio_interface():
145
+ """Create and configure the Gradio interface for Hugging Face."""
146
+
147
+ # Initialize the app
148
+ app = KeywordSpottingApp()
149
+
150
+ def classify_audio(audio_input, audio_file, keywords):
151
+ """Wrapper function for Gradio interface."""
152
+ results, status = app.process_audio_and_classify(audio_input, audio_file, keywords)
153
+ formatted_results = app.format_results_for_display(results)
154
+ return formatted_results, status
155
+
156
+ # Create the interface
157
+ with gr.Blocks(
158
+ title="🎯 Zero-Shot Audio Keyword Spotting",
159
+ theme=gr.themes.Soft(),
160
+ css="""
161
+ .gradio-container {
162
+ max-width: 900px !important;
163
+ margin: auto !important;
164
+ }
165
+ .status-box {
166
+ padding: 10px;
167
+ border-radius: 5px;
168
+ margin: 10px 0;
169
+ }
170
+ """
171
+ ) as interface:
172
+
173
+ gr.Markdown("""
174
+ # 🎯 Zero-Shot Audio Keyword Spotting
175
+
176
+ Detect keywords in Spanish audio using AI **without prior training**.
177
+ Uses Whisper + CLAP models for accurate keyword detection.
178
+
179
+ ## 📋 Instructions:
180
+ 1. **Enter keywords** you want to detect (comma-separated)
181
+ 2. **Record audio** using microphone OR **upload audio file**
182
+ 3. **Click "Analyze Audio"** to get probability results
183
+
184
+ ### 💡 Example Keywords:
185
+ `hola, gracias, adiós, sí, no, por favor`
186
+ """)
187
+
188
+ with gr.Row():
189
+ with gr.Column(scale=1):
190
+ gr.Markdown("### 🔤 Keywords")
191
+ gr.Markdown("*Example: hola, gracias, adiós*")
192
+ keywords_input = gr.Textbox(
193
+ label="Keywords (comma-separated)",
194
+ placeholder="hola, gracias, adiós, sí, no",
195
+ lines=2
196
+ )
197
+
198
+ gr.Markdown("### 🎵 Audio Input")
199
+
200
+ with gr.Tab("🎤 Record Audio"):
201
+ gr.Markdown("*Click to record (max 30 seconds)*")
202
+ audio_input = gr.Audio(
203
+ sources=["microphone"],
204
+ type="numpy",
205
+ label="Record your audio here"
206
+ )
207
+
208
+ with gr.Tab("📁 Upload File"):
209
+ gr.Markdown("*Supported: WAV, MP3, M4A, etc.*")
210
+ audio_file = gr.Audio(
211
+ sources=["upload"],
212
+ type="filepath",
213
+ label="Upload audio file"
214
+ )
215
+
216
+ analyze_btn = gr.Button(
217
+ "🔍 Analyze Audio",
218
+ variant="primary",
219
+ size="lg"
220
+ )
221
+
222
+ with gr.Column(scale=1):
223
+ gr.Markdown("### 📊 Results")
224
+
225
+ results_output = gr.Markdown(
226
+ value="Results will appear here after analysis...",
227
+ label="Classification Results"
228
+ )
229
+
230
+ status_output = gr.Textbox(
231
+ label="Status",
232
+ value="Ready to analyze",
233
+ interactive=False,
234
+ elem_classes=["status-box"]
235
+ )
236
+
237
+ # Event handlers
238
+ analyze_btn.click(
239
+ fn=classify_audio,
240
+ inputs=[audio_input, audio_file, keywords_input],
241
+ outputs=[results_output, status_output]
242
+ )
243
+
244
+ # Examples section
245
+ gr.Markdown("""
246
+ ## 💡 Usage Examples:
247
+
248
+ **Suggested Spanish keywords:**
249
+ - Greetings: `hola, buenos días, buenas tardes, adiós`
250
+ - Courtesy: `gracias, por favor, disculpe, perdón`
251
+ - Responses: `sí, no, tal vez, claro`
252
+ - Numbers: `uno, dos, tres, cuatro, cinco`
253
+ - Colors: `rojo, azul, verde, amarillo`
254
+
255
+ **Tips:**
256
+ - Use clear audio without background noise
257
+ - Speak at normal speed
258
+ - Keywords can appear anywhere in the audio
259
+ - Works best with common Spanish words
260
+
261
+ ## 🔧 Technical Details:
262
+ - **Models**: Whisper (transcription) + CLAP (audio-text similarity)
263
+ - **Languages**: Optimized for Spanish, works with others
264
+ - **Processing**: Up to 30 seconds, 48kHz sampling rate
265
+ - **Approach**: Hybrid zero-shot classification
266
+ """)
267
+
268
+ return interface
269
+
270
+
271
+ # Main execution for Hugging Face Spaces
272
+ if __name__ == "__main__":
273
+ print("🚀 Starting Keyword Spotting App on Hugging Face Spaces...")
274
+
275
+ # Create and launch the interface
276
+ interface = create_gradio_interface()
277
+
278
+ # Launch without authentication (HF Spaces handles this)
279
+ interface.launch(
280
+ server_name="0.0.0.0",
281
+ server_port=7860,
282
+ share=False,
283
+ show_error=True
284
+ )
audio_processor.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio processing module for zero-shot keyword spotting.
3
+ Handles audio loading, preprocessing, and feature extraction.
4
+ """
5
+
6
+ import librosa
7
+ import numpy as np
8
+ import torch
9
+ import torchaudio
10
+ from typing import Union, Tuple
11
+ import warnings
12
+
13
+ warnings.filterwarnings("ignore")
14
+
15
+
16
+ class AudioProcessor:
17
+ """Handles audio preprocessing for the keyword spotting model."""
18
+
19
+ def __init__(self, target_sample_rate: int = 48000, max_duration: float = 30.0):
20
+ """
21
+ Initialize the audio processor.
22
+
23
+ Args:
24
+ target_sample_rate: Target sampling rate for audio processing
25
+ max_duration: Maximum audio duration in seconds
26
+ """
27
+ self.target_sample_rate = target_sample_rate
28
+ self.max_duration = max_duration
29
+ self.max_samples = int(target_sample_rate * max_duration)
30
+
31
+ def load_audio(self, audio_path: str) -> Tuple[np.ndarray, int]:
32
+ """
33
+ Load audio file and return waveform and sample rate.
34
+
35
+ Args:
36
+ audio_path: Path to the audio file
37
+
38
+ Returns:
39
+ Tuple of (waveform, sample_rate)
40
+ """
41
+ try:
42
+ # Use librosa for robust audio loading
43
+ waveform, sr = librosa.load(audio_path, sr=None)
44
+ return waveform, sr
45
+ except Exception as e:
46
+ raise ValueError(f"Error loading audio file: {str(e)}")
47
+
48
+ def preprocess_audio(self, waveform: np.ndarray, sample_rate: int) -> torch.Tensor:
49
+ """
50
+ Preprocess audio waveform for model input.
51
+
52
+ Args:
53
+ waveform: Audio waveform as numpy array
54
+ sample_rate: Original sample rate
55
+
56
+ Returns:
57
+ Preprocessed audio tensor
58
+ """
59
+ # Convert to float32 if needed
60
+ if waveform.dtype != np.float32:
61
+ waveform = waveform.astype(np.float32)
62
+
63
+ # Resample if necessary
64
+ if sample_rate != self.target_sample_rate:
65
+ waveform = librosa.resample(
66
+ waveform,
67
+ orig_sr=sample_rate,
68
+ target_sr=self.target_sample_rate
69
+ )
70
+
71
+ # Ensure mono audio
72
+ if len(waveform.shape) > 1:
73
+ waveform = librosa.to_mono(waveform)
74
+
75
+ # Trim or pad to max duration
76
+ if len(waveform) > self.max_samples:
77
+ # Trim to max duration
78
+ waveform = waveform[:self.max_samples]
79
+ elif len(waveform) < self.max_samples:
80
+ # Pad with zeros
81
+ padding = self.max_samples - len(waveform)
82
+ waveform = np.pad(waveform, (0, padding), mode='constant', constant_values=0)
83
+
84
+ # Normalize audio
85
+ waveform = self._normalize_audio(waveform)
86
+
87
+ # Convert to tensor
88
+ audio_tensor = torch.from_numpy(waveform).float()
89
+
90
+ return audio_tensor
91
+
92
+ def _normalize_audio(self, waveform: np.ndarray) -> np.ndarray:
93
+ """
94
+ Normalize audio waveform.
95
+
96
+ Args:
97
+ waveform: Input waveform
98
+
99
+ Returns:
100
+ Normalized waveform
101
+ """
102
+ # RMS normalization
103
+ rms = np.sqrt(np.mean(waveform**2))
104
+ if rms > 0:
105
+ waveform = waveform / (rms * 10) # Scale down to prevent clipping
106
+
107
+ # Clip to [-1, 1] range
108
+ waveform = np.clip(waveform, -1.0, 1.0)
109
+
110
+ return waveform
111
+
112
+ def process_audio_file(self, audio_path: str) -> torch.Tensor:
113
+ """
114
+ Complete audio processing pipeline from file to tensor.
115
+
116
+ Args:
117
+ audio_path: Path to audio file
118
+
119
+ Returns:
120
+ Preprocessed audio tensor ready for model input
121
+ """
122
+ waveform, sample_rate = self.load_audio(audio_path)
123
+ processed_audio = self.preprocess_audio(waveform, sample_rate)
124
+ return processed_audio
125
+
126
+ def process_audio_array(self, audio_array: np.ndarray, sample_rate: int) -> torch.Tensor:
127
+ """
128
+ Process audio from numpy array (e.g., from Gradio microphone input).
129
+
130
+ Args:
131
+ audio_array: Audio data as numpy array
132
+ sample_rate: Sample rate of the audio
133
+
134
+ Returns:
135
+ Preprocessed audio tensor
136
+ """
137
+ return self.preprocess_audio(audio_array, sample_rate)
config.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration file for the Keyword Spotting App.
3
+ Contains authentication and app settings.
4
+ """
5
+
6
+ import os
7
+ from typing import Tuple, Optional
8
+
9
+
10
+ class AppConfig:
11
+ """Configuration class for the app."""
12
+
13
+ # Default authentication settings
14
+ DEFAULT_USERNAME = "admin"
15
+ DEFAULT_PASSWORD = "kws2024"
16
+
17
+ # App settings
18
+ DEFAULT_PORT = 7860
19
+ DEFAULT_HOST = "0.0.0.0"
20
+
21
+ @staticmethod
22
+ def get_auth_credentials() -> Optional[Tuple[str, str]]:
23
+ """
24
+ Get authentication credentials from environment variables or defaults.
25
+
26
+ Returns:
27
+ Tuple of (username, password) or None to disable auth
28
+ """
29
+ # Check environment variables first
30
+ username = os.getenv("KWS_USERNAME")
31
+ password = os.getenv("KWS_PASSWORD")
32
+
33
+ # Check if authentication should be disabled
34
+ if os.getenv("KWS_NO_AUTH", "").lower() in ["true", "1", "yes"]:
35
+ return None
36
+
37
+ # Use environment variables if available, otherwise use defaults
38
+ if username and password:
39
+ return (username, password)
40
+ else:
41
+ return (AppConfig.DEFAULT_USERNAME, AppConfig.DEFAULT_PASSWORD)
42
+
43
+ @staticmethod
44
+ def get_server_config() -> dict:
45
+ """
46
+ Get server configuration.
47
+
48
+ Returns:
49
+ Dictionary with server configuration
50
+ """
51
+ return {
52
+ "server_name": os.getenv("KWS_HOST", AppConfig.DEFAULT_HOST),
53
+ "server_port": int(os.getenv("KWS_PORT", AppConfig.DEFAULT_PORT)),
54
+ "share": os.getenv("KWS_SHARE", "false").lower() in ["true", "1", "yes"],
55
+ "debug": os.getenv("KWS_DEBUG", "false").lower() in ["true", "1", "yes"],
56
+ }
57
+
58
+ @staticmethod
59
+ def print_config_info():
60
+ """Print configuration information."""
61
+ auth = AppConfig.get_auth_credentials()
62
+ config = AppConfig.get_server_config()
63
+
64
+ print("🔧 Configuración de la aplicación:")
65
+ print(f" Host: {config['server_name']}")
66
+ print(f" Puerto: {config['server_port']}")
67
+ print(f" Compartir públicamente: {config['share']}")
68
+ print(f" Modo debug: {config['debug']}")
69
+
70
+ if auth:
71
+ print(f"🔐 Autenticación habilitada:")
72
+ print(f" Usuario: {auth[0]}")
73
+ print(f" Contraseña: {auth[1]}")
74
+ else:
75
+ print("🔓 Autenticación deshabilitada")
76
+
77
+ print("\n💡 Para cambiar la configuración, use variables de entorno:")
78
+ print(" KWS_USERNAME=tu_usuario")
79
+ print(" KWS_PASSWORD=tu_contraseña")
80
+ print(" KWS_NO_AUTH=true (para deshabilitar autenticación)")
81
+ print(" KWS_HOST=127.0.0.1 (para acceso local únicamente)")
82
+ print(" KWS_PORT=8080 (para cambiar puerto)")
83
+ print(" KWS_SHARE=true (para crear enlace público)")
84
+ print(" KWS_DEBUG=true (para modo debug)")
85
+
86
+
87
+ # Quick access functions
88
+ def get_auth() -> Optional[Tuple[str, str]]:
89
+ """Quick function to get auth credentials."""
90
+ return AppConfig.get_auth_credentials()
91
+
92
+ def get_server_config() -> dict:
93
+ """Quick function to get server config."""
94
+ return AppConfig.get_server_config()
95
+
96
+ def print_config() -> None:
97
+ """Quick function to print config."""
98
+ AppConfig.print_config_info()
requirements_hf.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Optimized requirements for Hugging Face Spaces
2
+ gradio==4.44.0
3
+ torch>=2.0.0
4
+ transformers>=4.30.0
5
+ librosa>=0.10.0
6
+ numpy>=1.21.0
7
+ soundfile>=0.12.0
8
+ openai-whisper>=20231117
9
+ scipy>=1.7.0
whisper_classifier.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Alternative keyword spotter using Whisper for transcription + text matching.
3
+ This approach transcribes the audio first, then matches keywords in the text.
4
+ """
5
+
6
+ import torch
7
+ import numpy as np
8
+ from typing import List, Dict
9
+ import warnings
10
+ import re
11
+ from difflib import SequenceMatcher
12
+
13
+ warnings.filterwarnings("ignore")
14
+
15
+ try:
16
+ import whisper
17
+ WHISPER_AVAILABLE = True
18
+ except ImportError:
19
+ WHISPER_AVAILABLE = False
20
+ print("⚠️ Whisper not available. Install with: pip install openai-whisper")
21
+
22
+
23
+ class WhisperKeywordSpotter:
24
+ """Keyword spotter using Whisper transcription + text matching."""
25
+
26
+ def __init__(self, model_size: str = "base"):
27
+ """
28
+ Initialize the Whisper-based keyword spotter.
29
+
30
+ Args:
31
+ model_size: Whisper model size ('tiny', 'base', 'small', 'medium', 'large')
32
+ """
33
+ if not WHISPER_AVAILABLE:
34
+ raise ImportError("Whisper is not available. Install with: pip install openai-whisper")
35
+
36
+ self.model_size = model_size
37
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+
39
+ print(f"Loading Whisper model: {model_size}")
40
+ print(f"Using device: {self.device}")
41
+
42
+ try:
43
+ self.model = whisper.load_model(model_size, device=self.device)
44
+ print("Whisper model loaded successfully!")
45
+ except Exception as e:
46
+ print(f"Error loading Whisper model: {e}")
47
+ raise
48
+
49
+ def prepare_keywords(self, keywords: str) -> List[str]:
50
+ """Prepare and clean keyword list."""
51
+ if not keywords.strip():
52
+ return []
53
+
54
+ keyword_list = [kw.strip().lower() for kw in keywords.split(",")]
55
+ keyword_list = [kw for kw in keyword_list if kw]
56
+
57
+ return keyword_list
58
+
59
+ def transcribe_audio(self, audio_tensor: torch.Tensor) -> str:
60
+ """
61
+ Transcribe audio using Whisper.
62
+
63
+ Args:
64
+ audio_tensor: Audio tensor (should be 16kHz for Whisper)
65
+
66
+ Returns:
67
+ Transcribed text
68
+ """
69
+ try:
70
+ # Convert to numpy and ensure it's float32
71
+ audio_np = audio_tensor.numpy().astype(np.float32)
72
+
73
+ # Whisper expects 16kHz, but our audio is 48kHz, so we need to resample
74
+ # Simple downsampling (not ideal but works for testing)
75
+ if len(audio_np) > 16000 * 30: # If longer than 30 seconds at 16kHz
76
+ # Downsample from 48kHz to 16kHz
77
+ audio_np = audio_np[::3] # Simple decimation
78
+
79
+ # Ensure audio is in the right range [-1, 1]
80
+ if audio_np.max() > 1.0 or audio_np.min() < -1.0:
81
+ audio_np = np.clip(audio_np, -1.0, 1.0)
82
+
83
+ # Transcribe
84
+ result = self.model.transcribe(
85
+ audio_np,
86
+ language="es", # Spanish
87
+ task="transcribe",
88
+ fp16=False,
89
+ verbose=False
90
+ )
91
+
92
+ transcription = result["text"].strip().lower()
93
+ print(f"📝 Transcription: '{transcription}'")
94
+
95
+ return transcription
96
+
97
+ except Exception as e:
98
+ print(f"Error transcribing audio: {e}")
99
+ return ""
100
+
101
+ def calculate_keyword_similarity(self, transcription: str, keyword: str) -> float:
102
+ """
103
+ Calculate similarity between transcription and keyword.
104
+
105
+ Args:
106
+ transcription: Transcribed text
107
+ keyword: Target keyword
108
+
109
+ Returns:
110
+ Similarity score (0-1)
111
+ """
112
+ if not transcription or not keyword:
113
+ return 0.0
114
+
115
+ # Method 1: Exact match
116
+ if keyword in transcription:
117
+ return 1.0
118
+
119
+ # Method 2: Word boundary match
120
+ word_pattern = r'\b' + re.escape(keyword) + r'\b'
121
+ if re.search(word_pattern, transcription):
122
+ return 1.0
123
+
124
+ # Method 3: Fuzzy matching for each word in transcription
125
+ words = transcription.split()
126
+ max_similarity = 0.0
127
+
128
+ for word in words:
129
+ # Clean word (remove punctuation)
130
+ clean_word = re.sub(r'[^\w]', '', word)
131
+ if clean_word:
132
+ similarity = SequenceMatcher(None, clean_word, keyword).ratio()
133
+ max_similarity = max(max_similarity, similarity)
134
+
135
+ # Method 4: Overall sequence similarity as fallback
136
+ overall_similarity = SequenceMatcher(None, transcription, keyword).ratio()
137
+
138
+ return max(max_similarity, overall_similarity * 0.7) # Weight overall similarity less
139
+
140
+ def classify_keywords(self, audio_tensor: torch.Tensor, keywords: str) -> Dict[str, float]:
141
+ """
142
+ Perform keyword classification using transcription.
143
+
144
+ Args:
145
+ audio_tensor: Preprocessed audio tensor
146
+ keywords: Comma-separated keywords string
147
+
148
+ Returns:
149
+ Dictionary mapping keywords to probability scores
150
+ """
151
+ try:
152
+ # Prepare keywords
153
+ keyword_list = self.prepare_keywords(keywords)
154
+
155
+ if not keyword_list:
156
+ return {"error": "No valid keywords provided"}
157
+
158
+ # Transcribe audio
159
+ transcription = self.transcribe_audio(audio_tensor)
160
+
161
+ if not transcription:
162
+ # If no transcription, return low scores
163
+ return {keyword: 0.1 for keyword in keyword_list}
164
+
165
+ # Calculate similarities
166
+ results = {}
167
+ for keyword in keyword_list:
168
+ similarity = self.calculate_keyword_similarity(transcription, keyword)
169
+ results[keyword] = round(similarity, 4)
170
+
171
+ return results
172
+
173
+ except Exception as e:
174
+ error_msg = f"Classification error: {str(e)}"
175
+ print(error_msg)
176
+ return {"error": error_msg}
177
+
178
+
179
+ class HybridKeywordSpotter:
180
+ """Hybrid approach combining multiple methods."""
181
+
182
+ def __init__(self):
183
+ """Initialize hybrid classifier."""
184
+ self.whisper_spotter = None
185
+ self.clap_spotter = None
186
+
187
+ # Try to initialize Whisper
188
+ try:
189
+ if WHISPER_AVAILABLE:
190
+ self.whisper_spotter = WhisperKeywordSpotter("base")
191
+ except Exception as e:
192
+ print(f"⚠️ Could not initialize Whisper: {e}")
193
+
194
+ # Try to initialize CLAP as fallback
195
+ try:
196
+ from improved_classifier import ImprovedZeroShotKeywordSpotter
197
+ self.clap_spotter = ImprovedZeroShotKeywordSpotter()
198
+ except Exception as e:
199
+ print(f"⚠️ Could not initialize CLAP: {e}")
200
+
201
+ def classify_keywords(self, audio_tensor: torch.Tensor, keywords: str) -> Dict[str, float]:
202
+ """
203
+ Classify using the best available method.
204
+
205
+ Args:
206
+ audio_tensor: Preprocessed audio tensor
207
+ keywords: Comma-separated keywords string
208
+
209
+ Returns:
210
+ Dictionary mapping keywords to probability scores
211
+ """
212
+ # Try Whisper first (usually more accurate for speech)
213
+ if self.whisper_spotter:
214
+ try:
215
+ results = self.whisper_spotter.classify_keywords(audio_tensor, keywords)
216
+ if "error" not in results:
217
+ return results
218
+ except Exception as e:
219
+ print(f"Whisper failed: {e}")
220
+
221
+ # Fallback to CLAP
222
+ if self.clap_spotter:
223
+ try:
224
+ return self.clap_spotter.classify_keywords_simple(audio_tensor, keywords)
225
+ except Exception as e:
226
+ print(f"CLAP failed: {e}")
227
+
228
+ # If all else fails
229
+ keyword_list = keywords.split(",")
230
+ return {kw.strip(): 0.0 for kw in keyword_list if kw.strip()}