from typing import List import torch import torchaudio from transformers import Wav2Vec2Processor, QuantoConfig from constants import MAX_AUDIO_DURATION_SECONDS, MONO_CHANNEL, SAMPLING_RATE from gop_model import GOPPhonemeClassifier import logging logger = logging.getLogger(__name__) def load_model_and_processor(model_repo_id: str): logger.info(f"Loading model and processor from Hugging Face Hub: {model_repo_id}") quantization_config = QuantoConfig(weights="int8") logger.info("Applying INT8 dynamic quantization during model loading...") model = GOPPhonemeClassifier.from_pretrained( model_repo_id, quantization_config=quantization_config, device_map="auto" ) processor = Wav2Vec2Processor.from_pretrained(model_repo_id) model.eval() return model, processor def validate_phonemes(phoneme_text, allowed_phonemes): if not phoneme_text.strip(): return "
Please enter the phonemes.
" phonemes = phoneme_text.strip().split() for phoneme in phonemes: if phoneme not in allowed_phonemes: return f"Invalid phoneme: '{phoneme}'. Please check your input.
" return None def run_inference(audio_file_path: str, transcript: str, model: GOPPhonemeClassifier, processor: Wav2Vec2Processor): if not audio_file_path or not transcript: return "Please provide both an audio file and the transcript.
" try: waveform, original_sr = torchaudio.load(audio_file_path) duration_seconds = waveform.shape[1] / original_sr if duration_seconds > MAX_AUDIO_DURATION_SECONDS: raise ValueError(f"The audio recording should not be longer than {MAX_AUDIO_DURATION_SECONDS} seconds.") if waveform.shape[0] > MONO_CHANNEL: waveform = torch.mean(waveform, dim=0, keepdim=True) if original_sr != SAMPLING_RATE: resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=SAMPLING_RATE) waveform = resampler(waveform) audio_input = waveform.squeeze(0) processed_audio = processor(audio_input, sampling_rate=SAMPLING_RATE, return_tensors="pt", padding=True) input_values = processed_audio.input_values.to(model.device) attention_mask = processed_audio.attention_mask.to(model.device) phonemes: List[str] = transcript.strip().split() tokenizer = processor.tokenizer unk_id = getattr(tokenizer, "unk_token_id", None) ids = tokenizer.convert_tokens_to_ids(phonemes) if isinstance(ids, int): ids = [ids] ids = [i if i is not None else unk_id for i in ids] canonical_token_ids = torch.tensor([ids], dtype=torch.long).to(model.device) token_lengths = torch.tensor([len(ids)], dtype=torch.long).to(model.device) token_mask = torch.ones_like(canonical_token_ids).to(model.device) with torch.no_grad(): outputs = model( input_values=input_values, attention_mask=attention_mask, canonical_token_ids=canonical_token_ids, token_lengths=token_lengths, token_mask=token_mask ) logits = outputs.logits head_name = next(iter(logits)) scores_tensor = logits[head_name] predicted_scores = torch.argmax(scores_tensor, dim=-1) tokens = processor.tokenizer.convert_ids_to_tokens(canonical_token_ids[0]) return predicted_scores, tokens, token_lengths except Exception as e: logger.error(f"An error occurred during inference: {e}", exc_info=True) return f"An error occurred: {e}
"