Aleksei Žavoronkov
enhance run_inference to support token masking and adjust logits handling
ed26f9c
| from typing import List | |
| import torch | |
| import torchaudio | |
| from transformers import Wav2Vec2Processor, QuantoConfig | |
| from constants import MAX_AUDIO_DURATION_SECONDS, MONO_CHANNEL, SAMPLING_RATE | |
| from gop_model import GOPPhonemeClassifier | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| def load_model_and_processor(model_repo_id: str): | |
| logger.info(f"Loading model and processor from Hugging Face Hub: {model_repo_id}") | |
| quantization_config = QuantoConfig(weights="int8") | |
| logger.info("Applying INT8 dynamic quantization during model loading...") | |
| model = GOPPhonemeClassifier.from_pretrained( | |
| model_repo_id, | |
| quantization_config=quantization_config, | |
| device_map="auto" | |
| ) | |
| processor = Wav2Vec2Processor.from_pretrained(model_repo_id) | |
| model.eval() | |
| return model, processor | |
| def validate_phonemes(phoneme_text, allowed_phonemes): | |
| if not phoneme_text.strip(): | |
| return "<p style='text-align:center; color:red;'>Please enter the phonemes.</p>" | |
| phonemes = phoneme_text.strip().split() | |
| for phoneme in phonemes: | |
| if phoneme not in allowed_phonemes: | |
| return f"<p style='text-align:center; color:red;'>Invalid phoneme: '{phoneme}'. Please check your input.</p>" | |
| return None | |
| def run_inference(audio_file_path: str, transcript: str, model: GOPPhonemeClassifier, processor: Wav2Vec2Processor): | |
| if not audio_file_path or not transcript: | |
| return "<p style='text-align:center; color:red;'>Please provide both an audio file and the transcript.</p>" | |
| try: | |
| waveform, original_sr = torchaudio.load(audio_file_path) | |
| duration_seconds = waveform.shape[1] / original_sr | |
| if duration_seconds > MAX_AUDIO_DURATION_SECONDS: | |
| raise ValueError(f"The audio recording should not be longer than {MAX_AUDIO_DURATION_SECONDS} seconds.") | |
| if waveform.shape[0] > MONO_CHANNEL: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| if original_sr != SAMPLING_RATE: | |
| resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=SAMPLING_RATE) | |
| waveform = resampler(waveform) | |
| audio_input = waveform.squeeze(0) | |
| processed_audio = processor(audio_input, sampling_rate=SAMPLING_RATE, return_tensors="pt", padding=True) | |
| input_values = processed_audio.input_values.to(model.device) | |
| attention_mask = processed_audio.attention_mask.to(model.device) | |
| phonemes: List[str] = transcript.strip().split() | |
| tokenizer = processor.tokenizer | |
| unk_id = getattr(tokenizer, "unk_token_id", None) | |
| ids = tokenizer.convert_tokens_to_ids(phonemes) | |
| if isinstance(ids, int): | |
| ids = [ids] | |
| ids = [i if i is not None else unk_id for i in ids] | |
| canonical_token_ids = torch.tensor([ids], dtype=torch.long).to(model.device) | |
| token_lengths = torch.tensor([len(ids)], dtype=torch.long).to(model.device) | |
| token_mask = torch.ones_like(canonical_token_ids).to(model.device) | |
| with torch.no_grad(): | |
| outputs = model( | |
| input_values=input_values, | |
| attention_mask=attention_mask, | |
| canonical_token_ids=canonical_token_ids, | |
| token_lengths=token_lengths, | |
| token_mask=token_mask | |
| ) | |
| logits = outputs.logits | |
| head_name = next(iter(logits)) | |
| scores_tensor = logits[head_name] | |
| predicted_scores = torch.argmax(scores_tensor, dim=-1) | |
| tokens = processor.tokenizer.convert_ids_to_tokens(canonical_token_ids[0]) | |
| return predicted_scores, tokens, token_lengths | |
| except Exception as e: | |
| logger.error(f"An error occurred during inference: {e}", exc_info=True) | |
| return f"<p style='text-align:center; color:red;'>An error occurred: {e}</p>" |