File size: 3,526 Bytes
62ceede
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a81a35c
 
 
0d38908
 
62ceede
0d38908
62ceede
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d38908
62ceede
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a81a35c
62ceede
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import streamlit as st
from transformers import BartForConditionalGeneration, DebertaV2Tokenizer
import torch
import time
from huggingface_hub import Repository

repo = Repository(
    local_dir="scripts",
    repo_type="model",
    clone_from="AILabTUL/APCR_BART",
    token=True
)
repo.git_pull()



# Nastavení stránky
st.set_page_config(page_title="Text Punctuation and Capitalization Restoration", layout="wide")

# Načtení modelu a tokenizeru do cache
@st.cache_resource
def load_model():
    tokenizer = DebertaV2Tokenizer.from_pretrained("./scripts")
    model = BartForConditionalGeneration.from_pretrained('./scripts')
    model.load_state_dict(torch.load("./scripts/pytorch_model.bin", map_location=torch.device('cpu')))
    model.eval()  # Přepnutí modelu do eval režimu
    encoder = model.get_encoder()
    decoder = model.model.decoder
    lm_head = model.lm_head
    start_token = model.config.decoder_start_token_id
    return encoder, decoder, lm_head, tokenizer, start_token

encoder, decoder, lm_head, tokenizer, start_token = load_model()

# Title
st.title("Czech punctuation and capitalization restoration (BART-Small)")

# Input form
with st.form(key='input_form'):
    input_text = st.text_area("Insert text without punctuation and capitalization:", 
                              value="Rachejtle létající nad stájí se závodními koňmi v Bílově na severním Plzeňsku popudily profesionální drezurní jezdkyni a trenérku Adélu Neumannovou.",
                              height=100)
    # Formating input text
    input_text = input_text.replace("\n", " ").replace(".", " ").replace(",", " ").replace("?", " ").replace("!", " ").lower()

    submit_button = st.form_submit_button(label='Generate')


if submit_button:
    if not input_text.strip():
        st.error("Please, insert some text.")
    else:
        # Tokenization
        input_ids = tokenizer(input_text, return_tensors="pt").input_ids

        eos_token_id = 32001
        max_length = 50

        # Empty input for generated text
        generated_ids = torch.tensor([[start_token]])

        output_placeholder = st.empty()

        generation_start = time.time()
        tokens_count = 0

        encoder_outputs = encoder(input_ids=input_ids, return_dict=True)

        for _ in range(max_length):
            # Forward pass
            outputs = decoder(
                input_ids=generated_ids,
                encoder_hidden_states=encoder_outputs.last_hidden_state,
                return_dict=True
            )
            logits = lm_head(outputs.last_hidden_state)
            
            # Last token logits
            next_token_logits = logits[:, -1, :]
            
            # Sampling nebo argmax pro výběr dalšího tokenu
            next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
            
            # Add token to generated sequence
            generated_ids = torch.cat([generated_ids, next_token_id], dim=1)
            
            # End generation when EOS token is reached
            if next_token_id.item() == eos_token_id:
                break


            # Tokens to text
            generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
            output_placeholder.text(generated_text)
            tokens_count += 1

        duration = time.time() - generation_start
        tokens_per_second = tokens_count / duration

        st.success(f"Generation completed! --- Generation speed: {tokens_per_second:.2f} tokens/s")