APCR_BART / app.py
mpolacek's picture
Update app.py
0d38908 verified
import streamlit as st
from transformers import BartForConditionalGeneration, DebertaV2Tokenizer
import torch
import time
from huggingface_hub import Repository
repo = Repository(
local_dir="scripts",
repo_type="model",
clone_from="AILabTUL/APCR_BART",
token=True
)
repo.git_pull()
# Nastavení stránky
st.set_page_config(page_title="Text Punctuation and Capitalization Restoration", layout="wide")
# Načtení modelu a tokenizeru do cache
@st.cache_resource
def load_model():
tokenizer = DebertaV2Tokenizer.from_pretrained("./scripts")
model = BartForConditionalGeneration.from_pretrained('./scripts')
model.load_state_dict(torch.load("./scripts/pytorch_model.bin", map_location=torch.device('cpu')))
model.eval() # Přepnutí modelu do eval režimu
encoder = model.get_encoder()
decoder = model.model.decoder
lm_head = model.lm_head
start_token = model.config.decoder_start_token_id
return encoder, decoder, lm_head, tokenizer, start_token
encoder, decoder, lm_head, tokenizer, start_token = load_model()
# Title
st.title("Czech punctuation and capitalization restoration (BART-Small)")
# Input form
with st.form(key='input_form'):
input_text = st.text_area("Insert text without punctuation and capitalization:",
value="Rachejtle létající nad stájí se závodními koňmi v Bílově na severním Plzeňsku popudily profesionální drezurní jezdkyni a trenérku Adélu Neumannovou.",
height=100)
# Formating input text
input_text = input_text.replace("\n", " ").replace(".", " ").replace(",", " ").replace("?", " ").replace("!", " ").lower()
submit_button = st.form_submit_button(label='Generate')
if submit_button:
if not input_text.strip():
st.error("Please, insert some text.")
else:
# Tokenization
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
eos_token_id = 32001
max_length = 50
# Empty input for generated text
generated_ids = torch.tensor([[start_token]])
output_placeholder = st.empty()
generation_start = time.time()
tokens_count = 0
encoder_outputs = encoder(input_ids=input_ids, return_dict=True)
for _ in range(max_length):
# Forward pass
outputs = decoder(
input_ids=generated_ids,
encoder_hidden_states=encoder_outputs.last_hidden_state,
return_dict=True
)
logits = lm_head(outputs.last_hidden_state)
# Last token logits
next_token_logits = logits[:, -1, :]
# Sampling nebo argmax pro výběr dalšího tokenu
next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
# Add token to generated sequence
generated_ids = torch.cat([generated_ids, next_token_id], dim=1)
# End generation when EOS token is reached
if next_token_id.item() == eos_token_id:
break
# Tokens to text
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
output_placeholder.text(generated_text)
tokens_count += 1
duration = time.time() - generation_start
tokens_per_second = tokens_count / duration
st.success(f"Generation completed! --- Generation speed: {tokens_per_second:.2f} tokens/s")