Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| def update(name): | |
| return f"Welcome to Gradio, {name}!" | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown(f"๊ฐ ์ง๋ฌธ์ ๋๋ต ํ Enter ํด์ฃผ์ธ์.\n\n") | |
| with gr.Row(): | |
| topic = gr.Textbox(label="Topic", placeholder="๋ํ ์ฃผ์ ๋ฅผ ์ ํด์ฃผ์ธ์ (e.g. ์ฌ๊ฐ ์ํ, ์ผ๊ณผ ์ง์ , ๊ฐ์ธ ๋ฐ ๊ด๊ณ, etc...)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| addr = gr.Textbox(label="์ง์ญ", placeholder="e.g. ์ฌ๊ฐ ์ํ, ์ผ๊ณผ ์ง์ , ๊ฐ์ธ ๋ฐ ๊ด๊ณ, etc...") | |
| age = gr.Textbox(label="๋์ด", placeholder="e.g. 20๋ ๋ฏธ๋ง, 40๋, 70๋ ์ด์, etc...") | |
| sex = gr.Textbox(label="์ฑ๋ณ", placeholder="e.g. ๋จ์ฑ, ์ฌ์ฑ, etc...") | |
| with gr.Column(): | |
| addr = gr.Textbox(label="์ง์ญ", placeholder="e.g. ์ฌ๊ฐ ์ํ, ์ผ๊ณผ ์ง์ , ๊ฐ์ธ ๋ฐ ๊ด๊ณ, etc...") | |
| age = gr.Textbox(label="๋์ด", placeholder="e.g. 20๋ ๋ฏธ๋ง, 40๋, 70๋ ์ด์, etc...") | |
| sex = gr.Textbox(label="์ฑ๋ณ", placeholder="e.g. ๋จ์ฑ, ์ฌ์ฑ, etc...") | |
| out = gr.Textbox() | |
| btn = gr.Button("Run") | |
| btn.click(fn=update, inputs=inp, outputs=out) | |
| demo.launch() | |
| def main(model_name): | |
| warnings.filterwarnings("ignore") | |
| tokenizer = AutoTokenizer.from_pretrained('kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b') | |
| special_tokens_dict = {'additional_special_tokens': ['<sep>', '<eos>', '<sos>', '#@์ด๋ฆ#', '#@๊ณ์ #', '#@์ ์#', '#@์ ๋ฒ#', '#@๊ธ์ต#', '#@๋ฒํธ#', '#@์ฃผ์#', '#@์์#', '#@๊ธฐํ#']} | |
| num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| model.resize_token_embeddings(len(tokenizer)) | |
| model = model.cuda() | |
| info = "" | |
| while True: | |
| if info == "": | |
| print( | |
| f"์ง๊ธ๋ถํฐ ๋ํ ์ ๋ณด๋ฅผ ์ ๋ ฅ ๋ฐ๊ฒ ์ต๋๋ค.\n" | |
| f"๊ฐ ์ง๋ฌธ์ ๋๋ต ํ Enter ํด์ฃผ์ธ์.\n" | |
| f"์๋ฌด ์ ๋ ฅ ์์ด Enter ํ ๊ฒฝ์ฐ, ๋ฏธ๋ฆฌ ์ง์ ๋ ๊ฐ ์ค ๋๋ค์ผ๋ก ์ ํ๊ฒ ๋ฉ๋๋ค.\n" | |
| ) | |
| time.sleep(1) | |
| yon = "no" | |
| else: | |
| yon = input( | |
| f"์ด์ ๋ํ ์ ๋ณด๋ฅผ ๊ทธ๋๋ก ์ ์งํ ๊น์? (yes : ์ ์ง, no : ์๋ก ์์ฑ) :" | |
| ) | |
| if yon == "no": | |
| info = "์ผ์ ๋ํ " | |
| topic = input("๋ํ ์ฃผ์ ๋ฅผ ์ ํด์ฃผ์ธ์ (e.g. ์ฌ๊ฐ ์ํ, ์ผ๊ณผ ์ง์ , ๊ฐ์ธ ๋ฐ ๊ด๊ณ, etc...) :") | |
| if topic == "": | |
| topic = random.choice(['์ฌ๊ฐ ์ํ', '์์ฌ/๊ต์ก', '๋ฏธ์ฉ๊ณผ ๊ฑด๊ฐ', '์์๋ฃ', '์๊ฑฐ๋(์ผํ)', '์ผ๊ณผ ์ง์ ', '์ฃผ๊ฑฐ์ ์ํ', '๊ฐ์ธ ๋ฐ ๊ด๊ณ', 'ํ์ฌ']) | |
| print(topic) | |
| info += topic + "<sep>" | |
| def ask_info(who, ment): | |
| print(ment) | |
| text = who + ":" | |
| addr = input("์ด๋ ์ฌ์ธ์? (e.g. ์์ธํน๋ณ์, ์ ์ฃผ๋, etc...) :").strip() | |
| if addr == "": | |
| addr = random.choice(['์์ธํน๋ณ์', '๊ฒฝ๊ธฐ๋', '๋ถ์ฐ๊ด์ญ์', '๋์ ๊ด์ญ์', '๊ด์ฃผ๊ด์ญ์', '์ธ์ฐ๊ด์ญ์', '๊ฒฝ์๋จ๋', '์ธ์ฒ๊ด์ญ์', '์ถฉ์ฒญ๋ถ๋', '์ ์ฃผ๋', '๊ฐ์๋', '์ถฉ์ฒญ๋จ๋', '์ ๋ผ๋ถ๋', '๋๊ตฌ๊ด์ญ์', '์ ๋ผ๋จ๋', '๊ฒฝ์๋ถ๋', '์ธ์ข ํน๋ณ์์น์', '๊ธฐํ']) | |
| print(addr) | |
| text += addr + " " | |
| age = input("๋์ด๊ฐ? (e.g. 20๋, 70๋ ์ด์, etc...) :").strip() | |
| if age == "": | |
| age = random.choice(['20๋', '30๋', '50๋', '20๋ ๋ฏธ๋ง', '60๋', '40๋', '70๋ ์ด์']) | |
| print(age) | |
| text += age + " " | |
| sex = input("์ฑ๋ณ์ด? (e.g. ๋จ์ฑ, ์ฌ์ฑ, etc... (?)) :").strip() | |
| if sex == "": | |
| sex = random.choice(['๋จ์ฑ', '์ฌ์ฑ']) | |
| print(sex) | |
| text += sex + "<sep>" | |
| return text | |
| info += ask_info(who="P01", ment=f"\n๋น์ ์ ๋ํด ์๋ ค์ฃผ์ธ์.\n") | |
| info += ask_info(who="P02", ment=f"\n์ฑ๋ด์ ๋ํด ์๋ ค์ฃผ์ธ์.\n") | |
| pp = info.replace('<sep>', '\n') | |
| print( | |
| f"\n----------------\n" | |
| f"<์ ๋ ฅ ์ ๋ณด ํ์ธ> (P01 : ๋น์ , P02 : ์ฑ๋ด)\n" | |
| f"{pp}" | |
| f"----------------\n" | |
| f"๋ํ๋ฅผ ์ข ๋ฃํ๊ณ ์ถ์ผ๋ฉด ์ธ์ ๋ ์ง 'end' ๋ผ๊ณ ๋งํด์ฃผ์ธ์~\n" | |
| ) | |
| talk = [] | |
| switch = True | |
| switch2 = True | |
| while True: | |
| inp = "P01<sos>" | |
| myinp = input("๋น์ : ") | |
| if myinp == "end": | |
| print("๋ํ ์ข ๋ฃ!") | |
| break | |
| inp += myinp + "<eos>" | |
| talk.append(inp) | |
| talk.append("P02<sos>") | |
| while True: | |
| now_inp = info + "".join(talk) | |
| inpu = tokenizer(now_inp, max_length=1024, truncation='longest_first', return_tensors='pt') | |
| seq_len = inpu.input_ids.size(1) | |
| if seq_len > 512 * 0.8 and switch: | |
| print( | |
| f"<์ฃผ์> ํ์ฌ ๋ํ ๊ธธ์ด๊ฐ ๊ณง ์ต๋ ๊ธธ์ด์ ๋๋ฌํฉ๋๋ค. ({seq_len} / 512)" | |
| ) | |
| switch = False | |
| if seq_len >= 512 and switch2: | |
| print("<์ฃผ์> ๋ํ ๊ธธ์ด๊ฐ ๋๋ฌด ๊ธธ์ด์ก๊ธฐ ๋๋ฌธ์, ์ดํ ๋ํ๋ ๋งจ ์์ ๋ฐํ๋ฅผ ์กฐ๊ธ์ฉ ์ง์ฐ๋ฉด์ ์งํ๋ฉ๋๋ค.") | |
| talk = talk[1:] | |
| switch2 = False | |
| else: | |
| break | |
| out = model.generate( | |
| inputs=inpu.input_ids.cuda(), | |
| attention_mask=inpu.attention_mask.cuda(), | |
| max_length=512, | |
| do_sample=True, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.encode('<eos>')[0] | |
| ) | |
| output = tokenizer.batch_decode(out) | |
| print("์ฑ๋ด : " + output[0][len(now_inp):-5]) | |
| talk[-1] += output[0][len(now_inp):] | |
| again = input(f"๋ค๋ฅธ ๋ํ๋ฅผ ์์ํ ๊น์? (yes : ์๋ก์ด ์์, no : ์ข ๋ฃ) :") | |
| if again == "no": | |
| break | |