import streamlit as st from transformers import BartForConditionalGeneration, DebertaV2Tokenizer import torch import time from huggingface_hub import Repository repo = Repository( local_dir="scripts", repo_type="model", clone_from="AILabTUL/APCR_BART", token=True ) repo.git_pull() # Nastavení stránky st.set_page_config(page_title="Text Punctuation and Capitalization Restoration", layout="wide") # Načtení modelu a tokenizeru do cache @st.cache_resource def load_model(): tokenizer = DebertaV2Tokenizer.from_pretrained("./scripts") model = BartForConditionalGeneration.from_pretrained('./scripts') model.load_state_dict(torch.load("./scripts/pytorch_model.bin", map_location=torch.device('cpu'))) model.eval() # Přepnutí modelu do eval režimu encoder = model.get_encoder() decoder = model.model.decoder lm_head = model.lm_head start_token = model.config.decoder_start_token_id return encoder, decoder, lm_head, tokenizer, start_token encoder, decoder, lm_head, tokenizer, start_token = load_model() # Title st.title("Czech punctuation and capitalization restoration (BART-Small)") # Input form with st.form(key='input_form'): input_text = st.text_area("Insert text without punctuation and capitalization:", value="Rachejtle létající nad stájí se závodními koňmi v Bílově na severním Plzeňsku popudily profesionální drezurní jezdkyni a trenérku Adélu Neumannovou.", height=100) # Formating input text input_text = input_text.replace("\n", " ").replace(".", " ").replace(",", " ").replace("?", " ").replace("!", " ").lower() submit_button = st.form_submit_button(label='Generate') if submit_button: if not input_text.strip(): st.error("Please, insert some text.") else: # Tokenization input_ids = tokenizer(input_text, return_tensors="pt").input_ids eos_token_id = 32001 max_length = 50 # Empty input for generated text generated_ids = torch.tensor([[start_token]]) output_placeholder = st.empty() generation_start = time.time() tokens_count = 0 encoder_outputs = encoder(input_ids=input_ids, return_dict=True) for _ in range(max_length): # Forward pass outputs = decoder( input_ids=generated_ids, encoder_hidden_states=encoder_outputs.last_hidden_state, return_dict=True ) logits = lm_head(outputs.last_hidden_state) # Last token logits next_token_logits = logits[:, -1, :] # Sampling nebo argmax pro výběr dalšího tokenu next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0) # Add token to generated sequence generated_ids = torch.cat([generated_ids, next_token_id], dim=1) # End generation when EOS token is reached if next_token_id.item() == eos_token_id: break # Tokens to text generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) output_placeholder.text(generated_text) tokens_count += 1 duration = time.time() - generation_start tokens_per_second = tokens_count / duration st.success(f"Generation completed! --- Generation speed: {tokens_per_second:.2f} tokens/s")