Aleksei Žavoronkov commited on
Commit ·
8cf218e
1
Parent(s): ed26f9c
update model architecture to the latest
Browse files- app.py +37 -74
- constants.py +3 -3
- gop_model.py +209 -153
- models.py +40 -20
- utils.py +62 -32
app.py
CHANGED
|
@@ -1,58 +1,60 @@
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
-
|
|
|
|
| 3 |
from utils import load_model_and_processor, run_inference, validate_phonemes
|
| 4 |
-
from multiprocessing import Process, Queue, set_start_method
|
| 5 |
-
import logging
|
| 6 |
|
| 7 |
logging.basicConfig(
|
| 8 |
level=logging.INFO,
|
| 9 |
format="%(asctime)s - %(levelname)s - %(message)s",
|
| 10 |
-
handlers=[logging.StreamHandler()]
|
| 11 |
)
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
| 14 |
-
logger.info("Loading
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
logger.info("Models loaded successfully.")
|
| 18 |
|
| 19 |
css = """
|
| 20 |
.phoneme-scores { display: flex; flex-wrap: wrap; justify-content: center; gap: 15px; }
|
| 21 |
.phoneme-container { text-align: center; padding: 10px; border: 1px solid #ddd; border-radius: 8px; }
|
| 22 |
.phoneme { font-size: 1.5em; font-weight: bold; margin-bottom: 5px; }
|
| 23 |
.score { padding: 8px 12px; border-radius: 5px; color: white; font-weight: bold; }
|
| 24 |
-
.good { background-color: #28a745; }
|
| 25 |
-
.medium { background-color: #ffc107; }
|
| 26 |
-
.bad { background-color: #dc3545; }
|
| 27 |
"""
|
| 28 |
|
| 29 |
|
| 30 |
def get_score_class(score, score_type):
|
| 31 |
if score_type == "quality":
|
| 32 |
-
if score == 1:
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
return
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
def generate_html_output(result, score_type):
|
| 40 |
if isinstance(result, str):
|
| 41 |
return result
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
score = int(score) + 1
|
| 50 |
-
score_class = get_score_class(score, score_type)
|
| 51 |
|
|
|
|
|
|
|
|
|
|
| 52 |
html_output += f"""
|
| 53 |
<div class='phoneme-container'>
|
| 54 |
<div class='phoneme'>{token}</div>
|
| 55 |
-
<div class='score {score_class}'>{
|
| 56 |
</div>
|
| 57 |
"""
|
| 58 |
|
|
@@ -60,24 +62,7 @@ def generate_html_output(result, score_type):
|
|
| 60 |
return html_output
|
| 61 |
|
| 62 |
|
| 63 |
-
def
|
| 64 |
-
result = run_inference(audio_path, transcript, model, processor)
|
| 65 |
-
|
| 66 |
-
if isinstance(result, str):
|
| 67 |
-
queue.put((model_type, result))
|
| 68 |
-
return
|
| 69 |
-
|
| 70 |
-
predicted_scores, tokens, token_lengths = result
|
| 71 |
-
|
| 72 |
-
scores_list = predicted_scores.cpu().tolist()
|
| 73 |
-
lengths_list = token_lengths.cpu().tolist()
|
| 74 |
-
|
| 75 |
-
safe_result = (scores_list, tokens, lengths_list)
|
| 76 |
-
|
| 77 |
-
queue.put((model_type, safe_result))
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def score_phonemes_in_parallel(phoneme_text, audio_file):
|
| 81 |
if audio_file is None:
|
| 82 |
return "<p style='text-align:center; color:red;'>Please upload a .wav audio file.</p>", ""
|
| 83 |
|
|
@@ -85,27 +70,9 @@ def score_phonemes_in_parallel(phoneme_text, audio_file):
|
|
| 85 |
if phonemes_validation_error:
|
| 86 |
return phonemes_validation_error, ""
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
target=inference_wrapper,
|
| 92 |
-
args=("quality", phoneme_model, phoneme_processor, audio_file, phoneme_text, results_queue)
|
| 93 |
-
)
|
| 94 |
-
duration_process = Process(
|
| 95 |
-
target=inference_wrapper,
|
| 96 |
-
args=("duration", duration_model, duration_processor, audio_file, phoneme_text, results_queue)
|
| 97 |
-
)
|
| 98 |
-
|
| 99 |
-
quality_process.start()
|
| 100 |
-
duration_process.start()
|
| 101 |
-
|
| 102 |
-
quality_process.join()
|
| 103 |
-
duration_process.join()
|
| 104 |
-
|
| 105 |
-
results = {}
|
| 106 |
-
while not results_queue.empty():
|
| 107 |
-
key, value = results_queue.get()
|
| 108 |
-
results[key] = value
|
| 109 |
|
| 110 |
quality_result = results.get("quality")
|
| 111 |
duration_result = results.get("duration")
|
|
@@ -120,8 +87,9 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
|
| 120 |
gr.Markdown(
|
| 121 |
"""
|
| 122 |
# Phoneme Pronunciation and Duration Scorer
|
| 123 |
-
Enter
|
| 124 |
-
|
|
|
|
| 125 |
The application will provide a pronunciation (quality) and duration score for each phoneme.
|
| 126 |
|
| 127 |
Scores legend:
|
|
@@ -144,7 +112,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
|
| 144 |
["aja (L2 speaker)", "a j a", "./audio/L2/03ac-e45b-ec8a-6fa0_aja_take1.wav"],
|
| 145 |
["sõpra (L2 speaker)", "s õ pp r a", "./audio/L2/4071-0c77-e1d3-9587_sõpra_take1.wav"],
|
| 146 |
],
|
| 147 |
-
inputs=[word_input, phoneme_text_input, audio_input]
|
| 148 |
)
|
| 149 |
|
| 150 |
gr.Markdown("---")
|
|
@@ -155,16 +123,11 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
|
| 155 |
duration_output_html = gr.HTML()
|
| 156 |
|
| 157 |
btn.click(
|
| 158 |
-
fn=
|
| 159 |
inputs=[phoneme_text_input, audio_input],
|
| 160 |
-
outputs=[phoneme_output_html, duration_output_html]
|
|
|
|
| 161 |
)
|
| 162 |
|
| 163 |
if __name__ == "__main__":
|
| 164 |
-
try:
|
| 165 |
-
set_start_method("fork", force=True)
|
| 166 |
-
logger.info("Multiprocessing start method set to 'fork'.")
|
| 167 |
-
except RuntimeError:
|
| 168 |
-
logger.warning("Start method has already been set.")
|
| 169 |
-
|
| 170 |
demo.queue(default_concurrency_limit=2).launch()
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
import gradio as gr
|
| 4 |
+
|
| 5 |
+
from constants import ALL_PHONEMES, MODEL_REPO_ID
|
| 6 |
from utils import load_model_and_processor, run_inference, validate_phonemes
|
|
|
|
|
|
|
| 7 |
|
| 8 |
logging.basicConfig(
|
| 9 |
level=logging.INFO,
|
| 10 |
format="%(asctime)s - %(levelname)s - %(message)s",
|
| 11 |
+
handlers=[logging.StreamHandler()],
|
| 12 |
)
|
| 13 |
logger = logging.getLogger(__name__)
|
| 14 |
|
| 15 |
+
logger.info("Loading model into memory globally")
|
| 16 |
+
model, processor = load_model_and_processor(MODEL_REPO_ID)
|
| 17 |
+
logger.info("Model loaded successfully")
|
|
|
|
| 18 |
|
| 19 |
css = """
|
| 20 |
.phoneme-scores { display: flex; flex-wrap: wrap; justify-content: center; gap: 15px; }
|
| 21 |
.phoneme-container { text-align: center; padding: 10px; border: 1px solid #ddd; border-radius: 8px; }
|
| 22 |
.phoneme { font-size: 1.5em; font-weight: bold; margin-bottom: 5px; }
|
| 23 |
.score { padding: 8px 12px; border-radius: 5px; color: white; font-weight: bold; }
|
| 24 |
+
.good { background-color: #28a745; }
|
| 25 |
+
.medium { background-color: #ffc107; }
|
| 26 |
+
.bad { background-color: #dc3545; }
|
| 27 |
"""
|
| 28 |
|
| 29 |
|
| 30 |
def get_score_class(score, score_type):
|
| 31 |
if score_type == "quality":
|
| 32 |
+
if score == 1:
|
| 33 |
+
return "good"
|
| 34 |
+
if score == 2:
|
| 35 |
+
return "medium"
|
| 36 |
+
return "bad"
|
| 37 |
+
return "good" if score == 1 else "bad"
|
| 38 |
|
| 39 |
|
| 40 |
def generate_html_output(result, score_type):
|
| 41 |
if isinstance(result, str):
|
| 42 |
return result
|
| 43 |
|
| 44 |
+
if not result:
|
| 45 |
+
return "<p style='text-align:center; color:red;'>No scores were produced.</p>"
|
| 46 |
|
| 47 |
+
predicted_scores, tokens = result
|
| 48 |
+
title = "Quality Scores" if score_type == "quality" else "Duration Scores"
|
| 49 |
+
html_output = f"<div class='phoneme-section'><h3 class='scores-title'>{title}</h3></div><div class='phoneme-scores'>"
|
|
|
|
|
|
|
| 50 |
|
| 51 |
+
for token, score in zip(tokens, predicted_scores):
|
| 52 |
+
display_score = int(score) + 1
|
| 53 |
+
score_class = get_score_class(display_score, score_type)
|
| 54 |
html_output += f"""
|
| 55 |
<div class='phoneme-container'>
|
| 56 |
<div class='phoneme'>{token}</div>
|
| 57 |
+
<div class='score {score_class}'>{display_score}</div>
|
| 58 |
</div>
|
| 59 |
"""
|
| 60 |
|
|
|
|
| 62 |
return html_output
|
| 63 |
|
| 64 |
|
| 65 |
+
def score_phonemes(phoneme_text, audio_file):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
if audio_file is None:
|
| 67 |
return "<p style='text-align:center; color:red;'>Please upload a .wav audio file.</p>", ""
|
| 68 |
|
|
|
|
| 70 |
if phonemes_validation_error:
|
| 71 |
return phonemes_validation_error, ""
|
| 72 |
|
| 73 |
+
results = run_inference(audio_file, phoneme_text, model, processor)
|
| 74 |
+
if isinstance(results, str):
|
| 75 |
+
return results, ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
quality_result = results.get("quality")
|
| 78 |
duration_result = results.get("duration")
|
|
|
|
| 87 |
gr.Markdown(
|
| 88 |
"""
|
| 89 |
# Phoneme Pronunciation and Duration Scorer
|
| 90 |
+
Enter phonemes directly into the text box, separated by spaces.
|
| 91 |
+
Use `|` between words if you want to score a multi-word sequence.
|
| 92 |
+
Then upload a `.wav` file or record the audio of the pronounced word.
|
| 93 |
The application will provide a pronunciation (quality) and duration score for each phoneme.
|
| 94 |
|
| 95 |
Scores legend:
|
|
|
|
| 112 |
["aja (L2 speaker)", "a j a", "./audio/L2/03ac-e45b-ec8a-6fa0_aja_take1.wav"],
|
| 113 |
["sõpra (L2 speaker)", "s õ pp r a", "./audio/L2/4071-0c77-e1d3-9587_sõpra_take1.wav"],
|
| 114 |
],
|
| 115 |
+
inputs=[word_input, phoneme_text_input, audio_input],
|
| 116 |
)
|
| 117 |
|
| 118 |
gr.Markdown("---")
|
|
|
|
| 123 |
duration_output_html = gr.HTML()
|
| 124 |
|
| 125 |
btn.click(
|
| 126 |
+
fn=score_phonemes,
|
| 127 |
inputs=[phoneme_text_input, audio_input],
|
| 128 |
+
outputs=[phoneme_output_html, duration_output_html],
|
| 129 |
+
api_name=False,
|
| 130 |
)
|
| 131 |
|
| 132 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
demo.queue(default_concurrency_limit=2).launch()
|
constants.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
-
|
| 2 |
-
DURATION_MODEL_REPO_ID = "alzavo/sayest-duration"
|
| 3 |
SAMPLING_RATE = 16000
|
| 4 |
MONO_CHANNEL = 1
|
| 5 |
MAX_AUDIO_DURATION_SECONDS = 5
|
|
@@ -94,4 +93,5 @@ ALL_PHONEMES = {
|
|
| 94 |
"u:",
|
| 95 |
"w",
|
| 96 |
"j:j",
|
| 97 |
-
|
|
|
|
|
|
| 1 |
+
MODEL_REPO_ID = "alzavo/sayest-latest"
|
|
|
|
| 2 |
SAMPLING_RATE = 16000
|
| 3 |
MONO_CHANNEL = 1
|
| 4 |
MAX_AUDIO_DURATION_SECONDS = 5
|
|
|
|
| 93 |
"u:",
|
| 94 |
"w",
|
| 95 |
"j:j",
|
| 96 |
+
"|",
|
| 97 |
+
}
|
gop_model.py
CHANGED
|
@@ -12,9 +12,6 @@ from models import OrdinalLogLoss
|
|
| 12 |
|
| 13 |
|
| 14 |
class GOPWav2Vec2Config(PretrainedConfig):
|
| 15 |
-
"""
|
| 16 |
-
Configuration for GOP-enhanced model that wraps a Wav2Vec2ForCTC backbone.
|
| 17 |
-
"""
|
| 18 |
model_type = "gop-wav2vec2"
|
| 19 |
|
| 20 |
def __init__(
|
|
@@ -33,6 +30,7 @@ class GOPWav2Vec2Config(PretrainedConfig):
|
|
| 33 |
unk_id: Optional[int] = None,
|
| 34 |
bos_id: Optional[int] = None,
|
| 35 |
eos_id: Optional[int] = None,
|
|
|
|
| 36 |
token_id_vocab: Optional[List[int]] = None,
|
| 37 |
ctc_config: Optional[dict] = None,
|
| 38 |
**kwargs,
|
|
@@ -57,21 +55,17 @@ class GOPWav2Vec2Config(PretrainedConfig):
|
|
| 57 |
self.unk_id = unk_id
|
| 58 |
self.bos_id = bos_id
|
| 59 |
self.eos_id = eos_id
|
|
|
|
| 60 |
self.token_id_vocab = token_id_vocab
|
| 61 |
self.ctc_config = ctc_config
|
| 62 |
|
| 63 |
|
| 64 |
class GOPPhonemeClassifier(PreTrainedModel):
|
| 65 |
-
"""
|
| 66 |
-
GOP classifier that wraps a pretrained Wav2Vec2ForCTC backbone.
|
| 67 |
-
Computes per-phoneme scores using GOP-derived features + a small Transformer + classifier head.
|
| 68 |
-
"""
|
| 69 |
-
|
| 70 |
config_class = GOPWav2Vec2Config
|
| 71 |
-
|
| 72 |
def __init__(self, config: GOPWav2Vec2Config, load_pretrained_backbone: bool = False):
|
| 73 |
super().__init__(config)
|
| 74 |
-
|
| 75 |
if config.ctc_config is not None:
|
| 76 |
backbone_config = Wav2Vec2Config.from_dict(config.ctc_config)
|
| 77 |
elif config.base_model_name_or_path is not None:
|
|
@@ -82,35 +76,43 @@ class GOPPhonemeClassifier(PreTrainedModel):
|
|
| 82 |
self.ctc_model = Wav2Vec2ForCTC(backbone_config)
|
| 83 |
self.config.ctc_config = backbone_config.to_dict()
|
| 84 |
|
| 85 |
-
# Special ids
|
| 86 |
self.blank_id = config.pad_id
|
| 87 |
self.unk_id = config.unk_id
|
| 88 |
self.bos_id = config.bos_id
|
| 89 |
self.eos_id = config.eos_id
|
| 90 |
self.pad_id = self.blank_id
|
|
|
|
| 91 |
|
| 92 |
special_ids = {self.blank_id, self.unk_id, self.bos_id, self.eos_id, self.pad_id}
|
| 93 |
self.special_ids = {i for i in special_ids if i is not None}
|
| 94 |
|
| 95 |
vocab_size = int(self.ctc_model.config.vocab_size)
|
| 96 |
self.token_id_vocab = (
|
| 97 |
-
config.token_id_vocab
|
|
|
|
|
|
|
| 98 |
)
|
| 99 |
|
| 100 |
self.gop_feature_dim = 1 + len(self.token_id_vocab) + 1
|
| 101 |
self.embedding_dim = int(config.gop_embedding_dim)
|
| 102 |
-
self.token_embedding = nn.Embedding(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
self.combined_feature_dim = self.gop_feature_dim + self.embedding_dim
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
| 111 |
batch_first=True,
|
| 112 |
)
|
| 113 |
-
self.gop_transformer_encoder = nn.TransformerEncoder(enc_layer, num_layers=config.gop_transformer_nlayers)
|
| 114 |
|
| 115 |
head_label_config = getattr(config, "gop_head_labels", None)
|
| 116 |
if head_label_config is None:
|
|
@@ -118,13 +120,24 @@ class GOPPhonemeClassifier(PreTrainedModel):
|
|
| 118 |
raise ValueError("Config must provide gop_head_labels or num_gop_labels for the classifier.")
|
| 119 |
head_label_config = {"default": int(config.num_gop_labels)}
|
| 120 |
self.head_label_config = {str(k): int(v) for k, v in head_label_config.items()}
|
| 121 |
-
self.
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
self._init_losses()
|
| 127 |
-
|
| 128 |
self.post_init()
|
| 129 |
|
| 130 |
if load_pretrained_backbone:
|
|
@@ -152,11 +165,15 @@ class GOPPhonemeClassifier(PreTrainedModel):
|
|
| 152 |
elif weights is not None:
|
| 153 |
head_weights = weights
|
| 154 |
if head_weights is not None:
|
| 155 |
-
head_weights =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
loss_modules[head] = OrdinalLogLoss(
|
| 157 |
num_classes=int(num_labels),
|
| 158 |
alpha=alpha,
|
| 159 |
-
reduction=
|
| 160 |
class_weights=head_weights,
|
| 161 |
)
|
| 162 |
self.loss_fns = nn.ModuleDict(loss_modules)
|
|
@@ -168,7 +185,6 @@ class GOPPhonemeClassifier(PreTrainedModel):
|
|
| 168 |
target_ids: torch.Tensor,
|
| 169 |
target_lengths: torch.Tensor,
|
| 170 |
) -> torch.Tensor:
|
| 171 |
-
"""CTC log p(target|input) per item for a batch."""
|
| 172 |
target_ids_cpu = target_ids.cpu()
|
| 173 |
target_lengths_cpu = target_lengths.cpu()
|
| 174 |
log_probs_cpu = log_probs_TNC.cpu()
|
|
@@ -176,23 +192,22 @@ class GOPPhonemeClassifier(PreTrainedModel):
|
|
| 176 |
|
| 177 |
targets_flat = []
|
| 178 |
for i in range(target_ids_cpu.size(0)):
|
| 179 |
-
valid_targets = target_ids_cpu[i, :target_lengths_cpu[i]]
|
| 180 |
targets_flat.append(valid_targets)
|
| 181 |
targets_cat = torch.cat(targets_flat) if targets_flat else torch.tensor([], dtype=torch.long)
|
| 182 |
|
| 183 |
if target_lengths_cpu.sum() == 0:
|
| 184 |
-
return torch.full((log_probs_TNC.size(1),), -float(
|
| 185 |
|
| 186 |
-
ctc_loss_fn = torch.nn.CTCLoss(blank=self.blank_id, reduction=
|
| 187 |
try:
|
| 188 |
loss_per_item = ctc_loss_fn(log_probs_cpu, targets_cat, input_lengths_cpu, target_lengths_cpu)
|
| 189 |
return -loss_per_item.to(log_probs_TNC.device)
|
| 190 |
-
except Exception as
|
| 191 |
-
warnings.warn(f"CTCLoss calculation failed: {
|
| 192 |
-
return torch.full((log_probs_TNC.size(1),), -float(
|
| 193 |
|
| 194 |
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
| 195 |
-
"""Compute time dimension after backbone feature extractor."""
|
| 196 |
def _conv_out_length(input_length, kernel_size, stride):
|
| 197 |
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
|
| 198 |
|
|
@@ -200,6 +215,92 @@ class GOPPhonemeClassifier(PreTrainedModel):
|
|
| 200 |
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
| 201 |
return input_lengths
|
| 202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
def forward(
|
| 204 |
self,
|
| 205 |
input_values: torch.Tensor,
|
|
@@ -212,10 +313,11 @@ class GOPPhonemeClassifier(PreTrainedModel):
|
|
| 212 |
return_dict: Optional[bool] = None,
|
| 213 |
labels: Optional[torch.Tensor] = None,
|
| 214 |
) -> SequenceClassifierOutput:
|
| 215 |
-
|
| 216 |
device = input_values.device
|
|
|
|
| 217 |
|
| 218 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
| 219 |
|
| 220 |
if self.training or labels is not None:
|
| 221 |
if canonical_token_ids is None or token_lengths is None or token_mask is None:
|
|
@@ -224,7 +326,15 @@ class GOPPhonemeClassifier(PreTrainedModel):
|
|
| 224 |
if token_mask is None:
|
| 225 |
raise ValueError("`token_mask` must be provided to GOPPhonemeClassifier.forward.")
|
| 226 |
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
outputs = self.ctc_model.wav2vec2(
|
| 229 |
input_values,
|
| 230 |
attention_mask=attention_mask,
|
|
@@ -234,150 +344,96 @@ class GOPPhonemeClassifier(PreTrainedModel):
|
|
| 234 |
)
|
| 235 |
hidden_states = outputs.last_hidden_state
|
| 236 |
|
| 237 |
-
# 2) Frame-level logits for CTC
|
| 238 |
logits_ctc = self.ctc_model.lm_head(hidden_states)
|
| 239 |
log_probs_ctc = F.log_softmax(logits_ctc, dim=-1)
|
| 240 |
log_probs_TNC = log_probs_ctc.permute(1, 0, 2).contiguous()
|
| 241 |
|
| 242 |
-
# 3) Frame lengths
|
| 243 |
batch_size = input_values.size(0)
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
input_lengths_samples = attention_mask.sum(dim=-1)
|
| 248 |
-
input_lengths_frames = self._get_feat_extract_output_lengths(input_lengths_samples)
|
| 249 |
-
input_lengths_frames = torch.clamp(input_lengths_frames, max=log_probs_TNC.size(0))
|
| 250 |
|
| 251 |
-
# 4) GOP feature calculation over tokens
|
| 252 |
max_token_len = canonical_token_ids.size(1) if canonical_token_ids is not None else 0
|
| 253 |
-
batch_combined_features_list = [
|
| 254 |
-
|
|
|
|
| 255 |
|
| 256 |
lpp_log_prob_batch = self._calculate_log_prob(
|
| 257 |
log_probs_TNC, input_lengths_frames, canonical_token_ids, token_lengths
|
| 258 |
)
|
| 259 |
|
| 260 |
for token_idx in range(max_token_len):
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
mask_column = token_mask_bool[:, token_idx] if token_mask_bool.dim() == 2 else token_mask_bool
|
| 265 |
-
skip_mask = token_out_of_bounds_mask | ~mask_column
|
| 266 |
-
|
| 267 |
-
if skip_mask.all():
|
| 268 |
-
continue
|
| 269 |
-
|
| 270 |
-
active_mask = ~skip_mask
|
| 271 |
|
| 272 |
all_sub_log_probs = []
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
sub_lpr_batch[skip_mask, :] = 0.0
|
| 287 |
-
else:
|
| 288 |
-
sub_lpr_batch = torch.zeros((batch_size, 0), device=device)
|
| 289 |
|
| 290 |
-
# Deletion GOP component
|
| 291 |
del_lpr_list = []
|
| 292 |
for b_idx in range(batch_size):
|
| 293 |
if skip_mask[b_idx]:
|
| 294 |
del_lpr_list.append(torch.tensor(-1e10, device=device))
|
| 295 |
-
continue
|
| 296 |
-
item_tokens = canonical_token_ids[b_idx, : token_lengths[b_idx]].tolist()
|
| 297 |
-
del_tokens_list = item_tokens[:token_idx] + item_tokens[token_idx + 1:]
|
| 298 |
-
if not del_tokens_list:
|
| 299 |
-
log_prob_del_item = torch.tensor(-float('inf'), device=device)
|
| 300 |
else:
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
del_lpr_batch = torch.stack(del_lpr_list)
|
| 314 |
|
| 315 |
-
gop_part = torch.cat(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
combined_features = torch.cat([gop_part, current_token_embeddings], dim=1)
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
feature_lengths_list = [len(seq_list) for seq_list in batch_combined_features_list]
|
| 323 |
-
if feature_lengths_list:
|
| 324 |
-
feature_lengths = torch.tensor(feature_lengths_list, dtype=torch.long, device=device)
|
| 325 |
-
target_pad_len = int(feature_lengths.max().item())
|
| 326 |
-
else:
|
| 327 |
-
feature_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
| 328 |
-
target_pad_len = 0
|
| 329 |
-
|
| 330 |
-
padded_sequences = []
|
| 331 |
-
for seq_list in batch_combined_features_list:
|
| 332 |
-
if seq_list:
|
| 333 |
-
seq_tensor = torch.stack(seq_list, dim=0)
|
| 334 |
-
pad_len = target_pad_len - seq_tensor.size(0)
|
| 335 |
-
if pad_len > 0:
|
| 336 |
-
seq_tensor = F.pad(seq_tensor, (0, 0, 0, pad_len))
|
| 337 |
-
padded_sequences.append(seq_tensor)
|
| 338 |
-
else:
|
| 339 |
-
padded_sequences.append(torch.zeros((target_pad_len, self.combined_feature_dim), device=device))
|
| 340 |
-
transformer_input = torch.stack(padded_sequences, dim=0) if padded_sequences else torch.zeros((0, 0, self.combined_feature_dim), device=device)
|
| 341 |
-
transformer_padding_mask = torch.arange(target_pad_len, device=device)[None, :] >= feature_lengths[:, None]
|
| 342 |
-
|
| 343 |
-
# 6) GOP transformer
|
| 344 |
-
gop_transformer_output = self.gop_transformer_encoder(
|
| 345 |
-
transformer_input,
|
| 346 |
-
src_key_padding_mask=transformer_padding_mask
|
| 347 |
-
)
|
| 348 |
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
}
|
| 354 |
|
| 355 |
-
# 8) Loss
|
| 356 |
loss = None
|
| 357 |
if labels is not None:
|
| 358 |
-
|
| 359 |
-
label_map = {next(iter(final_logits.keys())): labels}
|
| 360 |
-
elif isinstance(labels, dict):
|
| 361 |
-
label_map = labels
|
| 362 |
-
else:
|
| 363 |
-
raise TypeError("labels must be a Tensor or a dict of Tensors when provided.")
|
| 364 |
-
|
| 365 |
-
active_mask = ~transformer_padding_mask.view(-1)
|
| 366 |
for head, head_logits in final_logits.items():
|
| 367 |
-
head_labels =
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
logits_flat = head_logits.view(-1, head_logits.size(-1))
|
| 371 |
-
labels_flat = head_labels.view(-1)
|
| 372 |
-
active_logits = logits_flat[active_mask]
|
| 373 |
-
#breakpoint()
|
| 374 |
-
active_labels = labels_flat[active_mask]
|
| 375 |
-
if active_labels.numel() == 0:
|
| 376 |
-
continue
|
| 377 |
-
head_loss = self.loss_fns[head](active_logits, active_labels)
|
| 378 |
-
loss = head_loss if loss is None else loss + head_loss
|
| 379 |
-
if loss is None:
|
| 380 |
-
loss = torch.tensor(0.0, device=device, requires_grad=True)
|
| 381 |
|
| 382 |
if not return_dict:
|
| 383 |
output = (final_logits,)
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class GOPWav2Vec2Config(PretrainedConfig):
|
|
|
|
|
|
|
|
|
|
| 15 |
model_type = "gop-wav2vec2"
|
| 16 |
|
| 17 |
def __init__(
|
|
|
|
| 30 |
unk_id: Optional[int] = None,
|
| 31 |
bos_id: Optional[int] = None,
|
| 32 |
eos_id: Optional[int] = None,
|
| 33 |
+
word_boundary_id: Optional[int] = None,
|
| 34 |
token_id_vocab: Optional[List[int]] = None,
|
| 35 |
ctc_config: Optional[dict] = None,
|
| 36 |
**kwargs,
|
|
|
|
| 55 |
self.unk_id = unk_id
|
| 56 |
self.bos_id = bos_id
|
| 57 |
self.eos_id = eos_id
|
| 58 |
+
self.word_boundary_id = word_boundary_id
|
| 59 |
self.token_id_vocab = token_id_vocab
|
| 60 |
self.ctc_config = ctc_config
|
| 61 |
|
| 62 |
|
| 63 |
class GOPPhonemeClassifier(PreTrainedModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
config_class = GOPWav2Vec2Config
|
| 65 |
+
|
| 66 |
def __init__(self, config: GOPWav2Vec2Config, load_pretrained_backbone: bool = False):
|
| 67 |
super().__init__(config)
|
| 68 |
+
|
| 69 |
if config.ctc_config is not None:
|
| 70 |
backbone_config = Wav2Vec2Config.from_dict(config.ctc_config)
|
| 71 |
elif config.base_model_name_or_path is not None:
|
|
|
|
| 76 |
self.ctc_model = Wav2Vec2ForCTC(backbone_config)
|
| 77 |
self.config.ctc_config = backbone_config.to_dict()
|
| 78 |
|
|
|
|
| 79 |
self.blank_id = config.pad_id
|
| 80 |
self.unk_id = config.unk_id
|
| 81 |
self.bos_id = config.bos_id
|
| 82 |
self.eos_id = config.eos_id
|
| 83 |
self.pad_id = self.blank_id
|
| 84 |
+
self.word_boundary_id = config.word_boundary_id
|
| 85 |
|
| 86 |
special_ids = {self.blank_id, self.unk_id, self.bos_id, self.eos_id, self.pad_id}
|
| 87 |
self.special_ids = {i for i in special_ids if i is not None}
|
| 88 |
|
| 89 |
vocab_size = int(self.ctc_model.config.vocab_size)
|
| 90 |
self.token_id_vocab = (
|
| 91 |
+
config.token_id_vocab
|
| 92 |
+
if config.token_id_vocab is not None
|
| 93 |
+
else [i for i in range(vocab_size) if i not in self.special_ids]
|
| 94 |
)
|
| 95 |
|
| 96 |
self.gop_feature_dim = 1 + len(self.token_id_vocab) + 1
|
| 97 |
self.embedding_dim = int(config.gop_embedding_dim)
|
| 98 |
+
self.token_embedding = nn.Embedding(
|
| 99 |
+
vocab_size,
|
| 100 |
+
self.embedding_dim,
|
| 101 |
+
padding_idx=self.pad_id if self.pad_id is not None else 0,
|
| 102 |
+
)
|
| 103 |
self.combined_feature_dim = self.gop_feature_dim + self.embedding_dim
|
| 104 |
+
self.gop_part_dropout = nn.Dropout(config.gop_transformer_dropout)
|
| 105 |
+
self.gop_part_norm = nn.LayerNorm(self.gop_feature_dim)
|
| 106 |
+
|
| 107 |
+
self.lstm_hidden_size = int(config.gop_transformer_dim_feedforward)
|
| 108 |
+
self.gop_rnn = nn.LSTM(
|
| 109 |
+
input_size=self.combined_feature_dim,
|
| 110 |
+
hidden_size=self.lstm_hidden_size,
|
| 111 |
+
num_layers=config.gop_transformer_nlayers,
|
| 112 |
+
dropout=config.gop_transformer_dropout if config.gop_transformer_nlayers > 1 else 0.0,
|
| 113 |
+
bidirectional=True,
|
| 114 |
batch_first=True,
|
| 115 |
)
|
|
|
|
| 116 |
|
| 117 |
head_label_config = getattr(config, "gop_head_labels", None)
|
| 118 |
if head_label_config is None:
|
|
|
|
| 120 |
raise ValueError("Config must provide gop_head_labels or num_gop_labels for the classifier.")
|
| 121 |
head_label_config = {"default": int(config.num_gop_labels)}
|
| 122 |
self.head_label_config = {str(k): int(v) for k, v in head_label_config.items()}
|
| 123 |
+
self.head_hidden_dim = self.lstm_hidden_size * 2
|
| 124 |
+
self.head_mlps = nn.ModuleDict(
|
| 125 |
+
{
|
| 126 |
+
head: nn.Sequential(
|
| 127 |
+
nn.Linear(self.head_hidden_dim, self.head_hidden_dim),
|
| 128 |
+
nn.LeakyReLU(),
|
| 129 |
+
)
|
| 130 |
+
for head in self.head_label_config.keys()
|
| 131 |
+
}
|
| 132 |
+
)
|
| 133 |
+
self.classifiers = nn.ModuleDict(
|
| 134 |
+
{
|
| 135 |
+
head: nn.Linear(self.head_hidden_dim, num_labels)
|
| 136 |
+
for head, num_labels in self.head_label_config.items()
|
| 137 |
+
}
|
| 138 |
+
)
|
| 139 |
|
| 140 |
self._init_losses()
|
|
|
|
| 141 |
self.post_init()
|
| 142 |
|
| 143 |
if load_pretrained_backbone:
|
|
|
|
| 165 |
elif weights is not None:
|
| 166 |
head_weights = weights
|
| 167 |
if head_weights is not None:
|
| 168 |
+
head_weights = (
|
| 169 |
+
head_weights
|
| 170 |
+
if isinstance(head_weights, torch.Tensor)
|
| 171 |
+
else torch.tensor(head_weights, dtype=torch.float)
|
| 172 |
+
)
|
| 173 |
loss_modules[head] = OrdinalLogLoss(
|
| 174 |
num_classes=int(num_labels),
|
| 175 |
alpha=alpha,
|
| 176 |
+
reduction="mean",
|
| 177 |
class_weights=head_weights,
|
| 178 |
)
|
| 179 |
self.loss_fns = nn.ModuleDict(loss_modules)
|
|
|
|
| 185 |
target_ids: torch.Tensor,
|
| 186 |
target_lengths: torch.Tensor,
|
| 187 |
) -> torch.Tensor:
|
|
|
|
| 188 |
target_ids_cpu = target_ids.cpu()
|
| 189 |
target_lengths_cpu = target_lengths.cpu()
|
| 190 |
log_probs_cpu = log_probs_TNC.cpu()
|
|
|
|
| 192 |
|
| 193 |
targets_flat = []
|
| 194 |
for i in range(target_ids_cpu.size(0)):
|
| 195 |
+
valid_targets = target_ids_cpu[i, : target_lengths_cpu[i]]
|
| 196 |
targets_flat.append(valid_targets)
|
| 197 |
targets_cat = torch.cat(targets_flat) if targets_flat else torch.tensor([], dtype=torch.long)
|
| 198 |
|
| 199 |
if target_lengths_cpu.sum() == 0:
|
| 200 |
+
return torch.full((log_probs_TNC.size(1),), -float("inf"), device=log_probs_TNC.device)
|
| 201 |
|
| 202 |
+
ctc_loss_fn = torch.nn.CTCLoss(blank=self.blank_id, reduction="none", zero_infinity=True)
|
| 203 |
try:
|
| 204 |
loss_per_item = ctc_loss_fn(log_probs_cpu, targets_cat, input_lengths_cpu, target_lengths_cpu)
|
| 205 |
return -loss_per_item.to(log_probs_TNC.device)
|
| 206 |
+
except Exception as exc:
|
| 207 |
+
warnings.warn(f"CTCLoss calculation failed: {exc}. Returning -inf for batch.")
|
| 208 |
+
return torch.full((log_probs_TNC.size(1),), -float("inf"), device=log_probs_TNC.device)
|
| 209 |
|
| 210 |
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
|
|
|
| 211 |
def _conv_out_length(input_length, kernel_size, stride):
|
| 212 |
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
|
| 213 |
|
|
|
|
| 215 |
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
| 216 |
return input_lengths
|
| 217 |
|
| 218 |
+
def _prepare_labels(
|
| 219 |
+
self, labels: Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]]
|
| 220 |
+
) -> Optional[Dict[str, torch.Tensor]]:
|
| 221 |
+
if labels is None:
|
| 222 |
+
return None
|
| 223 |
+
if isinstance(labels, torch.Tensor):
|
| 224 |
+
if len(self.head_label_config) != 1:
|
| 225 |
+
raise ValueError("Multi-head setup requires `labels` to be a dict keyed by head name.")
|
| 226 |
+
head_name = next(iter(self.head_label_config))
|
| 227 |
+
return {head_name: labels}
|
| 228 |
+
if not isinstance(labels, dict):
|
| 229 |
+
raise ValueError("`labels` must be a Tensor for single-head setups or a dict for multi-head setups.")
|
| 230 |
+
return labels
|
| 231 |
+
|
| 232 |
+
def _validate_inputs(
|
| 233 |
+
self,
|
| 234 |
+
input_values: torch.Tensor,
|
| 235 |
+
attention_mask: torch.Tensor,
|
| 236 |
+
canonical_token_ids: torch.Tensor,
|
| 237 |
+
token_lengths: torch.Tensor,
|
| 238 |
+
token_mask: torch.Tensor,
|
| 239 |
+
labels: Optional[Dict[str, torch.Tensor]],
|
| 240 |
+
) -> torch.Tensor:
|
| 241 |
+
if input_values.dim() != 2:
|
| 242 |
+
raise ValueError(f"`input_values` must be 2D (batch, time); got shape {tuple(input_values.shape)}.")
|
| 243 |
+
if attention_mask is None or attention_mask.shape != input_values.shape:
|
| 244 |
+
raise ValueError("`attention_mask` must be provided and match the shape of `input_values`.")
|
| 245 |
+
|
| 246 |
+
if canonical_token_ids is None or token_lengths is None or token_mask is None:
|
| 247 |
+
raise ValueError("`canonical_token_ids`, `token_lengths`, and `token_mask` are required.")
|
| 248 |
+
|
| 249 |
+
if canonical_token_ids.dim() != 2:
|
| 250 |
+
raise ValueError(
|
| 251 |
+
f"`canonical_token_ids` must be 2D (batch, tokens); got shape {tuple(canonical_token_ids.shape)}."
|
| 252 |
+
)
|
| 253 |
+
batch_size, max_tokens = canonical_token_ids.shape
|
| 254 |
+
if batch_size != input_values.shape[0]:
|
| 255 |
+
raise ValueError("Batch size mismatch between `input_values` and `canonical_token_ids`.")
|
| 256 |
+
|
| 257 |
+
if token_mask.dim() != 2 or token_mask.shape != canonical_token_ids.shape:
|
| 258 |
+
raise ValueError("`token_mask` must be the same shape as `canonical_token_ids`.")
|
| 259 |
+
|
| 260 |
+
if token_lengths.dim() != 1 or token_lengths.shape[0] != batch_size:
|
| 261 |
+
raise ValueError("`token_lengths` must be 1D with length equal to batch size.")
|
| 262 |
+
if torch.any(token_lengths < 0):
|
| 263 |
+
raise ValueError("`token_lengths` must be non-negative.")
|
| 264 |
+
if torch.any(token_lengths > max_tokens):
|
| 265 |
+
raise ValueError("`token_lengths` cannot exceed the number of provided tokens.")
|
| 266 |
+
|
| 267 |
+
token_mask_bool = token_mask.to(device=canonical_token_ids.device, dtype=torch.bool)
|
| 268 |
+
arange_positions = torch.arange(max_tokens, device=canonical_token_ids.device)
|
| 269 |
+
padded_active = token_mask_bool & (arange_positions.unsqueeze(0) >= token_lengths.unsqueeze(1))
|
| 270 |
+
if torch.any(padded_active):
|
| 271 |
+
raise ValueError("`token_mask` marks padded positions as valid (indices >= token_lengths).")
|
| 272 |
+
if torch.any(token_mask_bool.sum(dim=1) > token_lengths):
|
| 273 |
+
raise ValueError("`token_mask` has more active positions than `token_lengths` for some batch items.")
|
| 274 |
+
|
| 275 |
+
if labels is not None:
|
| 276 |
+
if not isinstance(labels, dict):
|
| 277 |
+
raise ValueError("`labels` must be a dict keyed by head name after normalization.")
|
| 278 |
+
expected_heads = set(self.head_label_config.keys())
|
| 279 |
+
label_heads = set(labels.keys())
|
| 280 |
+
unknown_heads = label_heads - expected_heads
|
| 281 |
+
missing_heads = expected_heads - label_heads
|
| 282 |
+
if unknown_heads:
|
| 283 |
+
raise ValueError(f"Unexpected label heads provided: {sorted(unknown_heads)}.")
|
| 284 |
+
if missing_heads:
|
| 285 |
+
raise ValueError(f"Missing label heads: {sorted(missing_heads)}.")
|
| 286 |
+
for head, head_labels in labels.items():
|
| 287 |
+
if head_labels.shape != canonical_token_ids.shape:
|
| 288 |
+
raise ValueError(
|
| 289 |
+
f"Labels for head '{head}' must match `canonical_token_ids` shape "
|
| 290 |
+
f"{tuple(canonical_token_ids.shape)}; got {tuple(head_labels.shape)}."
|
| 291 |
+
)
|
| 292 |
+
if head_labels.dtype not in (torch.int64, torch.long):
|
| 293 |
+
raise ValueError(f"Labels for head '{head}' must be integer tensors; got dtype {head_labels.dtype}.")
|
| 294 |
+
masked_positions = token_mask_bool.logical_not()
|
| 295 |
+
bad_mask = masked_positions & (head_labels != -100)
|
| 296 |
+
if torch.any(bad_mask):
|
| 297 |
+
bad_count = int(bad_mask.sum().item())
|
| 298 |
+
raise ValueError(
|
| 299 |
+
f"Labels for head '{head}' must be -100 at masked positions; found {bad_count} mismatches."
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
return token_mask_bool
|
| 303 |
+
|
| 304 |
def forward(
|
| 305 |
self,
|
| 306 |
input_values: torch.Tensor,
|
|
|
|
| 313 |
return_dict: Optional[bool] = None,
|
| 314 |
labels: Optional[torch.Tensor] = None,
|
| 315 |
) -> SequenceClassifierOutput:
|
|
|
|
| 316 |
device = input_values.device
|
| 317 |
+
assert attention_mask is not None
|
| 318 |
|
| 319 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 320 |
+
labels = self._prepare_labels(labels)
|
| 321 |
|
| 322 |
if self.training or labels is not None:
|
| 323 |
if canonical_token_ids is None or token_lengths is None or token_mask is None:
|
|
|
|
| 326 |
if token_mask is None:
|
| 327 |
raise ValueError("`token_mask` must be provided to GOPPhonemeClassifier.forward.")
|
| 328 |
|
| 329 |
+
token_mask_bool = self._validate_inputs(
|
| 330 |
+
input_values=input_values,
|
| 331 |
+
attention_mask=attention_mask,
|
| 332 |
+
canonical_token_ids=canonical_token_ids,
|
| 333 |
+
token_lengths=token_lengths,
|
| 334 |
+
token_mask=token_mask,
|
| 335 |
+
labels=labels,
|
| 336 |
+
).to(device=device)
|
| 337 |
+
|
| 338 |
outputs = self.ctc_model.wav2vec2(
|
| 339 |
input_values,
|
| 340 |
attention_mask=attention_mask,
|
|
|
|
| 344 |
)
|
| 345 |
hidden_states = outputs.last_hidden_state
|
| 346 |
|
|
|
|
| 347 |
logits_ctc = self.ctc_model.lm_head(hidden_states)
|
| 348 |
log_probs_ctc = F.log_softmax(logits_ctc, dim=-1)
|
| 349 |
log_probs_TNC = log_probs_ctc.permute(1, 0, 2).contiguous()
|
| 350 |
|
|
|
|
| 351 |
batch_size = input_values.size(0)
|
| 352 |
+
input_lengths_samples = attention_mask.sum(dim=-1)
|
| 353 |
+
input_lengths_frames = self._get_feat_extract_output_lengths(input_lengths_samples)
|
| 354 |
+
input_lengths_frames = torch.clamp(input_lengths_frames, max=log_probs_TNC.size(0))
|
|
|
|
|
|
|
|
|
|
| 355 |
|
|
|
|
| 356 |
max_token_len = canonical_token_ids.size(1) if canonical_token_ids is not None else 0
|
| 357 |
+
batch_combined_features_list = []
|
| 358 |
+
|
| 359 |
+
token_embeddings = self.token_embedding(canonical_token_ids)
|
| 360 |
|
| 361 |
lpp_log_prob_batch = self._calculate_log_prob(
|
| 362 |
log_probs_TNC, input_lengths_frames, canonical_token_ids, token_lengths
|
| 363 |
)
|
| 364 |
|
| 365 |
for token_idx in range(max_token_len):
|
| 366 |
+
current_token_embeddings = token_embeddings[:, token_idx, :]
|
| 367 |
+
active_mask = token_mask_bool[:, token_idx]
|
| 368 |
+
skip_mask = ~active_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
|
| 370 |
all_sub_log_probs = []
|
| 371 |
+
for sub_token_id in self.token_id_vocab:
|
| 372 |
+
sub_ids_batch = canonical_token_ids.clone()
|
| 373 |
+
sub_ids_batch[active_mask, token_idx] = sub_token_id
|
| 374 |
+
log_prob_sub_batch = self._calculate_log_prob(
|
| 375 |
+
log_probs_TNC, input_lengths_frames, sub_ids_batch, token_lengths
|
| 376 |
+
)
|
| 377 |
+
all_sub_log_probs.append(log_prob_sub_batch)
|
| 378 |
+
|
| 379 |
+
sub_log_probs_batch = torch.stack(all_sub_log_probs, dim=1)
|
| 380 |
+
sub_log_probs_batch = F.log_softmax(sub_log_probs_batch, dim=1)
|
| 381 |
+
sub_log_probs_batch = torch.nan_to_num(sub_log_probs_batch, nan=0.0, posinf=1e10, neginf=-1e10)
|
| 382 |
+
if skip_mask.any():
|
| 383 |
+
sub_log_probs_batch[skip_mask, :] = 0.0
|
|
|
|
|
|
|
|
|
|
| 384 |
|
|
|
|
| 385 |
del_lpr_list = []
|
| 386 |
for b_idx in range(batch_size):
|
| 387 |
if skip_mask[b_idx]:
|
| 388 |
del_lpr_list.append(torch.tensor(-1e10, device=device))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
else:
|
| 390 |
+
item_tokens = canonical_token_ids[b_idx, : token_lengths[b_idx]].tolist()
|
| 391 |
+
del_tokens_list = item_tokens[:token_idx] + item_tokens[token_idx + 1 :]
|
| 392 |
+
if not del_tokens_list:
|
| 393 |
+
log_prob_del_item = torch.tensor(-float("inf"), device=device)
|
| 394 |
+
else:
|
| 395 |
+
del_ids_tensor = torch.tensor(
|
| 396 |
+
[del_tokens_list], dtype=torch.long, device=canonical_token_ids.device
|
| 397 |
+
)
|
| 398 |
+
del_len_tensor = torch.tensor(
|
| 399 |
+
[len(del_tokens_list)], dtype=torch.long, device=canonical_token_ids.device
|
| 400 |
+
)
|
| 401 |
+
log_probs_item_TNC = log_probs_TNC[:, b_idx : b_idx + 1, :]
|
| 402 |
+
input_len_item = input_lengths_frames[b_idx : b_idx + 1]
|
| 403 |
+
log_prob_del_item = self._calculate_log_prob(
|
| 404 |
+
log_probs_item_TNC, input_len_item, del_ids_tensor, del_len_tensor
|
| 405 |
+
)
|
| 406 |
+
if log_prob_del_item.dim() > 0:
|
| 407 |
+
log_prob_del_item = log_prob_del_item[0]
|
| 408 |
+
lpr_del_item = lpp_log_prob_batch[b_idx] - log_prob_del_item
|
| 409 |
+
lpr_del_item = torch.nan_to_num(lpr_del_item, nan=0.0, posinf=1e10, neginf=-1e10)
|
| 410 |
+
del_lpr_list.append(lpr_del_item)
|
| 411 |
del_lpr_batch = torch.stack(del_lpr_list)
|
| 412 |
|
| 413 |
+
gop_part = torch.cat(
|
| 414 |
+
[lpp_log_prob_batch.unsqueeze(1), sub_log_probs_batch, del_lpr_batch.unsqueeze(1)], dim=1
|
| 415 |
+
)
|
| 416 |
+
gop_part = self.gop_part_norm(gop_part)
|
| 417 |
+
gop_part = self.gop_part_dropout(gop_part)
|
| 418 |
combined_features = torch.cat([gop_part, current_token_embeddings], dim=1)
|
| 419 |
+
batch_combined_features_list.append(combined_features)
|
| 420 |
+
|
| 421 |
+
transformer_input = torch.stack(batch_combined_features_list, dim=1)
|
| 422 |
+
|
| 423 |
+
gop_rnn_output, _ = self.gop_rnn(transformer_input)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
|
| 425 |
+
final_logits = {}
|
| 426 |
+
for head, classifier in self.classifiers.items():
|
| 427 |
+
head_features = self.head_mlps[head](gop_rnn_output)
|
| 428 |
+
final_logits[head] = classifier(head_features)
|
|
|
|
| 429 |
|
|
|
|
| 430 |
loss = None
|
| 431 |
if labels is not None:
|
| 432 |
+
loss = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
for head, head_logits in final_logits.items():
|
| 434 |
+
head_labels = labels[head]
|
| 435 |
+
head_loss = self.loss_fns[head](head_logits, head_labels)
|
| 436 |
+
loss += head_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
|
| 438 |
if not return_dict:
|
| 439 |
output = (final_logits,)
|
models.py
CHANGED
|
@@ -4,19 +4,21 @@ import torch.nn as nn
|
|
| 4 |
|
| 5 |
class OrdinalLogLoss(nn.Module):
|
| 6 |
def __init__(
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
| 14 |
):
|
| 15 |
super(OrdinalLogLoss, self).__init__()
|
| 16 |
self.num_classes = num_classes
|
| 17 |
self.alpha = alpha
|
| 18 |
self.reduction = reduction
|
| 19 |
self.eps = eps
|
|
|
|
| 20 |
|
| 21 |
if distance_matrix is not None:
|
| 22 |
assert distance_matrix.shape == (num_classes, num_classes), \
|
|
@@ -35,21 +37,39 @@ class OrdinalLogLoss(nn.Module):
|
|
| 35 |
self.class_weights = None
|
| 36 |
|
| 37 |
def forward(self, logits, target):
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
if self.class_weights is not None:
|
| 45 |
-
sample_weights = self.class_weights[
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
-
# Apply reduction
|
| 49 |
if self.reduction == 'mean':
|
| 50 |
-
return
|
| 51 |
elif self.reduction == 'sum':
|
| 52 |
-
return
|
| 53 |
else:
|
| 54 |
-
|
| 55 |
-
|
|
|
|
| 4 |
|
| 5 |
class OrdinalLogLoss(nn.Module):
|
| 6 |
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
num_classes,
|
| 9 |
+
alpha=1.0,
|
| 10 |
+
reduction='mean',
|
| 11 |
+
distance_matrix=None,
|
| 12 |
+
class_weights=None,
|
| 13 |
+
eps=1e-8,
|
| 14 |
+
ignore_index=-100,
|
| 15 |
):
|
| 16 |
super(OrdinalLogLoss, self).__init__()
|
| 17 |
self.num_classes = num_classes
|
| 18 |
self.alpha = alpha
|
| 19 |
self.reduction = reduction
|
| 20 |
self.eps = eps
|
| 21 |
+
self.ignore_index = ignore_index
|
| 22 |
|
| 23 |
if distance_matrix is not None:
|
| 24 |
assert distance_matrix.shape == (num_classes, num_classes), \
|
|
|
|
| 37 |
self.class_weights = None
|
| 38 |
|
| 39 |
def forward(self, logits, target):
|
| 40 |
+
if logits.numel() == 0:
|
| 41 |
+
return logits.new_tensor(0.0)
|
| 42 |
+
|
| 43 |
+
probs = torch.softmax(logits, dim=-1).clamp(max=1 - self.eps)
|
| 44 |
+
|
| 45 |
+
if self.ignore_index is not None:
|
| 46 |
+
valid_mask = target != self.ignore_index
|
| 47 |
+
else:
|
| 48 |
+
valid_mask = torch.ones_like(target, dtype=torch.bool)
|
| 49 |
+
|
| 50 |
+
if not valid_mask.any():
|
| 51 |
+
if self.reduction == 'none':
|
| 52 |
+
return logits.new_zeros(target.shape, dtype=logits.dtype)
|
| 53 |
+
return logits.new_tensor(0.0)
|
| 54 |
+
|
| 55 |
+
active_probs = probs[valid_mask]
|
| 56 |
+
active_target = target[valid_mask]
|
| 57 |
+
distances = self.distance_matrix[active_target] ** self.alpha
|
| 58 |
+
per_class_loss = -torch.log(1 - active_probs + self.eps)
|
| 59 |
+
loss_active = (per_class_loss * distances).sum(dim=-1)
|
| 60 |
+
|
| 61 |
if self.class_weights is not None:
|
| 62 |
+
sample_weights = self.class_weights[active_target]
|
| 63 |
+
loss_active = loss_active * sample_weights
|
| 64 |
+
|
| 65 |
+
if self.reduction == 'none':
|
| 66 |
+
full_loss = logits.new_zeros(target.shape, dtype=logits.dtype)
|
| 67 |
+
full_loss[valid_mask] = loss_active
|
| 68 |
+
return full_loss
|
| 69 |
|
|
|
|
| 70 |
if self.reduction == 'mean':
|
| 71 |
+
return loss_active.mean()
|
| 72 |
elif self.reduction == 'sum':
|
| 73 |
+
return loss_active.sum()
|
| 74 |
else:
|
| 75 |
+
raise ValueError(f"Unsupported reduction: {self.reduction}")
|
|
|
utils.py
CHANGED
|
@@ -1,24 +1,27 @@
|
|
| 1 |
-
from typing import List
|
|
|
|
|
|
|
|
|
|
| 2 |
import torch
|
| 3 |
import torchaudio
|
| 4 |
-
from transformers import
|
|
|
|
| 5 |
from constants import MAX_AUDIO_DURATION_SECONDS, MONO_CHANNEL, SAMPLING_RATE
|
| 6 |
from gop_model import GOPPhonemeClassifier
|
| 7 |
-
import logging
|
| 8 |
|
| 9 |
logger = logging.getLogger(__name__)
|
| 10 |
|
| 11 |
|
| 12 |
def load_model_and_processor(model_repo_id: str):
|
| 13 |
-
logger.info(
|
| 14 |
|
| 15 |
quantization_config = QuantoConfig(weights="int8")
|
| 16 |
-
logger.info("Applying INT8 dynamic quantization during model loading
|
| 17 |
|
| 18 |
model = GOPPhonemeClassifier.from_pretrained(
|
| 19 |
model_repo_id,
|
| 20 |
quantization_config=quantization_config,
|
| 21 |
-
device_map="auto"
|
| 22 |
)
|
| 23 |
processor = Wav2Vec2Processor.from_pretrained(model_repo_id)
|
| 24 |
model.eval()
|
|
@@ -36,6 +39,46 @@ def validate_phonemes(phoneme_text, allowed_phonemes):
|
|
| 36 |
return None
|
| 37 |
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
def run_inference(audio_file_path: str, transcript: str, model: GOPPhonemeClassifier, processor: Wav2Vec2Processor):
|
| 40 |
if not audio_file_path or not transcript:
|
| 41 |
return "<p style='text-align:center; color:red;'>Please provide both an audio file and the transcript.</p>"
|
|
@@ -56,19 +99,14 @@ def run_inference(audio_file_path: str, transcript: str, model: GOPPhonemeClassi
|
|
| 56 |
|
| 57 |
audio_input = waveform.squeeze(0)
|
| 58 |
processed_audio = processor(audio_input, sampling_rate=SAMPLING_RATE, return_tensors="pt", padding=True)
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
ids = [ids]
|
| 68 |
-
ids = [i if i is not None else unk_id for i in ids]
|
| 69 |
-
canonical_token_ids = torch.tensor([ids], dtype=torch.long).to(model.device)
|
| 70 |
-
token_lengths = torch.tensor([len(ids)], dtype=torch.long).to(model.device)
|
| 71 |
-
token_mask = torch.ones_like(canonical_token_ids).to(model.device)
|
| 72 |
|
| 73 |
with torch.no_grad():
|
| 74 |
outputs = model(
|
|
@@ -76,19 +114,11 @@ def run_inference(audio_file_path: str, transcript: str, model: GOPPhonemeClassi
|
|
| 76 |
attention_mask=attention_mask,
|
| 77 |
canonical_token_ids=canonical_token_ids,
|
| 78 |
token_lengths=token_lengths,
|
| 79 |
-
token_mask=token_mask
|
| 80 |
)
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
head_name = next(iter(logits))
|
| 85 |
-
scores_tensor = logits[head_name]
|
| 86 |
-
predicted_scores = torch.argmax(scores_tensor, dim=-1)
|
| 87 |
-
|
| 88 |
-
tokens = processor.tokenizer.convert_ids_to_tokens(canonical_token_ids[0])
|
| 89 |
-
|
| 90 |
-
return predicted_scores, tokens, token_lengths
|
| 91 |
|
| 92 |
-
except Exception as
|
| 93 |
-
logger.error(
|
| 94 |
-
return f"<p style='text-align:center; color:red;'>An error occurred: {
|
|
|
|
| 1 |
+
from typing import Dict, List, Tuple
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
import torch
|
| 6 |
import torchaudio
|
| 7 |
+
from transformers import QuantoConfig, Wav2Vec2Processor
|
| 8 |
+
|
| 9 |
from constants import MAX_AUDIO_DURATION_SECONDS, MONO_CHANNEL, SAMPLING_RATE
|
| 10 |
from gop_model import GOPPhonemeClassifier
|
|
|
|
| 11 |
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
| 14 |
|
| 15 |
def load_model_and_processor(model_repo_id: str):
|
| 16 |
+
logger.info("Loading model and processor from Hugging Face Hub: %s", model_repo_id)
|
| 17 |
|
| 18 |
quantization_config = QuantoConfig(weights="int8")
|
| 19 |
+
logger.info("Applying INT8 dynamic quantization during model loading")
|
| 20 |
|
| 21 |
model = GOPPhonemeClassifier.from_pretrained(
|
| 22 |
model_repo_id,
|
| 23 |
quantization_config=quantization_config,
|
| 24 |
+
device_map="auto",
|
| 25 |
)
|
| 26 |
processor = Wav2Vec2Processor.from_pretrained(model_repo_id)
|
| 27 |
model.eval()
|
|
|
|
| 39 |
return None
|
| 40 |
|
| 41 |
|
| 42 |
+
def _prepare_canonical_tokens(transcript: str, processor: Wav2Vec2Processor, device: torch.device):
|
| 43 |
+
phonemes: List[str] = transcript.strip().split()
|
| 44 |
+
if not phonemes:
|
| 45 |
+
raise ValueError("Please enter at least one phoneme.")
|
| 46 |
+
|
| 47 |
+
token_mask_values = [token != "|" for token in phonemes]
|
| 48 |
+
if not any(token_mask_values):
|
| 49 |
+
raise ValueError("The phoneme sequence must contain at least one non-boundary token.")
|
| 50 |
+
|
| 51 |
+
tokenizer = processor.tokenizer
|
| 52 |
+
unk_id = getattr(tokenizer, "unk_token_id", None)
|
| 53 |
+
ids = tokenizer.convert_tokens_to_ids(phonemes)
|
| 54 |
+
if isinstance(ids, int):
|
| 55 |
+
ids = [ids]
|
| 56 |
+
ids = [token_id if token_id is not None else unk_id for token_id in ids]
|
| 57 |
+
|
| 58 |
+
canonical_token_ids = torch.tensor([ids], dtype=torch.long, device=device)
|
| 59 |
+
token_lengths = torch.tensor([len(ids)], dtype=torch.long, device=device)
|
| 60 |
+
token_mask = torch.tensor([token_mask_values], dtype=torch.bool, device=device)
|
| 61 |
+
|
| 62 |
+
display_tokens = [token for token, is_active in zip(phonemes, token_mask_values) if is_active]
|
| 63 |
+
return canonical_token_ids, token_lengths, token_mask, display_tokens
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _extract_head_predictions(
|
| 67 |
+
logits_by_head: Dict[str, torch.Tensor],
|
| 68 |
+
token_mask: torch.Tensor,
|
| 69 |
+
display_tokens: List[str],
|
| 70 |
+
) -> Dict[str, Tuple[List[int], List[str]]]:
|
| 71 |
+
active_mask = token_mask[0].bool()
|
| 72 |
+
results: Dict[str, Tuple[List[int], List[str]]] = {}
|
| 73 |
+
|
| 74 |
+
for head_name, head_logits in logits_by_head.items():
|
| 75 |
+
predicted_scores = torch.argmax(head_logits, dim=-1)[0]
|
| 76 |
+
filtered_scores = predicted_scores[active_mask].detach().cpu().tolist()
|
| 77 |
+
results[head_name] = (filtered_scores, display_tokens)
|
| 78 |
+
|
| 79 |
+
return results
|
| 80 |
+
|
| 81 |
+
|
| 82 |
def run_inference(audio_file_path: str, transcript: str, model: GOPPhonemeClassifier, processor: Wav2Vec2Processor):
|
| 83 |
if not audio_file_path or not transcript:
|
| 84 |
return "<p style='text-align:center; color:red;'>Please provide both an audio file and the transcript.</p>"
|
|
|
|
| 99 |
|
| 100 |
audio_input = waveform.squeeze(0)
|
| 101 |
processed_audio = processor(audio_input, sampling_rate=SAMPLING_RATE, return_tensors="pt", padding=True)
|
| 102 |
+
|
| 103 |
+
model_device = next(model.parameters()).device
|
| 104 |
+
input_values = processed_audio.input_values.to(model_device)
|
| 105 |
+
attention_mask = processed_audio.attention_mask.to(model_device)
|
| 106 |
+
|
| 107 |
+
canonical_token_ids, token_lengths, token_mask, display_tokens = _prepare_canonical_tokens(
|
| 108 |
+
transcript, processor, model_device
|
| 109 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
with torch.no_grad():
|
| 112 |
outputs = model(
|
|
|
|
| 114 |
attention_mask=attention_mask,
|
| 115 |
canonical_token_ids=canonical_token_ids,
|
| 116 |
token_lengths=token_lengths,
|
| 117 |
+
token_mask=token_mask,
|
| 118 |
)
|
| 119 |
|
| 120 |
+
return _extract_head_predictions(outputs.logits, token_mask, display_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
+
except Exception as exc:
|
| 123 |
+
logger.error("An error occurred during inference: %s", exc, exc_info=True)
|
| 124 |
+
return f"<p style='text-align:center; color:red;'>An error occurred: {exc}</p>"
|