File size: 7,006 Bytes
e685c03
97c892c
 
e685c03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97c892c
e685c03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97c892c
e685c03
97c892c
e685c03
 
97c892c
e685c03
97c892c
 
 
e685c03
97c892c
 
 
e685c03
97c892c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""

Whisper-only keyword spotter for zero-shot audio keyword detection.

Uses Whisper transcription + text matching without CLAP dependencies.

"""

import torch
import numpy as np
from typing import List, Dict
import warnings
import re
from difflib import SequenceMatcher

warnings.filterwarnings("ignore")

try:
    import whisper
    WHISPER_AVAILABLE = True
except ImportError:
    WHISPER_AVAILABLE = False
    print("⚠️ Whisper not available. Install with: pip install openai-whisper")


class WhisperKeywordSpotter:
    """Keyword spotter using Whisper transcription + text matching."""
    
    def __init__(self, model_size: str = "base"):
        """

        Initialize the Whisper-based keyword spotter.

        

        Args:

            model_size: Whisper model size ('tiny', 'base', 'small', 'medium', 'large')

        """
        if not WHISPER_AVAILABLE:
            raise ImportError("Whisper is not available. Install with: pip install openai-whisper")
        
        self.model_size = model_size
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        print(f"Loading Whisper model: {model_size}")
        print(f"Using device: {self.device}")
        
        try:
            self.model = whisper.load_model(model_size, device=self.device)
            print("Whisper model loaded successfully!")
        except Exception as e:
            print(f"Error loading Whisper model: {e}")
            raise
    
    def prepare_keywords(self, keywords: str) -> List[str]:
        """Prepare and clean keyword list."""
        if not keywords.strip():
            return []
        
        keyword_list = [kw.strip().lower() for kw in keywords.split(",")]
        keyword_list = [kw for kw in keyword_list if kw]
        
        return keyword_list
    
    def transcribe_audio(self, audio_tensor: torch.Tensor) -> str:
        """

        Transcribe audio using Whisper.

        

        Args:

            audio_tensor: Audio tensor (will be resampled for Whisper)

            

        Returns:

            Transcribed text

        """
        try:
            # Convert to numpy and ensure it's float32
            audio_np = audio_tensor.numpy().astype(np.float32)
            
            # Whisper expects 16kHz, but our audio is 48kHz, so we need to resample
            # Simple downsampling (not ideal but works for testing)
            if len(audio_np) > 16000 * 30:  # If longer than 30 seconds at 16kHz
                # Downsample from 48kHz to 16kHz
                audio_np = audio_np[::3]  # Simple decimation
            
            # Ensure audio is in the right range [-1, 1]
            if audio_np.max() > 1.0 or audio_np.min() < -1.0:
                audio_np = np.clip(audio_np, -1.0, 1.0)
            
            # Transcribe
            result = self.model.transcribe(
                audio_np, 
                language="es",  # Spanish
                task="transcribe",
                fp16=False,
                verbose=False
            )
            
            transcription = result["text"].strip().lower()
            print(f"📝 Transcription: '{transcription}'")
            
            return transcription
            
        except Exception as e:
            print(f"Error transcribing audio: {e}")
            return ""
    
    def calculate_keyword_similarity(self, transcription: str, keyword: str) -> float:
        """

        Calculate similarity between transcription and keyword.

        

        Args:

            transcription: Transcribed text

            keyword: Target keyword

            

        Returns:

            Similarity score (0-1)

        """
        if not transcription or not keyword:
            return 0.0
        
        # Method 1: Exact match
        if keyword in transcription:
            return 1.0
        
        # Method 2: Word boundary match
        word_pattern = r'\b' + re.escape(keyword) + r'\b'
        if re.search(word_pattern, transcription):
            return 1.0
        
        # Method 3: Fuzzy matching for each word in transcription
        words = transcription.split()
        max_similarity = 0.0
        
        for word in words:
            # Clean word (remove punctuation)
            clean_word = re.sub(r'[^\w]', '', word)
            if clean_word:
                similarity = SequenceMatcher(None, clean_word, keyword).ratio()
                max_similarity = max(max_similarity, similarity)
        
        # Method 4: Overall sequence similarity as fallback
        overall_similarity = SequenceMatcher(None, transcription, keyword).ratio()
        
        return max(max_similarity, overall_similarity * 0.7)  # Weight overall similarity less
    
    def classify_keywords(self, audio_tensor: torch.Tensor, keywords: str) -> Dict[str, float]:
        """

        Perform keyword classification using transcription.

        

        Args:

            audio_tensor: Preprocessed audio tensor

            keywords: Comma-separated keywords string

            

        Returns:

            Dictionary mapping keywords to probability scores

        """
        try:
            # Prepare keywords
            keyword_list = self.prepare_keywords(keywords)
            
            if not keyword_list:
                return {"error": "No valid keywords provided"}
            
            # Transcribe audio
            transcription = self.transcribe_audio(audio_tensor)
            
            if not transcription:
                # If no transcription, return low scores
                return {keyword: 0.1 for keyword in keyword_list}
            
            # Calculate similarities
            results = {}
            for keyword in keyword_list:
                similarity = self.calculate_keyword_similarity(transcription, keyword)
                results[keyword] = round(similarity, 4)
            
            return results
            
        except Exception as e:
            error_msg = f"Classification error: {str(e)}"
            print(error_msg)
            return {"error": error_msg}
    
    def change_model(self, new_model_size: str):
        """

        Change the Whisper model size.

        

        Args:

            new_model_size: New model size to load

        """
        if new_model_size != self.model_size:
            print(f"Changing model from {self.model_size} to {new_model_size}")
            self.model_size = new_model_size
            try:
                self.model = whisper.load_model(new_model_size, device=self.device)
                print(f"Successfully loaded {new_model_size} model!")
                return True
            except Exception as e:
                print(f"Error loading {new_model_size} model: {e}")
                return False
        return True