Commit
·
c3f3890
1
Parent(s):
1b1fce9
Chooses best query from chatgpt
Browse files
app.py
CHANGED
|
@@ -131,6 +131,11 @@ def extract_db_code(text):
|
|
| 131 |
matches = re.findall(pattern, text, re.DOTALL)
|
| 132 |
return [match.strip() for match in matches]
|
| 133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
def generate_dummy_db(db_info, question):
|
| 135 |
pre_prompt = """
|
| 136 |
Generate a SQLite database with dummy data for this database from the DB Layout. Your task is to generate just a database, no queries. For each input do the following:
|
|
@@ -188,6 +193,36 @@ def test_query_on_dummy_db(db_code, query):
|
|
| 188 |
print(f"Query: {query}\tError encountered: {e}")
|
| 189 |
return False
|
| 190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0, repetition_penalty=1.08, format_sql=True, log=False, num_return_sequences=1, num_beams=1, do_sample=False):
|
| 193 |
if num_return_sequences > num_beams:
|
|
@@ -246,15 +281,18 @@ def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0
|
|
| 246 |
query = query.replace("\n", " ").replace("\t", " ").strip()
|
| 247 |
# Test against dummy database
|
| 248 |
success = test_query_on_dummy_db(db_code, query)
|
| 249 |
-
|
| 250 |
-
query = format(query) if format_sql else query
|
| 251 |
if success:
|
| 252 |
responses.append(query)
|
| 253 |
else:
|
| 254 |
responses.append(query)
|
| 255 |
|
| 256 |
-
# Choose
|
| 257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
|
| 259 |
if log:
|
| 260 |
# Log the request to Firestore
|
|
|
|
| 131 |
matches = re.findall(pattern, text, re.DOTALL)
|
| 132 |
return [match.strip() for match in matches]
|
| 133 |
|
| 134 |
+
def extract_from_code_block(text):
|
| 135 |
+
pattern = r'```(?:\w+)?\s?(.*?)```'
|
| 136 |
+
match = re.search(pattern, text, re.DOTALL)
|
| 137 |
+
return match.group(1).strip() if match else ''
|
| 138 |
+
|
| 139 |
def generate_dummy_db(db_info, question):
|
| 140 |
pre_prompt = """
|
| 141 |
Generate a SQLite database with dummy data for this database from the DB Layout. Your task is to generate just a database, no queries. For each input do the following:
|
|
|
|
| 193 |
print(f"Query: {query}\tError encountered: {e}")
|
| 194 |
return False
|
| 195 |
|
| 196 |
+
def choose_best_query(queries, question):
|
| 197 |
+
pre_prompt = """
|
| 198 |
+
Given a list of queries. Your task is to choose just a single query which satisfies the question the most with the least amount of filters, groupings, and conditions. For each input do the following:
|
| 199 |
+
1. Breakdown the list of queries into small pieces and explain what each query is doing.
|
| 200 |
+
2. Explain why each query is relevant to the question.
|
| 201 |
+
3. Choose the most relevant query from your explanation that aligns to the question best with the least amount of unnecessary filters or conditions. Output the best query in a single code block ``````.
|
| 202 |
+
"""
|
| 203 |
+
prompt = pre_prompt + "\n\nQuestion: " + question + "\n\nQueries:" + "\n\n".join(queries)
|
| 204 |
+
|
| 205 |
+
while True:
|
| 206 |
+
try:
|
| 207 |
+
response = openai.ChatCompletion.create(
|
| 208 |
+
model="gpt-3.5-turbo",
|
| 209 |
+
messages=[
|
| 210 |
+
{"role": "user", "content": prompt}
|
| 211 |
+
],
|
| 212 |
+
#temperature=0.7,
|
| 213 |
+
)
|
| 214 |
+
response_text = response['choices'][0]['message']['content']
|
| 215 |
+
print(response_text)
|
| 216 |
+
|
| 217 |
+
query = extract_from_code_block(response_text)
|
| 218 |
+
|
| 219 |
+
return query
|
| 220 |
+
|
| 221 |
+
except Exception as e:
|
| 222 |
+
print(f'Error occurred: {str(e)}')
|
| 223 |
+
print('Waiting for 10 seconds before retrying...')
|
| 224 |
+
time.sleep(10)
|
| 225 |
+
|
| 226 |
|
| 227 |
def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0, repetition_penalty=1.08, format_sql=True, log=False, num_return_sequences=1, num_beams=1, do_sample=False):
|
| 228 |
if num_return_sequences > num_beams:
|
|
|
|
| 281 |
query = query.replace("\n", " ").replace("\t", " ").strip()
|
| 282 |
# Test against dummy database
|
| 283 |
success = test_query_on_dummy_db(db_code, query)
|
| 284 |
+
|
|
|
|
| 285 |
if success:
|
| 286 |
responses.append(query)
|
| 287 |
else:
|
| 288 |
responses.append(query)
|
| 289 |
|
| 290 |
+
# Choose the best query if num_return_sequences > 1
|
| 291 |
+
if num_return_sequences > 1:
|
| 292 |
+
query = choose_best_query(responses, input_message)
|
| 293 |
+
# Format again
|
| 294 |
+
query = format(query) if format_sql else query
|
| 295 |
+
responses = [query]
|
| 296 |
|
| 297 |
if log:
|
| 298 |
# Log the request to Firestore
|