sayest / utils.py
Aleksei Žavoronkov
enhance run_inference to support token masking and adjust logits handling
ed26f9c
from typing import List
import torch
import torchaudio
from transformers import Wav2Vec2Processor, QuantoConfig
from constants import MAX_AUDIO_DURATION_SECONDS, MONO_CHANNEL, SAMPLING_RATE
from gop_model import GOPPhonemeClassifier
import logging
logger = logging.getLogger(__name__)
def load_model_and_processor(model_repo_id: str):
logger.info(f"Loading model and processor from Hugging Face Hub: {model_repo_id}")
quantization_config = QuantoConfig(weights="int8")
logger.info("Applying INT8 dynamic quantization during model loading...")
model = GOPPhonemeClassifier.from_pretrained(
model_repo_id,
quantization_config=quantization_config,
device_map="auto"
)
processor = Wav2Vec2Processor.from_pretrained(model_repo_id)
model.eval()
return model, processor
def validate_phonemes(phoneme_text, allowed_phonemes):
if not phoneme_text.strip():
return "<p style='text-align:center; color:red;'>Please enter the phonemes.</p>"
phonemes = phoneme_text.strip().split()
for phoneme in phonemes:
if phoneme not in allowed_phonemes:
return f"<p style='text-align:center; color:red;'>Invalid phoneme: '{phoneme}'. Please check your input.</p>"
return None
def run_inference(audio_file_path: str, transcript: str, model: GOPPhonemeClassifier, processor: Wav2Vec2Processor):
if not audio_file_path or not transcript:
return "<p style='text-align:center; color:red;'>Please provide both an audio file and the transcript.</p>"
try:
waveform, original_sr = torchaudio.load(audio_file_path)
duration_seconds = waveform.shape[1] / original_sr
if duration_seconds > MAX_AUDIO_DURATION_SECONDS:
raise ValueError(f"The audio recording should not be longer than {MAX_AUDIO_DURATION_SECONDS} seconds.")
if waveform.shape[0] > MONO_CHANNEL:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if original_sr != SAMPLING_RATE:
resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=SAMPLING_RATE)
waveform = resampler(waveform)
audio_input = waveform.squeeze(0)
processed_audio = processor(audio_input, sampling_rate=SAMPLING_RATE, return_tensors="pt", padding=True)
input_values = processed_audio.input_values.to(model.device)
attention_mask = processed_audio.attention_mask.to(model.device)
phonemes: List[str] = transcript.strip().split()
tokenizer = processor.tokenizer
unk_id = getattr(tokenizer, "unk_token_id", None)
ids = tokenizer.convert_tokens_to_ids(phonemes)
if isinstance(ids, int):
ids = [ids]
ids = [i if i is not None else unk_id for i in ids]
canonical_token_ids = torch.tensor([ids], dtype=torch.long).to(model.device)
token_lengths = torch.tensor([len(ids)], dtype=torch.long).to(model.device)
token_mask = torch.ones_like(canonical_token_ids).to(model.device)
with torch.no_grad():
outputs = model(
input_values=input_values,
attention_mask=attention_mask,
canonical_token_ids=canonical_token_ids,
token_lengths=token_lengths,
token_mask=token_mask
)
logits = outputs.logits
head_name = next(iter(logits))
scores_tensor = logits[head_name]
predicted_scores = torch.argmax(scores_tensor, dim=-1)
tokens = processor.tokenizer.convert_ids_to_tokens(canonical_token_ids[0])
return predicted_scores, tokens, token_lengths
except Exception as e:
logger.error(f"An error occurred during inference: {e}", exc_info=True)
return f"<p style='text-align:center; color:red;'>An error occurred: {e}</p>"