Update handler.py
Browse files- handler.py +11 -9
handler.py
CHANGED
|
@@ -3,6 +3,7 @@ import re
|
|
| 3 |
from itertools import groupby
|
| 4 |
from dataclasses import dataclass
|
| 5 |
from typing import Optional, Tuple, Union, Dict, List, Any
|
|
|
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
from transformers.modeling_outputs import ModelOutput
|
|
@@ -47,8 +48,11 @@ ONSETS = {
|
|
| 47 |
|
| 48 |
class SpeechToJyutpingPipeline(Pipeline):
|
| 49 |
def _sanitize_parameters(self, **kwargs):
|
|
|
|
|
|
|
|
|
|
| 50 |
self.tone_tokenizer = Wav2Vec2CTCTokenizer(
|
| 51 |
-
|
| 52 |
unk_token="[UNK]",
|
| 53 |
pad_token="[PAD]",
|
| 54 |
word_delimiter_token="|",
|
|
@@ -95,7 +99,6 @@ class SpeechToJyutpingPipeline(Pipeline):
|
|
| 95 |
|
| 96 |
sample_rate = 16000
|
| 97 |
symbols = [w for w in transcription.split(" ") if len(w) > 0]
|
| 98 |
-
duration_sec = model_outputs["duration"] / sample_rate
|
| 99 |
|
| 100 |
ids_w_index = [(i, _id.item()) for i, _id in enumerate(predicted_ids[0])]
|
| 101 |
# remove entries which are just "padding" (i.e. no characers are recognized)
|
|
@@ -151,7 +154,7 @@ class SpeechToJyutpingPipeline(Pipeline):
|
|
| 151 |
transcription = re.sub(
|
| 152 |
r"\s+", " ", "".join(transcription).replace("_", " ").strip()
|
| 153 |
)
|
| 154 |
-
tone_probs = torch.stack(tone_probs).cpu().
|
| 155 |
|
| 156 |
return {"transcription": transcription, "tone_probs": tone_probs}
|
| 157 |
|
|
@@ -388,15 +391,14 @@ class Wav2Vec2BertForCantonese(Wav2Vec2BertPreTrainedModel):
|
|
| 388 |
|
| 389 |
|
| 390 |
class EndpointHandler:
|
| 391 |
-
def __init__(self, path="
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
)
|
| 396 |
|
| 397 |
self.pipeline = pipeline(
|
| 398 |
task="speech-to-jyutping",
|
| 399 |
-
model=Wav2Vec2BertForCantonese.from_pretrained(
|
| 400 |
feature_extractor=feature_extractor,
|
| 401 |
tokenizer=tokenizer,
|
| 402 |
)
|
|
|
|
| 3 |
from itertools import groupby
|
| 4 |
from dataclasses import dataclass
|
| 5 |
from typing import Optional, Tuple, Union, Dict, List, Any
|
| 6 |
+
from huggingface_hub import hf_hub_download
|
| 7 |
import torch
|
| 8 |
import torch.nn as nn
|
| 9 |
from transformers.modeling_outputs import ModelOutput
|
|
|
|
| 48 |
|
| 49 |
class SpeechToJyutpingPipeline(Pipeline):
|
| 50 |
def _sanitize_parameters(self, **kwargs):
|
| 51 |
+
tone_vocab_file = hf_hub_download(
|
| 52 |
+
repo_id="hon9kon9ize/wav2vec2bert-jyutping", filename="tone_vocab.json"
|
| 53 |
+
)
|
| 54 |
self.tone_tokenizer = Wav2Vec2CTCTokenizer(
|
| 55 |
+
tone_vocab_file,
|
| 56 |
unk_token="[UNK]",
|
| 57 |
pad_token="[PAD]",
|
| 58 |
word_delimiter_token="|",
|
|
|
|
| 99 |
|
| 100 |
sample_rate = 16000
|
| 101 |
symbols = [w for w in transcription.split(" ") if len(w) > 0]
|
|
|
|
| 102 |
|
| 103 |
ids_w_index = [(i, _id.item()) for i, _id in enumerate(predicted_ids[0])]
|
| 104 |
# remove entries which are just "padding" (i.e. no characers are recognized)
|
|
|
|
| 154 |
transcription = re.sub(
|
| 155 |
r"\s+", " ", "".join(transcription).replace("_", " ").strip()
|
| 156 |
)
|
| 157 |
+
tone_probs = torch.stack(tone_probs).cpu().tolist()
|
| 158 |
|
| 159 |
return {"transcription": transcription, "tone_probs": tone_probs}
|
| 160 |
|
|
|
|
| 391 |
|
| 392 |
|
| 393 |
class EndpointHandler:
|
| 394 |
+
def __init__(self, path="."):
|
| 395 |
+
model_path = "hon9kon9ize/wav2vec2bert-jyutping"
|
| 396 |
+
feature_extractor = SeamlessM4TFeatureExtractor.from_pretrained(model_path)
|
| 397 |
+
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_path)
|
|
|
|
| 398 |
|
| 399 |
self.pipeline = pipeline(
|
| 400 |
task="speech-to-jyutping",
|
| 401 |
+
model=Wav2Vec2BertForCantonese.from_pretrained(model_path),
|
| 402 |
feature_extractor=feature_extractor,
|
| 403 |
tokenizer=tokenizer,
|
| 404 |
)
|