Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from transformers import T5Tokenizer, AutoModelForCausalLM | |
| def cached_tokenizer(): | |
| tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium") | |
| tokenizer.do_lower_case = True | |
| return tokenizer | |
| def cached_model(): | |
| model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium") | |
| return model | |
| def main(): | |
| st.title("GPT-2ใซใใๆฅๆฌ่ชใฎๆ็ซ ็ๆ") | |
| num_of_output_text = st.slider(label='ๅบๅใใๆ็ซ ใฎๆฐ', | |
| min_value=1, | |
| max_value=2, | |
| value=1, | |
| ) | |
| length_of_output_text = st.slider(label='ๅบๅใใๆๅญๆฐ', | |
| min_value=30, | |
| max_value=200, | |
| value=100, | |
| ) | |
| PREFIX_TEXT = st.text_area( | |
| label='ใใญในใๅ ฅๅ', | |
| value='ๅพ่ผฉใฏ็ซใงใใ' | |
| ) | |
| progress_num = 0 | |
| status_text = st.empty() | |
| progress_bar = st.progress(progress_num) | |
| if st.button('ๆ็ซ ็ๆ'): | |
| st.text("่ชญใฟ่พผใฟใซๆ้ใใใใใพใ") | |
| progress_num = 10 | |
| status_text.text(f'Progress: {progress_num}%') | |
| progress_bar.progress(progress_num) | |
| tokenizer = cached_tokenizer() | |
| progress_num = 25 | |
| status_text.text(f'Progress: {progress_num}%') | |
| progress_bar.progress(progress_num) | |
| model = cached_model() | |
| progress_num = 40 | |
| status_text.text(f'Progress: {progress_num}%') | |
| progress_bar.progress(progress_num) | |
| # ๆจ่ซ | |
| input = tokenizer.encode(PREFIX_TEXT, return_tensors="pt") | |
| progress_num = 60 | |
| status_text.text(f'Progress: {progress_num}%') | |
| progress_bar.progress(progress_num) | |
| output = model.generate( | |
| input, do_sample=True, | |
| max_length=length_of_output_text, | |
| num_return_sequences=num_of_output_text | |
| ) | |
| progress_num = 90 | |
| status_text.text(f'Progress: {progress_num}%') | |
| progress_bar.progress(progress_num) | |
| output_text = "".join(tokenizer.batch_decode(output)).replace("</s>", "") | |
| output_text = output_text.replace("</unk>", "") | |
| progress_num = 95 | |
| status_text.text(f'Progress: {progress_num}%') | |
| progress_bar.progress(progress_num) | |
| st.info('็ๆ็ตๆ') | |
| progress_num = 100 | |
| status_text.text(f'Progress: {progress_num}%') | |
| st.write(output_text) | |
| progress_bar.progress(progress_num) | |
| if __name__ == '__main__': | |
| main() |