JYYong commited on
Commit
4cb51c5
ยท
1 Parent(s): 73dae5e

maybe complete

Browse files
Files changed (2) hide show
  1. app.py +106 -136
  2. flagged/log.csv +4 -0
app.py CHANGED
@@ -1,144 +1,114 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def update(name):
4
- return f"Welcome to Gradio, {name}!"
5
-
6
- demo = gr.Blocks()
7
-
8
- with demo:
9
- gr.Markdown(f"๊ฐ ์งˆ๋ฌธ์— ๋Œ€๋‹ต ํ›„ Enter ํ•ด์ฃผ์„ธ์š”.\n\n")
10
- with gr.Row():
11
- topic = gr.Textbox(label="Topic", placeholder="๋Œ€ํ™” ์ฃผ์ œ๋ฅผ ์ •ํ•ด์ฃผ์„ธ์š” (e.g. ์—ฌ๊ฐ€ ์ƒํ™œ, ์ผ๊ณผ ์ง์—…, ๊ฐœ์ธ ๋ฐ ๊ด€๊ณ„, etc...)")
12
- with gr.Row():
13
- with gr.Column():
14
- addr = gr.Textbox(label="์ง€์—ญ", placeholder="e.g. ์—ฌ๊ฐ€ ์ƒํ™œ, ์ผ๊ณผ ์ง์—…, ๊ฐœ์ธ ๋ฐ ๊ด€๊ณ„, etc...")
15
- age = gr.Textbox(label="๋‚˜์ด", placeholder="e.g. 20๋Œ€ ๋ฏธ๋งŒ, 40๋Œ€, 70๋Œ€ ์ด์ƒ, etc...")
16
- sex = gr.Textbox(label="์„ฑ๋ณ„", placeholder="e.g. ๋‚จ์„ฑ, ์—ฌ์„ฑ, etc...")
17
- with gr.Column():
18
- addr = gr.Textbox(label="์ง€์—ญ", placeholder="e.g. ์—ฌ๊ฐ€ ์ƒํ™œ, ์ผ๊ณผ ์ง์—…, ๊ฐœ์ธ ๋ฐ ๊ด€๊ณ„, etc...")
19
- age = gr.Textbox(label="๋‚˜์ด", placeholder="e.g. 20๋Œ€ ๋ฏธ๋งŒ, 40๋Œ€, 70๋Œ€ ์ด์ƒ, etc...")
20
- sex = gr.Textbox(label="์„ฑ๋ณ„", placeholder="e.g. ๋‚จ์„ฑ, ์—ฌ์„ฑ, etc...")
21
- out = gr.Textbox()
22
- btn = gr.Button("Run")
23
- # btn.click(fn=update, inputs=inp, outputs=out)
24
-
25
- demo.launch()
 
 
26
 
27
 
28
- def main(model_name):
29
  warnings.filterwarnings("ignore")
30
 
31
- tokenizer = AutoTokenizer.from_pretrained('kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b')
32
- special_tokens_dict = {'additional_special_tokens': ['<sep>', '<eos>', '<sos>', '#@์ด๋ฆ„#', '#@๊ณ„์ •#', '#@์‹ ์›#', '#@์ „๋ฒˆ#', '#@๊ธˆ์œต#', '#@๋ฒˆํ˜ธ#', '#@์ฃผ์†Œ#', '#@์†Œ์†#', '#@๊ธฐํƒ€#']}
33
- num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
34
-
35
- model = AutoModelForCausalLM.from_pretrained(model_name)
36
- model.resize_token_embeddings(len(tokenizer))
37
- model = model.cuda()
38
-
39
- info = ""
40
-
41
- while True:
42
- if info == "":
43
- print(
44
- f"์ง€๊ธˆ๋ถ€ํ„ฐ ๋Œ€ํ™” ์ •๋ณด๋ฅผ ์ž…๋ ฅ ๋ฐ›๊ฒ ์Šต๋‹ˆ๋‹ค.\n"
45
- f"๊ฐ ์งˆ๋ฌธ์— ๋Œ€๋‹ต ํ›„ Enter ํ•ด์ฃผ์„ธ์š”.\n"
46
- f"์•„๋ฌด ์ž…๋ ฅ ์—†์ด Enter ํ•  ๊ฒฝ์šฐ, ๋ฏธ๋ฆฌ ์ง€์ •๋œ ๊ฐ’ ์ค‘ ๋žœ๋ค์œผ๋กœ ์ •ํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.\n"
47
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- time.sleep(1)
50
 
51
- yon = "no"
52
- else:
53
- yon = input(
54
- f"์ด์ „ ๋Œ€ํ™” ์ •๋ณด๋ฅผ ๊ทธ๋Œ€๋กœ ์œ ์ง€ํ• ๊นŒ์š”? (yes : ์œ ์ง€, no : ์ƒˆ๋กœ ์ž‘์„ฑ) :"
55
- )
56
-
57
- if yon == "no":
58
- info = "์ผ์ƒ ๋Œ€ํ™” "
59
-
60
- topic = input("๋Œ€ํ™” ์ฃผ์ œ๋ฅผ ์ •ํ•ด์ฃผ์„ธ์š” (e.g. ์—ฌ๊ฐ€ ์ƒํ™œ, ์ผ๊ณผ ์ง์—…, ๊ฐœ์ธ ๋ฐ ๊ด€๊ณ„, etc...) :")
61
- if topic == "":
62
- topic = random.choice(['์—ฌ๊ฐ€ ์ƒํ™œ', '์‹œ์‚ฌ/๊ต์œก', '๋ฏธ์šฉ๊ณผ ๊ฑด๊ฐ•', '์‹์Œ๋ฃŒ', '์ƒ๊ฑฐ๋ž˜(์‡ผํ•‘)', '์ผ๊ณผ ์ง์—…', '์ฃผ๊ฑฐ์™€ ์ƒํ™œ', '๊ฐœ์ธ ๋ฐ ๊ด€๊ณ„', 'ํ–‰์‚ฌ'])
63
- print(topic)
64
- info += topic + "<sep>"
65
-
66
- def ask_info(who, ment):
67
- print(ment)
68
- text = who + ":"
69
- addr = input("์–ด๋”” ์‚ฌ์„ธ์š”? (e.g. ์„œ์šธํŠน๋ณ„์‹œ, ์ œ์ฃผ๋„, etc...) :").strip()
70
- if addr == "":
71
- addr = random.choice(['์„œ์šธํŠน๋ณ„์‹œ', '๊ฒฝ๊ธฐ๋„', '๋ถ€์‚ฐ๊ด‘์—ญ์‹œ', '๋Œ€์ „๊ด‘์—ญ์‹œ', '๊ด‘์ฃผ๊ด‘์—ญ์‹œ', '์šธ์‚ฐ๊ด‘์—ญ์‹œ', '๊ฒฝ์ƒ๋‚จ๋„', '์ธ์ฒœ๊ด‘์—ญ์‹œ', '์ถฉ์ฒญ๋ถ๋„', '์ œ์ฃผ๋„', '๊ฐ•์›๋„', '์ถฉ์ฒญ๋‚จ๋„', '์ „๋ผ๋ถ๋„', '๋Œ€๊ตฌ๊ด‘์—ญ์‹œ', '์ „๋ผ๋‚จ๋„', '๊ฒฝ์ƒ๋ถ๋„', '์„ธ์ข…ํŠน๋ณ„์ž์น˜์‹œ', '๊ธฐํƒ€'])
72
- print(addr)
73
- text += addr + " "
74
-
75
- age = input("๋‚˜์ด๊ฐ€? (e.g. 20๋Œ€, 70๋Œ€ ์ด์ƒ, etc...) :").strip()
76
- if age == "":
77
- age = random.choice(['20๋Œ€', '30๋Œ€', '50๋Œ€', '20๋Œ€ ๋ฏธ๋งŒ', '60๋Œ€', '40๋Œ€', '70๋Œ€ ์ด์ƒ'])
78
- print(age)
79
- text += age + " "
80
-
81
- sex = input("์„ฑ๋ณ„์ด? (e.g. ๋‚จ์„ฑ, ์—ฌ์„ฑ, etc... (?)) :").strip()
82
- if sex == "":
83
- sex = random.choice(['๋‚จ์„ฑ', '์—ฌ์„ฑ'])
84
- print(sex)
85
- text += sex + "<sep>"
86
- return text
87
-
88
- info += ask_info(who="P01", ment=f"\n๋‹น์‹ ์— ๋Œ€ํ•ด ์•Œ๋ ค์ฃผ์„ธ์š”.\n")
89
- info += ask_info(who="P02", ment=f"\n์ฑ—๋ด‡์— ๋Œ€ํ•ด ์•Œ๋ ค์ฃผ์„ธ์š”.\n")
90
-
91
- pp = info.replace('<sep>', '\n')
92
- print(
93
- f"\n----------------\n"
94
- f"<์ž…๋ ฅ ์ •๋ณด ํ™•์ธ> (P01 : ๋‹น์‹ , P02 : ์ฑ—๋ด‡)\n"
95
- f"{pp}"
96
- f"----------------\n"
97
- f"๋Œ€ํ™”๋ฅผ ์ข…๋ฃŒํ•˜๊ณ  ์‹ถ์œผ๋ฉด ์–ธ์ œ๋“ ์ง€ 'end' ๋ผ๊ณ  ๋งํ•ด์ฃผ์„ธ์š”~\n"
98
- )
99
- talk = []
100
- switch = True
101
- switch2 = True
102
- while True:
103
- inp = "P01<sos>"
104
- myinp = input("๋‹น์‹  : ")
105
- if myinp == "end":
106
- print("๋Œ€ํ™” ์ข…๋ฃŒ!")
107
- break
108
- inp += myinp + "<eos>"
109
- talk.append(inp)
110
- talk.append("P02<sos>")
111
-
112
- while True:
113
- now_inp = info + "".join(talk)
114
- inpu = tokenizer(now_inp, max_length=1024, truncation='longest_first', return_tensors='pt')
115
- seq_len = inpu.input_ids.size(1)
116
- if seq_len > 512 * 0.8 and switch:
117
- print(
118
- f"<์ฃผ์˜> ํ˜„์žฌ ๋Œ€ํ™” ๊ธธ์ด๊ฐ€ ๊ณง ์ตœ๋Œ€ ๊ธธ์ด์— ๋„๋‹ฌํ•ฉ๋‹ˆ๋‹ค. ({seq_len} / 512)"
119
- )
120
- switch = False
121
-
122
- if seq_len >= 512 and switch2:
123
- print("<์ฃผ์˜> ๋Œ€ํ™” ๊ธธ์ด๊ฐ€ ๋„ˆ๋ฌด ๊ธธ์–ด์กŒ๊ธฐ ๋•Œ๋ฌธ์—, ์ดํ›„ ๋Œ€ํ™”๋Š” ๋งจ ์•ž์˜ ๋ฐœํ™”๋ฅผ ์กฐ๊ธˆ์”ฉ ์ง€์šฐ๋ฉด์„œ ์ง„ํ–‰๋ฉ๋‹ˆ๋‹ค.")
124
- talk = talk[1:]
125
- switch2 = False
126
- else:
127
- break
128
-
129
- out = model.generate(
130
- inputs=inpu.input_ids.cuda(),
131
- attention_mask=inpu.attention_mask.cuda(),
132
- max_length=512,
133
- do_sample=True,
134
- pad_token_id=tokenizer.pad_token_id,
135
- eos_token_id=tokenizer.encode('<eos>')[0]
136
- )
137
- output = tokenizer.batch_decode(out)
138
- print("์ฑ—๋ด‡ : " + output[0][len(now_inp):-5])
139
- talk[-1] += output[0][len(now_inp):]
140
-
141
- again = input(f"๋‹ค๋ฅธ ๋Œ€ํ™”๋ฅผ ์‹œ์ž‘ํ• ๊นŒ์š”? (yes : ์ƒˆ๋กœ์šด ์‹œ์ž‘, no : ์ข…๋ฃŒ) :")
142
- if again == "no":
143
- break
144
-
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import warnings
4
+
5
+
6
+ class Chatbot():
7
+ def __init__(self):
8
+ self.tokenizer = AutoTokenizer.from_pretrained('kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b')
9
+ special_tokens_dict = {'additional_special_tokens': ['<sep>', '<eos>', '<sos>', '#@์ด๋ฆ„#', '#@๊ณ„์ •#', '#@์‹ ์›#', '#@์ „๋ฒˆ#', '#@๊ธˆ์œต#', '#@๋ฒˆํ˜ธ#', '#@์ฃผ์†Œ#', '#@์†Œ์†#', '#@๊ธฐํƒ€#']}
10
+ num_added_toks = self.tokenizer.add_special_tokens(special_tokens_dict)
11
+
12
+ self.model = AutoModelForCausalLM.from_pretrained("/workspace/test_trainer/checkpoint-10000")
13
+ self.model.resize_token_embeddings(len(self.tokenizer))
14
+ self.model = self.model.cuda()
15
+
16
+ self.info = None
17
+ self.talk = []
18
+
19
+ def initialize(self, topic, bot_addr, bot_age, bot_sex, my_addr, my_age, my_sex):
20
+ def encode(age):
21
+ if age < 20:
22
+ age = "20๋Œ€ ๋ฏธ๋งŒ"
23
+ elif age >= 70:
24
+ age = "70๋Œ€ ์ด์ƒ"
25
+ else:
26
+ age = str(age // 10 * 10) + "๋Œ€"
27
+ return age
28
+ bot_age = encode(bot_age)
29
+ my_age = encode(my_age)
30
+ self.info = f"์ผ์ƒ ๋Œ€ํ™” {topic}<sep>P01:{my_addr} {my_age} {my_sex}<sep>P02:{bot_addr} {bot_age} {bot_sex}<sep>"
31
+ return self.info_check()
32
+
33
+ def info_check(self):
34
+ return self.info.replace('<sep>', '\n').replace('P01', '๋‹น์‹ ').replace('P02', '์ฑ—๋ด‡')
35
+
36
+ def reset_talk(self):
37
+ self.talk = []
38
+
39
+ def test(self, myinp):
40
+ state = None
41
+ inp = "P01<sos>" + myinp + "<eos>"
42
+ self.talk.append(inp)
43
+ self.talk.append("P02<sos>")
44
 
45
+ while True:
46
+ now_inp = self.info + "".join(self.talk)
47
+ inputs = self.tokenizer(now_inp, max_length=1024, truncation='longest_first', return_tensors='pt')
48
+ seq_len = inputs.input_ids.size(1)
49
+ if seq_len > 512 * 0.8:
50
+ state = f"<์ฃผ์˜> ํ˜„์žฌ ๋Œ€ํ™” ๊ธธ์ด๊ฐ€ ๊ณง ์ตœ๋Œ€ ๊ธธ์ด์— ๋„๋‹ฌํ•ฉ๋‹ˆ๋‹ค. ({seq_len} / 512)"
51
+
52
+ if seq_len >= 512:
53
+ state = "<์ฃผ์˜> ๋Œ€ํ™” ๊ธธ์ด๊ฐ€ ๋„ˆ๋ฌด ๊ธธ์–ด์กŒ๊ธฐ ๋•Œ๋ฌธ์—, ์ดํ›„ ๋Œ€ํ™”๋Š” ๋งจ ์•ž์˜ ๋ฐœํ™”๋ฅผ ์กฐ๊ธˆ์”ฉ ์ง€์šฐ๋ฉด์„œ ์ง„ํ–‰๋ฉ๋‹ˆ๋‹ค."
54
+ talk = talk[1:]
55
+ else:
56
+ break
57
+
58
+ out = self.model.generate(
59
+ inputs=inputs.input_ids.cuda(),
60
+ attention_mask=inputs.attention_mask.cuda(),
61
+ max_length=512,
62
+ do_sample=True,
63
+ pad_token_id=self.tokenizer.pad_token_id,
64
+ eos_token_id=self.tokenizer.encode('<eos>')[0]
65
+ )
66
+ out = self.tokenizer.batch_decode(out)
67
+ real_out = out[0][len(now_inp):-5]
68
+ self.talk[-1] += out[0][len(now_inp):]
69
+ return [(self.talk[i][8:-5], self.talk[i+1][8:-5]) for i in range(0, len(self.talk)-1, 2)]
70
 
71
 
72
+ if __name__ == "__main__":
73
  warnings.filterwarnings("ignore")
74
 
75
+ chatbot = Chatbot()
76
+ demo = gr.Blocks()
77
+
78
+ with demo:
79
+ gr.Markdown("# <center>MINDs Lab Brain's Fast Neural Chit-Chatbot</center>")
80
+ with gr.Row():
81
+ with gr.Column():
82
+ topic = gr.Radio(label="Topic", choices=['์—ฌ๊ฐ€ ์ƒํ™œ', '์‹œ์‚ฌ/๊ต์œก', '๋ฏธ์šฉ๊ณผ ๊ฑด๊ฐ•', '์‹์Œ๋ฃŒ', '์ƒ๊ฑฐ๋ž˜(์‡ผํ•‘)', '์ผ๊ณผ ์ง์—…', '์ฃผ๊ฑฐ์™€ ์ƒํ™œ', '๊ฐœ์ธ ๋ฐ ๊ด€๊ณ„', 'ํ–‰์‚ฌ'])
83
+ with gr.Column():
84
+ gr.Markdown(f"Bot's persona")
85
+ bot_addr = gr.Dropdown(label="์ง€์—ญ", choices=['์„œ์šธํŠน๋ณ„์‹œ', '๊ฒฝ๊ธฐ๋„', '๋ถ€์‚ฐ๊ด‘์—ญ์‹œ', '๋Œ€์ „๊ด‘์—ญ์‹œ', '๊ด‘์ฃผ๊ด‘์—ญ์‹œ', '์šธ์‚ฐ๊ด‘์—ญ์‹œ', '๊ฒฝ์ƒ๋‚จ๋„', '์ธ์ฒœ๊ด‘์—ญ์‹œ', '์ถฉ์ฒญ๋ถ๋„', '์ œ์ฃผ๋„', '๊ฐ•์›๋„', '์ถฉ์ฒญ๋‚จ๋„', '์ „๋ผ๋ถ๋„', '๋Œ€๊ตฌ๊ด‘์—ญ์‹œ', '์ „๋ผ๋‚จ๋„', '๊ฒฝ์ƒ๋ถ๋„', '์„ธ์ข…ํŠน๋ณ„์ž์น˜์‹œ', '๊ธฐํƒ€'])
86
+ bot_age = gr.Slider(label="๋‚˜์ด", minimum=10, maximum=80, value=45, step=1)
87
+ bot_sex = gr.Radio(label="์„ฑ๋ณ„", choices=["๋‚จ์„ฑ", "์—ฌ์„ฑ"])
88
+ with gr.Column():
89
+ gr.Markdown(f"Your persona")
90
+ my_addr = gr.Dropdown(label="์ง€์—ญ", choices=['์„œ์šธํŠน๋ณ„์‹œ', '๊ฒฝ๊ธฐ๋„', '๋ถ€์‚ฐ๊ด‘์—ญ์‹œ', '๋Œ€์ „๊ด‘์—ญ์‹œ', '๊ด‘์ฃผ๊ด‘์—ญ์‹œ', '์šธ์‚ฐ๊ด‘์—ญ์‹œ', '๊ฒฝ์ƒ๋‚จ๋„', '์ธ์ฒœ๊ด‘์—ญ์‹œ', '์ถฉ์ฒญ๋ถ๋„', '์ œ์ฃผ๋„', '๊ฐ•์›๋„', '์ถฉ์ฒญ๋‚จ๋„', '์ „๋ผ๋ถ๋„', '๋Œ€๊ตฌ๊ด‘์—ญ์‹œ', '์ „๋ผ๋‚จ๋„', '๊ฒฝ์ƒ๋ถ๋„', '์„ธ์ข…ํŠน๋ณ„์ž์น˜์‹œ', '๊ธฐํƒ€'])
91
+ my_age = gr.Slider(label="๋‚˜์ด", minimum=10, maximum=80, value=45, step=1)
92
+ my_sex = gr.Radio(label="์„ฑ๋ณ„", choices=["๋‚จ์„ฑ", "์—ฌ์„ฑ"])
93
+ with gr.Row():
94
+ btn = gr.Button(label="์ ์šฉ")
95
+ state = gr.Textbox(label="์ƒํƒœ")
96
+ btn.click(
97
+ fn=chatbot.initialize,
98
+ inputs=[topic, bot_addr, bot_age, bot_sex, my_addr, my_age, my_sex],
99
+ outputs=state
100
+ )
101
+
102
+ with gr.Column():
103
+ screen = gr.Chatbot(label="์ต๋ช…์˜ ์ƒ๋Œ€")
104
+ with gr.Row():
105
+ speak = gr.Textbox(label="์ž…๋ ฅ์ฐฝ")
106
+ btn = gr.Button(label="Talk")
107
+ btn.click(
108
+ fn=chatbot.test,
109
+ inputs=speak,
110
+ outputs=screen
111
+ )
112
+ demo.launch(share=True)
113
 
 
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flagged/log.csv ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ 'self','output','flag','username','timestamp'
2
+ '๋ญํ•˜๊ณ  ๊ณ„์„ธ์š”?','[(''์•ˆ๋…•ํ•˜์„ธ์š”'', ''๋„ต''), (''๋ญํ•˜๊ณ  ๊ณ„์„ธ์š”?'', ''์ € ๊ฒŒ์ž„ํ•˜๋ฉด์„œ ์žˆ์–ด์šฉ'')]','','','2022-06-29 07:59:03.609856'
3
+ '๋ญํ•˜๊ณ  ๊ณ„์„ธ์š”?','[(''์•ˆ๋…•ํ•˜์„ธ์š”'', ''๋„ต''), (''๋ญํ•˜๊ณ  ๊ณ„์„ธ์š”?'', ''์ € ๊ฒŒ์ž„ํ•˜๋ฉด์„œ ์žˆ์–ด์šฉ'')]','','','2022-06-29 07:59:07.265460'
4
+ '์•ˆ๋…•ํ•˜์„ธ์š”?','[[''์•ˆ๋…•ํ•˜์„ธ์š”?'', ''์•„๋‹ˆ'']]','','','2022-06-29 08:15:33.284872'