Spaces:
Sleeping
Sleeping
File size: 7,006 Bytes
e685c03 97c892c e685c03 97c892c e685c03 97c892c e685c03 97c892c e685c03 97c892c e685c03 97c892c e685c03 97c892c e685c03 97c892c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
"""
Whisper-only keyword spotter for zero-shot audio keyword detection.
Uses Whisper transcription + text matching without CLAP dependencies.
"""
import torch
import numpy as np
from typing import List, Dict
import warnings
import re
from difflib import SequenceMatcher
warnings.filterwarnings("ignore")
try:
import whisper
WHISPER_AVAILABLE = True
except ImportError:
WHISPER_AVAILABLE = False
print("⚠️ Whisper not available. Install with: pip install openai-whisper")
class WhisperKeywordSpotter:
"""Keyword spotter using Whisper transcription + text matching."""
def __init__(self, model_size: str = "base"):
"""
Initialize the Whisper-based keyword spotter.
Args:
model_size: Whisper model size ('tiny', 'base', 'small', 'medium', 'large')
"""
if not WHISPER_AVAILABLE:
raise ImportError("Whisper is not available. Install with: pip install openai-whisper")
self.model_size = model_size
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Loading Whisper model: {model_size}")
print(f"Using device: {self.device}")
try:
self.model = whisper.load_model(model_size, device=self.device)
print("Whisper model loaded successfully!")
except Exception as e:
print(f"Error loading Whisper model: {e}")
raise
def prepare_keywords(self, keywords: str) -> List[str]:
"""Prepare and clean keyword list."""
if not keywords.strip():
return []
keyword_list = [kw.strip().lower() for kw in keywords.split(",")]
keyword_list = [kw for kw in keyword_list if kw]
return keyword_list
def transcribe_audio(self, audio_tensor: torch.Tensor) -> str:
"""
Transcribe audio using Whisper.
Args:
audio_tensor: Audio tensor (will be resampled for Whisper)
Returns:
Transcribed text
"""
try:
# Convert to numpy and ensure it's float32
audio_np = audio_tensor.numpy().astype(np.float32)
# Whisper expects 16kHz, but our audio is 48kHz, so we need to resample
# Simple downsampling (not ideal but works for testing)
if len(audio_np) > 16000 * 30: # If longer than 30 seconds at 16kHz
# Downsample from 48kHz to 16kHz
audio_np = audio_np[::3] # Simple decimation
# Ensure audio is in the right range [-1, 1]
if audio_np.max() > 1.0 or audio_np.min() < -1.0:
audio_np = np.clip(audio_np, -1.0, 1.0)
# Transcribe
result = self.model.transcribe(
audio_np,
language="es", # Spanish
task="transcribe",
fp16=False,
verbose=False
)
transcription = result["text"].strip().lower()
print(f"📝 Transcription: '{transcription}'")
return transcription
except Exception as e:
print(f"Error transcribing audio: {e}")
return ""
def calculate_keyword_similarity(self, transcription: str, keyword: str) -> float:
"""
Calculate similarity between transcription and keyword.
Args:
transcription: Transcribed text
keyword: Target keyword
Returns:
Similarity score (0-1)
"""
if not transcription or not keyword:
return 0.0
# Method 1: Exact match
if keyword in transcription:
return 1.0
# Method 2: Word boundary match
word_pattern = r'\b' + re.escape(keyword) + r'\b'
if re.search(word_pattern, transcription):
return 1.0
# Method 3: Fuzzy matching for each word in transcription
words = transcription.split()
max_similarity = 0.0
for word in words:
# Clean word (remove punctuation)
clean_word = re.sub(r'[^\w]', '', word)
if clean_word:
similarity = SequenceMatcher(None, clean_word, keyword).ratio()
max_similarity = max(max_similarity, similarity)
# Method 4: Overall sequence similarity as fallback
overall_similarity = SequenceMatcher(None, transcription, keyword).ratio()
return max(max_similarity, overall_similarity * 0.7) # Weight overall similarity less
def classify_keywords(self, audio_tensor: torch.Tensor, keywords: str) -> Dict[str, float]:
"""
Perform keyword classification using transcription.
Args:
audio_tensor: Preprocessed audio tensor
keywords: Comma-separated keywords string
Returns:
Dictionary mapping keywords to probability scores
"""
try:
# Prepare keywords
keyword_list = self.prepare_keywords(keywords)
if not keyword_list:
return {"error": "No valid keywords provided"}
# Transcribe audio
transcription = self.transcribe_audio(audio_tensor)
if not transcription:
# If no transcription, return low scores
return {keyword: 0.1 for keyword in keyword_list}
# Calculate similarities
results = {}
for keyword in keyword_list:
similarity = self.calculate_keyword_similarity(transcription, keyword)
results[keyword] = round(similarity, 4)
return results
except Exception as e:
error_msg = f"Classification error: {str(e)}"
print(error_msg)
return {"error": error_msg}
def change_model(self, new_model_size: str):
"""
Change the Whisper model size.
Args:
new_model_size: New model size to load
"""
if new_model_size != self.model_size:
print(f"Changing model from {self.model_size} to {new_model_size}")
self.model_size = new_model_size
try:
self.model = whisper.load_model(new_model_size, device=self.device)
print(f"Successfully loaded {new_model_size} model!")
return True
except Exception as e:
print(f"Error loading {new_model_size} model: {e}")
return False
return True |