mpolacek commited on
Commit
83479b5
·
verified ·
1 Parent(s): e51fb3b

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +84 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import BartForConditionalGeneration, DebertaV2Tokenizer
3
+ import torch
4
+ import time
5
+ from huggingface_hub import Repository
6
+
7
+ repo = Repository(
8
+ local_dir="scripts",
9
+ repo_type="model",
10
+ clone_from="AILabTUL/APCR_BART",
11
+ token=True
12
+ )
13
+ repo.git_pull()
14
+
15
+
16
+
17
+ # Nastavení stránky
18
+ st.set_page_config(page_title="Text Punctuation and Capitalization Restoration", layout="wide")
19
+
20
+ # Načtení modelu a tokenizeru do cache
21
+ @st.cache_resource
22
+ def load_model():
23
+ tokenizer = DebertaV2Tokenizer.from_pretrained("./scripts")
24
+ model = BartForConditionalGeneration.from_pretrained('./scripts')
25
+ model.load_state_dict(torch.load("./scripts/pytorch_model.bin", map_location=torch.device('cpu')))
26
+ model.eval() # Přepnutí modelu do eval režimu
27
+ return model, tokenizer
28
+
29
+ model, tokenizer = load_model()
30
+
31
+ # Titulek aplikace
32
+ st.title("Obnova interpunkce a velkých písmen v textu")
33
+
34
+ # Vstupní formulář pro uživatele
35
+ with st.form(key='input_form'):
36
+ input_text = st.text_area("Zadejte text bez interpunkce a velkých písmen:",
37
+ value="Co jde podat Sněmovny už je Sněmovně Ve zrychleném čtení chceme schválit změnu zákoníku práce která by měla platit od 1. ledna",
38
+ height=150)
39
+ submit_button = st.form_submit_button(label='Generovat')
40
+
41
+ input_text = input_text.replace("\n", " ").replace(".", " ").replace(",", " ").replace("?", " ").replace("!", " ").lower()
42
+
43
+ if submit_button:
44
+ if not input_text.strip():
45
+ st.error("Prosím, zadejte nějaký text.")
46
+ else:
47
+ # Tokenizace vstupního textu
48
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids
49
+
50
+ eos_token_id = 32001
51
+ max_length = 50
52
+ generated_ids = torch.tensor([[model.config.decoder_start_token_id]])
53
+
54
+ output_placeholder = st.empty()
55
+
56
+ for _ in range(max_length):
57
+ # Forward průchod
58
+ outputs = model(
59
+ input_ids=input_ids,
60
+ decoder_input_ids=generated_ids
61
+ )
62
+
63
+ # Extrakce logits posledního tokenu
64
+ next_token_logits = outputs.logits[:, -1, :]
65
+
66
+ # Sampling nebo argmax pro výběr dalšího tokenu
67
+ next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
68
+
69
+ # Přidání tokenu do generované sekvence
70
+ generated_ids = torch.cat([generated_ids, next_token_id], dim=1)
71
+
72
+ # Ukončení generace při dosažení EOS tokenu
73
+ if next_token_id.item() == eos_token_id:
74
+ break
75
+
76
+ # Malá prodleva pro viditelné generování (můžete upravit podle potřeby)
77
+ #time.sleep(0.3)
78
+
79
+ # Tokeny na text
80
+ generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
81
+ output_placeholder.text(generated_text)
82
+
83
+ st.success("Generování dokončeno!")
84
+
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ tqdm
3
+ numpy
4
+ sentencepiece
5
+ transformers
6
+ scikit-learn
7
+ huggingface_hub