Simple-KWS / whisper_classifier.py
IvanLayer7's picture
Upload 5 files
97c892c verified
"""
Whisper-only keyword spotter for zero-shot audio keyword detection.
Uses Whisper transcription + text matching without CLAP dependencies.
"""
import torch
import numpy as np
from typing import List, Dict
import warnings
import re
from difflib import SequenceMatcher
warnings.filterwarnings("ignore")
try:
import whisper
WHISPER_AVAILABLE = True
except ImportError:
WHISPER_AVAILABLE = False
print("⚠️ Whisper not available. Install with: pip install openai-whisper")
class WhisperKeywordSpotter:
"""Keyword spotter using Whisper transcription + text matching."""
def __init__(self, model_size: str = "base"):
"""
Initialize the Whisper-based keyword spotter.
Args:
model_size: Whisper model size ('tiny', 'base', 'small', 'medium', 'large')
"""
if not WHISPER_AVAILABLE:
raise ImportError("Whisper is not available. Install with: pip install openai-whisper")
self.model_size = model_size
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Loading Whisper model: {model_size}")
print(f"Using device: {self.device}")
try:
self.model = whisper.load_model(model_size, device=self.device)
print("Whisper model loaded successfully!")
except Exception as e:
print(f"Error loading Whisper model: {e}")
raise
def prepare_keywords(self, keywords: str) -> List[str]:
"""Prepare and clean keyword list."""
if not keywords.strip():
return []
keyword_list = [kw.strip().lower() for kw in keywords.split(",")]
keyword_list = [kw for kw in keyword_list if kw]
return keyword_list
def transcribe_audio(self, audio_tensor: torch.Tensor) -> str:
"""
Transcribe audio using Whisper.
Args:
audio_tensor: Audio tensor (will be resampled for Whisper)
Returns:
Transcribed text
"""
try:
# Convert to numpy and ensure it's float32
audio_np = audio_tensor.numpy().astype(np.float32)
# Whisper expects 16kHz, but our audio is 48kHz, so we need to resample
# Simple downsampling (not ideal but works for testing)
if len(audio_np) > 16000 * 30: # If longer than 30 seconds at 16kHz
# Downsample from 48kHz to 16kHz
audio_np = audio_np[::3] # Simple decimation
# Ensure audio is in the right range [-1, 1]
if audio_np.max() > 1.0 or audio_np.min() < -1.0:
audio_np = np.clip(audio_np, -1.0, 1.0)
# Transcribe
result = self.model.transcribe(
audio_np,
language="es", # Spanish
task="transcribe",
fp16=False,
verbose=False
)
transcription = result["text"].strip().lower()
print(f"📝 Transcription: '{transcription}'")
return transcription
except Exception as e:
print(f"Error transcribing audio: {e}")
return ""
def calculate_keyword_similarity(self, transcription: str, keyword: str) -> float:
"""
Calculate similarity between transcription and keyword.
Args:
transcription: Transcribed text
keyword: Target keyword
Returns:
Similarity score (0-1)
"""
if not transcription or not keyword:
return 0.0
# Method 1: Exact match
if keyword in transcription:
return 1.0
# Method 2: Word boundary match
word_pattern = r'\b' + re.escape(keyword) + r'\b'
if re.search(word_pattern, transcription):
return 1.0
# Method 3: Fuzzy matching for each word in transcription
words = transcription.split()
max_similarity = 0.0
for word in words:
# Clean word (remove punctuation)
clean_word = re.sub(r'[^\w]', '', word)
if clean_word:
similarity = SequenceMatcher(None, clean_word, keyword).ratio()
max_similarity = max(max_similarity, similarity)
# Method 4: Overall sequence similarity as fallback
overall_similarity = SequenceMatcher(None, transcription, keyword).ratio()
return max(max_similarity, overall_similarity * 0.7) # Weight overall similarity less
def classify_keywords(self, audio_tensor: torch.Tensor, keywords: str) -> Dict[str, float]:
"""
Perform keyword classification using transcription.
Args:
audio_tensor: Preprocessed audio tensor
keywords: Comma-separated keywords string
Returns:
Dictionary mapping keywords to probability scores
"""
try:
# Prepare keywords
keyword_list = self.prepare_keywords(keywords)
if not keyword_list:
return {"error": "No valid keywords provided"}
# Transcribe audio
transcription = self.transcribe_audio(audio_tensor)
if not transcription:
# If no transcription, return low scores
return {keyword: 0.1 for keyword in keyword_list}
# Calculate similarities
results = {}
for keyword in keyword_list:
similarity = self.calculate_keyword_similarity(transcription, keyword)
results[keyword] = round(similarity, 4)
return results
except Exception as e:
error_msg = f"Classification error: {str(e)}"
print(error_msg)
return {"error": error_msg}
def change_model(self, new_model_size: str):
"""
Change the Whisper model size.
Args:
new_model_size: New model size to load
"""
if new_model_size != self.model_size:
print(f"Changing model from {self.model_size} to {new_model_size}")
self.model_size = new_model_size
try:
self.model = whisper.load_model(new_model_size, device=self.device)
print(f"Successfully loaded {new_model_size} model!")
return True
except Exception as e:
print(f"Error loading {new_model_size} model: {e}")
return False
return True