Files changed (1) hide show
  1. app.py +62 -1
app.py CHANGED
@@ -18,6 +18,34 @@ SAMPLE_IDX = []
18
  RANDOM_POSITION = [(145 + 200 * i + 400 * (i//2), j * 110 + 900) for i in range(4) for j in range(4) ]
19
  CURRENT_POSITION = []
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  print (NUM_THREADS)
22
 
23
  def exception_handler(exception_type, exception, traceback):
@@ -128,7 +156,40 @@ def parse_codeblock(text):
128
  # print (f'error found: {e}')
129
  # yield [(parse_codeblock(history[i]), parse_codeblock(history[i + 1])) for i in range(0, len(history) - 1, 2) ], history, chat_counter, response, gr.update(interactive=True), gr.update(interactive=True)
130
  # print(json.dumps({"chat_counter": chat_counter, "payload": payload, "partial_words": partial_words, "token_counter": token_counter, "counter": counter}))
131
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  def reset_textbox():
134
  return gr.update(value='', interactive=False), gr.update(interactive=False)
 
18
  RANDOM_POSITION = [(145 + 200 * i + 400 * (i//2), j * 110 + 900) for i in range(4) for j in range(4) ]
19
  CURRENT_POSITION = []
20
 
21
+
22
+ import os
23
+ import numpy as np
24
+ import torch
25
+
26
+
27
+
28
+ ###########
29
+
30
+ from thinkgpt.llm import ThinkGPT
31
+
32
+ os.environ['OPENAI_API_KEY']='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'
33
+ llm = ThinkGPT(model_name="gpt-3.5-turbo")
34
+
35
+ import gradio as gr
36
+ from transformers import AutoModelForCausalLM, AutoTokenizer
37
+
38
+
39
+ title = """<h1 align="center"> 🏛️ Courtroom with AI juries</h1>"""
40
+ description = "Building open-domain chatbots is a challenging area for machine learning research."
41
+ examples = [["How are you?"]]
42
+ os.environ['OPENAI_API_KEY']='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'
43
+
44
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
45
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
46
+
47
+ ###########
48
+
49
  print (NUM_THREADS)
50
 
51
  def exception_handler(exception_type, exception, traceback):
 
156
  # print (f'error found: {e}')
157
  # yield [(parse_codeblock(history[i]), parse_codeblock(history[i + 1])) for i in range(0, len(history) - 1, 2) ], history, chat_counter, response, gr.update(interactive=True), gr.update(interactive=True)
158
  # print(json.dumps({"chat_counter": chat_counter, "payload": payload, "partial_words": partial_words, "token_counter": token_counter, "counter": counter}))
159
+
160
+
161
+ def predict(input, history=[]):
162
+ # tokenize the new input sentence
163
+ new_user_input_ids = tokenizer.encode(
164
+ input + tokenizer.eos_token, return_tensors="pt"
165
+ )
166
+
167
+ llm.memorize(memorize_list)
168
+ prediction = llm.predict(question, remember=llm.remember('철수', limit = 30))
169
+ print(prediction)
170
+ print('-'*100)
171
+
172
+
173
+ # append the new user input tokens to the chat history
174
+ bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
175
+
176
+ # generate a response
177
+ history = model.generate(
178
+ bot_input_ids, max_length=4000, pad_token_id=tokenizer.eos_token_id
179
+ ).tolist()
180
+
181
+ # convert the tokens to text, and then split the responses into lines
182
+ response = tokenizer.decode(history[0]).split("<|endoftext|>")
183
+ # print('decoded_response-->>'+str(response))
184
+ response = [
185
+ (response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)
186
+ ] # convert to tuples of list
187
+ # print('response-->>'+str(response))
188
+
189
+ return response, prediction
190
+
191
+
192
+
193
 
194
  def reset_textbox():
195
  return gr.update(value='', interactive=False), gr.update(interactive=False)