| import numpy as np |
| import torch |
| import librosa |
| from transformers import AutoTokenizer, AutoModelForCausalLM, WhisperProcessor, WhisperForConditionalGeneration, pipeline |
| import soundfile as sf |
| import os |
|
|
| import logging |
| logger = logging.getLogger(__name__) |
|
|
| class InferenceRecipe: |
| def __init__(self, model_path='./models', device='cuda'): |
| self.device = device |
| self.asr_processor = None |
| self.asr_model = None |
| self.chat_tokenizer = None |
| self.chat_model = None |
| self.tts_model = None |
| self.tts_sample_rate = 22050 |
| self.model_path = model_path |
| self.initialize_models() |
|
|
| def initialize_models(self): |
| """Initialize models from local cache""" |
| |
| asr_path = os.path.join(self.model_path, 'asr') |
| logger.info(f"Loading ASR model from {asr_path}") |
| |
| self.asr_processor = WhisperProcessor.from_pretrained(asr_path, local_files_only=True) |
| self.asr_model = WhisperForConditionalGeneration.from_pretrained(asr_path, local_files_only=True) |
| self.asr_model = self.asr_model.to(self.device) |
| |
| |
| self.asr_model.generation_config.no_timestamps_token_id = self.asr_processor.tokenizer.convert_tokens_to_ids("<|notimestamps|>") |
| self.asr_model.config.forced_decoder_ids = self.asr_processor.get_decoder_prompt_ids(language="english", task="transcribe") |
|
|
| |
| dialogpt_path = os.path.join(self.model_path, "llm") |
| logger.info(f"Loading Chat model from {dialogpt_path}") |
| self.chat_tokenizer = AutoTokenizer.from_pretrained(dialogpt_path) |
| self.chat_model = AutoModelForCausalLM.from_pretrained(dialogpt_path) |
| self.chat_model = self.chat_model.to(self.device) |
|
|
| |
| logger.info(f"Loading TTS model from {self.model_path}") |
| self.tts_model = pipeline( |
| "text-to-speech", |
| model=os.path.join(self.model_path, "tts"), |
| device=self.device, |
| torch_dtype=torch.float32 |
| ) |
|
|
| def inference(self, audio_array, sample_rate): |
| """Updated inference pipeline""" |
| logger.info(f"Running inference with audio shape: {audio_array.shape}") |
| if len(audio_array.shape) == 2: |
| audio_array = audio_array.squeeze() |
| |
| |
| logger.info(f"Running ASR with audio shape: {audio_array.shape}") |
| |
| |
| input_features = self.asr_processor( |
| audio_array, |
| sampling_rate=sample_rate, |
| return_tensors="pt" |
| ).input_features.to(self.device) |
|
|
| |
| generated_ids = self.asr_model.generate(input_features) |
| text = self.asr_processor.batch_decode( |
| generated_ids, |
| skip_special_tokens=True |
| )[0] |
|
|
| |
| logger.info(f"Running Chat with text: {text}") |
| input_ids = self.chat_tokenizer.encode(text + self.chat_tokenizer.eos_token, return_tensors="pt") |
| attention_mask = torch.ones_like(input_ids) |
| chat_output = self.chat_model.generate( |
| input_ids.to(self.device), |
| attention_mask=attention_mask.to(self.device), |
| max_length=1000, |
| pad_token_id=self.chat_tokenizer.eos_token_id |
| ) |
| reply = self.chat_tokenizer.decode(chat_output[:, input_ids.shape[-1]:][0], skip_special_tokens=True) |
|
|
| |
| logger.info(f"Running TTS with text: {reply}") |
| tts_output = self.tts_model(reply) |
| audio_array = tts_output['audio'] |
| |
| |
| logger.info(f"Ensuring audio is in correct format") |
| audio_array = audio_array.astype(np.float32) |
| audio_array = np.clip(audio_array, -1.0, 1.0) |
| |
| |
| if sample_rate != self.tts_sample_rate: |
| logger.info(f"Resampling audio to match input rate") |
| from scipy import signal |
| samples = len(audio_array) |
| new_samples = int(samples * sample_rate / self.tts_sample_rate) |
| audio_array = signal.resample(audio_array, new_samples) |
|
|
| |
| logger.info(f"Ensuring audio is 1D") |
| if len(audio_array.shape) > 1: |
| audio_array = audio_array.squeeze() |
|
|
| return {"audio": audio_array, "text": reply} |
|
|
| if __name__ == "__main__": |
| recipe = InferenceRecipe(model_path="./models") |
| |
| sr = 16000 |
| duration = 3 |
| audio = np.zeros(int(sr * duration)) |
| response = recipe.inference(audio, sr) |
| |
| print(f"Audio shape: {response['audio'].shape}, Range: [{response['audio'].min()}, {response['audio'].max()}]") |
| print(f"Generated text: {response['text']}") |
| |
| |
| sf.write( |
| "response.wav", |
| response['audio'], |
| sr, |
| format='WAV', |
| subtype='FLOAT' |
| ) |