botsi commited on
Commit
9c3db7a
·
verified ·
1 Parent(s): 6e911e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -125
app.py CHANGED
@@ -67,6 +67,107 @@ if torch.cuda.is_available():
67
  tokenizer = AutoTokenizer.from_pretrained(model_id)
68
  tokenizer.use_default_system_prompt = False
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  with gr.Blocks(css="style.css") as demo:
72
  ## gradio-chatbot-read-query-param
@@ -78,6 +179,11 @@ with gr.Blocks(css="style.css") as demo:
78
  return url_params;
79
  }
80
  """
 
 
 
 
 
81
 
82
  def fetch_personalized_data(session_index):
83
  # Connect to the database
@@ -125,116 +231,6 @@ with gr.Blocks(css="style.css") as demo:
125
  print(f"Error: {err}")
126
  return None
127
 
128
- ## trust-game-llama-2-7b-chat
129
- # app.py
130
- def construct_input_prompt(chat_history, message, personalized_data):
131
- input_prompt = f"<s>[INST] <<SYS>>\n{get_default_system_prompt(personalized_data)}\n<</SYS>>\n\n "
132
-
133
- for user, assistant in chat_history:
134
- input_prompt += f"{user} [/INST] {assistant} <s>[INST] "
135
-
136
- input_prompt += f"{message} [/INST] "
137
-
138
- return input_prompt
139
-
140
-
141
- ## trust-game-llama-2-7b-chat
142
- # app.py
143
- @spaces.GPU
144
- def generate(
145
- message: str,
146
- chat_history: list[tuple[str, str]],
147
- # system_prompt: str,
148
- max_new_tokens: int = 1024,
149
- temperature: float = 0.6,
150
- top_p: float = 0.9,
151
- top_k: int = 50,
152
- repetition_penalty: float = 1.2,
153
- ) -> Iterator[str]: # Change return type hint to Iterator[str]
154
-
155
- # Fetch personalized data
156
- url_params = get_window_url_params()
157
- session_index = get_session_index(chat_history, url_params)
158
- personalized_data = fetch_personalized_data(session_index)
159
-
160
- # Construct the input prompt using the functions from the system_prompt_config module
161
- input_prompt = construct_input_prompt(chat_history, message, personalized_data)
162
-
163
- # Use the global variable to store the chat history
164
- # global global_chat_history
165
-
166
- conversation = []
167
-
168
- # Move the condition here after the assignment
169
- if input_prompt:
170
- conversation.append({"role": "system", "content": input_prompt})
171
-
172
- # Convert input prompt to tensor
173
- input_ids = tokenizer(input_prompt, return_tensors="pt").to(model.device)
174
-
175
-
176
- for user, assistant in chat_history:
177
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
178
- conversation.append({"role": "user", "content": message})
179
-
180
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
181
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
182
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
183
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
184
- input_ids = input_ids.to(model.device)
185
-
186
- # Set up the TextIteratorStreamer
187
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
188
-
189
- # Set up the generation arguments
190
- generate_kwargs = dict(
191
- {"input_ids": input_ids},
192
- streamer=streamer,
193
- max_new_tokens=max_new_tokens,
194
- do_sample=True,
195
- top_p=top_p,
196
- top_k=top_k,
197
- temperature=temperature,
198
- num_beams=1,
199
- repetition_penalty=repetition_penalty,
200
- )
201
-
202
- # Start the model generation thread
203
- t = Thread(target=model.generate, kwargs=generate_kwargs)
204
- t.start()
205
-
206
- # Yield generated text chunks
207
- outputs = []
208
- for text in streamer:
209
- outputs.append(text)
210
- yield "".join(outputs)
211
-
212
- # Update the global_chat_history with the current conversation
213
- # global_chat_history.append({
214
- # "message": message,
215
- # "chat_history": chat_history,
216
- # "system_prompt": input_prompt,
217
- # "output": outputs[-1], # Assuming you want to save the latest model output
218
- # })
219
-
220
- # The modification above starting with "global_chat.history.append" introduces a global_chat_history variable to store the chat history globally.
221
- # The save_chat_history function is registered to be called when the program exits
222
- # using atexit.register(save_chat_history).
223
- # It saves the chat history to a JSON file named "chat_history.json".
224
- # The generate function is updated to append the current conversation to global_chat_history
225
- # after generating each response.
226
-
227
-
228
- #gr.Markdown(DESCRIPTION)
229
- #gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
230
- ## gradio-chatbot-read-query-param
231
-
232
- def get_session_index(chat_history, url_params):
233
- if chat_history and bool(chat_history[-1][0].strip()):
234
- session_index = url_params.get('session_index')
235
- print(session_index)
236
- return session_index
237
-
238
  ## trust-game-llama-2-7b-chat
239
  # app.py
240
  def get_default_system_prompt(personalized_data):
@@ -263,21 +259,20 @@ with gr.Blocks(css="style.css") as demo:
263
  print(DEFAULT_SYSTEM_PROMPT)
264
  return DEFAULT_SYSTEM_PROMPT
265
 
266
- chat_interface = gr.ChatInterface(
267
- fn=generate,
268
- theme="soft",
269
- retry_btn=None,
270
- clear_btn=None,
271
- undo_btn=None,
272
- chatbot=gr.Chatbot(avatar_images=('user.png', 'bot.png'), bubble_full_width = False),
273
- examples=[
274
- ["Can you explain the rules very briefly again?"],
275
- ["How much should I invest in order to win?"],
276
- ["What happened in the last round?"],
277
- ["What is my probability to win if I do not share anything?"],
278
- ],
279
- )
280
 
 
 
 
 
 
281
  chat_interface.render()
282
  #gr.Markdown(LICENSE)
283
 
@@ -290,4 +285,3 @@ if __name__ == "__main__":
290
  # Register the function to be called when the program exits
291
  # atexit.register(save_chat_history)
292
 
293
-
 
67
  tokenizer = AutoTokenizer.from_pretrained(model_id)
68
  tokenizer.use_default_system_prompt = False
69
 
70
+ ## trust-game-llama-2-7b-chat
71
+ # app.py
72
+ @spaces.GPU
73
+ def generate(
74
+ message: str,
75
+ chat_history: list[tuple[str, str]],
76
+ # system_prompt: str,
77
+ max_new_tokens: int = 1024,
78
+ temperature: float = 0.6,
79
+ top_p: float = 0.9,
80
+ top_k: int = 50,
81
+ repetition_penalty: float = 1.2,
82
+ ) -> Iterator[str]: # Change return type hint to Iterator[str]
83
+
84
+
85
+ # Construct the input prompt using the functions from the system_prompt_config module
86
+ input_prompt = construct_input_prompt(chat_history, message, personalized_data)
87
+
88
+ # Use the global variable to store the chat history
89
+ # global global_chat_history
90
+
91
+ conversation = []
92
+
93
+ # Move the condition here after the assignment
94
+ if input_prompt:
95
+ conversation.append({"role": "system", "content": input_prompt})
96
+
97
+ # Convert input prompt to tensor
98
+ input_ids = tokenizer(input_prompt, return_tensors="pt").to(model.device)
99
+
100
+
101
+ for user, assistant in chat_history:
102
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
103
+ conversation.append({"role": "user", "content": message})
104
+
105
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
106
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
107
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
108
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
109
+ input_ids = input_ids.to(model.device)
110
+
111
+ # Set up the TextIteratorStreamer
112
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
113
+
114
+ # Set up the generation arguments
115
+ generate_kwargs = dict(
116
+ {"input_ids": input_ids},
117
+ streamer=streamer,
118
+ max_new_tokens=max_new_tokens,
119
+ do_sample=True,
120
+ top_p=top_p,
121
+ top_k=top_k,
122
+ temperature=temperature,
123
+ num_beams=1,
124
+ repetition_penalty=repetition_penalty,
125
+ )
126
+
127
+ # Start the model generation thread
128
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
129
+ t.start()
130
+
131
+ # Yield generated text chunks
132
+ outputs = []
133
+ for text in streamer:
134
+ outputs.append(text)
135
+ yield "".join(outputs)
136
+
137
+ # Update the global_chat_history with the current conversation
138
+ # global_chat_history.append({
139
+ # "message": message,
140
+ # "chat_history": chat_history,
141
+ # "system_prompt": input_prompt,
142
+ # "output": outputs[-1], # Assuming you want to save the latest model output
143
+ # })
144
+
145
+ # The modification above starting with "global_chat.history.append" introduces a global_chat_history variable to store the chat history globally.
146
+ # The save_chat_history function is registered to be called when the program exits
147
+ # using atexit.register(save_chat_history).
148
+ # It saves the chat history to a JSON file named "chat_history.json".
149
+ # The generate function is updated to append the current conversation to global_chat_history
150
+ # after generating each response.
151
+
152
+
153
+ #gr.Markdown(DESCRIPTION)
154
+ #gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
155
+ ## gradio-chatbot-read-query-param
156
+
157
+ chat_interface = gr.ChatInterface(
158
+ fn=generate,
159
+ theme="soft",
160
+ retry_btn=None,
161
+ clear_btn=None,
162
+ undo_btn=None,
163
+ chatbot=gr.Chatbot(avatar_images=('user.png', 'bot.png'), bubble_full_width = False),
164
+ examples=[
165
+ ["Can you explain the rules very briefly again?"],
166
+ ["How much should I invest in order to win?"],
167
+ ["What happened in the last round?"],
168
+ ["What is my probability to win if I do not share anything?"],
169
+ ],
170
+ )
171
 
172
  with gr.Blocks(css="style.css") as demo:
173
  ## gradio-chatbot-read-query-param
 
179
  return url_params;
180
  }
181
  """
182
+ def get_session_index(chat_history, url_params):
183
+ if chat_history and bool(chat_history[-1][0].strip()):
184
+ session_index = url_params.get('session_index')
185
+ print(session_index)
186
+ return session_index
187
 
188
  def fetch_personalized_data(session_index):
189
  # Connect to the database
 
231
  print(f"Error: {err}")
232
  return None
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  ## trust-game-llama-2-7b-chat
235
  # app.py
236
  def get_default_system_prompt(personalized_data):
 
259
  print(DEFAULT_SYSTEM_PROMPT)
260
  return DEFAULT_SYSTEM_PROMPT
261
 
262
+ ## trust-game-llama-2-7b-chat
263
+ # app.py
264
+ def construct_input_prompt(chat_history, message, personalized_data):
265
+ input_prompt = f"<s>[INST] <<SYS>>\n{get_default_system_prompt(personalized_data)}\n<</SYS>>\n\n "
266
+ for user, assistant in chat_history:
267
+ input_prompt += f"{user} [/INST] {assistant} <s>[INST] "
268
+ input_prompt += f"{message} [/INST] "
269
+ return input_prompt
 
 
 
 
 
 
270
 
271
+ # Fetch personalized data
272
+ url_params = get_window_url_params()
273
+ session_index = get_session_index(chat_history, url_params)
274
+ personalized_data = fetch_personalized_data(session_index)
275
+
276
  chat_interface.render()
277
  #gr.Markdown(LICENSE)
278
 
 
285
  # Register the function to be called when the program exits
286
  # atexit.register(save_chat_history)
287