Spaces:
Runtime error
Runtime error
Caleb Fahlgren
commited on
Commit
·
467c2a7
1
Parent(s):
9fc2d21
fix pickle issue by using dict instead of pydantic model
Browse files
app.py
CHANGED
|
@@ -86,7 +86,7 @@ CREATE TABLE {} (
|
|
| 86 |
|
| 87 |
|
| 88 |
@spaces.GPU
|
| 89 |
-
def generate_query(dataset_id: str, query: str) ->
|
| 90 |
ddl = get_dataset_ddl(dataset_id)
|
| 91 |
|
| 92 |
system_prompt = f"""
|
|
@@ -118,37 +118,38 @@ def generate_query(dataset_id: str, query: str) -> str:
|
|
| 118 |
|
| 119 |
print("Received Response: ", resp)
|
| 120 |
|
| 121 |
-
return resp
|
| 122 |
|
| 123 |
|
| 124 |
def query_dataset(dataset_id: str, query: str) -> Tuple[pd.DataFrame, str, plt.Figure]:
|
| 125 |
-
response
|
| 126 |
|
| 127 |
print("Querying Parquet...")
|
| 128 |
-
df = conn.execute(response.sql).fetchdf()
|
| 129 |
|
| 130 |
plot = None
|
| 131 |
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
response.data_key = None
|
| 137 |
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
plt.xticks(rotation=45, ha="right")
|
| 143 |
plt.tight_layout()
|
| 144 |
-
elif
|
| 145 |
-
plot = df.plot(
|
| 146 |
-
kind="bar", x=response.label_key, y=response.data_key
|
| 147 |
-
).get_figure()
|
| 148 |
plt.xticks(rotation=45, ha="right")
|
| 149 |
plt.tight_layout()
|
| 150 |
|
| 151 |
-
markdown_output = f"""```sql\n{
|
| 152 |
return df, markdown_output, plot
|
| 153 |
|
| 154 |
|
|
|
|
| 86 |
|
| 87 |
|
| 88 |
@spaces.GPU
|
| 89 |
+
def generate_query(dataset_id: str, query: str) -> dict:
|
| 90 |
ddl = get_dataset_ddl(dataset_id)
|
| 91 |
|
| 92 |
system_prompt = f"""
|
|
|
|
| 118 |
|
| 119 |
print("Received Response: ", resp)
|
| 120 |
|
| 121 |
+
return resp.model_dump()
|
| 122 |
|
| 123 |
|
| 124 |
def query_dataset(dataset_id: str, query: str) -> Tuple[pd.DataFrame, str, plt.Figure]:
|
| 125 |
+
response = generate_query(dataset_id, query)
|
| 126 |
|
| 127 |
print("Querying Parquet...")
|
| 128 |
+
df = conn.execute(response.get("sql")).fetchdf()
|
| 129 |
|
| 130 |
plot = None
|
| 131 |
|
| 132 |
+
label_key = response.get("label_key")
|
| 133 |
+
data_key = response.get("data_key")
|
| 134 |
+
viz_type = response.get("visualization_type")
|
| 135 |
+
sql = response.get("sql")
|
|
|
|
| 136 |
|
| 137 |
+
# handle incorrect data and label keys
|
| 138 |
+
if label_key and label_key not in df.columns:
|
| 139 |
+
label_key = None
|
| 140 |
+
if data_key and data_key not in df.columns:
|
| 141 |
+
data_key = None
|
| 142 |
+
|
| 143 |
+
if viz_type == OutputTypes.LINECHART:
|
| 144 |
+
plot = df.plot(kind="line", x=label_key, y=data_key).get_figure()
|
| 145 |
plt.xticks(rotation=45, ha="right")
|
| 146 |
plt.tight_layout()
|
| 147 |
+
elif viz_type == OutputTypes.BARCHART:
|
| 148 |
+
plot = df.plot(kind="bar", x=label_key, y=data_key).get_figure()
|
|
|
|
|
|
|
| 149 |
plt.xticks(rotation=45, ha="right")
|
| 150 |
plt.tight_layout()
|
| 151 |
|
| 152 |
+
markdown_output = f"""```sql\n{sql}\n```"""
|
| 153 |
return df, markdown_output, plot
|
| 154 |
|
| 155 |
|