Spaces:
Build error
Build error
Update ai_agent.py
Browse files- ai_agent.py +31 -7
ai_agent.py
CHANGED
|
@@ -168,6 +168,26 @@ def create_chart(df, selected_columns=None, chart_type="bar", title=None):
|
|
| 168 |
except:
|
| 169 |
return None
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
def create_plot(df, dfw, plan, describe_stats, selected_columns=None):
|
| 172 |
plot_spec = plan.get("plot")
|
| 173 |
if not plot_spec:
|
|
@@ -175,9 +195,12 @@ def create_plot(df, dfw, plan, describe_stats, selected_columns=None):
|
|
| 175 |
ptype = plot_spec.get("type", "bar")
|
| 176 |
title = plot_spec.get("title", "Chart")
|
| 177 |
plot_df = df if describe_stats else dfw
|
| 178 |
-
x = plot_spec.get("x")
|
| 179 |
-
y = plot_spec.get("y")
|
| 180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
if not x and len(plot_df.columns) > 0:
|
| 182 |
categorical_cols = plot_df.select_dtypes(include=['object', 'category']).columns
|
| 183 |
x = categorical_cols[0] if len(categorical_cols) > 0 else plot_df.columns[0]
|
|
@@ -187,16 +210,16 @@ def create_plot(df, dfw, plan, describe_stats, selected_columns=None):
|
|
| 187 |
|
| 188 |
try:
|
| 189 |
if ptype == "pie" and x and x in plot_df.columns:
|
| 190 |
-
value_counts = plot_df[x].value_counts()
|
| 191 |
fig = go.Figure(data=[go.Pie(labels=value_counts.index, values=value_counts.values, hole=0.3)])
|
| 192 |
fig.update_layout(title=title, width=900, height=500)
|
| 193 |
return fig
|
| 194 |
elif ptype == "bar" and x and x in plot_df.columns and y and y in plot_df.columns:
|
| 195 |
-
fig = px.bar(plot_df, x=x, y=y, title=title)
|
| 196 |
fig.update_layout(width=900, height=500)
|
| 197 |
return fig
|
| 198 |
elif ptype == "line" and x and x in plot_df.columns and y and y in plot_df.columns:
|
| 199 |
-
fig = px.line(plot_df, x=x, y=y, title=title, markers=True)
|
| 200 |
fig.update_layout(width=900, height=500)
|
| 201 |
return fig
|
| 202 |
elif ptype == "hist" and y and y in plot_df.columns:
|
|
@@ -204,10 +227,11 @@ def create_plot(df, dfw, plan, describe_stats, selected_columns=None):
|
|
| 204 |
fig.update_layout(width=900, height=500)
|
| 205 |
return fig
|
| 206 |
elif ptype == "scatter" and x and x in plot_df.columns and y and y in plot_df.columns:
|
| 207 |
-
fig = px.scatter(plot_df, x=x, y=y, title=title)
|
| 208 |
fig.update_layout(width=900, height=500)
|
| 209 |
return fig
|
| 210 |
-
except:
|
|
|
|
| 211 |
pass
|
| 212 |
return None
|
| 213 |
|
|
|
|
| 168 |
except:
|
| 169 |
return None
|
| 170 |
|
| 171 |
+
def find_column(df, col_name):
|
| 172 |
+
"""Find a column in DataFrame with fuzzy matching (case-insensitive, handles spaces/underscores)"""
|
| 173 |
+
if not col_name:
|
| 174 |
+
return None
|
| 175 |
+
|
| 176 |
+
# Exact match first
|
| 177 |
+
if col_name in df.columns:
|
| 178 |
+
return col_name
|
| 179 |
+
|
| 180 |
+
# Normalize the search term
|
| 181 |
+
search_term = col_name.lower().replace('_', '').replace(' ', '')
|
| 182 |
+
|
| 183 |
+
# Try fuzzy matching
|
| 184 |
+
for col in df.columns:
|
| 185 |
+
normalized_col = str(col).lower().replace('_', '').replace(' ', '')
|
| 186 |
+
if search_term == normalized_col or search_term in normalized_col or normalized_col in search_term:
|
| 187 |
+
return col
|
| 188 |
+
|
| 189 |
+
return None
|
| 190 |
+
|
| 191 |
def create_plot(df, dfw, plan, describe_stats, selected_columns=None):
|
| 192 |
plot_spec = plan.get("plot")
|
| 193 |
if not plot_spec:
|
|
|
|
| 195 |
ptype = plot_spec.get("type", "bar")
|
| 196 |
title = plot_spec.get("title", "Chart")
|
| 197 |
plot_df = df if describe_stats else dfw
|
|
|
|
|
|
|
| 198 |
|
| 199 |
+
# Use fuzzy matching to find columns
|
| 200 |
+
x = find_column(plot_df, plot_spec.get("x"))
|
| 201 |
+
y = find_column(plot_df, plot_spec.get("y"))
|
| 202 |
+
|
| 203 |
+
# Fallback: auto-select columns if not found
|
| 204 |
if not x and len(plot_df.columns) > 0:
|
| 205 |
categorical_cols = plot_df.select_dtypes(include=['object', 'category']).columns
|
| 206 |
x = categorical_cols[0] if len(categorical_cols) > 0 else plot_df.columns[0]
|
|
|
|
| 210 |
|
| 211 |
try:
|
| 212 |
if ptype == "pie" and x and x in plot_df.columns:
|
| 213 |
+
value_counts = plot_df[x].value_counts().head(10) # Limit to top 10 for readability
|
| 214 |
fig = go.Figure(data=[go.Pie(labels=value_counts.index, values=value_counts.values, hole=0.3)])
|
| 215 |
fig.update_layout(title=title, width=900, height=500)
|
| 216 |
return fig
|
| 217 |
elif ptype == "bar" and x and x in plot_df.columns and y and y in plot_df.columns:
|
| 218 |
+
fig = px.bar(plot_df.head(50), x=x, y=y, title=title) # Limit rows for performance
|
| 219 |
fig.update_layout(width=900, height=500)
|
| 220 |
return fig
|
| 221 |
elif ptype == "line" and x and x in plot_df.columns and y and y in plot_df.columns:
|
| 222 |
+
fig = px.line(plot_df.head(100), x=x, y=y, title=title, markers=True)
|
| 223 |
fig.update_layout(width=900, height=500)
|
| 224 |
return fig
|
| 225 |
elif ptype == "hist" and y and y in plot_df.columns:
|
|
|
|
| 227 |
fig.update_layout(width=900, height=500)
|
| 228 |
return fig
|
| 229 |
elif ptype == "scatter" and x and x in plot_df.columns and y and y in plot_df.columns:
|
| 230 |
+
fig = px.scatter(plot_df.head(200), x=x, y=y, title=title) # Limit for performance
|
| 231 |
fig.update_layout(width=900, height=500)
|
| 232 |
return fig
|
| 233 |
+
except Exception as e:
|
| 234 |
+
print(f"Error creating plot: {str(e)}")
|
| 235 |
pass
|
| 236 |
return None
|
| 237 |
|