|
|
import streamlit as st |
|
|
from transformers import BartForConditionalGeneration, DebertaV2Tokenizer |
|
|
import torch |
|
|
import time |
|
|
from huggingface_hub import Repository |
|
|
|
|
|
repo = Repository( |
|
|
local_dir="scripts", |
|
|
repo_type="model", |
|
|
clone_from="AILabTUL/APCR_BART", |
|
|
token=True |
|
|
) |
|
|
repo.git_pull() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Text Punctuation and Capitalization Restoration", layout="wide") |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
tokenizer = DebertaV2Tokenizer.from_pretrained("./scripts") |
|
|
model = BartForConditionalGeneration.from_pretrained('./scripts') |
|
|
model.load_state_dict(torch.load("./scripts/pytorch_model.bin", map_location=torch.device('cpu'))) |
|
|
model.eval() |
|
|
encoder = model.get_encoder() |
|
|
decoder = model.model.decoder |
|
|
lm_head = model.lm_head |
|
|
start_token = model.config.decoder_start_token_id |
|
|
return encoder, decoder, lm_head, tokenizer, start_token |
|
|
|
|
|
encoder, decoder, lm_head, tokenizer, start_token = load_model() |
|
|
|
|
|
|
|
|
st.title("Czech punctuation and capitalization restoration (BART-Small)") |
|
|
|
|
|
|
|
|
with st.form(key='input_form'): |
|
|
input_text = st.text_area("Insert text without punctuation and capitalization:", |
|
|
value="Rachejtle létající nad stájí se závodními koňmi v Bílově na severním Plzeňsku popudily profesionální drezurní jezdkyni a trenérku Adélu Neumannovou.", |
|
|
height=100) |
|
|
|
|
|
input_text = input_text.replace("\n", " ").replace(".", " ").replace(",", " ").replace("?", " ").replace("!", " ").lower() |
|
|
|
|
|
submit_button = st.form_submit_button(label='Generate') |
|
|
|
|
|
|
|
|
if submit_button: |
|
|
if not input_text.strip(): |
|
|
st.error("Please, insert some text.") |
|
|
else: |
|
|
|
|
|
input_ids = tokenizer(input_text, return_tensors="pt").input_ids |
|
|
|
|
|
eos_token_id = 32001 |
|
|
max_length = 50 |
|
|
|
|
|
|
|
|
generated_ids = torch.tensor([[start_token]]) |
|
|
|
|
|
output_placeholder = st.empty() |
|
|
|
|
|
generation_start = time.time() |
|
|
tokens_count = 0 |
|
|
|
|
|
encoder_outputs = encoder(input_ids=input_ids, return_dict=True) |
|
|
|
|
|
for _ in range(max_length): |
|
|
|
|
|
outputs = decoder( |
|
|
input_ids=generated_ids, |
|
|
encoder_hidden_states=encoder_outputs.last_hidden_state, |
|
|
return_dict=True |
|
|
) |
|
|
logits = lm_head(outputs.last_hidden_state) |
|
|
|
|
|
|
|
|
next_token_logits = logits[:, -1, :] |
|
|
|
|
|
|
|
|
next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0) |
|
|
|
|
|
|
|
|
generated_ids = torch.cat([generated_ids, next_token_id], dim=1) |
|
|
|
|
|
|
|
|
if next_token_id.item() == eos_token_id: |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
|
output_placeholder.text(generated_text) |
|
|
tokens_count += 1 |
|
|
|
|
|
duration = time.time() - generation_start |
|
|
tokens_per_second = tokens_count / duration |
|
|
|
|
|
st.success(f"Generation completed! --- Generation speed: {tokens_per_second:.2f} tokens/s") |