import torch from transformers import ( AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, Wav2Vec2ForCTC, ) import onnxruntime as rt import numpy as np import librosa import warnings import os warnings.filterwarnings("ignore") class Wave2Vec2Inference: def __init__(self, model_name, hotwords=[], use_lm_if_possible=True, use_gpu=True, enable_optimizations=True): # Auto-detect best available device if use_gpu: if torch.backends.mps.is_available(): self.device = "mps" elif torch.cuda.is_available(): self.device = "cuda" else: self.device = "cpu" else: self.device = "cpu" print(f"Using device: {self.device}") # Set optimal torch settings for inference torch.set_grad_enabled(False) # Disable gradients globally for inference if self.device == "cpu": # CPU optimizations torch.set_num_threads(torch.get_num_threads()) # Use all available CPU cores torch.set_float32_matmul_precision('high') elif self.device == "cuda": # CUDA optimizations torch.backends.cudnn.benchmark = True # Enable cuDNN benchmark mode torch.backends.cudnn.deterministic = False elif self.device == "mps": # MPS optimizations torch.backends.mps.enable_fallback = True if use_lm_if_possible: self.processor = AutoProcessor.from_pretrained(model_name) else: self.processor = Wav2Vec2Processor.from_pretrained(model_name) self.model = AutoModelForCTC.from_pretrained(model_name) self.model.to(self.device) # Set model to evaluation mode for inference optimization self.model.eval() # Try to optimize model for inference (safe version) - only if enabled if enable_optimizations: try: # First try torch.compile (PyTorch 2.0+) - more robust if hasattr(torch, 'compile') and self.device != "mps": # MPS doesn't support torch.compile yet self.model = torch.compile(self.model, mode="reduce-overhead") print("Model compiled with torch.compile for faster inference") else: # Alternative: try JIT scripting for older PyTorch versions try: scripted_model = torch.jit.script(self.model) if hasattr(torch.jit, 'optimize_for_inference'): scripted_model = torch.jit.optimize_for_inference(scripted_model) self.model = scripted_model print("Model optimized with JIT scripting") except Exception as jit_e: print(f"JIT optimization failed, using regular model: {jit_e}") except Exception as e: print(f"Model optimization failed, using regular model: {e}") else: print("Model optimizations disabled") self.hotwords = hotwords self.use_lm_if_possible = use_lm_if_possible # Pre-allocate tensors for common audio lengths to avoid repeated allocation self.tensor_cache = {} # Warm up the model with a dummy input (only if optimizations enabled) if enable_optimizations: self._warmup_model() def _warmup_model(self): """Warm up the model with dummy input to optimize first inference""" try: dummy_audio = torch.zeros(16000, device=self.device) # 1 second of silence dummy_inputs = self.processor( dummy_audio, sampling_rate=16_000, return_tensors="pt", padding=True, ) # Move inputs to device dummy_inputs = {k: v.to(self.device) for k, v in dummy_inputs.items()} # Run dummy inference with torch.no_grad(): _ = self.model( dummy_inputs["input_values"], attention_mask=dummy_inputs.get("attention_mask") ) print("Model warmed up successfully") except Exception as e: print(f"Warmup failed: {e}") def buffer_to_text(self, audio_buffer): if len(audio_buffer) == 0: return "" # Convert to tensor with optimal dtype and device placement if isinstance(audio_buffer, np.ndarray): audio_tensor = torch.from_numpy(audio_buffer).float() else: audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32) # Use optimized processing inputs = self.processor( audio_tensor, sampling_rate=16_000, return_tensors="pt", padding=True, ) # Move to device in one operation input_values = inputs.input_values.to(self.device, non_blocking=True) attention_mask = inputs.attention_mask.to(self.device, non_blocking=True) if "attention_mask" in inputs else None # Optimized inference with mixed precision for GPU if self.device in ["cuda", "mps"]: with torch.no_grad(), torch.autocast(device_type=self.device.replace("mps", "cpu"), enabled=self.device=="cuda"): if attention_mask is not None: logits = self.model(input_values, attention_mask=attention_mask).logits else: logits = self.model(input_values).logits else: # CPU inference optimization with torch.no_grad(): if attention_mask is not None: logits = self.model(input_values, attention_mask=attention_mask).logits else: logits = self.model(input_values).logits # Optimized decoding if hasattr(self.processor, "decoder") and self.use_lm_if_possible: # Move to CPU for decoder processing (decoder only works on CPU) logits_cpu = logits[0].cpu().numpy() transcription = self.processor.decode( logits_cpu, hotwords=self.hotwords, output_word_offsets=True, ) confidence = transcription.lm_score / max(len(transcription.text.split(" ")), 1) transcription: str = transcription.text else: # Fast argmax on GPU/MPS, then move to CPU for batch_decode predicted_ids = torch.argmax(logits, dim=-1) if self.device != "cpu": predicted_ids = predicted_ids.cpu() transcription: str = self.processor.batch_decode(predicted_ids)[0] return transcription.lower().strip() def confidence_score(self, logits, predicted_ids): scores = torch.nn.functional.softmax(logits, dim=-1) pred_scores = scores.gather(-1, predicted_ids.unsqueeze(-1))[:, :, 0] mask = torch.logical_and( predicted_ids.not_equal(self.processor.tokenizer.word_delimiter_token_id), predicted_ids.not_equal(self.processor.tokenizer.pad_token_id), ) character_scores = pred_scores.masked_select(mask) total_average = torch.sum(character_scores) / len(character_scores) return total_average def file_to_text(self, filename): # Optimized audio loading try: audio_input, samplerate = librosa.load(filename, sr=16000, dtype=np.float32) return self.buffer_to_text(audio_input) except Exception as e: print(f"Error loading audio file {filename}: {e}") return "" class Wave2Vec2ONNXInference: def __init__(self, model_name, onnx_path): self.processor = Wav2Vec2Processor.from_pretrained(model_name) # Optimized ONNX Runtime session options = rt.SessionOptions() options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL options.execution_mode = rt.ExecutionMode.ORT_PARALLEL options.inter_op_num_threads = 0 # Use all available cores options.intra_op_num_threads = 0 # Use all available cores # Enable CPU optimizations providers = [] if rt.get_device() == 'GPU': providers.append('CUDAExecutionProvider') providers.extend(['CPUExecutionProvider']) self.model = rt.InferenceSession( onnx_path, options, providers=providers ) # Pre-compile input name for faster access self.input_name = self.model.get_inputs()[0].name print(f"ONNX model loaded with providers: {self.model.get_providers()}") def buffer_to_text(self, audio_buffer): if len(audio_buffer) == 0: return "" # Optimized preprocessing if isinstance(audio_buffer, np.ndarray): audio_tensor = torch.from_numpy(audio_buffer).float() else: audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32) inputs = self.processor( audio_tensor, sampling_rate=16_000, return_tensors="np", padding=True, ) # Optimized ONNX inference input_values = inputs.input_values.astype(np.float32) onnx_outputs = self.model.run( None, {self.input_name: input_values} )[0] # Fast argmax and decoding prediction = np.argmax(onnx_outputs, axis=-1) transcription = self.processor.decode(prediction.squeeze().tolist()) return transcription.lower().strip() def file_to_text(self, filename): try: audio_input, samplerate = librosa.load(filename, sr=16000, dtype=np.float32) return self.buffer_to_text(audio_input) except Exception as e: print(f"Error loading audio file {filename}: {e}") return "" # took that script from: https://github.com/ccoreilly/wav2vec2-service/blob/master/convert_torch_to_onnx.py class OptimizedWave2Vec2Factory: """Factory class to create the most optimized Wave2Vec2 inference instance""" @staticmethod def create_optimized_inference(model_name, onnx_path=None, safe_mode=False, **kwargs): """ Create the most optimized inference instance based on available resources Args: model_name: HuggingFace model name onnx_path: Path to ONNX model (optional, for maximum speed) safe_mode: If True, disable aggressive optimizations that might cause issues **kwargs: Additional arguments for Wave2Vec2Inference Returns: Optimized inference instance """ if onnx_path and os.path.exists(onnx_path): print("Using ONNX model for maximum speed") return Wave2Vec2ONNXInference(model_name, onnx_path) else: print("Using PyTorch model with optimizations") # In safe mode, disable optimizations that might cause issues if safe_mode: kwargs['enable_optimizations'] = False print("Running in safe mode - optimizations disabled") return Wave2Vec2Inference(model_name, **kwargs) @staticmethod def create_safe_inference(model_name, **kwargs): """Create a safe inference instance without aggressive optimizations""" kwargs['enable_optimizations'] = False return Wave2Vec2Inference(model_name, **kwargs) def convert_to_onnx(model_id_or_path, onnx_model_name): print(f"Converting {model_id_or_path} to onnx") model = Wav2Vec2ForCTC.from_pretrained(model_id_or_path) audio_len = 250000 x = torch.randn(1, audio_len, requires_grad=True) torch.onnx.export( model, # model being run x, # model input (or a tuple for multiple inputs) onnx_model_name, # where to save the model (can be a file or file-like object) export_params=True, # store the trained parameter weights inside the model file opset_version=14, # the ONNX version to export the model to do_constant_folding=True, # whether to execute constant folding for optimization input_names=["input"], # the model's input names output_names=["output"], # the model's output names dynamic_axes={ "input": {1: "audio_len"}, # variable length axes "output": {1: "audio_len"}, }, ) def quantize_onnx_model(onnx_model_path, quantized_model_path): print("Starting quantization...") from onnxruntime.quantization import quantize_dynamic, QuantType quantize_dynamic( onnx_model_path, quantized_model_path, weight_type=QuantType.QUInt8 ) print(f"Quantized model saved to: {quantized_model_path}") def export_to_onnx( model: str = "facebook/wav2vec2-large-960h-lv60-self", quantize: bool = False ): onnx_model_name = model.split("/")[-1] + ".onnx" convert_to_onnx(model, onnx_model_name) if quantize: quantized_model_name = model.split("/")[-1] + ".quant.onnx" quantize_onnx_model(onnx_model_name, quantized_model_name) if __name__ == "__main__": from loguru import logger import time # Use optimized factory to create the best inference instance asr = OptimizedWave2Vec2Factory.create_optimized_inference( "facebook/wav2vec2-large-960h-lv60-self" ) # Test if file exists test_file = "test.wav" if not os.path.exists(test_file): print(f"Test file {test_file} not found. Please provide a valid audio file.") exit(1) # Warm up runs (model already warmed up during initialization) print("Running additional warm-up...") for i in range(2): asr.file_to_text(test_file) print(f"Warm up {i+1} completed") # Test runs print("Running optimized performance tests...") times = [] for i in range(10): start_time = time.time() text = asr.file_to_text(test_file) end_time = time.time() execution_time = end_time - start_time times.append(execution_time) print(f"Test {i+1}: {execution_time:.3f}s - {text}") # Calculate statistics average_time = sum(times) / len(times) min_time = min(times) max_time = max(times) std_time = np.std(times) print(f"\n=== Performance Statistics ===") print(f"Average execution time: {average_time:.3f}s") print(f"Min time: {min_time:.3f}s") print(f"Max time: {max_time:.3f}s") print(f"Standard deviation: {std_time:.3f}s") print(f"Speed improvement: ~{((max_time - min_time) / max_time * 100):.1f}% faster (min vs max)") # Calculate throughput if times: throughput = 1.0 / average_time print(f"Average throughput: {throughput:.2f} inferences/second")