Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import ( | |
| AutoModelForCTC, | |
| AutoProcessor, | |
| Wav2Vec2Processor, | |
| Wav2Vec2ForCTC, | |
| ) | |
| import onnxruntime as rt | |
| import numpy as np | |
| import librosa | |
| import warnings | |
| import os | |
| warnings.filterwarnings("ignore") | |
| class Wave2Vec2Inference: | |
| def __init__(self, model_name, hotwords=[], use_lm_if_possible=True, use_gpu=True, enable_optimizations=True): | |
| # Auto-detect best available device | |
| if use_gpu: | |
| if torch.backends.mps.is_available(): | |
| self.device = "mps" | |
| elif torch.cuda.is_available(): | |
| self.device = "cuda" | |
| else: | |
| self.device = "cpu" | |
| else: | |
| self.device = "cpu" | |
| print(f"Using device: {self.device}") | |
| # Set optimal torch settings for inference | |
| torch.set_grad_enabled(False) # Disable gradients globally for inference | |
| if self.device == "cpu": | |
| # CPU optimizations | |
| torch.set_num_threads(torch.get_num_threads()) # Use all available CPU cores | |
| torch.set_float32_matmul_precision('high') | |
| elif self.device == "cuda": | |
| # CUDA optimizations | |
| torch.backends.cudnn.benchmark = True # Enable cuDNN benchmark mode | |
| torch.backends.cudnn.deterministic = False | |
| elif self.device == "mps": | |
| # MPS optimizations | |
| torch.backends.mps.enable_fallback = True | |
| if use_lm_if_possible: | |
| self.processor = AutoProcessor.from_pretrained(model_name) | |
| else: | |
| self.processor = Wav2Vec2Processor.from_pretrained(model_name) | |
| self.model = AutoModelForCTC.from_pretrained(model_name) | |
| self.model.to(self.device) | |
| # Set model to evaluation mode for inference optimization | |
| self.model.eval() | |
| # Try to optimize model for inference (safe version) - only if enabled | |
| if enable_optimizations: | |
| try: | |
| # First try torch.compile (PyTorch 2.0+) - more robust | |
| if hasattr(torch, 'compile') and self.device != "mps": # MPS doesn't support torch.compile yet | |
| self.model = torch.compile(self.model, mode="reduce-overhead") | |
| print("Model compiled with torch.compile for faster inference") | |
| else: | |
| # Alternative: try JIT scripting for older PyTorch versions | |
| try: | |
| scripted_model = torch.jit.script(self.model) | |
| if hasattr(torch.jit, 'optimize_for_inference'): | |
| scripted_model = torch.jit.optimize_for_inference(scripted_model) | |
| self.model = scripted_model | |
| print("Model optimized with JIT scripting") | |
| except Exception as jit_e: | |
| print(f"JIT optimization failed, using regular model: {jit_e}") | |
| except Exception as e: | |
| print(f"Model optimization failed, using regular model: {e}") | |
| else: | |
| print("Model optimizations disabled") | |
| self.hotwords = hotwords | |
| self.use_lm_if_possible = use_lm_if_possible | |
| # Pre-allocate tensors for common audio lengths to avoid repeated allocation | |
| self.tensor_cache = {} | |
| # Warm up the model with a dummy input (only if optimizations enabled) | |
| if enable_optimizations: | |
| self._warmup_model() | |
| def _warmup_model(self): | |
| """Warm up the model with dummy input to optimize first inference""" | |
| try: | |
| dummy_audio = torch.zeros(16000, device=self.device) # 1 second of silence | |
| dummy_inputs = self.processor( | |
| dummy_audio, | |
| sampling_rate=16_000, | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| # Move inputs to device | |
| dummy_inputs = {k: v.to(self.device) for k, v in dummy_inputs.items()} | |
| # Run dummy inference | |
| with torch.no_grad(): | |
| _ = self.model( | |
| dummy_inputs["input_values"], | |
| attention_mask=dummy_inputs.get("attention_mask") | |
| ) | |
| print("Model warmed up successfully") | |
| except Exception as e: | |
| print(f"Warmup failed: {e}") | |
| def buffer_to_text(self, audio_buffer): | |
| if len(audio_buffer) == 0: | |
| return "" | |
| # Convert to tensor with optimal dtype and device placement | |
| if isinstance(audio_buffer, np.ndarray): | |
| audio_tensor = torch.from_numpy(audio_buffer).float() | |
| else: | |
| audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32) | |
| # Use optimized processing | |
| inputs = self.processor( | |
| audio_tensor, | |
| sampling_rate=16_000, | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| # Move to device in one operation | |
| input_values = inputs.input_values.to(self.device, non_blocking=True) | |
| attention_mask = inputs.attention_mask.to(self.device, non_blocking=True) if "attention_mask" in inputs else None | |
| # Optimized inference with mixed precision for GPU | |
| if self.device in ["cuda", "mps"]: | |
| with torch.no_grad(), torch.autocast(device_type=self.device.replace("mps", "cpu"), enabled=self.device=="cuda"): | |
| if attention_mask is not None: | |
| logits = self.model(input_values, attention_mask=attention_mask).logits | |
| else: | |
| logits = self.model(input_values).logits | |
| else: | |
| # CPU inference optimization | |
| with torch.no_grad(): | |
| if attention_mask is not None: | |
| logits = self.model(input_values, attention_mask=attention_mask).logits | |
| else: | |
| logits = self.model(input_values).logits | |
| # Optimized decoding | |
| if hasattr(self.processor, "decoder") and self.use_lm_if_possible: | |
| # Move to CPU for decoder processing (decoder only works on CPU) | |
| logits_cpu = logits[0].cpu().numpy() | |
| transcription = self.processor.decode( | |
| logits_cpu, | |
| hotwords=self.hotwords, | |
| output_word_offsets=True, | |
| ) | |
| confidence = transcription.lm_score / max(len(transcription.text.split(" ")), 1) | |
| transcription: str = transcription.text | |
| else: | |
| # Fast argmax on GPU/MPS, then move to CPU for batch_decode | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| if self.device != "cpu": | |
| predicted_ids = predicted_ids.cpu() | |
| transcription: str = self.processor.batch_decode(predicted_ids)[0] | |
| return transcription.lower().strip() | |
| def confidence_score(self, logits, predicted_ids): | |
| scores = torch.nn.functional.softmax(logits, dim=-1) | |
| pred_scores = scores.gather(-1, predicted_ids.unsqueeze(-1))[:, :, 0] | |
| mask = torch.logical_and( | |
| predicted_ids.not_equal(self.processor.tokenizer.word_delimiter_token_id), | |
| predicted_ids.not_equal(self.processor.tokenizer.pad_token_id), | |
| ) | |
| character_scores = pred_scores.masked_select(mask) | |
| total_average = torch.sum(character_scores) / len(character_scores) | |
| return total_average | |
| def file_to_text(self, filename): | |
| # Optimized audio loading | |
| try: | |
| audio_input, samplerate = librosa.load(filename, sr=16000, dtype=np.float32) | |
| return self.buffer_to_text(audio_input) | |
| except Exception as e: | |
| print(f"Error loading audio file {filename}: {e}") | |
| return "" | |
| class Wave2Vec2ONNXInference: | |
| def __init__(self, model_name, onnx_path): | |
| self.processor = Wav2Vec2Processor.from_pretrained(model_name) | |
| # Optimized ONNX Runtime session | |
| options = rt.SessionOptions() | |
| options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| options.execution_mode = rt.ExecutionMode.ORT_PARALLEL | |
| options.inter_op_num_threads = 0 # Use all available cores | |
| options.intra_op_num_threads = 0 # Use all available cores | |
| # Enable CPU optimizations | |
| providers = [] | |
| if rt.get_device() == 'GPU': | |
| providers.append('CUDAExecutionProvider') | |
| providers.extend(['CPUExecutionProvider']) | |
| self.model = rt.InferenceSession( | |
| onnx_path, | |
| options, | |
| providers=providers | |
| ) | |
| # Pre-compile input name for faster access | |
| self.input_name = self.model.get_inputs()[0].name | |
| print(f"ONNX model loaded with providers: {self.model.get_providers()}") | |
| def buffer_to_text(self, audio_buffer): | |
| if len(audio_buffer) == 0: | |
| return "" | |
| # Optimized preprocessing | |
| if isinstance(audio_buffer, np.ndarray): | |
| audio_tensor = torch.from_numpy(audio_buffer).float() | |
| else: | |
| audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32) | |
| inputs = self.processor( | |
| audio_tensor, | |
| sampling_rate=16_000, | |
| return_tensors="np", | |
| padding=True, | |
| ) | |
| # Optimized ONNX inference | |
| input_values = inputs.input_values.astype(np.float32) | |
| onnx_outputs = self.model.run( | |
| None, | |
| {self.input_name: input_values} | |
| )[0] | |
| # Fast argmax and decoding | |
| prediction = np.argmax(onnx_outputs, axis=-1) | |
| transcription = self.processor.decode(prediction.squeeze().tolist()) | |
| return transcription.lower().strip() | |
| def file_to_text(self, filename): | |
| try: | |
| audio_input, samplerate = librosa.load(filename, sr=16000, dtype=np.float32) | |
| return self.buffer_to_text(audio_input) | |
| except Exception as e: | |
| print(f"Error loading audio file {filename}: {e}") | |
| return "" | |
| # took that script from: https://github.com/ccoreilly/wav2vec2-service/blob/master/convert_torch_to_onnx.py | |
| class OptimizedWave2Vec2Factory: | |
| """Factory class to create the most optimized Wave2Vec2 inference instance""" | |
| def create_optimized_inference(model_name, onnx_path=None, safe_mode=False, **kwargs): | |
| """ | |
| Create the most optimized inference instance based on available resources | |
| Args: | |
| model_name: HuggingFace model name | |
| onnx_path: Path to ONNX model (optional, for maximum speed) | |
| safe_mode: If True, disable aggressive optimizations that might cause issues | |
| **kwargs: Additional arguments for Wave2Vec2Inference | |
| Returns: | |
| Optimized inference instance | |
| """ | |
| if onnx_path and os.path.exists(onnx_path): | |
| print("Using ONNX model for maximum speed") | |
| return Wave2Vec2ONNXInference(model_name, onnx_path) | |
| else: | |
| print("Using PyTorch model with optimizations") | |
| # In safe mode, disable optimizations that might cause issues | |
| if safe_mode: | |
| kwargs['enable_optimizations'] = False | |
| print("Running in safe mode - optimizations disabled") | |
| return Wave2Vec2Inference(model_name, **kwargs) | |
| def create_safe_inference(model_name, **kwargs): | |
| """Create a safe inference instance without aggressive optimizations""" | |
| kwargs['enable_optimizations'] = False | |
| return Wave2Vec2Inference(model_name, **kwargs) | |
| def convert_to_onnx(model_id_or_path, onnx_model_name): | |
| print(f"Converting {model_id_or_path} to onnx") | |
| model = Wav2Vec2ForCTC.from_pretrained(model_id_or_path) | |
| audio_len = 250000 | |
| x = torch.randn(1, audio_len, requires_grad=True) | |
| torch.onnx.export( | |
| model, # model being run | |
| x, # model input (or a tuple for multiple inputs) | |
| onnx_model_name, # where to save the model (can be a file or file-like object) | |
| export_params=True, # store the trained parameter weights inside the model file | |
| opset_version=14, # the ONNX version to export the model to | |
| do_constant_folding=True, # whether to execute constant folding for optimization | |
| input_names=["input"], # the model's input names | |
| output_names=["output"], # the model's output names | |
| dynamic_axes={ | |
| "input": {1: "audio_len"}, # variable length axes | |
| "output": {1: "audio_len"}, | |
| }, | |
| ) | |
| def quantize_onnx_model(onnx_model_path, quantized_model_path): | |
| print("Starting quantization...") | |
| from onnxruntime.quantization import quantize_dynamic, QuantType | |
| quantize_dynamic( | |
| onnx_model_path, quantized_model_path, weight_type=QuantType.QUInt8 | |
| ) | |
| print(f"Quantized model saved to: {quantized_model_path}") | |
| def export_to_onnx( | |
| model: str = "facebook/wav2vec2-large-960h-lv60-self", quantize: bool = False | |
| ): | |
| onnx_model_name = model.split("/")[-1] + ".onnx" | |
| convert_to_onnx(model, onnx_model_name) | |
| if quantize: | |
| quantized_model_name = model.split("/")[-1] + ".quant.onnx" | |
| quantize_onnx_model(onnx_model_name, quantized_model_name) | |
| if __name__ == "__main__": | |
| from loguru import logger | |
| import time | |
| # Use optimized factory to create the best inference instance | |
| asr = OptimizedWave2Vec2Factory.create_optimized_inference( | |
| "facebook/wav2vec2-large-960h-lv60-self" | |
| ) | |
| # Test if file exists | |
| test_file = "test.wav" | |
| if not os.path.exists(test_file): | |
| print(f"Test file {test_file} not found. Please provide a valid audio file.") | |
| exit(1) | |
| # Warm up runs (model already warmed up during initialization) | |
| print("Running additional warm-up...") | |
| for i in range(2): | |
| asr.file_to_text(test_file) | |
| print(f"Warm up {i+1} completed") | |
| # Test runs | |
| print("Running optimized performance tests...") | |
| times = [] | |
| for i in range(10): | |
| start_time = time.time() | |
| text = asr.file_to_text(test_file) | |
| end_time = time.time() | |
| execution_time = end_time - start_time | |
| times.append(execution_time) | |
| print(f"Test {i+1}: {execution_time:.3f}s - {text}") | |
| # Calculate statistics | |
| average_time = sum(times) / len(times) | |
| min_time = min(times) | |
| max_time = max(times) | |
| std_time = np.std(times) | |
| print(f"\n=== Performance Statistics ===") | |
| print(f"Average execution time: {average_time:.3f}s") | |
| print(f"Min time: {min_time:.3f}s") | |
| print(f"Max time: {max_time:.3f}s") | |
| print(f"Standard deviation: {std_time:.3f}s") | |
| print(f"Speed improvement: ~{((max_time - min_time) / max_time * 100):.1f}% faster (min vs max)") | |
| # Calculate throughput | |
| if times: | |
| throughput = 1.0 / average_time | |
| print(f"Average throughput: {throughput:.2f} inferences/second") | |