File size: 6,783 Bytes
8cf218e
 
 
 
ebd2e0f
6c0fcb4
 
2420a56
8cf218e
4d77193
6c0fcb4
a53ba49
 
6c0fcb4
 
 
8cf218e
6c0fcb4
 
 
8cf218e
6c0fcb4
 
 
 
 
 
4d77193
 
 
 
 
 
 
 
 
 
 
8cf218e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f41c01
 
8cf218e
 
 
0f41c01
8cf218e
 
0f41c01
 
 
 
 
8cf218e
 
 
 
 
 
0f41c01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e03eae6
6c0fcb4
 
 
ebd2e0f
 
6c0fcb4
 
 
 
 
 
 
 
 
 
 
 
 
e03eae6
8cf218e
 
 
 
 
 
 
 
6c0fcb4
 
 
 
 
 
 
8cf218e
6c0fcb4
 
0f41c01
 
 
 
 
 
 
6c0fcb4
8cf218e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from typing import Dict, List, Tuple

import logging

import soundfile as sf
import torch
import torchaudio
from transformers import Wav2Vec2Processor

from constants import MAX_AUDIO_DURATION_SECONDS, MONO_CHANNEL, SAMPLING_RATE
from gop_model import GOPPhonemeClassifier

logger = logging.getLogger(__name__)


def load_model_and_processor(model_repo_id: str):
    logger.info("Loading model and processor from Hugging Face Hub: %s", model_repo_id)

    model = GOPPhonemeClassifier.from_pretrained(
        model_repo_id,
        device_map="auto",
    )
    processor = Wav2Vec2Processor.from_pretrained(model_repo_id)
    model.eval()
    return model, processor


def validate_phonemes(phoneme_text, allowed_phonemes):
    if not phoneme_text.strip():
        return "<p style='text-align:center; color:red;'>Please enter the phonemes.</p>"

    phonemes = phoneme_text.strip().split()
    for phoneme in phonemes:
        if phoneme not in allowed_phonemes:
            return f"<p style='text-align:center; color:red;'>Invalid phoneme: '{phoneme}'. Please check your input.</p>"
    return None


def _prepare_canonical_tokens(transcript: str, processor: Wav2Vec2Processor, device: torch.device):
    phonemes: List[str] = transcript.strip().split()
    if not phonemes:
        raise ValueError("Please enter at least one phoneme.")

    token_mask_values = [token != "|" for token in phonemes]
    if not any(token_mask_values):
        raise ValueError("The phoneme sequence must contain at least one non-boundary token.")

    tokenizer = processor.tokenizer
    unk_id = getattr(tokenizer, "unk_token_id", None)
    ids = tokenizer.convert_tokens_to_ids(phonemes)
    if isinstance(ids, int):
        ids = [ids]
    ids = [token_id if token_id is not None else unk_id for token_id in ids]

    canonical_token_ids = torch.tensor([ids], dtype=torch.long, device=device)
    token_lengths = torch.tensor([len(ids)], dtype=torch.long, device=device)
    token_mask = torch.tensor([token_mask_values], dtype=torch.bool, device=device)

    display_tokens = [token for token, is_active in zip(phonemes, token_mask_values) if is_active]
    return canonical_token_ids, token_lengths, token_mask, display_tokens


def _extract_head_predictions(
    logits_by_head: Dict[str, torch.Tensor],
    token_mask: torch.Tensor,
    display_tokens: List[str],
    deltas: Dict[str, float] | None = None,
    correct_index: int = 0,
) -> Dict[str, Tuple[List[int], List[str]]]:
    active_mask = token_mask[0].bool()
    results: Dict[str, Tuple[List[int], List[str]]] = {}
    head_deltas = deltas or {}

    for head_name, head_logits in logits_by_head.items():
        predicted_scores = _predict_scores(
            head_logits,
            delta=head_deltas.get(head_name),
            correct_index=correct_index,
        )[0]
        filtered_scores = predicted_scores[active_mask].detach().cpu().tolist()
        results[head_name] = (filtered_scores, display_tokens)

    return results


def parse_delta_value(value):
    if value is None or value == "":
        return None
    try:
        delta = float(value)
    except (TypeError, ValueError):
        logger.warning("Invalid delta value %r; ignoring.", value)
        return None
    if delta <= 0:
        return None
    return delta


def _predict_scores(scores_tensor, delta=None, correct_index=0):
    num_classes = scores_tensor.size(-1)
    use_delta = delta is not None and delta > 0 and 0 <= correct_index < num_classes

    if use_delta:
        probs = torch.softmax(scores_tensor, dim=-1)
        correct_probs = probs[..., correct_index]
        incorrect_probs = probs.clone()
        incorrect_probs[..., correct_index] = -float("inf")
        max_incorrect_probs, _ = incorrect_probs.max(dim=-1)
        argmax_scores = probs.argmax(dim=-1)
        within_delta = (max_incorrect_probs > correct_probs) & (
            (max_incorrect_probs - correct_probs) <= delta
        )
        predicted_scores = torch.where(
            within_delta,
            torch.tensor(correct_index, device=scores_tensor.device),
            argmax_scores,
        )
    else:
        if delta is not None and (correct_index < 0 or correct_index >= num_classes):
            logger.warning(
                "Delta provided but correct_index=%s is out of range for %s classes.",
                correct_index,
                num_classes,
            )
        predicted_scores = torch.argmax(scores_tensor, dim=-1)

    return predicted_scores


def run_inference(
    audio_file_path: str,
    transcript: str,
    model: GOPPhonemeClassifier,
    processor: Wav2Vec2Processor,
    deltas: Dict[str, float] | None = None,
    correct_index: int = 0,
):
    if not audio_file_path or not transcript:
        return "<p style='text-align:center; color:red;'>Please provide both an audio file and the transcript.</p>"

    try:
        waveform_np, original_sr = sf.read(audio_file_path, dtype="float32", always_2d=True)
        waveform = torch.from_numpy(waveform_np.T)

        duration_seconds = waveform.shape[1] / original_sr
        if duration_seconds > MAX_AUDIO_DURATION_SECONDS:
            raise ValueError(f"The audio recording should not be longer than {MAX_AUDIO_DURATION_SECONDS} seconds.")

        if waveform.shape[0] > MONO_CHANNEL:
            waveform = torch.mean(waveform, dim=0, keepdim=True)

        if original_sr != SAMPLING_RATE:
            resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=SAMPLING_RATE)
            waveform = resampler(waveform)

        audio_input = waveform.squeeze(0)
        processed_audio = processor(audio_input, sampling_rate=SAMPLING_RATE, return_tensors="pt", padding=True)

        model_device = next(model.parameters()).device
        input_values = processed_audio.input_values.to(model_device)
        attention_mask = processed_audio.attention_mask.to(model_device)

        canonical_token_ids, token_lengths, token_mask, display_tokens = _prepare_canonical_tokens(
            transcript, processor, model_device
        )

        with torch.no_grad():
            outputs = model(
                input_values=input_values,
                attention_mask=attention_mask,
                canonical_token_ids=canonical_token_ids,
                token_lengths=token_lengths,
                token_mask=token_mask,
            )

        return _extract_head_predictions(
            outputs.logits,
            token_mask,
            display_tokens,
            deltas=deltas,
            correct_index=correct_index,
        )

    except Exception as exc:
        logger.error("An error occurred during inference: %s", exc, exc_info=True)
        return f"<p style='text-align:center; color:red;'>An error occurred: {exc}</p>"