Upload sql_command.py
Browse files- sql_command.py +50 -0
sql_command.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
|
| 3 |
+
'''
|
| 4 |
+
##TODO:
|
| 5 |
+
|
| 6 |
+
import sqlite3
|
| 7 |
+
import pandas as pd
|
| 8 |
+
|
| 9 |
+
## run the following function to set up a dedicated database using SQLite.
|
| 10 |
+
def construct_db(data_file=None):
|
| 11 |
+
excel_file = "/Users/yunshi/Downloads/360Data/Data Center/Working-On Task/演讲与培训/2023ChatGPT/Coding/Text2SQL/模拟数据.csv" # Replace with your actual file path
|
| 12 |
+
df = pd.read_csv(excel_file)
|
| 13 |
+
print('df:', df.head())
|
| 14 |
+
|
| 15 |
+
conn = sqlite3.connect('myexcelDB.db') # Replace 'mydatabase.db' with your desired name
|
| 16 |
+
# Create a cursor object to execute SQL commands
|
| 17 |
+
cursor = conn.cursor()
|
| 18 |
+
|
| 19 |
+
##NOTE: Insert data from DataFrame into the table. 注意这里的if_exists选项,考虑是否要覆盖原始内容。这里如果需要指定index,那么就需要指定index_label,且index=False,否则会选择默认的index,这样就会产生duplicate错误。
|
| 20 |
+
df.to_sql('table01', conn, if_exists='replace', index=False, index_label="产品ID")
|
| 21 |
+
|
| 22 |
+
return None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def llm_query(sql_command):
|
| 26 |
+
# Connect to the database (or create it if it doesn't exist)
|
| 27 |
+
conn = sqlite3.connect('./myexcelDB.db') # Replace 'myexcelDB.db' with your desired name
|
| 28 |
+
|
| 29 |
+
# Create a cursor object to execute SQL commands
|
| 30 |
+
cursor = conn.cursor()
|
| 31 |
+
## SQL command
|
| 32 |
+
cursor.execute(sql_command)
|
| 33 |
+
query_result = cursor.fetchall() # Fetch all rows as a list of tuples
|
| 34 |
+
# print('query_result:', query_result)
|
| 35 |
+
|
| 36 |
+
## 将列名也取出来。
|
| 37 |
+
column_names = [description[0] for description in cursor.description]
|
| 38 |
+
|
| 39 |
+
query_df = pd.DataFrame(query_result, columns=column_names)
|
| 40 |
+
# query_df.set_index("产品ID", inplace=True)
|
| 41 |
+
print('query_df:', query_df)
|
| 42 |
+
|
| 43 |
+
return query_df
|
| 44 |
+
|
| 45 |
+
# llm_query("SELECT * FROM table01 WHERE 宽度 > 300") ## sample function call.
|
| 46 |
+
|
| 47 |
+
# construct_db()
|
| 48 |
+
# sql = "SELECT 产品ID FROM table01 WHERE 长度 > 50"
|
| 49 |
+
# res = llm_query(sql)
|
| 50 |
+
# res
|