botsi commited on
Commit
78ffba0
·
verified ·
1 Parent(s): 913c48f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +255 -12
app.py CHANGED
@@ -1,4 +1,248 @@
1
- # Original code from https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  # Modified for trust game purposes
3
 
4
  import gradio as gr
@@ -224,16 +468,15 @@ def generate(
224
  repo.push_to_hub(blocking=False, commit_message=f"Updating data at {timestamp}")
225
 
226
  chat_interface = gr.ChatInterface(
227
- fn=generate,
228
- retry_btn=None,
229
- clear_btn=None,
230
- undo_btn=None,
231
- chatbot=gr.Chatbot(avatar_images=('user.png', 'bot.png'), bubble_full_width = False),
232
- examples=[
233
- ["How many Marions are there?"],
234
- ["What is your favorite fruit?"],
235
- ["What do you think about AI in the workplace?"],
236
- ],
237
  )
238
 
239
  with gr.Blocks(css="style.css", theme=gr.themes.Default(primary_hue=gr.themes.colors.emerald,secondary_hue=gr.themes.colors.indigo)) as demo:
@@ -245,4 +488,4 @@ if __name__ == "__main__":
245
  demo.queue(max_size=20).launch()
246
  #demo.queue(max_size=20)
247
  demo.launch(share=True, debug=True)
248
-
 
1
+ import gradio as gr
2
+ import time
3
+ import random
4
+ import json
5
+ import mysql.connector
6
+ import os
7
+ import csv
8
+ import spaces
9
+ import torch
10
+
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
12
+ from threading import Thread
13
+ from typing import Iterator
14
+ from huggingface_hub import Repository, hf_hub_download
15
+ from datetime import datetime
16
+
17
+ # for fetch_personalized_data
18
+ import mysql.connector
19
+ import urllib.parse
20
+ import urllib.request
21
+
22
+ # for saving chat history as JSON - not used
23
+ import atexit
24
+ import os
25
+ from huggingface_hub import HfApi, HfFolder
26
+
27
+ # for saving chat history as dataset - not used
28
+ import huggingface_hub
29
+ from huggingface_hub import Repository
30
+ from datetime import datetime
31
+
32
+ # for saving chat history as dataset - used
33
+ import sqlite3
34
+ import huggingface_hub
35
+ import gradio as gr
36
+ import pandas as pd
37
+ import shutil
38
+ import os
39
+ import datetime
40
+ from apscheduler.schedulers.background import BackgroundScheduler
41
+
42
+ DATASET_REPO_URL = "https://huggingface.co/datasets/botsi/trust-game-llama-2-chat-history"
43
+ DATA_DIRECTORY = "data" # Separate directory for storing data files
44
+ DATA_FILENAME = "marion.csv" # Default filename
45
+ DATA_FILE = os.path.join("data", DATA_FILENAME)
46
+
47
+ HF_TOKEN = os.environ.get("HF_TOKEN")
48
+ print("is none?", HF_TOKEN is None)
49
+ print("hfh", huggingface_hub.__version__)
50
+
51
+ repo = Repository(
52
+ local_dir=DATA_DIRECTORY, clone_from=DATASET_REPO_URL
53
+ )
54
+
55
+ MAX_MAX_NEW_TOKENS = 2048
56
+ DEFAULT_MAX_NEW_TOKENS = 1024
57
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
58
+
59
+ DESCRIPTION = """\
60
+ # This is your personal space to chat.
61
+ You can ask anything.
62
+ """
63
+
64
+ # License and Acceptable Use Policy by Meta
65
+ LICENSE = """
66
+ <p/>
67
+
68
+ ---
69
+ This demo is governed by the [original license](https://ai.meta.com/llama/license/) and [acceptable use policy](https://ai.meta.com/llama/use-policy/).
70
+ The most recent copy of this policy can be found at ai.meta.com/llama/use-policy.
71
+ """
72
+
73
+ if not torch.cuda.is_available():
74
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
75
+
76
+
77
+ if torch.cuda.is_available():
78
+ model_id = "meta-llama/Llama-2-7b-chat-hf"
79
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
80
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
81
+ tokenizer.use_default_system_prompt = False
82
+
83
+ ## trust-game-llama-2-7b-chat
84
+ # app.py
85
+ def get_default_system_prompt():
86
+ BSYS, ESYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
87
+
88
+ DEFAULT_SYSTEM_PROMPT = f"""You are a friendly chat assistant. First, wait for a question. Do not talk before being asked. When the user says hi, politely ask what he/she wants to know.
89
+ Do not tell any more!
90
+ If asked about Marion, you can answer the following: There are only two Marions, one is currently working at Accenture and the other one will shortly begin a new job at BMW.
91
+ She recently quit Accenture. Nicknames by her friends include botsi and botsokrates. Her favorite project was Hypercare. Kathy/Kathi is her colleague at Accenture. Anita is her colleague at Accenture.
92
+ """
93
+ print(DEFAULT_SYSTEM_PROMPT)
94
+ return DEFAULT_SYSTEM_PROMPT
95
+
96
+ ## trust-game-llama-2-7b-chat
97
+ # app.py
98
+ def construct_input_prompt(chat_history, message):
99
+ input_prompt = f"<s>[INST] <<SYS>>\n{get_default_system_prompt()}\n<</SYS>>\n\n "
100
+ for user, assistant in chat_history:
101
+ input_prompt += f"{user} [/INST] {assistant} <s>[INST] "
102
+ input_prompt += f"{message} [/INST] "
103
+ return input_prompt
104
+
105
+ ## trust-game-llama-2-7b-chat
106
+ # app.py
107
+ @spaces.GPU
108
+ def generate(
109
+ request: gr.Request, # To fetch query params
110
+ message: str,
111
+ chat_history: list[tuple[str, str]],
112
+ # input_prompt: str,
113
+ max_new_tokens: int = 1024,
114
+ temperature: float = 0.6,
115
+ top_p: float = 0.9,
116
+ top_k: int = 50,
117
+ repetition_penalty: float = 1.2,
118
+ ) -> Iterator[str]: # Change return type hint to Iterator[str]
119
+
120
+ conversation = []
121
+
122
+ # Fetch query params
123
+ params = {
124
+ key: value for key, value in request.query_params.items()
125
+ }
126
+ print('those are the query params')
127
+ print(params)
128
+
129
+ print("Request headers dictionary:", request.headers)
130
+ print("IP address:", request.client.host)
131
+ print("Query parameters:", params)
132
+
133
+ # Construct the input prompt using the functions from the system_prompt_config module
134
+ input_prompt = construct_input_prompt(chat_history, message)
135
+
136
+ # Move the condition here after the assignment
137
+ if input_prompt:
138
+ conversation.append({"role": "system", "content": input_prompt})
139
+
140
+ # Convert input prompt to tensor
141
+ input_ids = tokenizer(input_prompt, return_tensors="pt").to(model.device)
142
+
143
+ for user, assistant in chat_history:
144
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
145
+ conversation.append({"role": "user", "content": message})
146
+
147
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
148
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
149
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
150
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
151
+ input_ids = input_ids.to(model.device)
152
+
153
+ # Set up the TextIteratorStreamer
154
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
155
+
156
+ # Set up the generation arguments
157
+ generate_kwargs = dict(
158
+ input_ids=input_ids,
159
+ streamer=streamer,
160
+ max_new_tokens=max_new_tokens,
161
+ do_sample=True,
162
+ top_p=top_p,
163
+ top_k=top_k,
164
+ temperature=temperature,
165
+ num_beams=1,
166
+ repetition_penalty=repetition_penalty,
167
+ )
168
+
169
+ # Start the model generation thread
170
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
171
+ t.start()
172
+
173
+ # Yield generated text chunks
174
+ outputs = []
175
+ for text in streamer:
176
+ outputs.append(text)
177
+ yield "".join(outputs)
178
+
179
+ # Fix bug that last answer is not recorded!
180
+ # Parse the outputs into a readable sentence and record them
181
+ # Filter out empty strings and join the remaining strings with spaces
182
+ readable_sentence = ' '.join(filter(lambda x: x.strip(), outputs))
183
+ # Print the readable sentence
184
+ print(readable_sentence)
185
+
186
+ # Save chat history to .csv file on HuggingFace Hub
187
+ # Generate filename with bot id and session id
188
+ filename = f"{DATA_FILENAME}"
189
+ data_file = os.path.join(DATA_DIRECTORY, filename)
190
+
191
+ # Generate timestamp
192
+ timestamp = datetime.datetime.now()
193
+
194
+ # Check if the file already exists
195
+ if os.path.exists(data_file):
196
+ # If file exists, load existing data
197
+ existing_data = pd.read_csv(data_file)
198
+
199
+ # Add timestamp column
200
+ conversation_df = pd.DataFrame(conversation)
201
+ conversation_df['ip_address'] = request.client.host
202
+ conversation_df['readable_sentence'] = readable_sentence
203
+ conversation_df['timestamp'] = timestamp
204
+
205
+ # Append new conversation to existing data
206
+ updated_data = pd.concat([existing_data, conversation_df], ignore_index=True)
207
+ updated_data.to_csv(data_file, index=False)
208
+ else:
209
+ # If file doesn't exist, create new file with conversation data
210
+ conversation_df = pd.DataFrame(conversation)
211
+ conversation_df['ip_address'] = request.client.host
212
+ conversation_df['readable_sentence'] = readable_sentence
213
+ conversation_df['timestamp'] = timestamp
214
+ conversation_df.to_csv(data_file, index=False)
215
+
216
+ print("Updating .csv")
217
+ repo.push_to_hub(blocking=False, commit_message=f"Updating data at {timestamp}")
218
+
219
+ chat_interface = gr.ChatInterface(
220
+ fn=generate,
221
+ retry_btn=None,
222
+ clear_btn=None,
223
+ undo_btn=None,
224
+ chatbot=gr.Chatbot(avatar_images=('user.png', 'bot.png'), bubble_full_width=False),
225
+ examples=[
226
+ ["What is your favorite fruit?"],
227
+ ["What do you think about AI in the workplace?"],
228
+ ],
229
+ )
230
+
231
+ with gr.Blocks(css="style.css", theme=gr.themes.Default(primary_hue=gr.themes.colors.emerald, secondary_hue=gr.themes.colors.indigo)) as demo:
232
+ gr.Markdown(DESCRIPTION)
233
+ chat_interface.render()
234
+ gr.Markdown(LICENSE)
235
+
236
+ if __name__ == "__main__":
237
+ demo.queue(max_size=20).launch() # Launching the interface with queueing and maximum size limit
238
+ # demo.launch(share=True, debug=True) # Uncomment this line if you want to launch the interface with sharing and debug mode
239
+
240
+
241
+
242
+
243
+
244
+
245
+ '''# Original code from https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat
246
  # Modified for trust game purposes
247
 
248
  import gradio as gr
 
468
  repo.push_to_hub(blocking=False, commit_message=f"Updating data at {timestamp}")
469
 
470
  chat_interface = gr.ChatInterface(
471
+ fn=generate,
472
+ retry_btn=None,
473
+ clear_btn=None,
474
+ undo_btn=None,
475
+ chatbot=gr.Chatbot(avatar_images=('user.png', 'bot.png'), bubble_full_width = False),
476
+ examples=[
477
+ ["What is your favorite fruit?"],
478
+ ["What do you think about AI in the workplace?"],
479
+ ],
 
480
  )
481
 
482
  with gr.Blocks(css="style.css", theme=gr.themes.Default(primary_hue=gr.themes.colors.emerald,secondary_hue=gr.themes.colors.indigo)) as demo:
 
488
  demo.queue(max_size=20).launch()
489
  #demo.queue(max_size=20)
490
  demo.launch(share=True, debug=True)
491
+ '''