Update app.py
Browse files
app.py
CHANGED
|
@@ -45,7 +45,7 @@ from apscheduler.schedulers.background import BackgroundScheduler
|
|
| 45 |
|
| 46 |
DATASET_REPO_URL = "https://huggingface.co/datasets/botsi/trust-game-llama-2-chat-history"
|
| 47 |
DATA_DIRECTORY = "data" # Separate directory for storing data files
|
| 48 |
-
DATA_FILENAME = "
|
| 49 |
DATA_FILE = os.path.join("data", DATA_FILENAME)
|
| 50 |
|
| 51 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
|
@@ -83,165 +83,26 @@ if torch.cuda.is_available():
|
|
| 83 |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
|
| 84 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 85 |
tokenizer.use_default_system_prompt = False
|
| 86 |
-
|
| 87 |
-
def fetch_personalized_data(session_index):
|
| 88 |
-
try:
|
| 89 |
-
# Connect to the database
|
| 90 |
-
with mysql.connector.connect(
|
| 91 |
-
host="18.153.94.89",
|
| 92 |
-
user="root",
|
| 93 |
-
password="N12RXMKtKxRj",
|
| 94 |
-
database="lionessdb"
|
| 95 |
-
) as conn:
|
| 96 |
-
# Create a cursor object
|
| 97 |
-
with conn.cursor() as cursor:
|
| 98 |
-
# Query to fetch relevant data from both tables based on session_index
|
| 99 |
-
query = """
|
| 100 |
-
SELECT e5390g37504_core.playerNr,
|
| 101 |
-
e5390g37504_core.groupNrStart,
|
| 102 |
-
e5390g37504_core.subjectNr,
|
| 103 |
-
e5390g37504_core.onPage,
|
| 104 |
-
e5390g37504_decisions.session_index,
|
| 105 |
-
e5390g37504_decisions.transfer1,
|
| 106 |
-
e5390g37504_decisions.tripledAmount1,
|
| 107 |
-
e5390g37504_decisions.keptForSelf1,
|
| 108 |
-
e5390g37504_decisions.returned1,
|
| 109 |
-
e5390g37504_decisions.newCreditRound2,
|
| 110 |
-
e5390g37504_decisions.transfer2,
|
| 111 |
-
e5390g37504_decisions.tripledAmount2,
|
| 112 |
-
e5390g37504_decisions.keptForSelf2,
|
| 113 |
-
e5390g37504_decisions.returned2,
|
| 114 |
-
e5390g37504_decisions.results2rounds,
|
| 115 |
-
e5390g37504_decisions.newCreditRound3,
|
| 116 |
-
e5390g37504_decisions.transfer3,
|
| 117 |
-
e5390g37504_decisions.tripledAmount3,
|
| 118 |
-
e5390g37504_decisions.keptForSelf3,
|
| 119 |
-
e5390g37504_decisions.returned3,
|
| 120 |
-
e5390g37504_decisions.results3rounds
|
| 121 |
-
FROM e5390g37504_core
|
| 122 |
-
JOIN e5390g37504_decisions ON
|
| 123 |
-
e5390g37504_core.playerNr = e5390g37504_decisions.playerNr
|
| 124 |
-
WHERE e5390g37504_decisions.session_index = %s
|
| 125 |
-
UNION ALL
|
| 126 |
-
SELECT e5390g37504_core.playerNr,
|
| 127 |
-
e5390g37504_core.groupNrStart,
|
| 128 |
-
e5390g37504_core.subjectNr,
|
| 129 |
-
e5390g37504_core.onPage,
|
| 130 |
-
e5390g37504_decisions.session_index,
|
| 131 |
-
e5390g37504_decisions.transfer1,
|
| 132 |
-
e5390g37504_decisions.tripledAmount1,
|
| 133 |
-
e5390g37504_decisions.keptForSelf1,
|
| 134 |
-
e5390g37504_decisions.returned1,
|
| 135 |
-
e5390g37504_decisions.newCreditRound2,
|
| 136 |
-
e5390g37504_decisions.transfer2,
|
| 137 |
-
e5390g37504_decisions.tripledAmount2,
|
| 138 |
-
e5390g37504_decisions.keptForSelf2,
|
| 139 |
-
e5390g37504_decisions.returned2,
|
| 140 |
-
e5390g37504_decisions.results2rounds,
|
| 141 |
-
e5390g37504_decisions.newCreditRound3,
|
| 142 |
-
e5390g37504_decisions.transfer3,
|
| 143 |
-
e5390g37504_decisions.tripledAmount3,
|
| 144 |
-
e5390g37504_decisions.keptForSelf3,
|
| 145 |
-
e5390g37504_decisions.returned3,
|
| 146 |
-
e5390g37504_decisions.results3rounds
|
| 147 |
-
FROM e5390g37504_core
|
| 148 |
-
JOIN e5390g37504_decisions
|
| 149 |
-
ON e5390g37504_core.playerNr = e5390g37504_decisions.playerNr
|
| 150 |
-
WHERE e5390g37504_core.groupNrStart IN (
|
| 151 |
-
SELECT DISTINCT groupNrStart
|
| 152 |
-
FROM e5390g37504_core
|
| 153 |
-
JOIN e5390g37504_decisions
|
| 154 |
-
ON e5390g37504_core.playerNr = e5390g37504_decisions.playerNr
|
| 155 |
-
WHERE e5390g37504_decisions.session_index = %s
|
| 156 |
-
) AND e5390g37504_decisions.session_index != %s
|
| 157 |
-
"""
|
| 158 |
-
cursor.execute(query,(session_index, session_index, session_index))
|
| 159 |
-
# Fetch data row by row
|
| 160 |
-
data = [{
|
| 161 |
-
'playerNr': row[0],
|
| 162 |
-
'groupNrStart': row[1],
|
| 163 |
-
'subjectNr': row[2],
|
| 164 |
-
'onPage': row[3],
|
| 165 |
-
'session_index': row[4],
|
| 166 |
-
'transfer1': row[5],
|
| 167 |
-
'tripledAmount1': row[6],
|
| 168 |
-
'keptForSelf1': row[7],
|
| 169 |
-
'returned1': row[8],
|
| 170 |
-
'newCreditRound2': row[9],
|
| 171 |
-
'transfer2': row[10],
|
| 172 |
-
'tripledAmount2': row[11],
|
| 173 |
-
'keptForSelf2': row[12],
|
| 174 |
-
'returned2': row[13],
|
| 175 |
-
'results2rounds': row[14],
|
| 176 |
-
'newCreditRound3': row[15],
|
| 177 |
-
'transfer3': row[16],
|
| 178 |
-
'tripledAmount3': row[17],
|
| 179 |
-
'keptForSelf3': row[18],
|
| 180 |
-
'returned3': row[19],
|
| 181 |
-
'results3rounds': row[20]
|
| 182 |
-
} for row in cursor]
|
| 183 |
-
print(data)
|
| 184 |
-
return data
|
| 185 |
-
except mysql.connector.Error as err:
|
| 186 |
-
print(f"Error: {err}")
|
| 187 |
-
return None
|
| 188 |
-
|
| 189 |
## trust-game-llama-2-7b-chat
|
| 190 |
# app.py
|
| 191 |
-
def get_default_system_prompt(
|
| 192 |
#BOS, EOS = "<s>", "</s>"
|
| 193 |
#BINST, EINST = "[INST]", "[/INST]"
|
| 194 |
BSYS, ESYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
| 195 |
|
| 196 |
-
DEFAULT_SYSTEM_PROMPT = f"""
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
The game consists of 3 rounds. In round 1, The Investor invests between 0 coins and 10 coins.
|
| 200 |
-
This amount is tripled automatically, and The Dealer can then distribute the tripled amount. After that, round 1 is over.
|
| 201 |
-
Both go into the next round with their current asset. This approach is repeated for 3 rounds in total.
|
| 202 |
-
This is what happened in the last rounds: {personalized_data}.
|
| 203 |
-
Your goal is to guide players through the game, providing clear instructions and explanations.
|
| 204 |
-
If any question or action seems unclear, explain it rather than providing inaccurate information.
|
| 205 |
-
If you're unsure about an answer, it's better not to guess.
|
| 206 |
"""
|
| 207 |
print(DEFAULT_SYSTEM_PROMPT)
|
| 208 |
return DEFAULT_SYSTEM_PROMPT
|
| 209 |
|
| 210 |
-
def map_onPage(onPage):
|
| 211 |
-
# Define the mapping of onPage values to onPage_filename and onPage_prompt
|
| 212 |
-
onPage_mapping = {
|
| 213 |
-
"stage407906.php": ("stage 1", "Welcome"),
|
| 214 |
-
"stage407908.php": ("stage 2", "Trust Game Instructions 1/3"),
|
| 215 |
-
"stage407909.php": ("stage 3", "Trust Game Instructions 2/3"),
|
| 216 |
-
"stage407915.php": ("stage 4", "Trust Game Instructions 3/3"),
|
| 217 |
-
"stage407923.php": ("stage 5", "Lobby with AI"),
|
| 218 |
-
"stage407924.php": ("stage 6", "Round 1: Investor’s turn with AI"),
|
| 219 |
-
"stage407925.php": ("stage 7", "Round 1: Dealer’s turn with AI"),
|
| 220 |
-
"stage407926.php": ("stage 8", "Round 2: Investor’s turn with AI"),
|
| 221 |
-
"stage407927.php": ("stage 9", "Round 2: Investor’s turn with AI"),
|
| 222 |
-
"stage407928.php": ("stage 10", "Results with AI after 2 rounds"),
|
| 223 |
-
"stage407929.php": ("stage 11", "Round 3: Investor’s turn with AI"),
|
| 224 |
-
"stage407930.php": ("stage 12", "Round 3: Dealer’s turn with AI"),
|
| 225 |
-
"stage407931.php": ("stage 13", "Overall Questionnaire"),
|
| 226 |
-
"stage407932.php": ("stage 14", "Results with AI after 3 rounds"),
|
| 227 |
-
"stage407933.php": ("stage 15", "Redirect to Prolific - Dropout no compensation"),
|
| 228 |
-
"stage407934.php": ("stage 16", "Not used yet: Redirect to Prolific - Win"),
|
| 229 |
-
"stage407935.php": ("stage 17", "Not used yet: Redirect to Prolific - Completion"),
|
| 230 |
-
}
|
| 231 |
-
|
| 232 |
-
# Check if onPage is in the mapping
|
| 233 |
-
if onPage in onPage_mapping:
|
| 234 |
-
onPage_filename, onPage_prompt = onPage_mapping[onPage]
|
| 235 |
-
else:
|
| 236 |
-
# If onPage is not in the mapping, set onPage_filename and onPage_prompt to "unknown"
|
| 237 |
-
onPage_filename, onPage_prompt = "unknown", "unknown"
|
| 238 |
-
|
| 239 |
-
return onPage_filename, onPage_prompt
|
| 240 |
|
| 241 |
## trust-game-llama-2-7b-chat
|
| 242 |
# app.py
|
| 243 |
def construct_input_prompt(chat_history, message, personalized_data):
|
| 244 |
-
input_prompt = f"<s>[INST] <<SYS>>\n{get_default_system_prompt(
|
| 245 |
for user, assistant in chat_history:
|
| 246 |
input_prompt += f"{user} [/INST] {assistant} <s>[INST] "
|
| 247 |
input_prompt += f"{message} [/INST] "
|
|
@@ -256,7 +117,7 @@ def generate(
|
|
| 256 |
chat_history: list[tuple[str, str]],
|
| 257 |
# input_prompt: str,
|
| 258 |
max_new_tokens: int = 1024,
|
| 259 |
-
temperature: float = 0.
|
| 260 |
top_p: float = 0.9,
|
| 261 |
top_k: int = 50,
|
| 262 |
repetition_penalty: float = 1.2,
|
|
@@ -268,23 +129,9 @@ def generate(
|
|
| 268 |
params = request.query_params
|
| 269 |
print('those are the query params')
|
| 270 |
print(params)
|
| 271 |
-
|
| 272 |
-
# Assuming params = request.query_params is the dictionary containing the query parameters
|
| 273 |
-
# Extract the value of the 'session_index' parameter
|
| 274 |
-
session_index = params.get('session_index')
|
| 275 |
-
|
| 276 |
-
# Check if session_index_value is None or contains a value
|
| 277 |
-
if session_index is not None:
|
| 278 |
-
print("Session index:", session_index)
|
| 279 |
-
else:
|
| 280 |
-
session_index = 'no_session_id'
|
| 281 |
-
print("Session index not found or has no value.")
|
| 282 |
-
|
| 283 |
-
# Fetch personalized data
|
| 284 |
-
personalized_data = fetch_personalized_data(session_index)
|
| 285 |
|
| 286 |
# Construct the input prompt using the functions from the system_prompt_config module
|
| 287 |
-
input_prompt = construct_input_prompt(chat_history, message
|
| 288 |
|
| 289 |
# Move the condition here after the assignment
|
| 290 |
if input_prompt:
|
|
@@ -329,44 +176,9 @@ def generate(
|
|
| 329 |
outputs.append(text)
|
| 330 |
yield "".join(outputs)
|
| 331 |
|
| 332 |
-
# Fix bug that last answer is not recorded!
|
| 333 |
-
# Parse the outputs into a readable sentence and record them
|
| 334 |
-
# Filter out empty strings and join the remaining strings with spaces
|
| 335 |
-
#readable_sentence = ' '.join(filter(lambda x: x.strip(), outputs))
|
| 336 |
-
# Print the readable sentence
|
| 337 |
-
#print(readable_sentence)
|
| 338 |
-
|
| 339 |
-
# Save chat history to .csv file on HuggingFace Hub
|
| 340 |
-
#pd.DataFrame(conversation).to_csv(DATA_FILE, index=False)
|
| 341 |
-
#print("updating conversation")
|
| 342 |
-
#repo.push_to_hub(blocking=False, commit_message=f"Updating data at {datetime.datetime.now()}")
|
| 343 |
-
#print(conversation)
|
| 344 |
-
|
| 345 |
-
# Find onPage variable in personalized_data to add it to the .csv filename to record
|
| 346 |
-
# Initialize onPage variable to None
|
| 347 |
-
onPage = None
|
| 348 |
-
|
| 349 |
-
# Iterate over each dictionary in the list
|
| 350 |
-
for entry in personalized_data:
|
| 351 |
-
# Check if the session_index matches the value in session_index variable
|
| 352 |
-
if entry['session_index'] == session_index:
|
| 353 |
-
# If a match is found, retrieve the onPage value
|
| 354 |
-
onPage = entry['onPage']
|
| 355 |
-
break # Break the loop since we found the desired entry
|
| 356 |
-
|
| 357 |
-
# Check if onPage is still None (i.e., no onPage found or session_index is None)
|
| 358 |
-
if onPage is None:
|
| 359 |
-
onPage = "no_onPage"
|
| 360 |
-
|
| 361 |
-
# Print the onPage value
|
| 362 |
-
print("onPage for session_index =", session_index, ":", onPage)
|
| 363 |
-
onPage_filename, onPage_prompt = map_onPage(onPage)
|
| 364 |
-
print("onPage_filename:", onPage_filename)
|
| 365 |
-
print("onPage_prompt:", onPage_prompt)
|
| 366 |
-
|
| 367 |
# Save chat history to .csv file on HuggingFace Hub
|
| 368 |
# Generate filename with bot id and session id
|
| 369 |
-
filename = f"{
|
| 370 |
data_file = os.path.join(DATA_DIRECTORY, filename)
|
| 371 |
|
| 372 |
# Generate timestamp
|
|
@@ -402,10 +214,9 @@ clear_btn=None,
|
|
| 402 |
undo_btn=None,
|
| 403 |
chatbot=gr.Chatbot(avatar_images=('user.png', 'bot.png'), bubble_full_width = False),
|
| 404 |
examples=[
|
| 405 |
-
["
|
| 406 |
-
["
|
| 407 |
-
["What
|
| 408 |
-
["What is my probability to win if I do not share anything?"],
|
| 409 |
],
|
| 410 |
)
|
| 411 |
|
|
|
|
| 45 |
|
| 46 |
DATASET_REPO_URL = "https://huggingface.co/datasets/botsi/trust-game-llama-2-chat-history"
|
| 47 |
DATA_DIRECTORY = "data" # Separate directory for storing data files
|
| 48 |
+
DATA_FILENAME = "marion.csv" # Default filename
|
| 49 |
DATA_FILE = os.path.join("data", DATA_FILENAME)
|
| 50 |
|
| 51 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
|
|
|
| 83 |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
|
| 84 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 85 |
tokenizer.use_default_system_prompt = False
|
| 86 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
## trust-game-llama-2-7b-chat
|
| 88 |
# app.py
|
| 89 |
+
def get_default_system_prompt():
|
| 90 |
#BOS, EOS = "<s>", "</s>"
|
| 91 |
#BINST, EINST = "[INST]", "[/INST]"
|
| 92 |
BSYS, ESYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
| 93 |
|
| 94 |
+
DEFAULT_SYSTEM_PROMPT = f"""You are a friendly and fun chat assistant.
|
| 95 |
+
If asked about Marion, you can answer the following: There are only two Marions, one is currently working at Accenture and the other one will shortly begin a new job at BMW.
|
| 96 |
+
She recently quit Accenture. Nicknames by her friends include botsi and botsokrates. Her favorite project was Hypercare.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
"""
|
| 98 |
print(DEFAULT_SYSTEM_PROMPT)
|
| 99 |
return DEFAULT_SYSTEM_PROMPT
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
## trust-game-llama-2-7b-chat
|
| 103 |
# app.py
|
| 104 |
def construct_input_prompt(chat_history, message, personalized_data):
|
| 105 |
+
input_prompt = f"<s>[INST] <<SYS>>\n{get_default_system_prompt()}\n<</SYS>>\n\n "
|
| 106 |
for user, assistant in chat_history:
|
| 107 |
input_prompt += f"{user} [/INST] {assistant} <s>[INST] "
|
| 108 |
input_prompt += f"{message} [/INST] "
|
|
|
|
| 117 |
chat_history: list[tuple[str, str]],
|
| 118 |
# input_prompt: str,
|
| 119 |
max_new_tokens: int = 1024,
|
| 120 |
+
temperature: float = 0.9,
|
| 121 |
top_p: float = 0.9,
|
| 122 |
top_k: int = 50,
|
| 123 |
repetition_penalty: float = 1.2,
|
|
|
|
| 129 |
params = request.query_params
|
| 130 |
print('those are the query params')
|
| 131 |
print(params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
# Construct the input prompt using the functions from the system_prompt_config module
|
| 134 |
+
input_prompt = construct_input_prompt(chat_history, message)
|
| 135 |
|
| 136 |
# Move the condition here after the assignment
|
| 137 |
if input_prompt:
|
|
|
|
| 176 |
outputs.append(text)
|
| 177 |
yield "".join(outputs)
|
| 178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
# Save chat history to .csv file on HuggingFace Hub
|
| 180 |
# Generate filename with bot id and session id
|
| 181 |
+
filename = f"{DATA_FILENAME}"
|
| 182 |
data_file = os.path.join(DATA_DIRECTORY, filename)
|
| 183 |
|
| 184 |
# Generate timestamp
|
|
|
|
| 214 |
undo_btn=None,
|
| 215 |
chatbot=gr.Chatbot(avatar_images=('user.png', 'bot.png'), bubble_full_width = False),
|
| 216 |
examples=[
|
| 217 |
+
["How many Marions are there?"],
|
| 218 |
+
["What is your favorite fruit?"],
|
| 219 |
+
["What do you think about AI in the workplace?"],
|
|
|
|
| 220 |
],
|
| 221 |
)
|
| 222 |
|