|
|
import numpy as np |
|
|
import onnxruntime as ort |
|
|
from huggingface_hub import hf_hub_download |
|
|
from omegaconf import OmegaConf |
|
|
from sentencepiece import SentencePieceProcessor |
|
|
from typing import List |
|
|
|
|
|
def process_text(input_text: str) -> str: |
|
|
spe_path = "sp.model" |
|
|
tokenizer: SentencePieceProcessor = SentencePieceProcessor(spe_path) |
|
|
|
|
|
|
|
|
onnx_path = "model.onnx" |
|
|
ort_session: ort.InferenceSession = ort.InferenceSession(onnx_path) |
|
|
|
|
|
|
|
|
config_path = "config.yaml" |
|
|
config = OmegaConf.load(config_path) |
|
|
|
|
|
pre_labels: List[str] = config.pre_labels |
|
|
|
|
|
post_labels: List[str] = config.post_labels |
|
|
|
|
|
null_token = config.get("null_token", "<NULL>") |
|
|
|
|
|
acronym_token = config.get("acronym_token", "<ACRONYM>") |
|
|
|
|
|
max_len = config.max_length |
|
|
|
|
|
languages: List[str] = config.languages |
|
|
|
|
|
|
|
|
input_ids = [tokenizer.bos_id()] + tokenizer.EncodeAsIds(input_text) + [tokenizer.eos_id()] |
|
|
|
|
|
|
|
|
input_ids_arr: np.array = np.array([input_ids]) |
|
|
|
|
|
|
|
|
pre_preds, post_preds, cap_preds, sbd_preds = ort_session.run(None, {"input_ids": input_ids_arr}) |
|
|
|
|
|
pre_preds = pre_preds[0].tolist() |
|
|
post_preds = post_preds[0].tolist() |
|
|
cap_preds = cap_preds[0].tolist() |
|
|
sbd_preds = sbd_preds[0].tolist() |
|
|
|
|
|
|
|
|
output_texts: List[str] = [] |
|
|
current_chars: List[str] = [] |
|
|
|
|
|
for token_idx in range(1, len(input_ids) - 1): |
|
|
token = tokenizer.IdToPiece(input_ids[token_idx]) |
|
|
if token.startswith("▁") and current_chars: |
|
|
current_chars.append(" ") |
|
|
|
|
|
pre_label = pre_labels[pre_preds[token_idx]] |
|
|
post_label = post_labels[post_preds[token_idx]] |
|
|
|
|
|
if pre_label != null_token: |
|
|
current_chars.append(pre_label) |
|
|
|
|
|
char_start = 1 if token.startswith("▁") else 0 |
|
|
for token_char_idx, char in enumerate(token[char_start:], start=char_start): |
|
|
|
|
|
if cap_preds[token_idx][token_char_idx]: |
|
|
char = char.upper() |
|
|
|
|
|
current_chars.append(char) |
|
|
|
|
|
if post_label == acronym_token: |
|
|
current_chars.append(".") |
|
|
|
|
|
if post_label != null_token and post_label != acronym_token: |
|
|
current_chars.append(post_label) |
|
|
|
|
|
|
|
|
if sbd_preds[token_idx]: |
|
|
output_texts.append("".join(current_chars)) |
|
|
current_chars.clear() |
|
|
|
|
|
|
|
|
output_texts.append("".join(current_chars)) |
|
|
|
|
|
|
|
|
return "\n".join(output_texts) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|